Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
)
from keras_nlp.src.models.falcon.falcon_preprocessor import FalconPreprocessor
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_nlp.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_nlp.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_nlp.src.models.gemma.gemma_causal_lm_preprocessor import (
Expand Down Expand Up @@ -182,6 +183,9 @@
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.models.preprocessor import Preprocessor
from keras_nlp.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_nlp.src.models.resnet.resnet_feature_pyramid_backbone import (
ResNetFeaturePyramidBackbone,
)
from keras_nlp.src.models.resnet.resnet_image_classifier import (
ResNetImageClassifier,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty
from keras_nlp.src.utils.timm.convert import load_timm_backbone
from keras_nlp.src.utils.transformers.convert import load_transformers_backbone


Expand Down Expand Up @@ -204,6 +205,8 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from

if format == "transformers":
return load_transformers_backbone(cls, preset, load_weights)
elif format == "timm":
return load_timm_backbone(cls, preset, load_weights, **kwargs)

preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
Expand Down
53 changes: 53 additions & 0 deletions keras_nlp/src/models/feature_pyramid_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.models.backbone import Backbone


@keras_nlp_export("keras_nlp.models.FeaturePyramidBackbone")
class FeaturePyramidBackbone(Backbone):
@property
def pyramid_outputs(self):
"""A dict for feature pyramid outputs.

The key is a string represents the name of the feature output and the
value is a `keras.KerasTensor`. A typical feature pyramid has multiple
levels corresponding to scales such as `["P2", "P3", "P4", "P5"]`. Scale
`Pn` represents a feature map `2^n` times smaller in width and height
than the inputs.
"""
return getattr(self, "_pyramid_outputs", {})

@pyramid_outputs.setter
def pyramid_outputs(self, value):
if not isinstance(value, dict):
raise TypeError(
"`pyramid_outputs` must be a dictionary. "
f"Received: value={value} of type {type(value)}"
)
for k, v in value.items():
if not isinstance(k, str):
raise TypeError(
"The key of `pyramid_outputs` must be a string. "
f"Received: key={k} of type {type(k)}"
)
if not isinstance(v, keras.KerasTensor):
raise TypeError(
"The value of `pyramid_outputs` must be a "
"`keras.KerasTensor`. "
f"Received: value={v} of type {type(v)}"
)
self._pyramid_outputs = value
Loading