Skip to content

Commit 922115c

Browse files
committed
vulkan : handle ggml_scale for n%8 != 0
ref ggml-org#3754
1 parent 0aa04a2 commit 922115c

File tree

4 files changed

+56
-16
lines changed

4 files changed

+56
-16
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ if (LLAMA_KOMPUTE)
476476
# Compile our shaders
477477
compile_shader(SOURCES
478478
kompute/op_scale.comp
479+
kompute/op_scale_8.comp
479480
kompute/op_add.comp
480481
kompute/op_addrow.comp
481482
kompute/op_mul.comp
@@ -508,6 +509,7 @@ if (LLAMA_KOMPUTE)
508509
# Create a custom target for our generated shaders
509510
add_custom_target(generated_shaders DEPENDS
510511
shaderop_scale.h
512+
shaderop_scale_8.h
511513
shaderop_add.h
512514
shaderop_addrow.h
513515
shaderop_mul.h

ggml-vulkan.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
// These are generated at build time by cmake custom command
1313
#include "shaderop_scale.h"
14+
#include "shaderop_scale_8.h"
1415
#include "shaderop_add.h"
1516
#include "shaderop_addrow.h"
1617
#include "shaderop_mul.h"
@@ -724,8 +725,12 @@ void ggml_vk_scale(kp::Sequence& seq,
724725
const std::shared_ptr<kp::Tensor>& out,
725726
uint32_t inOff, uint32_t outOff,
726727
uint32_t size, float scale) {
727-
const static auto spirv = getSpirvShader(kp::shader_data::op_scale_comp_spv,
728-
kp::shader_data::op_scale_comp_spv_len);
728+
const static auto spirv_1 = getSpirvShader(
729+
kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
730+
);
731+
const static auto spirv_8 = getSpirvShader(
732+
kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
733+
);
729734

730735
struct PushConstants {
731736
uint32_t inOff, outOff;
@@ -735,11 +740,19 @@ void ggml_vk_scale(kp::Sequence& seq,
735740
scale
736741
};
737742

743+
const auto * spirv = &spirv_1;
744+
std::string name(__func__);
745+
if (size % 8 == 0) {
746+
size /= 8;
747+
name += "_8";
748+
spirv = &spirv_8;
749+
}
750+
738751
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
739-
if (!komputeManager()->hasAlgorithm(__func__))
740-
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
741-
else {
742-
s_algo = komputeManager()->getAlgorithm(__func__);
752+
if (!komputeManager()->hasAlgorithm(name)) {
753+
s_algo = komputeManager()->algorithm<float, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
754+
} else {
755+
s_algo = komputeManager()->getAlgorithm(name);
743756
s_algo->setTensors({in, out});
744757
s_algo->setWorkgroup({size});
745758
s_algo->setPushConstants<PushConstants>({pushConsts});
@@ -1416,9 +1429,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
14161429
case GGML_OP_SCALE:
14171430
{
14181431
const float scale = *(const float *) src1->data;
1419-
int64_t n = ggml_nelements(dst);
1420-
GGML_ASSERT(n % 8 == 0);
1421-
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, n/8, scale);
1432+
ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
14221433
} break;
14231434
case GGML_OP_UNARY:
14241435
{

kompute/op_scale.comp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ layout(push_constant) uniform PushConstants {
2222
} pcs;
2323

2424
void main() {
25-
const uint baseIndex = gl_WorkGroupID.x * 8;
26-
27-
for (uint x = 0; x < 8; x++) {
28-
const uint i = baseIndex + x;
29-
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
30-
}
31-
}
25+
const uint i = gl_WorkGroupID.x;
26+
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
27+
}

kompute/op_scale_8.comp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
3+
*
4+
* This software is licensed under the terms of the Software for Open Models License (SOM),
5+
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
6+
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
7+
*/
8+
9+
#version 450
10+
11+
#include "common.comp"
12+
13+
layout(local_size_x = 1) in;
14+
15+
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
16+
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
17+
18+
layout(push_constant) uniform PushConstants {
19+
uint inOff;
20+
uint outOff;
21+
float scale;
22+
} pcs;
23+
24+
void main() {
25+
const uint baseIndex = gl_WorkGroupID.x * 8;
26+
27+
for (uint x = 0; x < 8; x++) {
28+
const uint i = baseIndex + x;
29+
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
30+
}
31+
}

0 commit comments

Comments
 (0)