diff --git a/README.md b/README.md index af37de2..ffd3188 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ pip install jax[tpu] To run: +sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled" + ``` python wan_tx.py ``` diff --git a/quantize.py b/quantize.py new file mode 100644 index 0000000..e0ce9df --- /dev/null +++ b/quantize.py @@ -0,0 +1,217 @@ +from torch import nn +import torch +from typing import Tuple, Union + +import jax +import jax.numpy as jnp +import torch +from torch.nn import functional as F +import torchax +import torchax.interop + +class WeightOnlyPerChannelQuantizedLinear(torch.nn.Module): + + def __init__( + self, + in_features, + out_features, + bias=False, + device=None, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + weight = torch.ones( + (out_features, in_features), dtype=torch.int8, device=device + ) + self.register_buffer("weight", weight) + + weight_scaler = torch.ones( + (out_features,), dtype=torch.bfloat16, device=device + ) + self.register_buffer("weight_scaler", weight_scaler) + + self.is_symmetric_weight = True + + if not self.is_symmetric_weight: + zero_point = torch.ones( + (out_features,), dtype=torch.bfloat16, device=device + ) + self.register_buffer("zero_point", zero_point) + else: + self.register_buffer("zero_point", None) + + if bias: + bias_tensor = torch.zeros((out_features, ), dtype=torch.bfloat16, device=device) + self.register_buffer('bias', bias_tensor) + + + # Number of bits of weight tensor + self.n_bit = 8 + + # Quantize activation + self.quantize_activation = True + + # Flag to enable dequantize weight first, then do matmul. Useful for debugging. + self.run_fake_quantize = False + + def _load_quantized_weights(self, w_q, scale, zp=None): + """ + Load weights quantized by 'quantize_tensor'. + """ + self.weight, self.weight_scaler, self.zero_point = load_q_weight_helper( + w_q, scale, zp, block_size=-1 + ) + + def quantize_weight_from_nn_linear(self, weight): + assert weight.dim() == 2, "Expect 2D weight from torch.nn.Linear." + assert weight.shape == ( + self.out_features, + self.in_features, + ), f"Got unexpected weight of shape {weight.shape}, expected weight shape ({self.out_features}, {self.in_features})." + w_q, scale, zp = quantize_tensor( + weight, (1,), self.n_bit, self.is_symmetric_weight, block_size=-1 + ) + self._load_quantized_weights(w_q, scale, zp) + + def forward(self, inputs): + if not self.quantize_activation: + result = F.linear(inputs, self.weight) + result *= self.weight_scaler + if self.bias is not None: + result += self.bias + return result + else: + inputs, act_s, _ = quantize_tensor(inputs, reduce_axis=(2,)) + # We have to call jax because we need to specify the output dtype of dot + # dot(int8, int8)->bf16. + # This semantic cannot be represented in torch. The inferred output dtype + # will be int8 in torch, causing the dot result to overflow. + result = torchax.interop.call_jax( + jax.lax.dot_general, + inputs, + self.weight, + (((2,), (1)), ((), ())), + None, + jnp.bfloat16.dtype, + ) + result = result * self.weight_scaler + if self.quantize_activation: + result = result * act_s + if not self.is_symmetric_weight: + zp_out = torch.einsum("...c,z->...z", inputs, self.zero_point) + result = result - zp_out + return result + +def create_quantized_from_nn_linear( + float_linear: nn.Linear +): + obj = WeightOnlyPerChannelQuantizedLinear( + float_linear.in_features, + float_linear.out_features, + float_linear.bias is not None, + "meta", + ) + obj.quantize_weight_from_nn_linear(float_linear.weight) + if float_linear.bias is not None: + obj.bias = float_linear.bias + return obj + + +EPS = 1e-5 + + +def quantize_tensor( + w: torch.Tensor, + reduce_axis: Union[Tuple[int], int], + n_bit: int = 8, + symmetric: bool = True, + block_size: int = -1, +): + """ + Quantize weight tensor w along 'reduce_axis'. + + Args: + w: weight tensor to be quantized. + reduce_axis: axises along which to quantize. + n_bit: Quantize to n_bit bits. (Use int8 container for n_bits < 8). + symmetric: Whether quantization is symmetric. + block_size: Blocksize for blockwise quantization. -1 for per-channel quant. + + Return: + w_q: Quantized weight in int8 container + scale: scalar for quantized tensor + zero_point: zero_point for quantized tensor, None if symmetric quantization + """ + + assert 0 < n_bit <= 8, "Quantization bits must be between [1, 8]." + if isinstance(reduce_axis, int): + reduce_axis = (reduce_axis,) + + if block_size > 0: + axis = reduce_axis[0] + w_shape = w.shape + assert w_shape[axis] % block_size == 0 + w = w.reshape(w_shape[:axis] + (-1, block_size) + w_shape[axis + 1 :]) + reduce_axis = axis + 1 + + max_int = 2 ** (n_bit - 1) - 1 + min_int = -(2 ** (n_bit - 1)) + if not symmetric: + max_val = w.amax(dim=reduce_axis, keepdim=True) + min_val = w.amin(dim=reduce_axis, keepdim=True) + scales = (max_val - min_val).clamp(min=EPS) / float(max_int - min_int) + zero_point = min_int - min_val / scales + else: + max_val = w.abs().amax(dim=reduce_axis, keepdim=True) + max_val = max_val.clamp(min=EPS) + scales = max_val / max_int + zero_point = 0 + + w = torch.clamp( + torch.round(w * (1.0 / scales) + zero_point), min_int, max_int + ).to(torch.int8) + + return w, scales, zero_point if not symmetric else None + + +def dequantize_tensor(w, scale, zero_point=None): + """Dequantize tensor quantized by quantize_tensor.""" + if zero_point is not None: + return (w - zero_point) * scale + + return w * scale + + +def load_q_weight_helper(w_q, scale, zp=None, block_size=-1): + """Helper function to update the shape of quantized weight to match + what quantized linear layer expects.""" + if block_size < 0: + w_q = w_q.to(torch.int8) + if zp is not None: + zp = (zp * scale).squeeze(-1).to(torch.bfloat16) + scale = scale.squeeze(-1).to(torch.bfloat16) + else: + w_q = w_q.permute(1, 2, 0).to(torch.int8) + if zp is not None: + zp = (zp * scale).transpose(1, 0).squeeze(-1).to(torch.bfloat16) + scale = scale.transpose(1, 0).squeeze(-1).to(torch.bfloat16) + return w_q, scale, zp + + +def quantize_model(float_model): + """Apply quantization to linear layers.""" + + def quantize_nn_mod(float_model): + for name, mod in float_model.named_modules(): + new_mod = None + + if isinstance(mod, torch.nn.Linear): + new_mod = create_quantized_from_nn_linear(mod) + + if new_mod: + setattr(float_model, name, new_mod) + + float_model.apply(quantize_nn_mod) + return float_model diff --git a/run_transformer.py b/run_transformer.py new file mode 100644 index 0000000..117b497 --- /dev/null +++ b/run_transformer.py @@ -0,0 +1,565 @@ +import quantize +import functools +import re +import math +import torch +import torchax +from torchax.ops import ops_registry +import time +import jax +import jax.numpy as jnp + +from jax.experimental.pallas.ops.tpu import splash_attention +from jax.experimental.shard_map import shard_map +from jax.sharding import NamedSharding, PartitionSpec as P + +from diffusers.utils import export_to_video +from diffusers import AutoencoderKLWan, WanPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + +from jax.tree_util import register_pytree_node + +from transformers import modeling_outputs + +from datetime import datetime + +# import torchax.ops.jtorch +import traceback + +#### SETTINGS +# 1.3B +# MODEL_ID = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +# 14B +MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + +# 384p +# FLOW_SHIFT = 3.0 # 5.0 for 720P, 3.0 for 480P +# WIDTH = 640 +# HEIGHT = 384 +# 480p +# FLOW_SHIFT = 3.0 # 5.0 for 720P, 3.0 for 480P +# WIDTH = 832 +# HEIGHT = 480 +# 720p +FLOW_SHIFT = 5.0 # 5.0 for 720P, 3.0 for 480P +WIDTH = 1280 +HEIGHT = 720 + +# 41 frames +# FRAMES = 41 +# FPS = 8 + +# 81 frames +FRAMES = 81 +FPS = 16 + +# step +NUM_STEP = 50 +# NUM_STEP = 1 + +BQSIZE = 2520 # 2240 # 3024 #2520 +BKVSIZE = 1024 # 2048 # 2304 # 1664 #2048 + +# <--- NEW: Local Attention Window Size Setting ---> +# window_size = (left, right). (128, 0) means each token can attend to itself and the previous 128 tokens. +# Set right=0 to maintain causality for autoregressive models. +# Set to None to use the original full Causal Attention. +WINDOW_SIZE = None + +PROFILE_OUT_PATH = "/tmp/tensorboard" + +#### + + +axis = 'axis' + +# Sharding for tranformers, all the replicated are commented out for speed +transformer_shardings = { +# 'scale_shift_table': (), # (torch.Size([1, 2, 1536]), torch.float32) +# 'patch_embedding.weight': (), # (torch.Size([1536, 16, 1, 2, 2]), torch.bfloat16) +# 'patch_embedding.bias': (), # (torch.Size([1536]), torch.bfloat16) +r'condition_embedder.time_embedder.linear_1.weight': (axis, None), # (torch.Size([1536, 256]), torch.float32) +r'condition_embedder.time_embedder.linear_1.weight_scaler': (axis, ), # (torch.Size([1536, 256]), torch.float32) +r'condition_embedder.time_embedder.linear_1.bias': (axis,), # (torch.Size([1536]), torch.float32) +r'condition_embedder.time_embedder.linear_2.weight': (None, axis), # (torch.Size([1536, 1536]), torch.float32) +r'condition_embedder.time_embedder.linear_2.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.float32) +# 'condition_embedder.time_embedder.linear_2.bias': (), # (torch.Size([1536]), torch.float32) +# 'condition_embedder.time_proj.weight': (), # (torch.Size([9216, 1536]), torch.bfloat16) +# 'condition_embedder.time_proj.bias': (), # (torch.Size([9216]), torch.bfloat16) +r'condition_embedder.text_embedder.linear_1.weight': (axis, None), # (torch.Size([1536, 4096]), torch.bfloat16) +r'condition_embedder.text_embedder.linear_1.weight_scaler': (axis, ), # (torch.Size([1536, 4096]), torch.bfloat16) +r'condition_embedder.text_embedder.linear_1.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) +r'condition_embedder.text_embedder.linear_2.weight': (None, axis), # (torch.Size([1536, 1536]), torch.bfloat16) +r'condition_embedder.text_embedder.linear_2.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.bfloat16) +# 'condition_embedder.text_embedder.linear_2.bias': (), # (torch.Size([1536]), torch.bfloat16) +# 'blocks.\d+.scale_shift_table': (), # (torch.Size([1, 6, 1536]), torch.float32) +# 'blocks.\d+.attn1.norm_q.weight': (), # (torch.Size([1536]), torch.bfloat16) +# 'blocks.\d+.attn1.norm_k.weight': (), # (torch.Size([1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_q.weight': (axis, None), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_q.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_q.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_k.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_k.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_k.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_v.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_v.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_v.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) +# to_out has 2 submodules, the first is the Linear and second is dropout +r'blocks.\d+.attn1.to_out.0.weight': (None, axis), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_out.0.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.bfloat16) +# 'blocks.\d+.attn1.to_out.0.bias': (), # (torch.Size([1536]), torch.bfloat16) +# 'blocks.\d+.attn1.to_out.1.weight': (), # (torch.Size([1536, 1536]), torch.bfloat16) +# 'blocks.\d+.attn1.to_out.1.bias': (), # (torch.Size([1536]), torch.bfloat16) +# 'blocks.\d+.attn2.norm_q.weight': (), # (torch.Size([1536]), torch.bfloat16) +# 'blocks.\d+.attn2.norm_k.weight': (), # (torch.Size([1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_q.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_q.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_q.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_k.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_k.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_k.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_v.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_v.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_v.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_out.0.weight': (None, axis), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_out.0.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.bfloat16) +# 'blocks.\d+.attn2.to_out.0.bias': (), # (torch.Size([1536]), torch.bfloat16) +# 'blocks.\d+.attn2.to_out.1.weight': (), # (torch.Size([1536, 1536]), torch.bfloat16) +# 'blocks.\d+.attn2.to_out.1.bias': (), # (torch.Size([1536]), torch.bfloat16) +# 'blocks.\d+.norm2.weight': (), # (torch.Size([1536]), torch.float32) +# 'blocks.\d+.norm2.bias': (), # (torch.Size([1536]), torch.float32) +r'blocks.\d+.ffn.net.0.proj.weight': (axis,), # (torch.Size([8960, 1536]), torch.bfloat16) +r'blocks.\d+.ffn.net.0.proj.weight_scaler': (axis,), # (torch.Size([8960, 1536]), torch.bfloat16) +r'blocks.\d+.ffn.net.0.proj.bias': (axis, ), # (torch.Size([8960]), torch.bfloat16) +r'blocks.\d+.ffn.net.2.weight': (None, axis), # (torch.Size([1536, 8960]), torch.bfloat16) +r'blocks.\d+.ffn.net.2.weight_scaler': (None, ), # (torch.Size([1536, 8960]), torch.bfloat16) +# 'blocks.\d+.ffn.net.2.bias': (), # (torch.Size([1536]), torch.bfloat16) +# 'proj_out.weight': (), # (torch.Size([64, 1536]), torch.bfloat16) +# 'proj_out.bias': (), # (torch.Size([64]), torch.bfloat16) +} + +text_encoder_shardings = { + 'shared.weight': (axis, ), # (torch.Size([256384, 4096]), torch.bfloat16) + 'encoder.block.*.layer.*.SelfAttention.q.weight': (axis, ), # (torch.Size([4096, 4096]), torch.bfloat16) + 'encoder.block.*.layer.*.SelfAttention.k.weight': (axis, ), # (torch.Size([4096, 4096]), torch.bfloat16) + 'encoder.block.*.layer.*.SelfAttention.v.weight': (axis, ), # (torch.Size([4096, 4096]), torch.bfloat16) + 'encoder.block.*.layer.*.SelfAttention.o.weight': (None, axis), # (torch.Size([4096, 4096]), torch.bfloat16) + # 'encoder.block.*.layer.*.SelfAttention.relative_attention_bias.weight': (), # (torch.Size([32, 64]), torch.bfloat16) + # 'encoder.block.*.layer.*.layer_norm.weight': (), # (torch.Size([4096]), torch.bfloat16) + 'encoder.block.*.layer.*.DenseReluDense.wi_0.weight': (axis, ), # (torch.Size([10240, 4096]), torch.bfloat16) + 'encoder.block.*.layer.*.DenseReluDense.wi_1.weight': (axis, ), # (torch.Size([10240, 4096]), torch.bfloat16) + 'encoder.block.*.layer.*.DenseReluDense.wo.weight': (None, axis), # (torch.Size([4096, 10240]), torch.bfloat16) + # 'encoder.final_layer_norm.weight': (), # (torch.Size([4096]), torch.bfloat16) +} + + +def _shard_weight_dict(weight_dict, sharding_dict, mesh): + result = {} + for k, v in weight_dict.items(): + for target, sharding in sharding_dict.items(): + if re.fullmatch(target, k) is not None: + v.apply_jax_(jax.device_put, NamedSharding(mesh, P(*sharding))) + break + else: + # replicate + v.apply_jax_(jax.device_put, NamedSharding(mesh, P())) + + result[k] = v + return result + + +def flatten_model_output(obj): + return obj.to_tuple(), type(obj) + +def unflatten_model_output(aux, children): + return aux(*children) + +register_pytree_node( + modeling_outputs.BaseModelOutputWithPastAndCrossAttentions, + flatten_model_output, + unflatten_model_output) + +def make_key(name): + return re.sub('\.\d+\.', '.*.', name) + + +def _get_weights_of_linear(module): + + result = {} + + def fn(start_path, module): + if isinstance(module, torch.nn.Linear): + for k, v in module.named_parameters(): + start_path.append(k) + key = '.'.join(start_path) + result[key] = v + start_path.pop() + else: + for name, child in module.named_children(): + start_path.append(name) + fn(start_path, child) + start_path.pop() + fn([], module) + return result + + +def _print_weights(module): + all_buffers = dict(module.named_parameters()) + all_buffers.update(module.named_buffers()) + result = {} + for k, v in all_buffers.items(): + result[make_key(k)] = (v.shape, v.dtype) + print('{') + for k, v in result.items(): + print(f"'{k}': (), # {v}") + print('}') + + +### Splash attention ### + +def _sdpa_reference( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, +) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones( + L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + if dropout_p > 0: + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight @ value + +CF_FOR_ALL_REDUCE_AND_ALL_GATHER = ( + " --xla_enable_async_all_reduce=true" + " --xla_enable_async_all_gather=true" + " --xla_tpu_overlap_compute_collective_tc=true" + " --xla_tpu_enable_async_collective_fusion=true" + " --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" + " --xla_tpu_enable_async_collective_fusion_multiple_steps=true" + " --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true" + " --xla_tpu_decompose_all_gather_einsum=true" + " --xla_tpu_decompose_einsum_reduce_scatter=true" +) +import os +os.environ['LIBTPU_INIT_ARGS'] = CF_FOR_ALL_REDUCE_AND_ALL_GATHER + + + +# <--- MODIFIED: Added window_size parameter to the function signature ---> +def _tpu_splash_attention(query, key, value, env, scale=None, is_causal=False, window_size=None): + # print('TPU flash attention', jax.typeof(query), jax.typeof(key), jax.typeof(value)) + import math + mesh = env._mesh + num_heads = query.shape[1] + + # Debug print to check window_size + # print(f"[DEBUG] _tpu_splash_attention called with window_size={window_size}") + + # The function that will be sharded across devices. + def _attention_on_slices(q, k, v): + import jax.numpy as jnp + # Scale the query tensor. This happens on each device with its slice of data. + scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale + q = q * scale_factor + + # Helper to pad to next multiple + def pad_to_multiple(x, multiple, axis): + seq_len = x.shape[axis] + if seq_len < multiple: + return x, seq_len + pad_len = (multiple - seq_len % multiple) % multiple + if pad_len == 0: + return x, seq_len + pad_width = [(0, 0)] * x.ndim + pad_width[axis] = (0, pad_len) + return jnp.pad(x, pad_width), seq_len + + # This function operates on a single item from the batch. + def kernel_3d(q_3d, k_3d, v_3d): + q_seq_len = q_3d.shape[1] + kv_seq_len = k_3d.shape[1] + num_heads_on_device = q_3d.shape[0] + + # Pad q, k, v to next multiple of BQSIZE/BKVSIZE + q_3d_padded, q_orig_len = pad_to_multiple(q_3d, BQSIZE, axis=1) + k_3d_padded, k_orig_len = pad_to_multiple(k_3d, BKVSIZE, axis=1) + v_3d_padded, v_orig_len = pad_to_multiple(v_3d, BKVSIZE, axis=1) + + padded_q_seq_len = q_3d_padded.shape[1] + padded_kv_seq_len = k_3d_padded.shape[1] + + # ======================= NEW MASK LOGIC ======================= + if window_size is not None: + mask_class = functools.partial(splash_attention.LocalMask, window_size=window_size, offset=0) + else: + mask_class = splash_attention.FullMask + + mask = splash_attention.MultiHeadMask( + [mask_class((padded_q_seq_len, padded_kv_seq_len)) for _ in range(num_heads_on_device)] + ) + # ============================================================= + + block_sizes = splash_attention.BlockSizes( + block_q=min(BQSIZE, padded_q_seq_len), block_kv=min(BKVSIZE, padded_kv_seq_len) + ) + print('===== block ====') + print(block_sizes) + print('===== block ====') + splash_kernel = splash_attention.make_splash_mha( + mask=mask, block_sizes=block_sizes, head_shards=1, q_seq_shards=1 + ) + out = splash_kernel(q_3d_padded, k_3d_padded, v_3d_padded) + # Remove padding if any + return out[:, :q_orig_len, ...] + + # Map the kernel over the batch dimension. + vmapped_kernel = jax.vmap(kernel_3d, in_axes=(0, 0, 0), out_axes=0) + return vmapped_kernel(q, k, v) + + # Determine the partitioning spec based on the number of heads. + if num_heads < mesh.size: + # Replicated case for VAE. All devices get the full tensor. + partition_spec = P() + else: + # Sharded case for Transformer. Split along the heads axis. + partition_spec = P(None, 'axis', None, None) + + # ALWAYS use shard_map. The partition_spec will control the behavior. + sharded_fn = shard_map( + _attention_on_slices, + mesh=mesh, + in_specs=(partition_spec, partition_spec, partition_spec), + out_specs=partition_spec, + check_rep=False, + ) + return sharded_fn(query, key, value) + + +# <--- MODIFIED: Added window_size parameter to the function signature ---> +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + env=None, + window_size=None, # <--- NEW +) -> torch.Tensor: + # Debug prints to understand what's happening + #print(f"[DEBUG] scaled_dot_product_attention called with:") + #print(f" query.shape={query.shape}") + #print(f" key.shape={key.shape}") + #print(f" value.shape={value.shape}") + #print(f" query.shape[-1]={query.shape[-1]}") + #print(f" window_size={window_size}") + #print(f" env.config.use_tpu_splash_attention={env.config.use_tpu_splash_attention if env else 'None'}") + + # <--- MODIFIED: Disable splash attention for VAE ---> + # VAE typically has different attention patterns, disable splash attention for it + # Check if this is likely VAE attention by looking at the shape + if query.shape[-1] >= 384: # VAE typically has larger hidden dimensions (384) + #print(f"[DEBUG] Using reference implementation (VAE detected)") + # Use reference implementation for VAE + return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, + scale, enable_gqa) + + if env.config.use_tpu_splash_attention: + #print(f"[DEBUG] Using splash attention") + jquery, jkey, jvalue = env.t2j_iso((query, key, value)) + # <--- MODIFIED: Pass window_size to the backend function ---> + res = _tpu_splash_attention(jquery, jkey, jvalue, env, scale=scale, is_causal=is_causal, window_size=window_size) + return env.j2t_iso(res) + + #print(f"[DEBUG] Using reference implementation (fallback)") + return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, + scale, enable_gqa) + +### +def _shard_weight_fsdp(weight_dict, mesh): + result = {} + for k, v in weight_dict.items(): + if len(v.shape) == 2: + v.apply_jax_(jax.device_put, NamedSharding(mesh, P('axis'))) + else: + v.apply_jax_(jax.device_put, NamedSharding(mesh, P())) + result[k] = v + return result + +def main(): + # Set JAX config to enable compilation cache + jax.config.update("jax_compilation_cache_dir", "/dev/shm/jax_cache") + jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) + jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) + jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") + + torch.set_default_dtype(torch.bfloat16) + # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers + #model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + # model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + model_id = MODEL_ID + vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) + # flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + flow_shift = FLOW_SHIFT + scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift) + pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + pipe.scheduler = scheduler + + # print('vae=====') + # _print_weights(pipe.vae) + # print('trans===') + # print(_get_weights_of_linear(pipe.transformer).keys()) + # print('encoder===') + # _print_weights(pipe.text_encoder) + # return + + def _move_module(module): + with jax.default_device('cpu'): + state_dict = module.state_dict() + state_dict = env.to_xla(state_dict) + module.load_state_dict(state_dict, assign=True) + + torchax.enable_globally() + env = torchax.default_env() + mesh = jax.make_mesh((len(jax.devices()),), (axis,)) + env.default_device_or_sharding = NamedSharding(mesh, P()) + + env._mesh = mesh + env.config.use_tpu_splash_attention = True + + # <--- MODIFIED: Override flash attention with custom function, now with window_size ---> + custom_attention = functools.partial( + scaled_dot_product_attention, + env=env, + window_size=WINDOW_SIZE # Inject the global window size setting here + ) + # Workaround for the function lack is_view_op argument + # env.override_op_definition(torch.nn.functional.scaled_dot_product_attention, custom_attention) + op_to_override = torch.nn.functional.scaled_dot_product_attention + op_impl = custom_attention + env._ops[op_to_override] = ops_registry.Operator( + op_to_override, + op_impl, + is_jax_function=False, + is_user_defined=True, + needs_env=False, + is_view_op=False, + ) + + + vae_options = torchax.CompileOptions( + methods_to_compile=['decode'] + ) + _move_module(pipe.vae) + pipe.vae = torchax.compile(pipe.vae) + _move_module(pipe.text_encoder) + pipe.text_encoder = torchax.compile(pipe.text_encoder) + + quantize.quantize_model(pipe.transformer.blocks) + print('Quantization done') + + # the param below is not declared as param or buffer so the module.to('jax') didnt work + _move_module(pipe.transformer) + pipe.transformer.rope.freqs = pipe.transformer.rope.freqs.to('jax') + options = torchax.CompileOptions( + jax_jit_kwargs={'static_argnames': ('return_dict',)} + ) + pipe.transformer = torchax.compile(pipe.transformer, options) + + #pipe.to('jax') + print('Number of devices is:, ', len(jax.devices())) + + + pipe.transformer.params = {k: v.data if isinstance(v, torch.nn.Parameter) else v + for k, v in pipe.transformer.params.items()} + pipe.transformer.params = _shard_weight_fsdp(pipe.transformer.params, + mesh) + pipe.transformer.buffers = _shard_weight_fsdp(pipe.transformer.buffers, + mesh) + pipe.text_encoder.params = _shard_weight_dict(pipe.text_encoder.params, + text_encoder_shardings, + mesh) + pipe.text_encoder.buffers = _shard_weight_dict(pipe.text_encoder.buffers, + text_encoder_shardings, + mesh) + + # NOTE this will effectively replicate vae + pipe.vae.params = _shard_weight_dict(pipe.vae.params, {}, mesh) + pipe.vae.buffers = _shard_weight_dict(pipe.vae.buffers, {}, mesh) + + def move_scheduler(scheduler): + for k, v in scheduler.__dict__.items(): + if isinstance(v, torch.Tensor): + setattr(scheduler, k, v.to('jax')) + + #move_scheduler(pipe.scheduler) + + def module_size(module): + size = 0 + for k, v in module.state_dict().items(): + size += math.prod(v.shape) * v.dtype.itemsize + return size + + for m in dir(pipe): + module = getattr(pipe, m, None) + if isinstance(module, torch.nn.Module): + print(m, module_size(module) / (1024 * 1024 * 1024), 'G') + + + prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + # prompt = "Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach.The crashing blue waters create white-tipped waves,while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and greenshrubbery covers the cliffs edge. The steep drop from the road down to the beach is adramatic feat, with the cliff's edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway." + negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + replicate = NamedSharding(mesh, P()) + + def make_input(): + return (torch.randn((1, 16, 21, 90, 160), device='jax').apply_jax(jax.device_put, replicate), + torch.tensor([1], device='jax').apply_jax(jax.device_put, replicate), + torch.randn((1, 512, 4096), device='jax').apply_jax(jax.device_put, replicate)) + + with mesh, torch.no_grad(): + for i in range(5): + if i == 4: + jax.profiler.start_trace(PROFILE_OUT_PATH) + inputs = make_input() + jax.block_until_ready(inputs[0].jax()) + start = time.perf_counter() + res = pipe.transformer(*inputs, None, return_dict=False, attention_kwargs=None) + res[0].jax().block_until_ready() + end = time.perf_counter() + if i == 4: + jax.profiler.stop_trace() + print(f'Iteration {i}: {end - start:.6f}s') + + print('DONE') + + #print(f'生成视频时长= {(num_frams-1)/fps} - 目前针对1.3B生成5s = (41-1)/8) + + +if __name__ == '__main__': + main() diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 93b11c2..f2c54d0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,11 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import jax from typing import Any, Dict, List, Optional, Tuple +from torch.nn.utils import stateless as torch_stateless + import torch +import torchax +from torchax import interop +from jax.experimental import shard_map import torch.nn.functional as F from torch import nn +from jax.sharding import PartitionSpec as P from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph @@ -1242,10 +1249,27 @@ def __init__( if final_dropout: self.net.append(nn.Dropout(dropout)) + def _forward_single(self, weight, hidden_states): + weight = jax.lax.all_gather(weight, 'axis', axis=0, tiled=True) + tweight = interop.torch_view(weight) + with torch_stateless._reparametrize_module(self, tweight): + for module in self.net: + hidden_states = module(interop.torch_view(hidden_states)) + return hidden_states.jax() + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states + + env = torchax.default_env() + + res = shard_map.shard_map( + self._forward_single, + mesh=env._mesh, + in_specs=(P('axis',), P()), + out_specs=P(), + check_rep=False, + )(interop.jax_view(self.state_dict()), hidden_states.jax()) + return interop.torch_view(res) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 34276a5..63ef65f 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from jax.sharding import PartitionSpec as P import inspect import math from typing import Callable, List, Optional, Tuple, Union @@ -3302,6 +3303,11 @@ def __call__( key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + breakpoint() + + query.shard_(P(None, None, 'axis', None)) + key.shard_(P(None, None, 'axis', None)) + value.shard_(P(None, None, 'axis', None)) if attn.norm_q is not None: query = attn.norm_q(query) @@ -3330,6 +3336,8 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor + # replicate everything? + hidden_states.shard(P()) return hidden_states diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index b6e782c..486f8e0 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from jax.sharding import PartitionSpec as P import math from typing import Any, Dict, Optional, Tuple, Union @@ -96,6 +96,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -151,7 +152,7 @@ def forward( ): timestep = self.timesteps_proj(timestep) - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + time_embedder_dtype = next(iter(self.time_embedder.state_dict().values())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: timestep = timestep.to(time_embedder_dtype) temb = self.time_embedder(timestep).type_as(encoder_hidden_states) @@ -289,6 +290,7 @@ def forward( return hidden_states +# NOTE: this is the one class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in the Wan model. @@ -396,6 +398,7 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + print('transformer', locals()) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -430,15 +433,15 @@ def forward( if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) - # 4. Transformer blocks - if torch.is_grad_enabled() and self.gradient_checkpointing: - for block in self.blocks: - hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb - ) - else: - for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + # # 4. Transformer blocks + # if torch.is_grad_enabled() and self.gradient_checkpointing: + # for block in self.blocks: + # hidden_states = self._gradient_checkpointing_func( + # block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + # ) + # else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) diff --git a/test_quantize.py b/test_quantize.py new file mode 100644 index 0000000..ecbe0e5 --- /dev/null +++ b/test_quantize.py @@ -0,0 +1,32 @@ +import jax.numpy as jnp +import torch.nn.functional as F +import time +import torch +import torchax +import jax + +torchax.enable_globally() + + +size = 50000 + + +a = torch.randn((size, size), dtype=torch.bfloat16, device='jax') +b = torch.randn((size, size), dtype=torch.bfloat16, device='jax') + +for i in range(3): + start = time.perf_counter() + jax.lax.dot(a.jax(), b.jax(), preferred_element_type=jnp.bfloat16).block_until_ready() + end = time.perf_counter() + print(i, end - start) + +c = torch.randn((size, size), dtype=torch.int8, device='jax') +d = torch.randn((size, size), dtype=torch.bfloat16, device='jax') + +print(' === int8 ') + +for i in range(3): + start = time.perf_counter() + jax.lax.dot(c.jax(), d.jax(), preferred_element_type=jnp.bfloat16).block_until_ready() + end = time.perf_counter() + print(i, end - start) \ No newline at end of file diff --git a/wan_tx_splash_attn.py b/wan_tx_splash_attn.py index 4df44f5..98dd256 100644 --- a/wan_tx_splash_attn.py +++ b/wan_tx_splash_attn.py @@ -8,10 +8,12 @@ import jax import jax.numpy as jnp import numpy as np +import quantize from jax.experimental.pallas.ops.tpu import splash_attention from jax.experimental.shard_map import shard_map from jax.sharding import NamedSharding, PartitionSpec as P +from torchax import interop # Add JAX VAE imports from flax import nnx @@ -63,7 +65,8 @@ # NUM_STEP = 1 BQSIZE = 2520 # 2240 # 3024 #2520 -BKVSIZE = 2048 # 2304 # 1664 #2048 +BKVSIZE = 1024 # 2048 # 2304 # 1664 #2048 + # <--- NEW: Local Attention Window Size Setting ---> # window_size = (left, right). (128, 0) means each token can attend to itself and the previous 128 tokens. @@ -78,52 +81,65 @@ axis = 'axis' -# Sharding for tranformers, all the replicated are commented out for speed transformer_shardings = { # 'scale_shift_table': (), # (torch.Size([1, 2, 1536]), torch.float32) # 'patch_embedding.weight': (), # (torch.Size([1536, 16, 1, 2, 2]), torch.bfloat16) # 'patch_embedding.bias': (), # (torch.Size([1536]), torch.bfloat16) r'condition_embedder.time_embedder.linear_1.weight': (axis, None), # (torch.Size([1536, 256]), torch.float32) +r'condition_embedder.time_embedder.linear_1.weight_scaler': (axis, ), # (torch.Size([1536, 256]), torch.float32) r'condition_embedder.time_embedder.linear_1.bias': (axis,), # (torch.Size([1536]), torch.float32) r'condition_embedder.time_embedder.linear_2.weight': (None, axis), # (torch.Size([1536, 1536]), torch.float32) +r'condition_embedder.time_embedder.linear_2.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.float32) # 'condition_embedder.time_embedder.linear_2.bias': (), # (torch.Size([1536]), torch.float32) # 'condition_embedder.time_proj.weight': (), # (torch.Size([9216, 1536]), torch.bfloat16) # 'condition_embedder.time_proj.bias': (), # (torch.Size([9216]), torch.bfloat16) r'condition_embedder.text_embedder.linear_1.weight': (axis, None), # (torch.Size([1536, 4096]), torch.bfloat16) +r'condition_embedder.text_embedder.linear_1.weight_scaler': (axis, ), # (torch.Size([1536, 4096]), torch.bfloat16) r'condition_embedder.text_embedder.linear_1.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) r'condition_embedder.text_embedder.linear_2.weight': (None, axis), # (torch.Size([1536, 1536]), torch.bfloat16) +r'condition_embedder.text_embedder.linear_2.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.bfloat16) # 'condition_embedder.text_embedder.linear_2.bias': (), # (torch.Size([1536]), torch.bfloat16) # 'blocks.\d+.scale_shift_table': (), # (torch.Size([1, 6, 1536]), torch.float32) # 'blocks.\d+.attn1.norm_q.weight': (), # (torch.Size([1536]), torch.bfloat16) # 'blocks.\d+.attn1.norm_k.weight': (), # (torch.Size([1536]), torch.bfloat16) r'blocks.\d+.attn1.to_q.weight': (axis, None), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_q.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) r'blocks.\d+.attn1.to_q.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) r'blocks.\d+.attn1.to_k.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_k.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) r'blocks.\d+.attn1.to_k.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) r'blocks.\d+.attn1.to_v.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_v.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) r'blocks.\d+.attn1.to_v.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) # to_out has 2 submodules, the first is the Linear and second is dropout r'blocks.\d+.attn1.to_out.0.weight': (None, axis), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn1.to_out.0.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.bfloat16) # 'blocks.\d+.attn1.to_out.0.bias': (), # (torch.Size([1536]), torch.bfloat16) # 'blocks.\d+.attn1.to_out.1.weight': (), # (torch.Size([1536, 1536]), torch.bfloat16) # 'blocks.\d+.attn1.to_out.1.bias': (), # (torch.Size([1536]), torch.bfloat16) # 'blocks.\d+.attn2.norm_q.weight': (), # (torch.Size([1536]), torch.bfloat16) # 'blocks.\d+.attn2.norm_k.weight': (), # (torch.Size([1536]), torch.bfloat16) r'blocks.\d+.attn2.to_q.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_q.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) r'blocks.\d+.attn2.to_q.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) r'blocks.\d+.attn2.to_k.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_k.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) r'blocks.\d+.attn2.to_k.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) r'blocks.\d+.attn2.to_v.weight': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_v.weight_scaler': (axis, ), # (torch.Size([1536, 1536]), torch.bfloat16) r'blocks.\d+.attn2.to_v.bias': (axis, ), # (torch.Size([1536]), torch.bfloat16) r'blocks.\d+.attn2.to_out.0.weight': (None, axis), # (torch.Size([1536, 1536]), torch.bfloat16) +r'blocks.\d+.attn2.to_out.0.weight_scaler': (None, ), # (torch.Size([1536, 1536]), torch.bfloat16) # 'blocks.\d+.attn2.to_out.0.bias': (), # (torch.Size([1536]), torch.bfloat16) # 'blocks.\d+.attn2.to_out.1.weight': (), # (torch.Size([1536, 1536]), torch.bfloat16) # 'blocks.\d+.attn2.to_out.1.bias': (), # (torch.Size([1536]), torch.bfloat16) # 'blocks.\d+.norm2.weight': (), # (torch.Size([1536]), torch.float32) # 'blocks.\d+.norm2.bias': (), # (torch.Size([1536]), torch.float32) r'blocks.\d+.ffn.net.0.proj.weight': (axis,), # (torch.Size([8960, 1536]), torch.bfloat16) +r'blocks.\d+.ffn.net.0.proj.weight_scaler': (axis,), # (torch.Size([8960, 1536]), torch.bfloat16) r'blocks.\d+.ffn.net.0.proj.bias': (axis, ), # (torch.Size([8960]), torch.bfloat16) r'blocks.\d+.ffn.net.2.weight': (None, axis), # (torch.Size([1536, 8960]), torch.bfloat16) +r'blocks.\d+.ffn.net.2.weight_scaler': (None, ), # (torch.Size([1536, 8960]), torch.bfloat16) # 'blocks.\d+.ffn.net.2.bias': (), # (torch.Size([1536]), torch.bfloat16) # 'proj_out.weight': (), # (torch.Size([64, 1536]), torch.bfloat16) # 'proj_out.bias': (), # (torch.Size([64]), torch.bfloat16) @@ -158,6 +174,18 @@ def _shard_weight_dict(weight_dict, sharding_dict, mesh): result[k] = v return result + +def _shard_weight_fsdp(weight_dict, sh, mesh): + result = {} + for k, v in weight_dict.items(): + if len(v.shape) == 2: + v.apply_jax_(jax.device_put, NamedSharding(mesh, P('axis'))) + else: + v.apply_jax_(jax.device_put, NamedSharding(mesh, P())) + result[k] = v + return result + + def flatten_model_output(obj): return obj.to_tuple(), type(obj) @@ -243,10 +271,22 @@ def _sdpa_reference( attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value +CF_FOR_ALL_REDUCE_AND_ALL_GATHER = ( + " --xla_enable_async_all_reduce=true" + " --xla_enable_async_all_gather=true" + " --xla_tpu_overlap_compute_collective_tc=true" + " --xla_tpu_enable_async_collective_fusion=true" + " --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" + " --xla_tpu_enable_async_collective_fusion_multiple_steps=true" +) +import os +os.environ['LIBTPU_INIT_ARGS'] = CF_FOR_ALL_REDUCE_AND_ALL_GATHER + + # <--- MODIFIED: Added window_size parameter to the function signature ---> def _tpu_splash_attention(query, key, value, env, scale=None, is_causal=False, window_size=None): - import jax + # print('TPU flash attention', jax.typeof(query), jax.typeof(key), jax.typeof(value)) import math mesh = env._mesh num_heads = query.shape[1] @@ -264,6 +304,8 @@ def _attention_on_slices(q, k, v): # Helper to pad to next multiple def pad_to_multiple(x, multiple, axis): seq_len = x.shape[axis] + if seq_len <= multiple: + return x, seq_len pad_len = (multiple - seq_len % multiple) % multiple if pad_len == 0: return x, seq_len @@ -299,6 +341,7 @@ def kernel_3d(q_3d, k_3d, v_3d): block_sizes = splash_attention.BlockSizes( block_q=min(BQSIZE, padded_q_seq_len), block_kv=min(BKVSIZE, padded_kv_seq_len) ) + print(block_sizes) splash_kernel = splash_attention.make_splash_mha( mask=mask, block_sizes=block_sizes, head_shards=1, q_seq_shards=1 ) @@ -502,37 +545,22 @@ def __getitem__(self, key): def __setitem__(self, key, value): setattr(self, key, value) -def to_torch_recursive(x): - import torch - import numpy as np - if 'ArrayImpl' in str(type(x)): - return torch.from_numpy(np.array(x)) - elif isinstance(x, (list, tuple)): - return type(x)(to_torch_recursive(xx) for xx in x) - elif isinstance(x, dict): - return {k: to_torch_recursive(v) for k, v in x.items()} - elif hasattr(x, 'sample'): - sample = to_torch_recursive(x.sample) - if hasattr(x, 'replace'): - return x.replace(sample=sample) - else: - return sample - else: - return x - class VAEProxy: + def __init__(self, vae, vae_cache, dtype, config): self._vae = vae self.vae_cache = vae_cache self.dtype = dtype self.config = config + def __getattr__(self, name): return getattr(self._vae, name) + def decode(self, *args, **kwargs): if 'feat_cache' not in kwargs: kwargs['feat_cache'] = self.vae_cache - out = self._vae.decode(*args, **kwargs) - return to_torch_recursive(out) + out = interop.call_jax(self._vae.decode, *args, **kwargs) + return out def prepare_video_for_export(video): import torch @@ -628,9 +656,7 @@ def main(): # We'll use JAX VAE directly # Temporarily disable torchax to load pipeline components - torchax.disable_globally() - - try: + with torchax.disable_temporarily(): # flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P flow_shift = FLOW_SHIFT scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift) @@ -638,9 +664,6 @@ def main(): # Load pipeline without VAE to avoid torchax interference pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16, use_safetensors=True) pipe.scheduler = scheduler - finally: - # Re-enable torchax for the rest of the pipeline - torchax.enable_globally() # Replace the VAE in the pipeline with our JAX VAE vae_config = ConfigWrapper( @@ -711,21 +734,26 @@ def _move_module(module): _move_module(pipe.text_encoder) pipe.text_encoder = torchax.compile(pipe.text_encoder) + + #pipe.transformer = quantize.quantize_model(pipe.transformer) # the param below is not declared as param or buffer so the module.to('jax') didnt work _move_module(pipe.transformer) pipe.transformer.rope.freqs = pipe.transformer.rope.freqs.to('jax') options = torchax.CompileOptions( jax_jit_kwargs={'static_argnames': ('return_dict',)} ) + + pipe.transformer = torchax.compile(pipe.transformer, options) #pipe.to('jax') print('Number of devices is:, ', len(jax.devices())) - pipe.transformer.params = _shard_weight_dict(pipe.transformer.params, + #pipe.transformer.params = _shard_weight_dict(pipe.transformer.params, + pipe.transformer.params = _shard_weight_fsdp(pipe.transformer.params, transformer_shardings, mesh) - pipe.transformer.buffers = _shard_weight_dict(pipe.transformer.buffers, + pipe.transformer.buffers = _shard_weight_fsdp(pipe.transformer.buffers, transformer_shardings, mesh) pipe.text_encoder.params = _shard_weight_dict(pipe.text_encoder.params, @@ -794,19 +822,22 @@ def module_size(module): # profile set fewer step and output latent to skip VAE for now # output_type='latent' will skip VAE - #jax.profiler.start_trace(PROFILE_OUT_PATH) - #output = pipe( - # prompt=prompt, - # negative_prompt=negative_prompt, - # height=HEIGHT, - # width=WIDTH, - # num_inference_steps=2, - # num_frames=FRAMES, - # guidance_scale=5.0, - # output_type="latent", - #) - #jax.effects_barrier() - #jax.profiler.stop_trace() + options = jax.profiler.ProfileOptions() + options.advanced_configuration = {"tpu_trace_mode" : "TRACE_COMPUTE_AND_SYNC"} + + jax.profiler.start_trace(PROFILE_OUT_PATH, create_perfetto_link=False, profiler_options=options) + output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=HEIGHT, + width=WIDTH, + num_inference_steps=2, + num_frames=FRAMES, + guidance_scale=5.0, + output_type="latent", + ) + jax.effects_barrier() + jax.profiler.stop_trace() #print("profile done") # Benchmark loop