Skip to content

Commit 62d2484

Browse files
committed
Update task.py
1 parent de830b1 commit 62d2484

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

keras_hub/src/models/task.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ def __setattr__(self, name, value):
7676
is_property = isinstance(getattr(type(self), name, None), property)
7777
is_unitialized = not hasattr(self, "_initialized")
7878
is_torch = keras.config.backend() == "torch"
79+
80+
# Prevent _DictWrapper creation for list attributes
81+
if isinstance(value, list) and hasattr(self, "_initialized"):
82+
# Use a trackable list wrapper instead of regular list
83+
try:
84+
from tensorflow.python.trackable.data_structures import ListWrapper
85+
value = ListWrapper(value)
86+
except ImportError:
87+
# Fallback: keep as regular list
88+
pass
89+
7990
if is_torch and (is_property or is_unitialized):
8091
return object.__setattr__(self, name, value)
8192
return super().__setattr__(name, value)
@@ -369,3 +380,37 @@ def add_layer(layer, info):
369380
print_fn=print_fn,
370381
**kwargs,
371382
)
383+
384+
def _trackable_children(self, save_type=None, **kwargs):
385+
"""Override to prevent _DictWrapper issues during TensorFlow export.
386+
387+
This method ensures clean trackable object traversal by avoiding
388+
problematic _DictWrapper objects that cause SavedModel export errors.
389+
"""
390+
try:
391+
children = super()._trackable_children(save_type, **kwargs)
392+
except Exception:
393+
# If parent fails, return minimal trackable children
394+
children = {}
395+
396+
# Import _DictWrapper safely
397+
try:
398+
from tensorflow.python.trackable.data_structures import _DictWrapper
399+
except ImportError:
400+
return children
401+
402+
clean_children = {}
403+
for name, child in children.items():
404+
try:
405+
# Skip _DictWrapper objects entirely to avoid introspection issues
406+
if hasattr(child, '__class__') and '_DictWrapper' in child.__class__.__name__:
407+
continue
408+
409+
# Test if child supports introspection safely
410+
_ = getattr(child, '__dict__', None)
411+
clean_children[name] = child
412+
except (TypeError, AttributeError):
413+
# Skip objects that cause introspection errors
414+
continue
415+
416+
return clean_children

0 commit comments

Comments
 (0)