@@ -1120,7 +1120,6 @@ class Task(MetaflowObject):
1120
1120
def __init__ (self , * args , ** kwargs ):
1121
1121
super (Task , self ).__init__ (* args , ** kwargs )
1122
1122
# We want to cache metadata dictionary since it's used in many places
1123
- self ._metadata_dict = None
1124
1123
1125
1124
def _iter_filter (self , x ):
1126
1125
# exclude private data artifacts
@@ -1140,6 +1139,7 @@ def _get_metadata_query_vals(
1140
1139
cur_foreach_stack_len : int ,
1141
1140
steps : List [str ],
1142
1141
is_ancestor : bool ,
1142
+ metadata_dict : Dict [str , Any ],
1143
1143
):
1144
1144
"""
1145
1145
Returns the field name and field value to be used for querying metadata of successor or ancestor tasks.
@@ -1157,7 +1157,10 @@ def _get_metadata_query_vals(
1157
1157
ancestors and successors across multiple steps.
1158
1158
is_ancestor : bool
1159
1159
If we are querying for ancestor tasks, set this to True.
1160
+ metadata_dict : Dict[str, Any]
1161
+ Cached metadata dictionary of the current task
1160
1162
"""
1163
+
1161
1164
# For each task, we also log additional metadata fields such as foreach-indices and foreach-indices-truncated
1162
1165
# which help us in querying ancestor and successor tasks.
1163
1166
# `foreach-indices`: contains the indices of the foreach stack at the time of task execution.
@@ -1181,7 +1184,7 @@ def _get_metadata_query_vals(
1181
1184
if query_foreach_stack_len == cur_foreach_stack_len :
1182
1185
# The successor or ancestor tasks belong to the same foreach stack level
1183
1186
field_name = "foreach-indices"
1184
- field_value = self . metadata_dict .get (field_name )
1187
+ field_value = metadata_dict .get (field_name )
1185
1188
elif is_ancestor :
1186
1189
if query_foreach_stack_len > cur_foreach_stack_len :
1187
1190
# This is a foreach join
@@ -1190,15 +1193,15 @@ def _get_metadata_query_vals(
1190
1193
# We will compare the foreach-indices-truncated value of ancestor task with the
1191
1194
# foreach-indices value of current task
1192
1195
field_name = "foreach-indices-truncated"
1193
- field_value = self . metadata_dict .get ("foreach-indices" )
1196
+ field_value = metadata_dict .get ("foreach-indices" )
1194
1197
else :
1195
1198
# This is a foreach split
1196
1199
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
1197
1200
# Ancestor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
1198
1201
# We will compare the foreach-indices value of ancestor task with the
1199
1202
# foreach-indices-truncated value of current task
1200
1203
field_name = "foreach-indices"
1201
- field_value = self . metadata_dict .get ("foreach-indices-truncated" )
1204
+ field_value = metadata_dict .get ("foreach-indices-truncated" )
1202
1205
else :
1203
1206
if query_foreach_stack_len > cur_foreach_stack_len :
1204
1207
# This is a foreach split
@@ -1207,34 +1210,40 @@ def _get_metadata_query_vals(
1207
1210
# We will compare the foreach-indices value of current task with the
1208
1211
# foreach-indices-truncated value of successor tasks
1209
1212
field_name = "foreach-indices-truncated"
1210
- field_value = self . metadata_dict .get ("foreach-indices" )
1213
+ field_value = metadata_dict .get ("foreach-indices" )
1211
1214
else :
1212
1215
# This is a foreach join
1213
1216
# Current Task: foreach-indices = [0, 1, 2], foreach-indices-truncated = [0, 1]
1214
1217
# Successor Task: foreach-indices = [0, 1], foreach-indices-truncated = [0]
1215
1218
# We will compare the foreach-indices-truncated value of current task with the
1216
1219
# foreach-indices value of successor tasks
1217
1220
field_name = "foreach-indices"
1218
- field_value = self . metadata_dict .get ("foreach-indices-truncated" )
1221
+ field_value = metadata_dict .get ("foreach-indices-truncated" )
1219
1222
return field_name , field_value
1220
1223
1221
1224
def _get_related_tasks (self , is_ancestor : bool ) -> Dict [str , List [str ]]:
1222
1225
flow_id , run_id , _ , _ = self .path_components
1226
+ metadata_dict = self .metadata_dict
1223
1227
steps = (
1224
- self . metadata_dict .get ("previous-steps" )
1228
+ metadata_dict .get ("previous-steps" )
1225
1229
if is_ancestor
1226
- else self . metadata_dict .get ("successor-steps" )
1230
+ else metadata_dict .get ("successor-steps" )
1227
1231
)
1228
1232
1229
1233
if not steps :
1230
1234
return {}
1231
1235
1236
+ # Convert steps to a list if it's stored as a string in the metadata
1237
+ if is_stringish (steps ):
1238
+ steps = [steps ]
1239
+
1232
1240
field_name , field_value = self ._get_metadata_query_vals (
1233
1241
flow_id ,
1234
1242
run_id ,
1235
- len (self . metadata_dict .get ("foreach-indices" , [])),
1243
+ len (metadata_dict .get ("foreach-indices" , [])),
1236
1244
steps ,
1237
1245
is_ancestor = is_ancestor ,
1246
+ metadata_dict = metadata_dict ,
1238
1247
)
1239
1248
1240
1249
return {
@@ -1419,12 +1428,9 @@ def metadata_dict(self) -> Dict[str, str]:
1419
1428
Dictionary mapping metadata name with value
1420
1429
"""
1421
1430
# use the newest version of each key, hence sorting
1422
- if self ._metadata_dict is None :
1423
- self ._metadata_dict = {
1424
- m .name : m .value
1425
- for m in sorted (self .metadata , key = lambda m : m .created_at )
1426
- }
1427
- return self ._metadata_dict
1431
+ return {
1432
+ m .name : m .value for m in sorted (self .metadata , key = lambda m : m .created_at )
1433
+ }
1428
1434
1429
1435
@property
1430
1436
def index (self ) -> Optional [int ]:
0 commit comments