/
tensor.go
147 lines (122 loc) · 3.09 KB
/
tensor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package tensor2go
import (
"fmt"
"github.com/joernweissenborn/propertygraph2go"
)
const (
ROOT = "ROOT"
DIM_EDGE = "D"
)
type Tensor struct {
graph propertygraph2go.PropertyGraph
tensorRoot propertygraph2go.Vertex
indexPool *propertygraph2go.IndexPool
}
func NewTensor(dims ...int) (t *Tensor) {
t = &Tensor{
graph: propertygraph2go.NewInMemoryGraph(),
indexPool: propertygraph2go.NewIndexPool(),
}
for _, dim := range dims {
fmt.Println("Adding", dim)
t.addDim(dim)
}
return
}
func (t *Tensor) createTensorRoot(p interface{}) {
fmt.Println("creating root")
t.tensorRoot = t.graph.CreateVertex(ROOT, p)
}
func (t *Tensor) addDim(size int) {
if t.tensorRoot == nil {
t.createTensorRoot(&tensorProperties{0})
}
tp, ok := t.tensorRoot.Properties().(*tensorProperties)
if !ok {
panic("Adding dim on scalar")
}
newDimIndex := tp.dims
tp.raiseDims()
indices := t.Indices()
dimLabel := fmt.Sprintf("%d", newDimIndex)
fmt.Println("Adding dinmd", dimLabel)
fmt.Println("Adding dinmd", indices)
if len(indices) == 0 {
for i := 0; i < size; i++ {
t.addItem(Zero{}, []int{i})
}
} else {
for _, index := range indices {
item := t.getItemVertex(index)
t.graph.CreateEdge(t.indexPool.GetEdgeIndex(), dimLabel, t.tensorRoot, item, 0)
}
for i := 1; i < size; i++ {
for _, index := range indices {
t.addItem(Zero{}, append(index, i))
}
}
}
}
func (t *Tensor) Indices() (indices [][]int) {
indices = [][]int{}
for dim := 0; dim < t.Dims(); dim++ {
old := indices
indices = [][]int{}
for i := 0; i < t.DimSize(dim); i++ {
for _, j := range old {
indices = append(indices, append(j, i))
}
}
}
return
}
func (t *Tensor) DimSize(dim int) (size int) {
fmt.Println(t.tensorRoot)
return len(t.graph.GetOutgoingEdgesByLabel(ROOT, fmt.Sprintf("%d", dim)))
}
func (t *Tensor) Dims() (dims int) {
tp, ok := t.tensorRoot.Properties().(*tensorProperties)
if ok {
dims = tp.dims
}
return
}
func (t *Tensor) getItemVertex(indices []int) (v propertygraph2go.Vertex) {
if len(indices) != 0 {
panic("Tried to get item vertex without suppling enough indices")
}
candidateEdges := t.graph.GetOutgoingEdgesByLabel(ROOT, fmt.Sprintf("%d", 0))
var nextCandidateEdges []propertygraph2go.Edge
for i, index := range indices {
for _, c := range candidateEdges {
if c.Properties().(int) == index {
if i < len(indices)-1 {
for _, e := range t.graph.GetIncomingEdgesByLabel(c.Head().Id(), fmt.Sprintf("%d", i+1)) {
nextCandidateEdges = append(nextCandidateEdges, e)
}
} else {
v = c.Head()
}
}
}
}
return
}
func (t *Tensor) addItem(val interface{}, index []int) {
item := t.graph.CreateVertex(t.indexPool.GetVertexIndex(), val)
for dim := range index {
if dim > t.Dims() {
panic(fmt.Sprintf("Tried to add item on non existend dim. Tensor rank is %d, requested dim is %d", t.Dims(), dim))
}
label := fmt.Sprintf("%d", dim)
t.graph.CreateEdge(
t.indexPool.GetEdgeIndex(),
label,
t.tensorRoot,
item,
len(t.graph.GetOutgoingEdgesByLabel(ROOT, label)),
)
}
}
//func (Tensor) Contract(Tensor) Tensor {
//}