|
74 | 74 | "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], |
75 | 75 | "PreTrainedModel": ["save_pretrained", "from_pretrained"], |
76 | 76 | "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], |
| 77 | + "ProcessorMixin": ["save_pretrained", "from_pretrained"], |
| 78 | + "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], |
77 | 79 | }, |
78 | 80 | } |
79 | 81 |
|
@@ -190,8 +192,8 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): |
190 | 192 | for library_name, library_classes in LOADABLE_CLASSES.items(): |
191 | 193 | library = importlib.import_module(library_name) |
192 | 194 | for base_class, save_load_methods in library_classes.items(): |
193 | | - class_candidate = getattr(library, base_class) |
194 | | - if issubclass(model_cls, class_candidate): |
| 195 | + class_candidate = getattr(library, base_class, None) |
| 196 | + if class_candidate is not None and issubclass(model_cls, class_candidate): |
195 | 197 | # if we found a suitable base class in LOADABLE_CLASSES then grab its save method |
196 | 198 | save_method_name = save_load_methods[0] |
197 | 199 | break |
@@ -543,11 +545,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
543 | 545 | library = importlib.import_module(library_name) |
544 | 546 | class_obj = getattr(library, class_name) |
545 | 547 | importable_classes = LOADABLE_CLASSES[library_name] |
546 | | - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} |
| 548 | + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} |
547 | 549 |
|
548 | 550 | expected_class_obj = None |
549 | 551 | for class_name, class_candidate in class_candidates.items(): |
550 | | - if issubclass(class_obj, class_candidate): |
| 552 | + if class_candidate is not None and issubclass(class_obj, class_candidate): |
551 | 553 | expected_class_obj = class_candidate |
552 | 554 |
|
553 | 555 | if not issubclass(passed_class_obj[name].__class__, expected_class_obj): |
@@ -577,14 +579,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P |
577 | 579 | else: |
578 | 580 | # else we just import it from the library. |
579 | 581 | library = importlib.import_module(library_name) |
| 582 | + |
580 | 583 | class_obj = getattr(library, class_name) |
581 | 584 | importable_classes = LOADABLE_CLASSES[library_name] |
582 | | - class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} |
| 585 | + class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} |
583 | 586 |
|
584 | 587 | if loaded_sub_model is None and sub_model_should_be_defined: |
585 | 588 | load_method_name = None |
586 | 589 | for class_name, class_candidate in class_candidates.items(): |
587 | | - if issubclass(class_obj, class_candidate): |
| 590 | + if class_candidate is not None and issubclass(class_obj, class_candidate): |
588 | 591 | load_method_name = importable_classes[class_name][1] |
589 | 592 |
|
590 | 593 | if load_method_name is None: |
|
0 commit comments