From 2dce7c4847dd26b43483c3178344259058fea3b2 Mon Sep 17 00:00:00 2001 From: Silencio Date: Thu, 16 May 2024 12:29:30 +0800 Subject: [PATCH] Add punica dimension 27648 --- csrc/punica/bgmv/bgmv_config.h | 2 ++ tests/lora/test_punica.py | 1 + 2 files changed, 3 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 19c058cacfbc..98ac8de779e1 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -53,6 +53,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ + f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ @@ -121,6 +122,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 22016, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index fd2a1b75f460..193e3906997c 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -79,6 +79,7 @@ def _lora_ref_impl( 22016, 24576, 27392, + 27648, 32000, 32256, 32512,