Esempio n. 1
0
// GenerateCrossFoldValidationConfusionMatrices divides the data into a number of folds
// then trains and evaluates the classifier on each fold, producing a new ConfusionMatrix.
func GenerateCrossFoldValidationConfusionMatrices(data base.FixedDataGrid, cls base.Classifier, folds int) ([]ConfusionMatrix, error) {
	_, rows := data.Size()

	// Assign each row to a fold
	foldMap := make([]int, rows)
	inverseFoldMap := make(map[int][]int)
	for i := 0; i < rows; i++ {
		fold := rand.Intn(folds)
		foldMap[i] = fold
		if _, ok := inverseFoldMap[fold]; !ok {
			inverseFoldMap[fold] = make([]int, 0)
		}
		inverseFoldMap[fold] = append(inverseFoldMap[fold], i)
	}

	ret := make([]ConfusionMatrix, folds)

	// Create training/test views for each fold
	for i := 0; i < folds; i++ {
		// Fold i is for testing
		testData := base.NewInstancesViewFromVisible(data, inverseFoldMap[i], data.AllAttributes())
		otherRows := make([]int, 0)
		for j := 0; j < folds; j++ {
			if i == j {
				continue
			}
			otherRows = append(otherRows, inverseFoldMap[j]...)
		}
		trainData := base.NewInstancesViewFromVisible(data, otherRows, data.AllAttributes())
		// Train
		err := cls.Fit(trainData)
		if err != nil {
			return nil, err
		}
		// Predict
		pred, err := cls.Predict(testData)
		if err != nil {
			return nil, err
		}
		// Evaluate
		cf, err := GetConfusionMatrix(testData, pred)
		if err != nil {
			return nil, err
		}
		ret[i] = cf
	}
	return ret, nil

}
Esempio n. 2
0
func main() {

	var tree base.Classifier

	rand.Seed(44111342)

	// Load in the iris dataset
	iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Discretise the iris dataset with Chi-Merge
	filt := filters.NewChiMergeFilter(iris, 0.999)
	for _, a := range base.NonClassFloatAttributes(iris) {
		filt.AddAttribute(a)
	}
	filt.Train()
	irisf := base.NewLazilyFilteredInstances(iris, filt)

	// Create a 60-40 training-test split
	trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)

	//
	// First up, use ID3
	//
	tree = trees.NewID3DecisionTree(0.6)
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}

	// Generate predictions
	predictions, err := tree.Predict(testData)
	if err != nil {
		panic(err)
	}

	// Evaluate
	fmt.Println("ID3 Performance (information gain)")
	cf, err := evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))

	tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}

	// Generate predictions
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}

	// Evaluate
	fmt.Println("ID3 Performance (information gain ratio)")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))

	tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}

	// Generate predictions
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}

	// Evaluate
	fmt.Println("ID3 Performance (gini index generator)")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))
	//
	// Next up, Random Trees
	//

	// Consider two randomly-chosen attributes
	tree = trees.NewRandomTree(2)
	err = tree.Fit(testData)
	if err != nil {
		panic(err)
	}
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}
	fmt.Println("RandomTree Performance")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))

	//
	// Finally, Random Forests
	//
	tree = ensemble.NewRandomForest(70, 3)
	err = tree.Fit(trainData)
	if err != nil {
		panic(err)
	}
	predictions, err = tree.Predict(testData)
	if err != nil {
		panic(err)
	}
	fmt.Println("RandomForest Performance")
	cf, err = evaluation.GetConfusionMatrix(testData, predictions)
	if err != nil {
		panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
	}
	fmt.Println(evaluation.GetSummary(cf))
}
Esempio n. 3
0
func main() {

	var tree base.Classifier

	rand.Seed(time.Now().UTC().UnixNano())

	// Load in the iris dataset
	iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
	if err != nil {
		panic(err)
	}

	// Discretise the iris dataset with Chi-Merge
	filt := filters.NewChiMergeFilter(iris, 0.99)
	for _, a := range base.NonClassFloatAttributes(iris) {
		filt.AddAttribute(a)
	}
	filt.Train()
	irisf := base.NewLazilyFilteredInstances(iris, filt)

	// Create a 60-40 training-test split
	trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)

	//
	// First up, use ID3
	//
	tree = trees.NewID3DecisionTree(0.6)
	// (Parameter controls train-prune split.)

	// Train the ID3 tree
	tree.Fit(trainData)

	// Generate predictions
	predictions := tree.Predict(testData)

	// Evaluate
	fmt.Println("ID3 Performance")
	cf := eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(eval.GetSummary(cf))

	//
	// Next up, Random Trees
	//

	// Consider two randomly-chosen attributes
	tree = trees.NewRandomTree(2)
	tree.Fit(testData)
	predictions = tree.Predict(testData)
	fmt.Println("RandomTree Performance")
	cf = eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(eval.GetSummary(cf))

	//
	// Finally, Random Forests
	//
	tree = ensemble.NewRandomForest(100, 3)
	tree.Fit(trainData)
	predictions = tree.Predict(testData)
	fmt.Println("RandomForest Performance")
	cf = eval.GetConfusionMatrix(testData, predictions)
	fmt.Println(eval.GetSummary(cf))
}