Skip to content

Properly handle SLM memory at linalg-to-xegpu pass #394

Closed
@dchigarev

Description

@dchigarev

[SUB-TASK OF #360]

Memrefs with a shared mem space require a special handling in xegpu dialect. For example, ND-descriptors do not support SLM memory for 2D loads at all and only support 1D loads with for f32 type.

More info on what's supported and what's not for SLM:

Lowering XeGPU to Intrinsics by IMEX:
 
xegpu.load_ND/store_ND:
1. 2D load: 2D SLM descriptors are not supported (no suitable intrinsics)
2. 1D load: can only load/store 64 elements of float32 (our pipeline only supports f16)
 
xegpu.load/store:
1. 2D load: can only load/store blocks of float32 (our pipeline only supports f16)
2. 1D load: can only load/store 32 elements of any type (float16/32)

Since our pipeline only supports the f16 type, we are left with the only option of loading small chunks of 32 elements from SLM memory and then concatenating them into properly sized chunks:

example of expected lowering to xegpu
// EXAMPLE: load and store 4x32xf16 SLM block using xegpu.load/store

// allocate SLM memory
%slm_memref = memref.alloc() : memref<4x32xf16, 3>

// cast 2D-memref to 1D-memref (only 1d memref supported by xegpu.create_tdesc descriptors)
%slm_memref_flat = memref.reinterpret_cast %slm_memref to offset: [0], sizes: [128], strides: [1] : memref<4x32xf16, 3> to memref<128xf16, 3>

// create tensor descriptors (max size = 32)
%desc0 = xegpu.create_tdesc %slm_memref_flat, %offsets0 : memref<128xf16, 3>, vector<32xindex> -> !xegpu.tensor_desc<32xf16, #slm>
%desc1 = xegpu.update_offset %desc0, %offsets1 : !xegpu.tensor_desc<32xf16, #slm>, vector<32xindex>
%desc2 = xegpu.update_offset %desc0, %offsets2 : !xegpu.tensor_desc<32xf16, #slm>, vector<32xindex>
%desc3 = xegpu.update_offset %desc0, %offsets3 : !xegpu.tensor_desc<32xf16, #slm>, vector<32xindex>

// load the chunks
%chunk0 = xegpu.load %desc0, %mask : !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1> -> vector<32xf16>
%chunk1 = xegpu.load %desc1, %mask : !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1> -> vector<32xf16>
%chunk2 = xegpu.load %desc2, %mask : !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1> -> vector<32xf16>
%chunk3 = xegpu.load %desc3, %mask : !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1> -> vector<32xf16>

// concatenate the chunks
%concated0 = arith.constant dense<0.0> : vector<4x32xf16>

%concated1 = vector.insert %chunk0, %concated0[0] : vector<32xf16> into vector<4x32xf16>
%concated2 = vector.insert %chunk1, %concated1[1] : vector<32xf16> into vector<4x32xf16>
%concated3 = vector.insert %chunk2, %concated2[2] : vector<32xf16> into vector<4x32xf16>
%concated4 = vector.insert %chunk3, %concated3[3] : vector<32xf16> into vector<4x32xf16>

// perform some op
%res = arith.addf %concated4, %non_slm_vector : vector<4x32xf16> -> vector<4x32xf16>

// splitthe result to store chunks
%st_chunk0 = vector.extract %res[0] : vector<4x32xf16> to vector<32xf16>
%st_chunk1 = vector.extract %res[1] : vector<4x32xf16> to vector<32xf16>
%st_chunk2 = vector.extract %res[2] : vector<4x32xf16> to vector<32xf16>
%st_chunk3 = vector.extract %res[3] : vector<4x32xf16> to vector<32xf16>

// store the chunks
xegpu.store %st_chunk0, %desc0, %mask: vector<32xf16>, !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1>
xegpu.store %st_chunk1, %desc1, %mask: vector<32xf16>, !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1>
xegpu.store %st_chunk2, %desc2, %mask: vector<32xf16>, !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1>
xegpu.store %st_chunk3, %desc3, %mask: vector<32xf16>, !xegpu.tensor_desc<32xf16, #slm>, vector<32xi1>

P.S. That many loads/stores will likely affect performance, but the use of SLM should be an exception (functional path) and should ideally be avoided in the first place

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions