diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 928becf3c0..fe3d181d7e 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -42,6 +42,9 @@ BASNetImageConverter, ) from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter +from keras_hub.src.models.cspnet.cspnet_image_converter import ( + CSPNetImageConverter, +) from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( DeepLabV3ImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 248232312d..59682dc4e5 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -61,11 +61,12 @@ from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer from keras_hub.src.models.clip.clip_vision_encoder import CLIPVisionEncoder -from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( - CSPDarkNetBackbone, +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone +from keras_hub.src.models.cspnet.cspnet_image_classifier import ( + CSPNetImageClassifier, ) -from keras_hub.src.models.csp_darknet.csp_darknet_image_classifier import ( - CSPDarkNetImageClassifier, +from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( + CSPNetImageClassifierPreprocessor, ) from keras_hub.src.models.deberta_v3.deberta_v3_backbone import ( DebertaV3Backbone, diff --git a/keras_hub/src/models/csp_darknet/__init__.py b/keras_hub/src/models/csp_darknet/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_backbone.py b/keras_hub/src/models/csp_darknet/csp_darknet_backbone.py deleted file mode 100644 index a6ca9cb2fc..0000000000 --- a/keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +++ /dev/null @@ -1,427 +0,0 @@ -import keras -from keras import layers - -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone - - -@keras_hub_export("keras_hub.models.CSPDarkNetBackbone") -class CSPDarkNetBackbone(FeaturePyramidBackbone): - """This class represents Keras Backbone of CSPDarkNet model. - - This class implements a CSPDarkNet backbone as described in - [CSPNet: A New Backbone that can Enhance Learning Capability of CNN]( - https://arxiv.org/abs/1911.11929). - - Args: - stackwise_num_filters: A list of ints, filter size for each dark - level in the model. - stackwise_depth: A list of ints, the depth for each dark level in the - model. - block_type: str. One of `"basic_block"` or `"depthwise_block"`. - Use `"depthwise_block"` for depthwise conv block - `"basic_block"` for basic conv block. - Defaults to "basic_block". - image_shape: tuple. The input shape without the batch size. - Defaults to `(None, None, 3)`. - - Examples: - ```python - input_data = np.ones(shape=(8, 224, 224, 3)) - - # Pretrained backbone - model = keras_hub.models.CSPDarkNetBackbone.from_preset( - "csp_darknet_tiny_imagenet" - ) - model(input_data) - - # Randomly initialized backbone with a custom config - model = keras_hub.models.CSPDarkNetBackbone( - stackwise_num_filters=[128, 256, 512, 1024], - stackwise_depth=[3, 9, 9, 3], - ) - model(input_data) - ``` - """ - - def __init__( - self, - stackwise_num_filters, - stackwise_depth, - block_type="basic_block", - image_shape=(None, None, 3), - **kwargs, - ): - # === Functional Model === - channel_axis = ( - -1 if keras.config.image_data_format() == "channels_last" else 1 - ) - apply_ConvBlock = ( - apply_darknet_conv_block_depthwise - if block_type == "depthwise_block" - else apply_darknet_conv_block - ) - base_channels = stackwise_num_filters[0] // 2 - - image_input = layers.Input(shape=image_shape) - x = image_input # Intermediate result. - x = apply_focus(channel_axis, name="stem_focus")(x) - x = apply_darknet_conv_block( - base_channels, - channel_axis, - kernel_size=3, - strides=1, - name="stem_conv", - )(x) - - pyramid_outputs = {} - for index, (channels, depth) in enumerate( - zip(stackwise_num_filters, stackwise_depth) - ): - x = apply_ConvBlock( - channels, - channel_axis, - kernel_size=3, - strides=2, - name=f"dark{index + 2}_conv", - )(x) - - if index == len(stackwise_depth) - 1: - x = apply_spatial_pyramid_pooling_bottleneck( - channels, - channel_axis, - hidden_filters=channels // 2, - name=f"dark{index + 2}_spp", - )(x) - - x = apply_cross_stage_partial( - channels, - channel_axis, - num_bottlenecks=depth, - block_type="basic_block", - residual=(index != len(stackwise_depth) - 1), - name=f"dark{index + 2}_csp", - )(x) - pyramid_outputs[f"P{index + 2}"] = x - - super().__init__(inputs=image_input, outputs=x, **kwargs) - - # === Config === - self.stackwise_num_filters = stackwise_num_filters - self.stackwise_depth = stackwise_depth - self.block_type = block_type - self.image_shape = image_shape - self.pyramid_outputs = pyramid_outputs - - def get_config(self): - config = super().get_config() - config.update( - { - "stackwise_num_filters": self.stackwise_num_filters, - "stackwise_depth": self.stackwise_depth, - "block_type": self.block_type, - "image_shape": self.image_shape, - } - ) - return config - - -def apply_focus(channel_axis, name=None): - """A block used in CSPDarknet to focus information into channels of the - image. - - If the dimensions of a batch input is (batch_size, width, height, channels), - this layer converts the image into size (batch_size, width/2, height/2, - 4*channels). See [the original discussion on YoloV5 Focus Layer](https://github.com/ultralytics/yolov5/discussions/3181). - - Args: - name: the name for the lambda layer used in the block. - - Returns: - a function that takes an input Tensor representing a Focus layer. - """ - - def apply(x): - return layers.Concatenate(axis=channel_axis, name=name)( - [ - x[..., ::2, ::2, :], - x[..., 1::2, ::2, :], - x[..., ::2, 1::2, :], - x[..., 1::2, 1::2, :], - ], - ) - - return apply - - -def apply_darknet_conv_block( - filters, - channel_axis, - kernel_size, - strides, - use_bias=False, - activation="silu", - name=None, -): - """ - The basic conv block used in Darknet. Applies Conv2D followed by a - BatchNorm. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the - number of output filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. Can be a single - integer to specify the same value both dimensions. - strides: An integer or tuple/list of 2 integers, specifying the strides - of the convolution along the height and width. Can be a single - integer to the same value both dimensions. - use_bias: Boolean, whether the layer uses a bias vector. - activation: the activation applied after the BatchNorm layer. One of - "silu", "relu" or "leaky_relu", defaults to "silu". - name: the prefix for the layer names used in the block. - """ - if name is None: - name = f"conv_block{keras.backend.get_uid('conv_block')}" - - def apply(inputs): - x = layers.Conv2D( - filters, - kernel_size, - strides, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=use_bias, - name=name + "_conv", - )(inputs) - - x = layers.BatchNormalization(axis=channel_axis, name=name + "_bn")(x) - - if activation == "silu": - x = layers.Lambda(lambda x: keras.activations.silu(x))(x) - elif activation == "relu": - x = layers.ReLU()(x) - elif activation == "leaky_relu": - x = layers.LeakyReLU(0.1)(x) - - return x - - return apply - - -def apply_darknet_conv_block_depthwise( - filters, channel_axis, kernel_size, strides, activation="silu", name=None -): - """ - The depthwise conv block used in CSPDarknet. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the - number of output filters in the final convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. Can be a single - integer to specify the same value both dimensions. - strides: An integer or tuple/list of 2 integers, specifying the strides - of the convolution along the height and width. Can be a single - integer to the same value both dimensions. - activation: the activation applied after the final layer. One of "silu", - "relu" or "leaky_relu", defaults to "silu". - name: the prefix for the layer names used in the block. - - """ - if name is None: - name = f"conv_block{keras.backend.get_uid('conv_block')}" - - def apply(inputs): - x = layers.DepthwiseConv2D( - kernel_size, - strides, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - )(inputs) - x = layers.BatchNormalization(axis=channel_axis)(x) - - if activation == "silu": - x = layers.Lambda(lambda x: keras.activations.swish(x))(x) - elif activation == "relu": - x = layers.ReLU()(x) - elif activation == "leaky_relu": - x = layers.LeakyReLU(0.1)(x) - - x = apply_darknet_conv_block( - filters, - channel_axis, - kernel_size=1, - strides=1, - activation=activation, - )(x) - - return x - - return apply - - -def apply_spatial_pyramid_pooling_bottleneck( - filters, - channel_axis, - hidden_filters=None, - kernel_sizes=(5, 9, 13), - activation="silu", - name=None, -): - """ - Spatial pyramid pooling layer used in YOLOv3-SPP - - Args: - filters: Integer, the dimensionality of the output spaces (i.e. the - number of output filters in used the blocks). - hidden_filters: Integer, the dimensionality of the intermediate - bottleneck space (i.e. the number of output filters in the - bottleneck convolution). If None, it will be equal to filters. - Defaults to None. - kernel_sizes: A list or tuple representing all the pool sizes used for - the pooling layers, defaults to (5, 9, 13). - activation: Activation for the conv layers, defaults to "silu". - name: the prefix for the layer names used in the block. - - Returns: - a function that takes an input Tensor representing an - SpatialPyramidPoolingBottleneck. - """ - if name is None: - name = f"spp{keras.backend.get_uid('spp')}" - - if hidden_filters is None: - hidden_filters = filters - - def apply(x): - x = apply_darknet_conv_block( - hidden_filters, - channel_axis, - kernel_size=1, - strides=1, - activation=activation, - name=f"{name}_conv1", - )(x) - x = [x] - - for kernel_size in kernel_sizes: - x.append( - layers.MaxPooling2D( - kernel_size, - strides=1, - padding="same", - data_format=keras.config.image_data_format(), - name=f"{name}_maxpool_{kernel_size}", - )(x[0]) - ) - - x = layers.Concatenate(axis=channel_axis, name=f"{name}_concat")(x) - x = apply_darknet_conv_block( - filters, - channel_axis, - kernel_size=1, - strides=1, - activation=activation, - name=f"{name}_conv2", - )(x) - - return x - - return apply - - -def apply_cross_stage_partial( - filters, - channel_axis, - num_bottlenecks, - residual=True, - block_type="basic_block", - activation="silu", - name=None, -): - """A block used in Cross Stage Partial Darknet. - - Args: - filters: Integer, the dimensionality of the output space (i.e. the - number of output filters in the final convolution). - num_bottlenecks: an integer representing the number of blocks added in - the layer bottleneck. - residual: a boolean representing whether the value tensor before the - bottleneck should be added to the output of the bottleneck as a - residual, defaults to True. - block_type: str. One of `"basic_block"` or `"depthwise_block"`. - Use `"depthwise_block"` for depthwise conv block - `"basic_block"` for basic conv block. - Defaults to "basic_block". - activation: the activation applied after the final layer. One of "silu", - "relu" or "leaky_relu", defaults to "silu". - """ - - if name is None: - uid = keras.backend.get_uid("cross_stage_partial") - name = f"cross_stage_partial_{uid}" - - def apply(inputs): - hidden_channels = filters // 2 - ConvBlock = ( - apply_darknet_conv_block_depthwise - if block_type == "basic_block" - else apply_darknet_conv_block - ) - - x1 = apply_darknet_conv_block( - hidden_channels, - channel_axis, - kernel_size=1, - strides=1, - activation=activation, - name=f"{name}_conv1", - )(inputs) - - x2 = apply_darknet_conv_block( - hidden_channels, - channel_axis, - kernel_size=1, - strides=1, - activation=activation, - name=f"{name}_conv2", - )(inputs) - - for i in range(num_bottlenecks): - residual_x = x1 - x1 = apply_darknet_conv_block( - hidden_channels, - channel_axis, - kernel_size=1, - strides=1, - activation=activation, - name=f"{name}_bottleneck_{i}_conv1", - )(x1) - x1 = ConvBlock( - hidden_channels, - channel_axis, - kernel_size=3, - strides=1, - activation=activation, - name=f"{name}_bottleneck_{i}_conv2", - )(x1) - if residual: - x1 = layers.Add(name=f"{name}_bottleneck_{i}_add")( - [residual_x, x1] - ) - - x = layers.Concatenate(name=f"{name}_concat")([x1, x2]) - x = apply_darknet_conv_block( - filters, - channel_axis, - kernel_size=1, - strides=1, - activation=activation, - name=f"{name}_conv3", - )(x) - - return x - - return apply diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_backbone_test.py b/keras_hub/src/models/csp_darknet/csp_darknet_backbone_test.py deleted file mode 100644 index 35b601987d..0000000000 --- a/keras_hub/src/models/csp_darknet/csp_darknet_backbone_test.py +++ /dev/null @@ -1,41 +0,0 @@ -import numpy as np -import pytest - -from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( - CSPDarkNetBackbone, -) -from keras_hub.src.tests.test_case import TestCase - - -class CSPDarkNetBackboneTest(TestCase): - def setUp(self): - self.init_kwargs = { - "stackwise_num_filters": [2, 4, 6, 8], - "stackwise_depth": [1, 3, 3, 1], - "block_type": "basic_block", - "image_shape": (32, 32, 3), - } - self.input_size = 32 - self.input_data = np.ones( - (2, self.input_size, self.input_size, 3), dtype="float32" - ) - - def test_backbone_basics(self): - self.run_vision_backbone_test( - cls=CSPDarkNetBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - expected_output_shape=(2, 1, 1, 8), - expected_pyramid_output_keys=["P2", "P3", "P4", "P5"], - expected_pyramid_image_sizes=[(8, 8), (4, 4), (2, 2), (1, 1)], - run_mixed_precision_check=False, - run_data_format_check=False, - ) - - @pytest.mark.large - def test_saved_model(self): - self.run_model_saving_test( - cls=CSPDarkNetBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - ) diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py deleted file mode 100644 index c0a5b1bb3e..0000000000 --- a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +++ /dev/null @@ -1,10 +0,0 @@ -from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( - CSPDarkNetBackbone, -) -from keras_hub.src.models.image_classifier import ImageClassifier - - -@keras_hub_export("keras_hub.models.CSPDarkNetImageClassifier") -class CSPDarkNetImageClassifier(ImageClassifier): - backbone_cls = CSPDarkNetBackbone diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py deleted file mode 100644 index 78eddcbbac..0000000000 --- a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py +++ /dev/null @@ -1,51 +0,0 @@ -import numpy as np -import pytest - -from keras_hub.src.models.csp_darknet.csp_darknet_backbone import ( - CSPDarkNetBackbone, -) -from keras_hub.src.models.csp_darknet.csp_darknet_image_classifier import ( - CSPDarkNetImageClassifier, -) -from keras_hub.src.tests.test_case import TestCase - - -class CSPDarkNetImageClassifierTest(TestCase): - def setUp(self): - # Setup model. - self.images = np.ones((2, 16, 16, 3), dtype="float32") - self.labels = [0, 3] - self.backbone = CSPDarkNetBackbone( - stackwise_num_filters=[2, 16, 16], - stackwise_depth=[1, 3, 3, 1], - block_type="basic_block", - image_shape=(16, 16, 3), - ) - self.init_kwargs = { - "backbone": self.backbone, - "num_classes": 2, - "activation": "softmax", - } - self.train_data = ( - self.images, - self.labels, - ) - - def test_classifier_basics(self): - pytest.skip( - reason="TODO: enable after preprocessor flow is figured out" - ) - self.run_task_test( - cls=CSPDarkNetImageClassifier, - init_kwargs=self.init_kwargs, - train_data=self.train_data, - expected_output_shape=(2, 2), - ) - - @pytest.mark.large - def test_saved_model(self): - self.run_model_saving_test( - cls=CSPDarkNetImageClassifier, - init_kwargs=self.init_kwargs, - input_data=self.images, - ) diff --git a/keras_hub/src/models/cspnet/__init__.py b/keras_hub/src/models/cspnet/__init__.py new file mode 100644 index 0000000000..2f870dd75f --- /dev/null +++ b/keras_hub/src/models/cspnet/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone +from keras_hub.src.models.cspnet.cspnet_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, CSPNetBackbone) diff --git a/keras_hub/src/models/cspnet/cspnet_backbone.py b/keras_hub/src/models/cspnet/cspnet_backbone.py new file mode 100644 index 0000000000..b66425feba --- /dev/null +++ b/keras_hub/src/models/cspnet/cspnet_backbone.py @@ -0,0 +1,1279 @@ +import keras +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.CSPNetBackbone") +class CSPNetBackbone(FeaturePyramidBackbone): + """This class represents Keras Backbone of CSPNet model. + + This class implements a CSPNet backbone as described in + [CSPNet: A New Backbone that can Enhance Learning Capability of CNN]( + https://arxiv.org/abs/1911.11929). + + Args: + stem_filters: int or list of ints, filter size for the stem. + stem_kernel_size: int or tuple/list of 2 integers, kernel size for the + stem. + stem_strides: int or tuple/list of 2 integers, stride length of the + convolution for the stem. + stackwise_num_filters: A list of ints, filter size for each block level + in the model. + stackwise_strides: int or tuple/list of ints, strides for each block + level in the model. + stackwise_depth: A list of ints, representing the depth + (number of blocks) for each block level in the model. + block_type: str. One of `"bottleneck_block"`, `"dark_block"`, or + `"edge_block"`. Use `"dark_block"` for DarkNet blocks, + `"edge_block"` for EdgeResidual / Fused-MBConv blocks. + groups: int, specifying the number of groups into which the input is + split along the channel axis. Defaults to `1`. + stage_type: str. One of `"csp"`, `"dark"`, or `"cs3"`. Use `"dark"` for + DarkNet stages, `"csp"` for Cross Stage, and `"cs3"` for Cross Stage + with only one transition conv. Defaults to `None`, which defaults to + `"cs3"`. + activation: str. Activation function for the model. + output_strides: int, output stride length of the backbone model. Must be + one of `(8, 16, 32)`. Defaults to `32`. + bottle_ratio: float or tuple/list of floats. The dimensionality of the + intermediate bottleneck space (i.e., the number of output filters in + the bottleneck convolution), calculated as + `(filters * bottle_ratio)` and applied to: + - the first convolution of `"dark_block"` and `"edge_block"` + - the first two convolutions of `"bottleneck_block"` + of each stage. Defaults to `1.0`. + block_ratio: float or tuple/list of floats. Filter size for each block, + calculated as `(stackwise_num_filters * block_ratio)` for each + stage. Defaults to `1.0`. + expand_ratio: float or tuple/list of floats. Filters ratio for `"csp"` + and `"cs3"` stages at different levels. Defaults to `1.0`. + stem_padding: str, padding value for the stem, either `"valid"` or + `"same"`. Defaults to `"valid"`. + stem_pooling: str, pooling value for the stem. Defaults to `None`. + avg_down: bool, if `True`, `AveragePooling2D` is applied at the + beginning of each stage when `strides == 2`. Defaults to `False`. + down_growth: bool, grow downsample channels to output channels. Applies + to Cross Stage only. Defaults to `False`. + cross_linear: bool, if `True`, activation will not be applied after the + expansion convolution. Applies to Cross Stage only. Defaults to + `False`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` while `"channels_first"` + corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + image_shape: tuple. The input shape without the batch size. + Defaults to `(None, None, 3)`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. + + Examples: + ```python + input_data = np.ones(shape=(8, 224, 224, 3)) + + # Pretrained backbone + model = keras_hub.models.CSPNetBackbone.from_preset( + "cspdarknet53_ra_imagenet" + ) + model(input_data) + + # Randomly initialized backbone with a custom config + model = keras_hub.models.CSPNetBackbone( + stem_filters=32, + stem_kernel_size=3, + stem_strides=1, + stackwise_depth=[1, 2, 4], + stackwise_strides=[1, 2, 2], + stackwise_num_filters=[32, 64, 128], + block_type="dark, + ) + model(input_data) + ``` + """ + + def __init__( + self, + stem_filters, + stem_kernel_size, + stem_strides, + stackwise_depth, + stackwise_strides, + stackwise_num_filters, + block_type, + groups=1, + stage_type=None, + activation="leaky_relu", + output_strides=32, + bottle_ratio=[1.0], + block_ratio=[1.0], + expand_ratio=[1.0], + stem_padding="valid", + stem_pooling=None, + avg_down=False, + down_growth=False, + cross_linear=False, + image_shape=(None, None, 3), + data_format=None, + dtype=None, + **kwargs, + ): + if block_type not in ( + "dark_block", + "edge_block", + "bottleneck_block", + ): + raise ValueError( + '`block_type` must be either `"dark_block"`, ' + '`"edge_block"`, or `"bottleneck_block"`.' + f"Received block_type={block_type}." + ) + + if stage_type not in ( + "dark", + "csp", + "cs3", + ): + raise ValueError( + '`block_type` must be either `"dark"`, `"csp"`, or `"cs3"`.' + f"Received block_type={stage_type}." + ) + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + + # === Functional Model === + image_input = layers.Input(shape=image_shape) + x = image_input # Intermediate result. + stem, stem_feat_info = create_csp_stem( + data_format=data_format, + channel_axis=channel_axis, + filters=stem_filters, + kernel_size=stem_kernel_size, + strides=stem_strides, + pooling=stem_pooling, + padding=stem_padding, + activation=activation, + dtype=dtype, + )(x) + + stages, pyramid_outputs = create_csp_stages( + inputs=stem, + filters=stackwise_num_filters, + data_format=data_format, + channel_axis=channel_axis, + stackwise_depth=stackwise_depth, + reduction=stem_feat_info, + groups=groups, + block_ratio=block_ratio, + bottle_ratio=bottle_ratio, + expand_ratio=expand_ratio, + strides=stackwise_strides, + avg_down=avg_down, + down_growth=down_growth, + cross_linear=cross_linear, + activation=activation, + output_strides=output_strides, + stage_type=stage_type, + block_type=block_type, + dtype=dtype, + name="csp_stage", + ) + + super().__init__( + inputs=image_input, outputs=stages, dtype=dtype, **kwargs + ) + + # === Config === + self.stem_filters = stem_filters + self.stem_kernel_size = stem_kernel_size + self.stem_strides = stem_strides + self.stackwise_depth = stackwise_depth + self.stackwise_strides = stackwise_strides + self.stackwise_num_filters = stackwise_num_filters + self.stage_type = stage_type + self.block_type = block_type + self.output_strides = output_strides + self.groups = groups + self.activation = activation + self.bottle_ratio = bottle_ratio + self.block_ratio = block_ratio + self.expand_ratio = expand_ratio + self.stem_padding = stem_padding + self.stem_pooling = stem_pooling + self.avg_down = avg_down + self.down_growth = down_growth + self.cross_linear = cross_linear + self.image_shape = image_shape + self.data_format = data_format + self.pyramid_outputs = pyramid_outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "stem_filters": self.stem_filters, + "stem_kernel_size": self.stem_kernel_size, + "stem_strides": self.stem_strides, + "stackwise_depth": self.stackwise_depth, + "stackwise_strides": self.stackwise_strides, + "stackwise_num_filters": self.stackwise_num_filters, + "stage_type": self.stage_type, + "block_type": self.block_type, + "output_strides": self.output_strides, + "groups": self.groups, + "activation": self.activation, + "bottle_ratio": self.bottle_ratio, + "block_ratio": self.block_ratio, + "expand_ratio": self.expand_ratio, + "stem_padding": self.stem_padding, + "stem_pooling": self.stem_pooling, + "avg_down": self.avg_down, + "down_growth": self.down_growth, + "cross_linear": self.cross_linear, + "image_shape": self.image_shape, + "data_format": self.data_format, + } + ) + return config + + +def bottleneck_block( + filters, + channel_axis, + data_format, + bottle_ratio, + dilation=1, + groups=1, + activation="relu", + dtype=None, + name=None, +): + """ + BottleNeck block. + + Args: + filters: Integer, the dimensionality of the output spaces (i.e. the + number of output filters in used the blocks). + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + bottle_ratio: float, ratio for bottleneck filters. Number of bottleneck + `filters = filters * bottle_ratio`. + dilation: int or tuple/list of 2 integers, specifying the dilation rate + to use for dilated convolution, defaults to `1`. + groups: A positive int specifying the number of groups in which the + input is split along the channel axis + activation: Activation for the conv layers, defaults to "relu". + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + Output tensor of block. + """ + if name is None: + name = f"bottleneck{keras.backend.get_uid('bottleneck')}" + + hidden_filters = int(round(filters * bottle_ratio)) + + def apply(x): + shortcut = x + x = layers.Conv2D( + filters=hidden_filters, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_bottleneck_block_conv_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_bottleneck_block_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_1", + )(x) + + x = layers.Conv2D( + filters=hidden_filters, + kernel_size=3, + dilation_rate=dilation, + groups=groups, + padding="same", + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_bottleneck_block_conv_2", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_bottleneck_block_bn_2", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_2", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_2", + )(x) + + x = layers.Conv2D( + filters=filters, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_bottleneck_block_conv_3", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_bottleneck_block_bn_3", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_3", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_3", + )(x) + + x = layers.add( + [x, shortcut], dtype=dtype, name=f"{name}_bottleneck_block_add" + ) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_4", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_bottleneck_block_activation_4", + )(x) + return x + + return apply + + +def dark_block( + filters, + data_format, + channel_axis, + dilation, + bottle_ratio, + groups, + activation, + dtype=None, + name=None, +): + """ + DarkNet block. + + Args: + filters: Integer, the dimensionality of the output spaces (i.e. the + number of output filters in used the blocks). + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + bottle_ratio: float, ratio for darknet filters. Number of darknet + `filters = filters * bottle_ratio`. + dilation: int or tuple/list of 2 integers, specifying the dilation rate + to use for dilated convolution, defaults to `1`. + groups: A positive int specifying the number of groups in which the + input is split along the channel axis + activation: Activation for the conv layers, defaults to "relu". + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + Output tensor of block. + """ + if name is None: + name = f"dark{keras.backend.get_uid('dark')}" + + hidden_filters = int(round(filters * bottle_ratio)) + + def apply(x): + shortcut = x + x = layers.Conv2D( + filters=hidden_filters, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_dark_block_conv_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_dark_block_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_dark_block_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_dark_block_activation_1", + )(x) + + x = layers.Conv2D( + filters=filters, + kernel_size=3, + dilation_rate=dilation, + groups=groups, + padding="same", + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_dark_block_conv_2", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_dark_block_bn_2", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_dark_block_activation_2", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_dark_block_activation_2", + )(x) + + x = layers.add( + [x, shortcut], dtype=dtype, name=f"{name}_dark_block_add" + ) + return x + + return apply + + +def edge_block( + filters, + data_format, + channel_axis, + dilation=1, + bottle_ratio=0.5, + groups=1, + activation="relu", + dtype=None, + name=None, +): + """ + EdgeResidual / Fused-MBConv blocks. + + Args: + filters: Integer, the dimensionality of the output spaces (i.e. the + number of output filters in used the blocks). + data_format: `None` or str. the ordering of the dimensions in the + inputs. Can be `"channels_last"` + (`(batch_size, height, width, channels)`) or`"channels_first"` + (`(batch_size, channels, height, width)`). + bottle_ratio: float, ratio for edge_block filters. Number of edge_block + `filters = filters * bottle_ratio`. + dilation: int or tuple/list of 2 integers, specifying the dilation rate + to use for dilated convolution, defaults to `1`. + groups: A positive int specifying the number of groups in which the + input is split along the channel axis + activation: Activation for the conv layers, defaults to "relu". + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the models computations and weights. + name: str. A prefix for the layer names used in the block. + + Returns: + Output tensor of block. + """ + if name is None: + name = f"edge{keras.backend.get_uid('edge')}" + + hidden_filters = int(round(filters * bottle_ratio)) + + def apply(x): + shortcut = x + x = layers.Conv2D( + filters=hidden_filters, + kernel_size=3, + use_bias=False, + dilation_rate=dilation, + groups=groups, + padding="same", + data_format=data_format, + dtype=dtype, + name=f"{name}_edge_block_conv_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_edge_block_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_edge_block_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_edge_block_activation_1", + )(x) + + x = layers.Conv2D( + filters=filters, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_edge_block_conv_2", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_edge_block_bn_2", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_edge_block_activation_2", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_edge_block_activation_2", + )(x) + + x = layers.add( + [x, shortcut], dtype=dtype, name=f"{name}_edge_block_add" + ) + return x + + return apply + + +def cross_stage( + filters, + strides, + dilation, + depth, + data_format, + channel_axis, + block_ratio=1.0, + bottle_ratio=1.0, + expand_ratio=1.0, + groups=1, + first_dilation=None, + avg_down=False, + activation="relu", + down_growth=False, + cross_linear=False, + block_fn=bottleneck_block, + dtype=None, + name=None, +): + """ " + Cross Stage. + """ + if name is None: + name = f"cross_stage_{keras.backend.get_uid('cross_stage')}" + + first_dilation = first_dilation or dilation + + def apply(x): + prev_filters = keras.ops.shape(x)[channel_axis] + down_chs = filters if down_growth else prev_filters + expand_chs = int(round(filters * expand_ratio)) + block_channels = int(round(filters * block_ratio)) + + if strides != 1 or first_dilation != dilation: + if avg_down: + if strides == 2: + x = layers.AveragePooling2D( + 2, dtype=dtype, name=f"{name}_csp_avg_pool" + )(x) + x = layers.Conv2D( + filters=filters, + kernel_size=1, + strides=1, + use_bias=False, + groups=groups, + data_format=data_format, + dtype=dtype, + name=f"{name}_csp_conv_down_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_csp_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_csp_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_csp_activation_1", + )(x) + else: + x = layers.Conv2D( + filters=down_chs, + kernel_size=3, + strides=strides, + dilation_rate=first_dilation, + use_bias=False, + groups=groups, + data_format=data_format, + dtype=dtype, + name=f"{name}_csp_conv_down_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_csp_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_csp_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_csp_activation_1", + )(x) + + x = layers.Conv2D( + filters=expand_chs, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_csp_conv_exp", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_csp_bn_2", + )(x) + if not cross_linear: + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_csp_activation_2", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_csp_activation_2", + )(x) + prev_filters = keras.ops.shape(x)[channel_axis] + xs, xb = ops.split( + x, + indices_or_sections=prev_filters // (expand_chs // 2), + axis=channel_axis, + ) + + for i in range(depth): + xb = block_fn( + filters=block_channels, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + activation=activation, + data_format=data_format, + channel_axis=channel_axis, + dtype=dtype, + name=f"{name}_block_{i}", + )(xb) + + xb = layers.Conv2D( + filters=expand_chs // 2, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_csp_conv_transition_b", + )(xb) + xb = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_csp_transition_b_bn", + )(xb) + if activation == "leaky_relu": + xb = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_csp_transition_b_activation", + )(xb) + else: + xb = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_csp_transition_b_activation", + )(xb) + + out = layers.Concatenate( + axis=channel_axis, dtype=dtype, name=f"{name}_csp_conv_concat" + )([xs, xb]) + out = layers.Conv2D( + filters=filters, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_csp_conv_transition", + )(out) + out = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_csp_transition_bn", + )(out) + if activation == "leaky_relu": + out = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_csp_transition_activation", + )(out) + else: + out = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_csp_transition_activation", + )(out) + return out + + return apply + + +def cross_stage3( + data_format, + channel_axis, + filters, + strides, + dilation, + depth, + block_ratio, + bottle_ratio, + expand_ratio, + avg_down, + activation, + first_dilation, + down_growth, + cross_linear, + block_fn, + groups, + name=None, + dtype=None, +): + """ + Cross Stage 3. + + Similar to Cross Stage, but with only one transition conv in the output. + """ + if name is None: + name = f"cross_stage3_{keras.backend.get_uid('cross_stage3')}" + + first_dilation = first_dilation or dilation + + def apply(x): + prev_filters = keras.ops.shape(x)[channel_axis] + down_chs = filters if down_growth else prev_filters + expand_chs = int(round(filters * expand_ratio)) + block_filters = int(round(filters * block_ratio)) + + if strides != 1 or first_dilation != dilation: + if avg_down: + if strides == 2: + x = layers.AveragePooling2D( + 2, dtype=dtype, name=f"{name}_cross_stage3_avg_pool" + )(x) + x = layers.Conv2D( + filters=filters, + kernel_size=1, + strides=1, + use_bias=False, + groups=groups, + data_format=data_format, + dtype=dtype, + name=f"{name}_cs3_conv_down_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_cs3_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_cs3_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_cs3_activation_1", + )(x) + else: + x = layers.Conv2D( + filters=down_chs, + kernel_size=3, + strides=strides, + dilation_rate=first_dilation, + use_bias=False, + groups=groups, + data_format=data_format, + dtype=dtype, + name=f"{name}_cs3_conv_down_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_cs3_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_cs3__activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_cs3_activation_1", + )(x) + + x = layers.Conv2D( + filters=expand_chs, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_cs3_conv_exp", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_cs3_bn_2", + )(x) + if not cross_linear: + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_cs3_activation_2", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_cs3_activation_2", + )(x) + + prev_filters = keras.ops.shape(x)[channel_axis] + x1, x2 = ops.split( + x, + indices_or_sections=prev_filters // (expand_chs // 2), + axis=channel_axis, + ) + + for i in range(depth): + x1 = block_fn( + filters=block_filters, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + activation=activation, + data_format=data_format, + channel_axis=channel_axis, + dtype=dtype, + name=f"{name}_block_{i}", + )(x1) + + out = layers.Concatenate( + axis=channel_axis, + dtype=dtype, + name=f"{name}_cs3_conv_transition_concat", + )([x1, x2]) + out = layers.Conv2D( + filters=expand_chs // 2, + kernel_size=1, + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"{name}_cs3_conv_transition", + )(out) + out = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_cs3_transition_bn", + )(out) + if activation == "leaky_relu": + out = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_cs3_activation_3", + )(out) + else: + out = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_cs3_activation_3", + )(out) + return out + + return apply + + +def dark_stage( + data_format, + channel_axis, + filters, + strides, + dilation, + depth, + block_ratio, + bottle_ratio, + avg_down, + activation, + first_dilation, + block_fn, + groups, + expand_ratio=None, + down_growth=None, + cross_linear=None, + name=None, + dtype=None, +): + """ + DarkNet Stage. + + Similar to DarkNet Stage, but with only one transition conv in the output. + """ + if name is None: + name = f"dark_stage_{keras.backend.get_uid('dark_stage')}" + + first_dilation = first_dilation or dilation + + def apply(x): + block_channels = int(round(filters * block_ratio)) + if avg_down: + if strides == 2: + x = layers.AveragePooling2D( + 2, dtype=dtype, name=f"{name}_dark_avg_pool" + )(x) + x = layers.Conv2D( + filters=filters, + kernel_size=1, + strides=1, + use_bias=False, + groups=groups, + data_format=data_format, + dtype=dtype, + name=f"{name}_dark_conv_down_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_dark_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_dark_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_dark_activation_1", + )(x) + else: + x = layers.Conv2D( + filters=filters, + kernel_size=3, + strides=strides, + dilation_rate=first_dilation, + use_bias=False, + groups=groups, + data_format=data_format, + dtype=dtype, + name=f"{name}_dark_conv_down_1", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"{name}_dark_bn_1", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"{name}_dark_activation_1", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"{name}_dark_activation_1", + )(x) + for i in range(depth): + x = block_fn( + filters=block_channels, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + activation=activation, + data_format=data_format, + channel_axis=channel_axis, + dtype=dtype, + name=f"{name}_block_{i}", + )(x) + return x + + return apply + + +def create_csp_stem( + data_format, + channel_axis, + activation, + padding, + filters=32, + kernel_size=3, + strides=2, + pooling=None, + dtype=None, +): + if not isinstance(filters, (tuple, list)): + filters = [filters] + stem_depth = len(filters) + assert stem_depth + assert strides in (1, 2, 4) + last_idx = stem_depth - 1 + + def apply(x): + stem_strides = 1 + for i, chs in enumerate(filters): + conv_strides = ( + 2 + if (i == 0 and strides > 1) + or (i == last_idx and strides > 2 and not pooling) + else 1 + ) + x = layers.Conv2D( + filters=chs, + kernel_size=kernel_size, + strides=conv_strides, + padding=padding if i == 0 else "valid", + use_bias=False, + data_format=data_format, + dtype=dtype, + name=f"csp_stem_conv_{i}", + )(x) + x = layers.BatchNormalization( + epsilon=1e-05, + axis=channel_axis, + dtype=dtype, + name=f"csp_stem_bn_{i}", + )(x) + if activation == "leaky_relu": + x = layers.LeakyReLU( + negative_slope=0.01, + dtype=dtype, + name=f"csp_stem_activation_{i}", + )(x) + else: + x = layers.Activation( + activation, + dtype=dtype, + name=f"csp_stem_activation_{i}", + )(x) + stem_strides *= conv_strides + + if pooling == "max": + assert strides > 2 + x = layers.MaxPooling2D( + pool_size=3, + strides=2, + padding="same", + data_format=data_format, + dtype=dtype, + name="csp_stem_pool", + )(x) + stem_strides *= 2 + return x, stem_strides + + return apply + + +def create_csp_stages( + inputs, + filters, + data_format, + channel_axis, + stackwise_depth, + reduction, + block_ratio, + bottle_ratio, + expand_ratio, + strides, + groups, + avg_down, + down_growth, + cross_linear, + activation, + output_strides, + stage_type, + block_type, + dtype, + name, +): + if name is None: + name = f"csp_stage_{keras.backend.get_uid('csp_stage')}" + + num_stages = len(stackwise_depth) + dilation = 1 + net_strides = reduction + strides = _pad_arg(strides, num_stages) + expand_ratio = _pad_arg(expand_ratio, num_stages) + bottle_ratio = _pad_arg(bottle_ratio, num_stages) + block_ratio = _pad_arg(block_ratio, num_stages) + + if stage_type == "dark": + stage_fn = dark_stage + elif stage_type == "csp": + stage_fn = cross_stage + else: + stage_fn = cross_stage3 + + if block_type == "dark_block": + block_fn = dark_block + elif block_type == "edge_block": + block_fn = edge_block + else: + block_fn = bottleneck_block + + stages = inputs + pyramid_outputs = {} + for stage_idx, _ in enumerate(stackwise_depth): + if net_strides >= output_strides and strides[stage_idx] > 1: + dilation *= strides[stage_idx] + strides = 1 + net_strides *= strides[stage_idx] + first_dilation = 1 if dilation in (1, 2) else 2 + stages = stage_fn( + data_format=data_format, + channel_axis=channel_axis, + filters=filters[stage_idx], + depth=stackwise_depth[stage_idx], + strides=strides[stage_idx], + dilation=dilation, + block_ratio=block_ratio[stage_idx], + bottle_ratio=bottle_ratio[stage_idx], + expand_ratio=expand_ratio[stage_idx], + groups=groups, + first_dilation=first_dilation, + avg_down=avg_down, + activation=activation, + down_growth=down_growth, + cross_linear=cross_linear, + block_fn=block_fn, + dtype=dtype, + name=f"stage_{stage_idx}", + )(stages) + pyramid_outputs[f"P{stage_idx + 2}"] = stages + return stages, pyramid_outputs + + +def _pad_arg(x, n): + """ + pads an argument tuple to specified n by padding with last value + """ + if not isinstance(x, (tuple, list)): + x = (x,) + curr_n = len(x) + pad_n = n - curr_n + if pad_n <= 0: + return x[:n] + return tuple( + list(x) + + [ + x[-1], + ] + * pad_n + ) diff --git a/keras_hub/src/models/cspnet/cspnet_backbone_test.py b/keras_hub/src/models/cspnet/cspnet_backbone_test.py new file mode 100644 index 0000000000..3b8681d3d9 --- /dev/null +++ b/keras_hub/src/models/cspnet/cspnet_backbone_test.py @@ -0,0 +1,61 @@ +import pytest +from absl.testing import parameterized +from keras import ops + +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class CSPNetBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "stem_filters": 32, + "stem_kernel_size": 3, + "stem_strides": 1, + "stackwise_strides": 2, + "stackwise_depth": [1, 2, 8], + "stackwise_num_filters": [16, 24, 48], + "image_shape": (None, None, 3), + "down_growth": True, + "bottle_ratio": (0.5,) + (1.0,), + "block_ratio": (1.0,) + (0.5,), + "expand_ratio": (2.0,) + (1.0,), + "block_type": "dark_block", + "stage_type": "csp", + } + self.input_size = 64 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + @parameterized.named_parameters( + ("cspnet", "csp", "dark_block"), + ) + def test_backbone_basics(self, stage_type, block_type): + self.run_vision_backbone_test( + cls=CSPNetBackbone, + init_kwargs={ + **self.init_kwargs, + "block_type": block_type, + "stage_type": stage_type, + }, + input_data=self.input_data, + expected_output_shape=(2, 6, 6, 48), + expected_pyramid_output_keys=["P2", "P3", "P4"], + expected_pyramid_image_sizes=[(30, 30), (14, 14), (6, 6)], + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in CSPNetBackbone.presets: + self.run_preset_test( + cls=CSPNetBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/cspnet/cspnet_image_classifier.py b/keras_hub/src/models/cspnet/cspnet_image_classifier.py new file mode 100644 index 0000000000..e74f41887e --- /dev/null +++ b/keras_hub/src/models/cspnet/cspnet_image_classifier.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone +from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( + CSPNetImageClassifierPreprocessor, +) +from keras_hub.src.models.image_classifier import ImageClassifier + + +@keras_hub_export("keras_hub.models.CSPNetImageClassifier") +class CSPNetImageClassifier(ImageClassifier): + backbone_cls = CSPNetBackbone + preprocessor_cls = CSPNetImageClassifierPreprocessor diff --git a/keras_hub/src/models/cspnet/cspnet_image_classifier_preprocessor.py b/keras_hub/src/models/cspnet/cspnet_image_classifier_preprocessor.py new file mode 100644 index 0000000000..cbe70b1683 --- /dev/null +++ b/keras_hub/src/models/cspnet/cspnet_image_classifier_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone +from keras_hub.src.models.cspnet.cspnet_image_converter import ( + CSPNetImageConverter, +) +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) + + +@keras_hub_export("keras_hub.models.CSPNetImageClassifierPreprocessor") +class CSPNetImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = CSPNetBackbone + image_converter_cls = CSPNetImageConverter diff --git a/keras_hub/src/models/cspnet/cspnet_image_classifier_test.py b/keras_hub/src/models/cspnet/cspnet_image_classifier_test.py new file mode 100644 index 0000000000..9e26aaf65e --- /dev/null +++ b/keras_hub/src/models/cspnet/cspnet_image_classifier_test.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone +from keras_hub.src.models.cspnet.cspnet_image_classifier import ( + CSPNetImageClassifier, +) +from keras_hub.src.models.cspnet.cspnet_image_classifier_preprocessor import ( + CSPNetImageClassifierPreprocessor, +) +from keras_hub.src.models.cspnet.cspnet_image_converter import ( + CSPNetImageConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class CSPNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 32, 32, 3), dtype="float32") + self.labels = [0, 2] + self.backbone = CSPNetBackbone( + stem_filters=32, + stem_kernel_size=3, + stem_strides=1, + stackwise_strides=2, + stackwise_depth=[1, 2, 8], + stackwise_num_filters=[16, 24, 48], + image_shape=(None, None, 3), + down_growth=True, + bottle_ratio=(0.5,) + (1.0,), + block_ratio=(1.0,) + (0.5,), + expand_ratio=(2.0,) + (1.0,), + block_type="dark_block", + stage_type="csp", + ) + self.image_converter = CSPNetImageConverter( + height=32, width=32, scale=1 / 255.0 + ) + self.preprocessor = CSPNetImageClassifierPreprocessor( + self.image_converter + ) + self.init_kwargs = { + "backbone": self.backbone, + "preprocessor": self.preprocessor, + "num_classes": 3, + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + self.run_task_test( + cls=CSPNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 3), + ) + + @pytest.mark.large + def test_smallest_preset(self): + image_batch = self.load_test_image()[None, ...] / 255.0 + self.run_preset_test( + cls=CSPNetImageClassifier, + preset="hf://timm/cspdarknet53.ra_in1k", + input_data=image_batch, + expected_output_shape=(1, 1000), + expected_labels=[85], + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=CSPNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/cspnet/cspnet_image_converter.py b/keras_hub/src/models/cspnet/cspnet_image_converter.py new file mode 100644 index 0000000000..f121637aa1 --- /dev/null +++ b/keras_hub/src/models/cspnet/cspnet_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone + + +@keras_hub_export("keras_hub.layers.CSPNetImageConverter") +class CSPNetImageConverter(ImageConverter): + backbone_cls = CSPNetBackbone diff --git a/keras_hub/src/models/cspnet/cspnet_presets.py b/keras_hub/src/models/cspnet/cspnet_presets.py new file mode 100644 index 0000000000..8b090b56bf --- /dev/null +++ b/keras_hub/src/models/cspnet/cspnet_presets.py @@ -0,0 +1,16 @@ +"""CSPNet preset configurations.""" + +backbone_presets = { + "csp_darknet_53_ra_imagenet": { + "metadata": { + "description": ( + "A CSP-DarkNet (Cross-Stage-Partial) image classification model" + " pre-trained on the Randomly Augmented ImageNet 1k dataset at " + "a 224x224 resolution." + ), + "params": 26652512, + "path": "cspnet", + }, + "kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/1", + }, +} diff --git a/keras_hub/src/utils/timm/convert_cspnet.py b/keras_hub/src/utils/timm/convert_cspnet.py new file mode 100644 index 0000000000..161edab23f --- /dev/null +++ b/keras_hub/src/utils/timm/convert_cspnet.py @@ -0,0 +1,165 @@ +import numpy as np + +from keras_hub.src.models.cspnet.cspnet_backbone import CSPNetBackbone + +backbone_cls = CSPNetBackbone + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + if timm_architecture == "cspdarknet53": + stem_filters = 32 + stem_kernel_size = 3 + stem_strides = 1 + stackwise_depth = [1, 2, 8, 8, 4] + stackwise_num_filters = [64, 128, 256, 512, 1024] + bottle_ratio = (0.5,) + (1.0,) + block_ratio = (1.0,) + (0.5,) + expand_ratio = (2.0,) + (1.0,) + stage_type = "csp" + block_type = "dark_block" + down_growth = True + stackwise_strides = 2 + else: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + return dict( + stem_filters=stem_filters, + stem_kernel_size=stem_kernel_size, + stem_strides=stem_strides, + stackwise_depth=stackwise_depth, + stackwise_num_filters=stackwise_num_filters, + bottle_ratio=bottle_ratio, + block_ratio=block_ratio, + expand_ratio=expand_ratio, + stage_type=stage_type, + block_type=block_type, + stackwise_strides=stackwise_strides, + down_growth=down_growth, + ) + + +def convert_weights(backbone, loader, timm_config): + def port_conv2d(hf_weight_prefix, keras_layer_name): + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + def port_batch_normalization(hf_weight_prefix, keras_layer_name): + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + + # Stem + + stem_filter = backbone.stem_filters + if not isinstance(stem_filter, (tuple, list)): + stem_filter = [stem_filter] + + for i in range(len(stem_filter)): + port_conv2d(f"stem.conv{i + 1}.conv", f"csp_stem_conv_{i}") + port_batch_normalization(f"stem.conv{i + 1}.bn", f"csp_stem_bn_{i}") + + # Stages + stackwise_depth = backbone.stackwise_depth + stage_type = backbone.stage_type + block_type = backbone.block_type + + for idx, block in enumerate(stackwise_depth): + port_conv2d( + f"stages.{idx}.conv_down.conv", + f"stage_{idx}_{stage_type}_conv_down_1", + ) + port_batch_normalization( + f"stages.{idx}.conv_down.bn", f"stage_{idx}_{stage_type}_bn_1" + ) + port_conv2d( + f"stages.{idx}.conv_exp.conv", f"stage_{idx}_{stage_type}_conv_exp" + ) + port_batch_normalization( + f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2" + ) + + for i in range(block): + port_conv2d( + f"stages.{idx}.blocks.{i}.conv1.conv", + f"stage_{idx}_block_{i}_{block_type}_conv_1", + ) + port_batch_normalization( + f"stages.{idx}.blocks.{i}.conv1.bn", + f"stage_{idx}_block_{i}_{block_type}_bn_1", + ) + port_conv2d( + f"stages.{idx}.blocks.{i}.conv2.conv", + f"stage_{idx}_block_{i}_{block_type}_conv_2", + ) + port_batch_normalization( + f"stages.{idx}.blocks.{i}.conv2.bn", + f"stage_{idx}_block_{i}_{block_type}_bn_2", + ) + if block_type == "bottleneck_block": + port_conv2d( + f"stages.{idx}.blocks.{i}.conv3.conv", + f"stage_{idx}_block_{i}_{block_type}_conv_3", + ) + port_batch_normalization( + f"stages.{idx}.blocks.{i}.conv3.bn", + f"stage_{idx}_block_{i}_{block_type}_bn_3", + ) + + if stage_type == "csp": + port_conv2d( + f"stages.{idx}.conv_transition_b.conv", + f"stage_{idx}_{stage_type}_conv_transition_b", + ) + port_batch_normalization( + f"stages.{idx}.conv_transition_b.bn", + f"stage_{idx}_{stage_type}_transition_b_bn", + ) + port_conv2d( + f"stages.{idx}.conv_transition.conv", + f"stage_{idx}_{stage_type}_conv_transition", + ) + port_batch_normalization( + f"stages.{idx}.conv_transition.bn", + f"stage_{idx}_{stage_type}_transition_bn", + ) + + else: + port_conv2d( + f"stages.{idx}.conv_transition.conv", + f"stage_{idx}_{stage_type}_conv_transition", + ) + port_batch_normalization( + f"stages.{idx}.conv_transition.bn", + f"stage_{idx}_{stage_type}_transition_bn", + ) + + +def convert_head(task, loader, timm_config): + loader.port_weight( + task.output_dense.kernel, + hf_weight_key="head.fc.weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key="head.fc.bias", + ) diff --git a/keras_hub/src/utils/timm/convert_cspnet_test.py b/keras_hub/src/utils/timm/convert_cspnet_test.py new file mode 100644 index 0000000000..dcddca8ae5 --- /dev/null +++ b/keras_hub/src/utils/timm/convert_cspnet_test.py @@ -0,0 +1,20 @@ +import pytest +from keras import ops + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.tests.test_case import TestCase + + +class TimmDenseNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_densenet_backbone(self): + model = Backbone.from_preset("hf://timm/cspdarknet53.ra_in1k") + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 5, 5, 1024)) + + @pytest.mark.large + def test_convert_densenet_classifier(self): + model = ImageClassifier.from_preset("hf://timm/cspdarknet53.ra_in1k") + outputs = model.predict(ops.ones((1, 512, 512, 3))) + self.assertEqual(outputs.shape, (1, 1000)) diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 069a6425e4..368f1be68a 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -3,6 +3,7 @@ from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup +from keras_hub.src.utils.timm import convert_cspnet from keras_hub.src.utils.timm import convert_densenet from keras_hub.src.utils.timm import convert_efficientnet from keras_hub.src.utils.timm import convert_resnet @@ -16,6 +17,8 @@ def __init__(self, preset, config): architecture = self.config["architecture"] if "resnet" in architecture: self.converter = convert_resnet + elif "csp" in architecture: + self.converter = convert_cspnet elif "densenet" in architecture: self.converter = convert_densenet elif "vgg" in architecture: diff --git a/tools/checkpoint_conversion/convert_cspnet_checkpoints.py b/tools/checkpoint_conversion/convert_cspnet_checkpoints.py new file mode 100644 index 0000000000..56d18486e9 --- /dev/null +++ b/tools/checkpoint_conversion/convert_cspnet_checkpoints.py @@ -0,0 +1,112 @@ +"""Convert cspnet checkpoints. + +python tools/checkpoint_conversion/convert_cspnet_checkpoints.py \ + --preset csp_darknet_53_ra_imagenet --upload_uri kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_hub + +PRESET_MAP = { + "csp_darknet_53_ra_imagenet": "timm/cspdarknet53.ra_in1k", +} +FLAGS = flags.FLAGS + + +flags.DEFINE_string( + "preset", + None, + "Must be a valid `CSPNet` preset from KerasHub", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', + required=False, +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Keras label:", keras_label) + print("🔶 TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = FLAGS.preset + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + print("✅ Loaded TIMM model.") + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + + print("✅ Loaded KerasHub model.") + keras_model = keras_hub.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + validate_output(keras_model, timm_model) + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)