Skip to content

Commit bcb5d63

Browse files
committed
rewrite platform and device selection
1 parent bb5c3e2 commit bcb5d63

File tree

1 file changed

+116
-27
lines changed

1 file changed

+116
-27
lines changed

ggml-opencl.c

Lines changed: 116 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float*
143143
do { \
144144
cl_int err_ = (err); \
145145
if (err_ != CL_SUCCESS) { \
146-
fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
146+
fprintf(stderr, "ggml_opencl: %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
147147
exit(1); \
148148
} \
149149
} while (0)
@@ -152,6 +152,7 @@ static cl_platform_id platform;
152152
static cl_device_id device;
153153
static cl_context context;
154154
static cl_command_queue queue;
155+
static cl_bool out_of_order;
155156
static cl_program program;
156157
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q5_0, kernel_q5_1, kernel_q8_0;
157158
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
@@ -188,35 +189,123 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
188189

189190
void ggml_cl_init(void) {
190191
cl_int err = 0;
191-
char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
192-
char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
193-
int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
194-
int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
195-
printf("\nInitializing CLBlast (First Run)...");
196-
printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
197-
cl_uint num_platforms;
198-
clGetPlatformIDs(0, NULL, &num_platforms);
199-
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
200-
clGetPlatformIDs(num_platforms, platforms, NULL);
201-
platform = platforms[plat_num];
202-
char platform_buffer[1024];
203-
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
204-
cl_uint num_devices;
205-
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
206-
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
207-
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
208-
device = devices[dev_num];
209-
char device_buffer[1024];
210-
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
211-
printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
212-
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
213-
CL_CHECK(err, "clCreateContext");
192+
193+
enum { NPLAT = 16, NDEV = 16 };
194+
195+
char text_buffer[1024] = {0};
196+
197+
platform = NULL;
198+
char * GGML_OPENCL_PLATFORM = getenv("GGML_OPENCL_PLATFORM");
199+
if (GGML_OPENCL_PLATFORM != NULL) {
200+
cl_platform_id platforms[NPLAT];
201+
cl_uint num_platforms;
202+
err = clGetPlatformIDs(NPLAT, platforms, &num_platforms);
203+
CL_CHECK(err, "clGetPlatformIDs");
204+
205+
unsigned plat_num;
206+
if (sscanf(GGML_OPENCL_PLATFORM, " %u", &plat_num) == 1) {
207+
if (plat_num >= num_platforms) {
208+
fprintf(stderr, "ggml_opencl: There is no platform %d\n", plat_num);
209+
exit(1);
210+
} else {
211+
platform = platforms[plat_num];
212+
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
213+
}
214+
} else {
215+
for (unsigned i = 0; i < num_platforms; i++) {
216+
clGetPlatformInfo(platforms[i], CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
217+
if (strstr(text_buffer, GGML_OPENCL_PLATFORM) != NULL) {
218+
platform = platforms[i];
219+
break;
220+
}
221+
}
222+
}
223+
if (platform == NULL) {
224+
fprintf(stderr, "ggml_opencl: no platform matching '%s' was found.\n", GGML_OPENCL_PLATFORM);
225+
exit(1);
226+
} else {
227+
fprintf(stderr, "ggml_opencl: selecting platform: '%s'\n", text_buffer);
228+
}
229+
}
230+
231+
text_buffer[0] = 0;
232+
device = NULL;
233+
char * GGML_OPENCL_DEVICE = getenv("GGML_OPENCL_DEVICE");
234+
if (GGML_OPENCL_DEVICE != NULL) {
235+
cl_device_id devices[16];
236+
cl_uint num_devices;
237+
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, NDEV, devices, &num_devices);
238+
239+
unsigned dev_num;
240+
if (sscanf(GGML_OPENCL_DEVICE, " %u", &dev_num) == 1) {
241+
if (dev_num >= num_devices) {
242+
fprintf(stderr, "ggml_opencl: There is no device %d\n", dev_num);
243+
exit(1);
244+
} else {
245+
device = devices[dev_num];
246+
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
247+
}
248+
} else {
249+
for (unsigned i = 0; i < num_devices; i++) {
250+
clGetDeviceInfo(devices[i], CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
251+
if (strstr(text_buffer, GGML_OPENCL_DEVICE) != NULL) {
252+
device = devices[i];
253+
break;
254+
}
255+
}
256+
}
257+
if (device == NULL) {
258+
fprintf(stderr, "ggml_opencl: no device matching '%s' was found.\n", GGML_OPENCL_DEVICE);
259+
exit(1);
260+
} else {
261+
fprintf(stderr, "ggml_opencl: selecting device: '%s'\n", text_buffer);
262+
}
263+
}
264+
265+
cl_context_properties *properties = platform == NULL ? NULL : (cl_context_properties[]){
266+
(intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
267+
};
268+
269+
if (device != NULL) {
270+
context = clCreateContext(properties, 1, &device, NULL, NULL, &err);
271+
CL_CHECK(err, "clCreateContext");
272+
} else {
273+
context = clCreateContextFromType(properties, CL_DEVICE_TYPE_GPU, NULL, NULL, &err);
274+
if (err == CL_DEVICE_NOT_AVAILABLE || err == CL_DEVICE_NOT_FOUND) {
275+
context = clCreateContextFromType(properties, CL_DEVICE_TYPE_DEFAULT, NULL, NULL, &err);
276+
if (err == CL_DEVICE_NOT_AVAILABLE || err == CL_DEVICE_NOT_FOUND) {
277+
context = clCreateContextFromType(properties, CL_DEVICE_TYPE_ALL, NULL, NULL, &err);
278+
}
279+
}
280+
CL_CHECK(err, "clCreateContextFromType");
281+
}
282+
283+
if (device == NULL) {
284+
err = clGetContextInfo(context, CL_CONTEXT_DEVICES, sizeof(&device), &device, NULL);
285+
CL_CHECK(err, "clGetContextInfo");
286+
if (platform == NULL) {
287+
err = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(&platform), &platform, NULL);
288+
CL_CHECK(err, "clGetDeviceInfo");
289+
}
290+
}
291+
292+
if (platform != NULL) {
293+
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(text_buffer), &text_buffer, NULL);
294+
fprintf(stderr, "ggml_opencl: using platform: '%s'\n", text_buffer);
295+
}
296+
if (device != NULL) {
297+
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(text_buffer), &text_buffer, NULL);
298+
fprintf(stderr, "ggml_opencl: using device: '%s'\n", text_buffer);
299+
}
300+
301+
out_of_order = CL_TRUE;
214302
queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
303+
if (err == CL_INVALID_PROPERTY) {
304+
out_of_order = CL_FALSE;
305+
queue = clCreateCommandQueue(context, device, 0, &err);
306+
}
215307
CL_CHECK(err, "clCreateCommandQueue");
216308

217-
free(platforms);
218-
free(devices);
219-
220309
program = build_program_from_source(context, device, clblast_dequant);
221310

222311
// Prepare dequantize kernels

0 commit comments

Comments
 (0)