forked from akualab/gjoa
/
model.go
303 lines (246 loc) · 7.15 KB
/
model.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
// Copyright (c) 2014 AKUALAB INC., All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package model
import (
"fmt"
"math/rand"
)
const (
// DefaultSeed provided for model implementation.
DefaultSeed = 33
)
// A Modeler type is a complete implementation of a statistical model in gjoa.
type Modeler interface {
// The model name.
Name() string
// Dimensionality of the observation vector.
Dim() int
Trainer
Predictor
Scorer
Sampler
}
// A Trainer type can do statictical learning.
type Trainer interface {
// Updates model using weighted samples: x[i] * w(x[i]).
Update(x Observer, w func(Obs) float64) error
// Updates model using a single weighted sample.
UpdateOne(o Obs, w float64)
// Estimates model parameters.
Estimate() error
// Clears all model parameters.
Clear()
}
// NoWeight is a parameter for the Update() method.
// Applies a weight of one to the observations.
var NoWeight = func(o Obs) float64 { return 1.0 }
// Weight is a parameter for the Update() method.
// Applies a weight to the observations.
var Weight = func(w float64) func(o Obs) float64 {
return func(o Obs) float64 {
return w
}
}
// Predictor returns a label with a hypothesis given the observations.
type Predictor interface {
Predict(x Observer) ([]Labeler, error)
}
// Scorer computes log probabilities.
type Scorer interface {
LogProb(x Obs) float64
}
// The Labeler interface manages data labels.
type Labeler interface {
// Human-readable name.
String() string
// Compare labels.
IsEqual(label Labeler) bool
}
// Obs is a generic interface to handle observed data.
// Each observation may have a value and a label.
type Obs interface {
// The observation's id.
ID() string
// The observation's value.
Value() interface{}
// The observation's label.
Label() Labeler
}
// The Observer provides streams of observations.
type Observer interface {
// Returns channel of observations.
// The sequence ends when the channel closes.
ObsChan() (<-chan Obs, error)
}
// The Sampler type generates random data using the model.
type Sampler interface {
// Returns a sample drawn from the underlying distribution.
Sample(*rand.Rand) Obs
// Returns a sample of size "size" drawn from the underlying distribution.
// The sequence ends when the channel closes.
SampleChan(r *rand.Rand, size int) <-chan Obs
}
// FloatObs implements the Obs interface. Values are slices of type float64.
type FloatObs struct {
value []float64
label SimpleLabel
id string
}
// NewFloatObs creates new FloatObs objects.
func NewFloatObs(val []float64, lab SimpleLabel) Obs {
return FloatObs{
value: val,
label: lab,
}
}
// Value method returns the observed value.
func (fo FloatObs) Value() interface{} { return interface{}(fo.value) }
// Label returns the label for the observation.
func (fo FloatObs) Label() Labeler { return Labeler(fo.label) }
// ID returns the observation id.
func (fo FloatObs) ID() string { return fo.id }
// FloatObsSequence implements the Obs interface using a slice of
// float64 slices.
type FloatObsSequence struct {
value [][]float64
label SimpleLabel
id string
alignment []*ANode
}
// NewFloatObsSequence creates new FloatObsSequence objects.
func NewFloatObsSequence(val [][]float64, lab SimpleLabel, id string) Obs {
return FloatObsSequence{
value: val,
label: lab,
id: id,
alignment: nil,
}
}
// Value method returns the observed value.
func (fo FloatObsSequence) Value() interface{} { return interface{}(fo.value) }
// ValueAsSlice returns the observed value as a slice of interfaces.
func (fo FloatObsSequence) ValueAsSlice() []interface{} {
res := make([]interface{}, len(fo.value), len(fo.value))
for k, v := range fo.value {
res[k] = v
}
return res
}
// Label returns the label for the observation.
func (fo FloatObsSequence) Label() Labeler { return Labeler(fo.label) }
// ID returns the observation id.
func (fo FloatObsSequence) ID() string { return fo.id }
// Add adds a FloatObs to the sequence.
func (fo *FloatObsSequence) Add(obs FloatObs, lab string) {
fo.value = append(fo.value, obs.value)
switch {
case len(lab) > 0 && len(fo.label) == 0:
x := string(lab) // no sperator
fo.label = SimpleLabel(x)
case len(lab) > 0 && len(fo.label) > 0:
x := string(fo.label) + "," + string(lab)
fo.label = SimpleLabel(x)
}
}
// JoinFloatObsSequence joins various FloatObsSequence objects into a new sequence.
// id is the new id of the joined sequence.
func JoinFloatObsSequence(id string, inputs ...*FloatObsSequence) Obs {
var val [][]float64
var lab SimpleLabel
for k, fos := range inputs {
for _, vec := range fos.value {
val = append(val, vec)
}
if k == 0 {
lab = fos.label
} else {
lab = lab + "," + fos.label
}
}
return &FloatObsSequence{
value: val,
label: lab,
id: id,
}
}
// SetAlignment sets the alignment object.
func (fo *FloatObsSequence) SetAlignment(a []*ANode) {
fo.alignment = a
}
// Alignment returns the alignment object.
func (fo FloatObsSequence) Alignment() []*ANode {
return fo.alignment
}
// IntObs implements Obs for integer values.
type IntObs struct {
value int
label SimpleLabel
id string
}
// NewIntObs creates new IntObs objects.
func NewIntObs(val int, lab SimpleLabel, id string) Obs {
return IntObs{
value: val,
label: lab,
id: id,
}
}
// Value method returns the observed value.
func (io IntObs) Value() interface{} { return interface{}(io.value) }
// Label returns the label for the observation.
func (io IntObs) Label() Labeler { return Labeler(io.label) }
// ID returns the observation id.
func (io IntObs) ID() string { return io.id }
// SimpleLabel implements a basic Labeler interface.
type SimpleLabel string
// String returns the label as a string. Multiple labels must be separated using a comma.
func (lab SimpleLabel) String() string {
// return lab.name
return string(lab)
}
// IsEqual compares two labels.
func (lab SimpleLabel) IsEqual(lab2 Labeler) bool {
if lab.String() == lab2.String() {
return true
}
return false
}
// FloatObserver implements an observer to stream FloatObs objects.
// Not safe to use with multiple goroutines.
type FloatObserver struct {
Values [][]float64
Labels []SimpleLabel
length int
}
// NewFloatObserver creates a new FloatObserver.
func NewFloatObserver(v [][]float64, lab []SimpleLabel) (*FloatObserver, error) {
if len(v) != len(lab) {
return nil, fmt.Errorf("length of v [%d] and length of lab [%d] don't match.", len(v), len(lab))
}
return &FloatObserver{
Values: v,
Labels: lab,
length: len(v),
}, nil
}
// ObsChan implements the ObsChan method for the observer interface.
func (fo FloatObserver) ObsChan() (<-chan Obs, error) {
obsChan := make(chan Obs, 1000)
go func() {
for i := 0; i < fo.length; i++ {
obsChan <- NewFloatObs(fo.Values[i], fo.Labels[i])
}
close(obsChan)
}()
return obsChan, nil
}
// ObsToF64 converts an Obs to a tuple: []float64, label, id.
func ObsToF64(o Obs) ([]float64, string, string) {
return o.Value().([]float64), o.Label().String(), o.ID()
}
// F64ToObs converts a []float64 to Obs.
func F64ToObs(v []float64, label string) Obs {
return NewFloatObs(v, SimpleLabel(label))
}