Skip to content

[SYCL] Fix handling of multiple usages of composite spec constants #2894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
; RUN: sycl-post-link -spec-const=rt --ir-output-only %s -S -o - \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to reduce this test to leave the bare minimum required to verify that the fix makes the SpecConstants pass behave as expected. The test does not need to be complete real-life test - the latter already exists in .cpp form, which should be enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the test in 7de5a3a

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks much better now. Nit: @llvm.lifetime.start.p0i8 could be removed as well

; RUN: | FileCheck %s --implicit-check-not __sycl_getCompositeSpecConstantValue
;
; This test is intended to check that sycl-post-link tool is capable of handling
; situations when the same composite specialization constants is used more than
; once. Unlike multiple-composite-spec-const-usages.ll test, this is a real life
; LLVM IR example
;
; CHECK-LABEL: @_ZTSN4test8kernel_tIfEE
; CHECK: %[[#X1:]] = call float @_Z20__spirv_SpecConstantif(i32 0, float 0
; CHECK: %[[#Y1:]] = call float @_Z20__spirv_SpecConstantif(i32 1, float 0
; CHECK: call {{.*}} @_Z29__spirv_SpecConstantCompositeff(float %[[#X1]], float %[[#Y1]]), !SYCL_SPEC_CONST_SYM_ID ![[#ID:]]
; CHECK-LABEL: @_ZTSN4test8kernel_tIiEE
; CHECK: %[[#X2:]] = call float @_Z20__spirv_SpecConstantif(i32 0, float 0
; CHECK: %[[#Y2:]] = call float @_Z20__spirv_SpecConstantif(i32 1, float 0
; CHECK: call {{.*}} @_Z29__spirv_SpecConstantCompositeff(float %[[#X2]], float %[[#Y2]]), !SYCL_SPEC_CONST_SYM_ID ![[#ID]]
; CHECK: ![[#ID]] = !{!"_ZTS11sc_kernel_t", i32 0, i32 1}

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown-sycldevice"

%"struct._ZTSN4test5pod_tE.test::pod_t" = type { float, float }

$_ZTSN4test8kernel_tIfEE = comdat any

$_ZTSN4test8kernel_tIiEE = comdat any

@__builtin_unique_stable_name._ZNK2cl4sycl6ONEAPI12experimental13spec_constantIN4test5pod_tE11sc_kernel_tE3getIS5_EENSt9enable_ifIXaasr3std8is_classIT_EE5valuesr3std6is_podISA_EE5valueESA_E4typeEv = private unnamed_addr addrspace(1) constant [18 x i8] c"_ZTS11sc_kernel_t\00", align 1

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTSN4test8kernel_tIfEE() local_unnamed_addr #0 comdat !kernel_arg_buffer_location !4 {
entry:
%ref.tmp.i = alloca %"struct._ZTSN4test5pod_tE.test::pod_t", align 4
%0 = bitcast %"struct._ZTSN4test5pod_tE.test::pod_t"* %ref.tmp.i to i8*
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0) #3
%1 = addrspacecast %"struct._ZTSN4test5pod_tE.test::pod_t"* %ref.tmp.i to %"struct._ZTSN4test5pod_tE.test::pod_t" addrspace(4)*
call spir_func void @_Z36__sycl_getCompositeSpecConstantValueIN4test5pod_tEET_PKc(%"struct._ZTSN4test5pod_tE.test::pod_t" addrspace(4)* sret(%"struct._ZTSN4test5pod_tE.test::pod_t") align 4 %1, i8 addrspace(4)* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([18 x i8], [18 x i8] addrspace(1)* @__builtin_unique_stable_name._ZNK2cl4sycl6ONEAPI12experimental13spec_constantIN4test5pod_tE11sc_kernel_tE3getIS5_EENSt9enable_ifIXaasr3std8is_classIT_EE5valuesr3std6is_podISA_EE5valueESA_E4typeEv, i64 0, i64 0) to i8 addrspace(4)*)) #4
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0) #3
ret void
}

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1

; Function Attrs: convergent
declare dso_local spir_func void @_Z36__sycl_getCompositeSpecConstantValueIN4test5pod_tEET_PKc(%"struct._ZTSN4test5pod_tE.test::pod_t" addrspace(4)* sret(%"struct._ZTSN4test5pod_tE.test::pod_t") align 4, i8 addrspace(4)*) local_unnamed_addr #2

; Function Attrs: convergent norecurse
define weak_odr dso_local spir_kernel void @_ZTSN4test8kernel_tIiEE() local_unnamed_addr #0 comdat !kernel_arg_buffer_location !4 {
entry:
%ref.tmp.i = alloca %"struct._ZTSN4test5pod_tE.test::pod_t", align 4
%0 = bitcast %"struct._ZTSN4test5pod_tE.test::pod_t"* %ref.tmp.i to i8*
call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0) #3
%1 = addrspacecast %"struct._ZTSN4test5pod_tE.test::pod_t"* %ref.tmp.i to %"struct._ZTSN4test5pod_tE.test::pod_t" addrspace(4)*
call spir_func void @_Z36__sycl_getCompositeSpecConstantValueIN4test5pod_tEET_PKc(%"struct._ZTSN4test5pod_tE.test::pod_t" addrspace(4)* sret(%"struct._ZTSN4test5pod_tE.test::pod_t") align 4 %1, i8 addrspace(4)* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([18 x i8], [18 x i8] addrspace(1)* @__builtin_unique_stable_name._ZNK2cl4sycl6ONEAPI12experimental13spec_constantIN4test5pod_tE11sc_kernel_tE3getIS5_EENSt9enable_ifIXaasr3std8is_classIT_EE5valuesr3std6is_podISA_EE5valueESA_E4typeEv, i64 0, i64 0) to i8 addrspace(4)*)) #4
call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0) #3
ret void
}

attributes #0 = { convergent norecurse "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="repro-1.cpp" "uniform-work-group-size"="true" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { argmemonly nofree nosync nounwind willreturn }
attributes #2 = { convergent "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #3 = { nounwind }
attributes #4 = { convergent }

!llvm.module.flags = !{!0}
!opencl.spir.version = !{!1}
!spirv.Source = !{!2}
!llvm.ident = !{!3}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 1, i32 2}
!2 = !{i32 4, i32 100000}
!3 = !{!"clang version 12.0.0 (/data/github.com/intel/llvm/clang 9b7086f7cef079b80ac5e137394f8d77d5d49c3e)"}
!4 = !{}
30 changes: 10 additions & 20 deletions llvm/tools/sycl-post-link/SpecConstants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,7 @@ Instruction *emitSpecConstantComposite(Type *Ty,
/// first ID. If \c IsNewSpecConstant is false, this vector is expected to
/// contain enough elements to assign ID to each scalar element encountered in
/// the specified composite type.
/// @param IsNewSpecConstant [in] Flag to specify whether \c IDs vector should
/// be filled with new IDs or it should be used as-is to replicate an existing
/// spec constant
/// @param [in,out] IsFirstElement Flag indicating whether this function is
/// handling the first scalar element encountered in the specified composite
/// type \c Ty or not.
/// @param [in,out] Index Index of scalar element within a composite type
///
/// @returns Instruction* representing specialization constant in LLVM IR, which
/// is in SPIR-V friendly LLVM IR form.
Expand All @@ -335,22 +330,20 @@ Instruction *emitSpecConstantComposite(Type *Ty,
/// encountered scalars and assigns them IDs (or re-uses existing ones).
Instruction *emitSpecConstantRecursiveImpl(Type *Ty, Instruction *InsertBefore,
SmallVectorImpl<unsigned> &IDs,
bool IsNewSpecConstant,
bool &IsFirstElement) {
unsigned &Index) {
if (!Ty->isArrayTy() && !Ty->isStructTy() && !Ty->isVectorTy()) { // Scalar
if (IsNewSpecConstant && !IsFirstElement) {
if (Index >= IDs.size()) {
// If it is a new specialization constant, we need to generate IDs for
// scalar elements, starting with the second one.
IDs.push_back(IDs.back() + 1);
}
IsFirstElement = false;
return emitSpecConstant(IDs.back(), Ty, InsertBefore);
return emitSpecConstant(IDs[Index++], Ty, InsertBefore);
}

SmallVector<Instruction *, 8> Elements;
auto LoopIteration = [&](Type *Ty) {
Elements.push_back(emitSpecConstantRecursiveImpl(
Ty, InsertBefore, IDs, IsNewSpecConstant, IsFirstElement));
Elements.push_back(
emitSpecConstantRecursiveImpl(Ty, InsertBefore, IDs, Index));
};

if (auto *ArrTy = dyn_cast<ArrayType>(Ty)) {
Expand All @@ -374,11 +367,9 @@ Instruction *emitSpecConstantRecursiveImpl(Type *Ty, Instruction *InsertBefore,

/// Wrapper intended to hide IsFirstElement argument from the caller
Instruction *emitSpecConstantRecursive(Type *Ty, Instruction *InsertBefore,
SmallVectorImpl<unsigned> &IDs,
bool IsNewSpecConstant) {
bool IsFirstElement = true;
return emitSpecConstantRecursiveImpl(Ty, InsertBefore, IDs, IsNewSpecConstant,
IsFirstElement);
SmallVectorImpl<unsigned> &IDs) {
unsigned Index = 0;
return emitSpecConstantRecursiveImpl(Ty, InsertBefore, IDs, Index);
}

} // namespace
Expand Down Expand Up @@ -446,8 +437,7 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,

// 3. Transform to spirv intrinsic _Z*__spirv_SpecConstant* or
// _Z*__spirv_SpecConstantComposite
auto *SPIRVCall =
emitSpecConstantRecursive(SCTy, CI, IDs, IsNewSpecConstant);
auto *SPIRVCall = emitSpecConstantRecursive(SCTy, CI, IDs);
if (IsNewSpecConstant) {
// emitSpecConstantRecursive might emit more than one spec constant
// (because of composite types) and therefore, we need to ajudst
Expand Down
76 changes: 76 additions & 0 deletions sycl/test/on-device/spec_const/multiple-usages-of-composite.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// UNSUPPORTED: cuda
//
// RUN: %clangxx -fsycl %s -o %t.out
// RUN: %RUN_ON_HOST %t.out | FileCheck %s
// RUN: %CPU_RUN_PLACEHOLDER %t.out %CPU_CHECK_PLACEHOLDER
// RUN: %GPU_RUN_PLACEHOLDER %t.out %GPU_CHECK_PLACEHOLDER
//
// The test checks that multiple usages of the same specialization constant
// works correctly: toolchain processes them correctly and runtime can
// correctly execute the program.
//
// CHECK: --------> 1

#include <CL/sycl.hpp>

using namespace cl::sycl;

class sc_kernel_t;

namespace test {

struct pod_t {
float x;
float y;
};

template <typename T> class kernel_t {
public:
using sc_t = sycl::ONEAPI::experimental::spec_constant<pod_t, sc_kernel_t>;

kernel_t(const sc_t &sc, cl::sycl::stream &strm) : sc_(sc), strm_(strm) {}

void operator()(cl::sycl::id<1> i) const {
strm_ << "--------> " << sc_.get().x << sycl::endl;
}

sc_t sc_;
cl::sycl::stream strm_;
};

template <typename T> class kernel_driver_t {
public:
void execute(const pod_t &pod) {
device dev = sycl::device(default_selector{});
context ctx = context(dev);
queue q(dev);

cl::sycl::program p(q.get_context());
auto sc = p.set_spec_constant<sc_kernel_t>(pod);
p.build_with_kernel_type<kernel_t<T>>();

q.submit([&](cl::sycl::handler &cgh) {
cl::sycl::stream strm(1024, 256, cgh);
kernel_t<T> func(sc, strm);

auto sycl_kernel = p.get_kernel<kernel_t<T>>();
cgh.parallel_for(sycl_kernel, cl::sycl::range<1>(1), func);
});
q.wait();
}
};

template class kernel_driver_t<float>;

// The line below instantiates the second use of the spec constant named
// `sc_kernel_t`, which used to corrupt the spec constant content
template class kernel_driver_t<int>;
} // namespace test

int main() {
test::pod_t pod = {1, 2};
test::kernel_driver_t<float> kd_float;
kd_float.execute(pod);

return 0;
}