diff --git a/modules/engine/src/errors.ts b/modules/engine/src/errors.ts index d635865ac..43525bdf3 100644 --- a/modules/engine/src/errors.ts +++ b/modules/engine/src/errors.ts @@ -46,36 +46,6 @@ export class CheckInError extends EngineError { } } -export class RestoreError extends EngineError { - static readonly type = "RestoreError"; - - static readonly reasons = { - AckFailed: "Could not send restore ack", - AcquireLockError: "Failed to acquire restore lock", - ChannelNotFound: "Channel not found", - CouldNotGetActiveTransfers: "Failed to retrieve active transfers from store", - CouldNotGetChannel: "Failed to retrieve channel from store", - GetChannelAddressFailed: "Failed to calculate channel address for verification", - InvalidChannelAddress: "Failed to verify channel address", - InvalidMerkleRoot: "Failed to validate merkleRoot for restoration", - InvalidSignatures: "Failed to validate sigs on latestUpdate", - NoData: "No data sent from counterparty to restore", - ReceivedError: "Got restore error from counterparty", - ReleaseLockError: "Failed to release restore lock", - SaveChannelFailed: "Failed to save channel state", - SyncableState: "Cannot restore, state is syncable. Try reconcileDeposit", - } as const; - - constructor( - public readonly message: Values, - channelAddress: string, - publicIdentifier: string, - context: any = {}, - ) { - super(message, channelAddress, publicIdentifier, context, RestoreError.type); - } -} - export class IsAliveError extends EngineError { static readonly type = "IsAliveError"; diff --git a/modules/engine/src/index.ts b/modules/engine/src/index.ts index f62f679d0..d061b1417 100644 --- a/modules/engine/src/index.ts +++ b/modules/engine/src/index.ts @@ -1,3 +1,4 @@ +import { WithdrawCommitment } from "@connext/vector-contracts"; import { Vector } from "@connext/vector-protocol"; import { ChainAddresses, @@ -19,31 +20,20 @@ import { IExternalValidation, AUTODEPLOY_CHAIN_IDS, EngineError, - UpdateType, - Values, VectorError, jsonifyError, MinimalTransaction, WITHDRAWAL_RESOLVED_EVENT, VectorErrorJson, - ProtocolError, } from "@connext/vector-types"; -import { - generateMerkleRoot, - validateChannelUpdateSignatures, - getSignerAddressFromPublicIdentifier, - getRandomBytes32, - getParticipant, - hashWithdrawalQuote, - delay, -} from "@connext/vector-utils"; +import { getRandomBytes32, getParticipant, hashWithdrawalQuote, delay } from "@connext/vector-utils"; import pino from "pino"; import Ajv from "ajv"; import { Evt } from "evt"; import { version } from "../package.json"; -import { DisputeError, IsAliveError, RestoreError, RpcError } from "./errors"; +import { DisputeError, IsAliveError, RpcError } from "./errors"; import { convertConditionalTransferParams, convertResolveConditionParams, @@ -53,8 +43,6 @@ import { import { setupEngineListeners } from "./listeners"; import { getEngineEvtContainer, withdrawRetryForTransferId, addTransactionToCommitment } from "./utils"; import { sendIsAlive } from "./isAlive"; -import { WithdrawCommitment } from "@connext/vector-contracts"; -import { FullChannelState } from "../../types/dist/src"; export const ajv = new Ajv(); @@ -578,10 +566,7 @@ export class VectorEngine implements IVectorEngine { if (setupParamsResult.isError) { return Result.fail(setupParamsResult.getError()!); } - const setupRes = await this.runProtocolMethodWithRetries( - () => this.vector.setup(setupParamsResult.getValue()), - "", - ); + const setupRes = await this.vector.setup(setupParamsResult.getValue()); if (setupRes.isError) { return Result.fail(setupRes.getError()!); @@ -683,10 +668,30 @@ export class VectorEngine implements IVectorEngine { // leaving all 8 out of the channel. // This race condition should be handled by the protocol retries - const depositRes = await this.runProtocolMethodWithRetries( - () => this.vector.deposit(params), - params.channelAddress, - ); + const timeout = 500; + let depositRes = await this.vector.deposit(params); + let count = 1; + for (const _ of Array(3).fill(0)) { + // If its not an error, do not retry + if (!depositRes.isError) { + break; + } + const error = depositRes.getError()!; + // IFF deposit fails because you or the counterparty fails to recover + // signatures, retry + // This should be the message from *.reasons.BadSignatures in the protocol + // errors + const recoveryErr = "Could not recover signers"; + const recoveryFailed = error.message === recoveryErr || error.context?.counterpartyError?.message === recoveryErr; + + if (!recoveryFailed) { + break; + } + this.logger.warn({ attempt: count, channelAddress: params.channelAddress }, "Retrying deposit reconciliation"); + depositRes = await this.vector.deposit(params); + count++; + await delay(timeout); + } this.logger.info( { result: depositRes.isError ? jsonifyError(depositRes.getError()!) : depositRes.getValue().channelAddress, @@ -786,10 +791,7 @@ export class VectorEngine implements IVectorEngine { } const createParams = createResult.getValue(); this.logger.info({ transferParams: createParams, method, methodId }, "Created conditional transfer params"); - const protocolRes = await this.runProtocolMethodWithRetries( - () => this.vector.create(createParams), - createParams.channelAddress, - ); + const protocolRes = await this.vector.create(createParams); if (protocolRes.isError) { return Result.fail(protocolRes.getError()!); } @@ -835,10 +837,7 @@ export class VectorEngine implements IVectorEngine { return Result.fail(resolveResult.getError()!); } const resolveParams = resolveResult.getValue(); - const protocolRes = await this.runProtocolMethodWithRetries( - () => this.vector.resolve(resolveParams), - resolveParams.channelAddress, - ); + const protocolRes = await this.vector.resolve(resolveParams); if (protocolRes.isError) { return Result.fail(protocolRes.getError()!); } @@ -902,10 +901,7 @@ export class VectorEngine implements IVectorEngine { ); // create withdrawal transfer - const protocolRes = await this.runProtocolMethodWithRetries( - () => this.vector.create(createParams), - createParams.channelAddress, - ); + const protocolRes = await this.vector.create(createParams); if (protocolRes.isError) { return Result.fail(protocolRes.getError()!); } @@ -1195,119 +1191,25 @@ export class VectorEngine implements IVectorEngine { ); } - // Send message to counterparty, they will grab lock and - // return information under lock, initiator will update channel, - // then send confirmation message to counterparty, who will release the lock - const { chainId, counterpartyIdentifier } = params; - const restoreDataRes = await this.messaging.sendRestoreStateMessage( - Result.ok({ chainId }), - counterpartyIdentifier, - this.signer.publicIdentifier, - ); - if (restoreDataRes.isError) { - return Result.fail(restoreDataRes.getError()!); + // Request protocol restore + const restoreResult = await this.vector.restoreState(params); + if (restoreResult.isError) { + return Result.fail(restoreResult.getError()!); } - const { channel, activeTransfers } = restoreDataRes.getValue() ?? ({} as any); - - // Create helper to generate error - const generateRestoreError = ( - error: Values, - context: any = {}, - ): Result => { - // handle error by returning it to counterparty && returning result - const err = new RestoreError(error, channel?.channelAddress ?? "", this.publicIdentifier, { - ...context, - method, - params, - }); - return Result.fail(err); - }; - - // Verify data exists - if (!channel || !activeTransfers) { - return generateRestoreError(RestoreError.reasons.NoData); - } - - // Verify channel address is same as calculated - const counterparty = getSignerAddressFromPublicIdentifier(counterpartyIdentifier); - const calculated = await this.chainService.getChannelAddress( - channel.alice === this.signer.address ? this.signer.address : counterparty, - channel.bob === this.signer.address ? this.signer.address : counterparty, - channel.networkContext.channelFactoryAddress, - chainId, - ); - if (calculated.isError) { - return generateRestoreError(RestoreError.reasons.GetChannelAddressFailed, { - getChannelAddressError: jsonifyError(calculated.getError()!), - }); - } - if (calculated.getValue() !== channel.channelAddress) { - return generateRestoreError(RestoreError.reasons.InvalidChannelAddress, { - calculated: calculated.getValue(), - }); - } - - // Verify signatures on latest update - const sigRes = await validateChannelUpdateSignatures( - channel, - channel.latestUpdate.aliceSignature, - channel.latestUpdate.bobSignature, - "both", - ); - if (sigRes.isError) { - return generateRestoreError(RestoreError.reasons.InvalidSignatures, { - recoveryError: sigRes.getError().message, - }); - } - - // Verify transfers match merkleRoot - const root = generateMerkleRoot(activeTransfers); - if (root !== channel.merkleRoot) { - return generateRestoreError(RestoreError.reasons.InvalidMerkleRoot, { - calculated: root, - merkleRoot: channel.merkleRoot, - activeTransfers: activeTransfers.map((t) => t.transferId), - }); - } - - // Verify nothing with a sync-able nonce exists in store - const existing = await this.getChannelState({ channelAddress: channel.channelAddress }); - if (existing.isError) { - return generateRestoreError(RestoreError.reasons.CouldNotGetChannel, { - getChannelStateError: jsonifyError(existing.getError()!), - }); - } - const nonce = existing.getValue()?.nonce ?? 0; - const diff = channel.nonce - nonce; - if (diff <= 1 && channel.latestUpdate.type !== UpdateType.setup) { - return generateRestoreError(RestoreError.reasons.SyncableState, { - existing: nonce, - toRestore: channel.nonce, - }); - } - - // Save channel - try { - await this.store.saveChannelStateAndTransfers(channel, activeTransfers); - } catch (e) { - return generateRestoreError(RestoreError.reasons.SaveChannelFailed, { - saveChannelStateAndTransfersError: e.message, - }); - } + const channel = restoreResult.getValue(); // Post to evt this.evts[EngineEvents.RESTORE_STATE_EVENT].post({ channelAddress: channel.channelAddress, aliceIdentifier: channel.aliceIdentifier, bobIdentifier: channel.bobIdentifier, - chainId, + chainId: channel.networkContext.chainId, }); this.logger.info( { - channel, - transfers: activeTransfers.map((t) => t.transferId), + channel: channel.channelAddress, method, methodId, }, @@ -1585,24 +1487,6 @@ export class VectorEngine implements IVectorEngine { } } - private async runProtocolMethodWithRetries( - fn: () => Promise>, - channelAddress: string, - retryCount = 5, - ) { - const result = await fn(); - // let result: Result | undefined; - // for (let i = 0; i < retryCount; i++) { - // result = await fn(); - // if (!result.isError) { - // return result; - // } - // this.logger.warn({ attempt: i, error: result.getError().message, channelAddress }, "Protocol method failed"); - // await delay(500); - // } - return result as Result; - } - // JSON RPC interface -- this will accept: // - "chan_deposit" // - "chan_createTransfer" diff --git a/modules/engine/src/listeners.ts b/modules/engine/src/listeners.ts index 9df66dfef..502af2bd9 100644 --- a/modules/engine/src/listeners.ts +++ b/modules/engine/src/listeners.ts @@ -44,7 +44,7 @@ import { BigNumber } from "@ethersproject/bignumber"; import { Zero } from "@ethersproject/constants"; import Pino, { BaseLogger } from "pino"; -import { IsAliveError, RestoreError, WithdrawQuoteError } from "./errors"; +import { IsAliveError, WithdrawQuoteError } from "./errors"; import { EngineEvtContainer } from "./index"; import { submitUnsubmittedWithdrawals } from "./utils"; @@ -169,79 +169,6 @@ export async function setupEngineListeners( }, ); - await messaging.onReceiveRestoreStateMessage( - signer.publicIdentifier, - async (restoreData: Result<{ chainId: number }, EngineError>, from: string, inbox: string) => { - // If it is from yourself, do nothing - if (from === signer.publicIdentifier) { - return; - } - const method = "onReceiveRestoreStateMessage"; - logger.warn({ method, data: restoreData.toJson(), inbox }, "Handling message"); - - // Received error from counterparty - if (restoreData.isError) { - logger.error({ message: restoreData.getError()!.message, method }, "Error received from counterparty restore"); - return; - } - - const data = restoreData.getValue(); - const [key] = Object.keys(data ?? []); - if (key !== "chainId") { - logger.error({ data }, "Message malformed"); - return; - } - - // Counterparty looking to initiate a restore - let channel: FullChannelState | undefined; - const sendCannotRestoreFromError = (error: Values, context: any = {}) => { - return messaging.respondToRestoreStateMessage( - inbox, - Result.fail( - new RestoreError(error, channel?.channelAddress ?? "", signer.publicIdentifier, { ...context, method }), - ), - ); - }; - - // Get info from store to send to counterparty - const { chainId } = data as any; - try { - channel = await store.getChannelStateByParticipants(signer.publicIdentifier, from, chainId); - } catch (e) { - return sendCannotRestoreFromError(RestoreError.reasons.CouldNotGetChannel, { - storeMethod: "getChannelStateByParticipants", - chainId, - identifiers: [signer.publicIdentifier, from], - }); - } - if (!channel) { - return sendCannotRestoreFromError(RestoreError.reasons.ChannelNotFound, { chainId }); - } - let activeTransfers: FullTransferState[]; - try { - activeTransfers = await store.getActiveTransfers(channel.channelAddress); - } catch (e) { - return sendCannotRestoreFromError(RestoreError.reasons.CouldNotGetActiveTransfers, { - storeMethod: "getActiveTransfers", - chainId, - channelAddress: channel.channelAddress, - }); - } - - // Send info to counterparty - logger.warn( - { - method, - channel: channel.channelAddress, - nonce: channel.nonce, - activeTransfers: activeTransfers.map((a) => a.transferId), - }, - "Sending counterparty state to sync", - ); - await messaging.respondToRestoreStateMessage(inbox, Result.ok({ channel, activeTransfers })); - }, - ); - await messaging.onReceiveIsAliveMessage( signer.publicIdentifier, async ( diff --git a/modules/engine/src/testing/listeners.spec.ts b/modules/engine/src/testing/listeners.spec.ts index 5031c9f84..708b55880 100644 --- a/modules/engine/src/testing/listeners.spec.ts +++ b/modules/engine/src/testing/listeners.spec.ts @@ -100,8 +100,6 @@ describe(testName, () => { let store: Sinon.SinonStubbedInstance; let chainService: Sinon.SinonStubbedInstance; let messaging: Sinon.SinonStubbedInstance; - let acquireRestoreLockStub: Sinon.SinonStub; - let releaseRestoreLockStub: Sinon.SinonStub; // Create an EVT to post to, that can be aliased as a // vector instance @@ -131,10 +129,6 @@ describe(testName, () => { vector = Sinon.createStubInstance(Vector); messaging = Sinon.createStubInstance(MemoryMessagingService); vector.on = on as any; - - // By default acquire/release for restore succeeds - acquireRestoreLockStub = Sinon.stub().resolves(Result.ok(undefined)); - releaseRestoreLockStub = Sinon.stub().resolves(Result.ok(undefined)); }); afterEach(() => { diff --git a/modules/engine/src/testing/utils.spec.ts b/modules/engine/src/testing/utils.spec.ts index 180edd006..a2352c450 100644 --- a/modules/engine/src/testing/utils.spec.ts +++ b/modules/engine/src/testing/utils.spec.ts @@ -59,8 +59,6 @@ describe(testName, () => { let store: Sinon.SinonStubbedInstance; let chainService: Sinon.SinonStubbedInstance; let messaging: Sinon.SinonStubbedInstance; - let acquireRestoreLockStub: Sinon.SinonStub; - let releaseRestoreLockStub: Sinon.SinonStub; let withdrawRetryForTrasferIdStub: Sinon.SinonStub; // Create an EVT to post to, that can be aliased as a @@ -92,9 +90,6 @@ describe(testName, () => { messaging = Sinon.createStubInstance(MemoryMessagingService); vector.on = on as any; - // By default acquire/release for restore succeeds - acquireRestoreLockStub = Sinon.stub().resolves(Result.ok(undefined)); - releaseRestoreLockStub = Sinon.stub().resolves(Result.ok(undefined)); withdrawRetryForTrasferIdStub = Sinon.stub(utils, "withdrawRetryForTransferId"); }); diff --git a/modules/protocol/src/errors.ts b/modules/protocol/src/errors.ts index 05b4c3749..078714fd8 100644 --- a/modules/protocol/src/errors.ts +++ b/modules/protocol/src/errors.ts @@ -11,6 +11,36 @@ import { Result, } from "@connext/vector-types"; +export class RestoreError extends ProtocolError { + static readonly type = "RestoreError"; + + static readonly reasons = { + AckFailed: "Could not send restore ack", + AcquireLockError: "Failed to acquire restore lock", + ChannelNotFound: "Channel not found", + CouldNotGetActiveTransfers: "Failed to retrieve active transfers from store", + CouldNotGetChannel: "Failed to retrieve channel from store", + GetChannelAddressFailed: "Failed to calculate channel address for verification", + InvalidChannelAddress: "Failed to verify channel address", + InvalidMerkleRoot: "Failed to validate merkleRoot for restoration", + InvalidSignatures: "Failed to validate sigs on latestUpdate", + NoData: "No data sent from counterparty to restore", + ReceivedError: "Got restore error from counterparty", + ReleaseLockError: "Failed to release restore lock", + SaveChannelFailed: "Failed to save channel state", + SyncableState: "Cannot restore, state is syncable. Try reconcileDeposit", + } as const; + + constructor( + public readonly message: Values, + channel: FullChannelState, + publicIdentifier: string, + context: any = {}, + ) { + super(message, channel, undefined, undefined, { publicIdentifier, ...context }, RestoreError.type); + } +} + export class ValidationError extends ProtocolError { static readonly type = "ValidationError"; @@ -114,6 +144,7 @@ export class QueuedUpdateError extends ProtocolError { Cancelled: "Queued update was cancelled", CannotSyncSetup: "Cannot sync a setup update, must restore", // TODO: remove ChannelNotFound: "Channel not found", + ChannelRestoring: "Channel is restoring, cannot update", CouldNotGetParams: "Could not generate params from update", CouldNotGetResolvedBalance: "Could not retrieve resolved balance from chain", CounterpartyFailure: "Counterparty failed to apply update", diff --git a/modules/protocol/src/testing/integration/restore.spec.ts b/modules/protocol/src/testing/integration/restore.spec.ts new file mode 100644 index 000000000..72a5b7264 --- /dev/null +++ b/modules/protocol/src/testing/integration/restore.spec.ts @@ -0,0 +1,92 @@ +import { delay, expect, getTestLoggers } from "@connext/vector-utils"; +import { FullChannelState, IChannelSigner, IVectorProtocol, IVectorStore, Result } from "@connext/vector-types"; +import { AddressZero } from "@ethersproject/constants"; + +import { createTransfer, getFundedChannel } from "../utils"; +import { env } from "../env"; +import { QueuedUpdateError } from "../../errors"; + +const testName = "Restore Integrations"; +const { log } = getTestLoggers(testName, env.logLevel); + +describe(testName, () => { + let alice: IVectorProtocol; + let bob: IVectorProtocol; + + let abChannelAddress: string; + let aliceSigner: IChannelSigner; + let aliceStore: IVectorStore; + let bobSigner: IChannelSigner; + let bobStore: IVectorStore; + let chainId: number; + + afterEach(async () => { + await alice.off(); + await bob.off(); + }); + + beforeEach(async () => { + const setup = await getFundedChannel(testName, [ + { + assetId: AddressZero, + amount: ["100", "100"], + }, + ]); + alice = setup.alice.protocol; + bob = setup.bob.protocol; + abChannelAddress = setup.channel.channelAddress; + aliceSigner = setup.alice.signer; + bobSigner = setup.bob.signer; + aliceStore = setup.alice.store; + bobStore = setup.bob.store; + chainId = setup.channel.networkContext.chainId; + + log.info({ + alice: alice.publicIdentifier, + bob: bob.publicIdentifier, + }); + }); + + it("should work with no transfers", async () => { + // remove channel + await bobStore.clear(); + + // bob should restore + const restore = await bob.restoreState({ counterpartyIdentifier: alice.publicIdentifier, chainId }); + expect(restore.getError()).to.be.undefined; + expect(restore.getValue()).to.be.deep.eq(await aliceStore.getChannelState(abChannelAddress)); + }); + + it("should work with transfers", async () => { + // install transfer + const { transfer } = await createTransfer(abChannelAddress, bob, alice, AddressZero, "1"); + + // remove channel + await bobStore.clear(); + + // bob should restore + const restore = await bob.restoreState({ counterpartyIdentifier: alice.publicIdentifier, chainId }); + + // verify results + expect(restore.getError()).to.be.undefined; + expect(restore.getValue()).to.be.deep.eq(await aliceStore.getChannelState(abChannelAddress)); + expect(await bob.getActiveTransfers(abChannelAddress)).to.be.deep.eq( + await alice.getActiveTransfers(abChannelAddress), + ); + }); + + it("should block updates when restoring", async () => { + // remove channel + await bobStore.clear(); + + // bob should restore, alice should attempt something + const [_, update] = (await Promise.all([ + bob.restoreState({ counterpartyIdentifier: alice.publicIdentifier, chainId }), + bob.deposit({ channelAddress: abChannelAddress, assetId: AddressZero }), + ])) as [Result, Result]; + + // verify update failed + expect(update.isError).to.be.true; + expect(update.getError()?.message).to.be.eq(QueuedUpdateError.reasons.ChannelRestoring); + }); +}); diff --git a/modules/protocol/src/testing/vector.spec.ts b/modules/protocol/src/testing/vector.spec.ts index 34bb77189..f3ed447fe 100644 --- a/modules/protocol/src/testing/vector.spec.ts +++ b/modules/protocol/src/testing/vector.spec.ts @@ -10,6 +10,7 @@ import { MemoryStoreService, expect, MemoryMessagingService, + mkPublicIdentifier, } from "@connext/vector-utils"; import pino from "pino"; import { @@ -20,14 +21,19 @@ import { Result, CreateTransferParams, ChainError, + MessagingError, + FullChannelState, + IChannelSigner, } from "@connext/vector-types"; import Sinon from "sinon"; -import { QueuedUpdateError } from "../errors"; +import { QueuedUpdateError, RestoreError } from "../errors"; import { Vector } from "../vector"; import * as vectorSync from "../sync"; +import * as vectorUtils from "../utils"; import { env } from "./env"; +import { chainId } from "./constants"; describe("Vector", () => { let chainReader: Sinon.SinonStubbedInstance; @@ -450,4 +456,250 @@ describe("Vector", () => { } }); }); + + describe("Vector.restore", () => { + let vector: Vector; + const channelAddress: string = mkAddress("0xccc"); + let counterpartyIdentifier: string; + let channel: FullChannelState; + let sigValidationStub: Sinon.SinonStub; + + beforeEach(async () => { + const signer = getRandomChannelSigner(); + const counterparty = getRandomChannelSigner(); + counterpartyIdentifier = counterparty.publicIdentifier; + + vector = await Vector.connect( + messagingService, + storeService, + signer, + chainReader as IVectorChainReader, + pino(), + false, + ); + + sigValidationStub = Sinon.stub(vectorUtils, "validateChannelSignatures"); + + channel = createTestChannelState(UpdateType.deposit, { + channelAddress, + aliceIdentifier: counterpartyIdentifier, + networkContext: { chainId }, + nonce: 5, + }).channel; + messagingService.sendRestoreStateMessage.resolves( + Result.ok({ + channel, + activeTransfers: [], + }), + ); + chainReader.getChannelAddress.resolves(Result.ok(channel.channelAddress)); + sigValidationStub.resolves(Result.ok(undefined)); + }); + + // UNIT TESTS + describe("should fail if the parameters are malformed", () => { + const paramTests: ParamValidationTest[] = [ + { + name: "should fail if parameters.chainId is invalid", + params: { + chainId: "fail", + counterpartyIdentifier: mkPublicIdentifier(), + }, + error: "should be number", + }, + { + name: "should fail if parameters.chainId is undefined", + params: { + chainId: undefined, + counterpartyIdentifier: mkPublicIdentifier(), + }, + error: "should have required property 'chainId'", + }, + { + name: "should fail if parameters.counterpartyIdentifier is invalid", + params: { + chainId, + counterpartyIdentifier: 1, + }, + error: "should be string", + }, + { + name: "should fail if parameters.counterpartyIdentifier is undefined", + params: { + chainId, + counterpartyIdentifier: undefined, + }, + error: "should have required property 'counterpartyIdentifier'", + }, + ]; + for (const { name, error, params } of paramTests) { + it(name, async () => { + const result = await vector.restoreState(params); + expect(result.isError).to.be.true; + expect(result.getError()?.message).to.be.eq(QueuedUpdateError.reasons.InvalidParams); + expect(result.getError()?.context.paramsError).to.be.eq(error); + }); + } + }); + + describe("restore initiator side", () => { + const runWithFailure = async (message: string) => { + const result = await vector.restoreState({ chainId, counterpartyIdentifier }); + expect(result.getError()).to.not.be.undefined; + expect(result.getError()?.message).to.be.eq(message); + }; + it("should fail if it receives an error", async () => { + messagingService.sendRestoreStateMessage.resolves( + Result.fail(new MessagingError(MessagingError.reasons.Timeout)), + ); + + await runWithFailure(MessagingError.reasons.Timeout); + }); + + it("should fail if there is no channel or active transfers provided", async () => { + messagingService.sendRestoreStateMessage.resolves( + Result.ok({ channel: undefined, activeTransfers: undefined }) as any, + ); + + await runWithFailure(RestoreError.reasons.NoData); + }); + + it("should fail if chainReader.geChannelAddress fails", async () => { + chainReader.getChannelAddress.resolves(Result.fail(new ChainError("fail"))); + + await runWithFailure(RestoreError.reasons.GetChannelAddressFailed); + }); + + it("should fail if it gives the wrong channel by channel address", async () => { + chainReader.getChannelAddress.resolves(Result.ok(mkAddress("0x334455666666ccccc"))); + + await runWithFailure(RestoreError.reasons.InvalidChannelAddress); + }); + + it("should fail if channel.latestUpdate is malsigned", async () => { + sigValidationStub.resolves(Result.fail(new Error("fail"))); + + await runWithFailure(RestoreError.reasons.InvalidSignatures); + }); + + it("should fail if channel.merkleRoot is incorrect", async () => { + messagingService.sendRestoreStateMessage.resolves( + Result.ok({ + channel: { ...channel, merkleRoot: mkHash("0xddddeeefffff") }, + activeTransfers: [], + }), + ); + + await runWithFailure(RestoreError.reasons.InvalidMerkleRoot); + }); + + it("should fail if the state is syncable", async () => { + storeService.getChannelState.resolves(channel); + + await runWithFailure(RestoreError.reasons.SyncableState); + }); + + it("should fail if store.saveChannelStateAndTransfers fails", async () => { + storeService.getChannelState.resolves(undefined); + storeService.saveChannelStateAndTransfers.rejects(new Error("fail")); + + await runWithFailure(RestoreError.reasons.SaveChannelFailed); + }); + }); + + describe("restore responder side", () => { + // Test with memory messaging service + stubs to properly trigger + // callback + let memoryMessaging: MemoryMessagingService; + let signer: IChannelSigner; + beforeEach(async () => { + memoryMessaging = new MemoryMessagingService(); + signer = getRandomChannelSigner(); + vector = await Vector.connect( + // Use real messaging service to test properly + memoryMessaging, + storeService, + signer, + chainReader as IVectorChainReader, + pino(), + false, + ); + }); + + it("should do nothing if it receives message from itself", async () => { + const response = await memoryMessaging.sendRestoreStateMessage( + Result.ok({ chainId }), + signer.publicIdentifier, + signer.publicIdentifier, + 500, + ); + expect(response.getError()?.message).to.be.eq(MessagingError.reasons.Timeout); + expect(storeService.getChannelStateByParticipants.callCount).to.be.eq(0); + }); + + it("should do nothing if it receives an error", async () => { + const response = await memoryMessaging.sendRestoreStateMessage( + Result.fail(new Error("fail") as any), + signer.publicIdentifier, + mkPublicIdentifier(), + 500, + ); + expect(response.getError()?.message).to.be.eq(MessagingError.reasons.Timeout); + expect(storeService.getChannelStateByParticipants.callCount).to.be.eq(0); + }); + + // Hard to test because of messaging service implementation + it.skip("should do nothing if message is malformed", async () => { + const response = await memoryMessaging.sendRestoreStateMessage( + Result.ok({ test: "test" } as any), + signer.publicIdentifier, + mkPublicIdentifier(), + 500, + ); + expect(response.getError()?.message).to.be.eq(MessagingError.reasons.Timeout); + expect(storeService.getChannelStateByParticipants.callCount).to.be.eq(0); + }); + + it("should send error if it cannot get channel", async () => { + storeService.getChannelStateByParticipants.rejects(new Error("fail")); + const response = await memoryMessaging.sendRestoreStateMessage( + Result.ok({ chainId }), + signer.publicIdentifier, + mkPublicIdentifier(), + ); + expect(response.getError()?.message).to.be.eq(RestoreError.reasons.CouldNotGetChannel); + expect(storeService.getChannelStateByParticipants.callCount).to.be.eq(1); + }); + + it("should send error if it cannot get active transfers", async () => { + storeService.getChannelStateByParticipants.resolves(createTestChannelState(UpdateType.deposit).channel); + storeService.getActiveTransfers.rejects(new Error("fail")); + const response = await memoryMessaging.sendRestoreStateMessage( + Result.ok({ chainId }), + signer.publicIdentifier, + mkPublicIdentifier(), + ); + expect(response.getError()?.message).to.be.eq(RestoreError.reasons.CouldNotGetActiveTransfers); + expect(storeService.getChannelStateByParticipants.callCount).to.be.eq(1); + }); + + it("should send correct information", async () => { + const channel = createTestChannelState(UpdateType.deposit).channel; + storeService.getChannelStateByParticipants.resolves(channel); + storeService.getActiveTransfers.resolves([]); + const response = await memoryMessaging.sendRestoreStateMessage( + Result.ok({ chainId }), + signer.publicIdentifier, + mkPublicIdentifier(), + ); + expect(response.getValue()).to.be.deep.eq({ channel, activeTransfers: [] }); + }); + }); + + it("should work", async () => { + const result = await vector.restoreState({ chainId, counterpartyIdentifier }); + expect(result.getError()).to.be.undefined; + expect(result.getValue()).to.be.deep.eq(channel); + }); + }); }); diff --git a/modules/protocol/src/vector.ts b/modules/protocol/src/vector.ts index c4e6de499..38aa6970e 100644 --- a/modules/protocol/src/vector.ts +++ b/modules/protocol/src/vector.ts @@ -22,14 +22,26 @@ import { UpdateIdentifier, } from "@connext/vector-types"; import { v4 as uuidV4 } from "uuid"; -import { getCreate2MultisigAddress, getRandomBytes32, delay } from "@connext/vector-utils"; +import { + getCreate2MultisigAddress, + getRandomBytes32, + delay, + getSignerAddressFromPublicIdentifier, + generateMerkleRoot, +} from "@connext/vector-utils"; import { Evt } from "evt"; import pino from "pino"; -import { QueuedUpdateError } from "./errors"; +import { QueuedUpdateError, RestoreError } from "./errors"; import { Cancellable, OtherUpdate, SelfUpdate, SerializedQueue } from "./queue"; import { outbound, inbound, OtherUpdateResult, SelfUpdateResult } from "./sync"; -import { extractContextFromStore, persistChannel, validateParamSchema } from "./utils"; +import { + extractContextFromStore, + getNextNonceForUpdate, + persistChannel, + validateChannelSignatures, + validateParamSchema, +} from "./utils"; type EvtContainer = { [K in keyof ProtocolEventPayloadsMap]: Evt }; @@ -42,6 +54,9 @@ export class Vector implements IVectorProtocol { // Do not interact with this directly. Always use getQueueAsync() private queues: Map | undefined>> = new Map(); + // Hold a flag to indicate whether or not a channel is being restored + private restorations: Map = new Map(); + // make it private so the only way to create the class is to use `connect` private constructor( private readonly messagingService: IMessagingService, @@ -176,10 +191,25 @@ export class Vector implements IVectorProtocol { : undefined; return resolve({ cancelled: false, - value: { updatedTransfer: transfer, updatedChannel: channelState, updatedTransfers: activeTransfers }, + value: Result.ok({ + updatedTransfer: transfer, + updatedChannel: channelState, + updatedTransfers: activeTransfers, + }), successfullyApplied: "previouslyExecuted", }); } + + // Make sure channel isnt being restored + if (this.restorations.get(initiated.params.channelAddress)) { + return resolve({ + cancelled: false, + value: Result.fail( + new QueuedUpdateError(QueuedUpdateError.reasons.ChannelRestoring, initiated.params, channelState), + ), + successfullyApplied: "executed", + }); + } try { const ret = await outbound( initiated.params, @@ -316,6 +346,16 @@ export class Vector implements IVectorProtocol { storeError: storeRes.getError()?.message, }); } + // Make sure channel isnt being restored + if (this.restorations.get(received.update.channelAddress)) { + return resolve({ + cancelled: false, + value: Result.fail( + new QueuedUpdateError(QueuedUpdateError.reasons.ChannelRestoring, received.update, channelState), + ), + }); + } + // NOTE: no need to validate that the update has already been executed // because that is asserted on sync, where as an initiator you dont have // that certainty @@ -562,7 +602,7 @@ export class Vector implements IVectorProtocol { return; } - // TODO: why in the world is this causing it to fail + // // TODO: why in the world is this causing it to fail // // Previous update may be undefined, but if it exists, validate // console.log("******** validating schema"); // const previousError = validateParamSchema(received.previousUpdate, TChannelUpdate); @@ -611,6 +651,81 @@ export class Vector implements IVectorProtocol { }, ); + // response to restore messages + await this.messagingService.onReceiveRestoreStateMessage( + this.publicIdentifier, + async (restoreData: Result<{ chainId: number }, ProtocolError>, from: string, inbox: string) => { + // If it is from yourself, do nothing + if (from === this.publicIdentifier) { + return; + } + const method = "onReceiveRestoreStateMessage"; + this.logger.debug({ method, data: restoreData.toJson(), inbox }, "Handling restore message"); + + // Received error from counterparty + if (restoreData.isError) { + this.logger.error( + { message: restoreData.getError()!.message, method }, + "Error received from counterparty restore", + ); + return; + } + + const data = restoreData.getValue(); + const [key] = Object.keys(data ?? []); + if (key !== "chainId") { + this.logger.error({ data }, "Message malformed"); + return; + } + + // Counterparty looking to initiate a restore + let channel: FullChannelState | undefined; + const sendCannotRestoreFromError = (error: Values, context: any = {}) => { + return this.messagingService.respondToRestoreStateMessage( + inbox, + Result.fail(new RestoreError(error, channel!, this.publicIdentifier, { ...context, method })), + ); + }; + + // Get info from store to send to counterparty + const { chainId } = data as any; + try { + channel = await this.storeService.getChannelStateByParticipants(this.publicIdentifier, from, chainId); + } catch (e) { + return sendCannotRestoreFromError(RestoreError.reasons.CouldNotGetChannel, { + storeMethod: "getChannelStateByParticipants", + chainId, + identifiers: [this.publicIdentifier, from], + }); + } + if (!channel) { + return sendCannotRestoreFromError(RestoreError.reasons.ChannelNotFound, { chainId }); + } + let activeTransfers: FullTransferState[]; + try { + activeTransfers = await this.storeService.getActiveTransfers(channel.channelAddress); + } catch (e) { + return sendCannotRestoreFromError(RestoreError.reasons.CouldNotGetActiveTransfers, { + storeMethod: "getActiveTransfers", + chainId, + channelAddress: channel.channelAddress, + }); + } + + // Send info to counterparty + this.logger.info( + { + method, + channel: channel.channelAddress, + nonce: channel.nonce, + activeTransfers: activeTransfers.map((a) => a.transferId), + }, + "Sending counterparty state to sync", + ); + await this.messagingService.respondToRestoreStateMessage(inbox, Result.ok({ channel, activeTransfers })); + }, + ); + // Handle disputes // TODO: if this is awaited, then it may cause problems with the // server-node startup (double check on prod). If it is *not* awaited @@ -804,6 +919,128 @@ export class Vector implements IVectorProtocol { return returnVal; } + public async restoreState( + params: ProtocolParams.Restore, + ): Promise> { + const method = "restoreState"; + const methodId = getRandomBytes32(); + this.logger.debug({ method, methodId }, "Method start"); + // Validate all input + const error = validateParamSchema(params, ProtocolParams.RestoreSchema); + if (error) { + return Result.fail(error); + } + + // Send message to counterparty, they will grab lock and + // return information under lock, initiator will update channel, + // then send confirmation message to counterparty, who will release the lock + const { chainId, counterpartyIdentifier } = params; + const restoreDataRes = await this.messagingService.sendRestoreStateMessage( + Result.ok({ chainId }), + counterpartyIdentifier, + this.signer.publicIdentifier, + ); + if (restoreDataRes.isError) { + return Result.fail(restoreDataRes.getError() as RestoreError); + } + + const { channel, activeTransfers } = restoreDataRes.getValue() ?? ({} as any); + + // Create helper to generate error + const generateRestoreError = ( + error: Values, + context: any = {}, + ): Result => { + // handle error by returning it to counterparty && returning result + const err = new RestoreError(error, channel, this.publicIdentifier, { + ...context, + method, + params, + }); + channel && this.restorations.set(channel.channelAddress, false); + return Result.fail(err); + }; + + // Verify data exists + if (!channel || !activeTransfers) { + return generateRestoreError(RestoreError.reasons.NoData); + } + + // Set restoration for channel to true + this.restorations.set(channel.channelAddress, true); + + // Verify channel address is same as calculated + const counterparty = getSignerAddressFromPublicIdentifier(counterpartyIdentifier); + const calculated = await this.chainReader.getChannelAddress( + channel.alice === this.signer.address ? this.signer.address : counterparty, + channel.bob === this.signer.address ? this.signer.address : counterparty, + channel.networkContext.channelFactoryAddress, + chainId, + ); + if (calculated.isError) { + return generateRestoreError(RestoreError.reasons.GetChannelAddressFailed, { + getChannelAddressError: jsonifyError(calculated.getError()!), + }); + } + if (calculated.getValue() !== channel.channelAddress) { + return generateRestoreError(RestoreError.reasons.InvalidChannelAddress, { + calculated: calculated.getValue(), + }); + } + + // Verify signatures on latest update + const sigRes = await validateChannelSignatures( + channel, + channel.latestUpdate.aliceSignature, + channel.latestUpdate.bobSignature, + "both", + ); + if (sigRes.isError) { + return generateRestoreError(RestoreError.reasons.InvalidSignatures, { + recoveryError: sigRes.getError()!.message, + }); + } + + // Verify transfers match merkleRoot + const root = generateMerkleRoot(activeTransfers); + if (root !== channel.merkleRoot) { + return generateRestoreError(RestoreError.reasons.InvalidMerkleRoot, { + calculated: root, + merkleRoot: channel.merkleRoot, + activeTransfers: activeTransfers.map((t) => t.transferId), + }); + } + + // Verify nothing with a sync-able nonce exists in store + const existing = await this.getChannelState(channel.channelAddress); + const nonce = existing?.nonce ?? 0; + const next = getNextNonceForUpdate(nonce, channel.latestUpdate.fromIdentifier === channel.aliceIdentifier); + if (next === channel.nonce && channel.latestUpdate.type !== UpdateType.setup) { + return generateRestoreError(RestoreError.reasons.SyncableState, { + existing: nonce, + toRestore: channel.nonce, + }); + } + if (nonce >= channel.nonce) { + return generateRestoreError(RestoreError.reasons.SyncableState, { + existing: nonce, + toRestore: channel.nonce, + }); + } + + // Save channel + try { + await this.storeService.saveChannelStateAndTransfers(channel, activeTransfers); + } catch (e) { + return generateRestoreError(RestoreError.reasons.SaveChannelFailed, { + saveChannelStateAndTransfersError: e.message, + }); + } + + this.restorations.set(channel.channelAddress, false); + return Result.ok(channel); + } + /////////////////////////////////// // STORE METHODS public async getChannelState(channelAddress: string): Promise { diff --git a/modules/types/src/channel.ts b/modules/types/src/channel.ts index 155d6123c..80c9ff45b 100644 --- a/modules/types/src/channel.ts +++ b/modules/types/src/channel.ts @@ -61,6 +61,12 @@ export interface UpdateParamsMap { [UpdateType.setup]: SetupParams; } +// Not exactly a channel update, but another protocol method +export type RestoreParams = { + counterpartyIdentifier: string; + chainId: number; +}; + // When generating an update from params, you need to create an // identifier to make sure the update remains idempotent. Imagine // without this and you are trying to apply a `create` update. diff --git a/modules/types/src/messaging.ts b/modules/types/src/messaging.ts index 4716d49fd..f45b6b28d 100644 --- a/modules/types/src/messaging.ts +++ b/modules/types/src/messaging.ts @@ -72,17 +72,17 @@ export interface IMessagingService extends IBasicMessaging { // - counterparty responds // - restore-r restores sendRestoreStateMessage( - restoreData: Result<{ chainId: number }, EngineError>, + restoreData: Result<{ chainId: number }, ProtocolError>, to: string, from: string, timeout?: number, numRetries?: number, ): Promise< - Result<{ channel: FullChannelState; activeTransfers: FullTransferState[] } | void, EngineError | MessagingError> + Result<{ channel: FullChannelState; activeTransfers: FullTransferState[] } | void, ProtocolError | MessagingError> >; onReceiveRestoreStateMessage( publicIdentifier: string, - callback: (restoreData: Result<{ chainId: number }, EngineError>, from: string, inbox: string) => void, + callback: (restoreData: Result<{ chainId: number }, ProtocolError>, from: string, inbox: string) => void, ): Promise; respondToRestoreStateMessage( inbox: string, diff --git a/modules/types/src/protocol.ts b/modules/types/src/protocol.ts index f82bfe0bb..39ef3d60f 100644 --- a/modules/types/src/protocol.ts +++ b/modules/types/src/protocol.ts @@ -7,6 +7,7 @@ import { SetupParams, UpdateType, FullChannelState, + RestoreParams, } from "./channel"; import { ProtocolError, Result } from "./error"; import { ProtocolEventName, ProtocolEventPayloadsMap } from "./event"; @@ -18,6 +19,7 @@ export interface IVectorProtocol { deposit(params: DepositParams): Promise>; create(params: CreateTransferParams): Promise>; resolve(params: ResolveTransferParams): Promise>; + on( event: T, callback: (payload: ProtocolEventPayloadsMap[T]) => void | Promise, @@ -41,6 +43,7 @@ export interface IVectorProtocol { getTransferState(transferId: string): Promise; getActiveTransfers(channelAddress: string): Promise; syncDisputes(): Promise; + restoreState(params: RestoreParams): Promise>; } type VectorChannelMessageData = { diff --git a/modules/types/src/schemas/engine.ts b/modules/types/src/schemas/engine.ts index 655c73640..556b222e1 100644 --- a/modules/types/src/schemas/engine.ts +++ b/modules/types/src/schemas/engine.ts @@ -15,6 +15,7 @@ import { WithdrawalQuoteSchema, TransferQuoteSchema, } from "./basic"; +import { ProtocolParams } from "./protocol"; //////////////////////////////////////// // Engine API Parameter schemas @@ -228,11 +229,11 @@ const SignUtilityMessageParamsSchema = Type.Object({ // Ping-pong const SendIsAliveParamsSchema = Type.Object({ channelAddress: TAddress, skipCheckIn: Type.Boolean() }); -// Restore channel from counterparty -const RestoreStateParamsSchema = Type.Object({ - counterpartyIdentifier: TPublicIdentifier, - chainId: TChainId, -}); +// // Restore channel from counterparty +// const RestoreStateParamsSchema = Type.Object({ +// counterpartyIdentifier: TPublicIdentifier, +// chainId: TChainId, +// }); // Rpc request schema const RpcRequestEngineParamsSchema = Type.Object({ @@ -299,8 +300,8 @@ export namespace EngineParams { export const SetupSchema = SetupEngineParamsSchema; export type Setup = Static; - export const RestoreStateSchema = RestoreStateParamsSchema; - export type RestoreState = Static; + export const RestoreStateSchema = ProtocolParams.RestoreSchema; + export type RestoreState = ProtocolParams.Restore; export const DepositSchema = DepositEngineParamsSchema; export type Deposit = Static; diff --git a/modules/types/src/schemas/protocol.ts b/modules/types/src/schemas/protocol.ts index d8e0c5fcf..178b20f17 100644 --- a/modules/types/src/schemas/protocol.ts +++ b/modules/types/src/schemas/protocol.ts @@ -5,6 +5,7 @@ import { TBalance, TBasicMeta, TBytes32, + TChainId, TIntegerString, TNetworkContext, TPublicIdentifier, @@ -52,6 +53,12 @@ const ResolveProtocolParamsSchema = Type.Object({ meta: Type.Optional(TBasicMeta), }); +// Restore +const RestoreProtocolParamsSchema = Type.Object({ + counterpartyIdentifier: TPublicIdentifier, + chainId: TChainId, +}); + // Namespace export // eslint-disable-next-line @typescript-eslint/no-namespace export namespace ProtocolParams { @@ -63,4 +70,6 @@ export namespace ProtocolParams { export type Create = Static; export const ResolveSchema = ResolveProtocolParamsSchema; export type Resolve = Static; + export const RestoreSchema = RestoreProtocolParamsSchema; + export type Restore = Static; } diff --git a/modules/types/src/store.ts b/modules/types/src/store.ts index 68509adc7..0a19a0e59 100644 --- a/modules/types/src/store.ts +++ b/modules/types/src/store.ts @@ -32,6 +32,8 @@ export interface IVectorStore { // Setters saveChannelState(channelState: FullChannelState, transfer?: FullTransferState): Promise; + // Used for restore + saveChannelStateAndTransfers(channelState: FullChannelState, activeTransfers: FullTransferState[]): Promise; /** * Saves information about a channel dispute from the onchain record @@ -175,8 +177,6 @@ export interface IEngineStore extends IVectorStore, IChainServiceStore { // Setters saveWithdrawalCommitment(transferId: string, withdrawCommitment: WithdrawCommitmentJson): Promise; - // Used for restore - saveChannelStateAndTransfers(channelState: FullChannelState, activeTransfers: FullTransferState[]): Promise; } export interface IServerNodeStore extends IEngineStore { diff --git a/modules/utils/src/test/services/messaging.ts b/modules/utils/src/test/services/messaging.ts index 14f6de919..64ac992de 100644 --- a/modules/utils/src/test/services/messaging.ts +++ b/modules/utils/src/test/services/messaging.ts @@ -19,7 +19,7 @@ import { Evt } from "evt"; import { getRandomBytes32 } from "../../hexStrings"; export class MemoryMessagingService implements IMessagingService { - private readonly evt: Evt<{ + private readonly protocolEvt: Evt<{ to?: string; from: string; inbox?: string; @@ -37,6 +37,24 @@ export class MemoryMessagingService implements IMessagingService { replyTo?: string; }>(); + private readonly restoreEvt: Evt<{ + to?: string; + from?: string; + chainId?: number; + channel?: FullChannelState; + activeTransfers?: FullTransferState[]; + error?: ProtocolError; + inbox?: string; + }> = Evt.create<{ + to?: string; + from?: string; + chainId?: number; + channel?: FullChannelState; + activeTransfers?: FullTransferState[]; + error?: ProtocolError; + inbox?: string; + }>(); + flush(): Promise { throw new Error("Method not implemented."); } @@ -46,7 +64,7 @@ export class MemoryMessagingService implements IMessagingService { } async disconnect(): Promise { - this.evt.detach(); + this.protocolEvt.detach(); } async sendProtocolMessage( @@ -56,8 +74,8 @@ export class MemoryMessagingService implements IMessagingService { numRetries = 0, ): Promise; previousUpdate: ChannelUpdate }, ProtocolError>> { const inbox = getRandomBytes32(); - const responsePromise = this.evt.pipe((e) => e.inbox === inbox).waitFor(timeout); - this.evt.post({ + const responsePromise = this.protocolEvt.pipe((e) => e.inbox === inbox).waitFor(timeout); + this.protocolEvt.post({ to: channelUpdate.toIdentifier, from: channelUpdate.fromIdentifier, replyTo: inbox, @@ -75,7 +93,7 @@ export class MemoryMessagingService implements IMessagingService { channelUpdate: ChannelUpdate, previousUpdate?: ChannelUpdate, ): Promise { - this.evt.post({ + this.protocolEvt.post({ inbox, data: { update: channelUpdate, previousUpdate }, from: channelUpdate.toIdentifier, @@ -83,7 +101,7 @@ export class MemoryMessagingService implements IMessagingService { } async respondWithProtocolError(inbox: string, error: ProtocolError): Promise { - this.evt.post({ + this.protocolEvt.post({ inbox, data: { error }, from: error.context.update.toIdentifier, @@ -98,7 +116,7 @@ export class MemoryMessagingService implements IMessagingService { inbox: string, ) => void, ): Promise { - this.evt + this.protocolEvt .pipe(({ to }) => to === myPublicIdentifier) .attach(({ data, replyTo, from }) => { callback( @@ -112,6 +130,59 @@ export class MemoryMessagingService implements IMessagingService { }); } + async onReceiveRestoreStateMessage( + publicIdentifier: string, + callback: (restoreData: Result<{ chainId: number }, EngineError>, from: string, inbox: string) => void, + ): Promise { + this.restoreEvt + .pipe(({ to }) => to === publicIdentifier) + .attach(({ inbox, from, chainId, error }) => { + callback(!!error ? Result.fail(error) : Result.ok({ chainId }), from, inbox); + }); + } + + async sendRestoreStateMessage( + restoreData: Result<{ chainId: number }, EngineError>, + to: string, + from: string, + timeout?: number, + numRetries?: number, + ): Promise> { + const inbox = getRandomBytes32(); + this.restoreEvt.post({ + to, + from, + error: restoreData.isError ? restoreData.getError() : undefined, + chainId: restoreData.isError ? undefined : restoreData.getValue().chainId, + inbox, + }); + try { + const response = await this.restoreEvt.waitFor((data) => { + return data.inbox === inbox; + }, timeout); + return response.error + ? Result.fail(response.error) + : Result.ok({ channel: response.channel!, activeTransfers: response.activeTransfers! }); + } catch (e) { + if (e.message.includes("Evt timeout")) { + return Result.fail(new MessagingError(MessagingError.reasons.Timeout)); + } + return Result.fail(e); + } + } + + async respondToRestoreStateMessage( + inbox: string, + restoreData: Result<{ channel: FullChannelState; activeTransfers: FullTransferState[] }, EngineError>, + ): Promise { + this.restoreEvt.post({ + inbox, + error: restoreData.getError(), + channel: restoreData.isError ? undefined : restoreData.getValue().channel, + activeTransfers: restoreData.isError ? undefined : restoreData.getValue().activeTransfers, + }); + } + sendSetupMessage( setupInfo: Result, Error>, to: string, @@ -158,28 +229,6 @@ export class MemoryMessagingService implements IMessagingService { throw new Error("Method not implemented."); } - sendRestoreStateMessage( - restoreData: Result<{ chainId: number }, EngineError>, - to: string, - from: string, - timeout?: number, - numRetries?: number, - ): Promise> { - throw new Error("Method not implemented."); - } - onReceiveRestoreStateMessage( - publicIdentifier: string, - callback: (restoreData: Result<{ chainId: number }, EngineError>, from: string, inbox: string) => void, - ): Promise { - throw new Error("Method not implemented."); - } - respondToRestoreStateMessage( - inbox: string, - restoreData: Result<{ channel: FullChannelState; activeTransfers: FullTransferState[] } | void, EngineError>, - ): Promise { - throw new Error("Method not implemented."); - } - sendIsAliveMessage( isAlive: Result<{ channelAddress: string }, VectorError>, to: string, diff --git a/modules/utils/src/test/services/store.ts b/modules/utils/src/test/services/store.ts index 0c01bece1..659ce0659 100644 --- a/modules/utils/src/test/services/store.ts +++ b/modules/utils/src/test/services/store.ts @@ -130,17 +130,19 @@ export class MemoryStoreService implements IEngineStore { } getChannelStateByParticipants( - participantA: string, - participantB: string, + publicIdentifierA: string, + publicIdentifierB: string, chainId: number, ): Promise { - return Promise.resolve( - [...this.channelStates.values()].find((channelState) => { - channelState.alice === participantA && - channelState.bob === participantB && - channelState.networkContext.chainId === chainId; - }), - ); + const channel = [...this.channelStates.values()].find((channelState) => { + const identifiers = [channelState.aliceIdentifier, channelState.bobIdentifier]; + return ( + identifiers.includes(publicIdentifierA) && + identifiers.includes(publicIdentifierB) && + channelState.networkContext.chainId === chainId + ); + }); + return Promise.resolve(channel); } getChannelStates(): Promise { @@ -178,7 +180,24 @@ export class MemoryStoreService implements IEngineStore { } saveChannelStateAndTransfers(channelState: FullChannelState, activeTransfers: FullTransferState[]): Promise { - return Promise.reject("Method not implemented"); + // remove all previous + this.channelStates.delete(channelState.channelAddress); + activeTransfers.map((transfer) => { + this.transfers.delete(transfer.transferId); + }); + this.transfersInChannel.delete(channelState.channelAddress); + + // add in new records + this.channelStates.set(channelState.channelAddress, channelState); + activeTransfers.map((transfer) => { + this.transfers.set(transfer.transferId, transfer); + }); + this.transfersInChannel.set( + channelState.channelAddress, + activeTransfers.map((t) => t.transferId), + ); + + return Promise.resolve(); } getActiveTransfers(channelAddress: string): Promise {