Skip to content
583 changes: 366 additions & 217 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#version 450

#include "generic_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

layout (constant_id = 0) const uint BLOCK_SIZE = 32;

shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
shared uint tmp[BLOCK_SIZE];

void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;

if (col >= p.KX) {
return;
}
A_TYPE amax = data_a[row*p.KX + col];
tmp[col] = col;

for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
A_TYPE val = data_a[row*p.KX + i];
if (val > amax) {
amax = val;
tmp[col] = i;
}
}
tmpmax[col] = amax;

barrier();
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
if (col < s && col + s < p.KX) {
if (tmpmax[col] < tmpmax[col + s]) {
tmpmax[col] = tmpmax[col + s];
tmp[col] = tmp[col + s];
}
}
barrier();
}

if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
}
}
31 changes: 31 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#version 450

#extension GL_EXT_control_flow_attributes : enable

#include "types.comp"
#include "generic_head.comp"

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};

const uint CHUNK_SIZE = 512;

void main() {
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
const uint col = gl_LocalInvocationID.x;

uint count = 0;
[[unroll]]
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
const uint idx = base + i + col;
if (idx >= p.KX) {
break;
}
count += uint(data_a[idx] == data_b[idx]);
}

atomicAdd(data_d[0], D_TYPE(count));
}
42 changes: 42 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#version 450

#include "generic_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) buffer X {A_TYPE x[];};
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
layout (binding = 2) buffer GM {A_TYPE gradm[];};
layout (binding = 3) buffer GV {A_TYPE gradv[];};
layout (binding = 4) readonly buffer P {float params[7];};

void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;

if (i >= p.KX) {
return;
}

const float alpha = params[0];
const float beta1 = params[1];
const float beta2 = params[2];
const float eps = params[3];
const float wd = params[4];
const float beta1h = params[5];
const float beta2h = params[6];

const float gi = grad[i];
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);

gradm[i] = gmi;
gradv[i] = gvi;

const float mh = gmi*beta1h;
const float vh = sqrt(gvi*beta2h) + eps;

x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
}
37 changes: 37 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#version 450

#include "types.comp"
#include "generic_unary_head.comp"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();

if (idx >= p.ne) {
return;
}

// Destination multi-index (inlined dst_idx)
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;

// Accumulate from sources
A_TYPE acc = A_TYPE(0);
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is get_aoffset() needed here? (I don't know)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably.

}
}
}
}

data_d[get_doffset() + d_idx] = D_TYPE(acc);
}
29 changes: 29 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require

#include "types.comp"
#include "generic_binary_head.comp"

const uint num_threads = 256;

layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;

void main() {
uint idx = get_idx();

// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;

[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);

data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));

idx += num_threads;
}
}
7 changes: 7 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ void process_shaders() {
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});

string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
Expand All @@ -452,6 +454,7 @@ void process_shaders() {
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

Expand Down Expand Up @@ -501,7 +504,9 @@ void process_shaders() {

string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});

string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));

string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
Expand All @@ -513,6 +518,8 @@ void process_shaders() {

string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

for (auto &c : compiles) {
c.wait();
}
Expand Down
10 changes: 5 additions & 5 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case {
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");

ggml_tensor * b_argmax = ggml_argmax(ctx, a);
ggml_tensor * b_argmax = ggml_argmax(ctx, b);
ggml_set_name(b_argmax, "b_argmax");

ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
Expand Down Expand Up @@ -1511,6 +1511,7 @@ struct test_cont : public test_case {
};

// GGML_OP_ADD
// GGML_OP_SUB
// GGML_OP_MUL
// GGML_OP_DIV
struct test_bin_bcast : public test_case {
Expand Down Expand Up @@ -3860,7 +3861,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));

test_cases.emplace_back(new test_count_equal());
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));

test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
Expand All @@ -3885,8 +3887,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
}

test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
Expand Down Expand Up @@ -3938,7 +3938,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));

auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
for (auto op : {ggml_add, ggml_mul, ggml_div}) {
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
}
};
Expand Down
Loading