Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def __post_init__(self):
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
# if self.ring_degree > 1 and self.ulysses_degree > 1:
# raise ValueError(
# "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
# )
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
Expand Down
139 changes: 137 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,68 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
x = _wait_tensor(x)
return x

def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
group_world_size = torch.distributed.get_world_size(group)

if scatter_idx == 2 and gather_idx == 1:
B, S_LOCAL, H, D = x.shape
S = S_LOCAL * group_world_size
H_LOCAL = H // group_world_size

# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous()


if group_world_size >1:
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
out = _all_to_all_single(x_temp, group=group)
#out = _wait_tensor(out)
else:
out = x_temp
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous()
out = out.reshape(B, S, H_LOCAL, D)
return out
elif scatter_idx == 1 and gather_idx == 2:
B, S, H_LOCAL, D = x.shape
H = H_LOCAL * group_world_size
S_LOCAL = S // group_world_size

#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)

if group_world_size >1:
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
output = _all_to_all_single(x_temp, group)
#output = _wait_tensor(output)
else:
output = x_temp
output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous()
output = output.reshape(B, S_LOCAL, H, D)
return output
else:
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")


class SeqAllToAllDim(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, scatter_id=2, gather_id=1):
ctx.group = group
ctx.scatter_id = scatter_id
ctx.gather_id = gather_id
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)

@staticmethod
def backward(ctx, grad_outputs):
grad_input = SeqAllToAllDim.apply(
ctx.group,
grad_outputs,
ctx.gather_id, # reversed
ctx.scatter_id, # reversed
)
return (None, grad_input, None, None)



class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -1147,7 +1209,7 @@ def backward(

grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))

return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None


class TemplatedUlyssesAttention(torch.autograd.Function):
Expand Down Expand Up @@ -1244,6 +1306,64 @@ def backward(

return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None

def TemplatedUnifiedAttention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
ulysses_group = ulysses_mesh.get_group()
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
ring_group = ring_mesh.get_group()
#hardcoded for now
scatter_idx = 2
gather_idx = 1

query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
out = TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if return_lse:
context_layer, lse, *_ = out
else:
context_layer = out
# Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D)
output = SeqAllToAllDim.apply(
ulysses_group,
context_layer,
gather_idx,
scatter_idx,
)
if return_lse:
# not sure if this is correct: Assuming (based on forward ops in ringAttention)
# the lse is of shape (B, S, H_LOCAL)
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1)
lse = lse.squeeze(-1)
return (output, lse)
return output

def _templated_context_parallel_attention(
query: torch.Tensor,
Expand All @@ -1268,7 +1388,22 @@ def _templated_context_parallel_attention(
raise ValueError("GQA is not yet supported for templated attention.")

# TODO: add support for unified attention with ring/ulysses degree both being > 1
if _parallel_config.context_parallel_config.ring_degree > 1:
if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1:
return TemplatedUnifiedAttention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
Expand Down
129 changes: 129 additions & 0 deletions tests/others/test_unified_sp_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import math
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers.models.attention_dispatch import TemplatedUnifiedAttention
import os

def run(rank, world_size):
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size
)

torch.manual_seed(0)

B, S, H, D = 2, 8, 4, 16 # small toy
q = torch.randn(B, S, H, D)
k = torch.randn(B, S, H, D)
v = torch.randn(B, S, H, D)

q.requires_grad_(True)

from diffusers.models._modeling_parallel import (
ParallelConfig,
ContextParallelConfig
)

pc = ParallelConfig(
context_parallel_config=ContextParallelConfig(
ring_degree=2,
ulysses_degree=2,
)
)

pc.context_parallel_config.setup(
rank=rank,
world_size=world_size,
device=torch.device("cpu"),
mesh=dist.device_mesh.init_device_mesh("cpu",
(2,2),
mesh_dim_names=["ring", "ulysses"],
)
)

def dummy_forward_op(
ctx,
q,
k,
v,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
*,
_save_ctx=True,
_parallel_config=None,
):
head_scale = math.sqrt(D)
attn = (q @ k.transpose(-1, -2)) / head_scale
out = attn @ v
lse = torch.logsumexp(attn, dim=-1)

if _save_ctx:
ctx.save_for_backward(q, k, v)
ctx._cached_qkv = []
ctx._cached_iter = 0

if not hasattr(ctx, "_cached_qkv"):
ctx._cached_qkv = []

ctx._cached_qkv.append((q.detach(), k.detach(), v.detach()))

return (out, lse) if return_lse else out

def dummy_backward_op(ctx, grad_out, *args, **kwargs):
if not hasattr(ctx, "_cached_qkv"):
raise RuntimeError("No cached tensors for backward.")

if not hasattr(ctx, "_cached_iter"):
ctx._cached_iter = 0

if ctx._cached_iter >= len(ctx._cached_qkv):
raise RuntimeError("Backward called more times than cached forwards.")

q, k, v = ctx._cached_qkv[ctx._cached_iter]
ctx._cached_iter += 1

head_scale = math.sqrt(D)
attn = (q @ k.transpose(-1, -2)) / head_scale

grad_v = attn.transpose(-1, -2) @ grad_out
grad_attn = grad_out @ v.transpose(-1, -2)
grad_q = (grad_attn @ k) / head_scale
grad_k = (grad_attn.transpose(-1, -2) @ q) / head_scale

return (
grad_q,
grad_k,
grad_v,
)


out = TemplatedUnifiedAttention(
q, k, v, None,
dropout_p=0.0,
is_causal=False,
scale=None,
enable_gqa=False,
return_lse=False,
forward_op=dummy_forward_op,
backward_op=dummy_backward_op,
_parallel_config=pc,
)

print(f"[RANK {rank}] output:", out.shape)

out.sum().backward()
print(f"[RANK {rank}] grad:", q.grad.shape)

dist.destroy_process_group()

if __name__ == "__main__":
world_size = 4
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
mp.spawn(run, args=(world_size,), nprocs=world_size)