diff --git a/scripts/foundry/DeployL1BridgeContracts.s.sol b/scripts/foundry/DeployL1BridgeContracts.s.sol index 4ae847f0..81eb5d11 100644 --- a/scripts/foundry/DeployL1BridgeContracts.s.sol +++ b/scripts/foundry/DeployL1BridgeContracts.s.sol @@ -156,7 +156,8 @@ contract DeployL1BridgeContracts is Script { L2_SCROLL_MESSENGER_PROXY_ADDR, L1_SCROLL_CHAIN_PROXY_ADDR, L1_MESSAGE_QUEUE_V1_PROXY_ADDR, - L1_MESSAGE_QUEUE_V2_PROXY_ADDR + L1_MESSAGE_QUEUE_V2_PROXY_ADDR, + address(enforcedTxGateway) ); logAddress("L1_SCROLL_MESSENGER_IMPLEMENTATION_ADDR", address(impl)); diff --git a/src/L1/L1ScrollMessenger.sol b/src/L1/L1ScrollMessenger.sol index a7452921..c9700f3a 100644 --- a/src/L1/L1ScrollMessenger.sol +++ b/src/L1/L1ScrollMessenger.sol @@ -53,6 +53,9 @@ contract L1ScrollMessenger is ScrollMessengerBase, IL1ScrollMessenger { /// @notice The address of L1MessageQueueV2 contract. address public immutable messageQueueV2; + /// @notice The address of `EnforcedTxGateway`. + address public immutable enforcedTxGateway; + /*********** * Structs * ***********/ @@ -110,17 +113,15 @@ contract L1ScrollMessenger is ScrollMessengerBase, IL1ScrollMessenger { address _counterpart, address _rollup, address _messageQueueV1, - address _messageQueueV2 + address _messageQueueV2, + address _enforcedTxGateway ) ScrollMessengerBase(_counterpart) { - if (_rollup == address(0) || _messageQueueV1 == address(0) || _messageQueueV2 == address(0)) { - revert ErrorZeroAddress(); - } - _disableInitializers(); rollup = _rollup; messageQueueV1 = _messageQueueV1; messageQueueV2 = _messageQueueV2; + enforcedTxGateway = _enforcedTxGateway; } /// @notice Initialize the storage of L1ScrollMessenger. @@ -193,7 +194,7 @@ contract L1ScrollMessenger is ScrollMessengerBase, IL1ScrollMessenger { } // @note check more `_to` address to avoid attack in the future when we add more gateways. - if (_to == messageQueueV1 || _to == messageQueueV2) { + if (_to == messageQueueV1 || _to == messageQueueV2 || _to == enforcedTxGateway) { revert ErrorForbidToCallMessageQueue(); } _validateTargetAddress(_to); diff --git a/src/test/L1GatewayTestBase.t.sol b/src/test/L1GatewayTestBase.t.sol index fd43715a..26618168 100644 --- a/src/test/L1GatewayTestBase.t.sol +++ b/src/test/L1GatewayTestBase.t.sol @@ -121,7 +121,8 @@ abstract contract L1GatewayTestBase is ScrollTestBase { address(l2Messenger), address(rollup), address(messageQueueV1), - address(messageQueueV2) + address(messageQueueV2), + address(enforcedTxGateway) ) ) ); diff --git a/src/test/L1ScrollMessengerTest.t.sol b/src/test/L1ScrollMessengerTest.t.sol index 50544681..2f911bf0 100644 --- a/src/test/L1ScrollMessengerTest.t.sol +++ b/src/test/L1ScrollMessengerTest.t.sol @@ -62,6 +62,26 @@ contract L1ScrollMessengerTest is L1GatewayTestBase { l1Messenger.relayMessageWithProof(address(this), address(messageQueueV2), 0, 0, new bytes(0), proof); } + function testForbidCallEnforcedGatewayFromL2() external { + bytes32 _xDomainCalldataHash = keccak256( + abi.encodeWithSignature( + "relayMessage(address,address,uint256,uint256,bytes)", + address(this), + address(enforcedTxGateway), + 0, + 0, + new bytes(0) + ) + ); + prepareL2MessageRoot(_xDomainCalldataHash); + + IL1ScrollMessenger.L2MessageProof memory proof; + proof.batchIndex = rollup.lastFinalizedBatchIndex(); + + hevm.expectRevert(L1ScrollMessenger.ErrorForbidToCallMessageQueue.selector); + l1Messenger.relayMessageWithProof(address(this), address(enforcedTxGateway), 0, 0, new bytes(0), proof); + } + function testForbidCallSelfFromL2() external { bytes32 _xDomainCalldataHash = keccak256( abi.encodeWithSignature( diff --git a/src/test/L2ScrollMessenger.t.sol b/src/test/L2ScrollMessenger.t.sol index b0c8c2f9..2bdb6d5d 100644 --- a/src/test/L2ScrollMessenger.t.sol +++ b/src/test/L2ScrollMessenger.t.sol @@ -28,7 +28,7 @@ contract L2ScrollMessengerTest is DSTestPlus { function setUp() public { // Deploy L1 contracts - l1Messenger = new L1ScrollMessenger(address(1), address(1), address(1), address(1)); + l1Messenger = new L1ScrollMessenger(address(1), address(1), address(1), address(1), address(1)); // Deploy L2 contracts whitelist = new Whitelist(address(this));