Skip to content

Commit 8436c7e

Browse files
author
Orbax Authors
committed
Add a multihost build job for Orbax Checkpoint.
PiperOrigin-RevId: 819744217
1 parent b9c9f5d commit 8436c7e

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

.github/workflows/build.yml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,60 @@ jobs:
8181
"context": "github-actions/build"
8282
}'
8383
84+
build-checkpoint-multihost:
85+
name: "build-checkpoint-multihost (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
86+
runs-on: linux-g2-16-l4-1gpu-x4
87+
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
88+
defaults:
89+
run:
90+
working-directory: checkpoint
91+
strategy:
92+
matrix:
93+
python-version: ["3.12"]
94+
jax-version: ["newest"]
95+
steps:
96+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
97+
- name: Set up Python ${{ matrix.python-version }}
98+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
99+
with:
100+
python-version: ${{ matrix.python-version }}
101+
- name: Install dependencies
102+
# TODO(b/275613424): remove `pip install -e .` and `pip uninstall -y orbax`.
103+
# Currently in place to override remote orbax import due to flax dependency.
104+
run: |
105+
pip install -e .
106+
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
107+
pip uninstall -y orbax
108+
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
109+
pip install -U jax[k8s,cuda12] jaxlib
110+
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
111+
pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
112+
else
113+
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
114+
fi
115+
- name: Test with pytest
116+
# TODO(yaning): Move these to an exclude target within pytest.ini.
117+
run: |
118+
python -m pytest orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py
119+
# The below step just reports the success or failure of tests as a "commit status".
120+
# This is needed for copybara integration.
121+
- name: Report success or failure as github status
122+
if: always()
123+
shell: bash
124+
run: |
125+
status="${{ job.status }}"
126+
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
127+
curl -sS --request POST \
128+
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
129+
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
130+
--header 'content-type: application/json' \
131+
--data '{
132+
"state": "'$lowercase_status'",
133+
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
134+
"description": "'$status'",
135+
"context": "github-actions/build"
136+
}'
137+
84138
build-export:
85139
name: "build-export (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
86140
runs-on: ubuntu-latest

0 commit comments

Comments
 (0)