Skip to content

Commit 2261abe

Browse files
author
Orbax Authors
committed
update flag parsing logic and support --flag=value format
PiperOrigin-RevId: 809197843
1 parent 47f1859 commit 2261abe

File tree

2 files changed

+259
-10
lines changed

2 files changed

+259
-10
lines changed

model/orbax/experimental/model/core/python/compile_options_util.py

Lines changed: 129 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717
from collections.abc import Mapping, Sequence
1818
import logging
19+
from typing import Any
1920

21+
from absl import flags
22+
from google.protobuf import descriptor
2023
from google.protobuf import text_format
2124
import jax
2225
from orbax.experimental.model.core.protos import manifest_pb2
@@ -28,6 +31,12 @@
2831
from tensorflow.compiler.xla import xla_pb2
2932
from tensorflow.compiler.xla.pjrt.proto import compile_options_pb2
3033

34+
# A mapping between XLAflag names and protobuf field names.
35+
_XLA_FLAG_TO_FIELD_MAP = {
36+
field.name: field
37+
for field in tpu_comp_env_pb2.TpuCompilationEnvironment.DESCRIPTOR.fields
38+
}
39+
3140

3241
def generate_tpu_compilation_env(
3342
xla_flags: Sequence[str] | None = None,
@@ -39,16 +48,21 @@ def generate_tpu_compilation_env(
3948
tpu_compilation_env_str
4049
)
4150
# Override with supplied XLA flags if any is provided.
42-
if xla_flags is not None:
43-
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
44-
xla_flags_str = '\n'.join(xla_flags)
45-
try:
46-
text_format.Parse(xla_flags_str, env_override)
47-
except text_format.ParseError as e:
48-
raise ValueError(
49-
f'Error parsing supplied XLA flag overrides {xla_flags_str}.'
50-
) from e
51-
env.MergeFrom(env_override)
51+
if xla_flags:
52+
is_proto_formatted = False if xla_flags[0].startswith('--') else True
53+
if is_proto_formatted:
54+
merge_proto_formatted_flags_compile_option(xla_flags, env)
55+
else:
56+
parsed_flags = {}
57+
for flag in xla_flags:
58+
if not flag.startswith('--'):
59+
raise ValueError(
60+
f"Flag {flag} does not start with '--'. All flags must be in the"
61+
' format of --flag_name=flag_value.'
62+
)
63+
flag_name, flag_value = flag[2:].split('=', 1)
64+
parsed_flags[flag_name] = flag_value
65+
merge_flags_into_compile_options(parsed_flags, env)
5266

5367
# Pack the TPU compilation environment into a compilation env proto.
5468
any_proto = any_pb2.Any()
@@ -109,8 +123,13 @@ def generate_xla_compile_options(
109123
"""Sets the XLA compilation options.
110124
111125
Args:
126+
native_serialization_platforms: A sequence of platform names that the
127+
compile options will be set for. If None, the compile options will be set
128+
for TPU only.
112129
xla_flags_per_platform: A mapping from platform name to a list of xla flags
113130
which will be used to override the default XLA compilation flags.
131+
jax_mesh: The JAX mesh used for sharding. If None, the compile options will
132+
be set for a default single-replica.
114133
115134
Returns:
116135
A `CompileOptionsProtoMap` containing the XLA compilation options per
@@ -156,3 +175,103 @@ def generate_xla_compile_options(
156175
generate_compilation_options(compile_environment, jax_mesh)
157176
)
158177
return compile_options_map
178+
179+
180+
def get_field_for_flag(flag_name: str) -> descriptor.FieldDescriptor:
181+
"""Gets the protobuf field descriptor for a given flag name."""
182+
if flag_name not in _XLA_FLAG_TO_FIELD_MAP:
183+
raise ValueError(
184+
f'No TpuCompilationEnvironment field matching flag {flag_name}'
185+
)
186+
return _XLA_FLAG_TO_FIELD_MAP[flag_name]
187+
188+
189+
def parse_flag_from_string(flag_name: str, value: str) -> Any:
190+
"""Parses a string value for a given flag and normalizes it for a proto field.
191+
192+
This is a Python implementation of the C++ function
193+
TpuCompEnvReflection::ParseFlagFromString.
194+
195+
Args:
196+
flag_name: The name of the flag.
197+
value: The string value of the flag.
198+
199+
Returns:
200+
The parsed and normalized value suitable for setting the corresponding field
201+
in `TpuCompilationEnvironment`. This can be a primitive type (int, bool,
202+
str), float, an enum's integer value, or a proto message instance.
203+
204+
Raises:
205+
ValueError: If the flag is not found, or if a proto message value cannot
206+
be parsed.
207+
"""
208+
try:
209+
flag_holder = flags.FLAGS[flag_name]
210+
except KeyError:
211+
raise ValueError(f'Flag not found: {flag_name}')
212+
213+
parsed_value = flag_holder.parser.parse(value)
214+
field = get_field_for_flag(flag_name)
215+
216+
if field.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
217+
message_instance = field.message_type._concrete_class()
218+
try:
219+
text_format.Parse(value, message_instance)
220+
return message_instance
221+
except text_format.ParseError as e:
222+
raise ValueError(
223+
f'Error parsing proto value for flag {flag_name}: {value}'
224+
) from e
225+
if field.type == descriptor.FieldDescriptor.TYPE_ENUM:
226+
if isinstance(parsed_value, str):
227+
return field.enum_type.values_by_name[parsed_value].number
228+
# If it's already an int, assume it's the correct value.
229+
return parsed_value
230+
if field.type in (
231+
descriptor.FieldDescriptor.TYPE_FLOAT,
232+
descriptor.FieldDescriptor.TYPE_DOUBLE,
233+
):
234+
return float(parsed_value)
235+
return parsed_value
236+
237+
238+
def merge_flags_into_compile_options(
239+
xla_flags: Mapping[str, str],
240+
env: tpu_comp_env_pb2.TpuCompilationEnvironment,
241+
):
242+
"""Merges flags into a TpuCompilationEnvironment proto.
243+
244+
Args:
245+
xla_flags: A mapping of XLA flag names to their string values. These flags
246+
will be parsed and merged into the `env` proto.
247+
env: The TpuCompilationEnvironment proto to merge the flags into. This
248+
proto will be modified in place.
249+
"""
250+
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
251+
for flag_name, value in xla_flags.items():
252+
field_descriptor = get_field_for_flag(flag_name)
253+
parsed_value = parse_flag_from_string(flag_name, value)
254+
if field_descriptor.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
255+
# For message types, we need to copy the parsed message.
256+
getattr(env_override, field_descriptor.name).CopyFrom(parsed_value)
257+
else:
258+
# For scalar types, we can set the attribute directly.
259+
setattr(env_override, field_descriptor.name, parsed_value)
260+
env.MergeFrom(env_override)
261+
262+
263+
# TODO(b/438187387): remove this path and only allow the "--flag=value" format.
264+
def merge_proto_formatted_flags_compile_option(
265+
xla_flags: Sequence[str],
266+
env: tpu_comp_env_pb2.TpuCompilationEnvironment,
267+
):
268+
"""Merges flags into a proto."""
269+
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
270+
xla_flags_str = '\n'.join(xla_flags)
271+
try:
272+
text_format.Parse(xla_flags_str, env_override)
273+
except text_format.ParseError as e:
274+
raise ValueError(
275+
f'Error parsing supplied XLA flag overrides {xla_flags_str}.'
276+
) from e
277+
env.MergeFrom(env_override)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from absl.testing import parameterized
17+
from orbax.experimental.model.core.python import compile_options_util
18+
from .platforms.xla.service.jellyfish import tpu_compilation_environment_pb2 as tpu_comp_env_pb2
19+
20+
21+
class CompileOptionsUtilTest(parameterized.TestCase):
22+
23+
def test_parse_flag_from_string_bool(self):
24+
result = compile_options_util.parse_flag_from_string(
25+
'xla_sc_poison_buffers', 'false'
26+
)
27+
self.assertEqual(result, False)
28+
29+
def test_parse_flag_from_string_int(self):
30+
result = compile_options_util.parse_flag_from_string(
31+
'xla_jf_rematerialization_percent_shared_memory_limit', '99'
32+
)
33+
self.assertEqual(result, 99)
34+
35+
def test_parse_flag_from_string_float(self):
36+
result = compile_options_util.parse_flag_from_string(
37+
'xla_tpu_async_copy_bandwidth_scaling_factor', '0.19125064716453793'
38+
)
39+
self.assertEqual(result, 0.19125064716453793)
40+
41+
def test_parse_flag_from_string_string(self):
42+
result = compile_options_util.parse_flag_from_string(
43+
'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers',
44+
'NO_SCALE',
45+
)
46+
self.assertEqual(result, 'NO_SCALE')
47+
48+
def test_parse_flag_from_string_proto(self):
49+
compile_options_util.parse_flag_from_string(
50+
'xla_tpu_memory_bound_loop_optimizer_options', 'enabled:false'
51+
)
52+
53+
def test_parse_flag_from_string_enum(self):
54+
result = compile_options_util.parse_flag_from_string(
55+
'xla_memory_scheduler', 'DFS'
56+
)
57+
expected = tpu_comp_env_pb2.MemorySchedulerProto.DFS
58+
self.assertEqual(result, expected)
59+
60+
def test_parse_flag_from_string_nonexistent_flag(self):
61+
with self.assertRaisesRegex(ValueError, 'Flag not found: nonexistent_flag'):
62+
compile_options_util.parse_flag_from_string('nonexistent_flag', 'value')
63+
64+
@parameterized.named_parameters(
65+
(
66+
'dict_xla_flags',
67+
{
68+
'xla_jf_rematerialization_percent_shared_memory_limit': '99',
69+
'xla_tpu_allocate_scoped_vmem_at_same_offset': 'false',
70+
'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers': (
71+
'NO_SCALE'
72+
),
73+
'xla_tpu_memory_bound_loop_optimizer_options': 'enabled:false',
74+
'xla_tpu_async_copy_bandwidth_scaling_factor': (
75+
'0.19125064716453793'
76+
),
77+
},
78+
compile_options_util.merge_flags_into_compile_options,
79+
),
80+
(
81+
'proto_formatted_xla_flags',
82+
[
83+
'xla_jf_rematerialization_percent_shared_memory_limit: 99',
84+
'xla_tpu_allocate_scoped_vmem_at_same_offset: false',
85+
(
86+
'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers:'
87+
" 'NO_SCALE'"
88+
),
89+
'xla_tpu_memory_bound_loop_optimizer_options: {enabled:false}',
90+
(
91+
'xla_tpu_async_copy_bandwidth_scaling_factor:'
92+
' 0.19125064716453793'
93+
),
94+
],
95+
compile_options_util.merge_proto_formatted_flags_compile_option,
96+
),
97+
)
98+
def test_merge_flags_into_compile_options(self, xla_flags, merge_fn):
99+
# Initialize the environment with some values.
100+
env = tpu_comp_env_pb2.TpuCompilationEnvironment()
101+
# Values that should be overridden.
102+
env.xla_jf_rematerialization_percent_shared_memory_limit = 10
103+
env.xla_tpu_memory_bound_loop_optimizer_options.enabled = True
104+
# Value that should not be overridden.
105+
env.xla_tpu_wait_n_cycles_before_program_termination = 1234
106+
107+
# Merge the flags into the environment.
108+
merge_fn(xla_flags, env)
109+
self.assertEqual(
110+
env.xla_jf_rematerialization_percent_shared_memory_limit, 99
111+
)
112+
self.assertEqual(env.xla_tpu_allocate_scoped_vmem_at_same_offset, False)
113+
self.assertEqual(
114+
env.xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers,
115+
'NO_SCALE',
116+
)
117+
self.assertEqual(
118+
env.xla_tpu_memory_bound_loop_optimizer_options.enabled, False
119+
)
120+
self.assertAlmostEqual(
121+
env.xla_tpu_async_copy_bandwidth_scaling_factor,
122+
0.19125064716453793,
123+
)
124+
125+
# Value that should not be overridden.
126+
self.assertEqual(env.xla_tpu_wait_n_cycles_before_program_termination, 1234)
127+
128+
129+
if __name__ == '__main__':
130+
absltest.main()

0 commit comments

Comments
 (0)