Exemplo n.º 1
0
func readShuffled(in io.Reader, rand *rand.Rand) []TrainingInstance {
	reader, err := input.NewTrainDataReader(in)
	common.ExitIfError("Error reading data: ", err)

	instances := make([]TrainingInstance, 0)

	for {
		err := reader.Scan()
		if err == io.EOF {
			break
		} else {
			common.ExitIfError("Error reading data: ", err)
		}

		instance := TrainingInstance{
			X: reader.InputVector().Dup(),
			Y: reader.Label(),
		}

		idx := rand.Int63n(int64(len(instances)) + 1)
		if int(idx) == len(instances) {
			instances = append(instances, instance)
		} else {
			instances = append(instances, instances[idx])
			instances[idx] = instance
		}
	}

	return instances
}
Exemplo n.º 2
0
func main() {
	flag.Parse()

	if flag.NArg() != 1 {
		flag.Usage()
		os.Exit(1)
	}

	f, err := os.Open(flag.Arg(0))
	common.ExitIfError("Could not open data file: ", err)
	defer f.Close()

	reader, err := input.NewTrainDataReader(f)
	common.ExitIfError("Error reading data: ", err)

	for {
		err := reader.Scan()
		if err == io.EOF {
			break
		} else {
			common.ExitIfError("Error reading data: ", err)
		}

		if *features {
			fmt.Printf("%d %s\n", reader.Label(), floatSliceToString(reader.InputVector().Layer(addr.FEATURE)))
		} else {
			fmt.Printf("%d %s\n", reader.Label(), floatSliceToString(reader.InputVector().All()))
		}
	}

}
Exemplo n.º 3
0
func writeTransitions(ts system.TransitionSystem, labelNumberer *system.LabelNumberer, transitionsFilename string) {
	serializer, ok := ts.(system.TransitionSerializer)
	if !ok {
		log.Fatal("Transition system does not implement transition serialization")
	}

	f, err := os.Create(transitionsFilename)
	common.ExitIfError("Cannot create transition file:", err)
	defer f.Close()

	err = labelNumberer.WriteLabelNumberer(f, serializer)
	common.ExitIfError("Cannot create label file:", err)
}
Exemplo n.º 4
0
func main() {
	flag.Parse()

	if flag.NArg() != 3 {
		flag.Usage()
		os.Exit(1)
	}

	config := common.ReadConfigOrExit(flag.Arg(0))

	if !config.Embeddings.Word.NormalizeInput {
		log.Println("Token layer inputs will not be normalized")
	}

	if !config.Embeddings.Tag.NormalizeInput {
		log.Println("Tag layer inputs will not be normalized")
	}

	if !config.Embeddings.DepRel.NormalizeInput {
		log.Println("Dependency layer inputs will not be normalized")
	}

	if !config.Embeddings.Feature.NormalizeInput {
		log.Println("Feature layer inputs will not be normalized")
	}

	if !config.Embeddings.Char.NormalizeInput {
		log.Println("Character layer inputs will not be normalized")
	}

	var normalizer *input.Normalizer

	normFilename := config.Parser.Normalisation
	if fileExists(normFilename) {
		log.Printf("Read normalization parameters from %s", normFilename)
		normalizer = common.ReadNormalizerOrExit(normFilename)
	} else {
		log.Print("Extracting normalization parameters from data")
		acc := input.NewAccumulator(normLayers(config), func() normalization.Accumulator { return normalization.NewVarianceAccumulator() })
		err := extractParameters(flag.Arg(1), acc)
		common.ExitIfError("Error extracting normalizer parameters: ", err)

		normalizer = acc.Normalizer()
		err = writeNormalizer(normFilename, normalizer)
		common.ExitIfError("Error writing normalizer parameters: ", err)
	}

	log.Printf("Normalizing data from %s and writing to %s", flag.Arg(1), flag.Arg(2))
	normalizeData(normalizer, flag.Arg(1), flag.Arg(2))
}
Exemplo n.º 5
0
func writeData(out io.Writer, instances []TrainingInstance) {
	writer := input.NewTrainDataWriter(out)

	for _, instance := range instances {
		err := writer.Write(instance.Y, instance.X)
		common.ExitIfError("Error writing data: ", err)
	}
}
Exemplo n.º 6
0
func main() {
	flag.Parse()

	if flag.NArg() != 2 {
		flag.Usage()
		os.Exit(1)
	}

	in, err := os.Open(flag.Arg(0))
	common.ExitIfError("Could not open data file: ", err)
	defer in.Close()

	out, err := os.Create(flag.Arg(1))
	common.ExitIfError("Could not open output data file for writing: ", err)
	defer out.Close()

	source := rand.NewSource(*seed)
	rand := rand.New(source)

	shuffled := readShuffled(in, rand)

	writeData(out, shuffled)
}
Exemplo n.º 7
0
func main() {
	flag.Parse()
	if flag.NArg() != 3 {
		flag.Usage()
		os.Exit(1)
	}

	vecs := common.ReadEmbeddingsOrExit(common.Embedding{flag.Arg(0), false, false})
	network := common.ReadModelOrExit(flag.Arg(1), cblas.Implementation{})

	out, err := os.Create(flag.Arg(2))
	common.ExitIfError("Cannot open output vectors for writing: ", err)
	defer out.Close()

	if network.Layers() != 1 {
		fmt.Fprintf(os.Stderr, "Weight file contains %d layers, expected 1", network.Layers())
		os.Exit(1)
	}

	layer := network.Layer(0)
	weights := layer.W()

	if layer.Inputs() != uint(vecs.Size()) {
		fmt.Fprintf(os.Stderr, "Embedding layer and one-hot size mismatch: %d - %d", layer.Inputs(), vecs.Size())
		os.Exit(1)
	}

	mergedVecs := go2vec.NewEmbeddings(int(layer.Outputs()))

	wordIdx := 0
	vec := make([]float32, layer.Outputs())
	vecs.Iterate(func(word string, vector []float32) bool {
		for idx := range vec {
			vec[idx] = weights[uint(idx)*layer.Inputs()+uint(wordIdx)]
		}

		mergedVecs.Put(word, vec)
		wordIdx++
		return true
	})

	writer := bufio.NewWriter(out)
	mergedVecs.Write(writer)
	writer.Flush()
}
Exemplo n.º 8
0
func run(parser system.Parser) {
	inputFile := os.Stdin
	if flag.NArg() == 2 {
		var err error
		inputFile, err = os.Open(flag.Arg(1))
		common.ExitIfError("Cannot open data:", err)
		defer inputFile.Close()
	}

	inputReader := conllx.NewReader(bufio.NewReader(inputFile))
	writer := conllx.NewWriter(os.Stdout)

	for {
		s, err := inputReader.ReadSentence()
		if err != nil {
			break
		}

		deps, err := parser.Parse(s)
		if err != nil {
			log.Fatal(err)
		}

		// Clear to ensure that no dependencies in the input leak
		// (if they were present).
		for idx := range s {
			s[idx].SetHead(0)
			s[idx].SetHeadRel("NULL")
		}

		for dep := range deps {
			s[dep.Dependent-1].SetHead(dep.Head)
			s[dep.Dependent-1].SetHeadRel(dep.Relation)
		}

		writer.WriteSentence(s)
	}
}
Exemplo n.º 9
0
func main() {
	flag.Parse()

	if flag.NArg() != 3 {
		flag.Usage()
		os.Exit(1)
	}

	config := common.ReadConfigOrExit(flag.Arg(0))

	transitionSystem, ok := common.TransitionSystems[config.Parser.System]
	if !ok {
		log.Fatalf("Unknown transition system: %s", config.Parser.System)
	}

	oracleConstructor, ok := common.Oracles[config.Parser.System]
	if !ok {
		log.Fatalf("Unknown transition system: %s", config.Parser.System)
	}

	log.Printf("Transition system: %s", config.Parser.System)

	ilas := common.ReadIlasOrExit(config.Parser.Inputs)

	var labelNumberer *system.LabelNumberer
	if config.Parser.Transitions != "" {
		if _, err := os.Stat(config.Parser.Transitions); err == nil {
			log.Printf("Transitions filename %s exists, reusing...", config.Parser.Transitions)
			labelNumberer = common.ReadTransitionsOrExit(config.Parser.Transitions, transitionSystem)
		}
	}

	instanceWriter, err := os.Create(flag.Arg(2))
	common.ExitIfError("Cannot open instance file for writing:", err)
	defer instanceWriter.Close()
	trainDataWriter := input.NewTrainDataWriter(instanceWriter)

	layerEmbeddings := common.MustReadAllEmbeddings(config.Embeddings)

	realizer := input.NewInputVectorRealizer(ilas, layerEmbeddings, nil)

	var collector *common.WritingCollector
	if labelNumberer == nil {
		collector = common.NewWritingCollector(realizer, trainDataWriter)
	} else {
		collector = common.NewWritingCollectorWithLabelNumberer(realizer, labelNumberer, trainDataWriter)
	}

	trainer := system.NewGreedyTrainer(transitionSystem, collector)

	f, err := os.Open(flag.Arg(1))
	common.ExitIfError("Cannot open training data:", err)
	defer f.Close()

	log.Println("Creating training instances...")
	common.ProcessData(f, func(s []conllx.Token) error {
		goldDependencies, err := system.SentenceToDependencies(s)
		if err != nil {
			return fmt.Errorf("Cannot extract dependencies: %s", err.Error())
		}
		trainer.Parse(s, oracleConstructor(goldDependencies))

		return nil
	})

	if err != nil {
		common.ExitIfError("Cannot process data:", err)
	}

	if config.Parser.Transitions != "" {
		if _, err := os.Stat(config.Parser.Transitions); err != nil {
			writeTransitions(transitionSystem, collector.LabelNumberer(), config.Parser.Transitions)
		}
	}

}