@@ -66,6 +66,25 @@ class AttnBlock : public UnaryBlock {
6666 int64_t in_channels;
6767 bool use_linear;
6868
69+ void init_params (struct ggml_context * ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = " " ) {
70+ auto iter = tensor_storage_map.find (prefix + " proj_out.weight" );
71+ if (iter != tensor_storage_map.end ()) {
72+ if (iter->second .n_dims == 4 && use_linear) {
73+ use_linear = false ;
74+ blocks[" q" ] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1 , 1 });
75+ blocks[" k" ] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1 , 1 });
76+ blocks[" v" ] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1 , 1 });
77+ blocks[" proj_out" ] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1 , 1 });
78+ } else if (iter->second .n_dims == 2 && !use_linear) {
79+ use_linear = true ;
80+ blocks[" q" ] = std::make_shared<Linear>(in_channels, in_channels);
81+ blocks[" k" ] = std::make_shared<Linear>(in_channels, in_channels);
82+ blocks[" v" ] = std::make_shared<Linear>(in_channels, in_channels);
83+ blocks[" proj_out" ] = std::make_shared<Linear>(in_channels, in_channels);
84+ }
85+ }
86+ }
87+
6988public:
7089 AttnBlock (int64_t in_channels, bool use_linear)
7190 : in_channels(in_channels), use_linear(use_linear) {
0 commit comments