Skip to content

Commit 61694eb

Browse files
committed
[experimental][kleidi] linter
1 parent 8bad3e6 commit 61694eb

File tree

4 files changed

+34
-41
lines changed

4 files changed

+34
-41
lines changed

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// namespace example
21
// Copyright (c) Meta Platforms, Inc. and affiliates.
32
// All rights reserved.
43
//
@@ -42,8 +41,9 @@ const Ukernel get_ukernel() {
4241
}
4342

4443
int activation_data_size(int m, int k, int group_size) {
45-
(void) group_size; // unused
46-
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k);
44+
(void)group_size; // unused
45+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
46+
get_ukernel(), m, k);
4747
}
4848

4949
void prepare_activation_data(
@@ -52,17 +52,14 @@ void prepare_activation_data(
5252
int k,
5353
int group_size,
5454
const float* activations) {
55-
(void) group_size; // unused
55+
(void)group_size; // unused
5656
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
57-
get_ukernel(),
58-
activation_data,
59-
m,
60-
k,
61-
activations);
57+
get_ukernel(), activation_data, m, k, activations);
6258
}
6359

6460
int weight_data_size(int n, int k, int group_size) {
65-
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size);
61+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
62+
get_ukernel(), n, k, group_size);
6663
}
6764

6865
void prepare_weight_data(
@@ -96,24 +93,24 @@ void kernel(
9693
const float* bias,
9794
float clamp_min,
9895
float clamp_max) {
99-
(void) bias; // unused - needs API fixing
100-
assert(output_m_stride == n);
101-
if (clamp_min == 0 && clamp_max == 0) {
102-
clamp_min = std::numeric_limits<float>::lowest();
103-
clamp_max = std::numeric_limits<float>::max();
104-
}
96+
(void)bias; // unused - needs API fixing
97+
assert(output_m_stride == n);
98+
if (clamp_min == 0 && clamp_max == 0) {
99+
clamp_min = std::numeric_limits<float>::lowest();
100+
clamp_max = std::numeric_limits<float>::max();
101+
}
105102

106-
auto ukernel = get_ukernel();
107-
ukernel.run_matmul(
103+
auto ukernel = get_ukernel();
104+
ukernel.run_matmul(
108105
m,
109106
n,
110107
k,
111108
group_size,
112109
activation_data,
113110
weight_data,
114111
output,
115-
/*dst_stride_row=*/ n * sizeof(float),
116-
/*dst_stride_col=*/ sizeof(float),
112+
/*dst_stride_row=*/n * sizeof(float),
113+
/*dst_stride_col=*/sizeof(float),
117114
clamp_min,
118115
clamp_max);
119116
}

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// namespace example
21
// Copyright (c) Meta Platforms, Inc. and affiliates.
32
// All rights reserved.
43
//

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// namespace example
21
// Copyright (c) Meta Platforms, Inc. and affiliates.
32
// All rights reserved.
43
//
@@ -7,10 +6,10 @@
76

87
#pragma once
98

10-
#include <cstdint>
9+
#include <cassert>
1110
#include <cstddef>
11+
#include <cstdint>
1212
#include <cstring>
13-
#include <cassert>
1413
#include <limits>
1514
#include <vector>
1615

@@ -25,7 +24,7 @@ namespace torchao::kernels::cpu::aarch64::kleidi {
2524
// TODO: find a better place for these?
2625

2726
size_t roundup(size_t a, size_t b) {
28-
return ((a + b - 1) / b) * b;
27+
return ((a + b - 1) / b) * b;
2928
}
3029

3130
uint16_t get_bf16_from_float(float f) {
@@ -111,13 +110,15 @@ void prepare_weight_data(
111110
uint8_t wzp = 8;
112111
for (size_t i = 0; i < n * k; i += 2) {
113112
const uint8_t low = static_cast<uint8_t>(weight_qvals[i] + wzp);
114-
const uint8_t high = static_cast<uint8_t>(weight_qvals[i+1] + wzp);
113+
const uint8_t high = static_cast<uint8_t>(weight_qvals[i + 1] + wzp);
115114
packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF));
116115
}
117116

118117
// Parameters for packing
119118
rhs_packing::qparams_t qparams{
120-
.lhs_zero_point=1, .rhs_zero_point=wzp, .scale_dt = kai_datatype::kai_dt_bf16};
119+
.lhs_zero_point = 1,
120+
.rhs_zero_point = wzp,
121+
.scale_dt = kai_datatype::kai_dt_bf16};
121122

122123
auto rhs_pack = get_rhs_packing();
123124

@@ -133,7 +134,7 @@ void prepare_weight_data(
133134
/*rhs_stride=*/roundup(k, 2) / 2,
134135
/*bias=*/nullptr, // TODO fix APIs to move bias here
135136
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()),
136-
/*scale_stride=*/ sizeof(uint16_t) * (roundup(k, group_size) / group_size),
137+
/*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size),
137138
/*rhs_packed=*/weight_data,
138139
/*extra_bytes=*/0,
139140
/*qparams=*/&qparams);

torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
#if defined(__aarch64__) || defined(__ARM_NEON)
88

9-
#include <vector>
109
#include <arm_neon.h>
10+
#include <vector>
1111

1212
#include <gtest/gtest.h>
1313
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
@@ -375,10 +375,10 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
375375
has_clamp,
376376
/*weight_scale_bf16_round_trip=*/true);
377377

378-
using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;
378+
using namespace torchao::kernels::cpu::aarch64::kleidi::
379+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32;
379380

380-
std::vector<char> activation_data(
381-
activation_data_size(m, k, group_size));
381+
std::vector<char> activation_data(activation_data_size(m, k, group_size));
382382

383383
prepare_activation_data(
384384
(void*)activation_data.data(),
@@ -387,8 +387,7 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod(
387387
group_size,
388388
test_case.activations.data());
389389

390-
std::vector<char> weight_data(
391-
weight_data_size(n, k, group_size));
390+
std::vector<char> weight_data(weight_data_size(n, k, group_size));
392391

393392
prepare_weight_data(
394393
(void*)weight_data.data(),
@@ -462,8 +461,6 @@ TEST(
462461
/*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128);
463462
}
464463

465-
466-
467464
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
468465
void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
469466
int m,
@@ -482,10 +479,10 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
482479
has_clamp,
483480
/*weight_scale_bf16_round_trip=*/true);
484481

485-
using namespace torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
482+
using namespace torchao::kernels::cpu::aarch64::kleidi::
483+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32;
486484

487-
std::vector<char> activation_data(
488-
activation_data_size(m, k, group_size));
485+
std::vector<char> activation_data(activation_data_size(m, k, group_size));
489486

490487
prepare_activation_data(
491488
(void*)activation_data.data(),
@@ -494,8 +491,7 @@ void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod(
494491
group_size,
495492
test_case.activations.data());
496493

497-
std::vector<char> weight_data(
498-
weight_data_size(n, k, group_size));
494+
std::vector<char> weight_data(weight_data_size(n, k, group_size));
499495

500496
prepare_weight_data(
501497
(void*)weight_data.data(),

0 commit comments

Comments
 (0)