diff --git a/src/merkle-tree-large.ts b/src/merkle-tree-large.ts index 459d531..560c63d 100644 --- a/src/merkle-tree-large.ts +++ b/src/merkle-tree-large.ts @@ -1,11 +1,7 @@ import { Field } from "./field"; import { poseidon } from "./poseidon"; -import { - default_snapshot_id, - local_uri, - MerkleTreeDb -} from "./db"; -import LRUCache = require('lru-cache'); +import { default_snapshot_id, local_uri, MerkleTreeDb } from "./db"; +import LRUCache = require("lru-cache"); const hash = poseidon; export const MaxHeight = 16; @@ -22,23 +18,60 @@ const cacheOptions: any = {}; cacheOptions.max = 100; cacheOptions.maxAge = 60 * 1000; -export class MerkleTree { +interface MerkleTreeStorage { + getNode(mtIndex: string): Promise; + setNode(mtIndex: string, value: Field): void; + startSnapshot(id: string): void; + endSnapshot(): void; + lastestSnapshot(): Promise; + loadSnapshot(latestSnapshot: string): void; + close(): void; +} + +class MerkleTreeMemory implements MerkleTreeStorage { + private inMemoryMerkleTree: Map; + + constructor() { + this.inMemoryMerkleTree = new Map(); + } + + async getNode(mtIndex: string) { + return this.inMemoryMerkleTree.get(mtIndex + "I"); + } + + setNode(mtIndex: string, value: Field) { + return this.inMemoryMerkleTree.set(mtIndex + "I", value); + } + + startSnapshot(id: string) { + return; + } + + endSnapshot() { + return; + } + + async lastestSnapshot() { + return "0"; + } + + loadSnapshot(_latestSnapshot: string) { + return; + } + + close() { + return; + } +} + +class MerkleTreeDB implements MerkleTreeStorage { + private cache: LRUCache; + private db: MerkleTreeDb; private currentSnapshotIdx: string | undefined = undefined; - private cache = new LRUCache(10000); - private db_name = "delphinus"; - private db = new MerkleTreeDb(local_uri, this.db_name); - static emptyHashes: Field[] = []; - static emptyNodeHash(height: number) { - if (this.emptyHashes.length === 0) { - this.emptyHashes.push(new Field(0)); - for (let i = 0; i < MaxHeight; i++) { - const last = this.emptyHashes[i]; - this.emptyHashes.push(hash([last, last, last, last])); - } - this.emptyHashes = this.emptyHashes.reverse(); - } - return this.emptyHashes[height]; + constructor(dbName: string, cacheSize: number) { + this.cache = new LRUCache(cacheSize); + this.db = new MerkleTreeDb(local_uri, dbName); } private async getRawNode(mtIndex: string) { @@ -46,9 +79,6 @@ export class MerkleTree { } async getNode(mtIndex: string) { - if (mtIndex.startsWith("-")) { - throw new Error(mtIndex); - } let field = this.cache.get(mtIndex); if (field !== undefined) { return field; @@ -77,12 +107,16 @@ export class MerkleTree { this.cache.set(mtIndex, value); } - async startSnapshot(id: string) { + startSnapshot(id: string) { this.currentSnapshotIdx = id; } async endSnapshot() { - this.db.updateLatestSnapshotId(this.currentSnapshotIdx!); + if (this.currentSnapshotIdx === undefined) { + throw new Error("snapshot not set"); + } + + await this.db.updateLatestSnapshotId(this.currentSnapshotIdx); this.currentSnapshotIdx = undefined; } @@ -90,14 +124,71 @@ export class MerkleTree { return this.db.queryLatestSnapshotId(); } - async loadSnapshot(latest_snapshot: string) { - await this.db.restoreMerkleTree(latest_snapshot); + async loadSnapshot(latestSnapshot: string) { + await this.db.restoreMerkleTree(latestSnapshot); this.cache.reset(); } - async closeDb() { + async close() { await this.db.closeMongoClient(); } +} + +export class MerkleTree { + private storage: MerkleTreeStorage; + + constructor(isMemData = false) { + if (isMemData) { + this.storage = new MerkleTreeMemory(); + } else { + this.storage = new MerkleTreeDB("delphinus", 10000); + } + } + + static emptyHashes: Field[] = []; + static emptyNodeHash(height: number) { + if (this.emptyHashes.length === 0) { + this.emptyHashes.push(new Field(0)); + for (let i = 0; i < MaxHeight; i++) { + const last = this.emptyHashes[i]; + this.emptyHashes.push(hash([last, last, last, last])); + } + this.emptyHashes = this.emptyHashes.reverse(); + } + return this.emptyHashes[height]; + } + + async getNode(mtIndex: string) { + if (mtIndex.startsWith("-")) { + throw new Error(mtIndex); + } + + return await this.storage.getNode(mtIndex); + } + + async setNode(mtIndex: string, value: Field) { + await this.storage.setNode(mtIndex, value); + } + + async startSnapshot(id: string) { + this.storage.startSnapshot(id); + } + + async endSnapshot() { + await this.storage.endSnapshot(); + } + + async lastestSnapshot() { + return await this.storage.lastestSnapshot(); + } + + async loadSnapshot(latestSnapshot: string) { + await this.storage.loadSnapshot(latestSnapshot); + } + + async closeDb() { + await this.storage.close(); + } private async getNodeOrDefault(mtIndex: string) { let value = await this.getNode(mtIndex); diff --git a/test/db.test.ts b/test/db.test.ts index 4adfc6d..b83de1d 100644 --- a/test/db.test.ts +++ b/test/db.test.ts @@ -3,7 +3,7 @@ import { MerkleTree } from "../src/merkle-tree-large"; var assert = require('assert'); -async function main() { +async function testDBMerkleTree() { const merkle_tree = new MerkleTree(); await merkle_tree.loadSnapshot("0"); @@ -55,4 +55,32 @@ async function main() { await merkle_tree.closeDb(); } -main().then(() => console.log("done")) \ No newline at end of file +async function testInMemoryMerkleTree() { + const merkle_tree = new MerkleTree(true); + + await merkle_tree.getNode("0001").then((node) => { + assert.ok(node === undefined) + }); + + await merkle_tree.setNode("0001", new Field(1)); + await merkle_tree.getNode("0001").then((node) => { + assert.ok(node!.v.eq(new Field(1).v)) + }); + + // lastestSnapshot should always return string "0" + await merkle_tree.lastestSnapshot().then((node) => { + assert.equal(node, "0"); + }); + + // should do nothing + await merkle_tree.loadSnapshot("0"); + await merkle_tree.endSnapshot(); + await merkle_tree.closeDb(); +} + +async function main() { + await testDBMerkleTree(); + await testInMemoryMerkleTree(); +} + +main().then(() => console.log("done"))