diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index e795947add373..30ea342073209 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -969,6 +969,7 @@ static CompoundStmt *CreateOpenCLKernelBody(Sema &S, // need to replace all refs to this kernel oject with refs to our clone // declared inside kernel body. Stmt *FunctionBody = KernelCallerFunc->getBody(); + ParmVarDecl *KernelObjParam = *(KernelCallerFunc->param_begin()); // DeclRefExpr with valid source location but with decl which is not marked @@ -1001,311 +1002,580 @@ static ParamDesc makeParamDesc(const FieldDecl *Src, QualType Ty) { Ctx.getTrivialTypeSourceInfo(Ty)); } +static ParamDesc makeParamDesc(ASTContext &Ctx, const CXXBaseSpecifier &Src, + QualType Ty) { + // TODO: There is no name for the base available, but duplicate names are + // seemingly already possible, so we'll give them all the same name for now. + // This only happens with the accessor types. + std::string Name = "_arg__base"; + return std::make_tuple(Ty, &Ctx.Idents.get(Name), + Ctx.getTrivialTypeSourceInfo(Ty)); +} + /// \return the target of given SYCL accessor type static target getAccessTarget(const ClassTemplateSpecializationDecl *AccTy) { return static_cast( AccTy->getTemplateArgs()[3].getAsIntegral().getExtValue()); } -// Creates list of kernel parameters descriptors using KernelObj (kernel object) -// Fields of kernel object must be initialized with SYCL kernel arguments so -// in the following function we extract types of kernel object fields and add it -// to the array with kernel parameters descriptors. -// Returns true if all arguments are successfully built. -static bool buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj, - SmallVectorImpl &ParamDescs) { - auto CreateAndAddPrmDsc = [&](const FieldDecl *Fld, const QualType &ArgType) { - // Create a parameter descriptor and append it to the result - ParamDescs.push_back(makeParamDesc(Fld, ArgType)); - }; - - // Creates a parameter descriptor for SYCL special object - SYCL accessor or - // sampler. +// The first template argument to the kernel function is used to identify the +// kernel itself. +static QualType calculateKernelNameType(ASTContext &Ctx, + FunctionDecl *KernelCallerFunc) { + // TODO: Not sure what the 'fully qualified type's purpose is here, the type + // itself should have its full qualified name, so figure out what the purpose + // is. + const TemplateArgumentList *TAL = + KernelCallerFunc->getTemplateSpecializationArgs(); + return TypeName::getFullyQualifiedType(TAL->get(0).getAsType(), Ctx, + /*WithGlobalNSPrefix=*/true); +} + +// Gets a name for the kernel caller func, calculated from the first template +// argument. +static std::string constructKernelName(Sema &S, FunctionDecl *KernelCallerFunc, + MangleContext &MC, bool StableName) { + QualType KernelNameType = + calculateKernelNameType(S.getASTContext(), KernelCallerFunc); + + if (StableName) + return PredefinedExpr::ComputeName(S.getASTContext(), + PredefinedExpr::UniqueStableNameType, + KernelNameType); + SmallString<256> Result; + llvm::raw_svector_ostream Out(Result); + + MC.mangleTypeName(KernelNameType, Out); + return std::string(Out.str()); +} + +// anonymous namespace so these don't get linkage. +namespace { + +QualType getItemType(const FieldDecl *FD) { return FD->getType(); } +QualType getItemType(const CXXBaseSpecifier &BS) { return BS.getType(); } + +// Implements the 'for-each-visitor' pattern. +template +static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent, + CXXRecordDecl *Wrapper, + Handlers &... handlers); + +template +static void VisitAccessorWrapperHelper(CXXRecordDecl *Owner, RangeTy Range, + Handlers &... handlers) { + for (const auto &Item : Range) { + QualType ItemTy = getItemType(Item); + if (Util::isSyclAccessorType(ItemTy)) + (void)std::initializer_list{ + (handlers.handleSyclAccessorType(Item, ItemTy), 0)...}; + else if (Util::isSyclStreamType(ItemTy)) + (void)std::initializer_list{ + (handlers.handleSyclStreamType(Item, ItemTy), 0)...}; + else if (ItemTy->isStructureOrClassType()) { + VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(), + handlers...); + } + } +} + +// poorly named Parent is the 'how we got here', basically just enough info for +// the offset adjustment to know what to do about the enter-struct info. +template +static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent, + CXXRecordDecl *Wrapper, + Handlers &... handlers) { + (void)std::initializer_list{(handlers.enterStruct(Owner, Parent), 0)...}; + VisitAccessorWrapperHelper(Wrapper, Wrapper->bases(), handlers...); + VisitAccessorWrapperHelper(Wrapper, Wrapper->fields(), handlers...); + (void)std::initializer_list{(handlers.leaveStruct(Owner, Parent), 0)...}; +} + +// A visitor function that dispatches to functions as defined in +// SyclKernelFieldHandler for the purposes of kernel generation. +template +static void VisitRecordFields(RecordDecl::field_range Fields, + Handlers &... handlers) { +#define KF_FOR_EACH(FUNC) \ + (void)std::initializer_list { (handlers.FUNC(Field, FieldTy), 0)... } + + for (const auto &Field : Fields) { + QualType FieldTy = Field->getType(); + + if (Util::isSyclAccessorType(FieldTy)) + KF_FOR_EACH(handleSyclAccessorType); + else if (Util::isSyclSamplerType(FieldTy)) + KF_FOR_EACH(handleSyclSamplerType); + else if (Util::isSyclSpecConstantType(FieldTy)) + KF_FOR_EACH(handleSyclSpecConstantType); + else if (Util::isSyclStreamType(FieldTy)) + KF_FOR_EACH(handleSyclStreamType); + else if (FieldTy->isStructureOrClassType()) { + KF_FOR_EACH(handleStructType); + CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); + VisitAccessorWrapper(nullptr, Field, RD, handlers...); + } else if (FieldTy->isReferenceType()) + KF_FOR_EACH(handleReferenceType); + else if (FieldTy->isPointerType()) + KF_FOR_EACH(handlePointerType); + else if (FieldTy->isArrayType()) + KF_FOR_EACH(handleArrayType); + else if (FieldTy->isScalarType()) + KF_FOR_EACH(handleScalarType); + else + KF_FOR_EACH(handleOtherType); + } +#undef KF_FOR_EACH +} + +// A base type that the SYCL OpenCL Kernel construction task uses to implement +// individual tasks. +template class SyclKernelFieldHandler { +protected: + Sema &SemaRef; + SyclKernelFieldHandler(Sema &S) : SemaRef(S) {} + +public: + // Mark these virutal so that we can use override in the implementer classes, + // despite virtual dispatch never being used. + + //// TODO: Can these return 'bool' and we can short-circuit the handling? That + // way the field checker cna return true/false based on whether the rest + // should be still working. + + // Accessor can be a base class or a field decl, so both must be handled. + virtual void handleSyclAccessorType(const CXXBaseSpecifier &, QualType) {} + virtual void handleSyclAccessorType(const FieldDecl *, QualType) {} + virtual void handleSyclSamplerType(const FieldDecl *, QualType) {} + virtual void handleSyclSpecConstantType(const FieldDecl *, QualType) {} + virtual void handleSyclStreamType(const CXXBaseSpecifier &, QualType) {} + virtual void handleSyclStreamType(const FieldDecl *, QualType) {} + virtual void handleStructType(const FieldDecl *, QualType) {} + virtual void handleReferenceType(const FieldDecl *, QualType) {} + virtual void handlePointerType(const FieldDecl *, QualType) {} + virtual void handleArrayType(const FieldDecl *, QualType) {} + virtual void handleScalarType(const FieldDecl *, QualType) {} + // Most handlers shouldn't be handling this, just the field checker. + virtual void handleOtherType(const FieldDecl *, QualType) {} + + // The following are only used for keeping track of where we are in the base + // class/field graph. Int Headers use this to calculate offset, most others + // don't have a need for these. + + virtual void enterStruct(const CXXRecordDecl *, const FieldDecl *) {} + virtual void leaveStruct(const CXXRecordDecl *, const FieldDecl *) {} + virtual void enterStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {} + virtual void leaveStruct(const CXXRecordDecl *, const CXXBaseSpecifier &) {} + // virtual void enterStruct(const FieldDecl *, CXXRecordDecl *Struct); + // virtual void leaveStruct(const FieldDecl *, CXXRecordDecl *Struct); +}; + +// A type to check the valididty of all of the argument types. +class SyclKernelFieldChecker + : public SyclKernelFieldHandler { + bool IsInvalid = false; + DiagnosticsEngine &Diag; + +public: + SyclKernelFieldChecker(Sema &S) + : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {} + bool isValid() { return !IsInvalid; } + + void handleReferenceType(const FieldDecl *FD, QualType ArgTy) final { + IsInvalid = Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) + << ArgTy; + } + void handleStructType(const FieldDecl *FD, QualType ArgTy) final { + if (SemaRef.getASTContext().getLangOpts().SYCLStdLayoutKernelParams && + !ArgTy->isStandardLayoutType()) + IsInvalid = + Diag.Report(FD->getLocation(), diag::err_sycl_non_std_layout_type) + << ArgTy; + else { + CXXRecordDecl *RD = ArgTy->getAsCXXRecordDecl(); + if (!RD->hasTrivialCopyConstructor()) + + IsInvalid = + Diag.Report(FD->getLocation(), + diag::err_sycl_non_trivially_copy_ctor_dtor_type) + << 0 << ArgTy; + else if (!RD->hasTrivialDestructor()) + IsInvalid = + Diag.Report(FD->getLocation(), + diag::err_sycl_non_trivially_copy_ctor_dtor_type) + << 1 << ArgTy; + } + } + + // We should be able to ahndle this, so we made it part of the visitor, but + // this is 'to be implemented'. + void handleArrayType(const FieldDecl *FD, QualType ArgTy) final { + IsInvalid = Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) + << ArgTy; + } + + void handleOtherType(const FieldDecl *FD, QualType ArgTy) final { + IsInvalid = Diag.Report(FD->getLocation(), diag::err_bad_kernel_param_type) + << ArgTy; + } +}; + +// A type to Create and own the FunctionDecl for the kernel. +class SyclKernelDeclCreator + : public SyclKernelFieldHandler { + FunctionDecl *KernelDecl; + llvm::SmallVector Params; + SyclKernelFieldChecker &ArgChecker; + Sema::ContextRAII FuncContext; + // Holds the last handled field's first parameter. This doesn't store an + // iterator as push_back invalidates iterators. + size_t LastParamIndex = 0; + + void addParam(const FieldDecl *FD, QualType ArgTy) { + ParamDesc newParamDesc = makeParamDesc(FD, ArgTy); + addParam(newParamDesc, ArgTy); + } + + void addParam(const CXXBaseSpecifier &BS, QualType ArgTy) { + ParamDesc newParamDesc = makeParamDesc(SemaRef.getASTContext(), BS, ArgTy); + addParam(newParamDesc, ArgTy); + } + + void addParam(ParamDesc newParamDesc, QualType ArgTy) { + // Create a new ParmVarDecl based on the new info. + auto *NewParam = ParmVarDecl::Create( + SemaRef.getASTContext(), KernelDecl, SourceLocation(), SourceLocation(), + std::get<1>(newParamDesc), std::get<0>(newParamDesc), + std::get<2>(newParamDesc), SC_None, /*DefArg*/ nullptr); + + NewParam->setScopeInfo(0, Params.size()); + NewParam->setIsUsed(); + + LastParamIndex = Params.size(); + Params.push_back(NewParam); + } + // All special SYCL objects must have __init method. We extract types for // kernel parameters from __init method parameters. We will use __init method // and kernel parameters which we build here to initialize special objects in // the kernel body. - auto createSpecialSYCLObjParamDesc = [&](const FieldDecl *Fld, - const QualType &ArgTy) { + void handleSpecialType(const FieldDecl *FD, QualType ArgTy) { const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); - assert(RecordDecl && "Special SYCL object must be of a record type"); + assert(RecordDecl && "The accessor/sampler must be a RecordDecl"); + CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName); + assert(InitMethod && "The accessor/sampler must have the __init method"); + // Don't do -1 here because we count on this to be the first parameter added + // (if any). + size_t ParamIndex = Params.size(); + for (const ParmVarDecl *Param : InitMethod->parameters()) + addParam(FD, Param->getType().getCanonicalType()); + LastParamIndex = ParamIndex; + } + + static void setKernelImplicitAttrs(ASTContext &Context, FunctionDecl *FD, + StringRef Name) { + // Set implict attributes. + FD->addAttr(OpenCLKernelAttr::CreateImplicit(Context)); + FD->addAttr(AsmLabelAttr::CreateImplicit(Context, Name)); + FD->addAttr(ArtificialAttr::CreateImplicit(Context)); + } + + static FunctionDecl *createKernelDecl(ASTContext &Ctx, StringRef Name, + SourceLocation Loc, bool IsInline) { + // Create this with no prototype, and we can fix this up after we've seen + // all the params. + FunctionProtoType::ExtProtoInfo Info(CC_OpenCLKernel); + QualType FuncType = Ctx.getFunctionType(Ctx.VoidTy, {}, Info); + + FunctionDecl *FD = FunctionDecl::Create( + Ctx, Ctx.getTranslationUnitDecl(), Loc, Loc, &Ctx.Idents.get(Name), + FuncType, Ctx.getTrivialTypeSourceInfo(Ctx.VoidTy), SC_None); + FD->setImplicitlyInline(IsInline); + setKernelImplicitAttrs(Ctx, FD, Name); + + // Add kernel to translation unit to see it in AST-dump. + Ctx.getTranslationUnitDecl()->addDecl(FD); + return FD; + } + +public: + SyclKernelDeclCreator(Sema &S, SyclKernelFieldChecker &ArgChecker, + StringRef Name, SourceLocation Loc, bool IsInline) + : SyclKernelFieldHandler(S), + KernelDecl(createKernelDecl(S.getASTContext(), Name, Loc, IsInline)), + ArgChecker(ArgChecker), FuncContext(SemaRef, KernelDecl) {} + + ~SyclKernelDeclCreator() { + ASTContext &Ctx = SemaRef.getASTContext(); + FunctionProtoType::ExtProtoInfo Info(CC_OpenCLKernel); + + SmallVector ArgTys; + std::transform(std::begin(Params), std::end(Params), + std::back_inserter(ArgTys), + [](const ParmVarDecl *PVD) { return PVD->getType(); }); + + QualType FuncType = Ctx.getFunctionType(Ctx.VoidTy, ArgTys, Info); + KernelDecl->setType(FuncType); + KernelDecl->setParams(Params); + + if (ArgChecker.isValid()) + SemaRef.addSyclDeviceDecl(KernelDecl); + } + + void handleSyclAccessorType(const CXXBaseSpecifier &BS, + QualType ArgTy) final { + const auto *RecordDecl = ArgTy->getAsCXXRecordDecl(); + assert(RecordDecl && "The accessor/sampler must be a RecordDecl"); CXXMethodDecl *InitMethod = getMethodByName(RecordDecl, InitMethodName); assert(InitMethod && "The accessor/sampler must have the __init method"); - unsigned NumParams = InitMethod->getNumParams(); - for (size_t I = 0; I < NumParams; ++I) { - ParmVarDecl *PD = InitMethod->getParamDecl(I); - CreateAndAddPrmDsc(Fld, PD->getType().getCanonicalType()); - } - }; - - // Create parameter descriptor for accessor in case when it's wrapped with - // some class. - // TODO: Do we need support case when sampler is wrapped with some class or - // struct? - std::function - createParamDescForWrappedAccessors = - [&](const FieldDecl *Fld, const QualType &ArgTy) { - const auto *Wrapper = ArgTy->getAsCXXRecordDecl(); - for (const auto *WrapperFld : Wrapper->fields()) { - QualType FldType = WrapperFld->getType(); - if (FldType->isStructureOrClassType()) { - if (Util::isSyclAccessorType(FldType)) { - // Accessor field is found - create descriptor. - createSpecialSYCLObjParamDesc(WrapperFld, FldType); - } else if (Util::isSyclSpecConstantType(FldType)) { - // Don't try recursive search below. - } else { - // Field is some class or struct - recursively check for - // accessor fields. - createParamDescForWrappedAccessors(WrapperFld, FldType); - } - } - } - }; - - bool AllArgsAreValid = true; - // Run through kernel object fields and create corresponding kernel - // parameters descriptors. There are a several possible cases: - // - Kernel object field is a SYCL special object (SYCL accessor or SYCL - // sampler). These objects has a special initialization scheme - using - // __init method. - // - Kernel object field has a scalar type. In this case we should add - // kernel parameter with the same type. - // - Kernel object field has a structure or class type. Same handling as a - // scalar but we should check if this structure/class contains accessors - // and add parameter decriptor for them properly. - for (const auto *Fld : KernelObj->fields()) { - QualType ArgTy = Fld->getType(); - if (Util::isSyclAccessorType(ArgTy) || Util::isSyclSamplerType(ArgTy)) { - createSpecialSYCLObjParamDesc(Fld, ArgTy); - } else if (Util::isSyclSpecConstantType(ArgTy)) { - // Specialization constants are not added as arguments. - } else if (ArgTy->isStructureOrClassType()) { - if (Context.getLangOpts().SYCLStdLayoutKernelParams) { - if (!ArgTy->isStandardLayoutType()) { - Context.getDiagnostics().Report(Fld->getLocation(), - diag::err_sycl_non_std_layout_type) - << ArgTy; - AllArgsAreValid = false; - continue; - } - } - CXXRecordDecl *RD = - cast(ArgTy->getAs()->getDecl()); - if (!RD->hasTrivialCopyConstructor()) { - Context.getDiagnostics().Report( - Fld->getLocation(), - diag::err_sycl_non_trivially_copy_ctor_dtor_type) - << 0 << ArgTy; - AllArgsAreValid = false; - continue; - } - if (!RD->hasTrivialDestructor()) { - Context.getDiagnostics().Report( - Fld->getLocation(), - diag::err_sycl_non_trivially_copy_ctor_dtor_type) - << 1 << ArgTy; - AllArgsAreValid = false; - continue; - } + // Don't do -1 here because we count on this to be the first parameter added + // (if any). + size_t ParamIndex = Params.size(); + for (const ParmVarDecl *Param : InitMethod->parameters()) + addParam(BS, Param->getType().getCanonicalType()); + LastParamIndex = ParamIndex; + } - CreateAndAddPrmDsc(Fld, ArgTy); - - // Create descriptors for each accessor field in the class or struct - createParamDescForWrappedAccessors(Fld, ArgTy); - } else if (ArgTy->isReferenceType()) { - Context.getDiagnostics().Report( - Fld->getLocation(), diag::err_bad_kernel_param_type) << ArgTy; - AllArgsAreValid = false; - } else if (ArgTy->isPointerType()) { - // Pointer Arguments need to be in the global address space - QualType PointeeTy = ArgTy->getPointeeType(); - Qualifiers Quals = PointeeTy.getQualifiers(); - Quals.setAddressSpace(LangAS::opencl_global); - PointeeTy = - Context.getQualifiedType(PointeeTy.getUnqualifiedType(), Quals); - QualType ModTy = Context.getPointerType(PointeeTy); - - CreateAndAddPrmDsc(Fld, ModTy); - } else if (ArgTy->isScalarType()) { - CreateAndAddPrmDsc(Fld, ArgTy); - } else { - llvm_unreachable("Unsupported kernel parameter type"); - } + void handleSyclAccessorType(const FieldDecl *FD, QualType ArgTy) final { + handleSpecialType(FD, ArgTy); } - return AllArgsAreValid; -} + void handleSyclSamplerType(const FieldDecl *FD, QualType ArgTy) final { + handleSpecialType(FD, ArgTy); + } + + void handlePointerType(const FieldDecl *FD, QualType ArgTy) final { + // TODO: Can we document what the heck this is doing?! + QualType PointeeTy = ArgTy->getPointeeType(); + Qualifiers Quals = PointeeTy.getQualifiers(); + Quals.setAddressSpace(LangAS::opencl_global); + PointeeTy = SemaRef.getASTContext().getQualifiedType( + PointeeTy.getUnqualifiedType(), Quals); + QualType ModTy = SemaRef.getASTContext().getPointerType(PointeeTy); + addParam(FD, ModTy); + } + + void handleScalarType(const FieldDecl *FD, QualType ArgTy) final { + addParam(FD, ArgTy); + } -/// Adds necessary data describing given kernel to the integration header. -/// \param H the integration header object -/// \param Name kernel name -/// \param NameType type representing kernel name (first template argument -/// of single_task, parallel_for, etc) -/// \param KernelObjTy kernel object type -static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name, - QualType NameType, CXXRecordDecl *KernelObjTy) { - - ASTContext &Ctx = KernelObjTy->getASTContext(); - const ASTRecordLayout &Layout = Ctx.getASTRecordLayout(KernelObjTy); - const std::string StableName = PredefinedExpr::ComputeName( - Ctx, PredefinedExpr::UniqueStableNameExpr, NameType); - H.startKernel(Name, NameType, StableName, KernelObjTy->getLocation()); - - auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) { - // The parameter is a SYCL accessor object. - // The Info field of the parameter descriptor for accessor contains - // two template parameters packed into an integer field: - // - target (e.g. global_buffer, constant_buffer, local); - // - dimension of the accessor. - const auto *AccTy = ArgTy->getAsCXXRecordDecl(); - assert(AccTy && "accessor must be of a record type"); - const auto *AccTmplTy = cast(AccTy); + // This is implemented here because this is the only case where the recurse + // object is required. The base type is pretty cheap, so we might opt + // to just always create it (the way this one is implemented) and just put + // this implementation in the base. + void handleStructType(const FieldDecl *FD, QualType ArgTy) final { + addParam(FD, ArgTy); + } + + void setBody(CompoundStmt *KB) { KernelDecl->setBody(KB); } + + FunctionDecl *getKernelDecl() { return KernelDecl; } + + llvm::ArrayRef getParamVarDeclsForCurrentField() { + return ArrayRef(std::begin(Params) + LastParamIndex, + std::end(Params)); + } +}; + +class SyclKernelBodyCreator + : public SyclKernelFieldHandler { + SyclKernelDeclCreator &DeclCreator; + llvm::SmallVector BodyStmts; + llvm::SmallVector FinalizeStmts; + llvm::SmallVector InitExprs; + + // Using the statements/init expressions that we've created, this generates + // the kernel body compound stmt. CompoundStmt needs to know its number of + // statements in advance to allocate it, so we cannot do this as we go along. + CompoundStmt *createKernelBody() { + // TODO: Can we hold off on creating KernelObjClone to here? + + Expr *ILE = new (SemaRef.getASTContext()) InitListExpr( + SemaRef.getASTContext(), SourceLocation(), InitExprs, SourceLocation()); + // TODO!!! ILE->setType(QualType(LC->getTypeForDecl(), 0)); + // KernelObjectClone->setInit(ILE); + + // TODO: More kernel object init with KernelBodyTransform. + + BodyStmts.insert(std::end(BodyStmts), std::begin(FinalizeStmts), + std::begin(FinalizeStmts)); + return CompoundStmt::Create(SemaRef.getASTContext(), BodyStmts, {}, {}); + } + + // TODO: not sure what this does yet, name is a placeholder for future use. + void doSomethingForParallelForWorkGroup() { + } + +public: + SyclKernelBodyCreator(Sema &S, SyclKernelDeclCreator &DC, + KernelInvocationKind K) + : SyclKernelFieldHandler(S), DeclCreator(DC) { + // TODO: Something special with the lambda when InvokeParallelForWorkGroup. + if (K == InvokeParallelForWorkGroup) + doSomethingForParallelForWorkGroup(); + } + ~SyclKernelBodyCreator() { + CompoundStmt *KernelBody = createKernelBody(); + DeclCreator.setBody(KernelBody); + } + + void handleSyclAccessorType(const FieldDecl *FD, QualType Ty) final { + // TODO: Creates init sequence and inits special sycl obj + } + + void handleSyclAccessorType(const CXXBaseSpecifier &BS, QualType Ty) final { + // TODO: Creates init sequence and inits special sycl obj + } + + + void handleSyclSamplerType(const FieldDecl *FD, QualType Ty) final { + // TODO: Creates init sequence and inits special sycl obj + } + + void handleSyclStreamType(const FieldDecl *FD, QualType Ty) final { + // TODO: Creates init/finalize sequence and inits special sycl obj + } + + void handleSyclStreamType(const CXXBaseSpecifier &BS, QualType Ty) final { + // TODO: Creates init/finalize sequence and inits special sycl obj + } + + + void handleStructType(const FieldDecl *FD, QualType Ty) final { + // TODO: a bunch of work doing inits, note this has a little more than + // scalar. + } + void handleScalarType(const FieldDecl *FD, QualType Ty) final { + // TODO: a bunch of work doing inits. + } +}; + +class SyclKernelIntHeaderCreator + : public SyclKernelFieldHandler { + SYCLIntegrationHeader &Header; + const CXXRecordDecl *KernelLambda; + int64_t CurOffset = 0; + + uint64_t getOffset(const CXXRecordDecl *RD) const { + // TODO: Figure this out! Offset of a base class. + return 0; + } + uint64_t getOffset(const FieldDecl *FD) const { + // TODO: Figure out how to calc lower down the structs, currently only gives + // the 'base' value. + return CurOffset + SemaRef.getASTContext().getFieldOffset(FD) / 8; + } + + void addParam(const FieldDecl *FD, QualType ArgTy, + SYCLIntegrationHeader::kernel_param_kind_t Kind) { + uint64_t Size = + SemaRef.getASTContext().getTypeSizeInChars(ArgTy).getQuantity(); + Header.addParamDesc(Kind, static_cast(Size), + static_cast(getOffset(FD))); + } + +public: + SyclKernelIntHeaderCreator(Sema &S, SYCLIntegrationHeader &H, + const CXXRecordDecl *KernelLambda, + QualType NameType, StringRef Name, + StringRef StableName) + : SyclKernelFieldHandler(S), Header(H), KernelLambda(KernelLambda) { + Header.startKernel(Name, NameType, StableName, KernelLambda->getLocation()); + } + + void handleSyclAccessorType(const CXXBaseSpecifier &BC, + QualType ArgTy) final { + const auto *AccTy = + cast(ArgTy->getAsRecordDecl()); + assert(AccTy->getTemplateArgs().size() >= 2 && + "Incorrect template args for Accessor Type"); int Dims = static_cast( - AccTmplTy->getTemplateArgs()[1].getAsIntegral().getExtValue()); - int Info = getAccessTarget(AccTmplTy) | (Dims << 11); - H.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, Offset); - }; - - std::function - populateHeaderForWrappedAccessors = [&](const QualType &ArgTy, - uint64_t Offset) { - const auto *Wrapper = ArgTy->getAsCXXRecordDecl(); - for (const auto *WrapperFld : Wrapper->fields()) { - QualType FldType = WrapperFld->getType(); - if (FldType->isStructureOrClassType()) { - ASTContext &WrapperCtx = Wrapper->getASTContext(); - const ASTRecordLayout &WrapperLayout = - WrapperCtx.getASTRecordLayout(Wrapper); - // Get offset (in bytes) of the field in wrapper class or struct - uint64_t OffsetInWrapper = - WrapperLayout.getFieldOffset(WrapperFld->getFieldIndex()) / 8; - if (Util::isSyclAccessorType(FldType)) { - // This is an accesor - populate the header appropriately - populateHeaderForAccessor(FldType, Offset + OffsetInWrapper); - } else { - // This is an other class or struct - recursively search for an - // accessor field - populateHeaderForWrappedAccessors(FldType, - Offset + OffsetInWrapper); - } - } - } - }; + AccTy->getTemplateArgs()[1].getAsIntegral().getExtValue()); + int Info = getAccessTarget(AccTy) | (Dims << 11); + Header.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, + // TODO: is this the right way? + getOffset(BC.getType()->getAsCXXRecordDecl())); + } - for (const auto Fld : KernelObjTy->fields()) { - QualType ActualArgType; - QualType ArgTy = Fld->getType(); - - // Get offset in bytes - uint64_t Offset = Layout.getFieldOffset(Fld->getFieldIndex()) / 8; - - if (Util::isSyclAccessorType(ArgTy)) { - populateHeaderForAccessor(ArgTy, Offset); - } else if (Util::isSyclSamplerType(ArgTy)) { - // The parameter is a SYCL sampler object - const auto *SamplerTy = ArgTy->getAsCXXRecordDecl(); - assert(SamplerTy && "sampler must be of a record type"); - - CXXMethodDecl *InitMethod = getMethodByName(SamplerTy, InitMethodName); - assert(InitMethod && "sampler must have __init method"); - - // sampler __init method has only one argument - auto *FuncDecl = cast(InitMethod); - ParmVarDecl *SamplerArg = FuncDecl->getParamDecl(0); - assert(SamplerArg && "sampler __init method must have sampler parameter"); - uint64_t Sz = Ctx.getTypeSizeInChars(SamplerArg->getType()).getQuantity(); - H.addParamDesc(SYCLIntegrationHeader::kind_sampler, - static_cast(Sz), static_cast(Offset)); - } else if (ArgTy->isPointerType()) { - uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity(); - H.addParamDesc(SYCLIntegrationHeader::kind_pointer, - static_cast(Sz), static_cast(Offset)); - } else if (Util::isSyclSpecConstantType(ArgTy)) { - // Add specialization constant ID to the header. - auto *TmplSpec = - cast(ArgTy->getAsCXXRecordDecl()); - const TemplateArgumentList *TemplateArgs = - &TmplSpec->getTemplateInstantiationArgs(); - // Get specialization constant ID type, which is the second template - // argument. - QualType SpecConstIDTy = TypeName::getFullyQualifiedType( - TemplateArgs->get(1).getAsType(), Ctx, true) - .getCanonicalType(); - const std::string SpecConstName = PredefinedExpr::ComputeName( - Ctx, PredefinedExpr::UniqueStableNameExpr, SpecConstIDTy); - H.addSpecConstant(SpecConstName, SpecConstIDTy); - // Spec constant lambda capture does not become a kernel argument. - } else if (ArgTy->isStructureOrClassType() || ArgTy->isScalarType()) { - // the parameter is an object of standard layout type or scalar; - // the check for standard layout is done elsewhere - uint64_t Sz = Ctx.getTypeSizeInChars(Fld->getType()).getQuantity(); - H.addParamDesc(SYCLIntegrationHeader::kind_std_layout, - static_cast(Sz), static_cast(Offset)); - - // check for accessor fields in structure or class and populate the - // integration header appropriately - if (ArgTy->isStructureOrClassType()) { - populateHeaderForWrappedAccessors(ArgTy, Offset); - } - } else { - llvm_unreachable("unsupported kernel parameter type"); - } + void handleSyclAccessorType(const FieldDecl *FD, QualType ArgTy) final { + const auto *AccTy = + cast(ArgTy->getAsRecordDecl()); + assert(AccTy->getTemplateArgs().size() >= 2 && + "Incorrect template args for Accessor Type"); + int Dims = static_cast( + AccTy->getTemplateArgs()[1].getAsIntegral().getExtValue()); + int Info = getAccessTarget(AccTy) | (Dims << 11); + Header.addParamDesc(SYCLIntegrationHeader::kind_accessor, Info, + getOffset(FD)); } -} -// Creates a mangled kernel name for given kernel name type -static std::string constructKernelName(QualType KernelNameType, - MangleContext &MC) { - SmallString<256> Result; - llvm::raw_svector_ostream Out(Result); + void handleSyclSamplerType(const FieldDecl *FD, QualType ArgTy) final { + const auto *SamplerTy = ArgTy->getAsCXXRecordDecl(); + assert(SamplerTy && "Sampler type must be a C++ record type"); + CXXMethodDecl *InitMethod = getMethodByName(SamplerTy, InitMethodName); + assert(InitMethod && "sampler must have __init method"); - MC.mangleTypeName(KernelNameType, Out); - return std::string(Out.str()); -} + // sampler __init method has only one argument + const ParmVarDecl *SamplerArg = InitMethod->getParamDecl(0); + assert(SamplerArg && "sampler __init method must have sampler parameter"); -static FunctionDecl * -CreateOpenCLKernelDeclaration(ASTContext &Context, StringRef Name, - ArrayRef ParamDescs) { - - DeclContext *DC = Context.getTranslationUnitDecl(); - QualType RetTy = Context.VoidTy; - SmallVector ArgTys; - - // Extract argument types from the descriptor array: - std::transform( - ParamDescs.begin(), ParamDescs.end(), std::back_inserter(ArgTys), - [](const ParamDesc &PD) -> QualType { return std::get<0>(PD); }); - FunctionProtoType::ExtProtoInfo Info(CC_OpenCLKernel); - QualType FuncTy = Context.getFunctionType(RetTy, ArgTys, Info); - DeclarationName DN = DeclarationName(&Context.Idents.get(Name)); - - FunctionDecl *OpenCLKernel = FunctionDecl::Create( - Context, DC, SourceLocation(), SourceLocation(), DN, FuncTy, - Context.getTrivialTypeSourceInfo(RetTy), SC_None); - - llvm::SmallVector Params; - int i = 0; - for (const auto &PD : ParamDescs) { - auto P = ParmVarDecl::Create(Context, OpenCLKernel, SourceLocation(), - SourceLocation(), std::get<1>(PD), - std::get<0>(PD), std::get<2>(PD), SC_None, 0); - P->setScopeInfo(0, i++); - P->setIsUsed(); - Params.push_back(P); - } - OpenCLKernel->setParams(Params); - - OpenCLKernel->addAttr(OpenCLKernelAttr::CreateImplicit(Context)); - OpenCLKernel->addAttr(AsmLabelAttr::CreateImplicit(Context, Name)); - OpenCLKernel->addAttr(ArtificialAttr::CreateImplicit(Context)); - - // Add kernel to translation unit to see it in AST-dump - DC->addDecl(OpenCLKernel); - return OpenCLKernel; -} + addParam(FD, SamplerArg->getType(), SYCLIntegrationHeader::kind_sampler); + } + + void handleSyclSpecConstantType(const FieldDecl *FD, QualType ArgTy) final { + const TemplateArgumentList &TemplateArgs = + cast(ArgTy->getAsRecordDecl()) + ->getTemplateInstantiationArgs(); + assert(TemplateArgs.size() == 2 && + "Incorrect template args for Accessor Type"); + // Get specialization constant ID type, which is the second template + // argument. + QualType SpecConstIDTy = + TypeName::getFullyQualifiedType(TemplateArgs.get(1).getAsType(), + SemaRef.getASTContext(), true) + .getCanonicalType(); + const std::string SpecConstName = PredefinedExpr::ComputeName( + SemaRef.getASTContext(), PredefinedExpr::UniqueStableNameType, + SpecConstIDTy); + Header.addSpecConstant(SpecConstName, SpecConstIDTy); + } + + void handlePointerType(const FieldDecl *FD, QualType ArgTy) final { + addParam(FD, ArgTy, SYCLIntegrationHeader::kind_pointer); + } + void handleStructType(const FieldDecl *FD, QualType ArgTy) final { + addParam(FD, ArgTy, SYCLIntegrationHeader::kind_std_layout); + } + void handleScalarType(const FieldDecl *FD, QualType ArgTy) final { + addParam(FD, ArgTy, SYCLIntegrationHeader::kind_std_layout); + } + + // Keep track of the current struct offset. + void enterStruct(const CXXRecordDecl *, const FieldDecl *FD) final { + CurOffset += SemaRef.getASTContext().getFieldOffset(FD) / 8; + } + + void leaveStruct(const CXXRecordDecl *, const FieldDecl *FD) final { + CurOffset -= SemaRef.getASTContext().getFieldOffset(FD) / 8; + } + + void enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final { + const ASTRecordLayout &Layout = + SemaRef.getASTContext().getASTRecordLayout(RD); + CurOffset += Layout.getBaseClassOffset(BS.getType()->getAsCXXRecordDecl()) + .getQuantity(); + } + + void leaveStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final { + const ASTRecordLayout &Layout = + SemaRef.getASTContext().getASTRecordLayout(RD); + CurOffset -= Layout.getBaseClassOffset(BS.getType()->getAsCXXRecordDecl()) + .getQuantity(); + } +}; +} // namespace // Generates the OpenCL kernel using KernelCallerFunc (kernel caller // function) defined is SYCL headers. @@ -1331,52 +1601,42 @@ CreateOpenCLKernelDeclaration(ASTContext &Context, StringRef Name, // void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, MangleContext &MC) { - CXXRecordDecl *LE = getKernelObjectType(KernelCallerFunc); - assert(LE && "invalid kernel caller"); - - // Build list of kernel arguments - llvm::SmallVector ParamDescs; - if (!buildArgTys(getASTContext(), LE, ParamDescs)) - return; - - // Extract name from kernel caller parameters and mangle it. - const TemplateArgumentList *TemplateArgs = - KernelCallerFunc->getTemplateSpecializationArgs(); - assert(TemplateArgs && "No template argument info"); - QualType KernelNameType = TypeName::getFullyQualifiedType( - TemplateArgs->get(0).getAsType(), getASTContext(), true); - - std::string Name; - // TODO SYCLIntegrationHeader also computes a unique stable name. It should - // probably lose this responsibility and only use the name provided here. - if (getLangOpts().SYCLUnnamedLambda) - Name = PredefinedExpr::ComputeName( - getASTContext(), PredefinedExpr::UniqueStableNameExpr, KernelNameType); - else - Name = constructKernelName(KernelNameType, MC); - - // TODO Maybe don't emit integration header inside the Sema? - populateIntHeader(getSyclIntegrationHeader(), Name, KernelNameType, LE); - - FunctionDecl *OpenCLKernel = - CreateOpenCLKernelDeclaration(getASTContext(), Name, ParamDescs); - - ContextRAII FuncContext(*this, OpenCLKernel); - - // Let's copy source location of a functor/lambda to emit nicer diagnostics - OpenCLKernel->setLocation(LE->getLocation()); - - // If the source function is implicitly inline, the kernel should be marked - // such as well. This allows the kernel to be ODR'd if there are multiple uses - // in different translation units. - OpenCLKernel->setImplicitlyInline(KernelCallerFunc->isInlined()); + // The first argument to the KernelCallerFunc is the lambda object. + CXXRecordDecl *KernelLambda = getKernelObjectType(KernelCallerFunc); + assert(KernelLambda && "invalid kernel caller"); + + // Calculate both names, since Integration headers need both. + std::string CalculatedName = + constructKernelName(*this, KernelCallerFunc, MC, /*StableName*/ false); + std::string StableName = + constructKernelName(*this, KernelCallerFunc, MC, /*StableName*/ true); + StringRef KernelName(getLangOpts().SYCLUnnamedLambda ? StableName + : CalculatedName); + + SyclKernelFieldChecker checker(*this); + SyclKernelDeclCreator kernel_decl(*this, checker, KernelName, + KernelLambda->getLocation(), + KernelCallerFunc->isInlined()); + SyclKernelBodyCreator kernel_body(*this, kernel_decl, + getKernelInvocationKind(KernelCallerFunc)); + SyclKernelIntHeaderCreator int_header( + *this, getSyclIntegrationHeader(), KernelLambda, + calculateKernelNameType(Context, KernelCallerFunc), CalculatedName, + StableName); ConstructingOpenCLKernel = true; - CompoundStmt *OpenCLKernelBody = - CreateOpenCLKernelBody(*this, KernelCallerFunc, OpenCLKernel); + VisitRecordFields(KernelLambda->fields(), checker, kernel_decl, kernel_body, + int_header); ConstructingOpenCLKernel = false; - OpenCLKernel->setBody(OpenCLKernelBody); - addSyclDeviceDecl(OpenCLKernel); + + /* + //ConstructingOpenCLKernel = true; + ****CompoundStmt *OpenCLKernelBody = + CreateOpenCLKernelBody(*this, KernelCallerFunc, OpenCLKernel); + //ConstructingOpenCLKernel = false; + //OpenCLKernel->setBody(OpenCLKernelBody); + //addSyclDeviceDecl(OpenCLKernel); + */ } void Sema::MarkDevice(void) {