func ReadSamples(dir string) (sgd.SampleSet, error) { index, err := speechdata.LoadIndex(dir) if err != nil { return nil, err } var samples sgd.SliceSampleSet for _, sample := range index.Samples { if sample.File == "" { continue } label := cubewhisper.LabelsForMoveString(sample.Label) wavPath := filepath.Join(index.DirPath, sample.File) sampleSeq, err := cubewhisper.ReadAudioFile(wavPath) if err != nil { return nil, err } intLabel := make([]int, len(label)) for i, x := range label { intLabel[i] = int(x) } samples = append(samples, ctc.Sample{Input: sampleSeq, Label: intLabel}) } return samples, nil }
func main() { if len(os.Args) != 3 { fmt.Fprintln(os.Stderr, "Usage: rate <rnn> <sample dir>") os.Exit(1) } rnnData, err := ioutil.ReadFile(os.Args[1]) if err != nil { die("Read RNN", err) } seqFunc, err := rnn.DeserializeBidirectional(rnnData) if err != nil { die("Deserialize RNN", err) } index, err := speechdata.LoadIndex(os.Args[2]) if err != nil { die("Load speech index", err) } log.Println("Crunching numbers...") var res results for _, sample := range index.Samples { if sample.File == "" { continue } label := cubewhisper.LabelsForMoveString(sample.Label) wavPath := filepath.Join(index.DirPath, sample.File) sampleSeq, err := cubewhisper.ReadAudioFile(wavPath) if err != nil { die("Load sample audio", err) } intLabel := make([]int, len(label)) for i, x := range label { intLabel[i] = int(x) } output := evalSample(seqFunc, sampleSeq) likelihood := ctc.LogLikelihood(output, intLabel).Output()[0] res.Likelihoods = append(res.Likelihoods, likelihood) res.SampleIDs = append(res.SampleIDs, sample.ID) } sort.Sort(&res) for i, id := range res.SampleIDs { likelihood := res.Likelihoods[i] fmt.Printf("%d. %s - %e\n", i, id, likelihood) } }