示例#1
0
// CreateState creates a new state for AROW classifier.
func (c *AROWStateCreator) CreateState(ctx *core.Context, params data.Map) (core.SharedState, error) {
	label, err := pluginutil.ExtractParamAsStringWithDefault(params, "label_field", "label")
	if err != nil {
		return nil, err
	}
	fv, err := pluginutil.ExtractParamAsStringWithDefault(params, "feature_vector_field", "feature_vector")
	if err != nil {
		return nil, err
	}
	rw, err := pluginutil.ExtractParamAndConvertToFloat(params, "regularization_weight")
	if err != nil {
		return nil, err
	}
	if rw <= 0 {
		return nil, errors.New("regularization_weight parameter must be greater than zero")
	}

	a, err := NewAROW(float32(rw))
	if err != nil {
		return nil, fmt.Errorf("failed to initialize AROW: %v", err)
	}

	return &AROWState{
		arow:               a,
		labelField:         label,
		featureVectorField: fv,
	}, nil
}
func (c *PassiveAggressiveStateCreator) CreateState(ctx *core.Context, params data.Map) (core.SharedState, error) {
	value, err := pluginutil.ExtractParamAsStringWithDefault(params, "value_field", "value")
	if err != nil {
		return nil, err
	}
	fv, err := pluginutil.ExtractParamAsStringWithDefault(params, "feature_vector_field", "feature_vector")
	if err != nil {
		return nil, err
	}

	rw, err := pluginutil.ExtractParamAndConvertToFloat(params, "regularization_weight")
	if err != nil {
		return nil, err
	}
	if rw <= 0 {
		return nil, errors.New("regularization_weight parameter must be greater than zero")
	}

	sen, err := pluginutil.ExtractParamAndConvertToFloat(params, "sensitivity")
	if err != nil {
		return nil, err
	}
	if sen < 0 {
		return nil, errors.New("sensitivity parameter must be not less than zero")
	}

	pa, err := NewPassiveAggressive(float32(rw), float32(sen))
	if err != nil {
		return nil, err
	}

	return &PassiveAggressiveState{
		pa:                 pa,
		valueField:         value,
		featureVectorField: fv,
	}, nil
}
示例#3
0
func (c *LightLOFStateCreator) CreateState(ctx *core.Context, params data.Map) (core.SharedState, error) {
	fv, err := pluginutil.ExtractParamAsStringWithDefault(params, "feature_vector_field", "feature_vector")
	if err != nil {
		return nil, err
	}

	nnAlgoName, err := pluginutil.ExtractParamAsString(params, "nearest_neighbor_algorithm")
	if err != nil {
		return nil, err
	}

	var nnAlgo NNAlgorithm
	switch strings.ToLower(nnAlgoName) {
	case "lsh":
		nnAlgo = LSH
	case "minhash":
		nnAlgo = Minhash
	case "euclid_lsh":
		nnAlgo = EuclidLSH
	default:
		return nil, fmt.Errorf("invalid nearest_neighbor_algorithm: %s", nnAlgoName)
	}

	hashNum, err := pluginutil.ExtractParamAsInt(params, "hash_num")
	if err != nil {
		return nil, err
	}
	nnNum, err := pluginutil.ExtractParamAsInt(params, "nearest_neighbor_num")
	if err != nil {
		return nil, err
	}
	rnnNum, err := pluginutil.ExtractParamAsInt(params, "reverse_nearest_neighbor_num")
	if err != nil {
		return nil, err
	}

	unlearn, err := pluginutil.ExtractParamAsStringWithDefault(params, "unlearner", "no")
	if err != nil {
		return nil, err
	}
	var maxSize int
	var seed int64
	switch unlearn {
	case "no":
		maxSize = 0
	case "random":
		m, err := pluginutil.ExtractParamAsInt(params, "max_size")
		if err != nil {
			return nil, err
		}
		maxSize = int(m)

		seed, err = pluginutil.ExtractParamAsIntWithDefault(params, "seed", 0)
		if err != nil {
			return nil, err
		}
	default:
		return nil, fmt.Errorf("invalid unlearner: %v", unlearn)
	}

	// TODO: check hashNum, nnNum, rnnNum <= INT_MAX
	llof, err := NewLightLOF(nnAlgo, int(hashNum), int(nnNum), int(rnnNum), maxSize, seed)
	if err != nil {
		return nil, err
	}
	return &lightLOFState{
		lightLOF:           llof,
		featureVectorField: fv,
	}, nil
}