Skip to content

Commit 68eaf8a

Browse files
committed
feat: preserve sharding
1 parent aa8142a commit 68eaf8a

File tree

2 files changed

+120
-10
lines changed

2 files changed

+120
-10
lines changed

src/Sharding.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ function HloSharding(sharding::NamedSharding, client::XLA.IFRT.Client, _, x)
949949
data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding)
950950

951951
# XXX: Can we auto-pad this case too? Will think about it later, for now use
952-
# NamedSharidng
952+
# NamedSharding
953953
return data, ShardInfo(hlo_sharding, device_to_array_slices), nothing
954954
end
955955

@@ -997,7 +997,7 @@ function (sharding::HloSharding)(
997997
data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding)
998998

999999
# XXX: Can we auto-pad this case too? Will think about it later, for now use
1000-
# NamedSharidng
1000+
# NamedSharding
10011001
return data, ShardInfo(sharding, device_to_array_slices), nothing
10021002
end
10031003

src/serialization/EnzymeJAX.jl

Lines changed: 118 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@ function export_to_enzymejax(
9595
# This returns compilation result with traced argument information
9696
argprefix = gensym("exportarg")
9797
mod, mlir_fn_res = Compiler.compile_mlir(
98-
f, args; argprefix, drop_unsupported_attributes=true
98+
f,
99+
args;
100+
argprefix,
101+
drop_unsupported_attributes=true,
102+
# to support older jax versions which don't support shardy
103+
shardy_passes=:to_mhlo_shardings,
99104
)
100105
hlo_code = string(mod)
101106

@@ -120,13 +125,21 @@ function export_to_enzymejax(
120125
# Store input data for the single NPZ file
121126
arr_key = "arr_$input_idx"
122127
input_data[arr_key] = _to_array(concrete_arg)
128+
129+
# Extract sharding information if available and if preserve_sharding is true
130+
sharding_info = nothing
131+
if preserve_sharding && _has_sharding_info(concrete_arg)
132+
sharding_info = _extract_sharding_info(concrete_arg)
133+
end
134+
123135
push!(
124136
input_info,
125137
(
126138
shape=size(concrete_arg),
127139
dtype=Reactant.unwrapped_eltype(concrete_arg),
128140
path="arg." * join(string.(path), "."),
129141
key=arr_key,
142+
sharding=sharding_info,
130143
),
131144
)
132145
input_idx += 1
@@ -138,13 +151,40 @@ function export_to_enzymejax(
138151

139152
# Generate Python script
140153
python_path = joinpath(output_dir, "$(function_name).py")
141-
_generate_python_script(python_path, function_name, mlir_path, input_path, input_info)
154+
_generate_python_script(
155+
python_path, function_name, mlir_path, input_path, input_info; preserve_sharding
156+
)
142157
return python_path
143158
end
144159

145160
_to_array(x::Reactant.ConcreteRArray) = Array(x)
146161
_to_array(x::Reactant.ConcreteRNumber{T}) where {T} = T(x)
147162

163+
_has_sharding_info(x::Reactant.ConcreteRArray) = Reactant.Sharding.is_sharded(x.sharding)
164+
_has_sharding_info(x) = false
165+
166+
function _extract_sharding_info(x::Reactant.ConcreteRArray)
167+
sharding = x.sharding
168+
if sharding isa Reactant.Sharding.ShardInfo
169+
inner_sharding = sharding.sharding
170+
if inner_sharding isa Reactant.Sharding.NamedSharding
171+
# TODO: we need to export is_closed, priority, and subaxes at some point
172+
return (;
173+
type="NamedSharding",
174+
mesh=inner_sharding.mesh,
175+
partition_spec=inner_sharding.partition_spec,
176+
)
177+
elseif inner_sharding isa Reactant.Sharding.Replicated
178+
return (; type="Replicated", mesh=inner_sharding.mesh)
179+
elseif inner_sharding isa Reactant.Sharding.NoSharding
180+
return (; type="NoSharding")
181+
else
182+
error("Unsupported sharding type: $(typeof(inner_sharding))")
183+
end
184+
end
185+
return (; type="NoSharding")
186+
end
187+
148188
function save_inputs_npz(
149189
output_path::String, inputs::Dict{String,<:Union{AbstractArray,Number}}
150190
)
@@ -162,7 +202,8 @@ function _generate_python_script(
162202
function_name::String,
163203
mlir_path::String,
164204
input_path::String,
165-
input_info::Vector,
205+
input_info::Vector;
206+
preserve_sharding::Bool=true,
166207
)
167208
# Get relative paths for the Python script
168209
output_dir = dirname(python_path)
@@ -191,6 +232,74 @@ function _generate_python_script(
191232
for (i, info) in enumerate(input_info)
192233
]
193234

235+
# Generate sharding annotations if available
236+
has_any_sharding =
237+
preserve_sharding && any(info.sharding !== nothing for info in input_info)
238+
239+
device_put_calls = String[]
240+
if has_any_sharding
241+
inserted_meshes = IdDict()
242+
counter = 0
243+
for (i, info) in enumerate(input_info)
244+
if info.sharding !== nothing
245+
if haskey(inserted_meshes, info.sharding.mesh)
246+
pymesh = inserted_meshes[info.sharding.mesh]
247+
else
248+
pymesh = "mesh$counter"
249+
counter += 1
250+
inserted_meshes[info.sharding.mesh] = pymesh
251+
axis_sizes = join(string.(reverse(info.sharding.mesh.axis_sizes)), ", ")
252+
mesh_axes = join(
253+
reverse(["'$(string(x))'" for x in info.sharding.mesh.axis_names]),
254+
", ",
255+
)
256+
257+
push!(
258+
device_put_calls,
259+
"$(pymesh) = jax.make_mesh(($(axis_sizes)), ($(mesh_axes)))",
260+
)
261+
end
262+
263+
push!(
264+
device_put_calls,
265+
"# Set up sharding for $(arg_names[i]): $(info.sharding.type)",
266+
)
267+
268+
# Create device_put call with NamedSharding
269+
if info.sharding.type == "NoSharding"
270+
device_put_calls_str = "$(arg_names[i]) = jnp.asarray($(arg_names[i]))"
271+
elseif info.sharding.type == "NamedSharding"
272+
pstrings = [
273+
if length(p) == 1
274+
p[1] isa Nothing ? "None" : "'$(string(p[1]))'"
275+
else
276+
join(string.(reverse(p)), ", ")
277+
end for p in info.sharding.partition_spec
278+
]
279+
partition_spec = join(reverse(pstrings), ", ")
280+
device_put_calls_str = "$(arg_names[i]) = jax.device_put($(arg_names[i]), jax.sharding.NamedSharding($(pymesh), P($(partition_spec))))"
281+
else
282+
error("Unsupported sharding type: $(info.sharding.type)")
283+
end
284+
push!(device_put_calls, device_put_calls_str)
285+
end
286+
end
287+
end
288+
289+
if has_any_sharding
290+
inputs_to_jax_arrays = """# Apply sharding to inputs using device_put and NamedSharding
291+
$(join(device_put_calls, "\n "))
292+
"""
293+
else
294+
convert_str_list = join(
295+
[" $(argname) = jnp.asarray($(argname))" for argname in arg_names], "\n"
296+
)
297+
inputs_to_jax_arrays = """
298+
# Convert inputs to jax arrays
299+
$(convert_str_list)
300+
"""
301+
end
302+
194303
load_inputs = ["npz_data['$(info.key)']" for info in input_info]
195304

196305
# Build the complete Python script
@@ -203,6 +312,7 @@ function _generate_python_script(
203312
204313
from enzyme_ad.jax import hlo_call
205314
import jax
315+
from jax.sharding import PartitionSpec as P
206316
import jax.numpy as jnp
207317
import numpy as np
208318
import os
@@ -245,11 +355,11 @@ function _generate_python_script(
245355
246356
if __name__ == \"__main__\":
247357
# Load the example inputs
248-
inputs = load_inputs()
249-
250-
# Run the function (with JIT compilation)
251-
print(\"Running $(function_name) with JIT compilation...\")
252-
result = run_$(function_name)(*inputs)
358+
($(arg_list),) = load_inputs()
359+
$(inputs_to_jax_arrays)
360+
# Run the function
361+
print(\"Running $(function_name)...\")
362+
result = run_$(function_name)($(arg_list))
253363
print(\"Result:\", result)
254364
"""
255365

0 commit comments

Comments
 (0)