/
alias.go
172 lines (138 loc) · 3.64 KB
/
alias.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
// Copyright (c) 2012-2015, Jack Christopher Kastorff
// All rights reserved.
// BSD Licensed, see LICENSE for details.
// The alias package picks items from a discrete distribution
// efficiently using the alias method.
package alias
import (
"encoding/binary"
"errors"
"math/rand"
)
type Alias struct {
table []ipiece
}
type fpiece struct {
prob float64
alias uint32
}
type ipiece struct {
prob uint32 // [0,2^31)
alias uint32
}
// Create a new alias object.
// For example,
// var v = alias.New([]float64{8,10,2})
// creates an alias that returns 0 40% of the time, 1 50% of the time, and
// 2 10% of the time.
func New(prob []float64) (*Alias, error) {
// This implementation is based on
// http://www.keithschwarz.com/darts-dice-coins/
n := len(prob)
if n < 1 {
return nil, errors.New("too few probabilities")
}
if int(uint32(n)) != n {
return nil, errors.New("too many probabilities")
}
total := float64(0)
for _, v := range prob {
if v <= 0 {
return nil, errors.New("a probability is non-positive")
}
total += v
}
var al Alias
al.table = make([]ipiece, n)
// Michael Vose's algorithm
// "small" stack grows from the bottom of this array
// "large" stack from the top
twins := make([]fpiece, n)
smTop := -1
lgBot := n
// invariant: smTop < lgBot, that is, the twin stacks don't collide
mult := float64(n) / total
for i, p := range prob {
p = p * mult
// push large items (>=1 probability) into the large stack
// others in the small stack
if p >= 1 {
lgBot--
twins[lgBot] = fpiece{p, uint32(i)}
} else {
smTop++
twins[smTop] = fpiece{p, uint32(i)}
}
}
for smTop >= 0 && lgBot < n {
// pair off a small and large block, taking the chunk from the large block that's wanted
l := twins[smTop]
smTop--
g := twins[lgBot]
lgBot++
al.table[l.alias].prob = uint32(l.prob * (1<<31 - 1))
al.table[l.alias].alias = g.alias
g.prob = (g.prob + l.prob) - 1
// put the rest of the large block back in a list
if g.prob < 1 {
smTop++
twins[smTop] = g
} else {
lgBot--
twins[lgBot] = g
}
}
// clear out any remaining blocks
for i := n - 1; i >= lgBot; i-- {
al.table[twins[i].alias].prob = 1<<31 - 1
}
// there shouldn't be anything here, but sometimes floating point
// errors send a probability just under 1.
for i := 0; i <= smTop; i++ {
al.table[twins[i].alias].prob = 1<<31 - 1
}
return &al, nil
}
// Generates a random number according to the distribution using the rng passed.
func (al *Alias) Gen(rng *rand.Rand) uint32 {
ri := uint32(rng.Int31())
w := ri % uint32(len(al.table))
if ri > al.table[w].prob {
return al.table[w].alias
}
return w
}
// MarshalBinary implements encoding.BinaryMarshaller.
func (al *Alias) MarshalBinary() ([]byte, error) {
out := make([]byte, len(al.table)*8)
for i, piece := range al.table {
bin := out[i*8 : 8+i*8]
binary.LittleEndian.PutUint32(bin[0:4], piece.prob)
binary.LittleEndian.PutUint32(bin[4:8], piece.alias)
}
return out, nil
}
// UnmarshalBinary implements encoding.BinaryUnmarshaller.
func (al *Alias) UnmarshalBinary(p []byte) error {
if len(p)%8 != 0 {
return errors.New("bad data length")
}
if int(uint32(len(p)/8)) != len(p)/8 {
return errors.New("data too large")
}
al.table = make([]ipiece, (len(p))/8)
for i := range al.table {
bin := p[i*8 : 8+i*8]
prob := binary.LittleEndian.Uint32(bin[0:4])
alias := binary.LittleEndian.Uint32(bin[4:8])
if prob >= 1<<31 {
return errors.New("bad data: probability out of range")
}
if alias >= uint32(len(al.table)) {
return errors.New("bad data: alias target out of range")
}
al.table[i].prob = prob
al.table[i].alias = alias
}
return nil
}