-
Notifications
You must be signed in to change notification settings - Fork 306
Add CLIP and T5XXL for StableDiffusionV3 #1790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
44783ea
Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTe…
james77777778 f4b9b4a
Merge remote-tracking branch 'upstream/keras-hub' into add-sdv3
james77777778 dcf3ec6
Make CLIPTextEncoder as Backbone
james77777778 c789236
Add `T5XXLPreprocessor` and remove `T5XXLTokenizer`
james77777778 7ddf4ec
Use `tf = None` at the top
james77777778 6f38cb4
Replace manual implementation of `CLIPAttention` with `MultiHeadAtten…
james77777778 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright 2024 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
111 changes: 111 additions & 0 deletions
111
keras_nlp/src/models/stable_diffusion_v3/clip_attention.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| # Copyright 2024 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from keras import layers | ||
| from keras import ops | ||
|
|
||
|
|
||
| class CLIPAttention(layers.Layer): | ||
| def __init__(self, num_heads, hidden_dim, dropout=0.0, **kwargs): | ||
| super().__init__(**kwargs) | ||
| if hidden_dim % num_heads != 0: | ||
| raise ValueError( | ||
| "`hidden_dim` must be divisible by num_heads. " | ||
| f"Received: num_heads={num_heads}, hidden_dim={hidden_dim}" | ||
| ) | ||
| self.num_heads = num_heads | ||
| self.hidden_dim = hidden_dim | ||
| self.dropout = dropout | ||
| self.head_dim = self.hidden_dim // self.num_heads | ||
|
|
||
| self.dropout_layer = layers.Dropout(self.dropout) | ||
| self.scale = self.head_dim**-0.5 | ||
| self.query_dense = layers.Dense( | ||
| units=self.hidden_dim, dtype=self.dtype_policy, name="query" | ||
| ) | ||
| self.key_dense = layers.Dense( | ||
| units=self.hidden_dim, dtype=self.dtype_policy, name="key" | ||
| ) | ||
| self.value_dense = layers.Dense( | ||
| units=self.hidden_dim, dtype=self.dtype_policy, name="value" | ||
| ) | ||
| self.softmax = layers.Softmax(dtype="float32") | ||
| self.output_dense = layers.Dense( | ||
| units=self.hidden_dim, | ||
| dtype=self.dtype_policy, | ||
| name="attention_output", | ||
| ) | ||
|
|
||
| def build(self, input_shape): | ||
| self.query_dense.build(input_shape) | ||
| self.key_dense.build(input_shape) | ||
| self.value_dense.build(input_shape) | ||
| self.output_dense.build([None, None, self.hidden_dim]) | ||
|
|
||
| def compute_output_shape(self, input_shape): | ||
| output_shape = list(input_shape) | ||
| output_shape[-1] = self.hidden_dim | ||
| return output_shape | ||
|
|
||
| def _transpose_for_scores(self, inputs): | ||
| batch_size = ops.shape(inputs)[0] | ||
| inputs = ops.reshape( | ||
| inputs, (batch_size, -1, self.num_heads, self.head_dim) | ||
| ) | ||
| return ops.transpose(inputs, axes=[0, 2, 1, 3]) | ||
|
|
||
| def call(self, x, attention_mask=None, training=None): | ||
| batch_size = ops.shape(x)[0] | ||
| query = self.query_dense(x) | ||
| key = self.key_dense(x) | ||
| value = self.value_dense(x) | ||
| query = self._transpose_for_scores(query) | ||
| key = self._transpose_for_scores(key) | ||
| value = self._transpose_for_scores(value) | ||
|
|
||
| attention_logits = ops.matmul( | ||
| query, ops.transpose(key, axes=[0, 1, 3, 2]) | ||
| ) | ||
| dk = ops.cast(ops.sqrt(self.head_dim), dtype=attention_logits.dtype) | ||
| attention_logits = ops.divide(attention_logits, dk) | ||
|
|
||
| if attention_mask is not None: | ||
| attention_logits = ops.add(attention_logits, attention_mask) | ||
|
|
||
| orig_dtype = attention_logits.dtype | ||
| attention_softmax = self.softmax(attention_logits) | ||
| attention_softmax = ops.cast(attention_softmax, orig_dtype) | ||
|
|
||
| if self.dropout: | ||
| attention_softmax = self.dropout_layer( | ||
| attention_softmax, training=training | ||
| ) | ||
|
|
||
| attention_output = ops.matmul(attention_softmax, value) | ||
| attention_output = ops.transpose(attention_output, axes=[0, 2, 1, 3]) | ||
| attention_output = ops.reshape( | ||
| attention_output, (batch_size, -1, self.hidden_dim) | ||
| ) | ||
| attention_output = self.output_dense(attention_output) | ||
| return attention_output | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "num_heads": self.num_heads, | ||
| "hidden_dim": self.hidden_dim, | ||
| "dropout": self.dropout, | ||
| } | ||
| ) | ||
| return config |
115 changes: 115 additions & 0 deletions
115
keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Copyright 2024 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from keras import layers | ||
| from keras import ops | ||
|
|
||
| from keras_nlp.src.models.stable_diffusion_v3.clip_attention import ( | ||
| CLIPAttention, | ||
| ) | ||
|
|
||
|
|
||
| def quick_gelu(x): | ||
| return x * ops.sigmoid(1.702 * x) | ||
|
|
||
|
|
||
| class CLIPEncoderBlock(layers.Layer): | ||
| def __init__( | ||
| self, | ||
| hidden_dim, | ||
| num_heads, | ||
| intermediate_dim, | ||
| intermediate_activation="quick_gelu", | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.hidden_dim = hidden_dim | ||
| self.num_heads = num_heads | ||
| self.intermediate_dim = intermediate_dim | ||
| self.intermediate_activation = intermediate_activation | ||
|
|
||
| if intermediate_activation == "quick_gelu": | ||
| intermediate_activation = quick_gelu | ||
|
|
||
| self.layer_norm_1 = layers.LayerNormalization( | ||
| epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1" | ||
| ) | ||
| self.attention = CLIPAttention( | ||
| self.num_heads, | ||
| self.hidden_dim, | ||
| dtype=self.dtype_policy, | ||
| name="attention", | ||
| ) | ||
| self.layer_norm_2 = layers.LayerNormalization( | ||
| epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2" | ||
| ) | ||
| self.dense_1 = layers.Dense( | ||
| self.intermediate_dim, dtype=self.dtype_policy, name="dense_1" | ||
| ) | ||
| self.activation = layers.Activation( | ||
| intermediate_activation, dtype=self.dtype_policy, name="activation" | ||
| ) | ||
| self.dense_2 = layers.Dense( | ||
| self.hidden_dim, dtype=self.dtype_policy, name="dense_2" | ||
| ) | ||
|
|
||
| def build(self, input_shape): | ||
| self.layer_norm_1.build(input_shape) | ||
| self.attention.build(input_shape) | ||
| self.layer_norm_2.build(input_shape) | ||
| self.dense_1.build(input_shape) | ||
| input_shape = self.dense_1.compute_output_shape(input_shape) | ||
| self.dense_2.build(input_shape) | ||
|
|
||
| def compute_output_shape(self, inputs_shape): | ||
| outputs_shape = list(inputs_shape) | ||
| outputs_shape[-1] = self.hidden_dim | ||
| return outputs_shape | ||
|
|
||
| def _compute_attention(self, x, attention_mask=None, training=None): | ||
| mask = None | ||
| if attention_mask is not None: | ||
| attention_mask = ( | ||
| ops.cast(attention_mask, dtype=x.dtype) | ||
| if attention_mask is not None | ||
| else None | ||
| ) | ||
| mask = attention_mask | ||
| return self.attention(x, attention_mask=mask, training=training) | ||
|
|
||
| def call(self, x, attention_mask=None, training=None): | ||
| residual = x | ||
| x = self.layer_norm_1(x) | ||
| x = self._compute_attention( | ||
| x, attention_mask=attention_mask, training=training | ||
| ) | ||
| x = ops.add(residual, x) | ||
|
|
||
| residual = x | ||
| x = self.dense_1(self.layer_norm_2(residual)) | ||
| x = self.activation(x) | ||
| x = self.dense_2(x) | ||
| x = ops.add(residual, x) | ||
| return x | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "hidden_dim": self.hidden_dim, | ||
| "num_heads": self.num_heads, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "intermediate_activation": self.intermediate_activation, | ||
| } | ||
| ) | ||
| return config |
147 changes: 147 additions & 0 deletions
147
keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| # Copyright 2024 The KerasNLP Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from keras import layers | ||
| from keras import ops | ||
|
|
||
| from keras_nlp.src.layers.modeling.token_and_position_embedding import ( | ||
| TokenAndPositionEmbedding, | ||
| ) | ||
| from keras_nlp.src.models.backbone import Backbone | ||
| from keras_nlp.src.models.stable_diffusion_v3.clip_encoder_block import ( | ||
| CLIPEncoderBlock, | ||
| ) | ||
|
|
||
|
|
||
| class CLIPTextEncoder(Backbone): | ||
| def __init__( | ||
| self, | ||
| embedding_dim, | ||
| hidden_dim, | ||
| num_layers, | ||
| num_heads, | ||
| intermediate_dim, | ||
| intermediate_activation="quick_gelu", | ||
| intermediate_output_index=None, | ||
| vocabulary_size=49408, | ||
| sequence_length=77, | ||
| dtype=None, | ||
| **kwargs, | ||
| ): | ||
| if ( | ||
| intermediate_output_index is not None | ||
| and intermediate_output_index < 0 | ||
| ): | ||
| intermediate_output_index += num_layers | ||
|
|
||
| # === Layers === | ||
| self.embedding = TokenAndPositionEmbedding( | ||
| vocabulary_size=vocabulary_size, | ||
| sequence_length=sequence_length, | ||
| embedding_dim=embedding_dim, | ||
| dtype=dtype, | ||
| name="embedding", | ||
| ) | ||
| self.encoder_layers = [ | ||
| CLIPEncoderBlock( | ||
| hidden_dim, | ||
| num_heads, | ||
| intermediate_dim, | ||
| intermediate_activation, | ||
| dtype=dtype, | ||
| ) | ||
| for _ in range(num_layers) | ||
| ] | ||
| self.layer_norm = layers.LayerNormalization( | ||
| epsilon=0.00001, dtype=dtype, name="layer_norm" | ||
| ) | ||
| self.text_projection = layers.Dense( | ||
| hidden_dim, | ||
| use_bias=False, | ||
| dtype=dtype, | ||
| name="text_projection", | ||
| ) | ||
|
|
||
| # === Functional Model === | ||
| encoder_token_ids = layers.Input( | ||
| shape=(sequence_length,), dtype="int32", name="encoder_token_ids" | ||
| ) | ||
| causal_mask = layers.Input( | ||
| batch_shape=(sequence_length, sequence_length), name="causal_mask" | ||
| ) | ||
| x = self.embedding(encoder_token_ids) | ||
| encoder_intermediate_output = None | ||
| # Encoder. | ||
| for i, block in enumerate(self.encoder_layers): | ||
| x = block(x, attention_mask=causal_mask) | ||
| if i == intermediate_output_index: | ||
| encoder_intermediate_output = x | ||
| x = self.layer_norm(x) | ||
| encoder_output = x | ||
| if encoder_intermediate_output is not None: | ||
| encoder_intermediate_output = self.layer_norm( | ||
| encoder_intermediate_output | ||
| ) | ||
| # Projection. | ||
| indices = ops.expand_dims( | ||
| ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1 | ||
| ) | ||
| pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1) | ||
| pooled_output = ops.squeeze(pooled_output, axis=1) | ||
| projection_output = self.text_projection(pooled_output) | ||
|
|
||
| outputs = { | ||
| "encoder_sequence_output": encoder_output, | ||
| "encoder_pooled_output": pooled_output, | ||
| "encoder_projection_output": projection_output, | ||
| } | ||
| if intermediate_output_index is not None: | ||
| outputs["encoder_intermediate_output"] = encoder_intermediate_output | ||
|
|
||
| super().__init__( | ||
| inputs={ | ||
| "encoder_token_ids": encoder_token_ids, | ||
| "causal_mask": causal_mask, | ||
| }, | ||
| outputs=outputs, | ||
| dtype=dtype, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # === Config === | ||
| self.embedding_dim = embedding_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.num_layers = num_layers | ||
| self.num_heads = num_heads | ||
| self.intermediate_dim = intermediate_dim | ||
| self.intermediate_activation = intermediate_activation | ||
| self.intermediate_output_index = intermediate_output_index | ||
| self.vocabulary_size = vocabulary_size | ||
| self.sequence_length = sequence_length | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "embedding_dim": self.embedding_dim, | ||
| "hidden_dim": self.hidden_dim, | ||
| "num_layers": self.num_layers, | ||
| "num_heads": self.num_heads, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "intermediate_activation": self.intermediate_activation, | ||
| "intermediate_output_index": self.intermediate_output_index, | ||
| "vocabulary_size": self.vocabulary_size, | ||
| "sequence_length": self.sequence_length, | ||
| } | ||
| ) | ||
| return config | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can land all of this in SD3 folder for now.
But we might eventually want clip in it's own folder here. Usable with SD3. WDYT @divyashreepathihalli ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation will be cleaner if we have a CLIP model. I can try adding it after landing SD3.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@james77777778, CLIP is in KerasCV - it needs to be added to keras_hub still. That would make this implementation cleaner for sure.
Given that this is only using the text encoder part and not the whole clip model. We can land this as is now and clean up once CLIP is added.