/
mod_math.go
343 lines (321 loc) · 9.16 KB
/
mod_math.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
// Routines for arithmetic mod p = 2^64-2^32+1
// -------------------------------------------
//
// The special form of p means that we can make a mod p routine which only
// involves a few shifts and arithmetic operations. We use these facts
//
// 2^64 == 2^32 -1 mod p
// 2^96 == -1 mod p
// 2^128 == -2^32 mod p
// 2^192 == 1 mod p
// 2^n * 2^(192-n) = 1 mod p
//
// Also we use the fact that 2^64 - p is 2^32-1 ie 00000000FFFFFFFF so instead
// of adding FFFFFFFF00000001 mod 2^64 we subract 00000000FFFFFFFF. This is
// convenient because the ARM carry flag is inverted for ADD and SUBtract so
// it is best (conditional execution wise) to follow an ADD with an ADDCS and
// a SUB with a SUBCC. Read on and you will understand!
//
// Note that some of the comments use positional notation where (a,b,c) is
// 2^64 * a + 2^32 * b + c
// Note also that p = (2^32 -1, 1) in this notation
//
// -----
//
// Factors of p-1 are 2^32 * 3 * 5 * 17 * 257 * 65537
//
// Note that a 64-th root of unity is 8 mod p.
//
// This means that practically an FFT can be defined of length up to 2^32 with
// optional factors of 3 and 5.
//
// For the Discrete Weighted Transform we need an n-th root of 2. 2 has order
// 192 mod p (ie 2^192 mod p = 1) so we can have the
//
// (p-1)/192 = (2^58 - 2^26) / 3 = 2^26 * 5 * 17 * 257 * 65537 th root of 2.
//
// This means that we can do the DWT for lengths up to 2^26 with an optional
// factor of 5
//
// 7 is a primitive root mod p
//
// An n-th root of unity can be generated by 7^(5*(p-1)/n) mod p.
//
// An n-th root of two can be generated by 7^(5*(p-1)/192/n) mod p
//
// So a suitable 5 * 2^26-th root of 1 is 0xED41D05B78D6E286 and the 5 * 2^26-th
// root of 2 is &C47FC73D33F80E14
//
// -----
//
// Many thanks to Peter-Lawrence Montgomery for working out the maths behind
// how to do the 128 bit to 64 bit reduction mod p and the shifts mod p so
// efficiently. Peter also suggested the idea of using shifts in the
// transform which really makes a lot of difference in execution speed on ARM
//
// -----
//
// We could shave a few cycles off here and there by using redundant
// representation probably where the numbers are represented in the range
// 0..2^64-1.
//
//
// Detecting carry on addition in C
// --------------------------------
//
// Using integers in range
//
// 0 <= x,y,z < N
//
// A carry of x + y is if x + y >= N
//
// However we calculate z = (x + y) % N
//
// if (z < x)
//
// There must have been a carry, but need there have been a carry?
//
// => (x + y) % N < x
//
// 0 <= x+y <= 2*N-2
//
// if x+y < N:
//
// => (x + y) < x
// => false
//
// if x+y >= N && x+y < 2N:
//
// => (x + y - N) < x
// => y - N < 0
// => y < N
// => true
//
// Detecting carry on subtraction in C
// -----------------------------------
//
// This is much easier
//
// Using integers in range
//
// 0 <= x,y,z < N
//
// A carry of x - y is if x - y < 0
//
// We can test for this directly and hope the compiler optimises it into
// a test of the carry flag.
package main
import (
"fmt"
"math/rand"
)
const (
ROOT_ORDER uint64 = 5 << 26 // 5 * 2^26
MOD_P uint64 = 0xFFFFFFFF00000001 // p
ROOT_ONE uint64 = 0xED41D05B78D6E286 // the ROOT_ORDER-th root of 1 mod p
ROOT_TWO uint64 = 0xC47FC73D33F80E14 // the ROOT_ORDER-th root of 2 mod p
)
// This adds carry onto x (not using modulo arithmetic) and
// returns the carry out of the first width bits
//
// This is used to do arithmetic with numbers that are width bits wide
// but stored in a uint64
func mod_adc(x uint64, width uint8, carry *uint64) uint64 {
sum := x + *carry
// carry? See top for proof
carry_bit := uint64(0)
if sum < x {
carry_bit = 1
}
*carry = (sum >> width) + (carry_bit << (64 - width))
// return low bits of the result
return sum & (((uint64(1)) << width) - 1)
}
// Shift a value (multiply by 2^shift).
//
// For shifts 0..31 bits
// (xhigh, xmid, xlow) = (x1,x0) << shift where shift is 0..31
// (xhigh, xmid, xlow) mod p
// = (xmid + xhigh, xlow- xhigh)
// = (xmid, xlow) + (xhigh, -xhigh)
// negate (xhigh, -xhigh) so we can use a subtract rather than an add
// p - (xhigh, -xhigh) = (2^32 - 1 - xhigh, xhigh + 1)
// is in range 2^32..p. The add xhigh + 1 cannot overflow since xhigh < 2^31
// = (xmid, xlow) - (2^32 - 1 - xhigh, xhigh + 1)
// (xmid, xlow) is in range 0..2^64-1, therefore result of subtract is in
// range -p..2^64-1-2^32 = -p..p-2 which is ok for the subtract sequence
func mod_shift0to31(x uint64, shift uint8) uint64 {
xmid_xlow := x << shift // (xmid, xlow)
xhigh := uint32(x >> (64 - shift))
t := uint64(0xFFFFFFFF-xhigh)<<32 + uint64(xhigh+1) // (2^32 - 1 - xhigh, xhigh + 1)
r := xmid_xlow - t // (xmid, xlow) + (xhigh, -xhigh)
// carry
if xmid_xlow < t {
r += MOD_P
}
return r
}
// Shift a value (multiply by 2^shift).
//
// For shifts between 32..63 bits
// (xhigh, xmid, xlow, 0) = (x1,x0) << shift where shift is 32..63
// (xhigh, xmid, xlow, 0) mod p
// = (xmid, xlow, - xhigh)
// = (xlow + xmid, - xhigh - xmid)
// This can be negative and can exceed p.
// = (xlow, 0) - (-xmid, xhigh + xmid) mod p
// note that (xlow, 0) < p
//
// xmidneg = 0 - xmid (mod 2^32)
// xmidcomp = xmidneg - (borrow from last subtract)
// temp = (xmidneg, xhigh) - (0, xmidcomp)
//
// If xmid = 0, then xmidneg = xmidcomp = 0 and temp2 = (0, xhigh) which is 0..p-1
// Otherwise xmidneg = 2^32 - xmid and xmidcomp = 2^32 - 1 - xmid
// so temp = (2^32 - xmid, xhigh) - (0, 2^32 - 1 - xmid)
// = (2^32, -2^32 + 1) + (-xmid, xhigh + xmid)
// = (2^32 - 1, 1) + (-xmid, xhigh + xmid)
// = p + (-xmid, xhigh + xmid)
//
// If xmid = 1 then temp = (2^32 - 1, xhigh) - (0, 2^32 - 2) since xhigh < 2^31 this is < p
// if xmid = 2^32 - 1 then temp = (1, xhigh) - (0, 0) which is < p
func mod_shift32to63(x uint64, shift uint8) uint64 {
xhigh := uint32(x >> (96 - shift))
xmid := uint32(x >> (64 - shift))
xlow := uint32(x << (shift - 32))
t0 := uint64(xmid) << 32 // (xmid, 0)
t1 := uint64(xmid) // (0, xmid)
t0 -= t1 // (xmid, -xmid) no carry and must be in range 0..p-1
t1 = uint64(xhigh) // (0, xhigh)
r := t0 - t1 // (xmid, - xhigh - xmid)
// carry?
if t0 < t1 {
r += MOD_P
}
t0 = r
// add (xlow, 0) by subtacting p - (xlow, 0) = (2^32 - 1 - xlow, 1)
t1 = uint64(0xFFFFFFFF-xlow)<<32 + 1 // -(xlow, 0)
r = t0 - t1 // (xlow + xmid, - xhigh - xmid)
// carry?
if t0 < t1 {
r += MOD_P
}
return r
}
// Shift a value (multiply by 2^shift).
//
// For shifts between 64..95 bits
// (xhigh, xmid, xlow, 0, 0) = (x1,x0) << shift where shift is 64..95
// (xhigh, xmid, xlow, 0, 0) mod p
// = (xmid, xlow, -xhigh, 0)
// = (xlow, -xhigh, -xmid)
// = (xlow - xhigh, -xmid - xlow)
// = (xlow, -xlow) - (xhigh, xmid)
// (xhigh, xmid) is < p since xhigh < 2^31
// (xlow, -xlow) can be evaluated as (xlow, 0) - (0, xlow) this is < p
func mod_shift64to95(x uint64, shift uint8) uint64 {
xhigh := uint32(x >> (128 - shift))
xmid := uint32(x >> (96 - shift))
xlow := uint32(x << (shift - 64))
t0 := uint64(xlow) << 32 // (xlow, 0)
t1 := uint64(xlow) // (0, xlow)
t0 -= t1 // (xlow, -xlow) - no carry possible
t1 = uint64(xhigh)<<32 + uint64(xmid) // (xhigh, xmid)
r := t0 - t1 // (xlow, -xlow) - (xhigh, xmid)
// carry?
if t0 < t1 {
r += MOD_P
}
return r
}
// Shift a value (multiply by 2^shift), 0 <= shift < 96
//
// FIXME specialising the code for the shifts would make them much
// quicker (not a register controlled shift). Also note that in the
// FFTs we only need shifts of multiples of 3, eg 3,6,9,12...
func mod_shift(x uint64, shift uint8) uint64 {
switch {
case shift == 0:
return x
case shift < 32:
return mod_shift0to31(x, shift)
case shift < 64:
return mod_shift32to63(x, shift)
case shift < 96:
return mod_shift64to95(x, shift)
case shift == 96:
// shift of 96 is negate
return mod_sub(0, x)
default:
panic(fmt.Sprintf("Bad shift value %d in mod_shift", shift))
}
return x
}
// This returns the top bit set in x
//
// It returns -1 if there are no bits set
func mod_top_bit(x uint64) int {
bit := uint64(1 << 63)
for i := 63; i >= 0; i-- {
if x&bit != 0 {
return i
}
bit >>= 1
}
return -1
}
// Calculate a ^ b mod p
func mod_pow(a, b uint64) uint64 {
r := uint64(1)
for i := mod_top_bit(b); i >= 0; i-- {
r = mod_mul(r, r)
if (b>>uint(i))&1 != 0 {
r = mod_mul(r, a)
}
}
return r
}
// Calculate 1/a mod p
//
// We do this by the easy and not very efficient method below
//
// a^(p-1) = 1 mod p for any field element except 0
// => a * a^(p-2) = 1 mod p
// therefore a^(p-2) mod p is 1/a
func mod_inv(a uint64) uint64 {
return mod_pow(a, MOD_P-2)
}
// Multiply one array by another
//
//
// Entry
// n = length
// a -> array
// b = multiplier array
// Exit
func mod_vector_mul(n uint, x []uint64, y []uint64) {
for i := uint(0); i < n; i++ {
x[i] = mod_mul(x[i], y[i])
}
}
// Square an array
//
// Entry
// n = length
// a -> array
// b = multiplier array
// Exit
func mod_vector_sqr(n uint, x []uint64) {
for i := uint(0); i < n; i++ {
x[i] = mod_sqr(x[i])
}
}
// Make a random integer 0 <= x < p
func mod_rnd() uint64 {
for {
x := uint64(rand.Uint32())<<32 + uint64(rand.Uint32())
if x < MOD_P {
return x
}
}
}