Skip to content

Commit 1b4303c

Browse files
howard0sujunsu0ms
authored andcommitted
Leverage mmap for CUDA loading
Fix a typo when calcualte VRAM size Normalize OpenCL loading code as CUDA Fix clang-tidy warnings Avoid mlock of offloaded tensors. Avoid allocate buffer for offloaded tensor when using no-mmap Address review comments
1 parent dcb2ed4 commit 1b4303c

File tree

5 files changed

+51
-125
lines changed

5 files changed

+51
-125
lines changed

ggml-cuda.cu

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
900900
}
901901
}
902902

903-
void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
903+
void ggml_cuda_transform_tensor(const void * data, ggml_tensor * tensor) {
904904
const int64_t ne0 = tensor->ne[0];
905905
const int64_t ne1 = tensor->ne[1];
906906
const int64_t ne2 = tensor->ne[2];
@@ -913,6 +913,7 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
913913
char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
914914

915915
cudaStream_t cudaStream2 = g_cudaStreams2[0];
916+
tensor->data = (void*)data;
916917

917918
// copy tensor to device
918919
for (int64_t i3 = 0; i3 < ne3; i3++) {
@@ -923,35 +924,5 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
923924
}
924925

925926
tensor->data = dst;
926-
tensor->backend = GGML_BACKEND_CUDA;
927-
}
928-
929-
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
930-
FILE * fp = fopen(fname, "rb");
931-
932-
const size_t size = ggml_nbytes(tensor);
933-
934-
void * buf;
935-
CUDA_CHECK(cudaMalloc(&buf, size));
936-
void * buf_host = malloc(size);
937-
938-
#ifdef _WIN32
939-
int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
940-
#else
941-
int ret = fseek(fp, (long) offset, SEEK_SET);
942-
#endif
943-
GGML_ASSERT(ret == 0); // same
944-
945-
size_t ret2 = fread(buf_host, size, 1, fp);
946-
if (ret2 != 1) {
947-
fprintf(stderr, "unexpectedly reached end of file");
948-
exit(1);
949-
}
950-
951-
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
952-
cudaDeviceSynchronize();
953-
954-
tensor->data = buf;
955-
free(buf_host);
956-
fclose(fp);
927+
GGML_ASSERT(tensor->backend == GGML_BACKEND_CUDA);
957928
}

ggml-cuda.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ void * ggml_cuda_host_malloc(size_t size);
1616
void ggml_cuda_host_free(void * ptr);
1717

1818
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
19-
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
2019

2120
#ifdef __cplusplus
2221
}

ggml-opencl.cpp

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct g
11331133
return 0;
11341134
}
11351135

1136-
void ggml_cl_transform_tensor(ggml_tensor * tensor) {
1136+
void ggml_cl_transform_tensor(const void * data, ggml_tensor * tensor) {
11371137
const int64_t ne0 = tensor->ne[0];
11381138
const int64_t ne1 = tensor->ne[1];
11391139
const int64_t ne2 = tensor->ne[2];
@@ -1145,6 +1145,8 @@ void ggml_cl_transform_tensor(ggml_tensor * tensor) {
11451145
size_t q_size;
11461146
cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size, CL_MEM_READ_ONLY);
11471147

1148+
tensor->data = (void*)data;
1149+
11481150
// copy tensor to device
11491151
for (int64_t i3 = 0; i3 < ne3; i3++) {
11501152
for (int64_t i2 = 0; i2 < ne2; i2++) {
@@ -1156,35 +1158,5 @@ void ggml_cl_transform_tensor(ggml_tensor * tensor) {
11561158
CL_CHECK(clFinish(queue));
11571159

11581160
tensor->data = dst;
1159-
tensor->backend = GGML_BACKEND_CL;
1160-
}
1161-
1162-
void ggml_cl_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
1163-
cl_int err;
1164-
FILE * fp = fopen(fname, "rb");
1165-
1166-
const size_t size = ggml_nbytes(tensor);
1167-
1168-
cl_mem dst;
1169-
CL_CHECK((dst = clCreateBuffer(context, CL_MEM_READ_ONLY, size, nullptr, &err), err));
1170-
void * buf_host = malloc(size);
1171-
1172-
#ifdef _WIN32
1173-
int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
1174-
#else
1175-
int ret = fseek(fp, (long) offset, SEEK_SET);
1176-
#endif
1177-
GGML_ASSERT(ret == 0); // same
1178-
1179-
size_t ret2 = fread(buf_host, size, 1, fp);
1180-
if (ret2 != 1) {
1181-
fprintf(stderr, "unexpectedly reached end of file");
1182-
exit(1);
1183-
}
1184-
1185-
clEnqueueWriteBuffer(queue, dst, CL_TRUE, 0, size, buf_host, 0, nullptr, nullptr);
1186-
1187-
tensor->data = dst;
1188-
free(buf_host);
1189-
fclose(fp);
1161+
GGML_ASSERT(tensor->backend == GGML_BACKEND_CL);
11901162
}

ggml-opencl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor
1616
void * ggml_cl_host_malloc(size_t size);
1717
void ggml_cl_host_free(void * ptr);
1818

19-
void ggml_cl_transform_tensor(struct ggml_tensor * tensor);
20-
void ggml_cl_load_data(const char * fname, struct ggml_tensor * tensor, size_t offset);
19+
void ggml_cl_transform_tensor(const void * data, struct ggml_tensor * tensor);
2120

2221
#ifdef __cplusplus
2322
}

llama.cpp

Lines changed: 43 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -673,13 +673,21 @@ struct llama_model_loader {
673673

674674
struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) {
675675
struct ggml_tensor * tensor;
676+
677+
if (backend != GGML_BACKEND_CPU) {
678+
ggml_set_no_alloc(ggml_ctx, true);
679+
}
676680
if (lt.ne.size() == 2) {
677681
tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1));
678682
} else {
679683
LLAMA_ASSERT(lt.ne.size() == 1);
680684
tensor = ggml_new_tensor_1d(ggml_ctx, lt.type, lt.ne.at(0));
681685
}
682686
ggml_set_name(tensor, lt.name.c_str());
687+
688+
if (backend != GGML_BACKEND_CPU) {
689+
ggml_set_no_alloc(ggml_ctx, use_mmap);
690+
}
683691
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
684692
tensor->backend = backend;
685693
lt.ggml_tensor = tensor;
@@ -696,6 +704,7 @@ struct llama_model_loader {
696704
void load_all_data(llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) {
697705
size_t data_size = 0;
698706
size_t prefetch_size = 0;
707+
size_t lock_size = 0;
699708
for (const llama_load_tensor & lt : tensors_map.tensors) {
700709
data_size += lt.size;
701710
if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
@@ -705,32 +714,52 @@ struct llama_model_loader {
705714

706715
if (use_mmap) {
707716
mapping.reset(new llama_mmap(&file_loaders.at(0)->file, prefetch_size));
708-
if (!lmlock) {
709-
// Don't call the callback since the actual loading will be lazy
710-
// and we can't measure it.
711-
progress_callback = NULL;
712-
}
713717
if (lmlock) {
714718
lmlock->init(mapping->addr);
715719
}
716720
}
717721

718722
size_t done_size = 0;
719723
for (llama_load_tensor & lt : tensors_map.tensors) {
720-
if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) {
721-
continue;
722-
}
723724
if (progress_callback) {
724725
progress_callback((float) done_size / data_size, progress_callback_user_data);
725726
}
726727
LLAMA_ASSERT(lt.ggml_tensor); // unused tensors should have been caught by load_data already
727728
lt.data = (uint8_t *) lt.ggml_tensor->data;
729+
// allocate temp buffer if not using mmap
730+
if (!use_mmap && lt.data == NULL) {
731+
lt.data = (uint8_t*)malloc(ggml_nbytes(lt.ggml_tensor));
732+
}
733+
728734
load_data_for(lt);
729-
lt.ggml_tensor->data = lt.data;
730-
done_size += lt.size;
731-
if (use_mmap && lmlock) {
732-
lmlock->grow_to(done_size);
735+
switch(lt.ggml_tensor->backend) {
736+
case GGML_BACKEND_CPU:
737+
lt.ggml_tensor->data = lt.data;
738+
if (use_mmap && lmlock) {
739+
lock_size += lt.size;
740+
lmlock->grow_to(lock_size);
741+
}
742+
break;
743+
#ifdef GGML_USE_CUBLAS
744+
case GGML_BACKEND_CUDA:
745+
ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor);
746+
if (!use_mmap) {
747+
free(lt.data);
748+
}
749+
break;
750+
#endif
751+
#ifdef GGML_USE_CLBLAST
752+
case GGML_BACKEND_CL:
753+
ggml_cl_transform_tensor(lt.data, lt.ggml_tensor);
754+
if (!use_mmap) {
755+
free(lt.data);
756+
}
757+
break;
758+
#endif
759+
default:
760+
continue;
733761
}
762+
done_size += lt.size;
734763
}
735764
}
736765

@@ -1069,8 +1098,8 @@ static void llama_model_load_internal(
10691098

10701099
if (backend == LLAMA_BACKEND_OFFLOAD) {
10711100
vram_total +=
1072-
ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
1073-
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) +
1101+
ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
1102+
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
10741103
ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
10751104
}
10761105
}
@@ -1117,50 +1146,6 @@ static void llama_model_load_internal(
11171146

11181147
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
11191148

1120-
#if defined(GGML_USE_CUBLAS)
1121-
{
1122-
size_t done_size = 0;
1123-
size_t data_size = 0;
1124-
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
1125-
data_size += lt.size;
1126-
if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
1127-
done_size += lt.size;
1128-
}
1129-
}
1130-
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
1131-
if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
1132-
continue;
1133-
}
1134-
if (progress_callback) {
1135-
progress_callback((float) done_size / data_size, progress_callback_user_data);
1136-
}
1137-
ggml_cuda_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off);
1138-
done_size += lt.size;
1139-
}
1140-
}
1141-
#elif defined(GGML_USE_CLBLAST)
1142-
{
1143-
size_t done_size = 0;
1144-
size_t data_size = 0;
1145-
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
1146-
data_size += lt.size;
1147-
if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) {
1148-
done_size += lt.size;
1149-
}
1150-
}
1151-
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
1152-
if (lt.ggml_tensor->backend != GGML_BACKEND_CL) {
1153-
continue;
1154-
}
1155-
if (progress_callback) {
1156-
progress_callback((float) done_size / data_size, progress_callback_user_data);
1157-
}
1158-
ggml_cl_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off);
1159-
done_size += lt.size;
1160-
}
1161-
}
1162-
#endif
1163-
11641149
if (progress_callback) {
11651150
progress_callback(1.0f, progress_callback_user_data);
11661151
}

0 commit comments

Comments
 (0)