/
main.go
57 lines (52 loc) · 1.13 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
package main
import (
"flag"
"fmt"
)
func main() {
var loop int
var file string
var mode string
var model string
var lambda float64
var verbose bool
flag.IntVar(&loop, "l", 10, "number of iterations")
flag.Float64Var(&lambda, "c", 0.1, "regularization parameter")
flag.StringVar(&file, "f", "", "data file")
flag.StringVar(&mode, "m", "", "mode {learn, test}")
flag.StringVar(&model, "w", "", "model file")
flag.BoolVar(&verbose, "v", false, "verbose mode")
flag.Parse()
if file == "" {
panic("Data must be specified")
}
if mode == "learn" {
flag.Parse()
X, y := LoadFromFile(file)
p := NewAdaGrad(lambda, loop)
p.Fit(X, y)
SaveModel(p, model)
} else if mode == "test" {
p := LoadModel(model)
X_test, y_test := LoadFromFile(file)
num_corr := 0.
n := 0.
pred_y := []string{}
for i, X_i := range X_test {
pred_y_i := p.Predict(X_i)
pred_y = append(pred_y, pred_y_i)
if verbose {
fmt.Println(pred_y_i)
}
if pred_y_i == y_test[i] {
num_corr += 1
}
n += 1
}
acc := num_corr / n
confusionMatrix(y_test, pred_y)
fmt.Println("Acc:", acc)
} else {
panic("Invalid mode")
}
}