From dc64627be5200d28bb41a98b1d6ee65056d80cf6 Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Mon, 24 Feb 2025 20:46:35 +0100 Subject: [PATCH 1/5] Add saturating math operations (unsigned) --- contracts/utils/math/Math.sol | 178 +++++++++++++++++++++------------- test/utils/math/Math.test.js | 56 +++++++++++ 2 files changed, 165 insertions(+), 69 deletions(-) diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 2acd0754060..101ec964216 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// OpenZeppelin Contracts (last updated v5.1.0) (utils/math/Math.sol) +// OpenZeppelin Contracts (last updated v5.0.0) (utils/math/Math.sol) pragma solidity ^0.8.20; @@ -23,8 +23,8 @@ library Math { function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { uint256 c = a + b; - if (c < a) return (false, 0); - return (true, c); + success = c >= a; + result = c * SafeCast.toUint(success); } } @@ -33,8 +33,9 @@ library Math { */ function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - if (b > a) return (false, 0); - return (true, a - b); + uint256 c = a - b; + success = c <= a; + result = c * SafeCast.toUint(success); } } @@ -43,13 +44,14 @@ library Math { */ function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - // Gas optimization: this is cheaper than requiring 'a' not being zero, but the - // benefit is lost if 'b' is also tested. - // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522 - if (a == 0) return (true, 0); uint256 c = a * b; - if (c / a != b) return (false, 0); - return (true, c); + assembly ("memory-safe") { + // Only true when the multiplication doesn't overflow + // (c / a == b) || (a == 0) + success := or(eq(div(c, a), b), iszero(a)) + } + // equivalent to: success ? c : 0 + result = c * SafeCast.toUint(success); } } @@ -58,8 +60,11 @@ library Math { */ function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - if (b == 0) return (false, 0); - return (true, a / b); + success = b > 0; + assembly ("memory-safe") { + // In EVM any value divided by zero is zero. + result := div(a, b) + } } } @@ -68,11 +73,38 @@ library Math { */ function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - if (b == 0) return (false, 0); - return (true, a % b); + success = b > 0; + assembly ("memory-safe") { + // In EVM a value modulus zero is equal to zero. + result := mod(a, b) + } } } + /** + * @dev Unsigned saturating addition, bounds to `2 ** 256 - 1` instead of overflowing. + */ + function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) { + (bool success, uint256 result) = tryAdd(a, b); + return ternary(success, result, type(uint256).max); + } + + /** + * @dev Unsigned saturating subtraction, bounds to zero instead of overflowing. + */ + function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) { + (, uint256 result) = trySub(a, b); + return result; + } + + /** + * @dev Unsigned saturating multiplication, bounds to `2 ** 256 - 1` instead of overflowing. + */ + function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) { + (bool success, uint256 result) = tryMul(a, b); + return ternary(success, result, type(uint256).max); + } + /** * @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant. * @@ -144,11 +176,11 @@ library Math { function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) { unchecked { // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use - // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 + // use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 // variables such that product = prod1 * 2²⁵⁶ + prod0. uint256 prod0 = x * y; // Least significant 256 bits of the product uint256 prod1; // Most significant 256 bits of the product - assembly { + assembly ("memory-safe") { let mm := mulmod(x, y, not(0)) prod1 := sub(sub(mm, prod0), lt(mm, prod0)) } @@ -172,7 +204,7 @@ library Math { // Make division exact by subtracting the remainder from [prod1 prod0]. uint256 remainder; - assembly { + assembly ("memory-safe") { // Compute remainder using mulmod. remainder := mulmod(x, y, denominator) @@ -185,7 +217,7 @@ library Math { // Always >= 1. See https://cs.stackexchange.com/q/138556/92363. uint256 twos = denominator & (0 - denominator); - assembly { + assembly ("memory-safe") { // Divide denominator by twos. denominator := div(denominator, twos) @@ -232,7 +264,7 @@ library Math { /** * @dev Calculate the modular multiplicative inverse of a number in Z/nZ. * - * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, except 0. + * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0. * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible. * * If the input value is not inversible, 0 is returned. @@ -537,45 +569,41 @@ library Math { * @dev Return the log in base 2 of a positive value rounded towards zero. * Returns 0 if given 0. */ - function log2(uint256 x) internal pure returns (uint256 r) { - // If value has upper 128 bits set, log2 result is at least 128 - r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7; - // If upper 64 bits of 128-bit half set, add 64 to result - r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6; - // If upper 32 bits of 64-bit half set, add 32 to result - r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5; - // If upper 16 bits of 32-bit half set, add 16 to result - r |= SafeCast.toUint((x >> r) > 0xffff) << 4; - // If upper 8 bits of 16-bit half set, add 8 to result - r |= SafeCast.toUint((x >> r) > 0xff) << 3; - // If upper 4 bits of 8-bit half set, add 4 to result - r |= SafeCast.toUint((x >> r) > 0xf) << 2; - - // Shifts value right by the current result and use it as an index into this lookup table: - // - // | x (4 bits) | index | table[index] = MSB position | - // |------------|---------|-----------------------------| - // | 0000 | 0 | table[0] = 0 | - // | 0001 | 1 | table[1] = 0 | - // | 0010 | 2 | table[2] = 1 | - // | 0011 | 3 | table[3] = 1 | - // | 0100 | 4 | table[4] = 2 | - // | 0101 | 5 | table[5] = 2 | - // | 0110 | 6 | table[6] = 2 | - // | 0111 | 7 | table[7] = 2 | - // | 1000 | 8 | table[8] = 3 | - // | 1001 | 9 | table[9] = 3 | - // | 1010 | 10 | table[10] = 3 | - // | 1011 | 11 | table[11] = 3 | - // | 1100 | 12 | table[12] = 3 | - // | 1101 | 13 | table[13] = 3 | - // | 1110 | 14 | table[14] = 3 | - // | 1111 | 15 | table[15] = 3 | - // - // The lookup table is represented as a 32-byte value with the MSB positions for 0-15 in the last 16 bytes. - assembly ("memory-safe") { - r := or(r, byte(shr(r, x), 0x0000010102020202030303030303030300000000000000000000000000000000)) + function log2(uint256 value) internal pure returns (uint256) { + uint256 result = 0; + uint256 exp; + unchecked { + exp = 128 * SafeCast.toUint(value > (1 << 128) - 1); + value >>= exp; + result += exp; + + exp = 64 * SafeCast.toUint(value > (1 << 64) - 1); + value >>= exp; + result += exp; + + exp = 32 * SafeCast.toUint(value > (1 << 32) - 1); + value >>= exp; + result += exp; + + exp = 16 * SafeCast.toUint(value > (1 << 16) - 1); + value >>= exp; + result += exp; + + exp = 8 * SafeCast.toUint(value > (1 << 8) - 1); + value >>= exp; + result += exp; + + exp = 4 * SafeCast.toUint(value > (1 << 4) - 1); + value >>= exp; + result += exp; + + exp = 2 * SafeCast.toUint(value > (1 << 2) - 1); + value >>= exp; + result += exp; + + result += SafeCast.toUint(value > 1); } + return result; } /** @@ -644,17 +672,29 @@ library Math { * * Adding one to the result gives the number of pairs of hex symbols needed to represent `value` as a hex string. */ - function log256(uint256 x) internal pure returns (uint256 r) { - // If value has upper 128 bits set, log2 result is at least 128 - r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7; - // If upper 64 bits of 128-bit half set, add 64 to result - r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6; - // If upper 32 bits of 64-bit half set, add 32 to result - r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5; - // If upper 16 bits of 32-bit half set, add 16 to result - r |= SafeCast.toUint((x >> r) > 0xffff) << 4; - // Add 1 if upper 8 bits of 16-bit half set, and divide accumulated result by 8 - return (r >> 3) | SafeCast.toUint((x >> r) > 0xff); + function log256(uint256 value) internal pure returns (uint256) { + uint256 result = 0; + uint256 isGt; + unchecked { + isGt = SafeCast.toUint(value > (1 << 128) - 1); + value >>= isGt * 128; + result += isGt * 16; + + isGt = SafeCast.toUint(value > (1 << 64) - 1); + value >>= isGt * 64; + result += isGt * 8; + + isGt = SafeCast.toUint(value > (1 << 32) - 1); + value >>= isGt * 32; + result += isGt * 4; + + isGt = SafeCast.toUint(value > (1 << 16) - 1); + value >>= isGt * 16; + result += isGt * 2; + + result += SafeCast.toUint(value > (1 << 8) - 1); + } + return result; } /** diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index f38f2f3184f..025eb121687 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -147,6 +147,62 @@ describe('Math', function () { }); }); + describe('saturatingAdd', function () { + it('adds correctly', async function () { + const a = 5678n; + const b = 1234n; + await testCommutative(this.mock.$saturatingAdd, a, b, a + b); + await testCommutative(this.mock.$saturatingAdd, a, 0n, a); + await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256); + }); + + it('bounds on addition overflow', async function () { + await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256); + await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal( + ethers.MaxUint256, + ); + }); + }); + + describe('saturatingSub', function () { + it('subtracts correctly', async function () { + const a = 5678n; + const b = 1234n; + await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b); + await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a); + await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n); + }); + + it('bounds on subtraction overflow', async function () { + await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n); + }); + }); + + describe('saturatingMul', function () { + it('multiplies correctly', async function () { + const a = 1234n; + const b = 5678n; + await testCommutative(this.mock.$saturatingMul, a, b, a * b); + }); + + it('multiplies by zero correctly', async function () { + const a = 0n; + const b = 5678n; + await testCommutative(this.mock.$saturatingMul, a, b, 0n); + }); + + it('bounds on multiplication overflow', async function () { + const a = ethers.MaxUint256; + const b = 2n; + await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256); + }); + }); + describe('max', function () { it('is correctly detected in both position', async function () { await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n)); From ad9e6221b4829a6894afa1b32dc646734083684a Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Mon, 24 Feb 2025 20:48:12 +0100 Subject: [PATCH 2/5] changeset --- .changeset/fair-pumpkins-compete.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/fair-pumpkins-compete.md diff --git a/.changeset/fair-pumpkins-compete.md b/.changeset/fair-pumpkins-compete.md new file mode 100644 index 00000000000..1022617190f --- /dev/null +++ b/.changeset/fair-pumpkins-compete.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`. From 1a07c5a4566e89c82e706131bfe86c5b16b1c500 Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Mon, 24 Feb 2025 20:54:15 +0100 Subject: [PATCH 3/5] Update Math.sol --- contracts/utils/math/Math.sol | 112 ++++++++++++++++------------------ 1 file changed, 52 insertions(+), 60 deletions(-) diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 101ec964216..11b0f90d60d 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// OpenZeppelin Contracts (last updated v5.0.0) (utils/math/Math.sol) +// OpenZeppelin Contracts (last updated v5.1.0) (utils/math/Math.sol) pragma solidity ^0.8.20; @@ -176,7 +176,7 @@ library Math { function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) { unchecked { // 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use - // use the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 + // the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256 // variables such that product = prod1 * 2²⁵⁶ + prod0. uint256 prod0 = x * y; // Least significant 256 bits of the product uint256 prod1; // Most significant 256 bits of the product @@ -264,7 +264,7 @@ library Math { /** * @dev Calculate the modular multiplicative inverse of a number in Z/nZ. * - * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0. + * If n is a prime, then Z/nZ is a field. In that case all elements are inversible, except 0. * If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible. * * If the input value is not inversible, 0 is returned. @@ -569,41 +569,45 @@ library Math { * @dev Return the log in base 2 of a positive value rounded towards zero. * Returns 0 if given 0. */ - function log2(uint256 value) internal pure returns (uint256) { - uint256 result = 0; - uint256 exp; - unchecked { - exp = 128 * SafeCast.toUint(value > (1 << 128) - 1); - value >>= exp; - result += exp; - - exp = 64 * SafeCast.toUint(value > (1 << 64) - 1); - value >>= exp; - result += exp; - - exp = 32 * SafeCast.toUint(value > (1 << 32) - 1); - value >>= exp; - result += exp; - - exp = 16 * SafeCast.toUint(value > (1 << 16) - 1); - value >>= exp; - result += exp; - - exp = 8 * SafeCast.toUint(value > (1 << 8) - 1); - value >>= exp; - result += exp; - - exp = 4 * SafeCast.toUint(value > (1 << 4) - 1); - value >>= exp; - result += exp; - - exp = 2 * SafeCast.toUint(value > (1 << 2) - 1); - value >>= exp; - result += exp; - - result += SafeCast.toUint(value > 1); + function log2(uint256 x) internal pure returns (uint256 r) { + // If value has upper 128 bits set, log2 result is at least 128 + r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7; + // If upper 64 bits of 128-bit half set, add 64 to result + r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6; + // If upper 32 bits of 64-bit half set, add 32 to result + r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5; + // If upper 16 bits of 32-bit half set, add 16 to result + r |= SafeCast.toUint((x >> r) > 0xffff) << 4; + // If upper 8 bits of 16-bit half set, add 8 to result + r |= SafeCast.toUint((x >> r) > 0xff) << 3; + // If upper 4 bits of 8-bit half set, add 4 to result + r |= SafeCast.toUint((x >> r) > 0xf) << 2; + + // Shifts value right by the current result and use it as an index into this lookup table: + // + // | x (4 bits) | index | table[index] = MSB position | + // |------------|---------|-----------------------------| + // | 0000 | 0 | table[0] = 0 | + // | 0001 | 1 | table[1] = 0 | + // | 0010 | 2 | table[2] = 1 | + // | 0011 | 3 | table[3] = 1 | + // | 0100 | 4 | table[4] = 2 | + // | 0101 | 5 | table[5] = 2 | + // | 0110 | 6 | table[6] = 2 | + // | 0111 | 7 | table[7] = 2 | + // | 1000 | 8 | table[8] = 3 | + // | 1001 | 9 | table[9] = 3 | + // | 1010 | 10 | table[10] = 3 | + // | 1011 | 11 | table[11] = 3 | + // | 1100 | 12 | table[12] = 3 | + // | 1101 | 13 | table[13] = 3 | + // | 1110 | 14 | table[14] = 3 | + // | 1111 | 15 | table[15] = 3 | + // + // The lookup table is represented as a 32-byte value with the MSB positions for 0-15 in the last 16 bytes. + assembly ("memory-safe") { + r := or(r, byte(shr(r, x), 0x0000010102020202030303030303030300000000000000000000000000000000)) } - return result; } /** @@ -672,29 +676,17 @@ library Math { * * Adding one to the result gives the number of pairs of hex symbols needed to represent `value` as a hex string. */ - function log256(uint256 value) internal pure returns (uint256) { - uint256 result = 0; - uint256 isGt; - unchecked { - isGt = SafeCast.toUint(value > (1 << 128) - 1); - value >>= isGt * 128; - result += isGt * 16; - - isGt = SafeCast.toUint(value > (1 << 64) - 1); - value >>= isGt * 64; - result += isGt * 8; - - isGt = SafeCast.toUint(value > (1 << 32) - 1); - value >>= isGt * 32; - result += isGt * 4; - - isGt = SafeCast.toUint(value > (1 << 16) - 1); - value >>= isGt * 16; - result += isGt * 2; - - result += SafeCast.toUint(value > (1 << 8) - 1); - } - return result; + function log256(uint256 x) internal pure returns (uint256 r) { + // If value has upper 128 bits set, log2 result is at least 128 + r = SafeCast.toUint(x > 0xffffffffffffffffffffffffffffffff) << 7; + // If upper 64 bits of 128-bit half set, add 64 to result + r |= SafeCast.toUint((x >> r) > 0xffffffffffffffff) << 6; + // If upper 32 bits of 64-bit half set, add 32 to result + r |= SafeCast.toUint((x >> r) > 0xffffffff) << 5; + // If upper 16 bits of 32-bit half set, add 16 to result + r |= SafeCast.toUint((x >> r) > 0xffff) << 4; + // Add 1 if upper 8 bits of 16-bit half set, and divide accumulated result by 8 + return (r >> 3) | SafeCast.toUint((x >> r) > 0xff); } /** From 00ba3366f858da4c776fe89dcf221641c0f58502 Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Wed, 26 Feb 2025 09:47:37 +0100 Subject: [PATCH 4/5] Apply suggestions from code review --- contracts/utils/math/Math.sol | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 11b0f90d60d..d6f16d1a042 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -62,7 +62,7 @@ library Math { unchecked { success = b > 0; assembly ("memory-safe") { - // In EVM any value divided by zero is zero. + // The `DIV` opcode returns zero when the denominator is 0. result := div(a, b) } } @@ -75,7 +75,7 @@ library Math { unchecked { success = b > 0; assembly ("memory-safe") { - // In EVM a value modulus zero is equal to zero. + // The `MOD` opcode returns zero when the denominator is 0. result := mod(a, b) } } From 260eb7b053d6eec6d560e68f163c7a052a0e274f Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Wed, 26 Feb 2025 14:19:30 +0100 Subject: [PATCH 5/5] Update Math.sol --- contracts/utils/math/Math.sol | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 8c99df5f5cc..46fb66562d9 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -110,7 +110,7 @@ library Math { } /** - * @dev Unsigned saturating addition, bounds to `2 ** 256 - 1` instead of overflowing. + * @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing. */ function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) { (bool success, uint256 result) = tryAdd(a, b); @@ -126,7 +126,7 @@ library Math { } /** - * @dev Unsigned saturating multiplication, bounds to `2 ** 256 - 1` instead of overflowing. + * @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing. */ function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) { (bool success, uint256 result) = tryMul(a, b);