2017年5月26日 星期五

Decision Tree (決策樹)


using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace MachineLearning
{
    class Program
    {
        static void Main(string[] args)
        {
            Data.ShowTreeValue();
        }
    }

    public class Data
    {
        public static void ShowTreeValue()
        {
            string[,] data = new string[,]
            {
                {"youth","high","no","fair","no"},
                {"youth","high","no","excellent","no"},
                {"middle_aged","high","no","fair","yes"},
                {"senior","medium","no","fair","yes"},
                {"senior","low","yes","fair","yes"},
                {"senior","low","yes","excellent","no"},
                {"middle_aged","low","yes","excellent","yes"},
                {"youth","medium","no","fair","no"},
                {"youth","low","yes","fair","yes"},
                {"senior","medium","yes","fair","yes"},
                {"youth","medium","yes","excellent","yes"},
                {"middle_aged","medium","no","excellent","yes"},
                {"middle_aged","high","yes","fair","yes"},
                {"senior","medium","no","excellent","no"}
            };
            string[] names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" };
            DecisionTreeID3<string> tree = new DecisionTreeID3<string>(data, names, new string[] { "yes", "no" });
            tree.Learn();
            Console.ReadKey();
        }
    }

    public class DecisionTreeID3<T> where T : IEquatable<T>
    {
        T[,] Data;
        string[] Lable;
        int AnswerLength;
        T[] CategoryLabels;
        DecisionTreeNode<T> Root;
        public DecisionTreeID3(T[,] data, string[] names, T[] anwserLabels)
        {
            Data = data;
            Lable = names;
            AnswerLength = data.GetLength(1) - 1;//类别变量需要放在最后一列
            CategoryLabels = anwserLabels;
        }

        public void Learn()
        {
            int NumberRows = Data.GetLength(0);
            int NumberColumn = Data.GetLength(1);
            int[] rows = new int[NumberRows];
            int[] cols = new int[NumberColumn];
            for (int i = 0; i < NumberRows; i++) rows[i] = i;
            for (int i = 0; i < NumberColumn; i++) cols[i] = i;
            Root = new DecisionTreeNode<T>(-1, default(T));
            Calculate(rows, cols, Root);
            DisplayNode(Root);
        }
        private void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
        {
            if (Node.IntLabel != -1)
                Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Lable[Node.IntLabel], Node.Value);
            foreach (DecisionTreeNode<T> item in Node.Children)
                DisplayNode(item, depth + 1);
        }
        private static IEnumerable<T> GetAttribute(T[,] data, int IntAnswer, int[] pnRows)
        {
            foreach (int i_row in pnRows)
                yield return data[i_row, IntAnswer];
        }
        private IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
        {
            IEnumerable<Tuple<T, int>> tuples = from n in GetAttribute(Data, col, pnRows)
                                                group n by n into i
                                                select Tuple.Create(i.First(), i.Count());
            return tuples;
        }
        private static double Log2(double x)
        {
            return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
        }
        private double CategoryInfo(int[] pnRows)
        {
            IEnumerable<Tuple<T, int>> tuples = AttributeCount(AnswerLength, pnRows);
            double sum = (double)pnRows.Length;
            double Entropy = 0.0;
            foreach (Tuple<T, int> tuple in tuples)
            {
                double frequency = tuple.Item2 / sum;
                double t = -frequency * Log2(frequency);
                Entropy += t;
            }
            return Entropy;
        }
        private double AttributeInfo(int attrCol, int[] pnRows)
        {
            IEnumerable<Tuple<T, int>> tuples = AttributeCount(attrCol, pnRows);
            double sum = (double)pnRows.Length;
            double Entropy = 0.0;
            foreach (Tuple<T, int> tuple in tuples)
            {
                int[] count = new int[CategoryLabels.Length];
                foreach (int irow in pnRows)
                    if (Data[irow, attrCol].Equals(tuple.Item1))
                    {
                        int index = Array.IndexOf(CategoryLabels, Data[irow, AnswerLength]);
                        count[index]++;
                    }
                double k = 0.0;
                for (int i = 0; i < count.Length; i++)
                {
                    double frequency = count[i] / (double)tuple.Item2;
                    double t = -frequency * Log2(frequency);
                    k += t;
                }
                double freq = tuple.Item2 / sum;
                Entropy += freq * k;
            }
            return Entropy;
        }
        private int MaxEntropy(int[] pnRows, int[] pnCols)
        {
            double cateEntropy = CategoryInfo(pnRows);
            int maxAttr = 0;
            double max = double.MinValue;
            foreach (int icol in pnCols)
                if (icol != AnswerLength)
                {
                    double Gain = cateEntropy - AttributeInfo(icol, pnRows);
                    if (max < Gain)
                    {
                        max = Gain;
                        maxAttr = icol;
                    }
                }
            return maxAttr;
        }
        private void Calculate(int[] pnRows, int[] pnColumns, DecisionTreeNode<T> Root)
        {
            IEnumerable<T> categoryValues = GetAttribute(Data, AnswerLength, pnRows);
            int categoryCount = categoryValues.Distinct().Count();
            if (categoryCount == 1)
            {
                DecisionTreeNode<T> node = new DecisionTreeNode<T>(AnswerLength, categoryValues.First());
                Root.Children.Add(node);
            }
            else
            {
                if (pnRows.Length == 0) return;
                else if (pnColumns.Length == 1)
                {
                    //投票~
                    //多数票表决制
                    IEnumerable<T> Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
                    DecisionTreeNode<T> node = new DecisionTreeNode<T>(AnswerLength, Vote.First());
                    Root.Children.Add(node);
                }
                else
                {
                    int maxCol = MaxEntropy(pnRows, pnColumns);
                    IEnumerable<T> attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
                    string currentPrefix = Lable[maxCol];
                    foreach (var attr in attributes)
                    {
                        int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
                        int[] cols = pnColumns.Where(i => i != maxCol).ToArray();
                        DecisionTreeNode<T> node = new DecisionTreeNode<T>(maxCol, attr);
                        Root.Children.Add(node);
                        Calculate(rows, cols, node);//递归生成决策树
                    }
                }
            }
        }
    }

    public sealed class DecisionTreeNode<T>
    {
        public int IntLabel;
        public T Value { get; set; }
        public List<DecisionTreeNode<T>> Children { get; set; }
        public DecisionTreeNode(int label, T value)
        {
            IntLabel = label;
            Value = value;
            Children = new List<DecisionTreeNode<T>>();
        }
    }
}

沒有留言:

張貼留言

WinFormTb02

https://drive.google.com/drive/u/0/folders/1UwS9FZ3ELCOK6SAwirHrkxq3z_RSbxJt https://www.youtube.com/watch?v=k7IkIeww_U0&list=PLumjEWemD...