Example #1
0
func init() {
	// Use the BLAS implementation specified in CGO_LDFLAGS. This line can be
	// commented out to use the native Go BLAS implementation found in
	// github.com/gonum/blas/native.
	blas64.Use(cgo.Implementation{})

	// These are here so that toggling the BLAS implementation does not make imports unused
	_ = cgo.Implementation{}
	_ = blas64.General{}
}
Example #2
0
func init() {
	blas64.Use(cgo.Implementation{})
}
Example #3
0
File: main.go Project: philipz/ntm
func main() {
	flag.Parse()
	blas64.Use(cgo.Implementation{})

	if *cpuprofile != "" {
		f, err := os.Create(*cpuprofile)
		if err != nil {
			log.Fatal(err)
		}
		pprof.StartCPUProfile(f)
		defer pprof.StopCPUProfile()
	}

	http.HandleFunc("/Weights", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []byte)
		weightsChan <- c
		w.Write(<-c)
	})
	http.HandleFunc("/Loss", func(w http.ResponseWriter, r *http.Request) {
		c := make(chan []float64)
		lossChan <- c
		json.NewEncoder(w).Encode(<-c)
	})
	http.HandleFunc("/PrintDebug", func(w http.ResponseWriter, r *http.Request) {
		printDebugChan <- struct{}{}
	})
	port := 8085
	go func() {
		log.Printf("Listening on port %d", port)
		if err := http.ListenAndServe(fmt.Sprintf(":%d", port), nil); err != nil {
			log.Fatalf("%v", err)
		}
	}()

	var seed int64 = 5
	rand.Seed(seed)
	log.Printf("seed: %d", seed)

	gen, err := poem.NewGenerator("data/quantangshi3000.int")
	if err != nil {
		log.Fatalf("%v", err)
	}
	h1Size := 512
	numHeads := 8
	n := 128
	m := 32
	c := ntm.NewEmptyController1(gen.InputSize(), gen.OutputSize(), h1Size, numHeads, n, m)
	weights := c.WeightsVal()
	for i := range weights {
		weights[i] = 1 * (rand.Float64() - 0.5)
	}

	losses := make([]float64, 0)
	doPrint := false

	rmsp := ntm.NewRMSProp(c)
	log.Printf("numweights: %d", len(c.WeightsVal()))
	var bpcSum float64 = 0
	for i := 1; ; i++ {
		x, y := gen.GenSeq()
		machines := rmsp.Train(x, &ntm.MultinomialModel{Y: y}, 0.95, 0.5, 1e-3, 1e-3)

		numChar := len(y) / 2
		l := (&ntm.MultinomialModel{Y: y[numChar+1:]}).Loss(ntm.Predictions(machines[numChar+1:]))
		bpc := l / float64(numChar)
		bpcSum += bpc

		acc := 100
		if i%acc == 0 {
			bpc := bpcSum / float64(acc)
			bpcSum = 0
			losses = append(losses, bpc)
			log.Printf("%d, bpc: %f, seq length: %d", i, bpc, len(y))
		}

		handleHTTP(c, losses, &doPrint)

		if i%10 == 0 && doPrint {
			printDebug(y, machines)
		}
	}
}
Example #4
0
func main() {
	flag.Parse()
	if flag.NArg() == 0 {
		log.Fatal("missing file name")
	}
	name := flag.Args()[0]

	f, err := os.Open(name)
	if err != nil {
		log.Fatal(err)
	}
	defer f.Close()

	var r io.Reader
	if path.Ext(name) == ".gz" {
		gz, err := gzip.NewReader(f)
		if err != nil {
			log.Fatal(err)
		}
		name = strings.TrimSuffix(name, ".gz")
		r = gz
	} else {
		r = f
	}

	// blas64.Use(cgo.Implementation{})
	blas64.Use(native.Implementation{})

	var aDok *sparse.DOK
	switch path.Ext(name) {
	case ".mtx":
		aDok, err = readMatrixMarket(r)
	case ".rsa":
		log.Fatal("reading of Harwell-Boeing format not yet implemented")
	default:
		log.Fatal("unknown file extension")
	}
	if err != nil {
		log.Fatal(err)
	}

	a := sparse.NewCSR(aDok)
	n, _ := a.Dims()

	// Create the right-hand side so that the solution is [1 1 ... 1].
	x := make([]float64, n)
	for i := range x {
		x[i] = 1
	}
	xVec := mat64.NewVector(n, x)
	bVec := mat64.NewVector(n, make([]float64, n))
	sparse.MulMatVec(bVec, 1, false, a, xVec)

	result, err := iterative.Solve(a, bVec, nil, nil, &iterative.CG{})
	if err != nil {
		log.Fatal(err)
	}
	if result.X.Len() > 10 {
		fmt.Println("Solution[:10]:", result.X.RawVector().Data[:10])
	} else {
		fmt.Println("Solution:", result.X.RawVector())
	}
}