diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1a47d728c2f9..4f78b324a8e2 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -12,10 +12,110 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools +import math + import flax.linen as nn +import jax import jax.numpy as jnp +def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): + """Multi-head dot product attention with a limited number of queries.""" + num_kv, num_heads, k_features = key.shape[-3:] + v_features = value.shape[-1] + key_chunk_size = min(key_chunk_size, num_kv) + query = query / jnp.sqrt(k_features) + + @functools.partial(jax.checkpoint, prevent_cse=False) + def summarize_chunk(query, key, value): + attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) + + max_score = jnp.max(attn_weights, axis=-1, keepdims=True) + max_score = jax.lax.stop_gradient(max_score) + exp_weights = jnp.exp(attn_weights - max_score) + + exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) + max_score = jnp.einsum("...qhk->...qh", max_score) + + return (exp_values, exp_weights.sum(axis=-1), max_score) + + def chunk_scanner(chunk_idx): + # julienne key array + key_chunk = jax.lax.dynamic_slice( + operand=key, + start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] + slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] + ) + + # julienne value array + value_chunk = jax.lax.dynamic_slice( + operand=value, + start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] + slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] + ) + + return summarize_chunk(query, key_chunk, value_chunk) + + chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) + + global_max = jnp.max(chunk_max, axis=0, keepdims=True) + max_diffs = jnp.exp(chunk_max - global_max) + + chunk_values *= jnp.expand_dims(max_diffs, axis=-1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(axis=0) + all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) + + return all_values / all_weights + + +def jax_memory_efficient_attention( + query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 +): + r""" + Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 + https://github.com/AminRezaei0x443/memory-efficient-attention + + Args: + query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) + key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) + value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) + precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): + numerical precision for computation + query_chunk_size (`int`, *optional*, defaults to 1024): + chunk size to divide query array value must divide query_length equally without remainder + key_chunk_size (`int`, *optional*, defaults to 4096): + chunk size to divide key and value array value must divide key_value_length equally without remainder + + Returns: + (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) + """ + num_q, num_heads, q_features = query.shape[-3:] + + def chunk_scanner(chunk_idx, _): + # julienne query array + query_chunk = jax.lax.dynamic_slice( + operand=query, + start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] + slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] + ) + + return ( + chunk_idx + query_chunk_size, # unused ignore it + _query_chunk_attention( + query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size + ), + ) + + _, res = jax.lax.scan( + f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter + ) + + return jnp.concatenate(res, axis=-3) # fuse the chunked result back + + class FlaxAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 @@ -29,6 +129,8 @@ class FlaxAttention(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` @@ -37,6 +139,7 @@ class FlaxAttention(nn.Module): heads: int = 8 dim_head: int = 64 dropout: float = 0.0 + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -77,13 +180,38 @@ def __call__(self, hidden_states, context=None, deterministic=True): key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) - # compute attentions - attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) - attention_scores = attention_scores * self.scale - attention_probs = nn.softmax(attention_scores, axis=2) + if self.use_memory_efficient_attention: + query_states = query_states.transpose(1, 0, 2) + key_states = key_states.transpose(1, 0, 2) + value_states = value_states.transpose(1, 0, 2) + + # this if statement create a chunk size for each layer of the unet + # the chunk size is equal to the query_length dimension of the deepest layer of the unet + + flatten_latent_dim = query_states.shape[-3] + if flatten_latent_dim % 64 == 0: + query_chunk_size = int(flatten_latent_dim / 64) + elif flatten_latent_dim % 16 == 0: + query_chunk_size = int(flatten_latent_dim / 16) + elif flatten_latent_dim % 4 == 0: + query_chunk_size = int(flatten_latent_dim / 4) + else: + query_chunk_size = int(flatten_latent_dim) + + hidden_states = jax_memory_efficient_attention( + query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + ) + + hidden_states = hidden_states.transpose(1, 0, 2) + else: + # compute attentions + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=2) + + # attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) - # attend to values - hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return hidden_states @@ -108,6 +236,8 @@ class FlaxBasicTransformerBlock(nn.Module): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ dim: int n_heads: int @@ -115,12 +245,17 @@ class FlaxBasicTransformerBlock(nn.Module): dropout: float = 0.0 only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False def setup(self): # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) # cross attention - self.attn2 = FlaxAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn2 = FlaxAttention( + self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype + ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -169,6 +304,8 @@ class FlaxTransformer2DModel(nn.Module): only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int n_heads: int @@ -178,6 +315,7 @@ class FlaxTransformer2DModel(nn.Module): use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 + use_memory_efficient_attention: bool = False def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) @@ -202,6 +340,7 @@ def setup(self): dropout=self.dropout, only_cross_attention=self.only_cross_attention, dtype=self.dtype, + use_memory_efficient_attention=self.use_memory_efficient_attention, ) for _ in range(self.depth) ] diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 8e9690d332c9..b8126c5f5930 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -37,6 +37,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -48,6 +50,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): add_downsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -72,6 +75,7 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -172,6 +176,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -184,6 +190,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): add_upsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -209,6 +216,7 @@ def setup(self): depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -311,6 +319,8 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): Number of attention blocks layers attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -319,6 +329,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 use_linear_projection: bool = False + use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -341,6 +352,7 @@ def setup(self): d_head=self.in_channels // self.attn_num_head_channels, depth=1, use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 812ca079db38..3c2f4a88ab7f 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -88,6 +88,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://arxiv.org/abs/2112.05682 """ @@ -111,6 +113,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 + use_memory_efficient_attention: bool = False def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors @@ -169,6 +172,7 @@ def setup(self): add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: @@ -190,6 +194,7 @@ def setup(self): dropout=self.dropout, attn_num_head_channels=attention_head_dim[-1], use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) @@ -217,6 +222,7 @@ def setup(self): dropout=self.dropout, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 9d91ff757799..6ab0b80ee655 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -296,6 +296,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pt = kwargs.pop("from_pt", False) + use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs @@ -451,7 +452,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): - loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) + loaded_sub_model, loaded_params = load_method( + loadable_folder, + from_pt=from_pt, + use_memory_efficient_attention=use_memory_efficient_attention, + dtype=dtype, + ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): if from_pt: diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index a461930f3a83..33f3aa671b3e 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -224,3 +224,47 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self): if jax.device_count() == 8: assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 1e-3 assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1 + + def test_jax_memory_efficient_attention(self): + prompt = ( + "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" + " field, close up, split lighting, cinematic" + ) + + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples) + + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="bf16", + dtype=jnp.bfloat16, + safety_checker=None, + ) + + params = replicate(params) + prompt_ids = pipeline.prepare_inputs(prompt) + prompt_ids = shard(prompt_ids) + images = pipeline(prompt_ids, params, prng_seed, jit=True).images + assert images.shape == (num_samples, 1, 512, 512, 3) + slice = images[2, 0, 256, 10:17, 1] + + # With memory efficient attention + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="bf16", + dtype=jnp.bfloat16, + safety_checker=None, + use_memory_efficient_attention=True, + ) + + params = replicate(params) + prompt_ids = pipeline.prepare_inputs(prompt) + prompt_ids = shard(prompt_ids) + images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images + assert images_eff.shape == (num_samples, 1, 512, 512, 3) + slice_eff = images[2, 0, 256, 10:17, 1] + + # I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum` + # over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now. + assert abs(slice_eff - slice).max() < 1e-2