/
redsvd_test.go
125 lines (97 loc) · 2.63 KB
/
redsvd_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
/*
* Go integration of RedSVD.h - tests
*
* Copyright (c) 2016 Rocket Internet
* For license, see included LICENSE file.
*/
package redsvd
import (
"github.com/stretchr/testify/assert"
"math/rand"
"testing"
)
func TestSparseMatrix64(t *testing.T) {
svd := NewGoRedSVD()
assert.NotNil(t, svd)
sparseMatrix64 := make(map[int]map[int]float64)
for row := 0; row < 100; row++ {
sparseMatrix64[row] = make(map[int]float64)
for col := 0; col < 100; col++ {
if rand.Float64() < 0.2 {
sparseMatrix64[row][col] = rand.ExpFloat64()
}
}
}
assert.Equal(t, 100, len(sparseMatrix64))
svd.SetMatrix64(100, 100, sparseMatrix64)
svd.RedSVD(10)
u := svd.MatrixU()
assert.NotNil(t, u)
assert.Equal(t, 100, len(u))
assert.Equal(t, 10, len(u[0]))
v := svd.MatrixV()
assert.NotNil(t, v)
assert.Equal(t, 100, len(v))
assert.Equal(t, 10, len(v[0]))
singularValues := svd.SingularValues()
assert.NotNil(t, singularValues)
assert.Equal(t, 10, len(singularValues))
svd.SetUnnormalized(true)
svd.RedSVD(10)
u = svd.MatrixUNotNormalized()
assert.NotNil(t, u)
assert.Equal(t, 100, len(u))
assert.Equal(t, 10, len(u[0]))
v = svd.MatrixVNotNormalized()
assert.NotNil(t, v)
assert.Equal(t, 100, len(v))
assert.Equal(t, 10, len(v[0]))
singularValues = svd.SingularValues()
assert.NotNil(t, singularValues)
assert.Equal(t, 10, len(singularValues))
DeleteGoRedSVD(svd)
}
func TestSparseMatrix32(t *testing.T) {
svd := NewGoRedSVD()
assert.NotNil(t, svd)
sparseMatrix32 := make(map[int]map[int]float32)
for row := 0; row < 100; row++ {
sparseMatrix32[row] = make(map[int]float32)
for col := 0; col < 100; col++ {
if rand.Float64() < 0.2 {
sparseMatrix32[row][col] = float32(rand.ExpFloat64())
}
}
}
assert.Equal(t, 100, len(sparseMatrix32))
svd.SetMatrix32(100, 100, sparseMatrix32)
svd.RedSVD(10)
u := svd.MatrixU()
assert.NotNil(t, u)
assert.Equal(t, 100, len(u))
assert.Equal(t, 10, len(u[0]))
v := svd.MatrixV()
assert.NotNil(t, v)
assert.Equal(t, 100, len(v))
assert.Equal(t, 10, len(v[0]))
singularValues := svd.SingularValues()
assert.NotNil(t, singularValues)
assert.Equal(t, 10, len(singularValues))
svd.SetUnnormalized(true)
svd.RedSVD(10)
u = svd.MatrixUNotNormalized()
assert.NotNil(t, u)
assert.Equal(t, 100, len(u))
assert.Equal(t, 10, len(u[0]))
v = svd.MatrixVNotNormalized()
assert.NotNil(t, v)
assert.Equal(t, 100, len(v))
assert.Equal(t, 10, len(v[0]))
singularValues = svd.SingularValues()
assert.NotNil(t, singularValues)
assert.Equal(t, 10, len(singularValues))
norms := svd.GetColumnNorms()
assert.NotNil(t, norms)
assert.Equal(t, 10, len(norms))
DeleteGoRedSVD(svd)
}