@@ -4737,13 +4737,18 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
4737
4737
}
4738
4738
4739
4739
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
4740
- int iiw = blockIdx .z * s0 + threadIdx .z * d0 - p0;
4741
- int iih = blockIdx .y * s1 + threadIdx .y * d1 - p1;
4742
- __syncthreads ();
4740
+ const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
4741
+ const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
4742
+
4743
+ const int offset_dst =
4744
+ (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
4745
+ (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
4746
+
4743
4747
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
4744
- int offset_dst = (threadIdx .x * gridDim .y * gridDim .z + blockIdx .y * gridDim .z + blockIdx .z ) * CHW;
4745
- int offset_src = threadIdx .x * ofs0 + blockIdx .x * ofs1;
4746
- dst[offset_dst + (blockIdx .x * (blockDim .y * blockDim .z ) + threadIdx .y * blockDim .z + threadIdx .z )] = __float2half (x[offset_src + iih * IW + iiw]);
4748
+ const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
4749
+ dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
4750
+ } else {
4751
+ dst[offset_dst] = __float2half(0.0f);
4747
4752
}
4748
4753
}
4749
4754
@@ -5735,7 +5740,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
5735
5740
int KH, int KW, int N, int ofs0, int ofs1,
5736
5741
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
5737
5742
dim3 block_nums(IC, OH, OW);
5738
- dim3 block_dims (N, KH, KW);
5743
+ dim3 block_dims(N, KH, KW);
5739
5744
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
5740
5745
}
5741
5746
@@ -6714,16 +6719,16 @@ inline void ggml_cuda_op_im2col(
6714
6719
6715
6720
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
6716
6721
6717
- const int64_t N = src1->ne [is_2D ? 3 : 2 ];
6722
+ const int64_t N = src1->ne[is_2D ? 3 : 2];
6718
6723
const int64_t IC = src1->ne[is_2D ? 2 : 1];
6719
6724
const int64_t IH = is_2D ? src1->ne[1] : 1;
6720
- const int64_t IW = src1->ne [0 ];
6725
+ const int64_t IW = src1->ne[0];
6721
6726
6722
6727
const int64_t KH = is_2D ? src0->ne[1] : 1;
6723
- const int64_t KW = src0->ne [0 ];
6728
+ const int64_t KW = src0->ne[0];
6724
6729
6725
6730
const int64_t OH = is_2D ? dst->ne[2] : 1;
6726
- const int64_t OW = dst->ne [1 ];
6731
+ const int64_t OW = dst->ne[1];
6727
6732
6728
6733
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
6729
6734
OH, IW, IH, OW, IC, KH, KW, N,
0 commit comments