Skip to content

Commit a8df33d

Browse files
committed
Fix bug in parsing steps due to different data formats across metadata services
1 parent b895b95 commit a8df33d

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

metaflow/client/core.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,6 @@ class Task(MetaflowObject):
11201120
def __init__(self, *args, **kwargs):
11211121
super(Task, self).__init__(*args, **kwargs)
11221122
# We want to cache metadata dictionary since it's used in many places
1123-
self._metadata_dict = None
11241123

11251124
def _iter_filter(self, x):
11261125
# exclude private data artifacts
@@ -1140,6 +1139,7 @@ def _get_metadata_query_vals(
11401139
cur_foreach_stack_len: int,
11411140
steps: List[str],
11421141
is_ancestor: bool,
1142+
metadata_dict: Dict[str, Any],
11431143
):
11441144
"""
11451145
Returns the field name and field value to be used for querying metadata of successor or ancestor tasks.
@@ -1157,7 +1157,10 @@ def _get_metadata_query_vals(
11571157
ancestors and successors across multiple steps.
11581158
is_ancestor : bool
11591159
If we are querying for ancestor tasks, set this to True.
1160+
metadata_dict : Dict[str, Any]
1161+
Cached metadata dictionary of the current task
11601162
"""
1163+
11611164
# For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated
11621165
# which help us in querying ancestor and successor tasks.
11631166
# `foreach-indices`: contains the indices of the foreach stack at the time of task execution.
@@ -1181,7 +1184,7 @@ def _get_metadata_query_vals(
11811184
if query_foreach_stack_len == cur_foreach_stack_len:
11821185
# The successor or ancestor tasks belong to the same foreach stack level
11831186
field_name = "foreach-indices"
1184-
field_value = self.metadata_dict.get(field_name)
1187+
field_value = metadata_dict.get(field_name)
11851188
elif is_ancestor:
11861189
if query_foreach_stack_len > cur_foreach_stack_len:
11871190
# This is a foreach join
@@ -1190,15 +1193,15 @@ def _get_metadata_query_vals(
11901193
# We will compare the foreach-indices-truncated value of ancestor task with the
11911194
# foreach-indices value of current task
11921195
field_name = "foreach-indices-truncated"
1193-
field_value = self.metadata_dict.get("foreach-indices")
1196+
field_value = metadata_dict.get("foreach-indices")
11941197
else:
11951198
# This is a foreach split
11961199
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
11971200
# Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
11981201
# We will compare the foreach-indices value of ancestor task with the
11991202
# foreach-indices-truncated value of current task
12001203
field_name = "foreach-indices"
1201-
field_value = self.metadata_dict.get("foreach-indices-truncated")
1204+
field_value = metadata_dict.get("foreach-indices-truncated")
12021205
else:
12031206
if query_foreach_stack_len > cur_foreach_stack_len:
12041207
# This is a foreach split
@@ -1207,34 +1210,40 @@ def _get_metadata_query_vals(
12071210
# We will compare the foreach-indices value of current task with the
12081211
# foreach-indices-truncated value of successor tasks
12091212
field_name = "foreach-indices-truncated"
1210-
field_value = self.metadata_dict.get("foreach-indices")
1213+
field_value = metadata_dict.get("foreach-indices")
12111214
else:
12121215
# This is a foreach join
12131216
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
12141217
# Successor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
12151218
# We will compare the foreach-indices-truncated value of current task with the
12161219
# foreach-indices value of successor tasks
12171220
field_name = "foreach-indices"
1218-
field_value = self.metadata_dict.get("foreach-indices-truncated")
1221+
field_value = metadata_dict.get("foreach-indices-truncated")
12191222
return field_name, field_value
12201223

12211224
def _get_related_tasks(self, is_ancestor: bool) -> Dict[str, List[str]]:
12221225
flow_id, run_id, _, _ = self.path_components
1226+
metadata_dict = self.metadata_dict
12231227
steps = (
1224-
self.metadata_dict.get("previous-steps")
1228+
metadata_dict.get("previous-steps")
12251229
if is_ancestor
1226-
else self.metadata_dict.get("successor-steps")
1230+
else metadata_dict.get("successor-steps")
12271231
)
12281232

12291233
if not steps:
12301234
return {}
12311235

1236+
# Convert steps to a list if it's stored as a string in the metadata
1237+
if is_stringish(steps):
1238+
steps = [steps]
1239+
12321240
field_name, field_value = self._get_metadata_query_vals(
12331241
flow_id,
12341242
run_id,
1235-
len(self.metadata_dict.get("foreach-indices", [])),
1243+
len(metadata_dict.get("foreach-indices", [])),
12361244
steps,
12371245
is_ancestor=is_ancestor,
1246+
metadata_dict=metadata_dict,
12381247
)
12391248

12401249
return {
@@ -1419,12 +1428,9 @@ def metadata_dict(self) -> Dict[str, str]:
14191428
Dictionary mapping metadata name with value
14201429
"""
14211430
# use the newest version of each key, hence sorting
1422-
if self._metadata_dict is None:
1423-
self._metadata_dict = {
1424-
m.name: m.value
1425-
for m in sorted(self.metadata, key=lambda m: m.created_at)
1426-
}
1427-
return self._metadata_dict
1431+
return {
1432+
m.name: m.value for m in sorted(self.metadata, key=lambda m: m.created_at)
1433+
}
14281434

14291435
@property
14301436
def index(self) -> Optional[int]:

metaflow/task.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ def _dynamic_runtime_metadata(foreach_stack):
4646
foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack]
4747
return foreach_indices, foreach_indices_truncated, foreach_step_names
4848

49-
def _static_runtime_metadata(self, graph_info, step_name):
49+
@staticmethod
50+
def _static_runtime_metadata(graph_info, step_name):
5051
prev_steps = [
5152
node_name
5253
for node_name, attributes in graph_info["steps"].items()
5354
if step_name in attributes["next"]
5455
]
55-
succesor_steps = graph_info["steps"][step_name]["next"]
56-
return prev_steps, succesor_steps
56+
successor_steps = graph_info["steps"][step_name]["next"]
57+
return prev_steps, successor_steps
5758

5859
def __init__(
5960
self,

0 commit comments

Comments
 (0)