2
2
3
3
template <int block_size>
4
4
static __global__ void ssm_conv_f32 (
5
- const float * src0, const float * src1, const float * src2, const float * src3,
6
- const int src0_ne0, const int src0_nb1, const int src0_nb2,
7
- const int src1_nb0, const int src1_nb1,
8
- const int src2_nb1, const int src2_nb2,
9
- const int src3_nb1,
5
+ const float * src0, const float * src1, const float * src2,
6
+ const int src0_nb1, const int src0_nb2,
7
+ const int src1_nb0, const int src1_nb1, const int src1_nb2,
8
+ const int src2_nb1,
10
9
float * dst,
11
- const int nc, const int nr, const int n_t , const int n_kv) {
10
+ const int dst_nb0, const int dst_nb1, const int dst_nb2,
11
+ const int nc, const int nr, const int n_t , const int n_s) {
12
12
13
13
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
14
14
const int tid = threadIdx .x ;
@@ -24,136 +24,118 @@ static __global__ void ssm_conv_f32(
24
24
const int ir1 = min (ir0 + dr, nr);
25
25
const int ir = ir1 - ir0;
26
26
27
- if (n_kv > 1 ) {
28
- // multiple sequences means it's hard to know when it's the first time a state is read,
29
- // so copy them all over to the destination, just to be sure.
30
- for (int i3 = 0 ; i3 < n_kv; ++i3) {
31
- float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2);
32
- float * s = (float *) ((char *) dst + ir0*src2_nb1 + i3*src2_nb2 + nr*n_t *sizeof (float ));
33
- // can't use memcpy because of d_conv vs d_conv - 1
34
- for (int i1 = 0 ; i1 < ir; ++i1) {
35
- for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
36
- // copy s0 to last (d_conv - 1) columns of s
37
- s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1 )];
38
- }
39
- }
40
- }
41
- }
27
+ // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
28
+ // This would avoid having to copy into an intermediate buffer, but the state would be bigger.
42
29
43
- for (int i2 = 0 ; i2 < n_t ; ++i2) {
44
- int32_t * sq = (int32_t *) ((char *) src3 + i2*src3_nb1); // {n_kv, n_tokens}
45
- float * x = (float *) ((char *) dst + ir0*sizeof (float ) + i2*(nr*sizeof (float ))); // {d_inner, n_tokens}
46
- float * s = (float *) ((char *) dst + ir0*src2_nb1 + sq[0 ]*src2_nb2 + nr*n_t *sizeof (float )); // {d_conv, d_inner, n_kv}
47
- float * s0; // {d_conv - 1, d_inner, n_kv}
48
- float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1); // {d_inner, n_tokens}
49
- float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
50
- int ne0s0;
51
-
52
- // avoid needing to copy the state for the first token
53
- if (i2 == 0 ) {
54
- s0 = (float *) ((char *) src0 + ir0*src0_nb1 + sq[0 ]*src0_nb2); // {d_conv - 1, d_inner, n_kv}
55
- ne0s0 = src0_ne0;
56
- } else {
57
- // the source is the last (d_conv - 1) columns of the destination
58
- s0 = s + 1 ;
59
- ne0s0 = nc;
60
- }
30
+ // float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
31
+ extern __shared__ float wdata_f32[]; // work buffer for all threads
32
+ float * s = (float *) wdata_f32 + nc*dr*ith;
61
33
62
- // d_inner
34
+ for (int i3 = 0 ; i3 < n_s; ++i3) {
35
+ float * s0 = (float *) ((char *) src0 + ir0*src0_nb1) + i3*src0_nb2; // {d_conv, d_inner, n_s}
36
+
37
+ // copy the state into working memory
38
+ // can't use memcpy because (d_conv) != (d_conv - 1)
63
39
for (int i1 = 0 ; i1 < ir; ++i1) {
64
- // shift state left
65
40
for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
66
- s[i0 + i1*nc] = s0[i0 + i1*ne0s0 ];
41
+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1 ) ];
67
42
}
68
- // insert x on the last column
69
- s[(nc - 1 ) + i1*nc] = x0[i1];
70
43
}
71
44
72
- // handle copies when there are multiple output states
73
- for (int i3 = 1 ; i3 < n_kv; ++i3) {
74
- int32_t seq = sq[i3];
75
- if (0 <= seq && seq < n_kv) {
76
- float * s1 = s + (seq - sq[0 ])*nc*nr;
45
+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
46
+ float * x = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
47
+ float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
48
+ float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
77
49
78
- // memcpy(s1, s, nc*ir*sizeof(float));
79
- for (int i4 = 0 ; i4 < nc*ir; i4++) {
80
- s1[i4] = s[i4];
50
+ // shift state left
51
+ // memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
52
+ for (int i4 = 0 ; i4 < nc*ir - 1 ; ++i4) {
53
+ s[i4] = s[i4+1 ];
54
+ }
55
+
56
+ // d_inner
57
+ for (int i1 = 0 ; i1 < ir; ++i1) {
58
+ // insert x on the last column
59
+ s[(nc - 1 ) + i1*nc] = x0[i1];
60
+ }
61
+
62
+ // it seems a little faster when this is separate from the state shift
63
+ for (int i1 = 0 ; i1 < ir; ++i1) {
64
+ // rowwise dot product
65
+ // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
66
+ float sumf = 0 .0f ;
67
+ for (int i0 = 0 ; i0 < nc; ++i0) {
68
+ int i = i0 + i1*nc;
69
+ sumf += s[i] * c[i];
81
70
}
82
- } else {
83
- // stop at negative or too big seq_ids
84
- break ;
71
+ x[i1] = sumf;
85
72
}
86
73
}
87
74
88
- // it seems a little faster when this is separate from the state shift
75
+ // copy the state out of it
89
76
for (int i1 = 0 ; i1 < ir; ++i1) {
90
- // rowwise dot product
91
- float sumf = 0 .0f ;
92
- for (int i0 = 0 ; i0 < nc; ++i0) {
93
- int i = i0 + i1*nc;
94
- sumf += s[i] * c[i];
77
+ for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
78
+ s0[i0 + i1*(nc - 1 )] = s[1 + i0 + i1*nc];
95
79
}
96
- x[i1] = sumf;
97
80
}
98
81
}
99
82
}
100
83
101
84
static void ssm_conv_f32_cuda (
102
- const float * src0, const float * src1, const float * src2, const float * src3,
103
- const int src0_ne0, const int src0_nb1, const int src0_nb2,
104
- const int src1_nb0, const int src1_nb1,
105
- const int src2_nb1, const int src2_nb2,
106
- const int src3_nb1,
85
+ const float * src0, const float * src1, const float * src2,
86
+ const int src0_nb1, const int src0_nb2,
87
+ const int src1_nb0, const int src1_nb1, const int src1_nb2,
88
+ const int src2_nb1,
107
89
float * dst,
108
- const int nc, const int nr, const int n_t , const int n_kv, cudaStream_t stream) {
90
+ const int dst_nb0, const int dst_nb1, const int dst_nb2,
91
+ const int nc, const int nr, const int n_t , const int n_s,
92
+ cudaStream_t stream) {
109
93
110
94
const dim3 block_dims (WARP_SIZE, 1 , 1 );
111
95
const int nblocks = 1 ; // TODO
96
+ const int shmem_size = nc * (nr + WARP_SIZE - 1 ) * sizeof (float ); // TODO
112
97
113
- ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0 , stream>>> (
114
- src0, src1, src2, src3,
115
- src0_ne0, src0_nb1, src0_nb2,
116
- src1_nb0, src1_nb1,
117
- src2_nb1, src2_nb2,
118
- src3_nb1,
98
+ ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, shmem_size, stream>>> (
99
+ src0, src1, src2,
100
+ src0_nb1, src0_nb2,
101
+ src1_nb0, src1_nb1, src1_nb2,
102
+ src2_nb1,
119
103
dst,
120
- nc, nr, n_t , n_kv);
104
+ dst_nb0, dst_nb1, dst_nb2,
105
+ nc, nr, n_t , n_s);
121
106
}
122
107
123
108
void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
124
109
const struct ggml_tensor * src0 = dst->src [0 ]; // conv_state
125
110
const struct ggml_tensor * src1 = dst->src [1 ]; // x
126
111
const struct ggml_tensor * src2 = dst->src [2 ]; // conv1d.weight
127
- const struct ggml_tensor * src3 = dst->src [3 ]; // state_seq
128
112
129
- const int nc = src2->ne [0 ]; // d_conv
130
- const int nr = src0->ne [1 ]; // d_inner
131
- const int n_t = src1->ne [1 ]; // n_tokens
132
- const int n_kv = src0->ne [2 ]; // max number of sequences in the batch
113
+ const int nc = src2->ne [0 ]; // d_conv
114
+ const int nr = src0->ne [1 ]; // d_inner
115
+ const int n_t = src1->ne [1 ]; // tokens per sequence
116
+ const int n_s = src0->ne [2 ]; // number of sequences in the batch
133
117
134
- GGML_ASSERT ((nr* n_t ) + (nc*nr*n_kv) == ggml_nelements ( dst));
118
+ GGML_ASSERT (ggml_are_same_shape (src1, dst));
135
119
GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
136
120
GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
137
121
GGML_ASSERT (src2->nb [0 ] == sizeof (float ));
138
- GGML_ASSERT (src3->nb [0 ] == sizeof (int32_t ));
139
122
GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ]*sizeof (float ));
140
- // for use with the destination state offset between sequences
141
- GGML_ASSERT (src2->nb [2 ] == src2->ne [1 ]*src2->ne [0 ]*sizeof (float ));
142
123
143
124
const float * src0_d = (const float *)src0->data ;
144
125
const float * src1_d = (const float *)src1->data ;
145
126
const float * src2_d = (const float *)src2->data ;
146
- const float * src3_d = (const float *)src3->data ;
147
127
float * dst_d = (float *)dst->data ;
148
128
cudaStream_t stream = ctx.stream ();
149
129
150
130
GGML_ASSERT (src0->type == GGML_TYPE_F32);
151
131
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
152
132
153
- ssm_conv_f32_cuda (src0_d, src1_d, src2_d, src3_d,
154
- src0->ne [0 ], src0->nb [1 ], src0->nb [2 ],
155
- src1->nb [0 ], src1->nb [1 ],
156
- src2->nb [1 ], src2->nb [2 ],
157
- src3->nb [1 ],
158
- dst_d, nc, nr, n_t , n_kv, stream);
133
+ ssm_conv_f32_cuda (src0_d, src1_d, src2_d,
134
+ src0->nb [1 ], src0->nb [2 ],
135
+ src1->nb [0 ], src1->nb [1 ], src1->nb [2 ],
136
+ src2->nb [1 ],
137
+ dst_d,
138
+ dst->nb [0 ], dst->nb [1 ], dst->nb [2 ],
139
+ nc, nr, n_t , n_s,
140
+ stream);
159
141
}
0 commit comments