/
gradbased.go
203 lines (168 loc) · 5.3 KB
/
gradbased.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package train
import (
"sync"
"github.com/gonum/floats"
"github.com/gonum/matrix/mat64"
"github.com/reggo/common"
"github.com/reggo/loss"
"github.com/reggo/regularize"
"fmt"
)
var _ = fmt.Println
// BatchGradBased is a wrapper for training a trainable with
// a fixed set of samples
// TODO: Maybe want to move Featurize out of batch grad (or let pass in features as an input)
type BatchGradBased struct {
t Trainable
losser loss.DerivLosser
regularizer regularize.Regularizer
inputDim int
outputDim int
nTrain int
nParameters int
grainSize int
inputs common.RowMatrix
outputs common.RowMatrix
features *mat64.Dense
lossDerivFunc func(start, end int, c chan lossDerivStruct, parameters []float64)
}
// NewBatchGradBased creates a new batch grad based with the given inputs
func NewBatchGradBased(trainable Trainable, cacheFeatures bool, inputs, outputs common.RowMatrix, losser loss.DerivLosser, regularizer regularize.Regularizer) *BatchGradBased {
var features *mat64.Dense
if cacheFeatures {
features = FeaturizeTrainable(trainable, inputs, nil)
}
// TODO: Add in error checking
nTrain, outputDim := outputs.Dims()
_, inputDim := inputs.Dims()
g := &BatchGradBased{
t: trainable,
inputs: inputs,
outputs: outputs,
features: features,
losser: losser,
regularizer: regularizer,
nTrain: nTrain,
outputDim: outputDim,
inputDim: inputDim,
nParameters: trainable.NumParameters(),
grainSize: trainable.GrainSize(),
}
// TODO: Add in row viewer stuff
// TODO: Create a different function for computing just the loss
//inputRowViewer, ok := inputs.(mat64.RowViewer)
//outputRowViewer, ok := outputs.(mat64.RowViewer)
// TODO: Move this to its own function
var f func(start, end int, c chan lossDerivStruct, parameters []float64)
switch {
default:
panic("Shouldn't be here")
case cacheFeatures:
f = func(start, end int, c chan lossDerivStruct, parameters []float64) {
lossDeriver := g.t.NewLossDeriver()
prediction := make([]float64, g.outputDim)
dLossDPred := make([]float64, g.outputDim)
dLossDWeight := make([]float64, g.nParameters)
totalDLossDWeight := make([]float64, g.nParameters)
var loss float64
output := make([]float64, g.outputDim)
for i := start; i < end; i++ {
// Compute the prediction
lossDeriver.Predict(parameters, g.features.RowView(i), prediction)
// Compute the loss
g.outputs.Row(output, i)
loss += g.losser.LossDeriv(prediction, output, dLossDPred)
// Compute the derivative
lossDeriver.Deriv(parameters, g.features.RowView(i), prediction, dLossDPred, dLossDWeight)
floats.Add(totalDLossDWeight, dLossDWeight)
}
// Send the value back on the channel
c <- lossDerivStruct{
loss: loss,
deriv: totalDLossDWeight,
}
}
case !cacheFeatures:
f = func(start, end int, c chan lossDerivStruct, parameters []float64) {
lossDeriver := g.t.NewLossDeriver()
prediction := make([]float64, g.outputDim)
dLossDPred := make([]float64, g.outputDim)
dLossDWeight := make([]float64, g.nParameters)
totalDLossDWeight := make([]float64, g.nParameters)
var loss float64
output := make([]float64, g.outputDim)
input := make([]float64, g.inputDim)
features := make([]float64, g.t.NumFeatures())
featurizer := g.t.NewFeaturizer()
for i := start; i < end; i++ {
g.inputs.Row(input, i)
featurizer.Featurize(input, features)
// Compute the prediction
lossDeriver.Predict(parameters, features, prediction)
// Compute the loss
g.outputs.Row(output, i)
loss += g.losser.LossDeriv(prediction, output, dLossDPred)
// Compute the derivative
lossDeriver.Deriv(parameters, features, prediction, dLossDPred, dLossDWeight)
// Add to the total derivative
floats.Add(totalDLossDWeight, dLossDWeight)
// Send the value back on the channel
c <- lossDerivStruct{
loss: loss,
deriv: totalDLossDWeight,
}
}
}
}
g.lossDerivFunc = f
return g
}
// Dimension returns the dimension of the optimization problem
func (g *BatchGradBased) Dimension() int {
return g.t.NumParameters()
}
// ObjDeriv computes the objective value and stores the derivative in place
func (g *BatchGradBased) ObjGrad(parameters []float64, derivative []float64) (loss float64) {
c := make(chan lossDerivStruct, 10)
// Set the channel for parallel for
f := func(start, end int) {
g.lossDerivFunc(start, end, c, parameters)
}
go func() {
wg := &sync.WaitGroup{}
// Compute the losses and the derivatives all in parallel
wg.Add(2)
go func() {
common.ParallelFor(g.nTrain, g.grainSize, f)
wg.Done()
}()
// Compute the regularization
go func() {
deriv := make([]float64, g.nParameters)
loss := g.regularizer.LossDeriv(parameters, deriv)
c <- lossDerivStruct{
loss: loss,
deriv: deriv,
}
wg.Done()
}()
// Wait for all of the results to be sent on the channel
wg.Wait()
// Close the channel
close(c)
}()
// zero the derivative
for i := range derivative {
derivative[i] = 0
}
// Range over the channel, incrementing the loss and derivative
// as they come in
for l := range c {
loss += l.loss
floats.Add(derivative, l.deriv)
}
// Normalize by the number of training samples
loss /= float64(g.nTrain)
floats.Scale(1/float64(g.nTrain), derivative)
return loss
}