Skip to content

Commit 0e7955d

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Implement aten.linear.default (#3594)
Summary: Pull Request resolved: #3594 As title. Implementation is rather simple because the shaders just have to accumulate the `mat2` shader across the width dim rather than the height dim. Reviewed By: yipjustin Differential Revision: D57203869 fbshipit-source-id: 08932a75e66924a0dfb0816f8ccefa718a341dd8
1 parent e8a520c commit 0e7955d

15 files changed

+330
-131
lines changed

backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl

+5-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
$if MAT2_IS_TRANSPOSED:
14+
#define MAT2_IS_TRANSPOSED
15+
1316
#include "indexing_utils.h"
1417
#include "matmul.h"
1518

@@ -45,24 +48,20 @@ void main() {
4548
}
4649

4750
vec4 texel = vec4(0);
48-
ivec3 mat1_pos = ivec3(0, pos.y, pos.z);
4951

5052
$if MAT1_PACKING == "W_packed":
5153
$if MAT2_PACKING == "H_packed":
5254
ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z);
5355
texel = matmul_naive_W_packed_H_packed(
5456
im_mat1,
5557
im_mat2,
56-
mat1_pos,
57-
mat2_pos,
58+
pos,
5859
in_sizes[0]);
5960
$elif MAT2_PACKING == "W_packed":
60-
ivec3 mat2_pos = ivec3(pos.x, 0, pos.z);
6161
texel = matmul_naive_W_packed_W_packed(
6262
im_mat1,
6363
im_mat2,
64-
mat1_pos,
65-
mat2_pos,
64+
pos,
6665
in_sizes[0]);
6766
$else:
6867
$raise Exception("Unsupported value for MAT2_PACKING")

backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ addmm_naive:
1010
NDIM: 3
1111
MAT1_PACKING: W_packed
1212
MAT2_PACKING: H_packed
13+
MAT2_IS_TRANSPOSED: false
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: float
@@ -18,3 +19,6 @@ addmm_naive:
1819
- NAME: addmm_naive_W_packed_H_packed
1920
- NAME: addmm_naive_W_packed_W_packed
2021
MAT2_PACKING: W_packed
22+
- NAME: linear_naive_W_packed_W_packed
23+
MAT2_PACKING: W_packed
24+
MAT2_IS_TRANSPOSED: true

backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl

+10-12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
$if MAT2_IS_TRANSPOSED:
14+
#define MAT2_IS_TRANSPOSED
15+
1316
#include "indexing_utils.h"
1417
#include "matmul.h"
1518

@@ -31,11 +34,8 @@ layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes {
3134
ivec4 self_sizes;
3235
};
3336

34-
layout(set = 0, binding = 7) uniform PRECISION restrict PackedDimMeta {
35-
int packed_dim_size;
36-
int packed_dim_size_padded;
37-
int packed_dim_texel_len;
38-
int packed_dim_padding;
37+
layout(set = 0, binding = 7) uniform PRECISION restrict InLimits {
38+
ivec3 in_limits;
3939
};
4040

4141
layout(set = 0, binding = 8) uniform PRECISION restrict Params {
@@ -57,8 +57,7 @@ void main() {
5757
im_mat2,
5858
pos,
5959
out_sizes[2],
60-
packed_dim_texel_len,
61-
packed_dim_padding);
60+
in_limits[0]);
6261

6362
for (int idx_c = 0; idx_c < FOUR; idx_c++) {
6463
for (int idx_r = 0; idx_r < FOUR; idx_r++) {
@@ -70,17 +69,16 @@ void main() {
7069
out_pos,
7170
self_sizes.x == 1,
7271
self_sizes.y == 1);
73-
results.data[idx_c][idx_r][0] = beta * self_texel.x + alpha * results.data[idx_c][idx_r][0];
7472

7573
// results is in transposed order w.r.t. the desired output
7674
imageStore(
7775
im_out,
7876
out_pos,
7977
vec4(
80-
results.data[idx_c][idx_r][0],
81-
results.data[idx_c][idx_r][1],
82-
results.data[idx_c][idx_r][2],
83-
results.data[idx_c][idx_r][3]));
78+
beta * self_texel.x + alpha * results.data[idx_c][idx_r][0],
79+
beta * self_texel.x + alpha * results.data[idx_c][idx_r][1],
80+
beta * self_texel.x + alpha * results.data[idx_c][idx_r][2],
81+
beta * self_texel.x + alpha * results.data[idx_c][idx_r][3]));
8482
}
8583
}
8684
}

backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ addmm_optimized:
99
DTYPE: float
1010
NDIM: 3
1111
PACKING: C_packed
12+
MAT2_IS_TRANSPOSED: false
1213
generate_variant_forall:
1314
DTYPE:
1415
- VALUE: float
1516
- VALUE: half
1617
shader_variants:
1718
- NAME: addmm_optimized
19+
- NAME: linear_optimized
20+
MAT2_IS_TRANSPOSED: true

backends/vulkan/runtime/graph/ops/glsl/matmul.h

+61-41
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,66 @@ struct FloatMatrix {
1616
float data[FOUR][FOUR][FOUR];
1717
};
1818

19+
#ifdef MAT2_IS_TRANSPOSED
20+
vec4 matmul_naive_W_packed_W_packed(
21+
#else
1922
vec4 matmul_naive_W_packed_H_packed(
20-
sampler3D im_mat1,
21-
sampler3D im_mat2,
22-
ivec3 mat1_pos,
23-
ivec3 mat2_pos,
23+
#endif
24+
const sampler3D im_mat1,
25+
const sampler3D im_mat2,
26+
const ivec3 out_pos,
2427
const int width) {
28+
ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z);
29+
#ifdef MAT2_IS_TRANSPOSED
30+
ivec3 mat2_pos = ivec3(0, out_pos.x * 4, 0);
31+
#else
32+
ivec3 mat2_pos = ivec3(out_pos.x * 4, 0, out_pos.z);
33+
#endif
34+
2535
vec4 texel = vec4(0);
26-
int K = (width + 3) / 4;
36+
const int K = (width + 3) / 4;
2737

2838
for (int i = 0; i < K; ++i) {
29-
vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
30-
vec4 sums = vec4(
39+
const vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0);
40+
#ifdef MAT2_IS_TRANSPOSED
41+
const vec4 sums = vec4(
42+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)),
43+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 1, 0), 0)),
44+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 2, 0), 0)),
45+
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 3, 0), 0)));
46+
#else
47+
const vec4 sums = vec4(
3148
dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)),
3249
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(1, 0, 0), 0)),
3350
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(2, 0, 0), 0)),
3451
dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(3, 0, 0), 0)));
52+
#endif
3553

3654
texel += sums;
3755

3856
mat1_pos.x++;
57+
#ifdef MAT2_IS_TRANSPOSED
58+
mat2_pos.x++;
59+
#else
3960
mat2_pos.y++;
61+
#endif
4062
}
4163

4264
return texel;
4365
}
4466

67+
#ifdef MAT2_IS_TRANSPOSED
68+
vec4 matmul_naive_W_packed_H_packed(
69+
#else
4570
vec4 matmul_naive_W_packed_W_packed(
46-
sampler3D im_mat1,
47-
sampler3D im_mat2,
48-
ivec3 mat1_pos,
49-
ivec3 mat2_pos,
71+
#endif
72+
const sampler3D im_mat1,
73+
const sampler3D im_mat2,
74+
const ivec3 out_pos,
5075
const int width) {
76+
ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z);
77+
ivec3 mat2_pos = ivec3(out_pos.x, 0, out_pos.z);
78+
5179
vec4 texel = vec4(0);
5280
int K = divup4(width);
5381

@@ -87,7 +115,7 @@ vec4 get_texel_W_packed(
87115
else if (broadcast_at_height) {
88116
self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0);
89117
} else {
90-
self_texel = texelFetch(im_self, pos, 0);
118+
self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0);
91119
}
92120

93121
return self_texel;
@@ -112,7 +140,7 @@ vec4 get_texel_C_packed(
112140
else if (broadcast_at_height) {
113141
self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0);
114142
} else {
115-
self_texel = texelFetch(im_self, pos, 0);
143+
self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0);
116144
}
117145

118146
return self_texel;
@@ -123,8 +151,7 @@ FloatMatrix matmul_partial_4x4(
123151
sampler3D im_mat2,
124152
const ivec3 pos,
125153
const int batch_size,
126-
const int K_texel_len,
127-
const int packed_dim_padding) {
154+
const int K_texel_len) {
128155
FloatMatrix results;
129156
for (int i = 0; i < FOUR; i++) {
130157
for (int j = 0; j < FOUR; j++) {
@@ -133,43 +160,36 @@ FloatMatrix matmul_partial_4x4(
133160
}
134161
}
135162
}
136-
vec4 im_mat1_partial_rows[FOUR];
137-
vec4 im_mat2_partial_cols[FOUR];
163+
vec4 im_mat1_partial_load[FOUR];
164+
vec4 im_mat2_partial_load[FOUR];
138165

139166
for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) {
140167
if (FOUR * pos.z + batch_idx >= batch_size) {
141168
break;
142169
}
143-
// read and cache 4x4 tile of im_mat1 (4 adjacent rows)
170+
int mat_z = FOUR * pos.z + batch_idx;
144171
for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) {
145-
for (int mat1_row = 0; mat1_row < FOUR; mat1_row++) {
146-
const int mat1_y = (FOUR * pos.y) + mat1_row;
147-
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, FOUR * pos.z + batch_idx);
148-
im_mat1_partial_rows[mat1_row] = texelFetch(im_mat1, mat1_pos, 0);
149-
// set the value out of the boundary to be 0
150-
if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) {
151-
for (int kk = 0; kk < packed_dim_padding; kk++) {
152-
im_mat1_partial_rows[mat1_row][3 - kk] = 0;
153-
}
154-
}
155-
}
156-
// read and cache 4x4 tile of im_mat2 (4 adjacent columns)
157-
for (int mat2_col = 0; mat2_col < FOUR; mat2_col++) {
158-
const int mat2_x = (FOUR * pos.x) + mat2_col;
159-
const ivec3 pos_rd = ivec3(mat2_x, mat1_x, FOUR * pos.z + batch_idx);
160-
im_mat2_partial_cols[mat2_col] = texelFetch(im_mat2, pos_rd, 0);
161-
// set the value out of the boundary to be 0
162-
if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0) {
163-
for (int kk = 0; kk < packed_dim_padding; kk++) {
164-
im_mat2_partial_cols[mat2_col][3 - kk] = 0;
165-
}
166-
}
172+
for (int offset = 0; offset < FOUR; offset++) {
173+
// read and cache 4x4 tile of im_mat1
174+
const int mat1_y = (FOUR * pos.y) + offset;
175+
const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z);
176+
im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0);
177+
// read and cache 4x4 tile of im_mat2
178+
#ifdef MAT2_IS_TRANSPOSED
179+
const int mat2_y = (FOUR * pos.x) + offset;
180+
const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0);
181+
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
182+
#else
183+
const int mat2_x = (FOUR * pos.x) + offset;
184+
const ivec3 mat2_pos = ivec3(mat2_x, mat1_x, mat_z);
185+
im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0);
186+
#endif
167187
}
168188
// perform partial dot products and add partial result to results
169189
for (int out_row = 0; out_row < FOUR; out_row++) {
170190
for (int out_col = 0; out_col < FOUR; out_col++) {
171191
results.data[out_row][out_col][batch_idx] +=
172-
dot(im_mat1_partial_rows[out_row], im_mat2_partial_cols[out_col]);
192+
dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]);
173193
}
174194
}
175195
}

backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl

+5-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
$if MAT2_IS_TRANSPOSED:
14+
#define MAT2_IS_TRANSPOSED
15+
1316
#include "indexing_utils.h"
1417
#include "matmul.h"
1518

@@ -35,24 +38,19 @@ void main() {
3538
}
3639

3740
vec4 texel = vec4(0);
38-
ivec3 mat1_pos = ivec3(0, pos.y, pos.z);
3941

4042
$if MAT1_PACKING == "W_packed":
4143
$if MAT2_PACKING == "H_packed":
42-
ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z);
4344
texel = matmul_naive_W_packed_H_packed(
4445
im_mat1,
4546
im_mat2,
46-
mat1_pos,
47-
mat2_pos,
47+
pos,
4848
in_sizes[0]);
4949
$elif MAT2_PACKING == "W_packed":
50-
ivec3 mat2_pos = ivec3(pos.x, 0, pos.z);
5150
texel = matmul_naive_W_packed_W_packed(
5251
im_mat1,
5352
im_mat2,
54-
mat1_pos,
55-
mat2_pos,
53+
pos,
5654
in_sizes[0]);
5755
$else:
5856
$raise Exception("Unsupported value for MAT2_PACKING")

backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ matmul_naive:
1010
NDIM: 3
1111
MAT1_PACKING: W_packed
1212
MAT2_PACKING: H_packed
13+
MAT2_IS_TRANSPOSED: false
1314
generate_variant_forall:
1415
DTYPE:
1516
- VALUE: float
@@ -18,3 +19,6 @@ matmul_naive:
1819
- NAME: matmul_naive_W_packed_H_packed
1920
- NAME: matmul_naive_W_packed_W_packed
2021
MAT2_PACKING: W_packed
22+
- NAME: matmul_transposed_naive_W_packed_W_packed
23+
MAT2_PACKING: W_packed
24+
MAT2_IS_TRANSPOSED: true

backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl

+6-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
$if MAT2_IS_TRANSPOSED:
14+
#define MAT2_IS_TRANSPOSED
15+
1316
#include "indexing_utils.h"
1417
#include "matmul.h"
1518

@@ -25,11 +28,8 @@ layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes {
2528
ivec4 out_sizes;
2629
};
2730

28-
layout(set = 0, binding = 5) uniform PRECISION restrict PackedDimMeta {
29-
int packed_dim_size;
30-
int packed_dim_size_padded;
31-
int packed_dim_texel_len;
32-
int packed_dim_padding;
31+
layout(set = 0, binding = 5) uniform PRECISION restrict InLimits {
32+
ivec3 in_limits;
3333
};
3434

3535
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
@@ -46,8 +46,7 @@ void main() {
4646
im_mat2,
4747
pos,
4848
out_sizes[2],
49-
packed_dim_texel_len,
50-
packed_dim_padding);
49+
in_limits[0]);
5150

5251
for (int idx_c = 0; idx_c < FOUR; idx_c++) {
5352
for (int idx_r = 0; idx_r < FOUR; idx_r++) {

backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ matmul_optimized:
99
DTYPE: float
1010
NDIM: 3
1111
PACKING: C_packed
12+
MAT2_IS_TRANSPOSED: false
1213
generate_variant_forall:
1314
DTYPE:
1415
- VALUE: float
1516
- VALUE: half
1617
shader_variants:
1718
- NAME: matmul_optimized
19+
- NAME: matmul_transposed_optimized
20+
MAT2_IS_TRANSPOSED: true

0 commit comments

Comments
 (0)