コード例 #1
0
ファイル: main.go プロジェクト: quantrocket/risk_model
func train(records []creditRecord, iusers []internetUser, match [][]float64) (*lr.LogisticRegression, []string) {
	lr := new(lr.LogisticRegression)
	protos, id2f := constructFeatureVectors(iusers)
	borrower2iuser := make([]int, len(records))

	for iter := 0; iter < 100; iter++ {
		// M-step:
		sampleBorrower2IUser(match, borrower2iuser)
		dataset := constructTrainingData(records, borrower2iuser, protos)
		lr.Init(map[string]string{"learning-rate": "0.1", "regularization": "1.0", "steps": "20"})
		lr.Train(dataset)

		// E-step:
		updateMatch(lr, records, protos, match)
	}

	return lr, id2f
}
コード例 #2
0
ファイル: main.go プロジェクト: quantrocket/risk_model
func updateMatch(lr *lr.LogisticRegression, records []creditRecord, protos []*hector.Sample, match [][]float64) {
	predictions := make([]float64, len(protos))
	for i, proto := range protos {
		predictions[i] = lr.Predict(proto)
	}

	for borrower, dist := range match {
		r := records[borrower].returned
		nr := records[borrower].borrowed - r

		for iuser, gamma := range dist {
			match[borrower][iuser] = gamma *
				math.Exp(float64(r)*math.Log(1-predictions[iuser])+float64(nr)*math.Log(predictions[iuser]))
		}

		norm := sum(match[borrower])
		for iuser, prob := range match[borrower] {
			match[borrower][iuser] = prob / norm
		}
	}
}