@@ -16,38 +16,66 @@ struct FloatMatrix {
16
16
float data [FOUR ][FOUR ][FOUR ];
17
17
};
18
18
19
+ #ifdef MAT2_IS_TRANSPOSED
20
+ vec4 matmul_naive_W_packed_W_packed (
21
+ #else
19
22
vec4 matmul_naive_W_packed_H_packed (
20
- sampler3D im_mat1 ,
21
- sampler3D im_mat2 ,
22
- ivec3 mat1_pos ,
23
- ivec3 mat2_pos ,
23
+ #endif
24
+ const sampler3D im_mat1 ,
25
+ const sampler3D im_mat2 ,
26
+ const ivec3 out_pos ,
24
27
const int width ) {
28
+ ivec3 mat1_pos = ivec3 (0 , out_pos .y , out_pos .z );
29
+ #ifdef MAT2_IS_TRANSPOSED
30
+ ivec3 mat2_pos = ivec3 (0 , out_pos .x * 4 , 0 );
31
+ #else
32
+ ivec3 mat2_pos = ivec3 (out_pos .x * 4 , 0 , out_pos .z );
33
+ #endif
34
+
25
35
vec4 texel = vec4 (0 );
26
- int K = (width + 3 ) / 4 ;
36
+ const int K = (width + 3 ) / 4 ;
27
37
28
38
for (int i = 0 ; i < K ; ++ i ) {
29
- vec4 mat1_tex = texelFetch (im_mat1 , mat1_pos , 0 );
30
- vec4 sums = vec4 (
39
+ const vec4 mat1_tex = texelFetch (im_mat1 , mat1_pos , 0 );
40
+ #ifdef MAT2_IS_TRANSPOSED
41
+ const vec4 sums = vec4 (
42
+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos , 0 )),
43
+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (0 , 1 , 0 ), 0 )),
44
+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (0 , 2 , 0 ), 0 )),
45
+ dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (0 , 3 , 0 ), 0 )));
46
+ #else
47
+ const vec4 sums = vec4 (
31
48
dot (mat1_tex , texelFetch (im_mat2 , mat2_pos , 0 )),
32
49
dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (1 , 0 , 0 ), 0 )),
33
50
dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (2 , 0 , 0 ), 0 )),
34
51
dot (mat1_tex , texelFetch (im_mat2 , mat2_pos + ivec3 (3 , 0 , 0 ), 0 )));
52
+ #endif
35
53
36
54
texel += sums ;
37
55
38
56
mat1_pos .x ++ ;
57
+ #ifdef MAT2_IS_TRANSPOSED
58
+ mat2_pos .x ++ ;
59
+ #else
39
60
mat2_pos .y ++ ;
61
+ #endif
40
62
}
41
63
42
64
return texel ;
43
65
}
44
66
67
+ #ifdef MAT2_IS_TRANSPOSED
68
+ vec4 matmul_naive_W_packed_H_packed (
69
+ #else
45
70
vec4 matmul_naive_W_packed_W_packed (
46
- sampler3D im_mat1 ,
47
- sampler3D im_mat2 ,
48
- ivec3 mat1_pos ,
49
- ivec3 mat2_pos ,
71
+ #endif
72
+ const sampler3D im_mat1 ,
73
+ const sampler3D im_mat2 ,
74
+ const ivec3 out_pos ,
50
75
const int width ) {
76
+ ivec3 mat1_pos = ivec3 (0 , out_pos .y , out_pos .z );
77
+ ivec3 mat2_pos = ivec3 (out_pos .x , 0 , out_pos .z );
78
+
51
79
vec4 texel = vec4 (0 );
52
80
int K = divup4 (width );
53
81
@@ -87,7 +115,7 @@ vec4 get_texel_W_packed(
87
115
else if (broadcast_at_height ) {
88
116
self_texel = texelFetch (im_self , ivec3 (pos .x , 0 , 0 ), 0 );
89
117
} else {
90
- self_texel = texelFetch (im_self , pos , 0 );
118
+ self_texel = texelFetch (im_self , ivec3 ( pos . x , pos . y , 0 ) , 0 );
91
119
}
92
120
93
121
return self_texel ;
@@ -112,7 +140,7 @@ vec4 get_texel_C_packed(
112
140
else if (broadcast_at_height ) {
113
141
self_texel = texelFetch (im_self , ivec3 (pos .x , 0 , 0 ), 0 );
114
142
} else {
115
- self_texel = texelFetch (im_self , pos , 0 );
143
+ self_texel = texelFetch (im_self , ivec3 ( pos . x , pos . y , 0 ) , 0 );
116
144
}
117
145
118
146
return self_texel ;
@@ -123,8 +151,7 @@ FloatMatrix matmul_partial_4x4(
123
151
sampler3D im_mat2 ,
124
152
const ivec3 pos ,
125
153
const int batch_size ,
126
- const int K_texel_len ,
127
- const int packed_dim_padding ) {
154
+ const int K_texel_len ) {
128
155
FloatMatrix results ;
129
156
for (int i = 0 ; i < FOUR ; i ++ ) {
130
157
for (int j = 0 ; j < FOUR ; j ++ ) {
@@ -133,43 +160,36 @@ FloatMatrix matmul_partial_4x4(
133
160
}
134
161
}
135
162
}
136
- vec4 im_mat1_partial_rows [FOUR ];
137
- vec4 im_mat2_partial_cols [FOUR ];
163
+ vec4 im_mat1_partial_load [FOUR ];
164
+ vec4 im_mat2_partial_load [FOUR ];
138
165
139
166
for (int batch_idx = 0 ; batch_idx < FOUR ; batch_idx ++ ) {
140
167
if (FOUR * pos .z + batch_idx >= batch_size ) {
141
168
break ;
142
169
}
143
- // read and cache 4x4 tile of im_mat1 (4 adjacent rows)
170
+ int mat_z = FOUR * pos . z + batch_idx ;
144
171
for (int mat1_x = 0 ; mat1_x < K_texel_len ; mat1_x ++ ) {
145
- for (int mat1_row = 0 ; mat1_row < FOUR ; mat1_row ++ ) {
146
- const int mat1_y = (FOUR * pos .y ) + mat1_row ;
147
- const ivec3 mat1_pos = ivec3 (mat1_x , mat1_y , FOUR * pos .z + batch_idx );
148
- im_mat1_partial_rows [mat1_row ] = texelFetch (im_mat1 , mat1_pos , 0 );
149
- // set the value out of the boundary to be 0
150
- if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0 ) {
151
- for (int kk = 0 ; kk < packed_dim_padding ; kk ++ ) {
152
- im_mat1_partial_rows [mat1_row ][3 - kk ] = 0 ;
153
- }
154
- }
155
- }
156
- // read and cache 4x4 tile of im_mat2 (4 adjacent columns)
157
- for (int mat2_col = 0 ; mat2_col < FOUR ; mat2_col ++ ) {
158
- const int mat2_x = (FOUR * pos .x ) + mat2_col ;
159
- const ivec3 pos_rd = ivec3 (mat2_x , mat1_x , FOUR * pos .z + batch_idx );
160
- im_mat2_partial_cols [mat2_col ] = texelFetch (im_mat2 , pos_rd , 0 );
161
- // set the value out of the boundary to be 0
162
- if (mat1_x == K_texel_len - 1 && packed_dim_padding > 0 ) {
163
- for (int kk = 0 ; kk < packed_dim_padding ; kk ++ ) {
164
- im_mat2_partial_cols [mat2_col ][3 - kk ] = 0 ;
165
- }
166
- }
172
+ for (int offset = 0 ; offset < FOUR ; offset ++ ) {
173
+ // read and cache 4x4 tile of im_mat1
174
+ const int mat1_y = (FOUR * pos .y ) + offset ;
175
+ const ivec3 mat1_pos = ivec3 (mat1_x , mat1_y , mat_z );
176
+ im_mat1_partial_load [offset ] = texelFetch (im_mat1 , mat1_pos , 0 );
177
+ // read and cache 4x4 tile of im_mat2
178
+ #ifdef MAT2_IS_TRANSPOSED
179
+ const int mat2_y = (FOUR * pos .x ) + offset ;
180
+ const ivec3 mat2_pos = ivec3 (mat1_x , mat2_y , 0 );
181
+ im_mat2_partial_load [offset ] = texelFetch (im_mat2 , mat2_pos , 0 );
182
+ #else
183
+ const int mat2_x = (FOUR * pos .x ) + offset ;
184
+ const ivec3 mat2_pos = ivec3 (mat2_x , mat1_x , mat_z );
185
+ im_mat2_partial_load [offset ] = texelFetch (im_mat2 , mat2_pos , 0 );
186
+ #endif
167
187
}
168
188
// perform partial dot products and add partial result to results
169
189
for (int out_row = 0 ; out_row < FOUR ; out_row ++ ) {
170
190
for (int out_col = 0 ; out_col < FOUR ; out_col ++ ) {
171
191
results .data [out_row ][out_col ][batch_idx ] +=
172
- dot (im_mat1_partial_rows [out_row ], im_mat2_partial_cols [out_col ]);
192
+ dot (im_mat1_partial_load [out_row ], im_mat2_partial_load [out_col ]);
173
193
}
174
194
}
175
195
}
0 commit comments