func init() { // Use the BLAS implementation specified in CGO_LDFLAGS. This line can be // commented out to use the native Go BLAS implementation found in // github.com/gonum/blas/native. blas64.Use(cgo.Implementation{}) // These are here so that toggling the BLAS implementation does not make imports unused _ = cgo.Implementation{} _ = blas64.General{} }
func init() { blas64.Use(cgo.Implementation{}) }
func main() { flag.Parse() blas64.Use(cgo.Implementation{}) if *cpuprofile != "" { f, err := os.Create(*cpuprofile) if err != nil { log.Fatal(err) } pprof.StartCPUProfile(f) defer pprof.StopCPUProfile() } http.HandleFunc("/Weights", func(w http.ResponseWriter, r *http.Request) { c := make(chan []byte) weightsChan <- c w.Write(<-c) }) http.HandleFunc("/Loss", func(w http.ResponseWriter, r *http.Request) { c := make(chan []float64) lossChan <- c json.NewEncoder(w).Encode(<-c) }) http.HandleFunc("/PrintDebug", func(w http.ResponseWriter, r *http.Request) { printDebugChan <- struct{}{} }) port := 8085 go func() { log.Printf("Listening on port %d", port) if err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); err != nil { log.Fatalf("%v", err) } }() var seed int64 = 5 rand.Seed(seed) log.Printf("seed: %d", seed) gen, err := poem.NewGenerator("data/quantangshi3000.int") if err != nil { log.Fatalf("%v", err) } h1Size := 512 numHeads := 8 n := 128 m := 32 c := ntm.NewEmptyController1(gen.InputSize(), gen.OutputSize(), h1Size, numHeads, n, m) weights := c.WeightsVal() for i := range weights { weights[i] = 1 * (rand.Float64() - 0.5) } losses := make([]float64, 0) doPrint := false rmsp := ntm.NewRMSProp(c) log.Printf("numweights: %d", len(c.WeightsVal())) var bpcSum float64 = 0 for i := 1; ; i++ { x, y := gen.GenSeq() machines := rmsp.Train(x, &ntm.MultinomialModel{Y: y}, 0.95, 0.5, 1e-3, 1e-3) numChar := len(y) / 2 l := (&ntm.MultinomialModel{Y: y[numChar+1:]}).Loss(ntm.Predictions(machines[numChar+1:])) bpc := l / float64(numChar) bpcSum += bpc acc := 100 if i%acc == 0 { bpc := bpcSum / float64(acc) bpcSum = 0 losses = append(losses, bpc) log.Printf("%d, bpc: %f, seq length: %d", i, bpc, len(y)) } handleHTTP(c, losses, &doPrint) if i%10 == 0 && doPrint { printDebug(y, machines) } } }
func main() { flag.Parse() if flag.NArg() == 0 { log.Fatal("missing file name") } name := flag.Args()[0] f, err := os.Open(name) if err != nil { log.Fatal(err) } defer f.Close() var r io.Reader if path.Ext(name) == ".gz" { gz, err := gzip.NewReader(f) if err != nil { log.Fatal(err) } name = strings.TrimSuffix(name, ".gz") r = gz } else { r = f } // blas64.Use(cgo.Implementation{}) blas64.Use(native.Implementation{}) var aDok *sparse.DOK switch path.Ext(name) { case ".mtx": aDok, err = readMatrixMarket(r) case ".rsa": log.Fatal("reading of Harwell-Boeing format not yet implemented") default: log.Fatal("unknown file extension") } if err != nil { log.Fatal(err) } a := sparse.NewCSR(aDok) n, _ := a.Dims() // Create the right-hand side so that the solution is [1 1 ... 1]. x := make([]float64, n) for i := range x { x[i] = 1 } xVec := mat64.NewVector(n, x) bVec := mat64.NewVector(n, make([]float64, n)) sparse.MulMatVec(bVec, 1, false, a, xVec) result, err := iterative.Solve(a, bVec, nil, nil, &iterative.CG{}) if err != nil { log.Fatal(err) } if result.X.Len() > 10 { fmt.Println("Solution[:10]:", result.X.RawVector().Data[:10]) } else { fmt.Println("Solution:", result.X.RawVector()) } }