Skip to content

Commit 29ba11b

Browse files
committed
Add TPU tests for JAX and Tensorflow.
Using the runners.
1 parent eb4a13f commit 29ba11b

File tree

4 files changed

+117
-50
lines changed

4 files changed

+117
-50
lines changed

.github/workflows/actions.yml

Lines changed: 89 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,65 +12,118 @@ permissions:
1212

1313
jobs:
1414
run_tests:
15-
name: Test the code
15+
name: Test the code on CPU
16+
runs-on: ubuntu-latest
17+
1618
strategy:
1719
fail-fast: false
1820
matrix:
1921
backend: [tensorflow, jax, torch]
20-
runs-on: ubuntu-latest
22+
2123
env:
2224
KERAS_BACKEND: ${{ matrix.backend }}
25+
2326
steps:
2427
- uses: actions/checkout@v4
25-
- name: Set up Python 3.11
26-
uses: actions/setup-python@v5
28+
29+
- uses: actions/setup-python@v5
2730
with:
2831
python-version: "3.11"
32+
2933
- name: Get pip cache dir
3034
id: pip-cache
3135
run: |
32-
python -m pip install --upgrade pip setuptools
36+
python -m pip install --upgrade pip
3337
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
34-
- name: pip cache
35-
uses: actions/cache@v4
38+
39+
- uses: actions/cache@v4
3640
with:
3741
path: ${{ steps.pip-cache.outputs.dir }}
3842
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}
3943
restore-keys: |
4044
${{ runner.os }}-pip-
45+
4146
- name: Install dependencies
42-
run: |
43-
pip install -r requirements.txt --progress-bar off
44-
pip install --no-deps -e "." --progress-bar off
47+
run: pip install -r requirements.txt --progress-bar off
48+
4549
- name: Test with pytest
50+
run: pytest keras_rs/
51+
52+
run_tests_in_container:
53+
name: Test the code on TPU
54+
runs-on: linux-x86-ct6e-44-1tpu
55+
56+
strategy:
57+
fail-fast: false
58+
matrix:
59+
backend: [tensorflow, jax]
60+
61+
container:
62+
image: python:3.11-slim
63+
options: --privileged --network host
64+
65+
steps:
66+
- uses: actions/checkout@v4
67+
68+
- name: Install Dependencies
69+
run: |
70+
pip install --no-cache-dir -U pip && \
71+
pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt
72+
73+
- name: Set Keras Backend
4674
run: |
47-
pytest keras_rs/
75+
echo "KERAS_BACKEND=${{ matrix.backend }}" >> $GITHUB_ENV
76+
echo "TPU_NAME=local" >> $GITHUB_ENV
77+
78+
- name: Set TF Specific Environment Variables
79+
if: ${{ matrix.backend == 'tensorflow'}}
80+
run: |
81+
echo "PJRT_DEVICE=TPU" >> $GITHUB_ENV
82+
echo "NEXT_PLUGGABLE_DEVICE_USE_C_API=true" >> $GITHUB_ENV
83+
echo "TF_XLA_FLAGS=--tf_mlir_enable_mlir_bridge=true" >> $GITHUB_ENV
84+
pip show libtpu | grep "^Location: " | sed "s/^Location: \(.*\)$/TF_PLUGGABLE_DEVICE_LIBRARY_PATH=\1\/libtpu\/libtpu.so/1" >> $GITHUB_ENV
85+
86+
- name: Verify TF Installation
87+
if: ${{ matrix.backend == 'tensorflow'}}
88+
run: python3 -c "import tensorflow as tf; print('Tensorflow devices:', tf.config.list_logical_devices())"
89+
90+
- name: Verify JAX Installation
91+
if: ${{ matrix.backend == 'jax'}}
92+
run: python3 -c "import jax; print('JAX devices:', jax.devices())"
93+
94+
- name: Test with pytest
95+
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
96+
4897
check_format:
4998
name: Check the code format
5099
runs-on: ubuntu-latest
100+
51101
steps:
52-
- uses: actions/checkout@v4
53-
- name: Set up Python 3.11
54-
uses: actions/setup-python@v5
55-
with:
56-
python-version: "3.11"
57-
- name: Get pip cache dir
58-
id: pip-cache
59-
run: |
60-
python -m pip install --upgrade pip setuptools
61-
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
62-
- name: pip cache
63-
uses: actions/cache@v4
64-
with:
65-
path: ${{ steps.pip-cache.outputs.dir }}
66-
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}
67-
restore-keys: |
68-
${{ runner.os }}-pip-
69-
- name: Install dependencies
70-
run: |
71-
pip install -r requirements.txt --progress-bar off
72-
pip install --no-deps -e "." --progress-bar off
73-
- name: Install pre-commit
74-
run: pip install pre-commit && pre-commit install
75-
- name: Run pre-commit
76-
run: pre-commit run --all-files --hook-stage manual
102+
- uses: actions/checkout@v4
103+
104+
- uses: actions/setup-python@v5
105+
with:
106+
python-version: "3.11"
107+
108+
- name: Get pip cache dir
109+
id: pip-cache
110+
run: |
111+
python -m pip install --upgrade pip
112+
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
113+
114+
- uses: actions/cache@v4
115+
with:
116+
path: ${{ steps.pip-cache.outputs.dir }}
117+
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}
118+
restore-keys: |
119+
${{ runner.os }}-pip-
120+
121+
- name: Install dependencies
122+
run: |
123+
pip install -r requirements.txt --progress-bar off
124+
125+
- name: Install pre-commit
126+
run: pip install pre-commit && pre-commit install
127+
128+
- name: Run pre-commit
129+
run: pre-commit run --all-files --hook-stage manual

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import keras
1111
import numpy as np
1212
import tensorflow as tf
13-
from absl import flags
1413
from absl.testing import absltest
1514
from absl.testing import parameterized
1615

@@ -19,10 +18,6 @@
1918
from keras_rs.src.layers.embedding import distributed_embedding
2019
from keras_rs.src.layers.embedding import distributed_embedding_config as config
2120

22-
FLAGS = flags.FLAGS
23-
_TPU = flags.DEFINE_string("tpu", None, "The TPU to use for TPUStrategy.")
24-
25-
2621
FEATURE1_EMBEDDING_OUTPUT_DIM = 7
2722
FEATURE2_EMBEDDING_OUTPUT_DIM = 11
2823
EMBEDDING_OUTPUT_DIM = 7
@@ -50,29 +45,32 @@ def experimental_distribute_dataset(self, dataset, options=None):
5045
class JaxDummyStrategy(DummyStrategy):
5146
@property
5247
def num_replicas_in_sync(self):
53-
return len(jax.devices("tpu"))
48+
return jax.device_count("tpu")
49+
50+
51+
def ragged_bool_true(self):
52+
return True
5453

5554

5655
class DistributedEmbeddingTest(testing.TestCase, parameterized.TestCase):
5756
def setUp(self):
5857
super().setUp()
59-
try:
60-
self.on_tpu = _TPU.value is not None
61-
except flags.UnparsedFlagAccessError:
62-
self.on_tpu = False
63-
58+
self.on_tpu = "TPU_NAME" in os.environ
6459
self.placement = "sparsecore" if self.on_tpu else "default_device"
6560

6661
if keras.backend.backend() == "tensorflow":
6762
tf.debugging.disable_traceback_filtering()
6863

6964
if keras.backend.backend() == "tensorflow" and self.on_tpu:
65+
# Workaround for a bug preventing weights from being ragged tensors.
66+
# The fix in TensorFlow was added after 2.19.1:
67+
# https://github.com/tensorflow/tensorflow/commit/185f2f58bafc6410125080264d5d7730e1fa1eb2
68+
tf.RaggedTensor.__bool__ = ragged_bool_true
69+
7070
# FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
7171
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16
7272

73-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
74-
tpu=_TPU.value
75-
)
73+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver("")
7674
tf.config.experimental_connect_to_cluster(resolver)
7775

7876
topology = tf.tpu.experimental.initialize_tpu_system(resolver)

requirements-jax-tpu.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Tensorflow cpu-only version.
2+
tensorflow-cpu>=2.20.0
3+
4+
# Jax with TPU support.
5+
jax[tpu]
6+
7+
# Support for TPU embeddings.
8+
jax-tpu-embedding
9+
10+
-r requirements-common.txt

requirements-tensorflow-tpu.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Tensorflow with TPU support.
2+
tensorflow-tpu==2.19.1
3+
4+
jax[cpu]
5+
6+
-r requirements-common.txt

0 commit comments

Comments
 (0)