Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/offline_distributed_inference_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
# TODO (cmq): ray is not supported currently, need some fixes
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=2,
distributed_executor_backend="mp",
distributed_executor_backend="ray",
trust_remote_code=True,
)

Expand Down
21 changes: 11 additions & 10 deletions vllm_ascend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type

Expand Down Expand Up @@ -520,7 +518,7 @@ def forward(
attn_metadata.sparse_mode = 2
attention_mask = gen_input_mask(
attn_metadata.max_prefill_seq_len, self.sliding_window,
num_tokens)
num_tokens, query.device)
attn_metadata.attn_mask = attention_mask

if (self.alibi_slopes is not None
Expand All @@ -531,6 +529,7 @@ def forward(
dtype=query.dtype,
seq_len=attn_metadata.max_prefill_seq_len,
batch_size=num_tokens,
device=query.device,
)

if (len(kv_cache) == 0 or attn_metadata.block_tables is None
Expand Down Expand Up @@ -571,7 +570,7 @@ def forward(
query = query.view(query.shape[0], -1,
self.num_heads * self.head_size)
output = torch.zeros(query.shape,
device="npu",
device=query.device,
dtype=query.dtype)
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
# support only when `S == 1`, OPTIMIZE ME when prefix caching
Expand Down Expand Up @@ -621,7 +620,7 @@ def forward(
return output


def gen_input_mask(seq_len, sliding_window, len):
def gen_input_mask(seq_len, sliding_window, len, device):
"""
Generating lower triangular matrix
"""
Expand All @@ -630,15 +629,15 @@ def gen_input_mask(seq_len, sliding_window, len):
global SHARE_MASK_TRIL_PREFIX_CACHE
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"),
torch.ones(1, 1, 2048, 2048, dtype=bool, device=device),
diagonal=1,
)
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
else:
global SHARE_MASK_TRIL
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
SHARE_MASK_TRIL = ~torch.tril(
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))
torch.ones(seq_len, seq_len, dtype=bool, device=device))

attention_mask = SHARE_MASK_TRIL
if sliding_window is not None:
Expand All @@ -656,8 +655,10 @@ def _make_alibi_bias(
dtype: torch.dtype,
seq_len: int,
batch_size: int,
device: torch.device,
):
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
alibi_slopes = alibi_slopes.to(device)
bias = torch.arange(seq_len, dtype=dtype, device=device)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
Expand All @@ -674,7 +675,7 @@ def _make_alibi_bias(
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
device=device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
Expand Down