diff --git a/.changeset/fair-pumpkins-compete.md b/.changeset/fair-pumpkins-compete.md new file mode 100644 index 00000000000..1022617190f --- /dev/null +++ b/.changeset/fair-pumpkins-compete.md @@ -0,0 +1,5 @@ +--- +'openzeppelin-solidity': minor +--- + +`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`. diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index 045d310cf11..46fb66562d9 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -51,8 +51,8 @@ library Math { function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { uint256 c = a + b; - if (c < a) return (false, 0); - return (true, c); + success = c >= a; + result = c * SafeCast.toUint(success); } } @@ -61,8 +61,9 @@ library Math { */ function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - if (b > a) return (false, 0); - return (true, a - b); + uint256 c = a - b; + success = c <= a; + result = c * SafeCast.toUint(success); } } @@ -71,13 +72,14 @@ library Math { */ function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - // Gas optimization: this is cheaper than requiring 'a' not being zero, but the - // benefit is lost if 'b' is also tested. - // See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522 - if (a == 0) return (true, 0); uint256 c = a * b; - if (c / a != b) return (false, 0); - return (true, c); + assembly ("memory-safe") { + // Only true when the multiplication doesn't overflow + // (c / a == b) || (a == 0) + success := or(eq(div(c, a), b), iszero(a)) + } + // equivalent to: success ? c : 0 + result = c * SafeCast.toUint(success); } } @@ -86,8 +88,11 @@ library Math { */ function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - if (b == 0) return (false, 0); - return (true, a / b); + success = b > 0; + assembly ("memory-safe") { + // The `DIV` opcode returns zero when the denominator is 0. + result := div(a, b) + } } } @@ -96,11 +101,38 @@ library Math { */ function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) { unchecked { - if (b == 0) return (false, 0); - return (true, a % b); + success = b > 0; + assembly ("memory-safe") { + // The `MOD` opcode returns zero when the denominator is 0. + result := mod(a, b) + } } } + /** + * @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing. + */ + function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) { + (bool success, uint256 result) = tryAdd(a, b); + return ternary(success, result, type(uint256).max); + } + + /** + * @dev Unsigned saturating subtraction, bounds to zero instead of overflowing. + */ + function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) { + (, uint256 result) = trySub(a, b); + return result; + } + + /** + * @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing. + */ + function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) { + (bool success, uint256 result) = tryMul(a, b); + return ternary(success, result, type(uint256).max); + } + /** * @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant. * @@ -192,7 +224,7 @@ library Math { // Make division exact by subtracting the remainder from [high low]. uint256 remainder; - assembly { + assembly ("memory-safe") { // Compute remainder using mulmod. remainder := mulmod(x, y, denominator) @@ -205,7 +237,7 @@ library Math { // Always >= 1. See https://cs.stackexchange.com/q/138556/92363. uint256 twos = denominator & (0 - denominator); - assembly { + assembly ("memory-safe") { // Divide denominator by twos. denominator := div(denominator, twos) diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index b2d7cd7ea2f..6a09938148a 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -168,6 +168,62 @@ describe('Math', function () { }); }); + describe('saturatingAdd', function () { + it('adds correctly', async function () { + const a = 5678n; + const b = 1234n; + await testCommutative(this.mock.$saturatingAdd, a, b, a + b); + await testCommutative(this.mock.$saturatingAdd, a, 0n, a); + await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256); + }); + + it('bounds on addition overflow', async function () { + await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256); + await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal( + ethers.MaxUint256, + ); + }); + }); + + describe('saturatingSub', function () { + it('subtracts correctly', async function () { + const a = 5678n; + const b = 1234n; + await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b); + await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a); + await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n); + }); + + it('bounds on subtraction overflow', async function () { + await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n); + await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n); + }); + }); + + describe('saturatingMul', function () { + it('multiplies correctly', async function () { + const a = 1234n; + const b = 5678n; + await testCommutative(this.mock.$saturatingMul, a, b, a * b); + }); + + it('multiplies by zero correctly', async function () { + const a = 0n; + const b = 5678n; + await testCommutative(this.mock.$saturatingMul, a, b, 0n); + }); + + it('bounds on multiplication overflow', async function () { + const a = ethers.MaxUint256; + const b = 2n; + await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256); + }); + }); + describe('max', function () { it('is correctly detected in both position', async function () { await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));