1616
1717from collections .abc import Mapping , Sequence
1818import logging
19+ from typing import Any
1920
21+ from absl import flags
22+ from google .protobuf import descriptor
2023from google .protobuf import text_format
2124import jax
2225from orbax .experimental .model .core .protos import manifest_pb2
2831from tensorflow .compiler .xla import xla_pb2
2932from tensorflow .compiler .xla .pjrt .proto import compile_options_pb2
3033
34+ # A mapping between XLAflag names and protobuf field names.
35+ _XLA_FLAG_TO_FIELD_MAP = {
36+ field .name : field
37+ for field in tpu_comp_env_pb2 .TpuCompilationEnvironment .DESCRIPTOR .fields
38+ }
39+
3140
3241def generate_tpu_compilation_env (
3342 xla_flags : Sequence [str ] | None = None ,
@@ -39,16 +48,21 @@ def generate_tpu_compilation_env(
3948 tpu_compilation_env_str
4049 )
4150 # Override with supplied XLA flags if any is provided.
42- if xla_flags is not None :
43- env_override = tpu_comp_env_pb2 .TpuCompilationEnvironment ()
44- xla_flags_str = '\n ' .join (xla_flags )
45- try :
46- text_format .Parse (xla_flags_str , env_override )
47- except text_format .ParseError as e :
48- raise ValueError (
49- f'Error parsing supplied XLA flag overrides { xla_flags_str } .'
50- ) from e
51- env .MergeFrom (env_override )
51+ if xla_flags :
52+ is_proto_formatted = False if xla_flags [0 ].startswith ('--' ) else True
53+ if is_proto_formatted :
54+ merge_proto_formatted_flags_compile_option (xla_flags , env )
55+ else :
56+ parsed_flags = {}
57+ for flag in xla_flags :
58+ if not flag .startswith ('--' ):
59+ raise ValueError (
60+ f"Flag { flag } does not start with '--'. All flags must be in the"
61+ ' format of --flag_name=flag_value.'
62+ )
63+ flag_name , flag_value = flag [2 :].split ('=' , 1 )
64+ parsed_flags [flag_name ] = flag_value
65+ merge_flags_into_compile_options (parsed_flags , env )
5266
5367 # Pack the TPU compilation environment into a compilation env proto.
5468 any_proto = any_pb2 .Any ()
@@ -109,8 +123,13 @@ def generate_xla_compile_options(
109123 """Sets the XLA compilation options.
110124
111125 Args:
126+ native_serialization_platforms: A sequence of platform names that the
127+ compile options will be set for. If None, the compile options will be set
128+ for TPU only.
112129 xla_flags_per_platform: A mapping from platform name to a list of xla flags
113130 which will be used to override the default XLA compilation flags.
131+ jax_mesh: The JAX mesh used for sharding. If None, the compile options will
132+ be set for a default single-replica.
114133
115134 Returns:
116135 A `CompileOptionsProtoMap` containing the XLA compilation options per
@@ -156,3 +175,103 @@ def generate_xla_compile_options(
156175 generate_compilation_options (compile_environment , jax_mesh )
157176 )
158177 return compile_options_map
178+
179+
180+ def get_field_for_flag (flag_name : str ) -> descriptor .FieldDescriptor :
181+ """Gets the protobuf field descriptor for a given flag name."""
182+ if flag_name not in _XLA_FLAG_TO_FIELD_MAP :
183+ raise ValueError (
184+ f'No TpuCompilationEnvironment field matching flag { flag_name } '
185+ )
186+ return _XLA_FLAG_TO_FIELD_MAP [flag_name ]
187+
188+
189+ def parse_flag_from_string (flag_name : str , value : str ) -> Any :
190+ """Parses a string value for a given flag and normalizes it for a proto field.
191+
192+ This is a Python implementation of the C++ function
193+ TpuCompEnvReflection::ParseFlagFromString.
194+
195+ Args:
196+ flag_name: The name of the flag.
197+ value: The string value of the flag.
198+
199+ Returns:
200+ The parsed and normalized value suitable for setting the corresponding field
201+ in `TpuCompilationEnvironment`. This can be a primitive type (int, bool,
202+ str), float, an enum's integer value, or a proto message instance.
203+
204+ Raises:
205+ ValueError: If the flag is not found, or if a proto message value cannot
206+ be parsed.
207+ """
208+ try :
209+ flag_holder = flags .FLAGS [flag_name ]
210+ except KeyError :
211+ raise ValueError (f'Flag not found: { flag_name } ' )
212+
213+ parsed_value = flag_holder .parser .parse (value )
214+ field = get_field_for_flag (flag_name )
215+
216+ if field .type == descriptor .FieldDescriptor .TYPE_MESSAGE :
217+ message_instance = field .message_type ._concrete_class ()
218+ try :
219+ text_format .Parse (value , message_instance )
220+ return message_instance
221+ except text_format .ParseError as e :
222+ raise ValueError (
223+ f'Error parsing proto value for flag { flag_name } : { value } '
224+ ) from e
225+ if field .type == descriptor .FieldDescriptor .TYPE_ENUM :
226+ if isinstance (parsed_value , str ):
227+ return field .enum_type .values_by_name [parsed_value ].number
228+ # If it's already an int, assume it's the correct value.
229+ return parsed_value
230+ if field .type in (
231+ descriptor .FieldDescriptor .TYPE_FLOAT ,
232+ descriptor .FieldDescriptor .TYPE_DOUBLE ,
233+ ):
234+ return float (parsed_value )
235+ return parsed_value
236+
237+
238+ def merge_flags_into_compile_options (
239+ xla_flags : Mapping [str , str ],
240+ env : tpu_comp_env_pb2 .TpuCompilationEnvironment ,
241+ ):
242+ """Merges flags into a TpuCompilationEnvironment proto.
243+
244+ Args:
245+ xla_flags: A mapping of XLA flag names to their string values. These flags
246+ will be parsed and merged into the `env` proto.
247+ env: The TpuCompilationEnvironment proto to merge the flags into. This
248+ proto will be modified in place.
249+ """
250+ env_override = tpu_comp_env_pb2 .TpuCompilationEnvironment ()
251+ for flag_name , value in xla_flags .items ():
252+ field_descriptor = get_field_for_flag (flag_name )
253+ parsed_value = parse_flag_from_string (flag_name , value )
254+ if field_descriptor .type == descriptor .FieldDescriptor .TYPE_MESSAGE :
255+ # For message types, we need to copy the parsed message.
256+ getattr (env_override , field_descriptor .name ).CopyFrom (parsed_value )
257+ else :
258+ # For scalar types, we can set the attribute directly.
259+ setattr (env_override , field_descriptor .name , parsed_value )
260+ env .MergeFrom (env_override )
261+
262+
263+ # TODO(b/438187387): remove this path and only allow the "--flag=value" format.
264+ def merge_proto_formatted_flags_compile_option (
265+ xla_flags : Sequence [str ],
266+ env : tpu_comp_env_pb2 .TpuCompilationEnvironment ,
267+ ):
268+ """Merges flags into a proto."""
269+ env_override = tpu_comp_env_pb2 .TpuCompilationEnvironment ()
270+ xla_flags_str = '\n ' .join (xla_flags )
271+ try :
272+ text_format .Parse (xla_flags_str , env_override )
273+ except text_format .ParseError as e :
274+ raise ValueError (
275+ f'Error parsing supplied XLA flag overrides { xla_flags_str } .'
276+ ) from e
277+ env .MergeFrom (env_override )
0 commit comments