예제 #1
0
func (algo *EPLogisticRegression) Predict(sample *core.Sample) float64 {
	s := util.Gaussian{Mean: 0.0, Vari: 0.0}
	for _, feature := range sample.Features {
		if feature.Value == 0.0 {
			continue
		}
		wi, ok := algo.Model[feature.Id]
		if !ok {
			wi = &(util.Gaussian{Mean: 0.0, Vari: algo.params.init_var})
		}
		s.Mean += feature.Value * wi.Mean
		s.Vari += feature.Value * feature.Value * wi.Vari
	}

	t := s
	t.Vari += algo.params.beta
	return t.Integral(t.Mean / math.Sqrt(t.Vari))
}
예제 #2
0
func (algo *EPLogisticRegression) Train(dataset *core.DataSet) {

	for _, sample := range dataset.Samples {
		s := util.Gaussian{Mean: 0.0, Vari: 0.0}
		for _, feature := range sample.Features {
			if feature.Value == 0.0 {
				continue
			}
			wi, ok := algo.Model[feature.Id]
			if !ok {
				wi = &(util.Gaussian{Mean: 0.0, Vari: algo.params.init_var})
				algo.Model[feature.Id] = wi
			}
			s.Mean += feature.Value * wi.Mean
			s.Vari += feature.Value * feature.Value * wi.Vari
		}

		t := s
		t.Vari += algo.params.beta

		t2 := util.Gaussian{Mean: 0.0, Vari: 0.0}
		if sample.Label > 0.0 {
			t2.UpperTruncateGaussian(t.Mean, t.Vari, 0.0)
		} else {
			t2.LowerTruncateGaussian(t.Mean, t.Vari, 0.0)
		}
		t.MultGaussian(&t2)
		s2 := t
		s2.Vari += algo.params.beta
		s0 := s
		s.MultGaussian(&s2)

		for _, feature := range sample.Features {
			if feature.Value == 0.0 {
				continue
			}
			wi0 := util.Gaussian{Mean: 0.0, Vari: algo.params.init_var}
			w2 := util.Gaussian{Mean: 0.0, Vari: 0.0}
			wi, _ := algo.Model[feature.Id]
			w2.Mean = (s.Mean - (s0.Mean - wi.Mean*feature.Value)) / feature.Value
			w2.Vari = (s.Vari + (s0.Vari - wi.Vari*feature.Value*feature.Value)) / (feature.Value * feature.Value)
			wi.MultGaussian(&w2)
			wi_vari := wi.Vari
			wi_new_vari := wi_vari * wi0.Vari / (0.99*wi0.Vari + 0.01*wi.Vari)
			wi.Vari = wi_new_vari
			wi.Mean = wi.Vari * (0.99*wi.Mean/wi_vari + 0.01*wi0.Mean/wi.Vari)
			if wi.Vari < algo.params.init_var*0.01 {
				wi.Vari = algo.params.init_var * 0.01
			}
			algo.Model[feature.Id] = wi
		}
	}
}