@@ -95,7 +95,12 @@ function export_to_enzymejax(
9595 # This returns compilation result with traced argument information
9696 argprefix = gensym (" exportarg" )
9797 mod, mlir_fn_res = Compiler. compile_mlir (
98- f, args; argprefix, drop_unsupported_attributes= true
98+ f,
99+ args;
100+ argprefix,
101+ drop_unsupported_attributes= true ,
102+ # to support older jax versions which don't support shardy
103+ shardy_passes= :to_mhlo_shardings ,
99104 )
100105 hlo_code = string (mod)
101106
@@ -120,13 +125,21 @@ function export_to_enzymejax(
120125 # Store input data for the single NPZ file
121126 arr_key = " arr_$input_idx "
122127 input_data[arr_key] = _to_array (concrete_arg)
128+
129+ # Extract sharding information if available and if preserve_sharding is true
130+ sharding_info = nothing
131+ if preserve_sharding && _has_sharding_info (concrete_arg)
132+ sharding_info = _extract_sharding_info (concrete_arg)
133+ end
134+
123135 push! (
124136 input_info,
125137 (
126138 shape= size (concrete_arg),
127139 dtype= Reactant. unwrapped_eltype (concrete_arg),
128140 path= " arg." * join (string .(path), " ." ),
129141 key= arr_key,
142+ sharding= sharding_info,
130143 ),
131144 )
132145 input_idx += 1
@@ -138,13 +151,40 @@ function export_to_enzymejax(
138151
139152 # Generate Python script
140153 python_path = joinpath (output_dir, " $(function_name) .py" )
141- _generate_python_script (python_path, function_name, mlir_path, input_path, input_info)
154+ _generate_python_script (
155+ python_path, function_name, mlir_path, input_path, input_info; preserve_sharding
156+ )
142157 return python_path
143158end
144159
145160_to_array (x:: Reactant.ConcreteRArray ) = Array (x)
146161_to_array (x:: Reactant.ConcreteRNumber{T} ) where {T} = T (x)
147162
163+ _has_sharding_info (x:: Reactant.ConcreteRArray ) = Reactant. Sharding. is_sharded (x. sharding)
164+ _has_sharding_info (x) = false
165+
166+ function _extract_sharding_info (x:: Reactant.ConcreteRArray )
167+ sharding = x. sharding
168+ if sharding isa Reactant. Sharding. ShardInfo
169+ inner_sharding = sharding. sharding
170+ if inner_sharding isa Reactant. Sharding. NamedSharding
171+ # TODO : we need to export is_closed, priority, and subaxes at some point
172+ return (;
173+ type= " NamedSharding" ,
174+ mesh= inner_sharding. mesh,
175+ partition_spec= inner_sharding. partition_spec,
176+ )
177+ elseif inner_sharding isa Reactant. Sharding. Replicated
178+ return (; type= " Replicated" , mesh= inner_sharding. mesh)
179+ elseif inner_sharding isa Reactant. Sharding. NoSharding
180+ return (; type= " NoSharding" )
181+ else
182+ error (" Unsupported sharding type: $(typeof (inner_sharding)) " )
183+ end
184+ end
185+ return (; type= " NoSharding" )
186+ end
187+
148188function save_inputs_npz (
149189 output_path:: String , inputs:: Dict{String,<:Union{AbstractArray,Number}}
150190)
@@ -162,7 +202,8 @@ function _generate_python_script(
162202 function_name:: String ,
163203 mlir_path:: String ,
164204 input_path:: String ,
165- input_info:: Vector ,
205+ input_info:: Vector ;
206+ preserve_sharding:: Bool = true ,
166207)
167208 # Get relative paths for the Python script
168209 output_dir = dirname (python_path)
@@ -191,6 +232,74 @@ function _generate_python_script(
191232 for (i, info) in enumerate (input_info)
192233 ]
193234
235+ # Generate sharding annotations if available
236+ has_any_sharding =
237+ preserve_sharding && any (info. sharding != = nothing for info in input_info)
238+
239+ device_put_calls = String[]
240+ if has_any_sharding
241+ inserted_meshes = IdDict ()
242+ counter = 0
243+ for (i, info) in enumerate (input_info)
244+ if info. sharding != = nothing
245+ if haskey (inserted_meshes, info. sharding. mesh)
246+ pymesh = inserted_meshes[info. sharding. mesh]
247+ else
248+ pymesh = " mesh$counter "
249+ counter += 1
250+ inserted_meshes[info. sharding. mesh] = pymesh
251+ axis_sizes = join (string .(reverse (info. sharding. mesh. axis_sizes)), " , " )
252+ mesh_axes = join (
253+ reverse ([" '$(string (x)) '" for x in info. sharding. mesh. axis_names]),
254+ " , " ,
255+ )
256+
257+ push! (
258+ device_put_calls,
259+ " $(pymesh) = jax.make_mesh(($(axis_sizes) ), ($(mesh_axes) ))" ,
260+ )
261+ end
262+
263+ push! (
264+ device_put_calls,
265+ " # Set up sharding for $(arg_names[i]) : $(info. sharding. type) " ,
266+ )
267+
268+ # Create device_put call with NamedSharding
269+ if info. sharding. type == " NoSharding"
270+ device_put_calls_str = " $(arg_names[i]) = jnp.asarray($(arg_names[i]) )"
271+ elseif info. sharding. type == " NamedSharding"
272+ pstrings = [
273+ if length (p) == 1
274+ p[1 ] isa Nothing ? " None" : " '$(string (p[1 ])) '"
275+ else
276+ join (string .(reverse (p)), " , " )
277+ end for p in info. sharding. partition_spec
278+ ]
279+ partition_spec = join (reverse (pstrings), " , " )
280+ device_put_calls_str = " $(arg_names[i]) = jax.device_put($(arg_names[i]) , jax.sharding.NamedSharding($(pymesh) , P($(partition_spec) )))"
281+ else
282+ error (" Unsupported sharding type: $(info. sharding. type) " )
283+ end
284+ push! (device_put_calls, device_put_calls_str)
285+ end
286+ end
287+ end
288+
289+ if has_any_sharding
290+ inputs_to_jax_arrays = """ # Apply sharding to inputs using device_put and NamedSharding
291+ $(join (device_put_calls, " \n " ))
292+ """
293+ else
294+ convert_str_list = join (
295+ [" $(argname) = jnp.asarray($(argname) )" for argname in arg_names], " \n "
296+ )
297+ inputs_to_jax_arrays = """
298+ # Convert inputs to jax arrays
299+ $(convert_str_list)
300+ """
301+ end
302+
194303 load_inputs = [" npz_data['$(info. key) ']" for info in input_info]
195304
196305 # Build the complete Python script
@@ -203,6 +312,7 @@ function _generate_python_script(
203312
204313 from enzyme_ad.jax import hlo_call
205314 import jax
315+ from jax.sharding import PartitionSpec as P
206316 import jax.numpy as jnp
207317 import numpy as np
208318 import os
@@ -245,11 +355,11 @@ function _generate_python_script(
245355
246356 if __name__ == \" __main__\" :
247357 # Load the example inputs
248- inputs = load_inputs()
249-
250- # Run the function (with JIT compilation)
251- print(\" Running $(function_name) with JIT compilation ...\" )
252- result = run_$(function_name) (*inputs )
358+ ( $(arg_list) ,) = load_inputs()
359+ $(inputs_to_jax_arrays)
360+ # Run the function
361+ print(\" Running $(function_name) ...\" )
362+ result = run_$(function_name) ($(arg_list) )
253363 print(\" Result:\" , result)
254364 """
255365
0 commit comments