using
System;
using
System.Collections.Generic;
using
System.Linq;
using
System.Text;
using
System.Threading.Tasks;
namespace
MachineLearning
{
class Program
{
static
void
Main(string[] args)
{
Data.ShowTreeValue();
}
}
public class
Data
{
public
static
void
ShowTreeValue()
{
string[,]
data = new string[,]
{
{"youth","high","no","fair","no"},
{"youth","high","no","excellent","no"},
{"middle_aged","high","no","fair","yes"},
{"senior","medium","no","fair","yes"},
{"senior","low","yes","fair","yes"},
{"senior","low","yes","excellent","no"},
{"middle_aged","low","yes","excellent","yes"},
{"youth","medium","no","fair","no"},
{"youth","low","yes","fair","yes"},
{"senior","medium","yes","fair","yes"},
{"youth","medium","yes","excellent","yes"},
{"middle_aged","medium","no","excellent","yes"},
{"middle_aged","high","yes","fair","yes"},
{"senior","medium","no","excellent","no"}
};
string[]
names = new string[]
{ "age", "income",
"student", "credit_rating",
"Class: buys_computer" };
DecisionTreeID3<string>
tree = new DecisionTreeID3<string>(data,
names, new string[]
{ "yes", "no"
});
tree.Learn();
Console.ReadKey();
}
}
public class
DecisionTreeID3<T> where T : IEquatable<T>
{
T[,] Data;
string[]
Lable;
int
AnswerLength;
T[] CategoryLabels;
DecisionTreeNode<T>
Root;
public
DecisionTreeID3(T[,] data, string[] names,
T[] anwserLabels)
{
Data = data;
Lable = names;
AnswerLength = data.GetLength(1) -
1;//类别变量需要放在最后一列
CategoryLabels = anwserLabels;
}
public
void
Learn()
{
int
NumberRows = Data.GetLength(0);
int
NumberColumn = Data.GetLength(1);
int[]
rows = new int[NumberRows];
int[]
cols = new int[NumberColumn];
for
(int
i = 0; i < NumberRows; i++) rows[i] = i;
for
(int
i = 0; i < NumberColumn; i++) cols[i] = i;
Root = new
DecisionTreeNode<T>(-1, default(T));
Calculate(rows, cols, Root);
DisplayNode(Root);
}
private
void
DisplayNode(DecisionTreeNode<T> Node, int
depth = 0)
{
if
(Node.IntLabel != -1)
Console.WriteLine("{0} {1}: {2}", new
string('-', depth * 3), Lable[Node.IntLabel], Node.Value);
foreach
(DecisionTreeNode<T> item in
Node.Children)
DisplayNode(item, depth + 1);
}
private
static
IEnumerable<T> GetAttribute(T[,] data, int
IntAnswer, int[] pnRows)
{
foreach
(int
i_row in pnRows)
yield
return
data[i_row, IntAnswer];
}
private
IEnumerable<Tuple<T, int>>
AttributeCount(int col, int[]
pnRows)
{
IEnumerable<Tuple<T, int>>
tuples = from n in
GetAttribute(Data, col, pnRows)
group n by
n into i
select Tuple.Create(i.First(),
i.Count());
return
tuples;
}
private
static
double
Log2(double x)
{
return
x == 0.0 ? 0.0 : Math.Log(x, 2.0);
}
private
double
CategoryInfo(int[] pnRows)
{
IEnumerable<Tuple<T, int>>
tuples = AttributeCount(AnswerLength, pnRows);
double
sum = (double)pnRows.Length;
double
Entropy = 0.0;
foreach
(Tuple<T, int> tuple
in
tuples)
{
double
frequency = tuple.Item2 / sum;
double
t = -frequency * Log2(frequency);
Entropy += t;
}
return
Entropy;
}
private
double
AttributeInfo(int attrCol, int[]
pnRows)
{
IEnumerable<Tuple<T, int>>
tuples = AttributeCount(attrCol, pnRows);
double
sum = (double)pnRows.Length;
double
Entropy = 0.0;
foreach
(Tuple<T, int>
tuple in tuples)
{
int[]
count = new int[CategoryLabels.Length];
foreach
(int
irow in pnRows)
if
(Data[irow, attrCol].Equals(tuple.Item1))
{
int
index = Array.IndexOf(CategoryLabels, Data[irow,
AnswerLength]);
count[index]++;
}
double
k = 0.0;
for
(int
i = 0; i < count.Length; i++)
{
double
frequency = count[i] / (double)tuple.Item2;
double
t = -frequency * Log2(frequency);
k += t;
}
double
freq = tuple.Item2 / sum;
Entropy += freq * k;
}
return
Entropy;
}
private
int
MaxEntropy(int[] pnRows, int[]
pnCols)
{
double
cateEntropy = CategoryInfo(pnRows);
int
maxAttr = 0;
double
max = double.MinValue;
foreach
(int
icol in pnCols)
if
(icol != AnswerLength)
{
double
Gain = cateEntropy - AttributeInfo(icol, pnRows);
if
(max < Gain)
{
max = Gain;
maxAttr = icol;
}
}
return
maxAttr;
}
private
void
Calculate(int[] pnRows, int[]
pnColumns, DecisionTreeNode<T> Root)
{
IEnumerable<T>
categoryValues = GetAttribute(Data, AnswerLength, pnRows);
int
categoryCount = categoryValues.Distinct().Count();
if
(categoryCount == 1)
{
DecisionTreeNode<T>
node = new DecisionTreeNode<T>(AnswerLength,
categoryValues.First());
Root.Children.Add(node);
}
else
{
if
(pnRows.Length == 0) return;
else
if
(pnColumns.Length == 1)
{
//投票~
//多数票表决制
IEnumerable<T>
Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
DecisionTreeNode<T>
node = new DecisionTreeNode<T>(AnswerLength,
Vote.First());
Root.Children.Add(node);
}
else
{
int
maxCol = MaxEntropy(pnRows, pnColumns);
IEnumerable<T>
attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
string
currentPrefix = Lable[maxCol];
foreach
(var
attr in attributes)
{
int[]
rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
int[]
cols = pnColumns.Where(i => i != maxCol).ToArray();
DecisionTreeNode<T> node = new
DecisionTreeNode<T>(maxCol, attr);
Root.Children.Add(node);
Calculate(rows, cols,
node);//递归生成决策树
}
}
}
}
}
public sealed
class
DecisionTreeNode<T>
{
public
int
IntLabel;
public
T Value { get; set;
}
public
List<DecisionTreeNode<T>>
Children { get; set;
}
public
DecisionTreeNode(int label, T value)
{
IntLabel = label;
Value = value;
Children = new
List<DecisionTreeNode<T>>();
}
}
}
沒有留言:
張貼留言