Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
# _C extension
#
set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/layernorm.cpp"
Expand Down
144 changes: 0 additions & 144 deletions csrc/cpu/activation.cpp

This file was deleted.

13 changes: 9 additions & 4 deletions csrc/cpu/pos_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,24 @@ void rotary_embedding_impl(
}
};

#pragma omp parallel for
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, query);
}
}

#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
Expand Down
22 changes: 0 additions & 22 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);

// Activation ops

// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);

// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);

// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);

// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCPU, &gelu_new);

// Approximate GELU implementation.
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def forward_xpu(self, *args, **kwargs):
raise NotImplementedError

def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
# By default, we assume that CPU ops are generated by torch.compile.
return self.forward_native(*args, **kwargs)

def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ def forward_xpu(
)
return out

def forward_cpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# Note: the forward_native() with torch.compile has significant
# performance regression.
return self.forward_cuda(x, residual)

def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ def forward_tpu(
if self.use_native2 else self.forward_native)
return forward_fn(positions, query, key, offsets)

def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Note: the forward_native() with torch.compile has significant
# performance regression.
return self.forward_cuda(positions, query, key, offsets)

def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
Expand Down
111 changes: 90 additions & 21 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad

logger = init_logger(__name__)
Expand Down Expand Up @@ -78,6 +79,10 @@ def __init__(
# Lazy initialization.
self.model: nn.Module # Set after init_Model

@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()

def load_model(self) -> None:
self.model = get_model(
model_config=self.model_config,
Expand All @@ -89,6 +94,25 @@ def load_model(self) -> None:
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)

# Apply torch.compile to custom ops
from vllm.model_executor.custom_op import CustomOp
Comment on lines +97 to +98
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need lazy import here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, we only compile the CustomOps, so import the class for type identification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually my question was why we import these "lazily".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No special reason. I think importing related classes at the local scope will make maintenance more convenient (for example, adding more transformation or moving the procedure to other places).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I personally think it's always good to avoid lazy imports whenever possible, but I agree that it can be a matter of personal preference. I'm ok with keeping it.


def replace_model(model: torch.nn.Module):
for _, m in model.named_children():
if isinstance(m, CustomOp):
m.forward_native = torch.compile(m.forward_native,
dynamic=True,
options={
"fx_graph_cache":
True,
"cpp_wrapper": True,
"dce": True
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add return here for calrity?

Comment on lines +100 to +110
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering, can we do this in CustomOp.forward_cpu instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this, but it didn't work.

  • The forward actually uses _forward_method, so we should replace _forward_method.
  • If we replaced _forward_method, the torch.complie will raise a error mutable rms_norm.default is not supported with cpp_wrapper. Seems cpp_wrapper is not compatible with RMSNorm.forward_cuda.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. This doesn't look aesthetically good to me, but I don't have an alternative solution... 😞

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Could you please take a look at this part of code if you have time? Just wondering if you have any suggestion, as the code doesn't look ideal to me.

replace_model(m)
return

replace_model(self.model)

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down Expand Up @@ -134,26 +158,29 @@ def _prepare_prompt(
multi_modal_kwargs_list[k].append(v)

# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, seq_len - self.sliding_window)

for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue

block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if seq_group_metadata.block_tables is None:
block_table = []
else:
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID
# , where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, seq_len - self.sliding_window)

for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue

block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

multi_modal_kwargs = {
k: torch.cat(v, dim=0).to(self.device)
Expand Down Expand Up @@ -361,3 +388,45 @@ def execute_model(
sampling_metadata=sampling_metadata,
)
return output

@torch.inference_mode()
def profile_run(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of the profiling urn? Is the goal invoking torch.compile for different input shapes? Or is it for measuring CPU memory usage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, to invoking torch.compile for batchsize=1 and batchsize=others

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we profiling for those particular shapes? IIRC, torch.compile supports dynamic shapes unless some advanced features are used (e.g., CUDA graphs).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I noticed torch.compile will generate different code for batchsize=1 and batchsize=others under the dynamic mode. So we should invoke them all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for the clarification.

# Warming up the model with batchsize = [1, max_num_seqs,
# max_num_batched_tokens], to generate corresponding operators
# using torch.compile.

model_config = self.model_config
vlm_config = self.vision_language_config

sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = min(
self.scheduler_config.max_num_batched_tokens,
self.model_config.max_model_len,
)
max_num_seqs = self.scheduler_config.max_num_seqs

assert self.lora_config is None

# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers

for seq_len in [1, max_num_seqs, max_num_batched_tokens]:
if vlm_config:
seq_data, dummy_multi_modal_data = (
MULTIMODAL_REGISTRY.dummy_data_for_profiling(
seq_len, model_config, vlm_config))
else:
seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
seq = SequenceGroupMetadata(
request_id=str(0),
is_prompt=True,
seq_data={0: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=None,
multi_modal_data=dummy_multi_modal_data,
)
self.execute_model([seq], kv_caches)
return
Loading