arithmetic.gno
11.62 Kb ยท 487 lines
1// arithmetic provides arithmetic operations for Uint objects.
2// This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations
3// as well as overflow checks, and negation. These functions are essential for numeric
4// calculations using 256-bit unsigned integers.
5package uint256
6
7import "math/bits"
8
9// Add sets z to the sum x+y and returns z.
10func (z *Uint) Add(x, y *Uint) *Uint {
11 var carry uint64
12 z[0], carry = bits.Add64(x[0], y[0], 0)
13 z[1], carry = bits.Add64(x[1], y[1], carry)
14 z[2], carry = bits.Add64(x[2], y[2], carry)
15 z[3], _ = bits.Add64(x[3], y[3], carry)
16 return z
17}
18
19// AddOverflow sets z to the sum x+y and returns z and true if overflow occurred.
20func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) {
21 var carry uint64
22 z[0], carry = bits.Add64(x[0], y[0], 0)
23 z[1], carry = bits.Add64(x[1], y[1], carry)
24 z[2], carry = bits.Add64(x[2], y[2], carry)
25 z[3], carry = bits.Add64(x[3], y[3], carry)
26 return z, carry != 0
27}
28
29// Sub sets z to the difference x-y and returns z.
30func (z *Uint) Sub(x, y *Uint) *Uint {
31 var carry uint64
32 z[0], carry = bits.Sub64(x[0], y[0], 0)
33 z[1], carry = bits.Sub64(x[1], y[1], carry)
34 z[2], carry = bits.Sub64(x[2], y[2], carry)
35 z[3], _ = bits.Sub64(x[3], y[3], carry)
36 return z
37}
38
39// SubOverflow sets z to the difference x-y and returns z and true if underflow occurred.
40func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) {
41 var carry uint64
42 z[0], carry = bits.Sub64(x[0], y[0], 0)
43 z[1], carry = bits.Sub64(x[1], y[1], carry)
44 z[2], carry = bits.Sub64(x[2], y[2], carry)
45 z[3], carry = bits.Sub64(x[3], y[3], carry)
46 return z, carry != 0
47}
48
49// Neg returns -x mod 2^256.
50func (z *Uint) Neg(x *Uint) *Uint {
51 return z.Sub(Zero(), x)
52}
53
54// Mul sets z to the product x*y and returns z.
55func (z *Uint) Mul(x, y *Uint) *Uint {
56 var (
57 res Uint
58 carry uint64
59 res1, res2, res3 uint64
60 )
61
62 carry, res[0] = bits.Mul64(x[0], y[0])
63 carry, res1 = umulHop(carry, x[1], y[0])
64 carry, res2 = umulHop(carry, x[2], y[0])
65 res3 = x[3]*y[0] + carry
66
67 carry, res[1] = umulHop(res1, x[0], y[1])
68 carry, res2 = umulStep(res2, x[1], y[1], carry)
69 res3 = res3 + x[2]*y[1] + carry
70
71 carry, res[2] = umulHop(res2, x[0], y[2])
72 res3 = res3 + x[1]*y[2] + carry
73
74 res[3] = res3 + x[0]*y[3]
75
76 return z.Set(&res)
77}
78
79// MulOverflow sets z to the product x*y and returns z and true if overflow occurred.
80func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) {
81 p := umul(x, y)
82 copy(z[:], p[:4])
83 return z, (p[4] | p[5] | p[6] | p[7]) != 0
84}
85
86// Div sets z to the quotient x/y and returns z.
87// If y == 0, z is set to 0.
88func (z *Uint) Div(x, y *Uint) *Uint {
89 if y.IsZero() || y.Gt(x) {
90 return z.Clear()
91 }
92 if x.Eq(y) {
93 return z.SetOne()
94 }
95 // Shortcut some cases
96 if x.IsUint64() {
97 return z.SetUint64(x.Uint64() / y.Uint64())
98 }
99
100 // At this point, we know
101 // x/y ; x > y > 0
102
103 var quot Uint
104 udivrem(quot[:], x[:], y)
105 return z.Set(")
106}
107
108// Mod sets z to the modulus x%y for y != 0 and returns z.
109// If y == 0, z is set to 0 (this differs from big.Int behavior).
110func (z *Uint) Mod(x, y *Uint) *Uint {
111 if x.IsZero() || y.IsZero() {
112 return z.Clear()
113 }
114 switch x.Cmp(y) {
115 case -1:
116 // x < y
117 copy(z[:], x[:])
118 return z
119 case 0:
120 // x == y
121 return z.Clear() // They are equal
122 }
123
124 // At this point:
125 // x != 0
126 // y != 0
127 // x > y
128
129 // Shortcut trivial case
130 if x.IsUint64() {
131 return z.SetUint64(x.Uint64() % y.Uint64())
132 }
133
134 var quot Uint
135 *z = udivrem(quot[:], x[:], y)
136 return z
137}
138
139// MulMod sets z to (x * y) mod m and returns z.
140// If m == 0, z is set to 0 (this differs from big.Int behavior).
141func (z *Uint) MulMod(x, y, m *Uint) *Uint {
142 if x.IsZero() || y.IsZero() || m.IsZero() {
143 return z.Clear()
144 }
145 p := umul(x, y)
146
147 if m[3] != 0 {
148 mu := Reciprocal(m)
149 r := reduce4(p, m, mu)
150 return z.Set(&r)
151 }
152
153 var (
154 pl Uint
155 ph Uint
156 )
157
158 pl[0], pl[1], pl[2], pl[3] = p[0], p[1], p[2], p[3]
159 ph[0], ph[1], ph[2], ph[3] = p[4], p[5], p[6], p[7]
160
161 // If the multiplication is within 256 bits use Mod().
162 if ph.IsZero() {
163 return z.Mod(&pl, m)
164 }
165
166 var quot [8]uint64
167 rem := udivrem(quot[:], p[:], m)
168 return z.Set(&rem)
169}
170
171// DivMod sets z to the quotient x/y and m to the modulus x%y, returning the pair (z, m).
172// If y == 0, both z and m are set to 0 (this differs from big.Int behavior).
173func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) {
174 if y.IsZero() {
175 return z.Clear(), m.Clear()
176 }
177
178 switch x.Cmp(y) {
179 case -1:
180 // x < y
181 return z.Clear(), m.Set(x)
182 case 0:
183 // x == y
184 return z.SetOne(), m.Clear()
185 }
186
187 // At this point:
188 // x != 0
189 // y != 0
190 // x > y
191
192 // Shortcut trivial case
193 if x.IsUint64() {
194 x0, y0 := x.Uint64(), y.Uint64()
195 return z.SetUint64(x0 / y0), m.SetUint64(x0 % y0)
196 }
197
198 var quot Uint
199 *m = udivrem(quot[:], x[:], y)
200 *z = quot
201 return z, m
202}
203
204// udivrem divides u by d and produces both quotient and remainder.
205// The quotient is stored in provided quot - len(u)-len(d)+1 words.
206// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words.
207// See Knuth, Volume 2, section 4.3.1, Algorithm D.
208func udivrem(quot, u []uint64, d *Uint) (rem Uint) {
209 var dLen int
210 for i := len(d) - 1; i >= 0; i-- {
211 if d[i] != 0 {
212 dLen = i + 1
213 break
214 }
215 }
216
217 shift := uint(bits.LeadingZeros64(d[dLen-1]))
218
219 var dnStorage Uint
220 dn := dnStorage[:dLen]
221 for i := dLen - 1; i > 0; i-- {
222 dn[i] = (d[i] << shift) | (d[i-1] >> (64 - shift))
223 }
224 dn[0] = d[0] << shift
225
226 var uLen int
227 for i := len(u) - 1; i >= 0; i-- {
228 if u[i] != 0 {
229 uLen = i + 1
230 break
231 }
232 }
233
234 if uLen < dLen {
235 copy(rem[:], u)
236 return rem
237 }
238
239 var unStorage [9]uint64
240 un := unStorage[:uLen+1]
241 un[uLen] = u[uLen-1] >> (64 - shift)
242 for i := uLen - 1; i > 0; i-- {
243 un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift))
244 }
245 un[0] = u[0] << shift
246
247 if dLen == 1 {
248 r := udivremBy1(quot, un, dn[0])
249 rem.SetUint64(r >> shift)
250 return rem
251 }
252
253 udivremKnuth(quot, un, dn)
254
255 for i := 0; i < dLen-1; i++ {
256 rem[i] = (un[i] >> shift) | (un[i+1] << (64 - shift))
257 }
258 rem[dLen-1] = un[dLen-1] >> shift
259
260 return rem
261}
262
263// umul computes full 256 x 256 -> 512 multiplication.
264func umul(x, y *Uint) [8]uint64 {
265 var res [8]uint64
266
267 topX := highestNonZeroWord(x)
268 topY := highestNonZeroWord(y)
269
270 if topX < 0 || topY < 0 {
271 return res
272 }
273
274 lenX := topX + 1
275 lenY := topY + 1
276
277 for i := 0; i < lenX; i++ {
278 xi := x[i]
279 if xi == 0 {
280 continue
281 }
282 var carry uint64
283 k := i
284 for j := 0; j < lenY; j++ {
285 hi, lo := bits.Mul64(xi, y[j])
286 lo, c := bits.Add64(lo, res[k], 0)
287 hi += c
288 lo, c = bits.Add64(lo, carry, 0)
289 hi += c
290 res[k] = lo
291 carry = hi
292 k++
293 }
294 res[i+lenY] = carry
295 }
296
297 return res
298}
299
300// highestNonZeroWord returns the highest index with non-zero value or -1 if the Uint is zero.
301func highestNonZeroWord(u *Uint) int {
302 for i := 3; i >= 0; i-- {
303 if u[i] != 0 {
304 return i
305 }
306 }
307 return -1
308}
309
310// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry.
311func umulStep(z, x, y, carry uint64) (hi, lo uint64) {
312 hi, lo = bits.Mul64(x, y)
313 lo, carry = bits.Add64(lo, carry, 0)
314 hi += carry
315 lo, carry = bits.Add64(lo, z, 0)
316 hi += carry
317 return hi, lo
318}
319
320// umulHop computes (hi * 2^64 + lo) = z + (x * y)
321func umulHop(z, x, y uint64) (hi, lo uint64) {
322 hi, lo = bits.Mul64(x, y)
323 lo, carry := bits.Add64(lo, z, 0)
324 hi += carry
325 return hi, lo
326}
327
328// udivremBy1 divides u by single normalized word d and produces both quotient and remainder.
329// The quotient is stored in provided quot.
330func udivremBy1(quot, u []uint64, d uint64) (rem uint64) {
331 reciprocal := reciprocal2by1(d)
332 rem = u[len(u)-1] // Set the top word as remainder.
333 for j := len(u) - 2; j >= 0; j-- {
334 quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal)
335 }
336 return rem
337}
338
339// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm.
340// The quotient is stored in provided quot - len(u)-len(d) words.
341// Updates u to contain the remainder - len(d) words.
342func udivremKnuth(quot, u, d []uint64) {
343 dLen := len(d)
344 dh := d[dLen-1]
345 dl := d[dLen-2]
346 reciprocal := reciprocal2by1(dh)
347
348 for j := len(u) - dLen - 1; j >= 0; j-- {
349 u2 := u[j+dLen]
350 u1 := u[j+dLen-1]
351 u0 := u[j+dLen-2]
352
353 var qhat, rhat uint64
354 if u2 >= dh { // Division overflows.
355 qhat = MAX_UINT64
356 // NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
357 } else {
358 qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal)
359 ph, pl := bits.Mul64(qhat, dl)
360 if ph > rhat || (ph == rhat && pl > u0) {
361 qhat--
362 // NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
363 }
364 }
365
366 // Multiply and subtract.
367 borrow := subMulTo(u[j:], d, qhat)
368 u[j+dLen] = u2 - borrow
369 if u2 < borrow { // Too much subtracted, add back.
370 qhat--
371 u[j+dLen] += addTo(u[j:], d)
372 }
373
374 quot[j] = qhat // Store quotient digit.
375 }
376}
377
378// isBitSet returns true if bit n-th is set, where n = 0 is LSB.
379// The n must be <= 255.
380func (z *Uint) isBitSet(n uint) bool {
381 return (z[n/64] & (1 << (n % 64))) != 0
382}
383
384func (z *Uint) IsOverflow() bool {
385 return z.isBitSet(255)
386}
387
388// addTo computes x += y.
389// Requires len(x) >= len(y).
390func addTo(x, y []uint64) uint64 {
391 var carry uint64
392 for i := 0; i < len(y); i++ {
393 x[i], carry = bits.Add64(x[i], y[i], carry)
394 }
395 return carry
396}
397
398// subMulTo computes x -= y * multiplier.
399// Requires len(x) >= len(y).
400func subMulTo(x, y []uint64, multiplier uint64) uint64 {
401 var borrow uint64
402 for i := 0; i < len(y); i++ {
403 s, carry1 := bits.Sub64(x[i], borrow, 0)
404 ph, pl := bits.Mul64(y[i], multiplier)
405 t, carry2 := bits.Sub64(s, pl, 0)
406 x[i] = t
407 borrow = ph + carry1 + carry2
408 }
409 return borrow
410}
411
412// reciprocal2by1 computes <^d, ^0> / d.
413func reciprocal2by1(d uint64) uint64 {
414 reciprocal, _ := bits.Div64(^d, MAX_UINT64, d)
415 return reciprocal
416}
417
418// udivrem2by1 divides <uh, ul> / d and produces both quotient and remainder.
419// It uses the provided d's reciprocal.
420// Implementation ported from https://github.com/chfast/intx and is based on
421// "Improved division by invariant integers", Algorithm 4.
422func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) {
423 qh, ql := bits.Mul64(reciprocal, uh)
424 ql, carry := bits.Add64(ql, ul, 0)
425 qh, _ = bits.Add64(qh, uh, carry)
426 qh++
427
428 r := ul - qh*d
429
430 if r > ql {
431 qh--
432 r += d
433 }
434
435 if r >= d {
436 qh++
437 r -= d
438 }
439
440 return qh, r
441}
442
443// MustDiv sets z to the quotient x/y and returns z.
444// It panics if y == 0. Used in critical AMM paths where division by zero represents a programming error.
445func (z *Uint) MustDiv(x, y *Uint) *Uint {
446 if y.IsZero() {
447 panic("division by zero")
448 }
449 return z.Div(x, y)
450}
451
452// MustMod sets z to the modulus x%y and returns z.
453// It panics if y == 0. Used in critical AMM paths where modulo by zero represents a programming error.
454func (z *Uint) MustMod(x, y *Uint) *Uint {
455 if y.IsZero() {
456 panic("modulo by zero")
457 }
458 return z.Mod(x, y)
459}
460
461// MustMulMod sets z to (x * y) mod m and returns z.
462// It panics if m == 0. Used in critical AMM paths where modulo by zero represents a programming error.
463func (z *Uint) MustMulMod(x, y, m *Uint) *Uint {
464 if m.IsZero() {
465 panic("modulo by zero")
466 }
467 return z.MulMod(x, y, m)
468}
469
470// MustDivMod sets z to the quotient x/y and m to the modulus x%y, returning the pair (z, m).
471// It panics if y == 0. Used in critical AMM paths where division by zero represents a programming error.
472func (z *Uint) MustDivMod(x, y, m *Uint) (*Uint, *Uint) {
473 if y.IsZero() {
474 panic("division by zero")
475 }
476 return z.DivMod(x, y, m)
477}
478
479// MustMul sets z to the product x*y and returns z.
480// It panics on overflow. Used in critical AMM calculations where overflow represents a programming error.
481func (z *Uint) MustMul(x, y *Uint) *Uint {
482 result, overflow := z.MulOverflow(x, y)
483 if overflow {
484 panic("uint256: multiplication overflow")
485 }
486 return result
487}