Skip to content

Commit c2d8ffc

Browse files
authored
fix: compatibility for models with modified tensor shapes (#951)
1 parent fb748bb commit c2d8ffc

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

common.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,22 @@ class SpatialTransformer : public GGMLBlock {
410410
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
411411
bool use_linear = false;
412412

413+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
414+
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
415+
if (iter != tensor_storage_map.end()) {
416+
int64_t inner_dim = n_head * d_head;
417+
if (iter->second.n_dims == 4 && use_linear) {
418+
use_linear = false;
419+
blocks["proj_in"] = std::make_shared<Conv2d>(in_channels, inner_dim, std::pair{1, 1});
420+
blocks["proj_out"] = std::make_shared<Conv2d>(inner_dim, in_channels, std::pair{1, 1});
421+
} else if (iter->second.n_dims == 2 && !use_linear) {
422+
use_linear = true;
423+
blocks["proj_in"] = std::make_shared<Linear>(in_channels, inner_dim);
424+
blocks["proj_out"] = std::make_shared<Linear>(inner_dim, in_channels);
425+
}
426+
}
427+
}
428+
413429
public:
414430
SpatialTransformer(int64_t in_channels,
415431
int64_t n_head,

ggml_extend.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1926,8 +1926,8 @@ class GGMLBlock {
19261926
if (prefix.size() > 0) {
19271927
prefix = prefix + ".";
19281928
}
1929-
init_blocks(ctx, tensor_storage_map, prefix);
19301929
init_params(ctx, tensor_storage_map, prefix);
1930+
init_blocks(ctx, tensor_storage_map, prefix);
19311931
}
19321932

19331933
size_t get_params_num() {

vae.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ class AttnBlock : public UnaryBlock {
6666
int64_t in_channels;
6767
bool use_linear;
6868

69+
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
70+
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
71+
if (iter != tensor_storage_map.end()) {
72+
if (iter->second.n_dims == 4 && use_linear) {
73+
use_linear = false;
74+
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
75+
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
76+
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
77+
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
78+
} else if (iter->second.n_dims == 2 && !use_linear) {
79+
use_linear = true;
80+
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
81+
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
82+
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
83+
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
84+
}
85+
}
86+
}
87+
6988
public:
7089
AttnBlock(int64_t in_channels, bool use_linear)
7190
: in_channels(in_channels), use_linear(use_linear) {

0 commit comments

Comments
 (0)