示例#1
0
func Train(data []Datum) *NaiveBayes {
	class := counter.New(0.0)
	features := make(map[string]*counter.Counter)

	for _, datum := range data {
		class.Incr(datum.class)
		for _, f := range datum.features {
			dist, ok := features[f]

			if !ok {
				dist = counter.New(0.0)
				features[f] = dist
			}

			dist.Incr(datum.class)
		}
	}

	class.LogNormalize()
	for _, dist := range features {
		dist.LogNormalize()
	}

	frozenFeatures := frozencounter.FreezeMap(features)

	var keyset *frozencounter.KeySet
	for _, dist := range frozenFeatures {
		keyset = dist.Keys
	}

	frozenClass := frozencounter.FreezeWithKeySet(class, keyset)

	return &NaiveBayes{FeatureLogDistributions: frozenFeatures, ClassLogPrior: frozenClass}
}
示例#2
0
func tally(data []Datum) (counts *frozencounter.CounterVector, features *frozencounter.KeySet, labels []string) {
	rawCounts := map[string]*counter.Counter{}

	datumCounts := []*counter.Counter{}
	for _, datum := range data {
		if rawCounts[datum.class] == nil {
			rawCounts[datum.class] = counter.New(0.0)
		}
		c := counter.New(0.0)

		for _, f := range datum.features {
			rawCounts[datum.class].Incr(f)
			c.Incr(f)
		}

		datumCounts = append(datumCounts, c)
	}

	for idx, c := range frozencounter.FreezeMany(datumCounts) {
		data[idx].featureCounts = c
	}

	counts = frozencounter.NewCounterVector(frozencounter.FreezeMap(rawCounts))

	features = data[0].featureCounts.Keys
	for label, _ := range counts.Extract() {
		labels = append(labels, label)
	}
	return
}
示例#3
0
// Convert a frozen counter back into a counter.Counter.
func (c *Counter) Thaw() *counter.Counter {
	t := counter.New(c.Keys.Base)

	for s, idx := range c.Keys.Positions {
		t.Set(s, c.values[idx])
	}

	return t
}
示例#4
0
// Calculate the label distribution of features given weights, storing the result in out
func (w *maxentWeights) labelDistribution(counts *frozencounter.Counter, weights *frozencounter.CounterVector) *frozencounter.Counter {
	out := counter.New(0.0)

	for label, featureWeights := range weights.Extract() {
		out.Set(label, featureWeights.DotProduct(counts))
	}

	out.LogNormalize()
	return frozencounter.Freeze(out)
}