Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 30 additions & 33 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,21 @@ def __init__(
hidden_act: str,
):
super().__init__()
# TODO: Merge the gate and down linear layers.
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()

def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
gate_up, _ = self.gate_up_proj(x)
gate_up = gate_up.reshape(gate_up.shape[:-1] + (-1, 2))
gate, up = torch.split(gate_up, 1, dim=-1)
gate = gate.squeeze(dim=-1).contiguous()
up = up.squeeze(dim=-1).contiguous()
x = self.act_fn(gate) * up
x, _ = self.down_proj(x)
return x
Expand All @@ -127,24 +126,9 @@ def __init__(
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5

# TODO: Merge the QKV linear layers.
self.q_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.k_proj = ColumnParallelLinear(
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.v_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
3 * self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
Expand All @@ -168,9 +152,12 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (-1, 3))
q, k, v = torch.split(qkv, 1, dim=-1)
q = q.squeeze(dim=-1).contiguous()
k = k.squeeze(dim=-1).contiguous()
v = v.squeeze(dim=-1).contiguous()

# Apply rotrary embedding.
# TODO: Optimize.
Expand Down Expand Up @@ -299,17 +286,27 @@ def forward(
return next_tokens

_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "gate_proj.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

def load_weights(self, weights_path: str):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
if "qkv_proj.weight" in name:
q_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj")))
k_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj")))
v_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj")))
loaded_weight = np.stack([q_weight, k_weight, v_weight]).transpose(1, 0, 2)
loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1]))
elif "gate_up_proj.weight" in name:
gate_weight = np.load(os.path.join(weights_path, name.replace("gate_up_proj", "gate_proj")))
up_weight = np.load(os.path.join(weights_path, name.replace("gate_up_proj", "up_proj")))
loaded_weight = np.stack([gate_weight, up_weight]).transpose(1, 0, 2)
loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1]))
else:
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
Expand Down
48 changes: 27 additions & 21 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,9 @@ def __init__(
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5

# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
Expand All @@ -76,16 +69,18 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (-1, 3))
q, k, v = torch.split(qkv, 1, dim=-1)
q = q.squeeze(dim=-1).contiguous()
k = k.squeeze(dim=-1).contiguous()
v = v.squeeze(dim=-1).contiguous()
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output


class OPTDecoderLayer(nn.Module):

def __init__(self, config: OPTConfig):
Expand Down Expand Up @@ -263,11 +258,9 @@ def forward(
self.lm_head_weight, hidden_states, input_metadata)
return next_tokens

_column_parallel_weights = ["embed_tokens.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "fc1.weight"]
_column_parallel_biases = ["q_proj.bias", "k_proj.bias",
"v_proj.bias", "fc1.bias"]
_column_parallel_weights = ["embed_tokens.weight", "qkv_proj.weight",
"fc1.weight"]
_column_parallel_biases = ["qkv_proj.bias", "fc1.bias"]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]

def load_weights(self, weights_path: str):
Expand All @@ -276,8 +269,21 @@ def load_weights(self, weights_path: str):
for name, param in state_dict.items():
if "lm_head_weight" in name:
continue
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
if "qkv_proj.weight" in name:
q_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj")))
k_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj")))
v_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj")))
loaded_weight = np.stack([q_weight, k_weight, v_weight]).transpose(1, 0, 2)
loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1]))
elif "qkv_proj.bias" in name:
q_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj")))
k_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj")))
v_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj")))
loaded_weight = np.stack([q_bias, k_bias, v_bias]).transpose(1, 0)
loaded_weight = torch.from_numpy(loaded_weight.reshape(-1))
else:
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, name)))

for p in (self._column_parallel_weights
+ self._column_parallel_biases):
if p in name:
Expand Down