Skip to content

Commit 697fab6

Browse files
committed
Update CUDA ops ssm_conv and ssm_scan to match CPU implementation from PR #7531 (as per eb589d5)
1 parent 9d0ccf8 commit 697fab6

File tree

3 files changed

+136
-220
lines changed

3 files changed

+136
-220
lines changed

ggml-cuda/ssm_conv.cu

+73-91
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
template <int block_size>
44
static __global__ void ssm_conv_f32(
5-
const float * src0, const float * src1, const float * src2, const float * src3,
6-
const int src0_ne0, const int src0_nb1, const int src0_nb2,
7-
const int src1_nb0, const int src1_nb1,
8-
const int src2_nb1, const int src2_nb2,
9-
const int src3_nb1,
5+
const float * src0, const float * src1, const float * src2,
6+
const int src0_nb1, const int src0_nb2,
7+
const int src1_nb0, const int src1_nb1, const int src1_nb2,
8+
const int src2_nb1,
109
float * dst,
11-
const int nc, const int nr, const int n_t, const int n_kv) {
10+
const int dst_nb0, const int dst_nb1, const int dst_nb2,
11+
const int nc, const int nr, const int n_t, const int n_s) {
1212

1313
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
1414
const int tid = threadIdx.x;
@@ -24,136 +24,118 @@ static __global__ void ssm_conv_f32(
2424
const int ir1 = min(ir0 + dr, nr);
2525
const int ir = ir1 - ir0;
2626

27-
if (n_kv > 1) {
28-
// multiple sequences means it's hard to know when it's the first time a state is read,
29-
// so copy them all over to the destination, just to be sure.
30-
for (int i3 = 0; i3 < n_kv; ++i3) {
31-
float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2);
32-
float * s = (float *) ((char *) dst + ir0*src2_nb1 + i3*src2_nb2 + nr*n_t*sizeof(float));
33-
// can't use memcpy because of d_conv vs d_conv - 1
34-
for (int i1 = 0; i1 < ir; ++i1) {
35-
for (int i0 = 0; i0 < nc - 1; ++i0) {
36-
// copy s0 to last (d_conv - 1) columns of s
37-
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
38-
}
39-
}
40-
}
41-
}
27+
// TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
28+
// This would avoid having to copy into an intermediate buffer, but the state would be bigger.
4229

43-
for (int i2 = 0; i2 < n_t; ++i2) {
44-
int32_t * sq = (int32_t *) ((char *) src3 + i2*src3_nb1); // {n_kv, n_tokens}
45-
float * x = (float *) ((char *) dst + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
46-
float * s = (float *) ((char *) dst + ir0*src2_nb1 + sq[0]*src2_nb2 + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
47-
float * s0; // {d_conv - 1, d_inner, n_kv}
48-
float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1); // {d_inner, n_tokens}
49-
float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
50-
int ne0s0;
51-
52-
// avoid needing to copy the state for the first token
53-
if (i2 == 0) {
54-
s0 = (float *) ((char *) src0 + ir0*src0_nb1 + sq[0]*src0_nb2); // {d_conv - 1, d_inner, n_kv}
55-
ne0s0 = src0_ne0;
56-
} else {
57-
// the source is the last (d_conv - 1) columns of the destination
58-
s0 = s + 1;
59-
ne0s0 = nc;
60-
}
30+
// float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
31+
extern __shared__ float wdata_f32[]; // work buffer for all threads
32+
float * s = (float *) wdata_f32 + nc*dr*ith;
6133

62-
// d_inner
34+
for (int i3 = 0; i3 < n_s; ++i3) {
35+
float * s0 = (float *) ((char *) src0 + ir0*src0_nb1) + i3*src0_nb2; // {d_conv, d_inner, n_s}
36+
37+
// copy the state into working memory
38+
// can't use memcpy because (d_conv) != (d_conv - 1)
6339
for (int i1 = 0; i1 < ir; ++i1) {
64-
// shift state left
6540
for (int i0 = 0; i0 < nc - 1; ++i0) {
66-
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
41+
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
6742
}
68-
// insert x on the last column
69-
s[(nc - 1) + i1*nc] = x0[i1];
7043
}
7144

72-
// handle copies when there are multiple output states
73-
for (int i3 = 1; i3 < n_kv; ++i3) {
74-
int32_t seq = sq[i3];
75-
if (0 <= seq && seq < n_kv) {
76-
float * s1 = s + (seq - sq[0])*nc*nr;
45+
for (int i2 = 0; i2 < n_t; ++i2) {
46+
float * x = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
47+
float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
48+
float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
7749

78-
//memcpy(s1, s, nc*ir*sizeof(float));
79-
for (int i4 = 0; i4 < nc*ir; i4++) {
80-
s1[i4] = s[i4];
50+
// shift state left
51+
//memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
52+
for (int i4 = 0; i4 < nc*ir - 1; ++i4) {
53+
s[i4] = s[i4+1];
54+
}
55+
56+
// d_inner
57+
for (int i1 = 0; i1 < ir; ++i1) {
58+
// insert x on the last column
59+
s[(nc - 1) + i1*nc] = x0[i1];
60+
}
61+
62+
// it seems a little faster when this is separate from the state shift
63+
for (int i1 = 0; i1 < ir; ++i1) {
64+
// rowwise dot product
65+
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
66+
float sumf = 0.0f;
67+
for (int i0 = 0; i0 < nc; ++i0) {
68+
int i = i0 + i1*nc;
69+
sumf += s[i] * c[i];
8170
}
82-
} else {
83-
// stop at negative or too big seq_ids
84-
break;
71+
x[i1] = sumf;
8572
}
8673
}
8774

88-
// it seems a little faster when this is separate from the state shift
75+
// copy the state out of it
8976
for (int i1 = 0; i1 < ir; ++i1) {
90-
// rowwise dot product
91-
float sumf = 0.0f;
92-
for (int i0 = 0; i0 < nc; ++i0) {
93-
int i = i0 + i1*nc;
94-
sumf += s[i] * c[i];
77+
for (int i0 = 0; i0 < nc - 1; ++i0) {
78+
s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc];
9579
}
96-
x[i1] = sumf;
9780
}
9881
}
9982
}
10083

10184
static void ssm_conv_f32_cuda(
102-
const float * src0, const float * src1, const float * src2, const float * src3,
103-
const int src0_ne0, const int src0_nb1, const int src0_nb2,
104-
const int src1_nb0, const int src1_nb1,
105-
const int src2_nb1, const int src2_nb2,
106-
const int src3_nb1,
85+
const float * src0, const float * src1, const float * src2,
86+
const int src0_nb1, const int src0_nb2,
87+
const int src1_nb0, const int src1_nb1, const int src1_nb2,
88+
const int src2_nb1,
10789
float * dst,
108-
const int nc, const int nr, const int n_t, const int n_kv, cudaStream_t stream) {
90+
const int dst_nb0, const int dst_nb1, const int dst_nb2,
91+
const int nc, const int nr, const int n_t, const int n_s,
92+
cudaStream_t stream) {
10993

11094
const dim3 block_dims(WARP_SIZE, 1, 1);
11195
const int nblocks = 1; // TODO
96+
const int shmem_size = nc * (nr + WARP_SIZE - 1) * sizeof(float); // TODO
11297

113-
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
114-
src0, src1, src2, src3,
115-
src0_ne0, src0_nb1, src0_nb2,
116-
src1_nb0, src1_nb1,
117-
src2_nb1, src2_nb2,
118-
src3_nb1,
98+
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, shmem_size, stream>>>(
99+
src0, src1, src2,
100+
src0_nb1, src0_nb2,
101+
src1_nb0, src1_nb1, src1_nb2,
102+
src2_nb1,
119103
dst,
120-
nc, nr, n_t, n_kv);
104+
dst_nb0, dst_nb1, dst_nb2,
105+
nc, nr, n_t, n_s);
121106
}
122107

123108
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
124109
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
125110
const struct ggml_tensor * src1 = dst->src[1]; // x
126111
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
127-
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
128112

129-
const int nc = src2->ne[0]; // d_conv
130-
const int nr = src0->ne[1]; // d_inner
131-
const int n_t = src1->ne[1]; // n_tokens
132-
const int n_kv = src0->ne[2]; // max number of sequences in the batch
113+
const int nc = src2->ne[0]; // d_conv
114+
const int nr = src0->ne[1]; // d_inner
115+
const int n_t = src1->ne[1]; // tokens per sequence
116+
const int n_s = src0->ne[2]; // number of sequences in the batch
133117

134-
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
118+
GGML_ASSERT(ggml_are_same_shape(src1, dst));
135119
GGML_ASSERT(src0->nb[0] == sizeof(float));
136120
GGML_ASSERT(src1->nb[0] == sizeof(float));
137121
GGML_ASSERT(src2->nb[0] == sizeof(float));
138-
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
139122
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
140-
// for use with the destination state offset between sequences
141-
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
142123

143124
const float * src0_d = (const float *)src0->data;
144125
const float * src1_d = (const float *)src1->data;
145126
const float * src2_d = (const float *)src2->data;
146-
const float * src3_d = (const float *)src3->data;
147127
float * dst_d = (float *)dst->data;
148128
cudaStream_t stream = ctx.stream();
149129

150130
GGML_ASSERT(src0->type == GGML_TYPE_F32);
151131
GGML_ASSERT( dst->type == GGML_TYPE_F32);
152132

153-
ssm_conv_f32_cuda(src0_d, src1_d, src2_d, src3_d,
154-
src0->ne[0], src0->nb[1], src0->nb[2],
155-
src1->nb[0], src1->nb[1],
156-
src2->nb[1], src2->nb[2],
157-
src3->nb[1],
158-
dst_d, nc, nr, n_t, n_kv, stream);
133+
ssm_conv_f32_cuda(src0_d, src1_d, src2_d,
134+
src0->nb[1], src0->nb[2],
135+
src1->nb[0], src1->nb[1], src1->nb[2],
136+
src2->nb[1],
137+
dst_d,
138+
dst->nb[0], dst->nb[1], dst->nb[2],
139+
nc, nr, n_t, n_s,
140+
stream);
159141
}

0 commit comments

Comments
 (0)