Skip to content

Commit 9c1ddc7

Browse files
committed
cuda : fix im2col kernel
1 parent 000b952 commit 9c1ddc7

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

ggml-cuda.cu

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4737,13 +4737,18 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
47374737
}
47384738

47394739
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
4740-
int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4741-
int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4742-
__syncthreads();
4740+
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4741+
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4742+
4743+
const int offset_dst =
4744+
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
4745+
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
4746+
47434747
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
4744-
int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
4745-
int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
4746-
dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
4748+
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
4749+
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
4750+
} else {
4751+
dst[offset_dst] = __float2half(0.0f);
47474752
}
47484753
}
47494754

@@ -5735,7 +5740,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
57355740
int KH, int KW, int N, int ofs0, int ofs1,
57365741
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
57375742
dim3 block_nums(IC, OH, OW);
5738-
dim3 block_dims(N, KH, KW);
5743+
dim3 block_dims(N, KH, KW);
57395744
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
57405745
}
57415746

@@ -6714,16 +6719,16 @@ inline void ggml_cuda_op_im2col(
67146719

67156720
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
67166721

6717-
const int64_t N = src1->ne[is_2D ? 3 : 2];
6722+
const int64_t N = src1->ne[is_2D ? 3 : 2];
67186723
const int64_t IC = src1->ne[is_2D ? 2 : 1];
67196724
const int64_t IH = is_2D ? src1->ne[1] : 1;
6720-
const int64_t IW = src1->ne[0];
6725+
const int64_t IW = src1->ne[0];
67216726

67226727
const int64_t KH = is_2D ? src0->ne[1] : 1;
6723-
const int64_t KW = src0->ne[0];
6728+
const int64_t KW = src0->ne[0];
67246729

67256730
const int64_t OH = is_2D ? dst->ne[2] : 1;
6726-
const int64_t OW = dst->ne[1];
6731+
const int64_t OW = dst->ne[1];
67276732

67286733
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
67296734
OH, IW, IH, OW, IC, KH, KW, N,

ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5227,13 +5227,13 @@ struct ggml_tensor * ggml_im2col(
52275227
}
52285228

52295229
const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
5230-
const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
5230+
const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
52315231

52325232
const int64_t ne[4] = {
52335233
is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
52345234
OW,
52355235
is_2D ? OH : b->ne[2],
5236-
is_2D ? b->ne[3] : 1,
5236+
is_2D ? b->ne[3] : 1,
52375237
};
52385238

52395239
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);

whisper.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,22 +1604,22 @@ static struct ggml_cgraph * whisper_build_graph_conv(
16041604
// convolution + gelu
16051605
{
16061606
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1607-
//cur = ggml_add(ctx0, cur, model.e_conv_1_b);
1608-
cur = ggml_add(ctx0,
1609-
ggml_repeat(ctx0,
1610-
model.e_conv_1_b,
1611-
cur),
1612-
cur);
1607+
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
1608+
//cur = ggml_add(ctx0,
1609+
// ggml_repeat(ctx0,
1610+
// model.e_conv_1_b,
1611+
// cur),
1612+
// cur);
16131613

16141614
cur = ggml_gelu(ctx0, cur);
16151615

16161616
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1617-
//cur = ggml_add(ctx0, cur, model.e_conv_2_b);
1618-
cur = ggml_add(ctx0,
1619-
ggml_repeat(ctx0,
1620-
model.e_conv_2_b,
1621-
cur),
1622-
cur);
1617+
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
1618+
//cur = ggml_add(ctx0,
1619+
// ggml_repeat(ctx0,
1620+
// model.e_conv_2_b,
1621+
// cur),
1622+
// cur);
16231623

16241624
cur = ggml_gelu(ctx0, cur);
16251625
}

0 commit comments

Comments
 (0)