Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
malaschitz committed Jun 6, 2019
1 parent 9d56511 commit f19b647
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 391 deletions.
175 changes: 96 additions & 79 deletions forest.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@ var mux = &sync.Mutex{}

func (forest *Forest) Train(trees int) {
forest.NSize = len(forest.Data.X)
forest.NAttrs = len(forest.Data.X[0])
forest.Features = len(forest.Data.X[0])
forest.NTrees = trees
forest.Trees = make([]Tree, forest.NTrees)
forest.ClassFunction = gini2
if forest.MAttrs == 0 {
forest.MAttrs = int(math.Sqrt(float64(forest.NAttrs)))
forest.Classes = 0
for _, c := range forest.Data.Class {
if c >= forest.Classes {
forest.Classes = c + 1
}
}
if forest.MFeatures == 0 {
forest.MFeatures = int(math.Sqrt(float64(forest.Features)))
}
if forest.LeafSize == 0 {
forest.LeafSize = forest.NSize / 20
Expand All @@ -33,38 +38,58 @@ func (forest *Forest) Train(trees int) {
wg.Add(trees)
for i := 0; i < trees; i++ {
go forest.newTree(i, &wg)
//forest.newTree(i, &wg)
//fmt.Println(i)
}
wg.Wait()
imp := make([]float64, forest.NAttrs)
imp := make([]float64, forest.Features)
for i := 0; i < trees; i++ {
z := forest.Trees[i].importance(forest)
for i := 0; i < forest.NAttrs; i++ {
for i := 0; i < forest.Features; i++ {
imp[i] += z[i]
}
//forest.Trees[i].Root.print()
}
for i := 0; i < forest.NAttrs; i++ {
for i := 0; i < forest.Features; i++ {
imp[i] = imp[i] / float64(trees)
}
forest.FeatureImportance = imp
}

func (forest *Forest) Vote(x []float64) float64 {
votes := 0.0
func (forest *Forest) Vote(x []float64) []float64 {
votes := make([]float64, forest.Classes)
for i := 0; i < forest.NTrees; i++ {
votes += forest.Trees[i].vote(x)
v := forest.Trees[i].vote(x)
for j := 0; j < forest.Classes; j++ {
votes[j] += v[j]
}
}
for j := 0; j < forest.Classes; j++ {
votes[j] = votes[j] / float64(forest.NTrees)
}
return votes / float64(forest.NTrees)
return votes
}

func (forest *Forest) WeightVote(x []float64) float64 {
votes := 0.0
func (forest *Forest) WeightVote(x []float64) []float64 {
votes := make([]float64, forest.Classes)
total := 0.0
for i := 0; i < forest.NTrees; i++ {
votes += forest.Trees[i].vote(x) * forest.Trees[i].Validation
total += forest.Trees[i].Validation
e := 1.0001 - forest.Trees[i].Validation
w := 0.5 * math.Log(float64(forest.Classes-1)*(1-e)/e)
if w > 0 {
v := forest.Trees[i].vote(x)
for j := 0; j < forest.Classes; j++ {
votes[j] += v[j] * w
}
total += w
} else {
fmt.Println("wv", e, w, total)
}
}
for j := 0; j < forest.Classes; j++ {
votes[j] = votes[j] / total
}
return votes / total
return votes
}

// Calculate a new tree in forest.
Expand All @@ -73,11 +98,11 @@ func (forest *Forest) newTree(index int, wg *sync.WaitGroup) {
//data
used := make([]bool, forest.NSize)
x := make([][]float64, forest.NSize)
results := make([]bool, forest.NSize)
results := make([]int, forest.NSize)
for i := 0; i < forest.NSize; i++ {
k := rand.Intn(forest.NSize)
x[i] = forest.Data.X[k]
results[i] = forest.Data.Results[k]
results[i] = forest.Data.Class[k]
used[k] = true
}
// build Root
Expand All @@ -86,14 +111,15 @@ func (forest *Forest) newTree(index int, wg *sync.WaitGroup) {
tree := Tree{Root: root}
// validation test tree
count := 0
right := 0.0
e := 0.0
for i := 0; i < forest.NSize; i++ {
if !used[i] {
count++
right += root.vote(forest.Data.X[i])
v := root.vote(forest.Data.X[i])
e += v[forest.Data.Class[i]]
}
}
tree.Validation = right / float64(count)
tree.Validation = e / float64(count)

// add tree
mux.Lock()
Expand All @@ -103,7 +129,7 @@ func (forest *Forest) newTree(index int, wg *sync.WaitGroup) {

func (forest *Forest) PrintFeatureImportance() {
fmt.Println("-------- feature importance")
for i := 0; i < forest.NAttrs; i++ {
for i := 0; i < forest.Features; i++ {
fmt.Println(i, forest.FeatureImportance[i])
}
fmt.Println("-------- cross validation")
Expand All @@ -127,28 +153,26 @@ func (forest *Forest) PrintFeatureImportance() {
fmt.Println("--------")
}

func (branch *Branch) build(forest *Forest, x [][]float64, results []bool, depth int) {
func (branch *Branch) build(forest *Forest, x [][]float64, class []int, depth int) {
//fmt.Println(repeat(".", depth), depth, len(x))
vTrue := 0
vFalse := 0
for _, r := range results {
if r {
vTrue++
} else {
vFalse++
}
classCount := make([]int, forest.Classes)
for _, r := range class {
classCount[r]++
}
branch.Gini = forest.ClassFunction(vTrue, vFalse)
branch.Size = len(results)
branch.Gini = gini(classCount)
branch.Size = len(class)
branch.Depth = depth

if (len(x) <= forest.LeafSize) || (branch.Gini == 0) {
branch.IsLeaf = true
branch.LeafValue = float64(vTrue) / float64(vTrue+vFalse)
branch.LeafValue = make([]float64, forest.Classes)
for i, r := range classCount {
branch.LeafValue[i] = float64(r) / float64(branch.Size)
}
return
}
//find best split
attrsRandom := rand.Perm(forest.NAttrs)[:forest.MAttrs]
attrsRandom := rand.Perm(forest.Features)[:forest.MFeatures]
//fmt.Println(repeat(".", depth), "ATRR", attrsRandom)
var bestAtrr int
var bestValue float64
Expand All @@ -166,27 +190,24 @@ func (branch *Branch) build(forest *Forest, x [][]float64, results []bool, depth
})
//go throuh data
v := x[srt[0]][a]
t := 0
f := 0
s1 := make([]int, forest.Classes)
s2 := make([]int, forest.Classes)
copy(s2, classCount)
for i := 0; i < branch.Size; i++ {
index := srt[i]
if x[index][a] > v {
g1 := forest.ClassFunction(t, f)
g2 := forest.ClassFunction(vTrue-t, vFalse-f)
wg := (g1*float64(t+f) + g2*float64(branch.Size-t-f)) / float64(branch.Size)
//fmt.Println(repeat(".", depth), g1, g2)
g1 := gini(s1)
g2 := gini(s2)
wg := (g1*float64(i) + g2*float64(branch.Size-i)) / float64(branch.Size)
if wg < bestGini {
bestGini = wg
bestValue = v
bestAtrr = a
}
v = x[index][a]
}
if results[index] {
t++
} else {
f++
}
s1[class[index]]++
s2[class[index]]--
}
}
//split it
Expand All @@ -195,38 +216,38 @@ func (branch *Branch) build(forest *Forest, x [][]float64, results []bool, depth
branch.Value = bestValue
x0 := make([][]float64, 0)
x1 := make([][]float64, 0)
r0 := make([]bool, 0)
r1 := make([]bool, 0)
c0 := make([]int, 0)
c1 := make([]int, 0)
for i := 0; i < branch.Size; i++ {
if x[i][branch.Atribute] > branch.Value {
x1 = append(x1, x[i])
r1 = append(r1, results[i])
c1 = append(c1, class[i])
} else {
x0 = append(x0, x[i])
r0 = append(r0, results[i])
c0 = append(c0, class[i])
}
}
//create branches
//fmt.Println(repeat(".", depth), "SPLIT", len(x0), len(x1))
branch.Branch0 = &Branch{}
branch.Branch1 = &Branch{}
branch.Branch0.build(forest, x0, r0, depth+1)
branch.Branch1.build(forest, x1, r1, depth+1)
branch.Branch0.build(forest, x0, c0, depth+1)
branch.Branch1.build(forest, x1, c1, depth+1)
}

func (tree *Tree) vote(x []float64) float64 {
func (tree *Tree) vote(x []float64) []float64 {
return tree.Root.vote(x)
}

func (tree *Tree) importance(forest *Forest) []float64 {
imp := make([]float64, forest.NAttrs)
imp := make([]float64, forest.Features)
tree.Root.importance(imp)
//normalize
sum := 0.0
for i := 0; i < forest.NAttrs; i++ {
for i := 0; i < forest.Features; i++ {
sum += imp[i]
}
for i := 0; i < forest.NAttrs; i++ {
for i := 0; i < forest.Features; i++ {
imp[i] = imp[i] / sum
}
return imp
Expand All @@ -242,7 +263,7 @@ func (branch *Branch) importance(imp []float64) {
}
}

func (branch *Branch) vote(x []float64) float64 {
func (branch *Branch) vote(x []float64) []float64 {
if branch.IsLeaf {
return branch.LeafValue
} else {
Expand All @@ -256,11 +277,11 @@ func (branch *Branch) vote(x []float64) float64 {

func (branch *Branch) print() {
if branch.IsLeaf {
fmt.Printf("%s\tLEAF %t\tsize: %6d\tgini: %5.4f\n",
fmt.Printf("%s ... LEAF %v\tsize: %6d\tgini: %5.4f\n",
repeat("_", branch.Depth*3), branch.LeafValue, branch.Size, branch.Gini)
} else {
fmt.Printf("%s\tsize: %6d\tattr: %3d\tvalue: %4.3f\tgini: %5.4f %5.4f\n",
repeat("_", branch.Depth*3), branch.Size, branch.Atribute, branch.Value, branch.Gini, branch.GiniGain)
fmt.Printf("%s ... size: %6d\tattr: %3d\tgini: %5.4f %5.4f \t\tvalue: %4.3f\n",
repeat("_", branch.Depth*3), branch.Size, branch.Atribute, branch.Gini, branch.GiniGain, branch.Value)
branch.Branch0.print()
branch.Branch1.print()
fmt.Printf("%s\n", repeat("_", branch.Depth*3))
Expand All @@ -283,38 +304,34 @@ func repeat(s string, n int) string {
return z
}

func gini2(a, b int) float64 {
sum := float64(a + b)
g := 1.0 - ((float64(a)/sum)*(float64(a)/sum) + (float64(b)/sum)*(float64(b)/sum))
return g
}

func entropy2(a, b int) float64 {
sum := float64(a + b)
ap := (float64(a) / sum)
bp := (float64(b) / sum)
g := -ap*math.Log2(ap) - bp*math.Log2(bp)
if math.IsNaN(g) {
return 0
func gini(data []int) float64 {
sum := 0
for _, a := range data {
sum += a
}
sumF := float64(sum)
g := 1.0
for _, a := range data {
g = g - (float64(a)/sumF)*(float64(a)/sumF)
}
return g
}

type Forest struct {
Data ForestData
Trees []Tree
Features int // number of attributes
Classes int // number of classes
LeafSize int // leaf size
MAttrs int // attributes for choose proper split
MFeatures int // attributes for choose proper split
NTrees int // number of trees
NAttrs int // number of attributes
NSize int // len of data
ClassFunction func(a, b int) float64
FeatureImportance []float64
}

type ForestData struct {
X [][]float64
Results []bool
X [][]float64
Class []int
}

type Tree struct {
Expand All @@ -326,7 +343,7 @@ type Branch struct {
Atribute int
Value float64
IsLeaf bool
LeafValue float64
LeafValue []float64
Gini float64
GiniGain float64
Size int
Expand Down
6 changes: 6 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module github.com/malaschitz/randomForest

require (
github.com/petar/GoMNIST v0.0.0-20150320212226-2fbe10d0fa63
gonum.org/v1/gonum v0.0.0-20181107204152-48288cca5b5e
)
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
github.com/petar/GoMNIST v0.0.0-20150320212226-2fbe10d0fa63 h1:xS51uYfMeRA+pLKKsXaM/UI06LkyRQ1ItZgOvrOgIu8=
github.com/petar/GoMNIST v0.0.0-20150320212226-2fbe10d0fa63/go.mod h1:d7fwuOuDrb75/3iplL4oWbe4MBZgWil/pSVR3ItECVU=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gonum.org/v1/gonum v0.0.0-20181107204152-48288cca5b5e h1:rFmnPtoaQiuDLenvVepoEGyL56C1PxwVhH5fMHUA1WI=
gonum.org/v1/gonum v0.0.0-20181107204152-48288cca5b5e/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
Loading

0 comments on commit f19b647

Please sign in to comment.