func ClassifyCmd(netPath, imgPath string) { networkData, err := ioutil.ReadFile(netPath) if err != nil { fmt.Fprintln(os.Stderr, "Error reading network:", err) os.Exit(1) } network, err := neuralnet.DeserializeNetwork(networkData) if err != nil { fmt.Fprintln(os.Stderr, "Error deserializing network:", err) os.Exit(1) } img, width, height, err := ReadImageFile(imgPath) if err != nil { fmt.Fprintln(os.Stderr, "Error reading image:", err) os.Exit(1) } firstLayer := network[1].(*neuralnet.ConvLayer) if width != firstLayer.InputWidth || height != firstLayer.InputHeight { fmt.Fprintf(os.Stderr, "Expected dimensions %dx%d but got %dx%d\n", firstLayer.InputWidth, firstLayer.InputHeight, width, height) } output := network.Apply(&autofunc.Variable{Vector: img}).Output() for i, x := range output { fmt.Printf("Class %d: probability %f\n", i, math.Exp(x)) } }
// DeserializeNetworkSeqFunc deserializes a NetworkSeqFunc // that was previously serialized. func DeserializeNetworkSeqFunc(d []byte) (*NetworkSeqFunc, error) { net, err := neuralnet.DeserializeNetwork(d) if err != nil { return nil, err } return &NetworkSeqFunc{Network: net}, nil }
func DreamCmd(netPath, imgPath string) { networkData, err := ioutil.ReadFile(netPath) if err != nil { fmt.Fprintln(os.Stderr, "Error reading network:", err) os.Exit(1) } network, err := neuralnet.DeserializeNetwork(networkData) if err != nil { fmt.Fprintln(os.Stderr, "Error deserializing network:", err) os.Exit(1) } convIn := network[1].(*neuralnet.ConvLayer) inputImage := &autofunc.Variable{ Vector: make(linalg.Vector, convIn.InputWidth*convIn.InputHeight* convIn.InputDepth), } for i := range inputImage.Vector { inputImage.Vector[i] = rand.Float64()*0.01 + 0.5 } desiredOut := linalg.Vector{0, 1} cost := neuralnet.DotCost{} grad := autofunc.NewGradient([]*autofunc.Variable{inputImage}) for i := 0; i < 1000; i++ { output := network.Apply(inputImage) costOut := cost.Cost(desiredOut, output) grad.Zero() log.Println("cost is", costOut.Output()[0]) costOut.PropagateGradient(linalg.Vector{1}, grad) grad.AddToVars(-0.01) } newImage := image.NewRGBA(image.Rect(0, 0, convIn.InputWidth, convIn.InputHeight)) var idx int for y := 0; y < convIn.InputHeight; y++ { for x := 0; x < convIn.InputWidth; x++ { r := uint8(0xff * inputImage.Vector[idx]) g := uint8(0xff * inputImage.Vector[idx+1]) b := uint8(0xff * inputImage.Vector[idx+2]) newImage.SetRGBA(x, y, color.RGBA{ R: r, G: g, B: b, A: 0xff, }) idx += 3 } } output, err := os.Create(imgPath) if err != nil { fmt.Fprintln(os.Stderr, "Failed to create output file:", err) os.Exit(1) } defer output.Close() png.Encode(output, newImage) }
func Run() { encoderPath := os.Args[2] encoderData, err := ioutil.ReadFile(encoderPath) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } network, err := neuralnet.DeserializeNetwork(encoderData) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } inputPath := os.Args[3] outputPath := os.Args[4] f, err := os.Open(inputPath) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } defer f.Close() inputImage, _, err := image.Decode(f) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } res := network.Apply(&autofunc.Variable{Vector: ImageTensor(inputImage).Data}) tensor := &neuralnet.Tensor3{ Width: inputImage.Bounds().Dx(), Height: inputImage.Bounds().Dy(), Depth: 3, Data: res.Output(), } image := ImageFromTensor(tensor) outFile, err := os.Create(outputPath) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } defer outFile.Close() if err := png.Encode(outFile, image); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } }
func TrainCmd(netPath, dirPath string) { log.Println("Loading samples...") images, width, height, err := LoadTrainingImages(dirPath) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } log.Println("Creating network...") var network neuralnet.Network networkData, err := ioutil.ReadFile(netPath) if err == nil { network, err = neuralnet.DeserializeNetwork(networkData) if err != nil { fmt.Fprintln(os.Stderr, "Failed to load network:", err) os.Exit(1) } log.Println("Loaded network from file.") } else { mean, stddev := sampleStatistics(images) convLayer := &neuralnet.ConvLayer{ FilterCount: FilterCount, FilterWidth: 4, FilterHeight: 4, Stride: 2, InputWidth: width, InputHeight: height, InputDepth: ImageDepth, } maxLayer := &neuralnet.MaxPoolingLayer{ XSpan: 3, YSpan: 3, InputWidth: convLayer.OutputWidth(), InputHeight: convLayer.OutputHeight(), InputDepth: convLayer.OutputDepth(), } convLayer1 := &neuralnet.ConvLayer{ FilterCount: FilterCount1, FilterWidth: 3, FilterHeight: 3, Stride: 2, InputWidth: maxLayer.OutputWidth(), InputHeight: maxLayer.OutputHeight(), InputDepth: maxLayer.InputDepth, } network = neuralnet.Network{ &neuralnet.RescaleLayer{ Bias: -mean, Scale: 1 / stddev, }, convLayer, neuralnet.HyperbolicTangent{}, maxLayer, neuralnet.HyperbolicTangent{}, convLayer1, neuralnet.HyperbolicTangent{}, &neuralnet.DenseLayer{ InputCount: convLayer1.OutputWidth() * convLayer1.OutputHeight() * convLayer1.OutputDepth(), OutputCount: HiddenSize, }, neuralnet.HyperbolicTangent{}, &neuralnet.DenseLayer{ InputCount: HiddenSize, OutputCount: len(images), }, &neuralnet.LogSoftmaxLayer{}, } network.Randomize() log.Println("Created new network.") } samples := neuralSamples(images) sgd.ShuffleSampleSet(samples) validationCount := int(ValidationFraction * float64(samples.Len())) validationSamples := samples.Subset(0, validationCount) trainingSamples := samples.Subset(validationCount, samples.Len()) costFunc := neuralnet.DotCost{} gradienter := &sgd.Adam{ Gradienter: &neuralnet.BatchRGradienter{ Learner: network.BatchLearner(), CostFunc: &neuralnet.RegularizingCost{ Variables: network.Parameters(), Penalty: Regularization, CostFunc: costFunc, }, }, } sgd.SGDInteractive(gradienter, trainingSamples, StepSize, BatchSize, func() bool { log.Printf("Costs: validation=%d/%d cost=%f", countCorrect(network, validationSamples), validationSamples.Len(), neuralnet.TotalCost(costFunc, network, trainingSamples)) return true }) data, _ := network.Serialize() if err := ioutil.WriteFile(netPath, data, 0755); err != nil { fmt.Fprintln(os.Stderr, "Failed to save:", err) os.Exit(1) } }