@@ -54,6 +54,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
54
54
uvec2 qs0 = uvec2(unpack8(qs0_u16));
55
55
uvec2 qs16 = uvec2(unpack8(qs16_u16));
56
56
57
+ FLOAT_TYPE sc_q[2][8];
58
+ [[unroll]] for (int l = 0; l < 2; ++l) {
59
+ sc_q[l][0] = sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3);
60
+ sc_q[l][1] = sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3);
61
+ sc_q[l][2] = sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3);
62
+ sc_q[l][3] = sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3);
63
+ sc_q[l][4] = sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3);
64
+ sc_q[l][5] = sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3);
65
+ sc_q[l][6] = sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3);
66
+ sc_q[l][7] = sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3);
67
+ }
68
+
57
69
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
58
70
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
59
71
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
@@ -67,14 +79,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
67
79
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
68
80
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
69
81
[[unroll]] for (int l = 0; l < 2; ++l) {
70
- sum1 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][0] * FLOAT_TYPE((qs0[l] ) & 3) ,
71
- fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][1] * FLOAT_TYPE((qs16[l] ) & 3) ,
72
- fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][2] * FLOAT_TYPE((qs0[l] >> 2) & 3) ,
73
- fma(FLOAT_TYPE(b48[l]), sccache[ix][v_im][3] * FLOAT_TYPE((qs16[l] >> 2) & 3) ,
74
- fma(FLOAT_TYPE(b64[l]), sccache[ix][v_im][4] * FLOAT_TYPE((qs0[l] >> 4) & 3) ,
75
- fma(FLOAT_TYPE(b80[l]), sccache[ix][v_im][5] * FLOAT_TYPE((qs16[l] >> 4) & 3) ,
76
- fma(FLOAT_TYPE(b96[l]), sccache[ix][v_im][6] * FLOAT_TYPE((qs0[l] >> 6) & 3) ,
77
- fma(FLOAT_TYPE(b112[l]), sccache[ix][v_im][7] * FLOAT_TYPE((qs16[l] >> 6) & 3) , sum1))))))));
82
+ sum1 = fma(FLOAT_TYPE(b0[l]), sc_q[l][0] ,
83
+ fma(FLOAT_TYPE(b16[l]), sc_q[l][1] ,
84
+ fma(FLOAT_TYPE(b32[l]), sc_q[l][2] ,
85
+ fma(FLOAT_TYPE(b48[l]), sc_q[l][3] ,
86
+ fma(FLOAT_TYPE(b64[l]), sc_q[l][4] ,
87
+ fma(FLOAT_TYPE(b80[l]), sc_q[l][5] ,
88
+ fma(FLOAT_TYPE(b96[l]), sc_q[l][6] ,
89
+ fma(FLOAT_TYPE(b112[l]), sc_q[l][7] , sum1))))))));
78
90
sum2 = fma(FLOAT_TYPE(b0[l]), sccache[ix][v_im][ 8],
79
91
fma(FLOAT_TYPE(b16[l]), sccache[ix][v_im][ 9],
80
92
fma(FLOAT_TYPE(b32[l]), sccache[ix][v_im][10],
0 commit comments