コード例 #1
0
ファイル: smote_test.go プロジェクト: shuLhan/go-mining
func TestSmote(t *testing.T) {
	smot := smote.New(PercentOver, K, 5)

	// Read samples.
	dataset := tabula.Claset{}

	_, e := dsv.SimpleRead(fcfg, &dataset)
	if nil != e {
		t.Fatal(e)
	}

	fmt.Println("[smote_test] Total samples:", dataset.Len())

	minorset := dataset.GetMinorityRows()

	fmt.Println("[smote_test] # minority samples:", minorset.Len())

	e = smot.Resampling(*minorset)
	if e != nil {
		t.Fatal(e)
	}

	fmt.Println("[smote_test] # synthetic:", smot.GetSynthetics().Len())

	e = smot.Write("phoneme_smote.csv")
	if e != nil {
		t.Fatal(e)
	}
}
コード例 #2
0
ファイル: lnsmote_test.go プロジェクト: shuLhan/go-mining
func TestLNSmote(t *testing.T) {
	// Read sample dataset.
	dataset := tabula.Claset{}
	_, e := dsv.SimpleRead(fcfg, &dataset)
	if nil != e {
		t.Fatal(e)
	}

	fmt.Println("[lnsmote_test] Total samples:", dataset.GetNRow())

	// Write original samples.
	writer, e := dsv.NewWriter("")

	if nil != e {
		t.Fatal(e)
	}

	e = writer.OpenOutput("phoneme_lnsmote.csv")
	if e != nil {
		t.Fatal(e)
	}

	sep := dsv.DefSeparator
	_, e = writer.WriteRawRows(dataset.GetRows(), &sep)
	if e != nil {
		t.Fatal(e)
	}

	// Initialize LN-SMOTE.
	lnsmoteRun := lnsmote.New(100, 5, 5, "1", "lnsmote.outliers")

	e = lnsmoteRun.Resampling(&dataset)

	fmt.Println("[lnsmote_test] # synthetic:", lnsmoteRun.Synthetics.Len())

	sep = dsv.DefSeparator
	_, e = writer.WriteRawRows(lnsmoteRun.Synthetics.GetRows(), &sep)
	if e != nil {
		t.Fatal(e)
	}

	e = writer.Close()
	if e != nil {
		t.Fatal(e)
	}
}
コード例 #3
0
ファイル: main.go プロジェクト: shuLhan/go-mining
//
// runSmote will select minority class from dataset and run oversampling.
//
func runSmote(smote *smote.Runtime, dataset *tabula.Claset) (e error) {
	minorset := dataset.GetMinorityRows()

	if DEBUG >= 1 {
		fmt.Println("[smote] # minority samples:", minorset.Len())
	}

	e = smote.Resampling(*minorset)
	if e != nil {
		return
	}

	if DEBUG >= 1 {
		fmt.Println("[smote] # synthetics:", smote.Synthetics.Len())
	}

	return
}
コード例 #4
0
ファイル: main.go プロジェクト: shuLhan/go-mining
func main() {
	defer un(trace("smote"))

	flag.Parse()

	if len(flag.Args()) <= 0 {
		usage()
		os.Exit(1)
	}

	fcfg := flag.Arg(0)

	// Parsing config file and parameter.
	smote, e := createSmote(fcfg)
	if e != nil {
		panic(e)
	}

	// Get dataset.
	dataset := tabula.Claset{}
	_, e = dsv.SimpleRead(fcfg, &dataset)
	if e != nil {
		panic(e)
	}

	fmt.Println("[smote] Dataset:", &dataset)

	row := dataset.GetRow(0)
	fmt.Println("[smote] sample:", row)

	e = runSmote(smote, &dataset)
	if e != nil {
		panic(e)
	}

	if !merge {
		return
	}

	e = runMerge(smote, &dataset)
	if e != nil {
		panic(e)
	}
}
コード例 #5
0
ファイル: main.go プロジェクト: shuLhan/go-mining
func test() {
	testset := tabula.Claset{}
	_, e := dsv.SimpleRead(testCfg, &testset)
	if e != nil {
		panic(e)
	}

	fmt.Println(tag, "Test set:", &testset)
	fmt.Println(tag, "Sample test set:", testset.GetRow(0))

	predicts, cm, probs := crforest.ClassifySetByWeight(&testset, nil)

	fmt.Println("[crf] Test set CM:", cm)

	crforest.Performance(&testset, predicts, probs)

	e = crforest.WritePerformance()
	if e != nil {
		panic(e)
	}
}
コード例 #6
0
ファイル: cart.go プロジェクト: shuLhan/go-mining
/*
CountOOBError process out-of-bag data on tree and return error value.
*/
func (runtime *Runtime) CountOOBError(oob tabula.Claset) (
	errval float64,
	e error,
) {
	// save the original target to be compared later.
	origTarget := oob.GetClassAsStrings()

	if DEBUG >= 2 {
		fmt.Println("[cart] OOB:", oob.Columns)
		fmt.Println("[cart] TREE:", &runtime.Tree)
	}

	// reset the target.
	oobtarget := oob.GetClassColumn()
	oobtarget.ClearValues()

	e = runtime.ClassifySet(&oob)

	if e != nil {
		// set original target values back.
		oobtarget.SetValues(origTarget)
		return
	}

	target := oobtarget.ToStringSlice()

	if DEBUG >= 2 {
		fmt.Println("[cart] original target:", origTarget)
		fmt.Println("[cart] classify target:", target)
	}

	// count how many target value is miss-classified.
	runtime.OOBErrVal, _, _ = tekstus.WordsCountMissRate(origTarget, target)

	// set original target values back.
	oobtarget.SetValues(origTarget)

	return runtime.OOBErrVal, nil
}
コード例 #7
0
ファイル: main.go プロジェクト: shuLhan/go-mining
func main() {
	defer un(trace("cart"))

	flag.Parse()

	if len(flag.Args()) <= 0 {
		usage()
		os.Exit(1)
	}

	fcfg := flag.Arg(0)

	// Parsing config file and check command parameter values.
	cartrt, e := createCart(fcfg)
	if e != nil {
		panic(e)
	}

	// Get dataset
	dataset := tabula.Claset{}
	_, e = dsv.SimpleRead(fcfg, &dataset)
	if e != nil {
		panic(e)
	}

	if DEBUG >= 1 {
		fmt.Printf("[cart] Class index: %v\n", dataset.GetClassIndex())
	}

	e = cartrt.Build(&dataset)
	if e != nil {
		panic(e)
	}

	if DEBUG >= 1 {
		fmt.Println("[cart] CART tree:\n", cartrt)
	}
}
コード例 #8
0
ファイル: claset_test.go プロジェクト: shuLhan/dsv
func TestReaderWithClaset(t *testing.T) {
	fcfg := "testdata/claset.dsv"

	claset := tabula.Claset{}

	_, e := dsv.NewReader(fcfg, &claset)
	if e != nil {
		t.Fatal(e)
	}

	assert(t, 3, claset.GetClassIndex(), true)

	claset.SetMajorityClass("regular")
	claset.SetMinorityClass("vandalism")

	clone := claset.Clone().(tabula.ClasetInterface)

	assert(t, 3, clone.GetClassIndex(), true)
	assert(t, "regular", clone.MajorityClass(), true)
	assert(t, "vandalism", clone.MinorityClass(), true)
}
コード例 #9
0
ファイル: rf_test.go プロジェクト: shuLhan/go-mining
func getSamples() (train, test tabula.ClasetInterface) {
	samples := tabula.Claset{}
	_, e := dsv.SimpleRead(SampleDsvFile, &samples)
	if nil != e {
		log.Fatal(e)
	}

	if !DoTest {
		return &samples, nil
	}

	ntrain := int(float32(samples.Len()) * (float32(NBootstrap) / 100.0))

	bag, oob, _, _ := tabula.RandomPickRows(&samples, ntrain, false)

	train = bag.(tabula.ClasetInterface)
	test = oob.(tabula.ClasetInterface)

	train.SetClassIndex(samples.GetClassIndex())
	test.SetClassIndex(samples.GetClassIndex())

	return train, test
}
コード例 #10
0
ファイル: cart_test.go プロジェクト: shuLhan/go-mining
func TestCART(t *testing.T) {
	fds := "../../testdata/iris/iris.dsv"

	ds := tabula.Claset{}

	_, e := dsv.SimpleRead(fds, &ds)
	if nil != e {
		t.Fatal(e)
	}

	fmt.Println("[cart_test] class index:", ds.GetClassIndex())

	// copy target to be compared later.
	targetv := ds.GetClassAsStrings()

	assert(t, NRows, ds.GetNRow(), true)

	// Build CART tree.
	CART, e := cart.New(&ds, cart.SplitMethodGini, 0)
	if e != nil {
		t.Fatal(e)
	}

	fmt.Println("[cart_test] CART Tree:\n", CART)

	// Create test set
	testset := tabula.Claset{}
	_, e = dsv.SimpleRead(fds, &testset)

	if nil != e {
		t.Fatal(e)
	}

	testset.GetClassColumn().ClearValues()

	// Classifiy test set
	e = CART.ClassifySet(&testset)
	if nil != e {
		t.Fatal(e)
	}

	assert(t, targetv, testset.GetClassAsStrings(), true)
}