1515# Description:
1616# TensorFlow Probability ODE solvers.
1717
18+ load (
19+ "//tensorflow_probability/python:build_defs.bzl" ,
20+ "multi_substrate_py_library" ,
21+ "multi_substrate_py_test" ,
22+ )
23+
1824package (
1925 default_visibility = [
2026 "//tensorflow_probability:__subpackages__" ,
@@ -23,18 +29,20 @@ package(
2329
2430licenses (["notice" ])
2531
26- py_library (
32+ multi_substrate_py_library (
2733 name = "base" ,
2834 srcs = ["base.py" ],
2935 srcs_version = "PY3" ,
3036 deps = [
3137 # six dep,
3238 # tensorflow dep,
39+ "//tensorflow_probability/python/internal:custom_gradient" ,
3340 "//tensorflow_probability/python/internal:dtype_util" ,
41+ "//tensorflow_probability/python/math:gradient" ,
3442 ],
3543)
3644
37- py_library (
45+ multi_substrate_py_library (
3846 name = "bdf" ,
3947 srcs = ["bdf.py" ],
4048 srcs_version = "PY3" ,
@@ -45,10 +53,12 @@ py_library(
4553 # numpy dep,
4654 # tensorflow dep,
4755 "//tensorflow_probability/python/internal:dtype_util" ,
56+ "//tensorflow_probability/python/internal:prefer_static" ,
57+ "//tensorflow_probability/python/internal:tensorshape_util" ,
4858 ],
4959)
5060
51- py_library (
61+ multi_substrate_py_library (
5262 name = "dormand_prince" ,
5363 srcs = ["dormand_prince.py" ],
5464 srcs_version = "PY3" ,
@@ -61,17 +71,20 @@ py_library(
6171 ],
6272)
6373
64- py_library (
74+ multi_substrate_py_library (
6575 name = "bdf_util" ,
6676 srcs = ["bdf_util.py" ],
6777 srcs_version = "PY3" ,
6878 deps = [
6979 # numpy dep,
7080 # tensorflow dep,
81+ "//tensorflow_probability/python/internal:dtype_util" ,
82+ "//tensorflow_probability/python/internal:prefer_static" ,
83+ "//tensorflow_probability/python/internal:tensorshape_util" ,
7184 ],
7285)
7386
74- py_test (
87+ multi_substrate_py_test (
7588 name = "bdf_util_test" ,
7689 size = "small" ,
7790 srcs = ["bdf_util_test.py" ],
@@ -87,7 +100,7 @@ py_test(
87100 ],
88101)
89102
90- py_library (
103+ multi_substrate_py_library (
91104 name = "runge_kutta_util" ,
92105 srcs = ["runge_kutta_util.py" ],
93106 srcs_version = "PY3" ,
@@ -99,7 +112,7 @@ py_library(
99112 ],
100113)
101114
102- py_test (
115+ multi_substrate_py_test (
103116 name = "runge_kutta_util_test" ,
104117 size = "small" ,
105118 srcs = ["runge_kutta_util_test.py" ],
@@ -113,7 +126,7 @@ py_test(
113126 ],
114127)
115128
116- py_library (
129+ multi_substrate_py_library (
117130 name = "ode" ,
118131 srcs = ["__init__.py" ],
119132 srcs_version = "PY3" ,
@@ -124,12 +137,12 @@ py_library(
124137 ],
125138)
126139
127- py_test (
140+ multi_substrate_py_test (
128141 name = "ode_test" ,
129142 size = "large" ,
130143 srcs = ["ode_test.py" ],
131144 python_version = "PY3" ,
132- shard_count = 6 ,
145+ shard_count = 8 ,
133146 srcs_version = "PY3" ,
134147 deps = [
135148 # absl/testing:parameterized dep,
@@ -183,17 +196,18 @@ py_test(
183196 ],
184197)
185198
186- py_library (
199+ multi_substrate_py_library (
187200 name = "util" ,
188201 srcs = ["util.py" ],
189202 deps = [
190203 # tensorflow dep,
204+ "//tensorflow_probability/python/internal:dtype_util" ,
191205 "//tensorflow_probability/python/internal:prefer_static" ,
192206 "//tensorflow_probability/python/math:gradient" ,
193207 ],
194208)
195209
196- py_test (
210+ multi_substrate_py_test (
197211 name = "util_test" ,
198212 size = "small" ,
199213 srcs = ["util_test.py" ],
0 commit comments