Skip to content

Commit f97d546

Browse files
CUDA: fix LoRAs
1 parent 89e8959 commit f97d546

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

ggml-cuda.cu

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5247,7 +5247,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
52475247
if (src->backend == GGML_BACKEND_CPU) {
52485248
kind = cudaMemcpyHostToDevice;
52495249
src_ptr = (char *) src->data;
5250-
} else if (src->backend == GGML_BACKEND_GPU) {
5250+
} else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) {
5251+
GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
52515252
kind = cudaMemcpyDeviceToDevice;
52525253
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
52535254
int id;
@@ -5289,9 +5290,7 @@ inline void ggml_cuda_op_add(
52895290
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
52905291
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
52915292

5292-
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
52935293
GGML_ASSERT(src1->type == GGML_TYPE_F32);
5294-
GGML_ASSERT( dst->type == GGML_TYPE_F32);
52955294

52965295
const int64_t ne10 = src1->ne[0];
52975296
const int64_t ne11 = src1->ne[1];
@@ -5631,10 +5630,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
56315630
const int64_t ne0 = dst->ne[0];
56325631
const int64_t row_diff = row_high - row_low;
56335632

5634-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
5635-
size_t src0_as;
5636-
float * src0_ddf_i = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as);
5637-
to_fp32_cuda(src0_dd_i, src0_ddf_i, row_diff*ne00, stream);
5633+
float * src0_ddq_as_f32;
5634+
size_t src0_as = 0;
5635+
5636+
if (src0->type != GGML_TYPE_F32) {
5637+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
5638+
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
5639+
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
5640+
}
5641+
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
56385642

56395643
int id;
56405644
CUDA_CHECK(cudaGetDevice(&id));
@@ -5651,10 +5655,11 @@ inline void ggml_cuda_op_mul_mat_cublas(
56515655
src1_ddf_i, ne10,
56525656
&beta, dst_dd_i, ldc));
56535657

5654-
ggml_cuda_pool_free(src0_ddf_i, src0_as);
5658+
if (src0_as > 0) {
5659+
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
5660+
}
56555661

56565662
(void) dst;
5657-
(void) src0_dd_i;
56585663
(void) src1_ddq_i;
56595664
(void) src1_padded_row_size;
56605665
}
@@ -5793,7 +5798,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
57935798
const bool use_src1 = src1 != nullptr;
57945799
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
57955800

5796-
GGML_ASSERT( src0->backend != GGML_BACKEND_GPU_SPLIT);
57975801
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
57985802
GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT);
57995803

0 commit comments

Comments
 (0)