diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4eb520c796..8d9e4193616c 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -90,10 +90,10 @@ def __post_init__(self): ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if self.ring_degree > 1 and self.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) + # if self.ring_degree > 1 and self.ulysses_degree > 1: + # raise ValueError( + # "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + # ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0c247b76d039..823d119c9753 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1025,6 +1025,68 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: x = _wait_tensor(x) return x +def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + group_world_size = torch.distributed.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + B, S_LOCAL, H, D = x.shape + S = S_LOCAL * group_world_size + H_LOCAL = H // group_world_size + + # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D + x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous() + + + if group_world_size >1: + #maybe here need to use the _all_to_all_single helper to avoid contiguity issues + out = _all_to_all_single(x_temp, group=group) + #out = _wait_tensor(out) + else: + out = x_temp + # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D + out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() + out = out.reshape(B, S, H_LOCAL, D) + return out + elif scatter_idx == 1 and gather_idx == 2: + B, S, H_LOCAL, D = x.shape + H = H_LOCAL * group_world_size + S_LOCAL = S // group_world_size + + #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D) + + if group_world_size >1: + #maybe here need to use the _all_to_all_single helper to avoid contiguity issues + output = _all_to_all_single(x_temp, group) + #output = _wait_tensor(output) + else: + output = x_temp + output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() + output = output.reshape(B, S_LOCAL, H, D) + return output + else: + raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") + + +class SeqAllToAllDim(torch.autograd.Function): + @staticmethod + def forward(ctx, group, input, scatter_id=2, gather_id=1): + ctx.group = group + ctx.scatter_id = scatter_id + ctx.gather_id = gather_id + return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) + + @staticmethod + def backward(ctx, grad_outputs): + grad_input = SeqAllToAllDim.apply( + ctx.group, + grad_outputs, + ctx.gather_id, # reversed + ctx.scatter_id, # reversed + ) + return (None, grad_input, None, None) + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod @@ -1147,7 +1209,7 @@ def backward( grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): @@ -1244,6 +1306,64 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None +def TemplatedUnifiedAttention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + ring_group = ring_mesh.get_group() + #hardcoded for now + scatter_idx = 2 + gather_idx = 1 + + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D) + output = SeqAllToAllDim.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) + if return_lse: + # not sure if this is correct: Assuming (based on forward ops in ringAttention) + # the lse is of shape (B, S, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1) + lse = lse.squeeze(-1) + return (output, lse) + return output def _templated_context_parallel_attention( query: torch.Tensor, @@ -1268,7 +1388,22 @@ def _templated_context_parallel_attention( raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1: + if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1: + return TemplatedUnifiedAttention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( query, key, diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py new file mode 100644 index 000000000000..4c0621999bd0 --- /dev/null +++ b/tests/others/test_unified_sp_attention.py @@ -0,0 +1,129 @@ +import math +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from diffusers.models.attention_dispatch import TemplatedUnifiedAttention +import os + +def run(rank, world_size): + dist.init_process_group( + backend="gloo", + rank=rank, + world_size=world_size + ) + + torch.manual_seed(0) + + B, S, H, D = 2, 8, 4, 16 # small toy + q = torch.randn(B, S, H, D) + k = torch.randn(B, S, H, D) + v = torch.randn(B, S, H, D) + + q.requires_grad_(True) + + from diffusers.models._modeling_parallel import ( + ParallelConfig, + ContextParallelConfig + ) + + pc = ParallelConfig( + context_parallel_config=ContextParallelConfig( + ring_degree=2, + ulysses_degree=2, + ) + ) + + pc.context_parallel_config.setup( + rank=rank, + world_size=world_size, + device=torch.device("cpu"), + mesh=dist.device_mesh.init_device_mesh("cpu", + (2,2), + mesh_dim_names=["ring", "ulysses"], + ) + ) + + def dummy_forward_op( + ctx, + q, + k, + v, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + *, + _save_ctx=True, + _parallel_config=None, + ): + head_scale = math.sqrt(D) + attn = (q @ k.transpose(-1, -2)) / head_scale + out = attn @ v + lse = torch.logsumexp(attn, dim=-1) + + if _save_ctx: + ctx.save_for_backward(q, k, v) + ctx._cached_qkv = [] + ctx._cached_iter = 0 + + if not hasattr(ctx, "_cached_qkv"): + ctx._cached_qkv = [] + + ctx._cached_qkv.append((q.detach(), k.detach(), v.detach())) + + return (out, lse) if return_lse else out + + def dummy_backward_op(ctx, grad_out, *args, **kwargs): + if not hasattr(ctx, "_cached_qkv"): + raise RuntimeError("No cached tensors for backward.") + + if not hasattr(ctx, "_cached_iter"): + ctx._cached_iter = 0 + + if ctx._cached_iter >= len(ctx._cached_qkv): + raise RuntimeError("Backward called more times than cached forwards.") + + q, k, v = ctx._cached_qkv[ctx._cached_iter] + ctx._cached_iter += 1 + + head_scale = math.sqrt(D) + attn = (q @ k.transpose(-1, -2)) / head_scale + + grad_v = attn.transpose(-1, -2) @ grad_out + grad_attn = grad_out @ v.transpose(-1, -2) + grad_q = (grad_attn @ k) / head_scale + grad_k = (grad_attn.transpose(-1, -2) @ q) / head_scale + + return ( + grad_q, + grad_k, + grad_v, + ) + + + out = TemplatedUnifiedAttention( + q, k, v, None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + return_lse=False, + forward_op=dummy_forward_op, + backward_op=dummy_backward_op, + _parallel_config=pc, + ) + + print(f"[RANK {rank}] output:", out.shape) + + out.sum().backward() + print(f"[RANK {rank}] grad:", q.grad.shape) + + dist.destroy_process_group() + +if __name__ == "__main__": + world_size = 4 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + mp.spawn(run, args=(world_size,), nprocs=world_size) \ No newline at end of file