Skip to content

Commit 728e8bb

Browse files
authored
bugfix: support uint8_t for vec_t class template (#1234)
This PR tries to fix an issue that occured while enabling fp8 kv-cache support for vllm (vllm-project/vllm#17005). The issue was that in an generated inc file (e.g. in my case flashinfer/100/generated/batch_decode_with_kv_cache_dtype_q_bf16_dtype_kv_u8_dtype_o_bf16_dtype_idx_i32_head_dim_qk_128_head_dim_vo_128_posenc_0_use_swa_False_use_logits_cap_False/batch_decode_config.inc ) we declared DTypeKV to be uint8_t, shown as below: ``` using DTypeKV = uint8_t; ... struct Params { ... using DTypeKV = DTypeKV; ... }; ``` Consequently, when we instantiate the vec_ from cast_load_impl defined in vec_dtypes.cuh, i.e. ``` vec_t<src_float_t, vec_size> tmp; ``` src_float_t is instantiated to uint8_t through template parameter T=Params::DTypeKV, where Params::DTypeKV is uint8_t. Because vec_t doesn't have any specialization for uint8_t, we ended up with the following ptxas error: ``` ptxas fatal : Unresolved extern function '_ZN10flashinfer5vec_tIhLm16EE4loadEPKh' ``` The fix is to add a specialization for uint8_t. However, this may not be the right fix, because the root cause might be that we shouldn't generate ```using DTypeKV = uint8_t;``` in the first place. <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 04f9758 commit 728e8bb

File tree

1 file changed

+243
-0
lines changed

1 file changed

+243
-0
lines changed

include/flashinfer/vec_dtypes.cuh

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,249 @@ struct vec_t<nv_bfloat16, vec_size> {
13991399
}
14001400
};
14011401

1402+
/******************* vec_t<uint8_t> *******************/
1403+
1404+
// uint8_t x 1
1405+
template <>
1406+
struct vec_t<uint8_t, 1> {
1407+
uint8_t data;
1408+
1409+
FLASHINFER_INLINE uint8_t& operator[](size_t i) { return ((uint8_t*)(&data))[i]; }
1410+
FLASHINFER_INLINE const uint8_t& operator[](size_t i) const {
1411+
return ((const uint8_t*)(&data))[i];
1412+
}
1413+
FLASHINFER_INLINE uint8_t* ptr() { return reinterpret_cast<uint8_t*>(&data); }
1414+
FLASHINFER_INLINE void fill(uint8_t val);
1415+
FLASHINFER_INLINE void load(const uint8_t* ptr);
1416+
FLASHINFER_INLINE void store(uint8_t* ptr) const;
1417+
template <typename T>
1418+
FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
1419+
cast_from_impl(*this, src);
1420+
}
1421+
template <typename T>
1422+
FLASHINFER_INLINE void cast_load(const T* ptr) {
1423+
cast_load_impl(*this, ptr);
1424+
}
1425+
template <typename T>
1426+
FLASHINFER_INLINE void cast_store(T* ptr) const {
1427+
cast_store_impl(ptr, *this);
1428+
}
1429+
1430+
FLASHINFER_INLINE static void memcpy(uint8_t* dst, const uint8_t* src);
1431+
};
1432+
1433+
FLASHINFER_INLINE void vec_t<uint8_t, 1>::fill(uint8_t val) { data = val; }
1434+
1435+
FLASHINFER_INLINE void vec_t<uint8_t, 1>::load(const uint8_t* ptr) { data = *ptr; }
1436+
1437+
FLASHINFER_INLINE void vec_t<uint8_t, 1>::store(uint8_t* ptr) const { *ptr = data; }
1438+
1439+
FLASHINFER_INLINE void vec_t<uint8_t, 1>::memcpy(uint8_t* dst, const uint8_t* src) { *dst = *src; }
1440+
1441+
// uint8_t x 2
1442+
template <>
1443+
struct vec_t<uint8_t, 2> {
1444+
uint16_t data;
1445+
1446+
FLASHINFER_INLINE uint8_t& operator[](size_t i) { return ((uint8_t*)(&data))[i]; }
1447+
FLASHINFER_INLINE const uint8_t& operator[](size_t i) const {
1448+
return ((const uint8_t*)(&data))[i];
1449+
}
1450+
FLASHINFER_INLINE uint8_t* ptr() { return reinterpret_cast<uint8_t*>(&data); }
1451+
FLASHINFER_INLINE void fill(uint8_t val);
1452+
FLASHINFER_INLINE void load(const uint8_t* ptr);
1453+
FLASHINFER_INLINE void store(uint8_t* ptr) const;
1454+
template <typename T>
1455+
FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
1456+
cast_from_impl(*this, src);
1457+
}
1458+
template <typename T>
1459+
FLASHINFER_INLINE void cast_load(const T* ptr) {
1460+
cast_load_impl(*this, ptr);
1461+
}
1462+
template <typename T>
1463+
FLASHINFER_INLINE void cast_store(T* ptr) const {
1464+
cast_store_impl(ptr, *this);
1465+
}
1466+
1467+
FLASHINFER_INLINE static void memcpy(uint8_t* dst, const uint8_t* src);
1468+
};
1469+
1470+
FLASHINFER_INLINE void vec_t<uint8_t, 2>::fill(uint8_t val) {
1471+
data = (uint16_t(val) << 8) | uint16_t(val);
1472+
}
1473+
1474+
FLASHINFER_INLINE void vec_t<uint8_t, 2>::load(const uint8_t* ptr) { data = *((uint16_t*)ptr); }
1475+
1476+
FLASHINFER_INLINE void vec_t<uint8_t, 2>::store(uint8_t* ptr) const { *((uint16_t*)ptr) = data; }
1477+
1478+
FLASHINFER_INLINE void vec_t<uint8_t, 2>::memcpy(uint8_t* dst, const uint8_t* src) {
1479+
*((uint16_t*)dst) = *((uint16_t*)src);
1480+
}
1481+
1482+
// uint8_t x 4
1483+
1484+
template <>
1485+
struct vec_t<uint8_t, 4> {
1486+
uint32_t data;
1487+
1488+
FLASHINFER_INLINE uint8_t& operator[](size_t i) { return ((uint8_t*)(&data))[i]; }
1489+
FLASHINFER_INLINE const uint8_t& operator[](size_t i) const {
1490+
return ((const uint8_t*)(&data))[i];
1491+
}
1492+
FLASHINFER_INLINE uint8_t* ptr() { return reinterpret_cast<uint8_t*>(&data); }
1493+
FLASHINFER_INLINE void fill(uint8_t val);
1494+
FLASHINFER_INLINE void load(const uint8_t* ptr);
1495+
FLASHINFER_INLINE void store(uint8_t* ptr) const;
1496+
template <typename T>
1497+
FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
1498+
cast_from_impl(*this, src);
1499+
}
1500+
template <typename T>
1501+
FLASHINFER_INLINE void cast_load(const T* ptr) {
1502+
cast_load_impl(*this, ptr);
1503+
}
1504+
template <typename T>
1505+
FLASHINFER_INLINE void cast_store(T* ptr) const {
1506+
cast_store_impl(ptr, *this);
1507+
}
1508+
1509+
FLASHINFER_INLINE static void memcpy(uint8_t* dst, const uint8_t* src);
1510+
};
1511+
1512+
FLASHINFER_INLINE void vec_t<uint8_t, 4>::fill(uint8_t val) {
1513+
data = (uint32_t(val) << 24) | (uint32_t(val) << 16) | (uint32_t(val) << 8) | uint32_t(val);
1514+
}
1515+
1516+
FLASHINFER_INLINE void vec_t<uint8_t, 4>::load(const uint8_t* ptr) { data = *((uint32_t*)ptr); }
1517+
1518+
FLASHINFER_INLINE void vec_t<uint8_t, 4>::store(uint8_t* ptr) const { *((uint32_t*)ptr) = data; }
1519+
1520+
FLASHINFER_INLINE void vec_t<uint8_t, 4>::memcpy(uint8_t* dst, const uint8_t* src) {
1521+
*((uint32_t*)dst) = *((uint32_t*)src);
1522+
}
1523+
1524+
// uint8_t x 8
1525+
1526+
template <>
1527+
struct vec_t<uint8_t, 8> {
1528+
uint2 data;
1529+
1530+
FLASHINFER_INLINE uint8_t& operator[](size_t i) { return ((uint8_t*)(&data))[i]; }
1531+
FLASHINFER_INLINE const uint8_t& operator[](size_t i) const {
1532+
return ((const uint8_t*)(&data))[i];
1533+
}
1534+
FLASHINFER_INLINE uint8_t* ptr() { return reinterpret_cast<uint8_t*>(&data); }
1535+
FLASHINFER_INLINE void fill(uint8_t val);
1536+
FLASHINFER_INLINE void load(const uint8_t* ptr);
1537+
FLASHINFER_INLINE void store(uint8_t* ptr) const;
1538+
template <typename T>
1539+
FLASHINFER_INLINE void cast_from(const vec_t<T, 8>& src) {
1540+
cast_from_impl(*this, src);
1541+
}
1542+
template <typename T>
1543+
FLASHINFER_INLINE void cast_load(const T* ptr) {
1544+
cast_load_impl(*this, ptr);
1545+
}
1546+
template <typename T>
1547+
FLASHINFER_INLINE void cast_store(T* ptr) const {
1548+
cast_store_impl(ptr, *this);
1549+
}
1550+
FLASHINFER_INLINE static void memcpy(uint8_t* dst, const uint8_t* src);
1551+
};
1552+
1553+
FLASHINFER_INLINE void vec_t<uint8_t, 8>::fill(uint8_t val) {
1554+
uint32_t val32 =
1555+
(uint32_t(val) << 24) | (uint32_t(val) << 16) | (uint32_t(val) << 8) | uint32_t(val);
1556+
data.x = val32;
1557+
data.y = val32;
1558+
}
1559+
1560+
FLASHINFER_INLINE void vec_t<uint8_t, 8>::load(const uint8_t* ptr) { data = *((uint2*)ptr); }
1561+
1562+
FLASHINFER_INLINE void vec_t<uint8_t, 8>::store(uint8_t* ptr) const { *((uint2*)ptr) = data; }
1563+
1564+
FLASHINFER_INLINE void vec_t<uint8_t, 8>::memcpy(uint8_t* dst, const uint8_t* src) {
1565+
*((uint2*)dst) = *((uint2*)src);
1566+
}
1567+
1568+
// uint8_t x 16 or more
1569+
1570+
template <size_t vec_size>
1571+
struct vec_t<uint8_t, vec_size> {
1572+
static_assert(vec_size % 16 == 0, "Invalid vector size");
1573+
int4 data[vec_size / 16];
1574+
1575+
FLASHINFER_INLINE uint8_t& operator[](size_t i) { return ((uint8_t*)data)[i]; }
1576+
FLASHINFER_INLINE const uint8_t& operator[](size_t i) const { return ((const uint8_t*)data)[i]; }
1577+
FLASHINFER_INLINE uint8_t* ptr() { return reinterpret_cast<uint8_t*>(&data); }
1578+
FLASHINFER_INLINE void fill(uint8_t val) {
1579+
uint32_t val32 =
1580+
(uint32_t(val) << 24) | (uint32_t(val) << 16) | (uint32_t(val) << 8) | uint32_t(val);
1581+
#pragma unroll
1582+
for (size_t i = 0; i < vec_size / 16; ++i) {
1583+
data[i].x = val32;
1584+
data[i].y = val32;
1585+
data[i].z = val32;
1586+
data[i].w = val32;
1587+
}
1588+
}
1589+
FLASHINFER_INLINE void load(const uint8_t* ptr) {
1590+
#pragma unroll
1591+
for (size_t i = 0; i < vec_size / 16; ++i) {
1592+
data[i] = ((int4*)ptr)[i];
1593+
}
1594+
}
1595+
FLASHINFER_INLINE void store(uint8_t* ptr) const {
1596+
#pragma unroll
1597+
for (size_t i = 0; i < vec_size / 16; ++i) {
1598+
((int4*)ptr)[i] = data[i];
1599+
}
1600+
}
1601+
FLASHINFER_INLINE void load_global_acquire(uint8_t* addr) {
1602+
#pragma unroll
1603+
for (size_t i = 0; i < vec_size / 16; ++i) {
1604+
data[i] = ld_global_acquire((int4*)(addr + i * 16));
1605+
}
1606+
}
1607+
FLASHINFER_INLINE void store_global_release(uint8_t* addr) const {
1608+
#pragma unroll
1609+
for (size_t i = 0; i < vec_size / 16; ++i) {
1610+
st_global_release(data[i], (int4*)(addr + i * 16));
1611+
}
1612+
}
1613+
FLASHINFER_INLINE void load_global_volatile(uint8_t* addr) {
1614+
#pragma unroll
1615+
for (size_t i = 0; i < vec_size / 16; ++i) {
1616+
data[i] = ld_global_volatile((int4*)(addr + i * 16));
1617+
}
1618+
}
1619+
FLASHINFER_INLINE void store_global_volatile(uint8_t* addr) const {
1620+
#pragma unroll
1621+
for (size_t i = 0; i < vec_size / 16; ++i) {
1622+
st_global_volatile(data[i], (int4*)(addr + i * 16));
1623+
}
1624+
}
1625+
template <typename T>
1626+
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
1627+
cast_from_impl(*this, src);
1628+
}
1629+
template <typename T>
1630+
FLASHINFER_INLINE void cast_load(const T* ptr) {
1631+
cast_load_impl(*this, ptr);
1632+
}
1633+
template <typename T>
1634+
FLASHINFER_INLINE void cast_store(T* ptr) const {
1635+
cast_store_impl(ptr, *this);
1636+
}
1637+
FLASHINFER_INLINE static void memcpy(uint8_t* dst, const uint8_t* src) {
1638+
#pragma unroll
1639+
for (size_t i = 0; i < vec_size / 16; ++i) {
1640+
((int4*)dst)[i] = ((int4*)src)[i];
1641+
}
1642+
}
1643+
};
1644+
14021645
/******************* vec_t<float> *******************/
14031646

14041647
// float x 1

0 commit comments

Comments
 (0)