-
Notifications
You must be signed in to change notification settings - Fork 802
perf: Add FP16 GEMM MMUL Reshaped Only Rhs Support #1181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Looks fine to me |
9181c67
to
25eea17
Compare
src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
Outdated
Show resolved
Hide resolved
const bool is_fp16 = (src0->data_type() == DataType::F16); | ||
|
||
// These error messages are for FP16 acc. | ||
ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_fp16 && (n < rhs_info.n0 * mmul_n0), "N must be greater that N0 * MMUL_N0"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First suggestion: Put the fp16 related validations into a if(is_fp16)
block and remove is_fp16
from every check.
Also, the message should be specific to fp16 kernel, e.g.
"K must be multiple of 4 in fp16 mmul kernel".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this? (same change as for the comment before it)
const unsigned int m = gemm_info.m; | ||
const unsigned int n = gemm_info.n; | ||
const unsigned int k = gemm_info.k; | ||
const bool is_fp16 = (src0->data_type() == DataType::F16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check all these validations for fp16, but at this point, we do not know whether we'll be using fp16 mmul kernel. So, we might be validating against all these for the fp32 mmul kernel running on fp16 input, which is wrong, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this?
src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
Outdated
Show resolved
Hide resolved
src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
Outdated
Show resolved
Hide resolved
@@ -132,15 +139,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture< | |||
|
|||
TEST_SUITE_END() // FP32 | |||
|
|||
TEST_SUITE(FP16) | |||
TEST_SUITE(MMUL_FP16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fp32 mmul kernel still supports Fp16 although the heuristics don't choose it (same goes for block sizes), therefore we shouldn't remove this test.
Also, we should replace the combine(combine(...
patterns with a single combine, i.e. combine(...)
wherever we touch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this?
|
||
/** N values to test */ | ||
const auto n_values = framework::dataset::make("N", {257}); | ||
const auto n_values = framework::dataset::make("N", {257}); | ||
const auto n_values_fp16 = framework::dataset::make("N", {79}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think one of the mistakes we did was to test on a single shape. While we test different block sizes on a single shape, we should also test on small shapes and a subset of block sizes. (While keeping in mind the test time required ofc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this? These are values from a number of models of interest, given the combinatorial expansion, I think 3 is a reasonable number of values for each of the datasets. I've tried to pick shapes on the smaller side to keep runtime reasonable too.
|
||
/** K0 values to test - Precommit */ | ||
const auto k0_values_precommit = framework::dataset::make("K0", { 1 }); | ||
|
||
/** Broadcast bias from vector to matrix */ | ||
const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } ); | ||
const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false } ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change reduces fp32 tests as well
@@ -160,7 +167,38 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture< | |||
framework::ARM_COMPUTE_PRINT_INFO(); | |||
} | |||
} | |||
TEST_SUITE_END() // FP16 | |||
|
|||
TEST_SUITE(ExportToCLImage) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there will be any, they should be under the TEST_SUITE(ExportToCLImage) below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How's this?
act_values)) | ||
{ | ||
// Validate output | ||
if(validate_result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we've done this the wrong way 4 years ago, and the culprit is me :) We shouldn't do the same mistake here. If you look at GEMMFixture.h and the relevant fixture class, this is set to false when the gemm or reshape does not validate to true. So, if we provide a faulty configuration to test, it'll skip the test. This is definitely not the right thing to do.
I hope the context is clear. Let me know if it's not. Now, what do we need to do?:
We change GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture so that validate_result is set to false only if the hardware features relevant for the test in question are not supported.
i.e. We have fp32 and fp16 tests using mmul with fp32 accumulators and we have fp16 tests using mmul with fp16 accumulators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made a change to GEMMFixture.h to the setting of this value, now making it correspond to whether the target hardware supports arm_matrix_multiply.
25eea17
to
5f4e692
Compare
This patch introduces a GEMM routine that is optimized for Arm(R) Mali(TM)-G1 Resolves: [COMPMID-8311], [COMPMID-8312] Signed-off-by: Omar Al Khatib <[email protected]> Change-Id: I84e685f0314da9af1c3fbb50d83e68b355727770
5f4e692
to
5e9d919
Compare
This patch introduces a GEMM routine that is optimized for Arm(R) Mali(TM)-G1
Resolves: [COMPMID-8311], [COMPMID-8312]
Change-Id: I84e685f0314da9af1c3fbb50d83e68b355727770