예제 #1
0
func (b *Categorical) Run() {
	var err error
	θ := []float64{1.0}
	sampler := NewCategoricalSampler(θ)
	for {
		select {
		case ruleI := <-b.inrule:
			// set a parameter of the block
			rule, ok := ruleI.(map[string]interface{})
			if !ok {
				b.Error(errors.New("couldn't assert rule to map"))
			}
			θ, err = util.ParseArrayFloat(rule, "Weights")
			if err != nil {
				b.Error(err)
			}
			// normalise!
			Z := 0.0
			for _, θi := range θ {
				Z += θi
			}
			if Z == 0 {
				b.Error(errors.New("Weights must not sum to zero"))
				continue
			}
			for i := range θ {
				θ[i] /= Z
			}

			sampler = NewCategoricalSampler(θ)
		case <-b.quit:
			// quit the block
			return
		case <-b.inpoll:
			// deal with a poll request
			b.out <- map[string]interface{}{
				"sample": float64(sampler()),
			}
		case c := <-b.queryrule:
			// deal with a query request
			c <- map[string]interface{}{
				"Weights": θ,
			}
		}
	}
}
예제 #2
0
// Run is the block's main loop. Here we listen on the different channels we set up.
func (b *LogisticModel) Run() {

	var β []float64
	var featurePaths []string
	var featureTrees []*jee.TokenTree
	var err error

	for {
	Loop:
		select {
		case rule := <-b.inrule:
			β, err = util.ParseArrayFloat(rule, "Weights")
			if err != nil {
				b.Error(err)
				continue
			}
			featurePaths, err = util.ParseArrayString(rule, "FeaturePaths")
			if err != nil {
				b.Error(err)
				continue
			}
			featureTrees = make([]*jee.TokenTree, len(featurePaths))
			for i, path := range featurePaths {
				token, err := jee.Lexer(path)
				if err != nil {
					b.Error(err)
					break
				}
				tree, err := jee.Parser(token)
				if err != nil {
					b.Error(err)
					break
				}
				featureTrees[i] = tree
			}
		case <-b.quit:
			// quit the block
			return
		case msg := <-b.in:
			if featureTrees == nil {
				continue
			}
			x := make([]float64, len(featureTrees))
			for i, tree := range featureTrees {
				feature, err := jee.Eval(tree, msg)
				if err != nil {
					b.Error(err)
					break Loop
				}
				fi, ok := feature.(float64)
				if !ok {
					b.Error(errors.New("features must be float64"))
					break Loop
				}
				x[i] = fi
			}
			μ := 0.0
			for i, βi := range β {
				μ += βi * x[i]
			}
			var y float64
			if rand.Float64() <= logit(μ) {
				y = 1
			} else {
				y = 0
			}
			b.out <- map[string]interface{}{
				"Response": y,
			}

		case MsgChan := <-b.queryrule:
			// deal with a query request
			out := map[string]interface{}{
				"Weights":      β,
				"FeaturePaths": featurePaths,
			}
			MsgChan <- out
		}
	}
}
예제 #3
0
func (b *Learn) Run() {

	dataChan := make(chan sgd.Obs)
	paramChan := make(chan sgd.Params)
	stateChan := make(chan chan []float64)
	kernelQuitChan := make(chan bool)

	lossfuncs := map[string]sgd.LossFunc{
		"linear":   sgd.GradLinearLoss,
		"logistic": sgd.GradLogisticLoss,
	}
	stepfuncs := map[string]sgd.StepFunc{
		"inverse":  sgd.EtaInverse,
		"constant": sgd.EtaConstant,
		"bottou":   sgd.EtaBottou,
	}
	kernelStarted := false

	var responsePath, lossfuncString, stepfuncString string
	var featurePaths []string
	var θ_0 []float64
	var featureTrees []*jee.TokenTree
	var responseTree *jee.TokenTree
	var err error

	for {
	Loop:
		select {
		case rule := <-b.inrule:
			if kernelStarted {
				// if we already have a rule, then we've already started a
				// kernel, which we should now quit.
				kernelQuitChan <- true
			}

			featurePaths, err = util.ParseArrayString(rule, "FeaturePaths")
			if err != nil {
				b.Error(err)
				continue
			}
			featureTrees = make([]*jee.TokenTree, len(featurePaths))
			for i, path := range featurePaths {
				token, err := jee.Lexer(path)
				if err != nil {
					b.Error(err)
					break
				}
				tree, err := jee.Parser(token)
				if err != nil {
					b.Error(err)
					break
				}
				featureTrees[i] = tree
			}
			responsePath, err = util.ParseString(rule, "ResponsePath")
			if err != nil {
				b.Error(err)
				break
			}
			token, err := jee.Lexer(responsePath)
			if err != nil {
				b.Error(err)
				break
			}
			responseTree, err = jee.Parser(token)
			if err != nil {
				b.Error(err)
				break
			}
			lossfuncString, err = util.ParseString(rule, "Lossfunc")
			if err != nil {
				b.Error(err)
				break
			}
			stepfuncString, err = util.ParseString(rule, "Stepfunc")
			if err != nil {
				b.Error(err)
				break
			}
			grad, ok := lossfuncs[lossfuncString]
			if !ok {
				b.Error(errors.New("Unknown loss function: " + lossfuncString))
			}
			step, ok := stepfuncs[stepfuncString]
			if !ok {
				b.Error(errors.New("Unknown step function: " + stepfuncString))
			}
			θ_0, err = util.ParseArrayFloat(rule, "InitialState")
			if err != nil {
				b.Error(err)
				break
			}
			go sgd.SgdKernel(dataChan, paramChan, stateChan, kernelQuitChan, grad, step, θ_0)
			kernelStarted = true

		case <-b.quit:
			kernelQuitChan <- true
			return
		case msg := <-b.in:
			if featureTrees == nil {
				continue
			}
			if responseTree == nil {
				continue
			}
			x := make([]float64, len(featureTrees))
			for i, tree := range featureTrees {
				feature, err := jee.Eval(tree, msg)
				if err != nil {
					b.Error(err)
					break Loop
				}
				fi, ok := feature.(float64)
				if !ok {
					b.Error(errors.New("features must be float64"))
					break Loop
				}
				x[i] = fi
			}
			responseI, err := jee.Eval(responseTree, msg)
			if err != nil {
				b.Error(err)
				break
			}
			y, ok := responseI.(float64)
			if !ok {
				b.Error(errors.New("response must be float64"))
				break
			}
			d := sgd.Obs{
				X: x,
				Y: y,
			}
			dataChan <- d
		case <-b.inpoll:
			var params []interface{}
			var model []float64
			if kernelStarted {
				kernelMsgChan := make(chan []float64)
				stateChan <- kernelMsgChan
				model = <-kernelMsgChan
				params = make([]interface{}, len(model))
				for i, p := range model {
					params[i] = p
				}
			}
			b.out <- map[string]interface{}{
				"params": params,
			}
		case c := <-b.queryrule:
			c <- map[string]interface{}{
				"Lossfunc":     lossfuncString,
				"Stepfunc":     stepfuncString,
				"FeaturePaths": featurePaths,
				"ResponsePath": responsePath,
				"InitialState": θ_0,
			}
		}
	}
}