Search Apps Documentation Source Content File Folder Download Copy Actions Download

arithmetic.gno

11.62 Kb ยท 487 lines
  1// arithmetic provides arithmetic operations for Uint objects.
  2// This includes basic binary operations such as addition, subtraction, multiplication, division, and modulo operations
  3// as well as overflow checks, and negation. These functions are essential for numeric
  4// calculations using 256-bit unsigned integers.
  5package uint256
  6
  7import "math/bits"
  8
  9// Add sets z to the sum x+y and returns z.
 10func (z *Uint) Add(x, y *Uint) *Uint {
 11	var carry uint64
 12	z[0], carry = bits.Add64(x[0], y[0], 0)
 13	z[1], carry = bits.Add64(x[1], y[1], carry)
 14	z[2], carry = bits.Add64(x[2], y[2], carry)
 15	z[3], _ = bits.Add64(x[3], y[3], carry)
 16	return z
 17}
 18
 19// AddOverflow sets z to the sum x+y and returns z and true if overflow occurred.
 20func (z *Uint) AddOverflow(x, y *Uint) (*Uint, bool) {
 21	var carry uint64
 22	z[0], carry = bits.Add64(x[0], y[0], 0)
 23	z[1], carry = bits.Add64(x[1], y[1], carry)
 24	z[2], carry = bits.Add64(x[2], y[2], carry)
 25	z[3], carry = bits.Add64(x[3], y[3], carry)
 26	return z, carry != 0
 27}
 28
 29// Sub sets z to the difference x-y and returns z.
 30func (z *Uint) Sub(x, y *Uint) *Uint {
 31	var carry uint64
 32	z[0], carry = bits.Sub64(x[0], y[0], 0)
 33	z[1], carry = bits.Sub64(x[1], y[1], carry)
 34	z[2], carry = bits.Sub64(x[2], y[2], carry)
 35	z[3], _ = bits.Sub64(x[3], y[3], carry)
 36	return z
 37}
 38
 39// SubOverflow sets z to the difference x-y and returns z and true if underflow occurred.
 40func (z *Uint) SubOverflow(x, y *Uint) (*Uint, bool) {
 41	var carry uint64
 42	z[0], carry = bits.Sub64(x[0], y[0], 0)
 43	z[1], carry = bits.Sub64(x[1], y[1], carry)
 44	z[2], carry = bits.Sub64(x[2], y[2], carry)
 45	z[3], carry = bits.Sub64(x[3], y[3], carry)
 46	return z, carry != 0
 47}
 48
 49// Neg returns -x mod 2^256.
 50func (z *Uint) Neg(x *Uint) *Uint {
 51	return z.Sub(Zero(), x)
 52}
 53
 54// Mul sets z to the product x*y and returns z.
 55func (z *Uint) Mul(x, y *Uint) *Uint {
 56	var (
 57		res              Uint
 58		carry            uint64
 59		res1, res2, res3 uint64
 60	)
 61
 62	carry, res[0] = bits.Mul64(x[0], y[0])
 63	carry, res1 = umulHop(carry, x[1], y[0])
 64	carry, res2 = umulHop(carry, x[2], y[0])
 65	res3 = x[3]*y[0] + carry
 66
 67	carry, res[1] = umulHop(res1, x[0], y[1])
 68	carry, res2 = umulStep(res2, x[1], y[1], carry)
 69	res3 = res3 + x[2]*y[1] + carry
 70
 71	carry, res[2] = umulHop(res2, x[0], y[2])
 72	res3 = res3 + x[1]*y[2] + carry
 73
 74	res[3] = res3 + x[0]*y[3]
 75
 76	return z.Set(&res)
 77}
 78
 79// MulOverflow sets z to the product x*y and returns z and true if overflow occurred.
 80func (z *Uint) MulOverflow(x, y *Uint) (*Uint, bool) {
 81	p := umul(x, y)
 82	copy(z[:], p[:4])
 83	return z, (p[4] | p[5] | p[6] | p[7]) != 0
 84}
 85
 86// Div sets z to the quotient x/y and returns z.
 87// If y == 0, z is set to 0.
 88func (z *Uint) Div(x, y *Uint) *Uint {
 89	if y.IsZero() || y.Gt(x) {
 90		return z.Clear()
 91	}
 92	if x.Eq(y) {
 93		return z.SetOne()
 94	}
 95	// Shortcut some cases
 96	if x.IsUint64() {
 97		return z.SetUint64(x.Uint64() / y.Uint64())
 98	}
 99
100	// At this point, we know
101	// x/y ; x > y > 0
102
103	var quot Uint
104	udivrem(quot[:], x[:], y)
105	return z.Set(&quot)
106}
107
108// Mod sets z to the modulus x%y for y != 0 and returns z.
109// If y == 0, z is set to 0 (this differs from big.Int behavior).
110func (z *Uint) Mod(x, y *Uint) *Uint {
111	if x.IsZero() || y.IsZero() {
112		return z.Clear()
113	}
114	switch x.Cmp(y) {
115	case -1:
116		// x < y
117		copy(z[:], x[:])
118		return z
119	case 0:
120		// x == y
121		return z.Clear() // They are equal
122	}
123
124	// At this point:
125	// x != 0
126	// y != 0
127	// x > y
128
129	// Shortcut trivial case
130	if x.IsUint64() {
131		return z.SetUint64(x.Uint64() % y.Uint64())
132	}
133
134	var quot Uint
135	*z = udivrem(quot[:], x[:], y)
136	return z
137}
138
139// MulMod sets z to (x * y) mod m and returns z.
140// If m == 0, z is set to 0 (this differs from big.Int behavior).
141func (z *Uint) MulMod(x, y, m *Uint) *Uint {
142	if x.IsZero() || y.IsZero() || m.IsZero() {
143		return z.Clear()
144	}
145	p := umul(x, y)
146
147	if m[3] != 0 {
148		mu := Reciprocal(m)
149		r := reduce4(p, m, mu)
150		return z.Set(&r)
151	}
152
153	var (
154		pl Uint
155		ph Uint
156	)
157
158	pl[0], pl[1], pl[2], pl[3] = p[0], p[1], p[2], p[3]
159	ph[0], ph[1], ph[2], ph[3] = p[4], p[5], p[6], p[7]
160
161	// If the multiplication is within 256 bits use Mod().
162	if ph.IsZero() {
163		return z.Mod(&pl, m)
164	}
165
166	var quot [8]uint64
167	rem := udivrem(quot[:], p[:], m)
168	return z.Set(&rem)
169}
170
171// DivMod sets z to the quotient x/y and m to the modulus x%y, returning the pair (z, m).
172// If y == 0, both z and m are set to 0 (this differs from big.Int behavior).
173func (z *Uint) DivMod(x, y, m *Uint) (*Uint, *Uint) {
174	if y.IsZero() {
175		return z.Clear(), m.Clear()
176	}
177
178	switch x.Cmp(y) {
179	case -1:
180		// x < y
181		return z.Clear(), m.Set(x)
182	case 0:
183		// x == y
184		return z.SetOne(), m.Clear()
185	}
186
187	// At this point:
188	// x != 0
189	// y != 0
190	// x > y
191
192	// Shortcut trivial case
193	if x.IsUint64() {
194		x0, y0 := x.Uint64(), y.Uint64()
195		return z.SetUint64(x0 / y0), m.SetUint64(x0 % y0)
196	}
197
198	var quot Uint
199	*m = udivrem(quot[:], x[:], y)
200	*z = quot
201	return z, m
202}
203
204// udivrem divides u by d and produces both quotient and remainder.
205// The quotient is stored in provided quot - len(u)-len(d)+1 words.
206// It loosely follows the Knuth's division algorithm (sometimes referenced as "schoolbook" division) using 64-bit words.
207// See Knuth, Volume 2, section 4.3.1, Algorithm D.
208func udivrem(quot, u []uint64, d *Uint) (rem Uint) {
209	var dLen int
210	for i := len(d) - 1; i >= 0; i-- {
211		if d[i] != 0 {
212			dLen = i + 1
213			break
214		}
215	}
216
217	shift := uint(bits.LeadingZeros64(d[dLen-1]))
218
219	var dnStorage Uint
220	dn := dnStorage[:dLen]
221	for i := dLen - 1; i > 0; i-- {
222		dn[i] = (d[i] << shift) | (d[i-1] >> (64 - shift))
223	}
224	dn[0] = d[0] << shift
225
226	var uLen int
227	for i := len(u) - 1; i >= 0; i-- {
228		if u[i] != 0 {
229			uLen = i + 1
230			break
231		}
232	}
233
234	if uLen < dLen {
235		copy(rem[:], u)
236		return rem
237	}
238
239	var unStorage [9]uint64
240	un := unStorage[:uLen+1]
241	un[uLen] = u[uLen-1] >> (64 - shift)
242	for i := uLen - 1; i > 0; i-- {
243		un[i] = (u[i] << shift) | (u[i-1] >> (64 - shift))
244	}
245	un[0] = u[0] << shift
246
247	if dLen == 1 {
248		r := udivremBy1(quot, un, dn[0])
249		rem.SetUint64(r >> shift)
250		return rem
251	}
252
253	udivremKnuth(quot, un, dn)
254
255	for i := 0; i < dLen-1; i++ {
256		rem[i] = (un[i] >> shift) | (un[i+1] << (64 - shift))
257	}
258	rem[dLen-1] = un[dLen-1] >> shift
259
260	return rem
261}
262
263// umul computes full 256 x 256 -> 512 multiplication.
264func umul(x, y *Uint) [8]uint64 {
265	var res [8]uint64
266
267	topX := highestNonZeroWord(x)
268	topY := highestNonZeroWord(y)
269
270	if topX < 0 || topY < 0 {
271		return res
272	}
273
274	lenX := topX + 1
275	lenY := topY + 1
276
277	for i := 0; i < lenX; i++ {
278		xi := x[i]
279		if xi == 0 {
280			continue
281		}
282		var carry uint64
283		k := i
284		for j := 0; j < lenY; j++ {
285			hi, lo := bits.Mul64(xi, y[j])
286			lo, c := bits.Add64(lo, res[k], 0)
287			hi += c
288			lo, c = bits.Add64(lo, carry, 0)
289			hi += c
290			res[k] = lo
291			carry = hi
292			k++
293		}
294		res[i+lenY] = carry
295	}
296
297	return res
298}
299
300// highestNonZeroWord returns the highest index with non-zero value or -1 if the Uint is zero.
301func highestNonZeroWord(u *Uint) int {
302	for i := 3; i >= 0; i-- {
303		if u[i] != 0 {
304			return i
305		}
306	}
307	return -1
308}
309
310// umulStep computes (hi * 2^64 + lo) = z + (x * y) + carry.
311func umulStep(z, x, y, carry uint64) (hi, lo uint64) {
312	hi, lo = bits.Mul64(x, y)
313	lo, carry = bits.Add64(lo, carry, 0)
314	hi += carry
315	lo, carry = bits.Add64(lo, z, 0)
316	hi += carry
317	return hi, lo
318}
319
320// umulHop computes (hi * 2^64 + lo) = z + (x * y)
321func umulHop(z, x, y uint64) (hi, lo uint64) {
322	hi, lo = bits.Mul64(x, y)
323	lo, carry := bits.Add64(lo, z, 0)
324	hi += carry
325	return hi, lo
326}
327
328// udivremBy1 divides u by single normalized word d and produces both quotient and remainder.
329// The quotient is stored in provided quot.
330func udivremBy1(quot, u []uint64, d uint64) (rem uint64) {
331	reciprocal := reciprocal2by1(d)
332	rem = u[len(u)-1] // Set the top word as remainder.
333	for j := len(u) - 2; j >= 0; j-- {
334		quot[j], rem = udivrem2by1(rem, u[j], d, reciprocal)
335	}
336	return rem
337}
338
339// udivremKnuth implements the division of u by normalized multiple word d from the Knuth's division algorithm.
340// The quotient is stored in provided quot - len(u)-len(d) words.
341// Updates u to contain the remainder - len(d) words.
342func udivremKnuth(quot, u, d []uint64) {
343	dLen := len(d)
344	dh := d[dLen-1]
345	dl := d[dLen-2]
346	reciprocal := reciprocal2by1(dh)
347
348	for j := len(u) - dLen - 1; j >= 0; j-- {
349		u2 := u[j+dLen]
350		u1 := u[j+dLen-1]
351		u0 := u[j+dLen-2]
352
353		var qhat, rhat uint64
354		if u2 >= dh { // Division overflows.
355			qhat = MAX_UINT64
356			// NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
357		} else {
358			qhat, rhat = udivrem2by1(u2, u1, dh, reciprocal)
359			ph, pl := bits.Mul64(qhat, dl)
360			if ph > rhat || (ph == rhat && pl > u0) {
361				qhat--
362				// NOTE: Add "qhat one to big" adjustment (not needed for correctness, but helps avoiding "add back" case).
363			}
364		}
365
366		// Multiply and subtract.
367		borrow := subMulTo(u[j:], d, qhat)
368		u[j+dLen] = u2 - borrow
369		if u2 < borrow { // Too much subtracted, add back.
370			qhat--
371			u[j+dLen] += addTo(u[j:], d)
372		}
373
374		quot[j] = qhat // Store quotient digit.
375	}
376}
377
378// isBitSet returns true if bit n-th is set, where n = 0 is LSB.
379// The n must be <= 255.
380func (z *Uint) isBitSet(n uint) bool {
381	return (z[n/64] & (1 << (n % 64))) != 0
382}
383
384func (z *Uint) IsOverflow() bool {
385	return z.isBitSet(255)
386}
387
388// addTo computes x += y.
389// Requires len(x) >= len(y).
390func addTo(x, y []uint64) uint64 {
391	var carry uint64
392	for i := 0; i < len(y); i++ {
393		x[i], carry = bits.Add64(x[i], y[i], carry)
394	}
395	return carry
396}
397
398// subMulTo computes x -= y * multiplier.
399// Requires len(x) >= len(y).
400func subMulTo(x, y []uint64, multiplier uint64) uint64 {
401	var borrow uint64
402	for i := 0; i < len(y); i++ {
403		s, carry1 := bits.Sub64(x[i], borrow, 0)
404		ph, pl := bits.Mul64(y[i], multiplier)
405		t, carry2 := bits.Sub64(s, pl, 0)
406		x[i] = t
407		borrow = ph + carry1 + carry2
408	}
409	return borrow
410}
411
412// reciprocal2by1 computes <^d, ^0> / d.
413func reciprocal2by1(d uint64) uint64 {
414	reciprocal, _ := bits.Div64(^d, MAX_UINT64, d)
415	return reciprocal
416}
417
418// udivrem2by1 divides <uh, ul> / d and produces both quotient and remainder.
419// It uses the provided d's reciprocal.
420// Implementation ported from https://github.com/chfast/intx and is based on
421// "Improved division by invariant integers", Algorithm 4.
422func udivrem2by1(uh, ul, d, reciprocal uint64) (quot, rem uint64) {
423	qh, ql := bits.Mul64(reciprocal, uh)
424	ql, carry := bits.Add64(ql, ul, 0)
425	qh, _ = bits.Add64(qh, uh, carry)
426	qh++
427
428	r := ul - qh*d
429
430	if r > ql {
431		qh--
432		r += d
433	}
434
435	if r >= d {
436		qh++
437		r -= d
438	}
439
440	return qh, r
441}
442
443// MustDiv sets z to the quotient x/y and returns z.
444// It panics if y == 0. Used in critical AMM paths where division by zero represents a programming error.
445func (z *Uint) MustDiv(x, y *Uint) *Uint {
446	if y.IsZero() {
447		panic("division by zero")
448	}
449	return z.Div(x, y)
450}
451
452// MustMod sets z to the modulus x%y and returns z.
453// It panics if y == 0. Used in critical AMM paths where modulo by zero represents a programming error.
454func (z *Uint) MustMod(x, y *Uint) *Uint {
455	if y.IsZero() {
456		panic("modulo by zero")
457	}
458	return z.Mod(x, y)
459}
460
461// MustMulMod sets z to (x * y) mod m and returns z.
462// It panics if m == 0. Used in critical AMM paths where modulo by zero represents a programming error.
463func (z *Uint) MustMulMod(x, y, m *Uint) *Uint {
464	if m.IsZero() {
465		panic("modulo by zero")
466	}
467	return z.MulMod(x, y, m)
468}
469
470// MustDivMod sets z to the quotient x/y and m to the modulus x%y, returning the pair (z, m).
471// It panics if y == 0. Used in critical AMM paths where division by zero represents a programming error.
472func (z *Uint) MustDivMod(x, y, m *Uint) (*Uint, *Uint) {
473	if y.IsZero() {
474		panic("division by zero")
475	}
476	return z.DivMod(x, y, m)
477}
478
479// MustMul sets z to the product x*y and returns z.
480// It panics on overflow. Used in critical AMM calculations where overflow represents a programming error.
481func (z *Uint) MustMul(x, y *Uint) *Uint {
482	result, overflow := z.MulOverflow(x, y)
483	if overflow {
484		panic("uint256: multiplication overflow")
485	}
486	return result
487}