Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
111 changes: 111 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/clip_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras import layers
from keras import ops


class CLIPAttention(layers.Layer):
def __init__(self, num_heads, hidden_dim, dropout=0.0, **kwargs):
super().__init__(**kwargs)
if hidden_dim % num_heads != 0:
raise ValueError(
"`hidden_dim` must be divisible by num_heads. "
f"Received: num_heads={num_heads}, hidden_dim={hidden_dim}"
)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.dropout = dropout
self.head_dim = self.hidden_dim // self.num_heads

self.dropout_layer = layers.Dropout(self.dropout)
self.scale = self.head_dim**-0.5
self.query_dense = layers.Dense(
units=self.hidden_dim, dtype=self.dtype_policy, name="query"
)
self.key_dense = layers.Dense(
units=self.hidden_dim, dtype=self.dtype_policy, name="key"
)
self.value_dense = layers.Dense(
units=self.hidden_dim, dtype=self.dtype_policy, name="value"
)
self.softmax = layers.Softmax(dtype="float32")
self.output_dense = layers.Dense(
units=self.hidden_dim,
dtype=self.dtype_policy,
name="attention_output",
)

def build(self, input_shape):
self.query_dense.build(input_shape)
self.key_dense.build(input_shape)
self.value_dense.build(input_shape)
self.output_dense.build([None, None, self.hidden_dim])

def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
output_shape[-1] = self.hidden_dim
return output_shape

def _transpose_for_scores(self, inputs):
batch_size = ops.shape(inputs)[0]
inputs = ops.reshape(
inputs, (batch_size, -1, self.num_heads, self.head_dim)
)
return ops.transpose(inputs, axes=[0, 2, 1, 3])

def call(self, x, attention_mask=None, training=None):
batch_size = ops.shape(x)[0]
query = self.query_dense(x)
key = self.key_dense(x)
value = self.value_dense(x)
query = self._transpose_for_scores(query)
key = self._transpose_for_scores(key)
value = self._transpose_for_scores(value)

attention_logits = ops.matmul(
query, ops.transpose(key, axes=[0, 1, 3, 2])
)
dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_logits.dtype)
attention_logits = ops.divide(attention_logits, dk)

if attention_mask is not None:
attention_logits = ops.add(attention_logits, attention_mask)

orig_dtype = attention_logits.dtype
attention_softmax = self.softmax(attention_logits)
attention_softmax = ops.cast(attention_softmax, orig_dtype)

if self.dropout:
attention_softmax = self.dropout_layer(
attention_softmax, training=training
)

attention_output = ops.matmul(attention_softmax, value)
attention_output = ops.transpose(attention_output, axes=[0, 2, 1, 3])
attention_output = ops.reshape(
attention_output, (batch_size, -1, self.hidden_dim)
)
attention_output = self.output_dense(attention_output)
return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"dropout": self.dropout,
}
)
return config
115 changes: 115 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras import layers
from keras import ops

from keras_nlp.src.models.stable_diffusion_v3.clip_attention import (
CLIPAttention,
)


def quick_gelu(x):
return x * ops.sigmoid(1.702 * x)


class CLIPEncoderBlock(layers.Layer):
def __init__(
self,
hidden_dim,
num_heads,
intermediate_dim,
intermediate_activation="quick_gelu",
**kwargs,
):
super().__init__(**kwargs)
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.intermediate_dim = intermediate_dim
self.intermediate_activation = intermediate_activation

if intermediate_activation == "quick_gelu":
intermediate_activation = quick_gelu

self.layer_norm_1 = layers.LayerNormalization(
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
)
self.attention = CLIPAttention(
self.num_heads,
self.hidden_dim,
dtype=self.dtype_policy,
name="attention",
)
self.layer_norm_2 = layers.LayerNormalization(
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2"
)
self.dense_1 = layers.Dense(
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"
)
self.activation = layers.Activation(
intermediate_activation, dtype=self.dtype_policy, name="activation"
)
self.dense_2 = layers.Dense(
self.hidden_dim, dtype=self.dtype_policy, name="dense_2"
)

def build(self, input_shape):
self.layer_norm_1.build(input_shape)
self.attention.build(input_shape)
self.layer_norm_2.build(input_shape)
self.dense_1.build(input_shape)
input_shape = self.dense_1.compute_output_shape(input_shape)
self.dense_2.build(input_shape)

def compute_output_shape(self, inputs_shape):
outputs_shape = list(inputs_shape)
outputs_shape[-1] = self.hidden_dim
return outputs_shape

def _compute_attention(self, x, attention_mask=None, training=None):
mask = None
if attention_mask is not None:
attention_mask = (
ops.cast(attention_mask, dtype=x.dtype)
if attention_mask is not None
else None
)
mask = attention_mask
return self.attention(x, attention_mask=mask, training=training)

def call(self, x, attention_mask=None, training=None):
residual = x
x = self.layer_norm_1(x)
x = self._compute_attention(
x, attention_mask=attention_mask, training=training
)
x = ops.add(residual, x)

residual = x
x = self.dense_1(self.layer_norm_2(residual))
x = self.activation(x)
x = self.dense_2(x)
x = ops.add(residual, x)
return x

def get_config(self):
config = super().get_config()
config.update(
{
"hidden_dim": self.hidden_dim,
"num_heads": self.num_heads,
"intermediate_dim": self.intermediate_dim,
"intermediate_activation": self.intermediate_activation,
}
)
return config
147 changes: 147 additions & 0 deletions keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras import layers
from keras import ops

from keras_nlp.src.layers.modeling.token_and_position_embedding import (
TokenAndPositionEmbedding,
)
from keras_nlp.src.models.backbone import Backbone
from keras_nlp.src.models.stable_diffusion_v3.clip_encoder_block import (
CLIPEncoderBlock,
)


class CLIPTextEncoder(Backbone):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can land all of this in SD3 folder for now.

But we might eventually want clip in it's own folder here. Usable with SD3. WDYT @divyashreepathihalli ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation will be cleaner if we have a CLIP model. I can try adding it after landing SD3.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@james77777778, CLIP is in KerasCV - it needs to be added to keras_hub still. That would make this implementation cleaner for sure.

Given that this is only using the text encoder part and not the whole clip model. We can land this as is now and clean up once CLIP is added.

def __init__(
self,
embedding_dim,
hidden_dim,
num_layers,
num_heads,
intermediate_dim,
intermediate_activation="quick_gelu",
intermediate_output_index=None,
vocabulary_size=49408,
sequence_length=77,
dtype=None,
**kwargs,
):
if (
intermediate_output_index is not None
and intermediate_output_index < 0
):
intermediate_output_index += num_layers

# === Layers ===
self.embedding = TokenAndPositionEmbedding(
vocabulary_size=vocabulary_size,
sequence_length=sequence_length,
embedding_dim=embedding_dim,
dtype=dtype,
name="embedding",
)
self.encoder_layers = [
CLIPEncoderBlock(
hidden_dim,
num_heads,
intermediate_dim,
intermediate_activation,
dtype=dtype,
)
for _ in range(num_layers)
]
self.layer_norm = layers.LayerNormalization(
epsilon=0.00001, dtype=dtype, name="layer_norm"
)
self.text_projection = layers.Dense(
hidden_dim,
use_bias=False,
dtype=dtype,
name="text_projection",
)

# === Functional Model ===
encoder_token_ids = layers.Input(
shape=(sequence_length,), dtype="int32", name="encoder_token_ids"
)
causal_mask = layers.Input(
batch_shape=(sequence_length, sequence_length), name="causal_mask"
)
x = self.embedding(encoder_token_ids)
encoder_intermediate_output = None
# Encoder.
for i, block in enumerate(self.encoder_layers):
x = block(x, attention_mask=causal_mask)
if i == intermediate_output_index:
encoder_intermediate_output = x
x = self.layer_norm(x)
encoder_output = x
if encoder_intermediate_output is not None:
encoder_intermediate_output = self.layer_norm(
encoder_intermediate_output
)
# Projection.
indices = ops.expand_dims(
ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1
)
pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1)
pooled_output = ops.squeeze(pooled_output, axis=1)
projection_output = self.text_projection(pooled_output)

outputs = {
"encoder_sequence_output": encoder_output,
"encoder_pooled_output": pooled_output,
"encoder_projection_output": projection_output,
}
if intermediate_output_index is not None:
outputs["encoder_intermediate_output"] = encoder_intermediate_output

super().__init__(
inputs={
"encoder_token_ids": encoder_token_ids,
"causal_mask": causal_mask,
},
outputs=outputs,
dtype=dtype,
**kwargs,
)

# === Config ===
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.intermediate_dim = intermediate_dim
self.intermediate_activation = intermediate_activation
self.intermediate_output_index = intermediate_output_index
self.vocabulary_size = vocabulary_size
self.sequence_length = sequence_length

def get_config(self):
config = super().get_config()
config.update(
{
"embedding_dim": self.embedding_dim,
"hidden_dim": self.hidden_dim,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"intermediate_dim": self.intermediate_dim,
"intermediate_activation": self.intermediate_activation,
"intermediate_output_index": self.intermediate_output_index,
"vocabulary_size": self.vocabulary_size,
"sequence_length": self.sequence_length,
}
)
return config
Loading