diff --git a/.github/actions/setup-and-build/action.yml b/.github/actions/setup-and-build/action.yml index 5e45e45e0c..4c705e6cb5 100644 --- a/.github/actions/setup-and-build/action.yml +++ b/.github/actions/setup-and-build/action.yml @@ -118,7 +118,7 @@ runs: uses: actions/cache@v4 with: path: prover/server/proving-keys - key: ${{ runner.os }}-proving-keys-${{ inputs.cache-suffix }}${{ inputs.cache-suffix && '-' || '' }}${{ hashFiles('prover/server/scripts/download_keys.sh') }} + key: ${{ runner.os }}-proving-keys-${{ inputs.cache-suffix }}${{ inputs.cache-suffix && '-' || '' }}${{ hashFiles('prover/server/prover/common/key_downloader.go', 'prover/server/prover/common/proving_keys_utils.go') }} - name: Download proving keys if: "!contains(inputs.skip-components, 'proving-keys') && steps.cache-keys.outputs.cache-hit != 'true'" diff --git a/.github/workflows/cli-v1.yml b/.github/workflows/cli-v1.yml index 354a1f1d04..10c7806628 100644 --- a/.github/workflows/cli-v1.yml +++ b/.github/workflows/cli-v1.yml @@ -49,13 +49,13 @@ jobs: skip-components: "redis,disk-cleanup" cache-suffix: "js" - - name: Build CLI with V1 (CI mode - Linux x64 only) + - name: Build CLI run: | - npx nx build-ci @lightprotocol/zk-compression-cli + npx nx build @lightprotocol/zk-compression-cli - name: Run CLI tests with V1 run: | - npx nx test-ci @lightprotocol/zk-compression-cli + npx nx test @lightprotocol/zk-compression-cli - name: Display prover logs on failure if: failure() diff --git a/.github/workflows/cli-v2.yml b/.github/workflows/cli-v2.yml index cd5baa5211..41e1855688 100644 --- a/.github/workflows/cli-v2.yml +++ b/.github/workflows/cli-v2.yml @@ -59,13 +59,13 @@ jobs: cd js/compressed-token pnpm build:v2 - - name: Build CLI with V2 (CI mode - Linux x64 only) + - name: Build CLI with V2 run: | - npx nx build-ci @lightprotocol/zk-compression-cli + npx nx build @lightprotocol/zk-compression-cli - name: Run CLI tests with V2 run: | - npx nx test-ci @lightprotocol/zk-compression-cli + npx nx test @lightprotocol/zk-compression-cli - name: Display prover logs on failure if: failure() diff --git a/.github/workflows/forester-tests.yml b/.github/workflows/forester-tests.yml index d06f72c76b..f38c02b042 100644 --- a/.github/workflows/forester-tests.yml +++ b/.github/workflows/forester-tests.yml @@ -74,8 +74,8 @@ jobs: df -h / du -sh /home/runner/work/* | sort -hr | head -n 10 - - name: Build CLI (CI mode - Linux x64 only) - run: npx nx build-ci @lightprotocol/zk-compression-cli + - name: Build CLI + run: npx nx build @lightprotocol/zk-compression-cli - name: Test run: cargo test --package forester e2e_test -- --nocapture diff --git a/.github/workflows/js-v2.yml b/.github/workflows/js-v2.yml index 5208ca36f1..c89ffd79b2 100644 --- a/.github/workflows/js-v2.yml +++ b/.github/workflows/js-v2.yml @@ -59,16 +59,16 @@ jobs: cd js/compressed-token pnpm build:v2 - - name: Build CLI (CI mode - Linux x64 only) + - name: Build CLI run: | - npx nx build-ci @lightprotocol/zk-compression-cli + npx nx build @lightprotocol/zk-compression-cli - name: Run stateless.js tests with V2 run: | echo "Running stateless.js tests with retry logic (max 2 attempts)..." attempt=1 max_attempts=2 - until npx nx test-ci @lightprotocol/stateless.js; do + until npx nx test @lightprotocol/stateless.js; do attempt=$((attempt + 1)) if [ $attempt -gt $max_attempts ]; then echo "Tests failed after $max_attempts attempts" @@ -84,7 +84,7 @@ jobs: echo "Running compressed-token tests with retry logic (max 2 attempts)..." attempt=1 max_attempts=2 - until npx nx test-ci @lightprotocol/compressed-token; do + until npx nx test @lightprotocol/compressed-token; do attempt=$((attempt + 1)) if [ $attempt -gt $max_attempts ]; then echo "Tests failed after $max_attempts attempts" diff --git a/.github/workflows/js.yml b/.github/workflows/js.yml index 7fffa4ad6d..a353b4e30e 100644 --- a/.github/workflows/js.yml +++ b/.github/workflows/js.yml @@ -59,16 +59,16 @@ jobs: cd js/compressed-token pnpm build:v1 - - name: Build CLI (CI mode - Linux x64 only) + - name: Build CLI run: | - npx nx build-ci @lightprotocol/zk-compression-cli + npx nx build @lightprotocol/zk-compression-cli - name: Run stateless.js tests with V1 run: | echo "Running stateless.js tests with retry logic (max 2 attempts)..." attempt=1 max_attempts=2 - until npx nx test-ci @lightprotocol/stateless.js; do + until npx nx test @lightprotocol/stateless.js; do attempt=$((attempt + 1)) if [ $attempt -gt $max_attempts ]; then echo "Tests failed after $max_attempts attempts" @@ -84,7 +84,7 @@ jobs: echo "Running compressed-token tests with retry logic (max 2 attempts)..." attempt=1 max_attempts=2 - until npx nx test-ci @lightprotocol/compressed-token; do + until npx nx test @lightprotocol/compressed-token; do attempt=$((attempt + 1)) if [ $attempt -gt $max_attempts ]; then echo "Tests failed after $max_attempts attempts" diff --git a/.github/workflows/light-system-programs-tests.yml b/.github/workflows/light-system-programs-tests.yml index cd4c0f73e0..44192a04ea 100644 --- a/.github/workflows/light-system-programs-tests.yml +++ b/.github/workflows/light-system-programs-tests.yml @@ -84,9 +84,9 @@ jobs: skip-components: "redis,disk-cleanup" cache-suffix: "system-programs" - - name: Build CLI (CI mode - Linux x64 only) + - name: Build CLI run: | - npx nx build-ci @lightprotocol/zk-compression-cli + npx nx build @lightprotocol/zk-compression-cli - name: ${{ matrix.program }} run: | diff --git a/.github/workflows/prover-release.yml b/.github/workflows/prover-release.yml index 96b443b274..85406325b2 100644 --- a/.github/workflows/prover-release.yml +++ b/.github/workflows/prover-release.yml @@ -14,17 +14,23 @@ jobs: - name: Set up Go uses: actions/setup-go@v6 with: - go-version: 1.21 + go-version-file: "./prover/server/go.mod" - name: Build artifacts run: | cd prover/server - for cfgstr in "darwin amd64" "darwin arm64" "linux amd64" "windows amd64"; do + for cfgstr in "darwin amd64" "darwin arm64" "linux amd64" "linux arm64" "windows amd64"; do IFS=' ' read -r -a cfg <<< "$cfgstr" export GOOS="${cfg[0]}" export GOARCH="${cfg[1]}" export CGO_ENABLED=0 - go build -o prover-"$GOOS"-"$GOARCH" + + ext="" + if [ "$GOOS" = "windows" ]; then + ext=".exe" + fi + + go build -o prover-"$GOOS"-"$GOARCH""$ext" done - name: Create Release @@ -34,4 +40,5 @@ jobs: prover/server/prover-darwin-amd64 prover/server/prover-darwin-arm64 prover/server/prover-linux-amd64 - prover/server/prover-windows-amd64 + prover/server/prover-linux-arm64 + prover/server/prover-windows-amd64.exe diff --git a/.github/workflows/prover-test.yml b/.github/workflows/prover-test.yml index 1305d21c2e..dd09859f61 100644 --- a/.github/workflows/prover-test.yml +++ b/.github/workflows/prover-test.yml @@ -1,4 +1,6 @@ name: Test gnark prover +permissions: + contents: read on: push: branches: @@ -22,10 +24,96 @@ on: - ready_for_review jobs: - build-and-test: - if: github.event.pull_request.draft == false + build: + if: github.event_name == 'push' || github.event.pull_request.draft == false runs-on: buildjet-8vcpu-ubuntu-2204 - timeout-minutes: 120 + timeout-minutes: 15 + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "./prover/server/go.mod" + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('prover/server/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Build Go project + run: | + cd prover/server + go build ./... + go build -o light-prover + + - name: Upload prover binary + uses: actions/upload-artifact@v4 + with: + name: light-prover-binary + retention-days: 1 + path: prover/server/light-prover + + test-no-redis: + needs: build + runs-on: buildjet-8vcpu-ubuntu-2204 + timeout-minutes: 90 + strategy: + fail-fast: false + matrix: + test-suite: + - name: "Unit tests" + command: "go test ./prover/... -timeout 60m" + - name: "Worker selection tests" + command: "go test -v -run TestWorkerSelection -timeout 5m" + - name: "Batch operations queue routing tests" + command: "go test -v -run TestBatchOperations -timeout 5m" + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "./prover/server/go.mod" + + - name: Restore Go cache + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('prover/server/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Run ${{ matrix.test-suite.name }} + run: | + cd prover/server + ${{ matrix.test-suite.command }} + + test-with-redis: + needs: build + runs-on: buildjet-8vcpu-ubuntu-2204 + timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + test-suite: + - name: "Redis Queue tests" + command: "go test -v -run TestRedis -timeout 10m" + - name: "Queue cleanup tests" + command: "go test -v -run TestCleanup -timeout 5m" + - name: "Queue processing flow tests" + command: "go test -v -run TestJobProcessingFlow -timeout 5m" + - name: "Failed job status tests" + command: "go test -v -run TestFailedJobStatus -timeout 5m" services: redis: image: redis:7-alpine @@ -45,97 +133,167 @@ jobs: with: go-version-file: "./prover/server/go.mod" - - name: Install Elan - run: | - curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y -v --default-toolchain leanprover/lean4:v4.2.0 - echo "LAKE_VERSION=$(~/.elan/bin/lake --version)" >> "$GITHUB_ENV" - - - name: Cache dependencies + - name: Restore Go cache uses: actions/cache@v4 with: - path: prover/server/formal-verification/lake-packages - key: "${{ env.LAKE_VERSION }}" + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('prover/server/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- - - name: Download keys for lightweight tests - if: ${{ github.event.pull_request.base.ref == 'main' }} + - name: Run ${{ matrix.test-suite.name }} + env: + TEST_REDIS_URL: redis://localhost:6379/15 run: | cd prover/server - ./scripts/download_keys.sh full + ${{ matrix.test-suite.command }} - - name: Download keys for full tests - if: ${{ github.event.pull_request.base.ref == 'release' }} - run: | - cd prover/server - ./scripts/download_keys.sh full + integration-test-lightweight: + needs: build + if: | + github.event_name == 'push' || + (github.event_name == 'pull_request' && + github.event.pull_request.draft == false && + github.event.pull_request.base.ref == 'main') + runs-on: buildjet-8vcpu-ubuntu-2204 + timeout-minutes: 30 + steps: + - name: Checkout sources + uses: actions/checkout@v4 - - name: Build - run: | - cd prover/server - go build ./... + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "./prover/server/go.mod" - - name: Unit tests - run: | - cd prover/server - go test ./prover/... -timeout 60m + - name: Restore Go cache + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('prover/server/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- - - name: Redis Queue tests - env: - TEST_REDIS_URL: redis://localhost:6379/15 + - name: Lightweight integration tests run: | cd prover/server - go test -v -run TestRedis -timeout 10m + go test -run TestLightweight -timeout 15m - - name: Queue cleanup tests - env: - TEST_REDIS_URL: redis://localhost:6379/15 - run: | - cd prover/server - go test -v -run TestCleanup -timeout 5m + integration-test-lightweight-lazy: + needs: build + if: | + github.event_name == 'push' || + (github.event_name == 'pull_request' && + github.event.pull_request.draft == false && + github.event.pull_request.base.ref == 'main') + runs-on: buildjet-8vcpu-ubuntu-2204 + timeout-minutes: 30 + steps: + - name: Checkout sources + uses: actions/checkout@v4 - - name: Worker selection tests - run: | - cd prover/server - go test -v -run TestWorkerSelection -timeout 5m + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "./prover/server/go.mod" - - name: Batch operations queue routing tests - run: | - cd prover/server - go test -v -run TestBatchOperations -timeout 5m + - name: Restore Go cache + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('prover/server/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- - - name: Queue processing flow tests - env: - TEST_REDIS_URL: redis://localhost:6379/15 + - name: Lightweight lazy loading integration tests run: | cd prover/server - go test -v -run TestJobProcessingFlow -timeout 5m + go test -run TestLightweightLazy -timeout 15m - - name: Failed job status tests - env: - TEST_REDIS_URL: redis://localhost:6379/15 - run: | - cd prover/server - go test -v -run TestFailedJobStatus -timeout 5m + integration-test-full: + needs: build + if: | + github.event_name == 'pull_request' && + github.event.pull_request.draft == false && + startsWith(github.event.pull_request.base.ref, 'release') + runs-on: buildjet-8vcpu-ubuntu-2204 + timeout-minutes: 120 + steps: + - name: Checkout sources + uses: actions/checkout@v4 - - name: Lightweight integration tests - if: ${{ github.event.pull_request.base.ref == 'main' }} - run: | - cd prover/server - go test -run TestLightweight -timeout 15m + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "./prover/server/go.mod" + + - name: Restore Go cache + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('prover/server/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- - name: Full integration tests - if: ${{ github.event.pull_request.base.ref == 'release' }} run: | cd prover/server go test -run TestFull -timeout 120m + lean-verification: + needs: build + runs-on: buildjet-8vcpu-ubuntu-2204 + timeout-minutes: 30 + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version-file: "./prover/server/go.mod" + + - name: Download prover binary + uses: actions/download-artifact@v4 + with: + name: light-prover-binary + path: prover/server/ + + - name: Make binary executable + run: chmod +x prover/server/light-prover + + - name: Install Elan + run: | + curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | sh -s -- -y -v --default-toolchain leanprover/lean4:v4.2.0 + + - name: Get Lake version for cache key + id: lake-version + run: | + echo "version=$(~/.elan/bin/lake --version)" >> "$GITHUB_OUTPUT" + + - name: Cache Lean dependencies + uses: actions/cache@v4 + with: + path: prover/server/formal-verification/lake-packages + key: lean-${{ steps.lake-version.outputs.version }}-${{ hashFiles('prover/server/formal-verification/lakefile.lean') }} + restore-keys: | + lean-${{ steps.lake-version.outputs.version }}- + - name: Extract circuit to Lean run: | cd prover/server - go build ./light-prover extract-circuit --output formal-verification/FormalVerification/Circuit.lean --address-tree-height 40 --compressed-accounts 8 --state-tree-height 32 - - name: Build lean project + - name: Build Lean project run: | cd prover/server/formal-verification ~/.elan/bin/lake exe cache get - ~/.elan/bin/lake build + ~/.elan/bin/lake build \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c2097d3e6e..fb187c2ad9 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -92,9 +92,9 @@ jobs: skip-components: "redis,disk-cleanup" cache-suffix: "rust" - - name: Build CLI (CI mode - Linux x64 only) + - name: Build CLI run: | - npx nx build-ci @lightprotocol/zk-compression-cli + npx nx build @lightprotocol/zk-compression-cli - name: Run tests for ${{ matrix.group.name }} run: | diff --git a/.github/workflows/sdk-tests.yml b/.github/workflows/sdk-tests.yml index 07961016ba..d1c573306b 100644 --- a/.github/workflows/sdk-tests.yml +++ b/.github/workflows/sdk-tests.yml @@ -71,9 +71,9 @@ jobs: skip-components: "redis,disk-cleanup" cache-suffix: "sdk-tests" - - name: Build CLI (CI mode - Linux x64 only) + - name: Build CLI run: | - npx nx build-ci @lightprotocol/zk-compression-cli + npx nx build @lightprotocol/zk-compression-cli - name: Run sub-tests for ${{ matrix.program }} if: matrix.sub-tests != null diff --git a/cli/package.json b/cli/package.json index e16a60c05c..949dd9d85e 100644 --- a/cli/package.json +++ b/cli/package.json @@ -18,25 +18,8 @@ "/bin", "!bin/cargo-generate", "!/bin/**/*.vkey", - "!/bin/proving-keys/*.key", - "/bin/proving-keys/combined_26_1_1.key", - "/bin/proving-keys/combined_26_1_2.key", - "/bin/proving-keys/combined_26_2_1.key", - "/bin/proving-keys/combined_32_40_1_1.key", - "/bin/proving-keys/combined_32_40_1_2.key", - "/bin/proving-keys/combined_32_40_2_1.key", - "/bin/proving-keys/inclusion_32_1.key", - "/bin/proving-keys/inclusion_32_2.key", - "/bin/proving-keys/inclusion_32_3.key", - "/bin/proving-keys/inclusion_32_4.key", - "/bin/proving-keys/mainnet_inclusion_26_1.key", - "/bin/proving-keys/mainnet_inclusion_26_2.key", - "/bin/proving-keys/mainnet_inclusion_26_3.key", - "/bin/proving-keys/mainnet_inclusion_26_4.key", - "/bin/proving-keys/non-inclusion_26_1.key", - "/bin/proving-keys/non-inclusion_26_2.key", - "/bin/proving-keys/non-inclusion_40_1.key", - "/bin/proving-keys/non-inclusion_40_2.key", + "!/bin/proving-keys", + "!/bin/prover-*", "/dist", "/test_bin", "./config.json", @@ -109,13 +92,10 @@ "topicSeparator": "" }, "scripts": { - "add-bins": "./scripts/copyLocalProgramBinaries.sh && scripts/buildProver.sh", - "add-bins-ci": "./scripts/copyLocalProgramBinaries.sh && scripts/buildProver.sh --ci", - "add-bins-release": "./scripts/copyLocalProgramBinaries.sh && scripts/buildProver.sh --release-only", + "add-bins": "./scripts/copyLocalProgramBinaries.sh", "postinstall": "[ -d ./bin ] && find ./bin -type f -exec chmod +x {} + || echo 'No bin directory found, skipping chmod'", "build": "shx rm -rf dist && pnpm tsc -p tsconfig.json && pnpm tsc -p tsconfig.test.json && pnpm add-bins", - "build-ci": "shx rm -rf dist && pnpm tsc -p tsconfig.json && pnpm tsc -p tsconfig.test.json && pnpm add-bins-ci", - "build-release": "shx rm -rf dist && pnpm tsc -p tsconfig.json && pnpm tsc -p tsconfig.test.json && pnpm add-bins-release", + "build-release": "shx rm -rf dist && pnpm tsc -p tsconfig.json && pnpm tsc -p tsconfig.test.json && pnpm add-bins", "format": "pnpm prettier --write \"src/**/*.{ts,js}\" \"test/**/*.{ts,js}\" -w", "format:check": "pnpm prettier \"src/**/*{ts,js}\" \"test/**/*.{ts,js}\" --check", "lint": "eslint .", @@ -140,7 +120,6 @@ "kill": "killall solana-test-validator || true && killall solana-test-val || true && sleep 1", "test-cli": "pnpm test-config && pnpm kill", "test": "pnpm kill && pnpm test-cli && pnpm test-utils && pnpm test-create-mint && pnpm test-mint-to && pnpm test-transfer && pnpm test-merge-token-accounts && pnpm test-create-token-pool && pnpm test-compress-spl && pnpm test-decompress-spl && pnpm test-token-balance && pnpm test-compress-sol && pnpm test-balance && pnpm test-decompress-sol && pnpm test-approve-and-mint-to && pnpm test-test-validator", - "test-ci": "pnpm kill && pnpm test-cli && pnpm test-utils && pnpm test-create-mint && pnpm test-mint-to && pnpm test-transfer && pnpm test-merge-token-accounts && pnpm test-create-token-pool && pnpm test-compress-spl && pnpm test-decompress-spl && pnpm test-token-balance && pnpm test-compress-sol && pnpm test-balance && pnpm test-decompress-sol && pnpm test-approve-and-mint-to && pnpm test-test-validator", "install-local": "pnpm build && pnpm global remove @lightprotocol/zk-compression-cli || true && pnpm global add $PWD", "version": "oclif readme && git add README.md" }, diff --git a/cli/scripts/buildProver.sh b/cli/scripts/buildProver.sh deleted file mode 100755 index 4bffd6c0ea..0000000000 --- a/cli/scripts/buildProver.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env sh - -set -eux - -build_prover() { - GOOS=$1 GOARCH=$2 go build -o "$3" -} - -# Parse command line arguments -RELEASE_ONLY=false -CI_MODE=false -while [ $# -gt 0 ]; do - case $1 in - --release-only) - RELEASE_ONLY=true - shift - ;; - --ci) - CI_MODE=true - shift - ;; - *) - echo "Unknown option: $1" - echo "Usage: $0 [--release-only] [--ci]" - exit 1 - ;; - esac -done - -root_dir="$(git rev-parse --show-toplevel)" -gnark_dir="${root_dir}/prover/server" -out_dir="${root_dir}/cli/bin" -cli_dir="${root_dir}/cli" - -if [ ! -e "$out_dir" ]; then - mkdir -p "$out_dir" -fi - -# Check if proving keys exist before copying -if [ ! -d "${gnark_dir}/proving-keys" ] || [ -z "$(ls -A "${gnark_dir}/proving-keys" 2>/dev/null)" ]; then - echo "ERROR: Proving keys not found at ${gnark_dir}/proving-keys" - echo "Please run: ./prover/server/scripts/download_keys.sh light" - exit 1 -fi - -# Create proving-keys directory in output -mkdir -p "$out_dir/proving-keys" - -if [ "$RELEASE_ONLY" = true ]; then - echo "Release mode: copying only keys listed in package.json" - # Dynamically read .key files from package.json files field - # Extract all lines containing "/bin/proving-keys/" and ".key" - key_files=$(node -e " -const pkg = require('${cli_dir}/package.json'); -const keyFiles = pkg.files - .filter(f => f.includes('/bin/proving-keys/') && f.endsWith('.key')) - .map(f => f.split('/').pop()); -console.log(keyFiles.join(' ')); -") - - # Copy only the specified .key files - for key_file in $key_files; do - if [ -f "${gnark_dir}/proving-keys/${key_file}" ]; then - cp "${gnark_dir}/proving-keys/${key_file}" "$out_dir/proving-keys/${key_file}" - echo "Copied (release): ${key_file}" - else - echo "WARNING: ${key_file} not found in ${gnark_dir}/proving-keys" - fi - done -else - echo "Development mode: copying ALL .key files" - # Copy ALL .key files from prover directory - for key_file in "${gnark_dir}/proving-keys"/*.key; do - if [ -f "$key_file" ]; then - filename=$(basename "$key_file") - cp "$key_file" "$out_dir/proving-keys/$filename" - echo "Copied (all): $filename" - fi - done -fi - -cd "$gnark_dir" - -if [ "$CI_MODE" = true ]; then - echo "CI mode: Building only Linux x64 prover" - # Linux x64 only (for CI) - build_prover linux amd64 "$out_dir"/prover-linux-x64 -else - echo "Building all prover binaries for release" - # Windows - build_prover windows amd64 "$out_dir"/prover-windows-x64.exe - build_prover windows arm64 "$out_dir"/prover-windows-arm64.exe - - # MacOS - build_prover darwin amd64 "$out_dir"/prover-darwin-x64 - build_prover darwin arm64 "$out_dir"/prover-darwin-arm64 - - # Linux - build_prover linux amd64 "$out_dir"/prover-linux-x64 - build_prover linux arm64 "$out_dir"/prover-linux-arm64 -fi diff --git a/cli/src/commands/start-prover/index.ts b/cli/src/commands/start-prover/index.ts index ea8e54947e..07c29bee6d 100644 --- a/cli/src/commands/start-prover/index.ts +++ b/cli/src/commands/start-prover/index.ts @@ -15,41 +15,6 @@ class StartProver extends Command { required: false, default: 3001, }), - "run-mode": Flags.string({ - description: - "Specify the running mode (local-rpc, forester, forester-test, rpc, or full). Default: local-rpc", - options: [ - "local-rpc", - "rpc", - "forester", - "forester-test", - "full", - "full-test", - ], - required: false, - }), - circuit: Flags.string({ - description: "Specify individual circuits to enable.", - options: [ - "inclusion", - "non-inclusion", - "combined", - "append", - "update", - "address-append", - "append-test", - "update-test", - "address-append-test", - ], - multiple: true, - required: false, - }), - force: Flags.boolean({ - description: - "Force restart the prover even if one is already running with the same flags.", - required: false, - default: false, - }), redisUrl: Flags.string({ description: "Redis URL to use for the prover (e.g. redis://localhost:6379)", @@ -62,24 +27,10 @@ class StartProver extends Command { const loader = new CustomLoader("Performing setup tasks...\n"); loader.start(); - if (!flags["run-mode"] && !flags["circuit"]) { - this.log("Please specify --run-mode or --circuit."); - return; - } - const proverPort = flags["prover-port"] || 3001; - const force = flags["force"] || false; const redisUrl = flags["redisUrl"] || process.env.REDIS_URL || undefined; - // TODO: remove this workaround. - // Force local-rpc mode when rpc is specified - let runMode = flags["run-mode"]; - if (runMode === "rpc") { - runMode = "local-rpc"; - this.log("Note: Running in local-rpc mode instead of rpc mode"); - } - - await startProver(proverPort, runMode, flags["circuit"], force, redisUrl); + await startProver(proverPort, redisUrl); const healthy = await healthCheck(proverPort, 10, 1000); loader.stop(); diff --git a/cli/src/utils/downloadProverBinary.ts b/cli/src/utils/downloadProverBinary.ts new file mode 100644 index 0000000000..cf690ba01b --- /dev/null +++ b/cli/src/utils/downloadProverBinary.ts @@ -0,0 +1,155 @@ +import fs from "fs"; +import path from "path"; +import https from "https"; +import http from "http"; +import { pipeline } from "stream/promises"; + +const PROVER_VERSION = "1.0.4"; +const GITHUB_RELEASES_BASE_URL = `https://github.com/Lightprotocol/light-protocol/releases/download/light-prover-v${PROVER_VERSION}`; +const MAX_REDIRECTS = 10; + +interface DownloadOptions { + maxRetries?: number; + retryDelay?: number; +} + +export async function downloadProverBinary( + binaryPath: string, + binaryName: string, + options: DownloadOptions = {}, +): Promise { + const { maxRetries = 3, retryDelay = 2000 } = options; + const url = `${GITHUB_RELEASES_BASE_URL}/${binaryName}`; + + console.log(`\nDownloading prover binary: ${binaryName}`); + console.log(` From: ${url}`); + console.log(` To: ${binaryPath}\n`); + + const dir = path.dirname(binaryPath); + if (!fs.existsSync(dir)) { + fs.mkdirSync(dir, { recursive: true }); + } + + let lastError: Error | null = null; + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + await downloadFile(url, binaryPath); + + if (process.platform !== "win32") { + fs.chmodSync(binaryPath, 0o755); + } + + console.log("\nProver binary downloaded.\n"); + return; + } catch (error) { + lastError = error as Error; + console.error( + `\nDownload attempt ${attempt}/${maxRetries} failed: ${lastError.message}`, + ); + + if (attempt < maxRetries) { + console.log(` Retrying in ${retryDelay / 1000}s...\n`); + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + } + } + } + + throw new Error( + `Failed to download prover binary after ${maxRetries} attempts: ${lastError?.message}`, + ); +} + +async function downloadFile( + url: string, + outputPath: string, + redirectDepth: number = 0, +): Promise { + return new Promise((resolve, reject) => { + const protocol = url.startsWith("https") ? https : http; + + const request = protocol.get(url, (response) => { + if ( + response.statusCode === 301 || + response.statusCode === 302 || + response.statusCode === 307 || + response.statusCode === 308 + ) { + const redirectUrl = response.headers.location; + if (!redirectUrl) { + return reject(new Error("Redirect without location header")); + } + if (redirectDepth >= MAX_REDIRECTS) { + return reject( + new Error( + `Too many redirects: exceeded maximum of ${MAX_REDIRECTS} redirects`, + ), + ); + } + return downloadFile(redirectUrl, outputPath, redirectDepth + 1).then( + resolve, + reject, + ); + } + + if (response.statusCode !== 200) { + return reject( + new Error(`HTTP ${response.statusCode}: ${response.statusMessage}`), + ); + } + + const totalBytes = parseInt( + response.headers["content-length"] || "0", + 10, + ); + let downloadedBytes = 0; + let lastProgress = 0; + + const fileStream = fs.createWriteStream(outputPath); + + response.on("data", (chunk: Buffer) => { + downloadedBytes += chunk.length; + + if (totalBytes > 0) { + const progress = Math.floor((downloadedBytes / totalBytes) * 100); + if (progress >= lastProgress + 5) { + lastProgress = progress; + const mb = (downloadedBytes / 1024 / 1024).toFixed(1); + const totalMb = (totalBytes / 1024 / 1024).toFixed(1); + process.stdout.write( + `\r Progress: ${progress}% (${mb}MB / ${totalMb}MB)`, + ); + } + } + }); + + pipeline(response, fileStream) + .then(() => { + if (totalBytes > 0) { + process.stdout.write("\r Progress: 100% - Download complete\n"); + } + resolve(); + }) + .catch((error) => { + fs.unlinkSync(outputPath); + reject(error); + }); + }); + + request.on("error", (error) => { + if (fs.existsSync(outputPath)) { + fs.unlinkSync(outputPath); + } + reject(error); + }); + + request.setTimeout(60000, () => { + request.destroy(); + reject(new Error("Download timeout")); + }); + }); +} + +export function getProverVersion(): string { + return PROVER_VERSION; +} diff --git a/cli/src/utils/initTestEnv.ts b/cli/src/utils/initTestEnv.ts index b500ef82a7..cfb5138144 100644 --- a/cli/src/utils/initTestEnv.ts +++ b/cli/src/utils/initTestEnv.ts @@ -152,7 +152,7 @@ export async function initTestEnv({ setConfig(config); try { // TODO: check if using redisUrl is better here. - await startProver(proverPort, proverRunMode, circuits); + await startProver(proverPort); } catch (error) { console.error("Failed to start prover:", error); // Prover logs will be automatically displayed by spawnBinary in process.ts diff --git a/cli/src/utils/processProverServer.ts b/cli/src/utils/processProverServer.ts index e46447e2d0..9b6852f8bd 100644 --- a/cli/src/utils/processProverServer.ts +++ b/cli/src/utils/processProverServer.ts @@ -1,4 +1,5 @@ import path from "path"; +import fs from "fs"; import { killProcess, killProcessByPort, @@ -7,6 +8,7 @@ import { } from "./process"; import { LIGHT_PROVER_PROCESS_NAME, BASE_PATH } from "./constants"; import find from "find-process"; +import { downloadProverBinary } from "./downloadProverBinary"; const KEYS_DIR = "proving-keys/"; @@ -85,47 +87,41 @@ export async function isProverRunningWithFlags( return found; } -export async function startProver( - proverPort: number, - runMode: string | undefined, - circuits: string[] | undefined = [], - force: boolean = false, - redisUrl?: string, -) { - if ( - !force && - (await isProverRunningWithFlags(runMode, circuits, proverPort)) - ) { +/** + * Ensures the prover binary exists, downloading it if necessary + */ +async function ensureProverBinary(): Promise { + const binaryPath = getProverPathByArch(); + const binaryName = getProverNameByArch(); + + if (fs.existsSync(binaryPath)) { return; } - console.log("Kill existing prover process..."); + console.log("Prover binary not found. Downloading..."); + + try { + await downloadProverBinary(binaryPath, binaryName); + } catch (error) { + throw new Error( + `Failed to download prover binary: ${error instanceof Error ? error.message : String(error)}\n` + + `Please download manually from: https://github.com/Lightprotocol/light-protocol/releases`, + ); + } +} + +export async function startProver(proverPort: number, redisUrl?: string) { + await ensureProverBinary(); + await killProver(); await killProcessByPort(proverPort); const keysDir = path.join(path.resolve(__dirname, BASE_PATH), KEYS_DIR); const args = ["start"]; + args.push("--keys-dir", keysDir); args.push("--prover-address", `0.0.0.0:${proverPort}`); - if (runMode != null) { - args.push("--run-mode", runMode); - } - - for (const circuit of circuits) { - args.push("--circuit", circuit); - } - - if (runMode != null) { - console.log(`Starting prover in ${runMode} mode...`); - } else if (circuits && circuits.length > 0) { - console.log(`Starting prover with circuits: ${circuits.join(", ")}...`); - } - - if ((!circuits || circuits.length === 0) && runMode == null) { - runMode = "local-rpc"; - args.push("--run-mode", runMode); - console.log(`Starting prover with fallback ${runMode} mode...`); - } + args.push("--auto-download", "true"); if (redisUrl) { args.push("--redis-url", redisUrl); @@ -137,16 +133,26 @@ export async function startProver( } export function getProverNameByArch(): string { - const platform = process.platform; - const arch = process.arch; + const nodePlatform = process.platform; + const nodeArch = process.arch; - if (!platform || !arch) { + if (!nodePlatform || !nodeArch) { throw new Error("Unsupported platform or architecture"); } - let binaryName = `prover-${platform}-${arch}`; + let goPlatform: string = nodePlatform; + let goArch: string = nodeArch; + + if (nodeArch === "x64") { + goArch = "amd64"; + } + if (nodePlatform === "win32") { + goPlatform = "windows"; + } + + let binaryName = `prover-${goPlatform}-${goArch}`; - if (platform.toString() === "windows") { + if (goPlatform === "windows") { binaryName += ".exe"; } return binaryName; diff --git a/forester/tests/e2e_test.rs b/forester/tests/e2e_test.rs index dc41962442..526533970f 100644 --- a/forester/tests/e2e_test.rs +++ b/forester/tests/e2e_test.rs @@ -18,7 +18,7 @@ use light_batched_merkle_tree::{ }; use light_client::{ indexer::{AddressWithTree, GetCompressedTokenAccountsByOwnerOrDelegateOptions, Indexer}, - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{LightClient, LightClientConfig, Rpc}, }; use light_compressed_account::{ @@ -243,8 +243,8 @@ async fn e2e_test() { if test_mode == TestMode::Local { init(Some(LightValidatorConfig { enable_indexer: true, + enable_prover: false, wait_time: 60, - prover_config: None, sbf_programs: vec![( "FNt7byTHev1k5x2cXZLBr8TdWiC3zoP5vcnZR4P682Uy".to_string(), "../target/deploy/create_address_test_program.so".to_string(), @@ -252,7 +252,7 @@ async fn e2e_test() { limit_ledger_size: None, })) .await; - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; } let mut rpc = setup_rpc_connection(&env.protocol.forester).await; diff --git a/forester/tests/legacy/address_v2_test.rs b/forester/tests/legacy/address_v2_test.rs index 96ed4e2c58..22c8ad6700 100644 --- a/forester/tests/legacy/address_v2_test.rs +++ b/forester/tests/legacy/address_v2_test.rs @@ -10,7 +10,7 @@ use light_batched_merkle_tree::{ }; use light_client::{ indexer::AddressWithTree, - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{merkle_tree::MerkleTreeExt, LightClient, LightClientConfig, Rpc}, }; use light_compressed_account::{ @@ -54,8 +54,8 @@ async fn test_create_v2_address() { init(Some(LightValidatorConfig { enable_indexer: true, + enable_prover: true, wait_time: 90, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![( "FNt7byTHev1k5x2cXZLBr8TdWiC3zoP5vcnZR4P682Uy".to_string(), "../target/deploy/create_address_test_program.so".to_string(), diff --git a/forester/tests/legacy/batched_address_test.rs b/forester/tests/legacy/batched_address_test.rs index fe53f8c82e..7b6db499d7 100644 --- a/forester/tests/legacy/batched_address_test.rs +++ b/forester/tests/legacy/batched_address_test.rs @@ -8,7 +8,7 @@ use light_batched_merkle_tree::{ }; use light_client::{ indexer::{photon_indexer::PhotonIndexer, AddressMerkleTreeAccounts, Indexer}, - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{client::RpcUrl, LightClient, LightClientConfig, Rpc}, }; use light_program_test::{accounts::test_accounts::TestAccounts, indexer::TestIndexer}; @@ -33,8 +33,8 @@ mod test_utils; async fn test_address_batched() { init(Some(LightValidatorConfig { enable_indexer: true, + enable_prover: true, wait_time: 90, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![( "FNt7byTHev1k5x2cXZLBr8TdWiC3zoP5vcnZR4P682Uy".to_string(), "../target/deploy/create_address_test_program.so".to_string(), diff --git a/forester/tests/legacy/batched_state_async_indexer_test.rs b/forester/tests/legacy/batched_state_async_indexer_test.rs index f6fc8d6ef1..38c7b95362 100644 --- a/forester/tests/legacy/batched_state_async_indexer_test.rs +++ b/forester/tests/legacy/batched_state_async_indexer_test.rs @@ -11,7 +11,7 @@ use light_client::{ photon_indexer::PhotonIndexer, AddressWithTree, GetCompressedTokenAccountsByOwnerOrDelegateOptions, Indexer, }, - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{LightClient, LightClientConfig, Rpc}, }; use light_compressed_account::{ @@ -78,13 +78,13 @@ async fn test_state_indexer_async_batched() { init(Some(LightValidatorConfig { enable_indexer: true, + enable_prover: true, wait_time: 30, - prover_config: None, sbf_programs: vec![], limit_ledger_size: None, })) .await; - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let env = TestAccounts::get_local_test_validator_accounts(); let mut config = forester_config(); diff --git a/forester/tests/legacy/batched_state_indexer_test.rs b/forester/tests/legacy/batched_state_indexer_test.rs index 89d3b1faf9..32d3a229a8 100644 --- a/forester/tests/legacy/batched_state_indexer_test.rs +++ b/forester/tests/legacy/batched_state_indexer_test.rs @@ -11,7 +11,7 @@ use light_batched_merkle_tree::{ }; use light_client::{ indexer::{photon_indexer::PhotonIndexer, Indexer, IndexerRpcConfig, RetryConfig}, - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{LightClient, LightClientConfig, Rpc}, }; use light_compressed_account::TreeType; @@ -39,8 +39,8 @@ async fn test_state_indexer_batched() { init(Some(LightValidatorConfig { enable_indexer: true, + enable_prover: true, wait_time: 90, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![], limit_ledger_size: None, })) diff --git a/forester/tests/legacy/batched_state_test.rs b/forester/tests/legacy/batched_state_test.rs index 89fb391f11..559af5b181 100644 --- a/forester/tests/legacy/batched_state_test.rs +++ b/forester/tests/legacy/batched_state_test.rs @@ -10,7 +10,7 @@ use light_batched_merkle_tree::{ merkle_tree::BatchedMerkleTreeAccount, queue::BatchedQueueAccount, }; use light_client::{ - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{client::RpcUrl, LightClient, LightClientConfig, Rpc}, }; use light_compressed_account::TreeType; @@ -43,8 +43,8 @@ async fn test_state_batched() { init(Some(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 30, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![], limit_ledger_size: None, })) diff --git a/forester/tests/legacy/e2e_test.rs b/forester/tests/legacy/e2e_test.rs index 386e31cb8f..9dc712eb3f 100644 --- a/forester/tests/legacy/e2e_test.rs +++ b/forester/tests/legacy/e2e_test.rs @@ -11,7 +11,7 @@ use forester_utils::{ }; use light_client::{ indexer::{AddressMerkleTreeAccounts, StateMerkleTreeAccounts}, - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{client::RpcUrl, LightClient, LightClientConfig, Rpc, RpcError}, }; use light_program_test::{accounts::test_accounts::TestAccounts, indexer::TestIndexer}; @@ -35,8 +35,8 @@ use test_utils::*; async fn test_epoch_monitor_with_2_foresters() { init(Some(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 90, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![], limit_ledger_size: None, })) @@ -381,8 +381,8 @@ async fn test_epoch_double_registration() { init(Some(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 90, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![], limit_ledger_size: None, })) diff --git a/forester/tests/legacy/e2e_v1_test.rs b/forester/tests/legacy/e2e_v1_test.rs index a706094eca..050ece14af 100644 --- a/forester/tests/legacy/e2e_v1_test.rs +++ b/forester/tests/legacy/e2e_v1_test.rs @@ -11,7 +11,7 @@ use forester_utils::{ }; use light_client::{ indexer::{AddressMerkleTreeAccounts, StateMerkleTreeAccounts}, - local_test_validator::{LightValidatorConfig, ProverConfig}, + local_test_validator::LightValidatorConfig, rpc::{client::RpcUrl, LightClient, LightClientConfig, Rpc, RpcError}, }; use light_program_test::{accounts::test_accounts::TestAccounts, indexer::TestIndexer}; @@ -36,8 +36,8 @@ use test_utils::*; async fn test_e2e_v1() { init(Some(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 90, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![], limit_ledger_size: None, })) @@ -378,8 +378,8 @@ async fn test_epoch_double_registration() { init(Some(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 90, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![], limit_ledger_size: None, })) diff --git a/forester/tests/test_batch_append_spent.rs b/forester/tests/test_batch_append_spent.rs index 34abb37c60..38caf3ab02 100644 --- a/forester/tests/test_batch_append_spent.rs +++ b/forester/tests/test_batch_append_spent.rs @@ -46,8 +46,8 @@ async fn test_batch_sequence() { init(Some(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 10, - prover_config: None, sbf_programs: vec![], limit_ledger_size: None, })) diff --git a/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs b/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs index 24f41dce39..f06ad2e867 100644 --- a/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs +++ b/program-tests/batched-merkle-tree-test/tests/merkle_tree.rs @@ -37,7 +37,7 @@ use light_compressed_account::{ }; use light_hasher::{Hasher, Poseidon}; use light_merkle_tree_reference::MerkleTree; -use light_prover_client::prover::{spawn_prover, ProverConfig}; +use light_prover_client::prover::spawn_prover; use light_test_utils::mock_batched_forester::{ MockBatchedAddressForester, MockBatchedForester, MockTxEvent, }; @@ -446,7 +446,7 @@ pub fn simulate_transaction( #[serial] #[tokio::test] async fn test_simulate_transactions() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let mut mock_indexer = MockBatchedForester::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>::default(); @@ -885,7 +885,7 @@ pub fn get_random_leaf(rng: &mut StdRng, active_leaves: &mut Vec<[u8; 32]>) -> ( #[serial] #[tokio::test] async fn test_e2e() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let mut mock_indexer = MockBatchedForester::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>::default(); @@ -1580,7 +1580,7 @@ pub fn get_rnd_bytes(rng: &mut StdRng) -> [u8; 32] { #[serial] #[tokio::test] async fn test_fill_state_queues_completely() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let mut current_slot = 1; let roothistory_capacity = vec![17, 80]; for root_history_capacity in roothistory_capacity { @@ -1980,7 +1980,7 @@ async fn test_fill_state_queues_completely() { #[serial] #[tokio::test] async fn test_fill_address_tree_completely() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let mut current_slot = 1; let roothistory_capacity = vec![17, 80]; // for root_history_capacity in roothistory_capacity { diff --git a/program-tests/compressed-token-test/tests/test.rs b/program-tests/compressed-token-test/tests/test.rs index 29f45912a0..f5b4d221af 100644 --- a/program-tests/compressed-token-test/tests/test.rs +++ b/program-tests/compressed-token-test/tests/test.rs @@ -44,7 +44,7 @@ use light_program_test::{ utils::assert::assert_rpc_error, LightProgramTest, ProgramTestConfig, }; -use light_prover_client::prover::{spawn_prover, ProverConfig, ProverMode}; +use light_prover_client::prover::spawn_prover; use light_sdk::token::{AccountState, TokenDataWithMerkleContext}; use light_system_program::{errors::SystemProgramError, utils::get_sol_pool_pda}; use light_test_utils::{ @@ -594,11 +594,7 @@ pub async fn add_token_pool( #[serial] #[tokio::test] async fn test_wrapped_sol() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; // is token 22 fails with Instruction: InitializeAccount, Program log: Error: Invalid Mint line 216 for is_token_22 in [false] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) @@ -888,7 +884,7 @@ async fn test_5_mint_to() { #[tokio::test] async fn test_10_mint_to() { let mut rng = thread_rng(); - // Make sure that the tokal token supply does not exceed `u64::MAX`. + // Make sure that the total token supply does not exceed `u64::MAX`. let amounts: Vec = (0..10).map(|_| rng.gen_range(0..(u64::MAX / 10))).collect(); test_mint_to(amounts, 1, Some(1_000_000)).await } @@ -1254,11 +1250,7 @@ async fn test_mint_to_failing() { #[serial] #[tokio::test] async fn test_transfers() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; let possible_inputs = [1, 2, 3, 4, 8]; for input_num in possible_inputs { for output_num in 1..8 { @@ -1337,7 +1329,7 @@ async fn perform_transfer_test( perform_transfer_22_test(inputs, outputs, amount, false, start_prover_server, false).await; } -// TODO: reexport these types from ligth-program test. +// TODO: reexport these types from light-program test. use light_batched_merkle_tree::{ initialize_address_tree::InitAddressTreeAccountsInstructionData, initialize_state_tree::InitStateTreeAccountsInstructionData, @@ -1421,12 +1413,7 @@ async fn perform_transfer_22_test( #[serial] #[tokio::test] async fn test_decompression() { - spawn_prover(ProverConfig { - // is overkill but we run everything on ForesterTest - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { println!("is_token_22: {}", is_token_22); let mut context = LightProgramTest::new(ProgramTestConfig::new(false, None)) @@ -1570,11 +1557,7 @@ pub async fn assert_minted_to_all_token_pools( #[serial] #[tokio::test] async fn test_mint_to_and_burn_from_all_token_pools() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -1649,7 +1632,7 @@ async fn test_mint_to_and_burn_from_all_token_pools() { #[serial] #[tokio::test] async fn test_multiple_decompression() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let rng = &mut thread_rng(); for is_token_22 in [false, true] { println!("is_token_22: {}", is_token_22); @@ -2837,11 +2820,7 @@ async fn test_revoke_failing() { #[serial] #[tokio::test] async fn test_burn() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { println!("is_token_22: {}", is_token_22); let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) @@ -3119,11 +3098,7 @@ async fn test_burn() { #[serial] #[tokio::test] async fn failing_tests_burn() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -3487,11 +3462,7 @@ async fn failing_tests_burn() { /// 4. Freeze delegated tokens /// 5. Thaw delegated tokens async fn test_freeze_and_thaw(mint_amount: u64, delegated_amount: u64) { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -3680,11 +3651,7 @@ async fn test_freeze_and_thaw_10000() { #[serial] #[tokio::test] async fn test_failing_freeze() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -3948,11 +3915,7 @@ async fn test_failing_freeze() { #[serial] #[tokio::test] async fn test_failing_thaw() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { let mut rpc = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -4246,11 +4209,7 @@ async fn test_failing_thaw() { #[serial] #[tokio::test] async fn test_failing_decompression() { - spawn_prover(ProverConfig { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - }) - .await; + spawn_prover().await; for is_token_22 in [false, true] { let mut context = LightProgramTest::new(ProgramTestConfig::new(false, None)) .await @@ -5486,8 +5445,8 @@ async fn test_transfer_with_transaction_hash() { async fn test_transfer_with_photon_and_batched_tree() { spawn_validator(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 15, - prover_config: Some(ProverConfig::default()), sbf_programs: vec![], limit_ledger_size: None, }) diff --git a/program-tests/e2e-test/tests/test.rs b/program-tests/e2e-test/tests/test.rs index dcf570738a..317d302526 100644 --- a/program-tests/e2e-test/tests/test.rs +++ b/program-tests/e2e-test/tests/test.rs @@ -5,7 +5,6 @@ use light_batched_merkle_tree::{ initialize_state_tree::InitStateTreeAccountsInstructionData, }; use light_program_test::{indexer::TestIndexer, LightProgramTest, ProgramTestConfig}; -use light_prover_client::prover::ProverConfig; use light_registry::protocol_config::state::ProtocolConfig; use light_test_utils::{ e2e_test_env::{E2ETestEnv, GeneralActionConfig, KeypairActionConfig}, @@ -31,7 +30,6 @@ async fn test_10_all() { config.v2_address_tree_config = Some(address_params); config.protocol_config = protocol_config; config.with_prover = true; - config.prover_config = Some(ProverConfig::default()); let rpc = LightProgramTest::new(config).await.unwrap(); let indexer: TestIndexer = TestIndexer::init_from_acounts( @@ -81,7 +79,6 @@ async fn test_batched_only() { config.v2_address_tree_config = Some(address_params); config.protocol_config = protocol_config; config.with_prover = true; - config.prover_config = Some(ProverConfig::default()); config.additional_programs = Some(vec![( "create_address_test_program", CREATE_ADDRESS_TEST_PROGRAM_ID, @@ -153,7 +150,6 @@ async fn test_10000_all() { config.v2_address_tree_config = Some(address_params); config.protocol_config = protocol_config; config.with_prover = true; - config.prover_config = Some(ProverConfig::default()); config.additional_programs = Some(vec![( "create_address_test_program", CREATE_ADDRESS_TEST_PROGRAM_ID, diff --git a/program-tests/system-cpi-v2-test/tests/event.rs b/program-tests/system-cpi-v2-test/tests/event.rs index 17a3e3b10a..c96746fc02 100644 --- a/program-tests/system-cpi-v2-test/tests/event.rs +++ b/program-tests/system-cpi-v2-test/tests/event.rs @@ -531,8 +531,8 @@ async fn generate_photon_test_data_multiple_events() { for num_expected_events in 4..5 { spawn_validator(LightValidatorConfig { enable_indexer: false, + enable_prover: true, wait_time: 10, - prover_config: None, sbf_programs: vec![( create_address_test_program::ID.to_string(), "../../target/deploy/create_address_test_program.so".to_string(), diff --git a/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs b/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs index 1f0b19dddb..8472abdb7a 100644 --- a/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs +++ b/program-tests/system-cpi-v2-test/tests/invoke_cpi_with_read_only.rs @@ -27,7 +27,7 @@ use light_program_test::{ utils::assert::assert_rpc_error, LightProgramTest, ProgramTestConfig, }; -use light_prover_client::prover::{spawn_prover, ProverConfig}; +use light_prover_client::prover::spawn_prover; use light_sdk::{ address::{NewAddressParamsAssigned, ReadOnlyAddress}, instruction::ValidityProof, @@ -47,7 +47,7 @@ use solana_sdk::pubkey::Pubkey; #[serial] #[tokio::test] async fn functional_read_only() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)] { let config = if batched { let mut config = ProgramTestConfig::default_with_batched_trees(false); @@ -346,7 +346,7 @@ async fn functional_read_only() { #[serial] #[tokio::test] async fn functional_account_infos() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)].into_iter() { @@ -660,7 +660,7 @@ async fn functional_account_infos() { #[serial] #[tokio::test] async fn create_addresses_with_account_info() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let with_transaction_hash = true; for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)].into_iter() @@ -1260,7 +1260,7 @@ async fn create_addresses_with_account_info() { #[serial] #[tokio::test] async fn create_addresses_with_read_only() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let with_transaction_hash = true; for (batched, is_v2_ix) in [(true, false), (true, true), (false, false), (false, true)].into_iter() @@ -2020,7 +2020,7 @@ async fn compress_sol_with_account_info() { #[serial] #[tokio::test] async fn cpi_context_with_read_only() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let with_transaction_hash = false; let batched = true; for is_v2_ix in [true, false].into_iter() { @@ -2318,7 +2318,7 @@ async fn cpi_context_with_read_only() { #[serial] #[tokio::test] async fn cpi_context_with_account_info() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let with_transaction_hash = false; let batched = true; for is_v2_ix in [true, false].into_iter() { @@ -2837,7 +2837,7 @@ fn get_output_account_info(output_merkle_tree_index: u8) -> OutAccountInfo { #[serial] #[tokio::test] async fn test_duplicate_account_in_inputs_and_read_only() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let mut config = ProgramTestConfig::default_with_batched_trees(false); config.with_prover = false; diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index 3fcfea43c2..f9c5a8a42a 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -136,9 +136,7 @@ use light_program_test::{ use light_prover_client::{ constants::{PROVE_PATH, SERVER_ADDRESS}, proof::{compress_proof, deserialize_gnark_proof_json, proof_from_json_struct}, - proof_type::ProofType, proof_types::batch_address_append::{get_batch_address_append_circuit_inputs, to_json}, - prover::ProverConfig, }; use light_registry::{ account_compression_cpi::sdk::create_batch_update_address_tree_instruction, @@ -3196,20 +3194,6 @@ pub struct KeypairActionConfig { } impl KeypairActionConfig { - pub fn prover_config(&self) -> ProverConfig { - let mut config = ProverConfig::default(); - - if self.inclusion() { - config.circuits.push(ProofType::Inclusion); - } - - if self.non_inclusion() { - config.circuits.push(ProofType::NonInclusion); - } - - config - } - pub fn inclusion(&self) -> bool { self.transfer_sol.is_some() || self.transfer_spl.is_some() } diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index 89215f7d3f..a4f0de334d 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -21,7 +21,7 @@ use crate::{ const MAX_RETRIES: u32 = 10; const BASE_RETRY_DELAY_SECS: u64 = 1; const DEFAULT_POLLING_INTERVAL_SECS: u64 = 1; -const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 120; +const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 600; const DEFAULT_LOCAL_SERVER: &str = "http://localhost:3001"; #[derive(Debug, Deserialize)] diff --git a/prover/client/src/prover.rs b/prover/client/src/prover.rs index ec6c8cc7a6..3bf1bab785 100644 --- a/prover/client/src/prover.rs +++ b/prover/client/src/prover.rs @@ -1,5 +1,4 @@ use std::{ - fmt::{Display, Formatter}, process::Command, sync::atomic::{AtomicBool, Ordering}, thread::sleep, @@ -11,68 +10,11 @@ use tracing::info; use crate::{ constants::{HEALTH_CHECK, SERVER_ADDRESS}, helpers::get_project_root, - proof_type::ProofType, }; static IS_LOADING: AtomicBool = AtomicBool::new(false); -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum ProverMode { - Rpc, - Forester, - ForesterTest, - Full, - FullTest, -} - -impl Display for ProverMode { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - ProverMode::Rpc => "rpc", - ProverMode::Forester => "forester", - ProverMode::ForesterTest => "forester-test", - ProverMode::Full => "full", - ProverMode::FullTest => "full-test", - } - ) - } -} - -#[derive(Debug, Clone)] -pub struct ProverConfig { - pub run_mode: Option, - pub circuits: Vec, -} - -impl Default for ProverConfig { - #[cfg(feature = "devenv")] - fn default() -> Self { - Self { - run_mode: Some(ProverMode::ForesterTest), - circuits: vec![], - } - } - #[cfg(not(feature = "devenv"))] - fn default() -> Self { - Self { - run_mode: Some(ProverMode::Rpc), - circuits: vec![], - } - } -} -impl ProverConfig { - pub fn rpc_no_restart() -> Self { - Self { - run_mode: Some(ProverMode::Rpc), - circuits: vec![], - } - } -} - -pub async fn spawn_prover(config: ProverConfig) { +pub async fn spawn_prover() { if let Some(_project_root) = get_project_root() { let prover_path: &str = { #[cfg(feature = "devenv")] @@ -89,23 +31,12 @@ pub async fn spawn_prover(config: ProverConfig) { if !health_check(10, 1).await && !IS_LOADING.load(Ordering::Relaxed) { IS_LOADING.store(true, Ordering::Relaxed); - let mut command = Command::new(prover_path); - command.arg("start-prover"); - - if let Some(ref mode) = config.run_mode { - command.arg("--run-mode").arg(mode.to_string()); - } - - for circuit in config.circuits.clone() { - command.arg("--circuit").arg(circuit.to_string()); - } - - println!("Starting prover with command: {:?}", command); - - let _ = command + let command = Command::new(prover_path) + .arg("start-prover") .spawn() - .expect("Failed to start prover process") - .wait(); + .expect("Failed to start prover process"); + + let _ = command.wait_with_output(); let health_result = health_check(120, 1).await; if health_result { @@ -114,10 +45,6 @@ pub async fn spawn_prover(config: ProverConfig) { panic!("Failed to start prover, health check failed."); } } - #[cfg(not(feature = "devenv"))] - { - "light" - } } else { panic!("Failed to find project root."); }; diff --git a/prover/client/tests/batch_address_append.rs b/prover/client/tests/batch_address_append.rs index 483b022de9..22f58d5362 100644 --- a/prover/client/tests/batch_address_append.rs +++ b/prover/client/tests/batch_address_append.rs @@ -6,7 +6,7 @@ use light_prover_client::{ proof_types::batch_address_append::{ get_batch_address_append_circuit_inputs, to_json, BatchAddressAppendInputs, }, - prover::{spawn_prover, ProverConfig}, + prover::spawn_prover, }; use light_sparse_merkle_tree::{ changelog::ChangelogEntry, indexed_changelog::IndexedChangelogEntry, SparseMerkleTree, @@ -23,7 +23,7 @@ async fn prove_batch_address_append() { use light_merkle_tree_reference::indexed::IndexedMerkleTree; println!("spawning prover"); - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; // Initialize test data let mut new_element_values = vec![]; diff --git a/prover/client/tests/batch_append.rs b/prover/client/tests/batch_append.rs index 523d4486bb..801e155e92 100644 --- a/prover/client/tests/batch_append.rs +++ b/prover/client/tests/batch_append.rs @@ -3,7 +3,7 @@ use light_merkle_tree_reference::MerkleTree; use light_prover_client::{ constants::{DEFAULT_BATCH_STATE_TREE_HEIGHT, PROVE_PATH, SERVER_ADDRESS}, proof_types::batch_append::{get_batch_append_inputs, BatchAppendInputsJson}, - prover::{spawn_prover, ProverConfig}, + prover::spawn_prover, }; use reqwest::Client; use serial_test::serial; @@ -12,8 +12,7 @@ mod init_merkle_tree; #[serial] #[tokio::test] async fn prove_batch_append_with_proofs() { - // Spawn the prover with specific configuration - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; const HEIGHT: usize = DEFAULT_BATCH_STATE_TREE_HEIGHT as usize; const CANOPY: usize = 0; diff --git a/prover/client/tests/batch_update.rs b/prover/client/tests/batch_update.rs index 90efc5132a..31b7a809c4 100644 --- a/prover/client/tests/batch_update.rs +++ b/prover/client/tests/batch_update.rs @@ -3,7 +3,7 @@ use light_merkle_tree_reference::MerkleTree; use light_prover_client::{ constants::{DEFAULT_BATCH_STATE_TREE_HEIGHT, PROVE_PATH, SERVER_ADDRESS}, proof_types::batch_update::{get_batch_update_inputs, update_inputs_string}, - prover::{spawn_prover, ProverConfig}, + prover::spawn_prover, }; use reqwest::Client; use serial_test::serial; @@ -12,7 +12,7 @@ mod init_merkle_tree; #[serial] #[tokio::test] async fn prove_batch_update() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; const HEIGHT: usize = DEFAULT_BATCH_STATE_TREE_HEIGHT as usize; const CANOPY: usize = 0; let num_insertions = 10; diff --git a/prover/client/tests/combined.rs b/prover/client/tests/combined.rs index 2520e7d077..92b9cc2644 100644 --- a/prover/client/tests/combined.rs +++ b/prover/client/tests/combined.rs @@ -1,6 +1,6 @@ use light_prover_client::{ constants::{PROVE_PATH, SERVER_ADDRESS}, - prover::{spawn_prover, ProverConfig}, + prover::spawn_prover, }; use reqwest::Client; use serial_test::serial; @@ -10,7 +10,7 @@ use crate::init_merkle_tree::{combined_inputs_string_v1, combined_inputs_string_ #[serial] #[tokio::test] async fn prove_combined() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let client = Client::new(); { for i in 1..=4 { diff --git a/prover/client/tests/inclusion.rs b/prover/client/tests/inclusion.rs index 9ff510e6f8..be50a6d3cb 100644 --- a/prover/client/tests/inclusion.rs +++ b/prover/client/tests/inclusion.rs @@ -1,6 +1,6 @@ use light_prover_client::{ constants::{PROVE_PATH, SERVER_ADDRESS}, - prover::{spawn_prover, ProverConfig}, + prover::spawn_prover, }; use reqwest::Client; use serial_test::serial; @@ -10,7 +10,7 @@ use crate::init_merkle_tree::{inclusion_inputs_string_v1, inclusion_inputs_strin #[serial] #[tokio::test] async fn prove_inclusion() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let client = Client::new(); // v2 - test all keys from 1 to 20 diff --git a/prover/client/tests/non_inclusion.rs b/prover/client/tests/non_inclusion.rs index 43f348f6e7..39fc0d21dd 100644 --- a/prover/client/tests/non_inclusion.rs +++ b/prover/client/tests/non_inclusion.rs @@ -1,6 +1,6 @@ use light_prover_client::{ constants::{PROVE_PATH, SERVER_ADDRESS}, - prover::{spawn_prover, ProverConfig}, + prover::spawn_prover, }; use reqwest::Client; use serial_test::serial; @@ -10,7 +10,7 @@ use crate::init_merkle_tree::{non_inclusion_inputs_string_v1, non_inclusion_inpu #[serial] #[tokio::test] async fn prove_non_inclusion() { - spawn_prover(ProverConfig::default()).await; + spawn_prover().await; let client = Client::new(); // legacy height 26 { diff --git a/prover/server/README.md b/prover/server/README.md index 82e2ac05ca..501c56e421 100644 --- a/prover/server/README.md +++ b/prover/server/README.md @@ -111,9 +111,6 @@ The Docker image is configured to include only the necessary proving-key files: ### Building the Docker Image ```shell -# First ensure you have the proving keys downloaded -./scripts/download_keys.sh light - # Build the Docker image with the selected proving keys # Make sure to run this command from the prover/server directory docker build -t light-prover . diff --git a/prover/server/integration_test.go b/prover/server/integration_test.go index f954a91ce2..fcfcc2adb9 100644 --- a/prover/server/integration_test.go +++ b/prover/server/integration_test.go @@ -19,72 +19,70 @@ import ( ) var isLightweightMode bool +var preloadKeys bool const ProverAddress = "localhost:8081" const MetricsAddress = "localhost:9999" var instance server.RunningJob +var serverStopped bool func proveEndpoint() string { return "http://" + ProverAddress + "/prove" } func StartServer(isLightweight bool) { + StartServerWithPreload(isLightweight, true) +} + +func StartServerWithPreload(isLightweight bool, preload bool) { logging.Logger().Info().Msg("Setting up the prover") - var keys []string var runMode common.RunMode if isLightweight { - keys = common.GetKeys("./proving-keys/", common.FullTest, []string{}) runMode = common.FullTest } else { - keys = common.GetKeys("./proving-keys/", common.Full, []string{}) runMode = common.Full } - var pssv1 []*common.MerkleProofSystem - var pssv2 []*common.BatchProofSystem - missingKeys := []string{} + downloadConfig := common.DefaultDownloadConfig() + downloadConfig.AutoDownload = true + + keyManager := common.NewLazyKeyManager("./proving-keys/", downloadConfig) - for _, key := range keys { - system, err := common.ReadSystemFromFile(key) + if preload { + // Preload keys for the test run mode + err := keyManager.PreloadForRunMode(runMode) if err != nil { - if os.IsNotExist(err) { - logging.Logger().Warn().Msgf("Key file not found: %s. Skipping this key.", key) - missingKeys = append(missingKeys, key) - continue + logging.Logger().Fatal().Err(err).Msg("Failed to preload proving keys") + return + } + } else { + var testCircuits []string + if isLightweight { + testCircuits = []string{ + "inclusion", "non-inclusion", "combined", + "append-test", "update-test", "address-append-test", + } + } else { + testCircuits = []string{ + "inclusion", "non-inclusion", "combined", + "append", "update", "address-append", } - logging.Logger().Error().Msgf("Error reading proving system from file: %s. Error: %v", key, err) - continue } - switch s := system.(type) { - case *common.MerkleProofSystem: - pssv1 = append(pssv1, s) - case *common.BatchProofSystem: - pssv2 = append(pssv2, s) - default: - logging.Logger().Info().Msgf("Unknown proving system type for file: %s", key) - panic("Unknown proving system type") + err := keyManager.PreloadCircuits(testCircuits) + if err != nil { + logging.Logger().Warn().Err(err).Msg("Failed to preload some test keys, will download on-demand") } } - if len(missingKeys) > 0 { - logging.Logger().Warn().Msgf("Some key files are missing. To download %s keys, run: ./scripts/download_keys.sh %s", - map[bool]string{true: "lightweight", false: "full"}[isLightweight], - map[bool]string{true: "lightweight", false: "full"}[isLightweight]) - } - - if len(pssv1) == 0 && len(pssv2) == 0 { - logging.Logger().Fatal().Msg("No valid proving systems found. Cannot start the server. Please ensure you have downloaded the necessary key files.") - return - } - serverCfg := server.Config{ ProverAddress: ProverAddress, MetricsAddress: MetricsAddress, } logging.Logger().Info().Msg("Starting the server") - instance = server.Run(&serverCfg, []string{}, runMode, pssv1, pssv2) + instance = server.Run(&serverCfg, keyManager) + serverStopped = false // sleep for 1 sec to ensure that the server is up and running before running the tests time.Sleep(1 * time.Second) @@ -93,31 +91,74 @@ func StartServer(isLightweight bool) { } func StopServer() { + if serverStopped { + return + } instance.RequestStop() instance.AwaitStop() + serverStopped = true } func TestMain(m *testing.M) { gnarkLogger.Set(*logging.Logger()) + + runIntegrationTests := false isLightweightMode = true + preloadKeys = true + for _, arg := range os.Args { - if arg == "-test.run=TestFull" { + if strings.Contains(arg, "-test.run=TestFull") { isLightweightMode = false + runIntegrationTests = true + break + } + if strings.Contains(arg, "-test.run=TestLightweightLazy") { + runIntegrationTests = true + preloadKeys = false + break + } + if strings.Contains(arg, "-test.run=TestLightweight") { + runIntegrationTests = true break } } - if isLightweightMode { - logging.Logger().Info().Msg("Running in lightweight mode") - logging.Logger().Info().Msg("If you encounter missing key errors, run: ./scripts/download_keys.sh light") - } else { - logging.Logger().Info().Msg("Running in full mode") - logging.Logger().Info().Msg("If you encounter missing key errors, run: ./scripts/download_keys.sh full") + if !runIntegrationTests { + hasTestRunFlag := false + for _, arg := range os.Args { + if strings.HasPrefix(arg, "-test.run=") { + hasTestRunFlag = true + pattern := strings.TrimPrefix(arg, "-test.run=") + if pattern == "" || pattern == "^Test" || strings.Contains(pattern, "Lightweight") || strings.Contains(pattern, "Full") { + runIntegrationTests = true + } + break + } + } + if !hasTestRunFlag { + runIntegrationTests = true + } } - StartServer(isLightweightMode) - m.Run() - StopServer() + if runIntegrationTests { + if isLightweightMode { + if preloadKeys { + logging.Logger().Info().Msg("Running in lightweight mode - preloading keys") + } else { + logging.Logger().Info().Msg("Running in lazy lightweight mode") + } + } else { + logging.Logger().Info().Msg("Running in full mode - preloading keys") + } + + StartServerWithPreload(isLightweightMode, preloadKeys) + code := m.Run() + StopServer() + os.Exit(code) + } else { + logging.Logger().Info().Msg("Skipping key loading - no integration tests in this run") + os.Exit(m.Run()) + } } func TestLightweight(t *testing.T) { @@ -128,6 +169,19 @@ func TestLightweight(t *testing.T) { runLightweightOnlyTests(t) } +func TestLightweightLazy(t *testing.T) { + if preloadKeys { + t.Skip("This test only runs when preloadKeys is false (lazy mode)") + } + + logging.Logger().Info().Msg("TestLightweightLazy: Running tests with lazy key loading") + + runCommonTests(t) + runLightweightOnlyTests(t) + + logging.Logger().Info().Msg("TestLightweightLazy: All tests passed with lazy loading") +} + func TestFull(t *testing.T) { if isLightweightMode { t.Skip("This test only runs in full mode") diff --git a/prover/server/main.go b/prover/server/main.go index 7f9c80f37e..948718a7a9 100644 --- a/prover/server/main.go +++ b/prover/server/main.go @@ -473,6 +473,93 @@ func runCli() { return nil }, }, + { + Name: "download", + Usage: "Download proving keys", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "run-mode", + Usage: "Download keys for specific run mode (rpc, forester, forester-test, full, full-test, local-rpc)", + Value: "local-rpc", + }, + &cli.StringSliceFlag{ + Name: "circuit", + Usage: "Download keys for specific circuits (inclusion, non-inclusion, combined, append, update, append-test, update-test, address-append, address-append-test)", + }, + &cli.StringFlag{ + Name: "keys-dir", + Usage: "Directory where key files will be stored", + Value: "./proving-keys/", + }, + &cli.StringFlag{ + Name: "download-url", + Usage: "Base URL for downloading key files", + Value: common.DefaultBaseURL, + }, + &cli.IntFlag{ + Name: "max-retries", + Usage: "Maximum number of retries for downloading keys", + Value: common.DefaultMaxRetries, + }, + &cli.BoolFlag{ + Name: "verify-only", + Usage: "Only verify existing keys without downloading", + Value: false, + }, + }, + Action: func(context *cli.Context) error { + circuits := context.StringSlice("circuit") + runMode, err := parseRunMode(context.String("run-mode")) + if err != nil { + return err + } + + keysDirPath := context.String("keys-dir") + verifyOnly := context.Bool("verify-only") + + // Configure download settings + downloadConfig := &common.DownloadConfig{ + BaseURL: context.String("download-url"), + MaxRetries: context.Int("max-retries"), + RetryDelay: common.DefaultRetryDelay, + MaxRetryDelay: common.DefaultMaxRetryDelay, + AutoDownload: !verifyOnly, + } + + logging.Logger().Info(). + Str("run_mode", string(runMode)). + Strs("circuits", circuits). + Str("keys_dir", keysDirPath). + Bool("verify_only", verifyOnly). + Str("download_url", downloadConfig.BaseURL). + Int("max_retries", downloadConfig.MaxRetries). + Msg("Download configuration") + + // Get required keys + keys := common.GetKeys(keysDirPath, runMode, circuits) + + if len(keys) == 0 { + return fmt.Errorf("no keys to download for run-mode=%s circuits=%v", runMode, circuits) + } + + logging.Logger().Info(). + Int("total_keys", len(keys)). + Msg("Starting key download/verification") + + // Download/verify keys + if err := common.EnsureKeysExist(keys, downloadConfig); err != nil { + return fmt.Errorf("failed to ensure keys exist: %w", err) + } + + if verifyOnly { + logging.Logger().Info().Msg("All keys verified successfully") + } else { + logging.Logger().Info().Msg("All keys downloaded and verified successfully") + } + + return nil + }, + }, { Name: "start", Flags: []cli.Flag{ @@ -488,6 +575,15 @@ func runCli() { Name: "run-mode", Usage: "Specify the running mode (rpc, forester, forester-test, full, or full-test)", }, + &cli.StringFlag{ + Name: "preload-keys", + Usage: "Preload keys: none (lazy load all), all (preload everything), or a run mode (rpc, forester, forester-test, full, full-test, local-rpc)", + Value: "none", + }, + &cli.StringSliceFlag{ + Name: "preload-circuits", + Usage: "Preload specific circuits, e.g.: update,append,batch_update_32_500,batch_append_32_500)", + }, &cli.StringFlag{ Name: "redis-url", Usage: "Redis URL for queue processing (e.g., redis://localhost:6379)", @@ -503,31 +599,77 @@ func runCli() { Usage: "Run only HTTP server (no queue workers)", Value: false, }, + &cli.BoolFlag{ + Name: "auto-download", + Usage: "Automatically download missing key files", + Value: true, + }, + &cli.StringFlag{ + Name: "download-url", + Usage: "Base URL for downloading key files", + Value: common.DefaultBaseURL, + }, + &cli.IntFlag{ + Name: "download-max-retries", + Usage: "Maximum number of retries for downloading keys", + Value: common.DefaultMaxRetries, + }, }, Action: func(context *cli.Context) error { if context.Bool("json-logging") { logging.SetJSONOutput() } - circuits := context.StringSlice("circuit") - runMode, err := parseRunMode(context.String("run-mode")) - if err != nil { - if len(circuits) == 0 { - return err - } + var keysDirPath = context.String("keys-dir") + + // Configure download settings + downloadConfig := &common.DownloadConfig{ + BaseURL: context.String("download-url"), + MaxRetries: context.Int("download-max-retries"), + RetryDelay: common.DefaultRetryDelay, + MaxRetryDelay: common.DefaultMaxRetryDelay, + AutoDownload: context.Bool("auto-download"), } - var keysDirPath = context.String("keys-dir") - debugProvingSystemKeys(keysDirPath, runMode, circuits) - psv1, psv2, err := common.LoadKeys(keysDirPath, runMode, circuits) - if err != nil { - return err + keyManager := common.NewLazyKeyManager(keysDirPath, downloadConfig) + + preloadKeys := context.String("preload-keys") + preloadCircuits := context.StringSlice("preload-circuits") + + logging.Logger().Info(). + Str("preload_keys", preloadKeys). + Strs("preload_circuits", preloadCircuits). + Str("keys_dir", keysDirPath). + Msg("Initializing lazy key manager") + + if preloadKeys == "all" { + logging.Logger().Info().Msg("Preloading all keys") + if err := keyManager.PreloadAll(); err != nil { + return fmt.Errorf("failed to preload all keys: %w", err) + } + } else if preloadKeys != "none" { + preloadRunMode, err := parseRunMode(preloadKeys) + if err != nil { + return fmt.Errorf("invalid --preload-keys value: %s (must be none, all, or a valid run mode: rpc, forester, forester-test, full, full-test, local-rpc)", preloadKeys) + } + logging.Logger().Info().Str("run_mode", string(preloadRunMode)).Msg("Preloading keys for run mode") + if err := keyManager.PreloadForRunMode(preloadRunMode); err != nil { + return fmt.Errorf("failed to preload keys for run mode: %w", err) + } } - if len(psv1) == 0 && len(psv2) == 0 { - return fmt.Errorf("no proving systems loaded") + if len(preloadCircuits) > 0 { + logging.Logger().Info().Strs("circuits", preloadCircuits).Msg("Preloading specific circuits") + if err := keyManager.PreloadCircuits(preloadCircuits); err != nil { + return fmt.Errorf("failed to preload circuits: %w", err) + } } + stats := keyManager.GetStats() + logging.Logger().Info(). + Interface("stats", stats). + Msg("Key manager initialized") + redisURL := context.String("redis-url") if redisURL == "" { redisURL = os.Getenv("REDIS_URL") @@ -561,6 +703,7 @@ func runCli() { return fmt.Errorf("Redis URL is required for queue mode. Use --redis-url or set REDIS_URL environment variable") } + var err error redisQueue, err = server.NewRedisQueue(redisURL) if err != nil { return fmt.Errorf("failed to connect to Redis: %w", err) @@ -574,47 +717,21 @@ func runCli() { logging.Logger().Info().Msg("Starting queue workers") - startAllWorkers := runMode == common.Forester || runMode == common.ForesterTest || runMode == common.Full || runMode == common.FullTest + updateWorker := server.NewUpdateQueueWorker(redisQueue, keyManager) + workers = append(workers, updateWorker) + go updateWorker.Start() - var workersStarted []string + appendWorker := server.NewAppendQueueWorker(redisQueue, keyManager) + workers = append(workers, appendWorker) + go appendWorker.Start() - logging.Logger().Info().Bool("startAllWorkers", startAllWorkers) + addressAppendWorker := server.NewAddressAppendQueueWorker(redisQueue, keyManager) + workers = append(workers, addressAppendWorker) + go addressAppendWorker.Start() - for _, circuit := range circuits { - logging.Logger().Info().Str("circuit", circuit) - } - // Start update worker for batch-update circuits or forester modes - if startAllWorkers || containsCircuit(circuits, "update") || containsCircuit(circuits, "update-test") { - updateWorker := server.NewUpdateQueueWorker(redisQueue, psv1, psv2) - workers = append(workers, updateWorker) - go updateWorker.Start() - workersStarted = append(workersStarted, "update") - } - - // Start append worker for batch-append circuits or forester modes - if startAllWorkers || containsCircuit(circuits, "append") || containsCircuit(circuits, "append-test") { - appendWorker := server.NewAppendQueueWorker(redisQueue, psv1, psv2) - workers = append(workers, appendWorker) - go appendWorker.Start() - workersStarted = append(workersStarted, "append") - } - - // Start address append worker for address-append circuits or forester modes - if startAllWorkers || containsCircuit(circuits, "address-append") || containsCircuit(circuits, "address-append-test") { - addressAppendWorker := server.NewAddressAppendQueueWorker(redisQueue, psv1, psv2) - workers = append(workers, addressAppendWorker) - go addressAppendWorker.Start() - workersStarted = append(workersStarted, "address-append") - } - - if len(workersStarted) == 0 { - logging.Logger().Warn().Msg("No queue workers started - no matching circuits found") - } else { - logging.Logger().Info(). - Strs("workers_started", workersStarted). - Bool("forester_mode", startAllWorkers). - Msg("Queue workers started") - } + logging.Logger().Info(). + Strs("workers_started", []string{"update", "append", "address-append"}). + Msg("Queue workers started") } if enableServer { @@ -624,13 +741,13 @@ func runCli() { } if redisQueue != nil { - instance = server.RunWithQueue(&config, redisQueue, circuits, runMode, psv1, psv2) + instance = server.RunWithQueue(&config, redisQueue, keyManager) logging.Logger().Info(). Str("prover_address", config.ProverAddress). Str("metrics_address", config.MetricsAddress). Msg("Started enhanced server with Redis queue support") } else { - instance = server.Run(&config, circuits, runMode, psv1, psv2) + instance = server.Run(&config, keyManager) logging.Logger().Info(). Str("prover_address", config.ProverAddress). Str("metrics_address", config.MetricsAddress). diff --git a/prover/server/prover/common/key_downloader.go b/prover/server/prover/common/key_downloader.go new file mode 100644 index 0000000000..c4ad334823 --- /dev/null +++ b/prover/server/prover/common/key_downloader.go @@ -0,0 +1,413 @@ +package common + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "light/light-prover/logging" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + DefaultBaseURL = "https://storage.googleapis.com/light-protocol-proving-keys/proving-keys-06-10-25" + DefaultMaxRetries = 10 + DefaultRetryDelay = 5 * time.Second + DefaultMaxRetryDelay = 5 * time.Minute +) + +type DownloadConfig struct { + BaseURL string + MaxRetries int + RetryDelay time.Duration + MaxRetryDelay time.Duration + AutoDownload bool +} + +func DefaultDownloadConfig() *DownloadConfig { + return &DownloadConfig{ + BaseURL: DefaultBaseURL, + MaxRetries: DefaultMaxRetries, + RetryDelay: DefaultRetryDelay, + MaxRetryDelay: DefaultMaxRetryDelay, + AutoDownload: true, + } +} + +type checksumCacheEntry struct { + checksums map[string]string + loaded bool +} + +type checksumCacheManager struct { + mu sync.RWMutex + caches map[string]*checksumCacheEntry +} + +var globalChecksumCaches = &checksumCacheManager{ + caches: make(map[string]*checksumCacheEntry), +} + +func downloadChecksum(config *DownloadConfig) error { + globalChecksumCaches.mu.RLock() + if entry, exists := globalChecksumCaches.caches[config.BaseURL]; exists && entry.loaded { + globalChecksumCaches.mu.RUnlock() + return nil + } + globalChecksumCaches.mu.RUnlock() + + globalChecksumCaches.mu.Lock() + defer globalChecksumCaches.mu.Unlock() + + if entry, exists := globalChecksumCaches.caches[config.BaseURL]; exists && entry.loaded { + return nil + } + + checksumURL := config.BaseURL + "/CHECKSUM" + logging.Logger().Info(). + Str("url", checksumURL). + Msg("Downloading CHECKSUM file") + + resp, err := http.Get(checksumURL) + if err != nil { + return fmt.Errorf("failed to download CHECKSUM file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download CHECKSUM file: HTTP %d", resp.StatusCode) + } + + content, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read CHECKSUM file: %w", err) + } + + entry := &checksumCacheEntry{ + checksums: make(map[string]string), + loaded: false, + } + + // Parse CHECKSUM file (format: "checksum filename") + lines := strings.Split(string(content), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.Fields(line) + if len(parts) >= 2 { + checksum := parts[0] + filename := parts[1] + entry.checksums[filename] = checksum + } + } + + entry.loaded = true + globalChecksumCaches.caches[config.BaseURL] = entry + + logging.Logger().Info(). + Int("count", len(entry.checksums)). + Str("base_url", config.BaseURL). + Msg("Loaded checksums") + + return nil +} + +func verifyChecksum(filepath string, expectedChecksum string) (bool, error) { + file, err := os.Open(filepath) + if err != nil { + return false, err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return false, err + } + + actualChecksum := hex.EncodeToString(hash.Sum(nil)) + return actualChecksum == expectedChecksum, nil +} + +func calculateBackoff(attempt int, initialDelay, maxDelay time.Duration) time.Duration { + delay := initialDelay * time.Duration(1< maxDelay { + return maxDelay + } + return delay +} + +func downloadFileWithResume(url, outputPath string, config *DownloadConfig) error { + tempPath := outputPath + ".tmp" + + for attempt := 1; attempt <= config.MaxRetries; attempt++ { + var existingSize int64 = 0 + if fileInfo, err := os.Stat(tempPath); err == nil { + existingSize = fileInfo.Size() + } + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + if existingSize > 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize)) + logging.Logger().Info(). + Str("url", url). + Int64("resume_from", existingSize). + Int("attempt", attempt). + Int("max_retries", config.MaxRetries). + Msg("Resuming download") + } else { + logging.Logger().Info(). + Str("url", url). + Int("attempt", attempt). + Int("max_retries", config.MaxRetries). + Msg("Starting download") + } + + client := &http.Client{ + Timeout: 60 * time.Minute, + } + resp, err := client.Do(req) + if err != nil { + if attempt < config.MaxRetries { + delay := calculateBackoff(attempt, config.RetryDelay, config.MaxRetryDelay) + logging.Logger().Warn(). + Err(err). + Dur("retry_delay", delay). + Msg("Download failed, retrying") + time.Sleep(delay) + continue + } + return fmt.Errorf("failed to download after %d attempts: %w", config.MaxRetries, err) + } + + // Check response status + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + resp.Body.Close() + if attempt < config.MaxRetries { + delay := calculateBackoff(attempt, config.RetryDelay, config.MaxRetryDelay) + logging.Logger().Warn(). + Int("status_code", resp.StatusCode). + Dur("retry_delay", delay). + Msg("Unexpected status code, retrying") + time.Sleep(delay) + continue + } + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var file *os.File + if existingSize > 0 && resp.StatusCode == http.StatusPartialContent { + file, err = os.OpenFile(tempPath, os.O_APPEND|os.O_WRONLY, 0644) + } else { + file, err = os.Create(tempPath) + existingSize = 0 + } + if err != nil { + resp.Body.Close() + return fmt.Errorf("failed to open file: %w", err) + } + + totalSize := existingSize + resp.ContentLength + downloadedBytes := existingSize + lastLogTime := time.Now() + logInterval := 5 * time.Second + + buffer := make([]byte, 32*1024) + for { + n, err := resp.Body.Read(buffer) + if n > 0 { + if _, writeErr := file.Write(buffer[:n]); writeErr != nil { + file.Close() + resp.Body.Close() + return fmt.Errorf("failed to write to file: %w", writeErr) + } + downloadedBytes += int64(n) + + if time.Since(lastLogTime) >= logInterval { + if totalSize > 0 { + progress := float64(downloadedBytes) / float64(totalSize) * 100 + logging.Logger().Info(). + Int64("downloaded", downloadedBytes). + Int64("total", totalSize). + Float64("progress", progress). + Msg("Download progress") + } + lastLogTime = time.Now() + } + } + if err == io.EOF { + break + } + if err != nil { + file.Close() + resp.Body.Close() + if attempt < config.MaxRetries { + delay := calculateBackoff(attempt, config.RetryDelay, config.MaxRetryDelay) + logging.Logger().Warn(). + Err(err). + Dur("retry_delay", delay). + Msg("Download interrupted, retrying") + time.Sleep(delay) + continue + } + return fmt.Errorf("download failed: %w", err) + } + } + + file.Close() + resp.Body.Close() + + if err := os.Rename(tempPath, outputPath); err != nil { + return fmt.Errorf("failed to rename temp file: %w", err) + } + + logging.Logger().Info(). + Str("file", filepath.Base(outputPath)). + Int64("size", downloadedBytes). + Msg("Download completed successfully") + + return nil + } + + return fmt.Errorf("failed to download after %d attempts", config.MaxRetries) +} + +func DownloadKey(keyPath string, config *DownloadConfig) error { + filename := filepath.Base(keyPath) + + if err := downloadChecksum(config); err != nil { + return fmt.Errorf("failed to load checksums: %w", err) + } + + globalChecksumCaches.mu.RLock() + entry, exists := globalChecksumCaches.caches[config.BaseURL] + if !exists { + globalChecksumCaches.mu.RUnlock() + return fmt.Errorf("checksum cache not found for BaseURL: %s", config.BaseURL) + } + expectedChecksum, checksumExists := entry.checksums[filename] + globalChecksumCaches.mu.RUnlock() + + if !checksumExists { + return fmt.Errorf("no checksum found for %s", filename) + } + + if fileInfo, err := os.Stat(keyPath); err == nil { + logging.Logger().Info(). + Str("file", filename). + Int64("size", fileInfo.Size()). + Msg("Verifying existing key file") + + valid, err := verifyChecksum(keyPath, expectedChecksum) + if err != nil { + if !config.AutoDownload { + return fmt.Errorf("key file %s exists but failed verification (auto-download disabled): %w", filename, err) + } + logging.Logger().Warn(). + Err(err). + Str("file", filename). + Msg("Failed to verify checksum, will re-download") + } else if valid { + logging.Logger().Info(). + Str("file", filename). + Msg("Key file is valid, skipping download") + return nil + } else { + if !config.AutoDownload { + return fmt.Errorf("key file %s checksum mismatch (auto-download disabled)", filename) + } + logging.Logger().Warn(). + Str("file", filename). + Msg("Checksum mismatch, re-downloading") + os.Remove(keyPath) + } + } else if os.IsNotExist(err) { + if !config.AutoDownload { + return fmt.Errorf("required key file not found: %s (auto-download disabled)", filename) + } + } else { + return fmt.Errorf("failed to check key file %s: %w", filename, err) + } + + if err := os.MkdirAll(filepath.Dir(keyPath), 0755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + url := fmt.Sprintf("%s/%s", config.BaseURL, filename) + logging.Logger().Info(). + Str("file", filename). + Str("url", url). + Msg("Downloading key file") + + if err := downloadFileWithResume(url, keyPath, config); err != nil { + return err + } + + valid, err := verifyChecksum(keyPath, expectedChecksum) + if err != nil { + return fmt.Errorf("failed to verify downloaded file: %w", err) + } + if !valid { + os.Remove(keyPath) + return fmt.Errorf("downloaded file checksum mismatch") + } + + logging.Logger().Info(). + Str("file", filename). + Msg("Key file downloaded and verified successfully") + + return nil +} + +func EnsureKeysExist(keys []string, config *DownloadConfig) error { + if !config.AutoDownload { + for _, key := range keys { + if _, err := os.Stat(key); os.IsNotExist(err) { + return fmt.Errorf("required key file not found: %s (auto-download disabled)", key) + } + } + return nil + } + + if err := downloadChecksum(config); err != nil { + return fmt.Errorf("failed to download checksums: %w", err) + } + + var missingKeys []string + for _, key := range keys { + if _, err := os.Stat(key); os.IsNotExist(err) { + missingKeys = append(missingKeys, key) + } + } + + if len(missingKeys) > 0 { + logging.Logger().Info(). + Int("missing_count", len(missingKeys)). + Int("total_count", len(keys)). + Msg("Found missing key files, will download") + + for i, key := range missingKeys { + logging.Logger().Info(). + Int("current", i+1). + Int("total", len(missingKeys)). + Str("file", filepath.Base(key)). + Msg("Downloading missing key") + + if err := DownloadKey(key, config); err != nil { + return fmt.Errorf("failed to download key %s: %w", filepath.Base(key), err) + } + } + } + + return nil +} diff --git a/prover/server/prover/common/lazy_key_manager.go b/prover/server/prover/common/lazy_key_manager.go new file mode 100644 index 0000000000..13556bcdbe --- /dev/null +++ b/prover/server/prover/common/lazy_key_manager.go @@ -0,0 +1,438 @@ +package common + +import ( + "fmt" + "light/light-prover/logging" + "strings" + "sync" +) + +type LazyKeyManager struct { + mu sync.RWMutex + merkleSystems map[string]*MerkleProofSystem + batchSystems map[string]*BatchProofSystem + keysDir string + downloadConfig *DownloadConfig + loadingInProgress map[string]chan struct{} +} + +func NewLazyKeyManager(keysDir string, downloadConfig *DownloadConfig) *LazyKeyManager { + if downloadConfig == nil { + downloadConfig = DefaultDownloadConfig() + } + return &LazyKeyManager{ + merkleSystems: make(map[string]*MerkleProofSystem), + batchSystems: make(map[string]*BatchProofSystem), + keysDir: keysDir, + downloadConfig: downloadConfig, + loadingInProgress: make(map[string]chan struct{}), + } +} + +func (m *LazyKeyManager) GetMerkleSystem( + inclusionTreeHeight uint32, + inclusionCompressedAccounts uint32, + nonInclusionTreeHeight uint32, + nonInclusionCompressedAccounts uint32, + version uint32, +) (*MerkleProofSystem, error) { + var key string + if inclusionCompressedAccounts > 0 && nonInclusionCompressedAccounts > 0 { + key = fmt.Sprintf("comb_%d_%d_%d_%d_v%d", inclusionTreeHeight, inclusionCompressedAccounts, nonInclusionTreeHeight, nonInclusionCompressedAccounts, version) + } else if inclusionCompressedAccounts > 0 { + key = fmt.Sprintf("inc_%d_%d_v%d", inclusionTreeHeight, inclusionCompressedAccounts, version) + } else if nonInclusionCompressedAccounts > 0 { + key = fmt.Sprintf("non_%d_%d_v%d", nonInclusionTreeHeight, nonInclusionCompressedAccounts, version) + } else { + return nil, fmt.Errorf("invalid parameters: must specify either inclusion or non-inclusion accounts") + } + + m.mu.RLock() + if ps, exists := m.merkleSystems[key]; exists { + m.mu.RUnlock() + logging.Logger().Debug(). + Str("key", key). + Msg("Found cached MerkleProofSystem") + return ps, nil + } + m.mu.RUnlock() + + return m.loadMerkleSystem(key, inclusionTreeHeight, inclusionCompressedAccounts, nonInclusionTreeHeight, nonInclusionCompressedAccounts, version) +} + +func (m *LazyKeyManager) GetBatchSystem(circuitType CircuitType, treeHeight uint32, batchSize uint32) (*BatchProofSystem, error) { + key := fmt.Sprintf("%s_%d_%d", circuitType, treeHeight, batchSize) + + m.mu.RLock() + if ps, exists := m.batchSystems[key]; exists { + m.mu.RUnlock() + logging.Logger().Debug(). + Str("key", key). + Msg("Found cached BatchProofSystem") + return ps, nil + } + m.mu.RUnlock() + + return m.loadBatchSystem(key, circuitType, treeHeight, batchSize) +} + +func (m *LazyKeyManager) loadMerkleSystem( + key string, + inclusionTreeHeight uint32, + inclusionCompressedAccounts uint32, + nonInclusionTreeHeight uint32, + nonInclusionCompressedAccounts uint32, + version uint32, +) (*MerkleProofSystem, error) { + loadChan := m.acquireLoadingLock(key) + if loadChan == nil { + m.waitForLoading(key) + m.mu.RLock() + ps, exists := m.merkleSystems[key] + m.mu.RUnlock() + if exists { + return ps, nil + } + return nil, fmt.Errorf("loading completed but system not found in cache") + } + defer m.releaseLoadingLock(key, loadChan) + + keyPath := m.determineMerkleKeyPath(inclusionTreeHeight, inclusionCompressedAccounts, nonInclusionTreeHeight, nonInclusionCompressedAccounts, version) + if keyPath == "" { + return nil, fmt.Errorf("no key file mapping for parameters: inc(%d,%d) non(%d,%d) v%d", + inclusionTreeHeight, inclusionCompressedAccounts, nonInclusionTreeHeight, nonInclusionCompressedAccounts, version) + } + + logging.Logger().Info(). + Str("key_path", keyPath). + Str("cache_key", key). + Msg("Loading MerkleProofSystem") + + if err := DownloadKey(keyPath, m.downloadConfig); err != nil { + return nil, fmt.Errorf("failed to download key %s: %w", keyPath, err) + } + + system, err := ReadSystemFromFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to load key %s: %w", keyPath, err) + } + + ps, ok := system.(*MerkleProofSystem) + if !ok { + return nil, fmt.Errorf("expected MerkleProofSystem but got different type") + } + + m.mu.Lock() + m.merkleSystems[key] = ps + m.mu.Unlock() + + logging.Logger().Info(). + Str("cache_key", key). + Uint32("inc_height", ps.InclusionTreeHeight). + Uint32("inc_accounts", ps.InclusionNumberOfCompressedAccounts). + Uint32("non_height", ps.NonInclusionTreeHeight). + Uint32("non_accounts", ps.NonInclusionNumberOfCompressedAccounts). + Msg("MerkleProofSystem loaded and cached successfully") + + return ps, nil +} + +func (m *LazyKeyManager) loadBatchSystem(key string, circuitType CircuitType, treeHeight uint32, batchSize uint32) (*BatchProofSystem, error) { + loadChan := m.acquireLoadingLock(key) + if loadChan == nil { + m.waitForLoading(key) + m.mu.RLock() + ps, exists := m.batchSystems[key] + m.mu.RUnlock() + if exists { + return ps, nil + } + return nil, fmt.Errorf("loading completed but system not found in cache") + } + defer m.releaseLoadingLock(key, loadChan) + + keyPath := m.determineBatchKeyPath(circuitType, treeHeight, batchSize) + if keyPath == "" { + return nil, fmt.Errorf("no key file mapping for %s with height %d and batch size %d", circuitType, treeHeight, batchSize) + } + + logging.Logger().Info(). + Str("key_path", keyPath). + Str("cache_key", key). + Msg("Loading BatchProofSystem") + + if err := DownloadKey(keyPath, m.downloadConfig); err != nil { + return nil, fmt.Errorf("failed to download key %s: %w", keyPath, err) + } + + system, err := ReadSystemFromFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to load key %s: %w", keyPath, err) + } + + ps, ok := system.(*BatchProofSystem) + if !ok { + return nil, fmt.Errorf("expected BatchProofSystem but got different type") + } + + m.mu.Lock() + m.batchSystems[key] = ps + m.mu.Unlock() + + logging.Logger().Info(). + Str("cache_key", key). + Uint32("tree_height", ps.TreeHeight). + Uint32("batch_size", ps.BatchSize). + Str("circuit_type", string(ps.CircuitType)). + Msg("BatchProofSystem loaded and cached successfully") + + return ps, nil +} + +func (m *LazyKeyManager) acquireLoadingLock(key string) chan struct{} { + m.mu.Lock() + defer m.mu.Unlock() + + if _, loading := m.loadingInProgress[key]; loading { + return nil + } + + ch := make(chan struct{}) + m.loadingInProgress[key] = ch + return ch +} + +func (m *LazyKeyManager) waitForLoading(key string) { + m.mu.RLock() + ch := m.loadingInProgress[key] + m.mu.RUnlock() + + if ch != nil { + <-ch + } +} + +func (m *LazyKeyManager) releaseLoadingLock(key string, ch chan struct{}) { + m.mu.Lock() + delete(m.loadingInProgress, key) + m.mu.Unlock() + close(ch) +} + +func (m *LazyKeyManager) determineMerkleKeyPath( + inclusionTreeHeight uint32, + inclusionCompressedAccounts uint32, + nonInclusionTreeHeight uint32, + nonInclusionCompressedAccounts uint32, + version uint32, +) string { + if inclusionCompressedAccounts > 0 && nonInclusionCompressedAccounts > 0 { + if version == 1 && inclusionTreeHeight == 26 && nonInclusionTreeHeight == 26 { + return fmt.Sprintf("%sv1_combined_26_26_%d_%d.key", m.keysDir, inclusionCompressedAccounts, nonInclusionCompressedAccounts) + } else if version == 2 && inclusionTreeHeight == 32 && nonInclusionTreeHeight == 40 { + return fmt.Sprintf("%sv2_combined_32_40_%d_%d.key", m.keysDir, inclusionCompressedAccounts, nonInclusionCompressedAccounts) + } + } else if inclusionCompressedAccounts > 0 { + if version == 1 && inclusionTreeHeight == 26 { + return fmt.Sprintf("%sv1_inclusion_26_%d.key", m.keysDir, inclusionCompressedAccounts) + } else if version == 2 && inclusionTreeHeight == 32 { + return fmt.Sprintf("%sv2_inclusion_32_%d.key", m.keysDir, inclusionCompressedAccounts) + } + } else if nonInclusionCompressedAccounts > 0 { + if version == 1 && nonInclusionTreeHeight == 26 { + return fmt.Sprintf("%sv1_non-inclusion_26_%d.key", m.keysDir, nonInclusionCompressedAccounts) + } else if version == 2 && nonInclusionTreeHeight == 40 { + return fmt.Sprintf("%sv2_non-inclusion_40_%d.key", m.keysDir, nonInclusionCompressedAccounts) + } + } + + return "" +} + +func (m *LazyKeyManager) determineBatchKeyPath(circuitType CircuitType, treeHeight uint32, batchSize uint32) string { + switch circuitType { + case BatchAppendCircuitType: + if treeHeight == 32 && batchSize == 500 { + return fmt.Sprintf("%sbatch_append_32_500.key", m.keysDir) + } else if treeHeight == 32 && batchSize == 10 { + return fmt.Sprintf("%sbatch_append_32_10.key", m.keysDir) + } + case BatchUpdateCircuitType: + if treeHeight == 32 && batchSize == 500 { + return fmt.Sprintf("%sbatch_update_32_500.key", m.keysDir) + } else if treeHeight == 32 && batchSize == 10 { + return fmt.Sprintf("%sbatch_update_32_10.key", m.keysDir) + } + case BatchAddressAppendCircuitType: + if treeHeight == 40 && batchSize == 250 { + return fmt.Sprintf("%sbatch_address-append_40_250.key", m.keysDir) + } else if treeHeight == 40 && batchSize == 10 { + return fmt.Sprintf("%sbatch_address-append_40_10.key", m.keysDir) + } + } + + return "" +} + +func (m *LazyKeyManager) GetStats() map[string]interface{} { + m.mu.RLock() + defer m.mu.RUnlock() + + return map[string]interface{}{ + "merkle_systems_loaded": len(m.merkleSystems), + "batch_systems_loaded": len(m.batchSystems), + "keys_loading": len(m.loadingInProgress), + } +} + +func (m *LazyKeyManager) PreloadForRunMode(runMode RunMode) error { + logging.Logger().Info(). + Str("run_mode", string(runMode)). + Msg("Preloading keys for run mode") + + keys := GetKeys(m.keysDir, runMode, nil) + return m.preloadKeys(keys) +} + +func (m *LazyKeyManager) PreloadAll() error { + logging.Logger().Info().Msg("Preloading all keys") + + allKeys := make(map[string]bool) + runModes := []RunMode{Full, FullTest} + for _, runMode := range runModes { + keys := GetKeys(m.keysDir, runMode, nil) + for _, key := range keys { + allKeys[key] = true + } + } + + keySlice := make([]string, 0, len(allKeys)) + for key := range allKeys { + keySlice = append(keySlice, key) + } + + return m.preloadKeys(keySlice) +} + +func (m *LazyKeyManager) PreloadCircuits(circuits []string) error { + logging.Logger().Info(). + Strs("circuits", circuits). + Msg("Preloading keys for circuits") + + var keyPaths []string + seen := make(map[string]bool) + + for _, circuit := range circuits { + if specificPath := m.tryParseSpecificConfig(circuit); specificPath != "" { + if !seen[specificPath] { + keyPaths = append(keyPaths, specificPath) + seen[specificPath] = true + } + continue + } + + circuitKeys := GetKeys(m.keysDir, "", []string{circuit}) + for _, key := range circuitKeys { + if !seen[key] { + keyPaths = append(keyPaths, key) + seen[key] = true + } + } + } + + return m.preloadKeys(keyPaths) +} + +func (m *LazyKeyManager) tryParseSpecificConfig(config string) string { + if strings.HasPrefix(config, "batch_") || + strings.HasPrefix(config, "v1_") || + strings.HasPrefix(config, "v2_") { + return fmt.Sprintf("%s%s.key", m.keysDir, config) + } + return "" +} + +func (m *LazyKeyManager) preloadKeys(keyPaths []string) error { + if len(keyPaths) == 0 { + logging.Logger().Info().Msg("No keys to preload") + return nil + } + + logging.Logger().Info(). + Int("count", len(keyPaths)). + Msg("Starting to preload keys") + + for i, keyPath := range keyPaths { + logging.Logger().Info(). + Int("current", i+1). + Int("total", len(keyPaths)). + Str("key_path", keyPath). + Msg("Preloading key") + + if err := DownloadKey(keyPath, m.downloadConfig); err != nil { + return fmt.Errorf("failed to download key %s: %w", keyPath, err) + } + + system, err := ReadSystemFromFile(keyPath) + if err != nil { + return fmt.Errorf("failed to load key %s: %w", keyPath, err) + } + + if err := m.cacheSystem(system); err != nil { + return fmt.Errorf("failed to cache key %s: %w", keyPath, err) + } + } + + logging.Logger().Info(). + Int("count", len(keyPaths)). + Msg("Successfully preloaded all keys") + + return nil +} + +func (m *LazyKeyManager) cacheSystem(system interface{}) error { + m.mu.Lock() + defer m.mu.Unlock() + + switch ps := system.(type) { + case *MerkleProofSystem: + var key string + if ps.InclusionNumberOfCompressedAccounts > 0 && ps.NonInclusionNumberOfCompressedAccounts > 0 { + key = fmt.Sprintf("comb_%d_%d_%d_%d_v%d", + ps.InclusionTreeHeight, + ps.InclusionNumberOfCompressedAccounts, + ps.NonInclusionTreeHeight, + ps.NonInclusionNumberOfCompressedAccounts, + ps.Version) + } else if ps.InclusionNumberOfCompressedAccounts > 0 { + key = fmt.Sprintf("inc_%d_%d_v%d", + ps.InclusionTreeHeight, + ps.InclusionNumberOfCompressedAccounts, + ps.Version) + } else if ps.NonInclusionNumberOfCompressedAccounts > 0 { + key = fmt.Sprintf("non_%d_%d_v%d", + ps.NonInclusionTreeHeight, + ps.NonInclusionNumberOfCompressedAccounts, + ps.Version) + } else { + return fmt.Errorf("invalid MerkleProofSystem: no compressed accounts specified") + } + + m.merkleSystems[key] = ps + logging.Logger().Debug(). + Str("cache_key", key). + Msg("Cached MerkleProofSystem") + + case *BatchProofSystem: + key := fmt.Sprintf("%s_%d_%d", ps.CircuitType, ps.TreeHeight, ps.BatchSize) + m.batchSystems[key] = ps + logging.Logger().Debug(). + Str("cache_key", key). + Msg("Cached BatchProofSystem") + + default: + return fmt.Errorf("unknown system type: %T", system) + } + + return nil +} diff --git a/prover/server/prover/common/proving_keys_utils.go b/prover/server/prover/common/proving_keys_utils.go index c04cff2bab..6d12244e5c 100644 --- a/prover/server/prover/common/proving_keys_utils.go +++ b/prover/server/prover/common/proving_keys_utils.go @@ -6,6 +6,7 @@ import ( "io" "light/light-prover/logging" "os" + "path/filepath" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" @@ -116,22 +117,22 @@ func GetKeys(keysDir string, runMode RunMode, circuits []string) []string { var inclusionKeys []string // V2 inclusion keys (height 32) - 1 to 20 for i := 1; i <= 20; i++ { - inclusionKeys = append(inclusionKeys, fmt.Sprintf("%sv2_inclusion_32_%d.key", keysDir, i)) + inclusionKeys = append(inclusionKeys, filepath.Join(keysDir, fmt.Sprintf("v2_inclusion_32_%d.key", i))) } // V1 inclusion keys (legacy, height 26) for _, i := range []int{1, 2, 3, 4, 8} { - inclusionKeys = append(inclusionKeys, fmt.Sprintf("%sv1_inclusion_26_%d.key", keysDir, i)) + inclusionKeys = append(inclusionKeys, filepath.Join(keysDir, fmt.Sprintf("v1_inclusion_26_%d.key", i))) } // Build non-inclusion keys var nonInclusionKeys []string // V1 non-inclusion keys (legacy, height 26) for i := 1; i <= 2; i++ { - nonInclusionKeys = append(nonInclusionKeys, fmt.Sprintf("%sv1_non-inclusion_26_%d.key", keysDir, i)) + nonInclusionKeys = append(nonInclusionKeys, filepath.Join(keysDir, fmt.Sprintf("v1_non-inclusion_26_%d.key", i))) } // V2 non-inclusion keys (height 40) - 1 to 32 for i := 1; i <= 32; i++ { - nonInclusionKeys = append(nonInclusionKeys, fmt.Sprintf("%sv2_non-inclusion_40_%d.key", keysDir, i)) + nonInclusionKeys = append(nonInclusionKeys, filepath.Join(keysDir, fmt.Sprintf("v2_non-inclusion_40_%d.key", i))) } // Build combined keys @@ -139,66 +140,66 @@ func GetKeys(keysDir string, runMode RunMode, circuits []string) []string { // V1 combined keys (legacy, heights 26/26) for i := 1; i <= 4; i++ { for j := 1; j <= 2; j++ { - combinedKeys = append(combinedKeys, fmt.Sprintf("%sv1_combined_26_26_%d_%d.key", keysDir, i, j)) + combinedKeys = append(combinedKeys, filepath.Join(keysDir, fmt.Sprintf("v1_combined_26_26_%d_%d.key", i, j))) } } // V2 combined keys (heights 32/40) for i := 1; i <= 4; i++ { for j := 1; j <= 4; j++ { - combinedKeys = append(combinedKeys, fmt.Sprintf("%sv2_combined_32_40_%d_%d.key", keysDir, i, j)) + combinedKeys = append(combinedKeys, filepath.Join(keysDir, fmt.Sprintf("v2_combined_32_40_%d_%d.key", i, j))) } } // Keys for local-rpc mode - matching the 18 keys in cli/package.json var localRpcKeys []string = []string{ // V1 combined keys - keysDir + "v1_combined_26_26_1_1.key", - keysDir + "v1_combined_26_26_1_2.key", - keysDir + "v1_combined_26_26_2_1.key", + filepath.Join(keysDir, "v1_combined_26_26_1_1.key"), + filepath.Join(keysDir, "v1_combined_26_26_1_2.key"), + filepath.Join(keysDir, "v1_combined_26_26_2_1.key"), // V2 combined keys - keysDir + "v2_combined_32_40_1_1.key", - keysDir + "v2_combined_32_40_1_2.key", - keysDir + "v2_combined_32_40_2_1.key", + filepath.Join(keysDir, "v2_combined_32_40_1_1.key"), + filepath.Join(keysDir, "v2_combined_32_40_1_2.key"), + filepath.Join(keysDir, "v2_combined_32_40_2_1.key"), // V2 inclusion keys - keysDir + "v2_inclusion_32_1.key", - keysDir + "v2_inclusion_32_2.key", - keysDir + "v2_inclusion_32_3.key", - keysDir + "v2_inclusion_32_4.key", + filepath.Join(keysDir, "v2_inclusion_32_1.key"), + filepath.Join(keysDir, "v2_inclusion_32_2.key"), + filepath.Join(keysDir, "v2_inclusion_32_3.key"), + filepath.Join(keysDir, "v2_inclusion_32_4.key"), // V1 inclusion keys - keysDir + "v1_inclusion_26_1.key", - keysDir + "v1_inclusion_26_2.key", - keysDir + "v1_inclusion_26_3.key", - keysDir + "v1_inclusion_26_4.key", + filepath.Join(keysDir, "v1_inclusion_26_1.key"), + filepath.Join(keysDir, "v1_inclusion_26_2.key"), + filepath.Join(keysDir, "v1_inclusion_26_3.key"), + filepath.Join(keysDir, "v1_inclusion_26_4.key"), // V1 non-inclusion keys - keysDir + "v1_non-inclusion_26_1.key", - keysDir + "v1_non-inclusion_26_2.key", + filepath.Join(keysDir, "v1_non-inclusion_26_1.key"), + filepath.Join(keysDir, "v1_non-inclusion_26_2.key"), // V2 non-inclusion keys - keysDir + "v2_non-inclusion_40_1.key", - keysDir + "v2_non-inclusion_40_2.key", + filepath.Join(keysDir, "v2_non-inclusion_40_1.key"), + filepath.Join(keysDir, "v2_non-inclusion_40_2.key"), } var appendKeys []string = []string{ - keysDir + "batch_append_32_500.key", + filepath.Join(keysDir, "batch_append_32_500.key"), } var updateKeys []string = []string{ - keysDir + "batch_update_32_500.key", + filepath.Join(keysDir, "batch_update_32_500.key"), } var appendTestKeys []string = []string{ - keysDir + "batch_append_32_10.key", + filepath.Join(keysDir, "batch_append_32_10.key"), } var updateTestKeys []string = []string{ - keysDir + "batch_update_32_10.key", + filepath.Join(keysDir, "batch_update_32_10.key"), } var addressAppendKeys []string = []string{ - keysDir + "batch_address-append_40_250.key", + filepath.Join(keysDir, "batch_address-append_40_250.key"), } var addressAppendTestKeys []string = []string{ - keysDir + "batch_address-append_40_10.key", + filepath.Join(keysDir, "batch_address-append_40_10.key"), } switch runMode { @@ -276,10 +277,19 @@ func GetKeys(keysDir string, runMode RunMode, circuits []string) []string { } func LoadKeys(keysDirPath string, runMode RunMode, circuits []string) ([]*MerkleProofSystem, []*BatchProofSystem, error) { + return LoadKeysWithConfig(keysDirPath, runMode, circuits, DefaultDownloadConfig()) +} + +func LoadKeysWithConfig(keysDirPath string, runMode RunMode, circuits []string, config *DownloadConfig) ([]*MerkleProofSystem, []*BatchProofSystem, error) { var pssv1 []*MerkleProofSystem var pssv2 []*BatchProofSystem keys := GetKeys(keysDirPath, runMode, circuits) + // Ensure all required keys exist (download if necessary) + if err := EnsureKeysExist(keys, config); err != nil { + return nil, nil, fmt.Errorf("failed to ensure keys exist: %w", err) + } + for _, key := range keys { logging.Logger().Info().Msg("Reading proving system from file " + key + "...") system, err := ReadSystemFromFile(key) diff --git a/prover/server/redis_queue_test.go b/prover/server/redis_queue_test.go index 7af0847bfb..56157e81a4 100644 --- a/prover/server/redis_queue_test.go +++ b/prover/server/redis_queue_test.go @@ -676,20 +676,19 @@ func TestWorkerCreation(t *testing.T) { rq := setupRedisQueue(t) defer teardownRedisQueue(t, rq) - var psv1 []*common.MerkleProofSystem - var psv2 []*common.BatchProofSystem + keyManager := common.NewLazyKeyManager("./proving-keys/", common.DefaultDownloadConfig()) - updateWorker := server.NewUpdateQueueWorker(rq, psv1, psv2) + updateWorker := server.NewUpdateQueueWorker(rq, keyManager) if updateWorker == nil { t.Errorf("Expected update worker to be created, got nil") } - appendWorker := server.NewAppendQueueWorker(rq, psv1, psv2) + appendWorker := server.NewAppendQueueWorker(rq, keyManager) if appendWorker == nil { t.Errorf("Expected append worker to be created, got nil") } - addressAppendWorker := server.NewAddressAppendQueueWorker(rq, psv1, psv2) + addressAppendWorker := server.NewAddressAppendQueueWorker(rq, keyManager) if addressAppendWorker == nil { t.Errorf("Expected address append worker to be created, got nil") } @@ -846,8 +845,7 @@ func TestFailedJobStatusHTTPEndpoint(t *testing.T) { rq := setupRedisQueue(t) defer teardownRedisQueue(t, rq) - var psv1 []*common.MerkleProofSystem - var psv2 []*common.BatchProofSystem + keyManager := common.NewLazyKeyManager("./proving-keys/", common.DefaultDownloadConfig()) config := &server.EnhancedConfig{ ProverAddress: "localhost:8082", @@ -858,7 +856,7 @@ func TestFailedJobStatusHTTPEndpoint(t *testing.T) { }, } - serverJob := server.RunEnhanced(config, rq, []string{}, common.FullTest, psv1, psv2) + serverJob := server.RunEnhanced(config, rq, keyManager) defer serverJob.RequestStop() time.Sleep(100 * time.Millisecond) diff --git a/prover/server/scripts/download_keys.sh b/prover/server/scripts/download_keys.sh deleted file mode 100755 index def6771d64..0000000000 --- a/prover/server/scripts/download_keys.sh +++ /dev/null @@ -1,323 +0,0 @@ -#!/usr/bin/env bash - -set -e - -# Configuration with environment variable support -ROOT_DIR="$(git rev-parse --show-toplevel)" -KEYS_DIR="${ROOT_DIR}/prover/server/proving-keys" -BASE_URL="https://storage.googleapis.com/light-protocol-proving-keys/proving-keys-06-10-25" -CHECKSUM_URL="${BASE_URL}/CHECKSUM" - -# Configurable parameters for poor connections -MAX_RETRIES=${DOWNLOAD_MAX_RETRIES:-10} # Default 10, can be overridden -INITIAL_RETRY_DELAY=${DOWNLOAD_RETRY_DELAY:-5} # Initial delay in seconds -MAX_RETRY_DELAY=${DOWNLOAD_MAX_RETRY_DELAY:-300} # Max delay (5 minutes) -BANDWIDTH_LIMIT=${DOWNLOAD_BANDWIDTH_LIMIT:-} # Optional bandwidth limit (e.g., "500K") -PARALLEL_DOWNLOADS=${DOWNLOAD_PARALLEL:-1} # Number of parallel downloads -STATUS_FILE="${KEYS_DIR}/.download_status" - -# Colors for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -NC='\033[0m' # No Color - -# Create keys directory -mkdir -p "$KEYS_DIR" - -# Function to calculate exponential backoff -calculate_backoff() { - local attempt=$1 - local delay=$((INITIAL_RETRY_DELAY * (2 ** (attempt - 1)))) - if [ $delay -gt $MAX_RETRY_DELAY ]; then - delay=$MAX_RETRY_DELAY - fi - echo $delay -} - -# Function to format bytes for human reading -format_bytes() { - local bytes=$1 - if [ $bytes -gt 1073741824 ]; then - echo "$(echo "scale=2; $bytes/1073741824" | bc) GB" - elif [ $bytes -gt 1048576 ]; then - echo "$(echo "scale=2; $bytes/1048576" | bc) MB" - else - echo "$(echo "scale=2; $bytes/1024" | bc) KB" - fi -} - -# Function to save download status -save_status() { - local file="$1" - local status="$2" - local timestamp=$(date +%s) - echo "${file}|${status}|${timestamp}" >> "$STATUS_FILE" -} - -# Function to check if file was already completed -is_completed() { - local file="$1" - [ -f "$STATUS_FILE" ] && grep -q "^${file}|completed" "$STATUS_FILE" -} - -# Enhanced download function with progress tracking -download_file() { - local url="$1" - local output="$2" - local attempt=1 - local temp_output="${output}.tmp" - local progress_file="${output}.progress" - - while [ $attempt -le $MAX_RETRIES ]; do - local retry_delay=$(calculate_backoff $attempt) - echo -e "${YELLOW}Downloading $url (attempt $attempt/$MAX_RETRIES)${NC}" - - # Build curl command with optional bandwidth limit - local curl_cmd="curl -L --fail -H 'Accept: */*' -H 'Accept-Encoding: identity'" - curl_cmd+=" -A 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'" - curl_cmd+=" --connect-timeout 30 --max-time 0" # No overall timeout - curl_cmd+=" --progress-bar" - - if [ -n "$BANDWIDTH_LIMIT" ]; then - curl_cmd+=" --limit-rate $BANDWIDTH_LIMIT" - echo -e "${YELLOW}Bandwidth limited to $BANDWIDTH_LIMIT${NC}" - fi - - # Check if partial download exists - if [ -f "$temp_output" ]; then - local resume_pos=$(stat -f%z "$temp_output" 2>/dev/null || stat -c%s "$temp_output" 2>/dev/null || echo "0") - local formatted_pos=$(format_bytes $resume_pos) - echo -e "${GREEN}Resuming from $formatted_pos${NC}" - curl_cmd+=" -H 'Range: bytes=${resume_pos}-' -C -" - fi - - curl_cmd+=" --output '$temp_output' '$url'" - - # Execute download and capture progress - if eval "$curl_cmd" 2>&1 | tee "$progress_file"; then - mv "$temp_output" "$output" - rm -f "$progress_file" - save_status "${output##*/}" "completed" - return 0 - fi - - local exit_code=$? - echo -e "${RED}Download failed (exit code: $exit_code)${NC}" - - # Check if it's a connection error vs other errors - if [ $exit_code -eq 56 ] || [ $exit_code -eq 18 ] || [ $exit_code -eq 28 ]; then - echo -e "${YELLOW}Connection issue detected. Will retry with exponential backoff.${NC}" - fi - - if [ $attempt -lt $MAX_RETRIES ]; then - echo -e "${YELLOW}Retrying in $retry_delay seconds...${NC}" - echo "Tip: You can also manually resume by running this script again" - sleep $retry_delay - fi - - attempt=$((attempt + 1)) - done - - # Save failed status but keep partial file - save_status "${output##*/}" "failed" - return 1 -} - -verify_checksum() { - local file="$1" - local checksum_file="$2" - local expected - local actual - - if command -v sha256sum >/dev/null 2>&1; then - CHECKSUM_CMD="sha256sum" - else - CHECKSUM_CMD="shasum -a 256" - fi - - expected=$(grep "${file##*/}" "$checksum_file" | cut -d' ' -f1) - actual=$($CHECKSUM_CMD "$file" | cut -d' ' -f1) - - echo "Expected checksum: $expected" - echo "Actual checksum: $actual" - - [ "$expected" = "$actual" ] -} - -# Show current configuration -echo "=========================================" -echo "Download Configuration:" -echo " Max retries: $MAX_RETRIES" -echo " Initial retry delay: ${INITIAL_RETRY_DELAY}s" -echo " Max retry delay: ${MAX_RETRY_DELAY}s" -if [ -n "$BANDWIDTH_LIMIT" ]; then - echo " Bandwidth limit: $BANDWIDTH_LIMIT" -fi -echo " Parallel downloads: $PARALLEL_DOWNLOADS" -echo "" -echo "To customize, set environment variables:" -echo " DOWNLOAD_MAX_RETRIES=20" -echo " DOWNLOAD_RETRY_DELAY=10" -echo " DOWNLOAD_BANDWIDTH_LIMIT=500K" -echo " DOWNLOAD_PARALLEL=2" -echo "=========================================" -echo "" - -# Download checksum file -CHECKSUM_FILE="${KEYS_DIR}/CHECKSUM" -if ! download_file "$CHECKSUM_URL" "$CHECKSUM_FILE"; then - echo -e "${RED}Failed to download checksum file${NC}" - exit 1 -fi - -echo "Content of CHECKSUM file:" -cat "$CHECKSUM_FILE" - -case "$1" in - "light") - SUFFIXES=( - # V1 keys (height 26) - "v1_inclusion_26:1 2 3 4 8" - "v1_non-inclusion_26:1 2 3 4 8" - "v1_combined_26_26:1_1 1_2 1_4 1_8 2_1 2_2 2_4 2_8 3_1 3_2 3_4 3_8 4_1 4_2 4_4 4_8 8_1 8_2 8_4 8_8" - - # V2 keys (heights 32/40) - "v2_inclusion_32:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20" - "v2_non-inclusion_40:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32" - "v2_combined_32_40:1_1 1_2 1_3 1_4 2_1 2_2 2_3 2_4 3_1 3_2 3_3 3_4 4_1 4_2 4_3 4_4" - "batch_append_32:10" - "batch_update_32:10" - "batch_address-append_40:10" - ) - ;; - "full") - SUFFIXES=( - # V1 keys (height 26) - "v1_inclusion_26:1 2 3 4 8" - "v1_non-inclusion_26:1 2 3 4 8" - "v1_combined_26_26:1_1 1_2 1_4 1_8 2_1 2_2 2_4 2_8 3_1 3_2 3_4 3_8 4_1 4_2 4_4 4_8 8_1 8_2 8_4 8_8" - - # V2 keys (heights 32/40) - "v2_inclusion_32:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20" - "v2_non-inclusion_40:1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32" - "v2_combined_32_40:1_1 1_2 1_3 1_4 2_1 2_2 2_3 2_4 3_1 3_2 3_3 3_4 4_1 4_2 4_3 4_4" - "batch_append_32:10 500" - "batch_update_32:10 500" - "batch_address-append_40:10 250" - ) - echo -e "${YELLOW}WARNING: Full keys include files >6GB. Ensure stable connection!${NC}" - ;; - *) - echo "Usage: $0 [light|full]" - exit 1 - ;; -esac - -# Count total files and calculate total size -total_files=0 -completed_files=0 -skipped_files=0 -failed_files=0 -total_size=0 - -# Build file list -declare -a FILE_LIST -for group in "${SUFFIXES[@]}"; do - base=${group%:*} - suffixes=${group#*:} - for suffix in $suffixes; do - for ext in key vkey; do - FILE_LIST+=("${base}_${suffix}.${ext}") - total_files=$((total_files + 1)) - done - done -done - -echo "Total files to process: $total_files" -echo "" - -# Process downloads (with simple parallel support if requested) -process_download() { - local file="$1" - local index="$2" - local output="${KEYS_DIR}/${file}" - local temp_output="${output}.tmp" - - # Check if already completed in previous run - if is_completed "$file"; then - echo -e "${GREEN}[$index/$total_files] Skipping $file (marked as completed)${NC}" - return 0 - fi - - # Check if file already exists and is valid - if [ -f "$output" ] && verify_checksum "$output" "$CHECKSUM_FILE" 2>/dev/null; then - echo -e "${GREEN}[$index/$total_files] Skipping $file (already downloaded and verified)${NC}" - save_status "$file" "completed" - return 0 - fi - - # Check if partial download exists - if [ -f "$temp_output" ]; then - local partial_size=$(stat -f%z "$temp_output" 2>/dev/null || stat -c%s "$temp_output" 2>/dev/null || echo "0") - local formatted_size=$(format_bytes $partial_size) - echo -e "${YELLOW}Found partial download for $file ($formatted_size)${NC}" - fi - - echo -e "${YELLOW}[$index/$total_files] Downloading $file...${NC}" - if download_file "${BASE_URL}/${file}" "$output"; then - echo "Verifying checksum for $file..." - if ! verify_checksum "$output" "$CHECKSUM_FILE"; then - echo -e "${RED}Checksum verification failed for $file${NC}" - rm -f "$output" - rm -f "$temp_output" - save_status "$file" "checksum_failed" - return 1 - fi - echo -e "${GREEN}[$index/$total_files] Successfully downloaded and verified $file${NC}" - return 0 - else - echo -e "${RED}Failed to download $file after $MAX_RETRIES attempts${NC}" - echo "You can resume the download by running this script again" - return 1 - fi -} - -# Execute downloads -index=0 -for file in "${FILE_LIST[@]}"; do - index=$((index + 1)) - if process_download "$file" "$index"; then - completed_files=$((completed_files + 1)) - else - failed_files=$((failed_files + 1)) - # On mobile connections, offer to continue with remaining files - if [ $failed_files -gt 0 ]; then - echo "" - echo -e "${YELLOW}Download failed. Continue with remaining files? (y/n)${NC}" - read -r -n 1 response - echo "" - if [[ ! "$response" =~ ^[Yy]$ ]]; then - break - fi - fi - fi -done - -# Summary -echo "" -echo "=========================================" -if [ $failed_files -eq 0 ]; then - echo -e "${GREEN}All files downloaded and verified successfully!${NC}" -else - echo -e "${YELLOW}Download session completed with errors${NC}" - echo -e " Successful: ${GREEN}$completed_files${NC}" - echo -e " Failed: ${RED}$failed_files${NC}" - echo "" - echo "To resume failed downloads, run this script again." - echo "Partial downloads will be automatically resumed." -fi -echo "=========================================" - -# Exit with appropriate code -[ $failed_files -eq 0 ] && exit 0 || exit 1 diff --git a/prover/server/server/queue_job.go b/prover/server/server/queue_job.go index b30141ff17..e8d3074570 100644 --- a/prover/server/server/queue_job.go +++ b/prover/server/server/queue_job.go @@ -30,8 +30,7 @@ type QueueWorker interface { type BaseQueueWorker struct { queue *RedisQueue - provingSystemsV1 []*common.MerkleProofSystem - provingSystemsV2 []*common.BatchProofSystem + keyManager *common.LazyKeyManager stopChan chan struct{} queueName string processingQueueName string @@ -49,12 +48,11 @@ type AddressAppendQueueWorker struct { *BaseQueueWorker } -func NewUpdateQueueWorker(redisQueue *RedisQueue, psv1 []*common.MerkleProofSystem, psv2 []*common.BatchProofSystem) *UpdateQueueWorker { +func NewUpdateQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyManager) *UpdateQueueWorker { return &UpdateQueueWorker{ BaseQueueWorker: &BaseQueueWorker{ queue: redisQueue, - provingSystemsV1: psv1, - provingSystemsV2: psv2, + keyManager: keyManager, stopChan: make(chan struct{}), queueName: "zk_update_queue", processingQueueName: "zk_update_processing_queue", @@ -62,12 +60,11 @@ func NewUpdateQueueWorker(redisQueue *RedisQueue, psv1 []*common.MerkleProofSyst } } -func NewAppendQueueWorker(redisQueue *RedisQueue, psv1 []*common.MerkleProofSystem, psv2 []*common.BatchProofSystem) *AppendQueueWorker { +func NewAppendQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyManager) *AppendQueueWorker { return &AppendQueueWorker{ BaseQueueWorker: &BaseQueueWorker{ queue: redisQueue, - provingSystemsV1: psv1, - provingSystemsV2: psv2, + keyManager: keyManager, stopChan: make(chan struct{}), queueName: "zk_append_queue", processingQueueName: "zk_append_processing_queue", @@ -75,12 +72,11 @@ func NewAppendQueueWorker(redisQueue *RedisQueue, psv1 []*common.MerkleProofSyst } } -func NewAddressAppendQueueWorker(redisQueue *RedisQueue, psv1 []*common.MerkleProofSystem, psv2 []*common.BatchProofSystem) *AddressAppendQueueWorker { +func NewAddressAppendQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyManager) *AddressAppendQueueWorker { return &AddressAppendQueueWorker{ BaseQueueWorker: &BaseQueueWorker{ queue: redisQueue, - provingSystemsV1: psv1, - provingSystemsV2: psv2, + keyManager: keyManager, stopChan: make(chan struct{}), queueName: "zk_address_append_queue", processingQueueName: "zk_address_append_processing_queue", @@ -268,19 +264,15 @@ func (w *BaseQueueWorker) processProofJob(job *ProofJob) error { } func (w *BaseQueueWorker) processInclusionProof(payload json.RawMessage, meta common.ProofRequestMeta) (*common.Proof, error) { - var ps *common.MerkleProofSystem - for _, provingSystem := range w.provingSystemsV1 { - if provingSystem.InclusionNumberOfCompressedAccounts == meta.NumInputs && - provingSystem.InclusionTreeHeight == meta.StateTreeHeight && - provingSystem.Version == meta.Version && - provingSystem.NonInclusionNumberOfCompressedAccounts == uint32(0) { - ps = provingSystem - break - } - } - - if ps == nil { - return nil, fmt.Errorf("no proving system found for inclusion proof with meta: %+v", meta) + ps, err := w.keyManager.GetMerkleSystem( + meta.StateTreeHeight, + meta.NumInputs, + 0, + 0, + meta.Version, + ) + if err != nil { + return nil, fmt.Errorf("inclusion proof: %w", err) } if meta.Version == 1 { @@ -301,18 +293,15 @@ func (w *BaseQueueWorker) processInclusionProof(payload json.RawMessage, meta co } func (w *BaseQueueWorker) processNonInclusionProof(payload json.RawMessage, meta common.ProofRequestMeta) (*common.Proof, error) { - var ps *common.MerkleProofSystem - for _, provingSystem := range w.provingSystemsV1 { - if provingSystem.NonInclusionNumberOfCompressedAccounts == meta.NumAddresses && - provingSystem.NonInclusionTreeHeight == meta.AddressTreeHeight && - provingSystem.InclusionNumberOfCompressedAccounts == uint32(0) { - ps = provingSystem - break - } - } - - if ps == nil { - return nil, fmt.Errorf("no proving system found for non-inclusion proof with meta: %+v", meta) + ps, err := w.keyManager.GetMerkleSystem( + 0, + 0, + meta.AddressTreeHeight, + meta.NumAddresses, + meta.Version, + ) + if err != nil { + return nil, fmt.Errorf("non-inclusion proof: %w", err) } if meta.AddressTreeHeight == 26 { @@ -333,19 +322,15 @@ func (w *BaseQueueWorker) processNonInclusionProof(payload json.RawMessage, meta } func (w *BaseQueueWorker) processCombinedProof(payload json.RawMessage, meta common.ProofRequestMeta) (*common.Proof, error) { - var ps *common.MerkleProofSystem - for _, provingSystem := range w.provingSystemsV1 { - if provingSystem.InclusionNumberOfCompressedAccounts == meta.NumInputs && - provingSystem.NonInclusionNumberOfCompressedAccounts == meta.NumAddresses && - provingSystem.InclusionTreeHeight == meta.StateTreeHeight && - provingSystem.NonInclusionTreeHeight == meta.AddressTreeHeight { - ps = provingSystem - break - } - } - - if ps == nil { - return nil, fmt.Errorf("no proving system found for combined proof with meta: %+v", meta) + ps, err := w.keyManager.GetMerkleSystem( + meta.StateTreeHeight, + meta.NumInputs, + meta.AddressTreeHeight, + meta.NumAddresses, + meta.Version, + ) + if err != nil { + return nil, fmt.Errorf("combined proof: %w", err) } if meta.AddressTreeHeight == 26 { @@ -371,15 +356,16 @@ func (w *BaseQueueWorker) processBatchUpdateProof(payload json.RawMessage) (*com return nil, fmt.Errorf("failed to unmarshal batch update parameters: %w", err) } - for _, provingSystem := range w.provingSystemsV2 { - if provingSystem.CircuitType == common.BatchUpdateCircuitType && - provingSystem.TreeHeight == params.Height && - provingSystem.BatchSize == params.BatchSize { - return v2.ProveBatchUpdate(provingSystem, ¶ms) - } + ps, err := w.keyManager.GetBatchSystem( + common.BatchUpdateCircuitType, + params.Height, + params.BatchSize, + ) + if err != nil { + return nil, fmt.Errorf("batch update proof: %w", err) } - return nil, fmt.Errorf("no proving system found for batch update with height %d and batch size %d", params.Height, params.BatchSize) + return v2.ProveBatchUpdate(ps, ¶ms) } func (w *BaseQueueWorker) processBatchAppendProof(payload json.RawMessage) (*common.Proof, error) { @@ -388,15 +374,16 @@ func (w *BaseQueueWorker) processBatchAppendProof(payload json.RawMessage) (*com return nil, fmt.Errorf("failed to unmarshal batch append parameters: %w", err) } - for _, provingSystem := range w.provingSystemsV2 { - if provingSystem.CircuitType == common.BatchAppendCircuitType && - provingSystem.TreeHeight == params.Height && - provingSystem.BatchSize == params.BatchSize { - return v2.ProveBatchAppend(provingSystem, ¶ms) - } + ps, err := w.keyManager.GetBatchSystem( + common.BatchAppendCircuitType, + params.Height, + params.BatchSize, + ) + if err != nil { + return nil, fmt.Errorf("batch append proof: %w", err) } - return nil, fmt.Errorf("no proving system found for batch append with height %d and batch size %d", params.Height, params.BatchSize) + return v2.ProveBatchAppend(ps, ¶ms) } func (w *BaseQueueWorker) processBatchAddressAppendProof(payload json.RawMessage) (*common.Proof, error) { @@ -405,17 +392,17 @@ func (w *BaseQueueWorker) processBatchAddressAppendProof(payload json.RawMessage return nil, fmt.Errorf("failed to unmarshal batch address append parameters: %w", err) } - for _, provingSystem := range w.provingSystemsV2 { - logging.Logger().Info().Str(string(provingSystem.CircuitType), "proving system") - if provingSystem.CircuitType == common.BatchAddressAppendCircuitType && - provingSystem.TreeHeight == params.TreeHeight && - provingSystem.BatchSize == params.BatchSize { - logging.Logger().Info().Msg("Processing batch address append proof") - return v2.ProveBatchAddressAppend(provingSystem, ¶ms) - } + ps, err := w.keyManager.GetBatchSystem( + common.BatchAddressAppendCircuitType, + params.TreeHeight, + params.BatchSize, + ) + if err != nil { + return nil, fmt.Errorf("batch address append proof: %w", err) } - return nil, fmt.Errorf("no proving system found for batch address append with height %d and batch size %d", params.TreeHeight, params.BatchSize) + logging.Logger().Info().Msg("Processing batch address append proof") + return v2.ProveBatchAddressAppend(ps, ¶ms) } func (w *BaseQueueWorker) removeFromProcessingQueue(jobID string) { diff --git a/prover/server/server/server.go b/prover/server/server/server.go index 031ccfac96..23b089c815 100644 --- a/prover/server/server/server.go +++ b/prover/server/server/server.go @@ -268,12 +268,9 @@ type EnhancedConfig struct { } type proveHandler struct { - provingSystemsV1 []*common.MerkleProofSystem - provingSystemsV2 []*common.BatchProofSystem - redisQueue *RedisQueue - enableQueue bool - runMode common.RunMode - circuits []string + keyManager *common.LazyKeyManager + redisQueue *RedisQueue + enableQueue bool } func (handler proveHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -459,17 +456,17 @@ func (handler queueCleanupHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ } } -func RunWithQueue(config *Config, redisQueue *RedisQueue, circuits []string, runMode common.RunMode, provingSystemsV1 []*common.MerkleProofSystem, provingSystemsV2 []*common.BatchProofSystem) RunningJob { +func RunWithQueue(config *Config, redisQueue *RedisQueue, keyManager *common.LazyKeyManager) RunningJob { return RunEnhanced(&EnhancedConfig{ ProverAddress: config.ProverAddress, MetricsAddress: config.MetricsAddress, Queue: &QueueConfig{ Enabled: redisQueue != nil, }, - }, redisQueue, circuits, runMode, provingSystemsV1, provingSystemsV2) + }, redisQueue, keyManager) } -func RunEnhanced(config *EnhancedConfig, redisQueue *RedisQueue, circuits []string, runMode common.RunMode, provingSystemsV1 []*common.MerkleProofSystem, provingSystemsV2 []*common.BatchProofSystem) RunningJob { +func RunEnhanced(config *EnhancedConfig, redisQueue *RedisQueue, keyManager *common.LazyKeyManager) RunningJob { apiKey := getAPIKeyFromEnv() if apiKey != "" { logging.Logger().Info().Msg("API key authentication enabled for prover server") @@ -485,12 +482,9 @@ func RunEnhanced(config *EnhancedConfig, redisQueue *RedisQueue, circuits []stri proverMux := http.NewServeMux() proverMux.Handle("/prove", proveHandler{ - provingSystemsV1: provingSystemsV1, - provingSystemsV2: provingSystemsV2, - redisQueue: redisQueue, - enableQueue: config.Queue != nil && config.Queue.Enabled, - runMode: runMode, - circuits: circuits, + keyManager: keyManager, + redisQueue: redisQueue, + enableQueue: config.Queue != nil && config.Queue.Enabled, }) proverMux.Handle("/health", healthHandler{}) @@ -585,8 +579,8 @@ func RunEnhanced(config *EnhancedConfig, redisQueue *RedisQueue, circuits []stri return CombineJobs(metricsJob, proverJob) } -func Run(config *Config, circuits []string, runMode common.RunMode, provingSystemsV1 []*common.MerkleProofSystem, provingSystemsV2 []*common.BatchProofSystem) RunningJob { - return RunWithQueue(config, nil, circuits, runMode, provingSystemsV1, provingSystemsV2) +func Run(config *Config, keyManager *common.LazyKeyManager) RunningJob { + return RunWithQueue(config, nil, keyManager) } type Error struct { @@ -898,16 +892,9 @@ func (handler proveHandler) batchAddressAppendProof(buf []byte) (*common.Proof, treeHeight := params.TreeHeight batchSize := params.BatchSize - var ps *common.BatchProofSystem - for _, provingSystem := range handler.provingSystemsV2 { - if provingSystem.CircuitType == common.BatchAddressAppendCircuitType && provingSystem.TreeHeight == treeHeight && provingSystem.BatchSize == batchSize { - ps = provingSystem - break - } - } - - if ps == nil { - return nil, provingError(fmt.Errorf("batch address append: no proving system for tree height %d and batch size %d", treeHeight, batchSize)) + ps, err := handler.keyManager.GetBatchSystem(common.BatchAddressAppendCircuitType, treeHeight, batchSize) + if err != nil { + return nil, provingError(fmt.Errorf("batch address append: %w", err)) } proof, err := v2.ProveBatchAddressAppend(ps, ¶ms) @@ -928,16 +915,9 @@ func (handler proveHandler) batchAppendHandler(buf []byte) (*common.Proof, *Erro treeHeight := params.Height batchSize := params.BatchSize - var ps *common.BatchProofSystem - for _, provingSystem := range handler.provingSystemsV2 { - if provingSystem.CircuitType == common.BatchAppendCircuitType && provingSystem.TreeHeight == treeHeight && provingSystem.BatchSize == batchSize { - ps = provingSystem - break - } - } - - if ps == nil { - return nil, provingError(fmt.Errorf("no proving system for tree height %d and batch size %d", treeHeight, batchSize)) + ps, err := handler.keyManager.GetBatchSystem(common.BatchAppendCircuitType, treeHeight, batchSize) + if err != nil { + return nil, provingError(fmt.Errorf("batch append: %w", err)) } proof, err := v2.ProveBatchAppend(ps, ¶ms) @@ -959,16 +939,9 @@ func (handler proveHandler) batchUpdateProof(buf []byte) (*common.Proof, *Error) treeHeight := params.Height batchSize := params.BatchSize - var ps *common.BatchProofSystem - for _, provingSystem := range handler.provingSystemsV2 { - if provingSystem.CircuitType == common.BatchUpdateCircuitType && provingSystem.TreeHeight == treeHeight && provingSystem.BatchSize == batchSize { - ps = provingSystem - break - } - } - - if ps == nil { - return nil, provingError(fmt.Errorf("no proving system for tree height %d and batch size %d", treeHeight, batchSize)) + ps, err := handler.keyManager.GetBatchSystem(common.BatchUpdateCircuitType, treeHeight, batchSize) + if err != nil { + return nil, provingError(fmt.Errorf("batch update: %w", err)) } proof, err := v2.ProveBatchUpdate(ps, ¶ms) @@ -980,16 +953,15 @@ func (handler proveHandler) batchUpdateProof(buf []byte) (*common.Proof, *Error) } func (handler proveHandler) inclusionProof(buf []byte, proofRequestMeta common.ProofRequestMeta) (*common.Proof, *Error) { - var ps *common.MerkleProofSystem - for _, provingSystem := range handler.provingSystemsV1 { - if provingSystem.InclusionNumberOfCompressedAccounts == proofRequestMeta.NumInputs && provingSystem.InclusionTreeHeight == proofRequestMeta.StateTreeHeight && provingSystem.Version == proofRequestMeta.Version && provingSystem.NonInclusionNumberOfCompressedAccounts == uint32(0) { - ps = provingSystem - break - } - } - - if ps == nil { - return nil, provingError(fmt.Errorf("no proving system for %+v proofRequest", proofRequestMeta)) + ps, err := handler.keyManager.GetMerkleSystem( + proofRequestMeta.StateTreeHeight, + proofRequestMeta.NumInputs, + 0, + 0, + proofRequestMeta.Version, + ) + if err != nil { + return nil, provingError(fmt.Errorf("inclusion proof: %w", err)) } if proofRequestMeta.Version == 1 { @@ -1019,17 +991,20 @@ func (handler proveHandler) inclusionProof(buf []byte, proofRequestMeta common.P } func (handler proveHandler) nonInclusionProof(buf []byte, proofRequestMeta common.ProofRequestMeta) (*common.Proof, *Error) { - - var ps *common.MerkleProofSystem - for _, provingSystem := range handler.provingSystemsV1 { - if provingSystem.NonInclusionNumberOfCompressedAccounts == uint32(proofRequestMeta.NumAddresses) && provingSystem.NonInclusionTreeHeight == uint32(proofRequestMeta.AddressTreeHeight) && provingSystem.InclusionNumberOfCompressedAccounts == uint32(0) { - ps = provingSystem - break - } + version := uint32(1) + if proofRequestMeta.AddressTreeHeight == 40 { + version = 2 } - if ps == nil { - return nil, provingError(fmt.Errorf("no proving system for %+v proofRequest", proofRequestMeta)) + ps, err := handler.keyManager.GetMerkleSystem( + 0, + 0, + proofRequestMeta.AddressTreeHeight, + proofRequestMeta.NumAddresses, + version, + ) + if err != nil { + return nil, provingError(fmt.Errorf("non-inclusion proof: %w", err)) } if proofRequestMeta.AddressTreeHeight == 26 { @@ -1068,16 +1043,20 @@ func (handler proveHandler) nonInclusionProof(buf []byte, proofRequestMeta commo } func (handler proveHandler) combinedProof(buf []byte, proofRequestMeta common.ProofRequestMeta) (*common.Proof, *Error) { - var ps *common.MerkleProofSystem - for _, provingSystem := range handler.provingSystemsV1 { - if provingSystem.InclusionNumberOfCompressedAccounts == proofRequestMeta.NumInputs && provingSystem.NonInclusionNumberOfCompressedAccounts == proofRequestMeta.NumAddresses && provingSystem.InclusionTreeHeight == proofRequestMeta.StateTreeHeight && provingSystem.NonInclusionTreeHeight == proofRequestMeta.AddressTreeHeight { - ps = provingSystem - break - } + version := uint32(1) + if proofRequestMeta.AddressTreeHeight == 40 { + version = 2 } - if ps == nil { - return nil, provingError(fmt.Errorf("no proving system for %+v proofRequest", proofRequestMeta)) + ps, err := handler.keyManager.GetMerkleSystem( + proofRequestMeta.StateTreeHeight, + proofRequestMeta.NumInputs, + proofRequestMeta.AddressTreeHeight, + proofRequestMeta.NumAddresses, + version, + ) + if err != nil { + return nil, provingError(fmt.Errorf("combined proof: %w", err)) } if proofRequestMeta.AddressTreeHeight == 26 { diff --git a/scripts/devenv/download-gnark-keys.sh b/scripts/devenv/download-gnark-keys.sh index 52373e183c..0413445663 100755 --- a/scripts/devenv/download-gnark-keys.sh +++ b/scripts/devenv/download-gnark-keys.sh @@ -6,11 +6,38 @@ source "${SCRIPT_DIR}/shared.sh" download_gnark_keys() { local key_type=${1:-light} ROOT_DIR="$(git rev-parse --show-toplevel)" + PROVER_DIR="${ROOT_DIR}/prover/server" + KEYS_DIR="${ROOT_DIR}/prover/server/proving-keys" - if [ ! -d "${ROOT_DIR}/prover/server/proving-keys" ] || [ -z "$(ls -A "${ROOT_DIR}/prover/server/proving-keys" 2>/dev/null)" ]; then - echo "Downloading gnark keys..." - "${ROOT_DIR}/prover/server/scripts/download_keys.sh" "$key_type" - log "gnark_keys" + case "$key_type" in + "light") + RUN_MODE="forester-test" + ;; + "full") + RUN_MODE="full" + ;; + *) + echo "Invalid key type: $key_type (expected 'light' or 'full')" + exit 1 + ;; + esac + + if [ ! -d "${KEYS_DIR}" ] || [ -z "$(ls -A "${KEYS_DIR}" 2>/dev/null)" ]; then + echo "Downloading gnark keys (run-mode: ${RUN_MODE})..." + cd "${PROVER_DIR}" || { + echo "Error: Failed to change directory to ${PROVER_DIR}" >&2 + exit 1 + } + if go run . download \ + --run-mode="${RUN_MODE}" \ + --keys-dir="${KEYS_DIR}" \ + --max-retries=10; then + log "gnark_keys" + else + exit_code=$? + echo "Error: Failed to download gnark keys (exit code: ${exit_code})" >&2 + exit ${exit_code} + fi else echo "Gnark keys already exist, skipping download..." fi diff --git a/sdk-libs/client/src/lib.rs b/sdk-libs/client/src/lib.rs index 9726c9b0fe..5ab761c25e 100644 --- a/sdk-libs/client/src/lib.rs +++ b/sdk-libs/client/src/lib.rs @@ -30,7 +30,6 @@ //! indexer::{Indexer, IndexerRpcConfig, RetryConfig}, //! local_test_validator::{spawn_validator, LightValidatorConfig}, //! }; -//! use light_prover_client::prover::ProverConfig; //! use solana_pubkey::Pubkey; //! //! #[tokio::main] @@ -38,7 +37,7 @@ //! // Start local test validator with Light Protocol programs //! let config = LightValidatorConfig { //! enable_indexer: true, -//! prover_config: Some(ProverConfig::default()), +//! enable_prover: true, //! wait_time: 75, //! sbf_programs: vec![], //! limit_ledger_size: None, @@ -89,5 +88,4 @@ pub mod indexer; pub mod local_test_validator; pub mod rpc; -/// Reexport for ProverConfig and other types. pub use light_prover_client; diff --git a/sdk-libs/client/src/local_test_validator.rs b/sdk-libs/client/src/local_test_validator.rs index ec284984f8..2d46ba7e72 100644 --- a/sdk-libs/client/src/local_test_validator.rs +++ b/sdk-libs/client/src/local_test_validator.rs @@ -1,12 +1,11 @@ use std::process::{Command, Stdio}; use light_prover_client::helpers::get_project_root; -pub use light_prover_client::prover::ProverConfig; #[derive(Debug)] pub struct LightValidatorConfig { pub enable_indexer: bool, - pub prover_config: Option, + pub enable_prover: bool, pub wait_time: u64, pub sbf_programs: Vec<(String, String)>, pub limit_ledger_size: Option, @@ -16,7 +15,7 @@ impl Default for LightValidatorConfig { fn default() -> Self { Self { enable_indexer: false, - prover_config: None, + enable_prover: false, wait_time: 35, sbf_programs: vec![], limit_ledger_size: None, @@ -43,14 +42,7 @@ pub async fn spawn_validator(config: LightValidatorConfig) { )); } - if let Some(prover_config) = config.prover_config { - prover_config.circuits.iter().for_each(|circuit| { - path.push_str(&format!(" --circuit {}", circuit)); - }); - if let Some(prover_mode) = prover_config.run_mode { - path.push_str(&format!(" --prover-run-mode {}", prover_mode)); - } - } else { + if !config.enable_prover { path.push_str(" --skip-prover"); } diff --git a/sdk-libs/program-test/src/program_test/config.rs b/sdk-libs/program-test/src/program_test/config.rs index 65a95e7161..e65d29e42a 100644 --- a/sdk-libs/program-test/src/program_test/config.rs +++ b/sdk-libs/program-test/src/program_test/config.rs @@ -7,7 +7,6 @@ use light_batched_merkle_tree::{ initialize_address_tree::InitAddressTreeAccountsInstructionData, initialize_state_tree::InitStateTreeAccountsInstructionData, }; -use light_prover_client::prover::ProverConfig; #[cfg(feature = "devenv")] use light_registry::protocol_config::state::ProtocolConfig; use solana_sdk::pubkey::Pubkey; @@ -21,7 +20,6 @@ pub struct ProgramTestConfig { #[cfg(feature = "devenv")] pub protocol_config: ProtocolConfig, pub with_prover: bool, - pub prover_config: Option, #[cfg(feature = "devenv")] pub skip_register_programs: bool, #[cfg(feature = "devenv")] @@ -73,7 +71,6 @@ impl ProgramTestConfig { ) -> Self { Self { additional_programs, - prover_config: Some(ProverConfig::default()), with_prover, #[cfg(feature = "devenv")] v2_state_tree_config: Some(InitStateTreeAccountsInstructionData::test_default()), @@ -87,7 +84,6 @@ impl ProgramTestConfig { pub fn default_with_batched_trees(with_prover: bool) -> Self { Self { additional_programs: None, - prover_config: Some(ProverConfig::default()), with_prover, v2_state_tree_config: Some(InitStateTreeAccountsInstructionData::test_default()), v2_address_tree_config: Some(InitAddressTreeAccountsInstructionData::test_default()), @@ -102,7 +98,6 @@ impl ProgramTestConfig { with_prover, v2_state_tree_config: Some(InitStateTreeAccountsInstructionData::test_default()), v2_address_tree_config: Some(InitAddressTreeAccountsInstructionData::test_default()), - prover_config: Some(ProverConfig::default()), ..Default::default() } } @@ -134,7 +129,6 @@ impl Default for ProgramTestConfig { ..Default::default() }, with_prover: true, - prover_config: None, #[cfg(feature = "devenv")] skip_second_v1_tree: false, #[cfg(feature = "devenv")] diff --git a/sdk-libs/program-test/src/program_test/light_program_test.rs b/sdk-libs/program-test/src/program_test/light_program_test.rs index 6661d370b4..b55d0e2a90 100644 --- a/sdk-libs/program-test/src/program_test/light_program_test.rs +++ b/sdk-libs/program-test/src/program_test/light_program_test.rs @@ -8,7 +8,7 @@ use light_client::{ }; #[cfg(feature = "devenv")] use light_compressed_account::hash_to_bn254_field_size_be; -use light_prover_client::prover::{spawn_prover, ProverConfig}; +use light_prover_client::prover::spawn_prover; use litesvm::LiteSVM; #[cfg(feature = "devenv")] use solana_account::WritableAccount; @@ -318,22 +318,16 @@ impl LightProgramTest { // reset tx counter after program setup. context.transaction_counter = 0; - // Will always start a prover server. + #[cfg(feature = "devenv")] - let prover_config = if config.prover_config.is_none() { - Some(ProverConfig::default()) - } else { - config.prover_config - }; + { + spawn_prover().await; + } #[cfg(not(feature = "devenv"))] - let prover_config = if config.with_prover && config.prover_config.is_none() { - Some(ProverConfig::default()) - } else { - config.prover_config - }; - if let Some(ref prover_config) = prover_config { - spawn_prover(prover_config.clone()).await; + if config.with_prover { + spawn_prover().await; } + Ok(context) } diff --git a/sdk-tests/client-test/tests/light_client.rs b/sdk-tests/client-test/tests/light_client.rs index 5bcf2074bc..7ee8ba84a2 100644 --- a/sdk-tests/client-test/tests/light_client.rs +++ b/sdk-tests/client-test/tests/light_client.rs @@ -14,7 +14,6 @@ use light_compressed_token::mint_sdk::{ use light_hasher::Poseidon; use light_merkle_tree_reference::{indexed::IndexedMerkleTree, MerkleTree}; use light_program_test::accounts::test_accounts::TestAccounts; -use light_prover_client::prover::ProverConfig; use light_sdk::{ address::{v1::derive_address, NewAddressParams}, token::{AccountState, TokenData}, @@ -53,8 +52,8 @@ const LAMPORTS_PER_SOL: u64 = 1_000_000_000; async fn test_all_endpoints() { let config = LightValidatorConfig { enable_indexer: true, - prover_config: Some(ProverConfig::default()), - wait_time: 75, + enable_prover: true, + wait_time: 10, sbf_programs: vec![], limit_ledger_size: None, };