forked from danieldk/golinear
/
validation.go
44 lines (36 loc) · 1.09 KB
/
validation.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package golinear
/*
#include <stdlib.h>
#include <linear.h>
#include "wrap.h"
*/
import "C"
import (
"errors"
"unsafe"
)
// Perform cross validation. The instances in the problem are separated
// in the given number of folds. Each fold is sequentially evaluated
// using the model trained with the remaining folds. The slice that is
// returned contains the predicted instance classes.
func CrossValidation(problem *Problem, param Parameters, nFolds uint) ([]float64, error) {
cParam := toCParameter(param)
defer func() {
C.destroy_param_wrap(cParam)
C.free(unsafe.Pointer(cParam))
}()
r := C.check_parameter_wrap(problem.problem, cParam)
if r != nil {
msg := C.GoString(r)
return nil, errors.New(msg)
}
nInstances := uint(problem.problem.l)
target := newDouble(C.size_t(nInstances))
defer C.free(unsafe.Pointer(target))
C.cross_validation_wrap(problem.problem, cParam, C.int(nFolds), target)
classifications := make([]float64, nInstances)
for idx, _ := range classifications {
classifications[idx] = float64(C.get_double_idx(target, C.int(idx)))
}
return classifications, nil
}