Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fair-pumpkins-compete.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`Math`: Add saturating arithmetic operations `saturatingAdd`, `saturatingSub` and `saturatingMul`.
64 changes: 48 additions & 16 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ library Math {
function tryAdd(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
uint256 c = a + b;
if (c < a) return (false, 0);
return (true, c);
success = c >= a;
result = c * SafeCast.toUint(success);
}
}

Expand All @@ -61,8 +61,9 @@ library Math {
*/
function trySub(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
if (b > a) return (false, 0);
return (true, a - b);
uint256 c = a - b;
success = c <= a;
result = c * SafeCast.toUint(success);
}
}

Expand All @@ -71,13 +72,14 @@ library Math {
*/
function tryMul(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
// Gas optimization: this is cheaper than requiring 'a' not being zero, but the
// benefit is lost if 'b' is also tested.
// See: https://github.com/OpenZeppelin/openzeppelin-contracts/pull/522
if (a == 0) return (true, 0);
uint256 c = a * b;
if (c / a != b) return (false, 0);
return (true, c);
assembly ("memory-safe") {
// Only true when the multiplication doesn't overflow
// (c / a == b) || (a == 0)
success := or(eq(div(c, a), b), iszero(a))
}
// equivalent to: success ? c : 0
result = c * SafeCast.toUint(success);
}
}

Expand All @@ -86,8 +88,11 @@ library Math {
*/
function tryDiv(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
if (b == 0) return (false, 0);
return (true, a / b);
success = b > 0;
assembly ("memory-safe") {
// The `DIV` opcode returns zero when the denominator is 0.
result := div(a, b)
}
}
}

Expand All @@ -96,11 +101,38 @@ library Math {
*/
function tryMod(uint256 a, uint256 b) internal pure returns (bool success, uint256 result) {
unchecked {
if (b == 0) return (false, 0);
return (true, a % b);
success = b > 0;
assembly ("memory-safe") {
// The `MOD` opcode returns zero when the denominator is 0.
result := mod(a, b)
}
}
}

/**
* @dev Unsigned saturating addition, bounds to `2²⁵⁶ - 1` instead of overflowing.
*/
function saturatingAdd(uint256 a, uint256 b) internal pure returns (uint256) {
(bool success, uint256 result) = tryAdd(a, b);
return ternary(success, result, type(uint256).max);
}

/**
* @dev Unsigned saturating subtraction, bounds to zero instead of overflowing.
*/
function saturatingSub(uint256 a, uint256 b) internal pure returns (uint256) {
(, uint256 result) = trySub(a, b);
return result;
}

/**
* @dev Unsigned saturating multiplication, bounds to `2²⁵⁶ - 1` instead of overflowing.
*/
function saturatingMul(uint256 a, uint256 b) internal pure returns (uint256) {
(bool success, uint256 result) = tryMul(a, b);
return ternary(success, result, type(uint256).max);
}

/**
* @dev Branchless ternary evaluation for `a ? b : c`. Gas costs are constant.
*
Expand Down Expand Up @@ -192,7 +224,7 @@ library Math {

// Make division exact by subtracting the remainder from [high low].
uint256 remainder;
assembly {
assembly ("memory-safe") {
// Compute remainder using mulmod.
remainder := mulmod(x, y, denominator)

Expand All @@ -205,7 +237,7 @@ library Math {
// Always >= 1. See https://cs.stackexchange.com/q/138556/92363.

uint256 twos = denominator & (0 - denominator);
assembly {
assembly ("memory-safe") {
// Divide denominator by twos.
denominator := div(denominator, twos)

Expand Down
56 changes: 56 additions & 0 deletions test/utils/math/Math.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,62 @@ describe('Math', function () {
});
});

describe('saturatingAdd', function () {
it('adds correctly', async function () {
const a = 5678n;
const b = 1234n;
await testCommutative(this.mock.$saturatingAdd, a, b, a + b);
await testCommutative(this.mock.$saturatingAdd, a, 0n, a);
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 0n, ethers.MaxUint256);
});

it('bounds on addition overflow', async function () {
await testCommutative(this.mock.$saturatingAdd, ethers.MaxUint256, 1n, ethers.MaxUint256);
await expect(this.mock.$saturatingAdd(ethers.MaxUint256, ethers.MaxUint256)).to.eventually.equal(
ethers.MaxUint256,
);
});
});

describe('saturatingSub', function () {
it('subtracts correctly', async function () {
const a = 5678n;
const b = 1234n;
await expect(this.mock.$saturatingSub(a, b)).to.eventually.equal(a - b);
await expect(this.mock.$saturatingSub(a, a)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(a, 0n)).to.eventually.equal(a);
await expect(this.mock.$saturatingSub(0n, a)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(ethers.MaxUint256, 1n)).to.eventually.equal(ethers.MaxUint256 - 1n);
});

it('bounds on subtraction overflow', async function () {
await expect(this.mock.$saturatingSub(0n, 1n)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(1n, 2n)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(1n, ethers.MaxUint256)).to.eventually.equal(0n);
await expect(this.mock.$saturatingSub(ethers.MaxUint256 - 1n, ethers.MaxUint256)).to.eventually.equal(0n);
});
});

describe('saturatingMul', function () {
it('multiplies correctly', async function () {
const a = 1234n;
const b = 5678n;
await testCommutative(this.mock.$saturatingMul, a, b, a * b);
});

it('multiplies by zero correctly', async function () {
const a = 0n;
const b = 5678n;
await testCommutative(this.mock.$saturatingMul, a, b, 0n);
});

it('bounds on multiplication overflow', async function () {
const a = ethers.MaxUint256;
const b = 2n;
await testCommutative(this.mock.$saturatingMul, a, b, ethers.MaxUint256);
});
});

describe('max', function () {
it('is correctly detected in both position', async function () {
await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n));
Expand Down
Loading