Skip to content

Implement aten.linear.default #3594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand Down Expand Up @@ -45,24 +48,20 @@ void main() {
}

vec4 texel = vec4(0);
ivec3 mat1_pos = ivec3(0, pos.y, pos.z);

$if MAT1_PACKING == "W_packed":
$if MAT2_PACKING == "H_packed":
ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z);
texel = matmul_naive_W_packed_H_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$elif MAT2_PACKING == "W_packed":
ivec3 mat2_pos = ivec3(pos.x, 0, pos.z);
texel = matmul_naive_W_packed_W_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$else:
$raise Exception("Unsupported value for MAT2_PACKING")
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ addmm_naive:
NDIM: 3
MAT1_PACKING: W_packed
MAT2_PACKING: H_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
Expand All @@ -18,3 +19,6 @@ addmm_naive:
- NAME: addmm_naive_W_packed_H_packed
- NAME: addmm_naive_W_packed_W_packed
MAT2_PACKING: W_packed
- NAME: linear_naive_W_packed_W_packed
MAT2_PACKING: W_packed
MAT2_IS_TRANSPOSED: true
22 changes: 10 additions & 12 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand All @@ -31,11 +34,8 @@ layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes {
ivec4 self_sizes;
};

layout(set = 0, binding = 7) uniform PRECISION restrict PackedDimMeta {
int packed_dim_size;
int packed_dim_size_padded;
int packed_dim_texel_len;
int packed_dim_padding;
layout(set = 0, binding = 7) uniform PRECISION restrict InLimits {
ivec3 in_limits;
};

layout(set = 0, binding = 8) uniform PRECISION restrict Params {
Expand All @@ -57,8 +57,7 @@ void main() {
im_mat2,
pos,
out_sizes[2],
packed_dim_texel_len,
packed_dim_padding);
in_limits[0]);

for (int idx_c = 0; idx_c < FOUR; idx_c++) {
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
Expand All @@ -70,17 +69,16 @@ void main() {
out_pos,
self_sizes.x == 1,
self_sizes.y == 1);
results.data[idx_c][idx_r][0] = beta * self_texel.x + alpha * results.data[idx_c][idx_r][0];

// results is in transposed order w.r.t. the desired output
imageStore(
im_out,
out_pos,
vec4(
results.data[idx_c][idx_r][0],
results.data[idx_c][idx_r][1],
results.data[idx_c][idx_r][2],
results.data[idx_c][idx_r][3]));
beta * self_texel.x + alpha * results.data[idx_c][idx_r][0],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][1],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][2],
beta * self_texel.x + alpha * results.data[idx_c][idx_r][3]));
}
}
}
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ addmm_optimized:
DTYPE: float
NDIM: 3
PACKING: C_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: addmm_optimized
- NAME: linear_optimized
MAT2_IS_TRANSPOSED: true
102 changes: 61 additions & 41 deletions backends/vulkan/runtime/graph/ops/glsl/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,66 @@ struct FloatMatrix {
float data[FOUR][FOUR][FOUR];
};

#ifdef MAT2_IS_TRANSPOSED
vec4 matmul_naive_W_packed_W_packed(
#else
vec4 matmul_naive_W_packed_H_packed(
sampler3D im_mat1,
sampler3D im_mat2,
ivec3 mat1_pos,
ivec3 mat2_pos,
#endif
const sampler3D im_mat1,
const sampler3D im_mat2,
const ivec3 out_pos,
const int width) {
ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z);
#ifdef MAT2_IS_TRANSPOSED
ivec3 mat2_pos = ivec3(0, out_pos.x * 4, 0);
#else
ivec3 mat2_pos = ivec3(out_pos.x * 4, 0, out_pos.z);
#endif

vec4 texel = vec4(0);
int K = (width + 3) / 4;
const int K = (width + 3) / 4;

for (int i = 0; i < K; ++i) {
vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
vec4 sums = vec4(
const vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
#ifdef MAT2_IS_TRANSPOSED
const vec4 sums = vec4(
dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 1, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 2, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 3, 0), 0)));
#else
const vec4 sums = vec4(
dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(1, 0, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(2, 0, 0), 0)),
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(3, 0, 0), 0)));
#endif

texel += sums;

mat1_pos.x++;
#ifdef MAT2_IS_TRANSPOSED
mat2_pos.x++;
#else
mat2_pos.y++;
#endif
}

return texel;
}

#ifdef MAT2_IS_TRANSPOSED
vec4 matmul_naive_W_packed_H_packed(
#else
vec4 matmul_naive_W_packed_W_packed(
sampler3D im_mat1,
sampler3D im_mat2,
ivec3 mat1_pos,
ivec3 mat2_pos,
#endif
const sampler3D im_mat1,
const sampler3D im_mat2,
const ivec3 out_pos,
const int width) {
ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z);
ivec3 mat2_pos = ivec3(out_pos.x, 0, out_pos.z);

vec4 texel = vec4(0);
int K = divup4(width);

Expand Down Expand Up @@ -87,7 +115,7 @@ vec4 get_texel_W_packed(
else if (broadcast_at_height) {
self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0);
} else {
self_texel = texelFetch(im_self, pos, 0);
self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0);
}

return self_texel;
Expand All @@ -112,7 +140,7 @@ vec4 get_texel_C_packed(
else if (broadcast_at_height) {
self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0);
} else {
self_texel = texelFetch(im_self, pos, 0);
self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0);
}

return self_texel;
Expand All @@ -123,8 +151,7 @@ FloatMatrix matmul_partial_4x4(
sampler3D im_mat2,
const ivec3 pos,
const int batch_size,
const int K_texel_len,
const int packed_dim_padding) {
const int K_texel_len) {
FloatMatrix results;
for (int i = 0; i < FOUR; i++) {
for (int j = 0; j < FOUR; j++) {
Expand All @@ -133,43 +160,36 @@ FloatMatrix matmul_partial_4x4(
}
}
}
vec4 im_mat1_partial_rows[FOUR];
vec4 im_mat2_partial_cols[FOUR];
vec4 im_mat1_partial_load[FOUR];
vec4 im_mat2_partial_load[FOUR];

for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {
if (FOUR * pos.z + batch_idx >= batch_size) {
break;
}
// read and cache 4x4 tile of im_mat1 (4 adjacent rows)
int mat_z = FOUR * pos.z + batch_idx;
for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
for (int mat1_row = 0; mat1_row < FOUR; mat1_row++) {
const int mat1_y = (FOUR * pos.y) + mat1_row;
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, FOUR * pos.z + batch_idx);
im_mat1_partial_rows[mat1_row] = texelFetch(im_mat1, mat1_pos, 0);
// set the value out of the boundary to be 0
if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) {
for (int kk = 0; kk < packed_dim_padding; kk++) {
im_mat1_partial_rows[mat1_row][3 - kk] = 0;
}
}
}
// read and cache 4x4 tile of im_mat2 (4 adjacent columns)
for (int mat2_col = 0; mat2_col < FOUR; mat2_col++) {
const int mat2_x = (FOUR * pos.x) + mat2_col;
const ivec3 pos_rd = ivec3(mat2_x, mat1_x, FOUR * pos.z + batch_idx);
im_mat2_partial_cols[mat2_col] = texelFetch(im_mat2, pos_rd, 0);
// set the value out of the boundary to be 0
if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) {
for (int kk = 0; kk < packed_dim_padding; kk++) {
im_mat2_partial_cols[mat2_col][3 - kk] = 0;
}
}
for (int offset = 0; offset < FOUR; offset++) {
// read and cache 4x4 tile of im_mat1
const int mat1_y = (FOUR * pos.y) + offset;
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z);
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
// read and cache 4x4 tile of im_mat2
#ifdef MAT2_IS_TRANSPOSED
const int mat2_y = (FOUR * pos.x) + offset;
const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0);
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
#else
const int mat2_x = (FOUR * pos.x) + offset;
const ivec3 mat2_pos = ivec3(mat2_x, mat1_x, mat_z);
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
#endif
}
// perform partial dot products and add partial result to results
for (int out_row = 0; out_row < FOUR; out_row++) {
for (int out_col = 0; out_col < FOUR; out_col++) {
results.data[out_row][out_col][batch_idx] +=
dot(im_mat1_partial_rows[out_row], im_mat2_partial_cols[out_col]);
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
}
}
}
Expand Down
12 changes: 5 additions & 7 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand All @@ -35,24 +38,19 @@ void main() {
}

vec4 texel = vec4(0);
ivec3 mat1_pos = ivec3(0, pos.y, pos.z);

$if MAT1_PACKING == "W_packed":
$if MAT2_PACKING == "H_packed":
ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z);
texel = matmul_naive_W_packed_H_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$elif MAT2_PACKING == "W_packed":
ivec3 mat2_pos = ivec3(pos.x, 0, pos.z);
texel = matmul_naive_W_packed_W_packed(
im_mat1,
im_mat2,
mat1_pos,
mat2_pos,
pos,
in_sizes[0]);
$else:
$raise Exception("Unsupported value for MAT2_PACKING")
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ matmul_naive:
NDIM: 3
MAT1_PACKING: W_packed
MAT2_PACKING: H_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
Expand All @@ -18,3 +19,6 @@ matmul_naive:
- NAME: matmul_naive_W_packed_H_packed
- NAME: matmul_naive_W_packed_W_packed
MAT2_PACKING: W_packed
- NAME: matmul_transposed_naive_W_packed_W_packed
MAT2_PACKING: W_packed
MAT2_IS_TRANSPOSED: true
13 changes: 6 additions & 7 deletions backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#define PRECISION ${PRECISION}

$if MAT2_IS_TRANSPOSED:
#define MAT2_IS_TRANSPOSED

#include "indexing_utils.h"
#include "matmul.h"

Expand All @@ -25,11 +28,8 @@ layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes {
ivec4 out_sizes;
};

layout(set = 0, binding = 5) uniform PRECISION restrict PackedDimMeta {
int packed_dim_size;
int packed_dim_size_padded;
int packed_dim_texel_len;
int packed_dim_padding;
layout(set = 0, binding = 5) uniform PRECISION restrict InLimits {
ivec3 in_limits;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand All @@ -46,8 +46,7 @@ void main() {
im_mat2,
pos,
out_sizes[2],
packed_dim_texel_len,
packed_dim_padding);
in_limits[0]);

for (int idx_c = 0; idx_c < FOUR; idx_c++) {
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ matmul_optimized:
DTYPE: float
NDIM: 3
PACKING: C_packed
MAT2_IS_TRANSPOSED: false
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
shader_variants:
- NAME: matmul_optimized
- NAME: matmul_transposed_optimized
MAT2_IS_TRANSPOSED: true
Loading
Loading