Search Apps Documentation Source Content File Folder Download Copy Actions Download

sqrt_price_math.gno

11.08 Kb ยท 329 lines
  1package gnsmath
  2
  3import (
  4	i256 "gno.land/p/gnoswap/int256"
  5	u256 "gno.land/p/gnoswap/uint256"
  6)
  7
  8const (
  9	Q96_RESOLUTION  uint = 96
 10	Q160_RESOLUTION uint = 160
 11)
 12
 13var (
 14	q96       = u256.Zero().Lsh(u256.One(), 96)                               // 2^96
 15	max160    = u256.Zero().Sub(u256.Zero().Lsh(u256.One(), 160), u256.One()) // 2^160 - 1
 16	maxInt256 = u256.Zero().Sub(u256.Zero().Lsh(u256.One(), 255), u256.One()) // 2^255 - 1
 17
 18	MIN_SQRT_RATIO = u256.MustFromDecimal("4295128739")
 19	MAX_SQRT_RATIO = u256.MustFromDecimal("1461446703485210103287273052203988822378723970342")
 20)
 21
 22// getNextPriceAmount0Add calculates the next sqrt price when adding token0 liquidity,
 23// rounding up to ensure conservative pricing for the protocol.
 24// This internal function handles the case where token0 is being added to the pool.
 25func getNextPriceAmount0Add(
 26	currentSqrtPriceX96, liquidity, amountToAdd *u256.Uint,
 27) *u256.Uint {
 28	// liquidityShifted = liquidity << 96
 29	liquidityShifted := u256.Zero().Lsh(liquidity, Q96_RESOLUTION)
 30	// amountTimesSqrtPrice = amount * sqrtPrice
 31	amountTimesSqrtPrice := u256.Zero().Mul(amountToAdd, currentSqrtPriceX96)
 32
 33	// Overflow check: Ensure (amountTimesSqrtPrice / amountToAdd) == currentSqrtPriceX96
 34	quotientCheck := u256.Zero().Div(amountTimesSqrtPrice, amountToAdd)
 35	if quotientCheck.Eq(currentSqrtPriceX96) {
 36		// denominator = liquidityShifted + amountTimesSqrtPrice
 37		denominator := u256.Zero().Add(liquidityShifted, amountTimesSqrtPrice)
 38		// only take this path when denominator >= liquidityShifted
 39		if denominator.Gte(liquidityShifted) {
 40			return u256.MulDivRoundingUp(liquidityShifted, currentSqrtPriceX96, denominator)
 41		}
 42	}
 43
 44	// fallback: liquidityShifted / ((liquidityShifted / sqrtPrice) + amount)
 45	divValue := u256.Zero().Div(liquidityShifted, currentSqrtPriceX96)
 46	denominator := u256.Zero().Add(divValue, amountToAdd)
 47	return u256.DivRoundingUp(liquidityShifted, denominator)
 48}
 49
 50// getNextPriceAmount0Remove calculates the next sqrt price when removing token0 liquidity,
 51// rounding up to ensure conservative pricing for the protocol.
 52// This internal function handles the case where token0 is being removed from the pool.
 53// Panics if validation checks fail (invalid pool sqrt price calculation).
 54func getNextPriceAmount0Remove(
 55	currentSqrtPriceX96, liquidity, amountToRemove *u256.Uint,
 56) *u256.Uint {
 57	// liquidityShifted = liquidity << 96
 58	liquidityShifted := u256.Zero().Lsh(liquidity, Q96_RESOLUTION)
 59	// amountTimesSqrtPrice = amountToRemove * currentSqrtPriceX96
 60	amountTimesSqrtPrice := u256.Zero().Mul(amountToRemove, currentSqrtPriceX96)
 61
 62	// Validation checks
 63	quotientCheck := u256.Zero().Div(amountTimesSqrtPrice, amountToRemove)
 64	if !quotientCheck.Eq(currentSqrtPriceX96) || !liquidityShifted.Gt(amountTimesSqrtPrice) {
 65		panic(errInvalidPoolSqrtPrice)
 66	}
 67
 68	denominator := u256.Zero().Sub(liquidityShifted, amountTimesSqrtPrice)
 69	return u256.MulDivRoundingUp(liquidityShifted, currentSqrtPriceX96, denominator)
 70}
 71
 72// getNextSqrtPriceFromAmount0RoundingUp calculates the next sqrt price based on token0 amount,
 73// always rounding up to ensure conservative pricing in both exact output and exact input cases.
 74// The add parameter determines whether liquidity is being added (true) or removed (false).
 75func getNextSqrtPriceFromAmount0RoundingUp(
 76	sqrtPX96 *u256.Uint,
 77	liquidity *u256.Uint,
 78	amount *u256.Uint,
 79	add bool,
 80) *u256.Uint {
 81	// Shortcut: if no amount, return original price
 82	if amount.IsZero() {
 83		return sqrtPX96
 84	}
 85
 86	if add {
 87		return getNextPriceAmount0Add(sqrtPX96, liquidity, amount)
 88	}
 89	return getNextPriceAmount0Remove(sqrtPX96, liquidity, amount)
 90}
 91
 92// getNextPriceAmount1Add calculates the next sqrt price when adding token1,
 93// preserving rounding-down logic for the final result.
 94// This internal function handles the case where token1 is being added to the pool.
 95func getNextPriceAmount1Add(
 96	sqrtPX96, liquidity, amount *u256.Uint,
 97) *u256.Uint {
 98	var quotient *u256.Uint
 99
100	if amount.Lte(max160) {
101		// Use local variables to avoid allocation conflicts
102		shifted := u256.Zero().Lsh(amount, Q96_RESOLUTION)
103		quotient = u256.Zero().MustDiv(shifted, liquidity)
104	} else {
105		quotient = u256.MulDiv(amount, q96, liquidity)
106	}
107
108	return u256.Zero().Add(sqrtPX96, quotient)
109}
110
111// getNextPriceAmount1Remove calculates the next sqrt price when removing token1,
112// preserving rounding-down logic for the final result.
113// This internal function handles the case where token1 is being removed from the pool.
114// Panics if sqrt price would exceed quotient.
115func getNextPriceAmount1Remove(
116	sqrtPX96, liquidity, amount *u256.Uint,
117) *u256.Uint {
118	var quotient *u256.Uint
119
120	if amount.Lte(max160) {
121		shifted := u256.Zero().Lsh(amount, Q96_RESOLUTION)
122		quotient = u256.DivRoundingUp(shifted, liquidity)
123	} else {
124		quotient = u256.MulDivRoundingUp(amount, q96, liquidity)
125	}
126
127	if !sqrtPX96.Gt(quotient) {
128		panic(errSqrtPriceExceedsQuotient)
129	}
130
131	return u256.Zero().Sub(sqrtPX96, quotient)
132}
133
134// getNextSqrtPriceFromAmount1RoundingDown calculates the next sqrt price based on token1 amount,
135// always rounding down to ensure conservative pricing in both exact output and exact input cases.
136// The add parameter determines whether liquidity is being added (true) or removed (false).
137func getNextSqrtPriceFromAmount1RoundingDown(
138	sqrtPX96,
139	liquidity,
140	amount *u256.Uint,
141	add bool,
142) *u256.Uint {
143	// Shortcut: if no amount, return original price
144	if amount.IsZero() {
145		return sqrtPX96
146	}
147
148	if add {
149		return getNextPriceAmount1Add(sqrtPX96, liquidity, amount)
150	}
151	return getNextPriceAmount1Remove(sqrtPX96, liquidity, amount)
152}
153
154// getNextSqrtPriceFromInput calculates the next sqrt price after adding tokens to the pool,
155// rounding up for conservative pricing in both swap directions.
156// The zeroForOne parameter indicates swap direction (token0 for token1 when true).
157// Panics if sqrtPX96 or liquidity is zero.
158func getNextSqrtPriceFromInput(
159	sqrtPX96, liquidity, amountIn *u256.Uint,
160	zeroForOne bool,
161) *u256.Uint {
162	if sqrtPX96.IsZero() {
163		panic(errSqrtPriceZero)
164	}
165
166	if liquidity.IsZero() {
167		panic(errLiquidityZero)
168	}
169
170	if zeroForOne {
171		return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true)
172	}
173
174	return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true)
175}
176
177// getNextSqrtPriceFromOutput calculates the next sqrt price after removing tokens from the pool,
178// using different rounding directions based on swap direction.
179// The zeroForOne parameter indicates swap direction (token0 for token1 when true).
180// Panics if sqrtPX96 or liquidity is zero.
181func getNextSqrtPriceFromOutput(
182	sqrtPX96, liquidity, amountOut *u256.Uint,
183	zeroForOne bool,
184) *u256.Uint {
185	if sqrtPX96.IsZero() {
186		panic(errSqrtPriceZero)
187	}
188
189	if liquidity.IsZero() {
190		panic(errLiquidityZero)
191	}
192
193	if zeroForOne {
194		return getNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false)
195	}
196
197	return getNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false)
198}
199
200// getAmount0DeltaHelper calculates the absolute token0 amount difference between two price ranges,
201// automatically swapping inputs to ensure correct ordering. The roundUp parameter controls
202// rounding direction for the final result to ensure conservative AMM calculations.
203// Panics if sqrtRatioAX96 is zero.
204func getAmount0DeltaHelper(
205	sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint,
206	roundUp bool,
207) *u256.Uint {
208	if sqrtRatioAX96.Gt(sqrtRatioBX96) {
209		sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96
210	}
211
212	// Use local variables for thread safety
213	numerator := u256.Zero().Lsh(liquidity, Q96_RESOLUTION)
214	difference := u256.Zero().Sub(sqrtRatioBX96, sqrtRatioAX96)
215
216	if sqrtRatioAX96.IsZero() {
217		panic(errSqrtRatioAX96Zero)
218	}
219
220	if roundUp {
221		intermediate := u256.MulDivRoundingUp(numerator, difference, sqrtRatioBX96)
222		return u256.DivRoundingUp(intermediate, sqrtRatioAX96)
223	}
224
225	intermediate := u256.MulDiv(numerator, difference, sqrtRatioBX96)
226	return u256.Zero().Div(intermediate, sqrtRatioAX96)
227}
228
229// getAmount1DeltaHelper calculates the absolute token1 amount difference between two price ranges,
230// automatically swapping inputs to ensure correct ordering. The roundUp parameter controls
231// rounding direction for the final result to ensure conservative AMM calculations.
232func getAmount1DeltaHelper(
233	sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint,
234	roundUp bool,
235) *u256.Uint {
236	if sqrtRatioAX96.Gt(sqrtRatioBX96) {
237		sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96
238	}
239
240	// amount1 = liquidity * (sqrtB - sqrtA) / 2^96
241	// Use local variable for thread safety
242	difference := u256.Zero().Sub(sqrtRatioBX96, sqrtRatioAX96)
243
244	if roundUp {
245		return u256.MulDivRoundingUp(liquidity, difference, q96)
246	}
247
248	return u256.MulDiv(liquidity, difference, q96)
249}
250
251// GetAmount0Delta calculates the token0 amount difference within a price range, returning
252// a signed int256 value that is negative when liquidity is negative. Rounds down for
253// negative liquidity and up for positive liquidity.
254//
255// Parameters:
256//   - sqrtRatioAX96: first sqrt price in Q96 format
257//   - sqrtRatioBX96: second sqrt price in Q96 format
258//   - liquidity: signed liquidity value
259//
260// Returns the token0 amount difference as a signed int256 value.
261//
262// Panics if any input is nil or if the result overflows int256.
263func GetAmount0Delta(
264	sqrtRatioAX96, sqrtRatioBX96 *u256.Uint,
265	liquidity *i256.Int,
266) *i256.Int {
267	if sqrtRatioAX96 == nil || sqrtRatioBX96 == nil || liquidity == nil {
268		panic(errGetAmount0DeltaNilInput)
269	}
270
271	if liquidity.IsNeg() {
272		u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
273		if u.Gt(maxInt256) {
274			// if u > (2**255 - 1), cannot cast to int256
275			panic(errAmount0DeltaOverflow)
276		}
277
278		// Convert to i256 and negate properly
279		return i256.Zero().Neg(i256.FromUint256(u))
280	}
281
282	u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
283	if u.Gt(maxInt256) {
284		// if u > (2**255 - 1), cannot cast to int256
285		panic(errAmount0DeltaOverflow)
286	}
287
288	return i256.FromUint256(u)
289}
290
291// GetAmount1Delta calculates the token1 amount difference within a price range, returning
292// a signed int256 value that is negative when liquidity is negative. Rounds down for
293// negative liquidity and up for positive liquidity.
294//
295// Parameters:
296//   - sqrtRatioAX96: first sqrt price in Q96 format
297//   - sqrtRatioBX96: second sqrt price in Q96 format
298//   - liquidity: signed liquidity value
299//
300// Returns the token1 amount difference as a signed int256 value.
301//
302// Panics if any input is nil or if the result overflows int256.
303func GetAmount1Delta(
304	sqrtRatioAX96, sqrtRatioBX96 *u256.Uint,
305	liquidity *i256.Int,
306) *i256.Int {
307	if sqrtRatioAX96 == nil || sqrtRatioBX96 == nil || liquidity == nil {
308		panic(errGetAmount1DeltaNilInput)
309	}
310
311	if liquidity.IsNeg() {
312		u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
313		if u.Gt(maxInt256) {
314			// if u > (2**255 - 1), cannot cast to int256
315			panic(errAmount1DeltaOverflow)
316		}
317
318		// Convert to i256 and negate properly
319		return i256.Zero().Neg(i256.FromUint256(u))
320	}
321
322	u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
323	if u.Gt(maxInt256) {
324		// if u > (2**255 - 1), cannot cast to int256
325		panic(errAmount1DeltaOverflow)
326	}
327
328	return i256.FromUint256(u)
329}