@@ -454,9 +454,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
454454 return Res;
455455 };
456456
457- QualType FieldType = Field-> getType ();
458- CXXRecordDecl *CRD = FieldType-> getAsCXXRecordDecl ();
459- if ( CRD && Util::isSyclAccessorType (FieldType) ) {
457+ auto getExprForAccessorInit = [&]( const QualType ¶mTy,
458+ FieldDecl *Field,
459+ const CXXRecordDecl * CRD, Expr *Base ) {
460460 // Since this is an accessor next 4 TargetFuncParams including current
461461 // should be set in __init method: _ValueType*, range<int>, range<int>,
462462 // id<int>
@@ -472,9 +472,9 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
472472 std::advance (TargetFuncParam, NumParams - 1 );
473473
474474 DeclAccessPair FieldDAP = DeclAccessPair::make (Field, AS_none);
475- // kernel_obj .accessor
475+ // [kenrel_obj or wrapper object] .accessor
476476 auto AccessorME = MemberExpr::Create (
477- S.Context , CloneRef , false , SourceLocation (),
477+ S.Context , Base , false , SourceLocation (),
478478 NestedNameSpecifierLoc (), SourceLocation (), Field, FieldDAP,
479479 DeclarationNameInfo (Field->getDeclName (), SourceLocation ()),
480480 nullptr , Field->getType (), VK_LValue, OK_Ordinary);
@@ -488,7 +488,7 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
488488 }
489489 assert (InitMethod && " The accessor must have the __init method" );
490490
491- // kernel_obj .accessor.__init
491+ // [kenrel_obj or wrapper object] .accessor.__init
492492 DeclAccessPair MethodDAP = DeclAccessPair::make (InitMethod, AS_none);
493493 auto ME = MemberExpr::Create (
494494 S.Context , AccessorME, false , SourceLocation (),
@@ -515,11 +515,52 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
515515 S, ((*ParamItr++))->getOriginalType (), ParamDREs[2 ]));
516516 ParamStmts.push_back (getExprForRangeOrOffset (
517517 S, ((*ParamItr++))->getOriginalType (), ParamDREs[3 ]));
518- // kernel_obj .accessor.__init(_ValueType*, range<int>, range<int> ,
519- // id<int>)
518+ // [kenrel_obj or wrapper object] .accessor.__init(_ValueType*,
519+ // range<int>, range<int>, id<int>)
520520 CXXMemberCallExpr *Call = CXXMemberCallExpr::Create (
521521 S.Context , ME, ParamStmts, ResultTy, VK, SourceLocation ());
522522 BodyStmts.push_back (Call);
523+ };
524+
525+ // Recursively search for accessor fields to initialize them with kernel
526+ // parameters
527+ std::function<void (const CXXRecordDecl *, Expr *)>
528+ getExprForWrappedAccessorInit = [&](const CXXRecordDecl *CRD,
529+ Expr *Base) {
530+ for (auto *WrapperFld : CRD->fields ()) {
531+ QualType FldType = WrapperFld->getType ();
532+ CXXRecordDecl *WrapperFldCRD = FldType->getAsCXXRecordDecl ();
533+ if (FldType->isStructureOrClassType ()) {
534+ if (Util::isSyclAccessorType (FldType)) {
535+ // Accessor field found - create expr to initialize this
536+ // accessor object. Need to start from the next target
537+ // function parameter, since current one is the wrapper object
538+ // or parameter of the previous processed accessor object.
539+ TargetFuncParam++;
540+ getExprForAccessorInit (FldType, WrapperFld, WrapperFldCRD,
541+ Base);
542+ } else {
543+ // Field is a structure or class so change the wrapper object
544+ // and recursively search for accessor field.
545+ DeclAccessPair WrapperFieldDAP =
546+ DeclAccessPair::make (WrapperFld, AS_none);
547+ auto NewBase = MemberExpr::Create (
548+ S.Context , Base, false , SourceLocation (),
549+ NestedNameSpecifierLoc (), SourceLocation (), WrapperFld,
550+ WrapperFieldDAP,
551+ DeclarationNameInfo (WrapperFld->getDeclName (),
552+ SourceLocation ()),
553+ nullptr , WrapperFld->getType (), VK_LValue, OK_Ordinary);
554+ getExprForWrappedAccessorInit (WrapperFldCRD, NewBase);
555+ }
556+ }
557+ }
558+ };
559+
560+ QualType FieldType = Field->getType ();
561+ CXXRecordDecl *CRD = FieldType->getAsCXXRecordDecl ();
562+ if (Util::isSyclAccessorType (FieldType)) {
563+ getExprForAccessorInit (FieldType, Field, CRD, CloneRef);
523564 } else if (CRD && Util::isSyclSamplerType (FieldType)) {
524565
525566 // Sampler has only one TargetFuncParam, which should be set in
@@ -596,6 +637,12 @@ CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelCallerFunc, DeclContext *DC) {
596637 BinaryOperator (Lhs, Rhs, BO_Assign, FieldType, VK_LValue,
597638 OK_Ordinary, SourceLocation (), FPOptions ());
598639 BodyStmts.push_back (Res);
640+
641+ // If a structure/class type has accessor fields then we need to
642+ // initialize these accessors in proper way by calling __init method of
643+ // the accessor and passing corresponding kernel parameters.
644+ if (CRD)
645+ getExprForWrappedAccessorInit (CRD, Lhs);
599646 } else {
600647 llvm_unreachable (" unsupported field type" );
601648 }
@@ -675,56 +722,78 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
675722 // create a parameter descriptor and append it to the result
676723 ParamDescs.push_back (makeParamDesc (Fld, ArgType));
677724 };
725+
726+ auto createAccessorParamDesc = [&](const FieldDecl *Fld,
727+ const QualType &ArgTy) {
728+ // the parameter is a SYCL accessor object
729+ const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
730+ assert (RecordDecl && " accessor must be of a record type" );
731+ const auto *TemplateDecl =
732+ cast<ClassTemplateSpecializationDecl>(RecordDecl);
733+ // First accessor template parameter - data type
734+ QualType PointeeType = TemplateDecl->getTemplateArgs ()[0 ].getAsType ();
735+ // Fourth parameter - access target
736+ target AccessTarget = getAccessTarget (TemplateDecl);
737+ Qualifiers Quals = PointeeType.getQualifiers ();
738+ // TODO: Support all access targets
739+ switch (AccessTarget) {
740+ case target::global_buffer:
741+ Quals.setAddressSpace (LangAS::opencl_global);
742+ break ;
743+ case target::constant_buffer:
744+ Quals.setAddressSpace (LangAS::opencl_constant);
745+ break ;
746+ case target::local:
747+ Quals.setAddressSpace (LangAS::opencl_local);
748+ break ;
749+ default :
750+ llvm_unreachable (" Unsupported access target" );
751+ }
752+ PointeeType =
753+ Context.getQualifiedType (PointeeType.getUnqualifiedType (), Quals);
754+ QualType PointerType = Context.getPointerType (PointeeType);
755+
756+ CreateAndAddPrmDsc (Fld, PointerType);
757+
758+ FieldDecl *AccessRangeFld =
759+ getFieldDeclByName (RecordDecl, {" impl" , " AccessRange" });
760+ assert (AccessRangeFld &&
761+ " The accessor.impl must contain the AccessRange field" );
762+ CreateAndAddPrmDsc (AccessRangeFld, AccessRangeFld->getType ());
763+
764+ FieldDecl *MemRangeFld =
765+ getFieldDeclByName (RecordDecl, {" impl" , " MemRange" });
766+ assert (MemRangeFld && " The accessor.impl must contain the MemRange field" );
767+ CreateAndAddPrmDsc (MemRangeFld, MemRangeFld->getType ());
768+
769+ FieldDecl *OffsetFld = getFieldDeclByName (RecordDecl, {" impl" , " Offset" });
770+ assert (OffsetFld && " The accessor.impl must contain the Offset field" );
771+ CreateAndAddPrmDsc (OffsetFld, OffsetFld->getType ());
772+ };
773+
774+ std::function<void (const FieldDecl *, const QualType &ArgTy)>
775+ createParamDescForWrappedAccessors =
776+ [&](const FieldDecl *Fld, const QualType &ArgTy) {
777+ const auto *Wrapper = ArgTy->getAsCXXRecordDecl ();
778+ for (const auto *WrapperFld : Wrapper->fields ()) {
779+ QualType FldType = WrapperFld->getType ();
780+ if (FldType->isStructureOrClassType ()) {
781+ if (Util::isSyclAccessorType (FldType)) {
782+ // accessor field is found - create descriptor
783+ createAccessorParamDesc (WrapperFld, FldType);
784+ } else {
785+ // field is some class or struct - recursively check for
786+ // accessor fields
787+ createParamDescForWrappedAccessors (WrapperFld, FldType);
788+ }
789+ }
790+ }
791+ };
792+
678793 for (const auto *Fld : KernelObj->fields ()) {
679794 QualType ArgTy = Fld->getType ();
680795 if (Util::isSyclAccessorType (ArgTy)) {
681- // the parameter is a SYCL accessor object
682- const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
683- assert (RecordDecl && " accessor must be of a record type" );
684- const auto *TemplateDecl =
685- cast<ClassTemplateSpecializationDecl>(RecordDecl);
686- // First accessor template parameter - data type
687- QualType PointeeType = TemplateDecl->getTemplateArgs ()[0 ].getAsType ();
688- // Fourth parameter - access target
689- target AccessTarget = getAccessTarget (TemplateDecl);
690- Qualifiers Quals = PointeeType.getQualifiers ();
691- // TODO: Support all access targets
692- switch (AccessTarget) {
693- case target::global_buffer:
694- Quals.setAddressSpace (LangAS::opencl_global);
695- break ;
696- case target::constant_buffer:
697- Quals.setAddressSpace (LangAS::opencl_constant);
698- break ;
699- case target::local:
700- Quals.setAddressSpace (LangAS::opencl_local);
701- break ;
702- default :
703- llvm_unreachable (" Unsupported access target" );
704- }
705- // TODO: get address space from accessor template parameter.
706- PointeeType =
707- Context.getQualifiedType (PointeeType.getUnqualifiedType (), Quals);
708- QualType PointerType = Context.getPointerType (PointeeType);
709-
710- CreateAndAddPrmDsc (Fld, PointerType);
711-
712- FieldDecl *AccessRangeFld =
713- getFieldDeclByName (RecordDecl, {" impl" , " AccessRange" });
714- assert (AccessRangeFld &&
715- " The accessor.impl must contain the AccessRange field" );
716- CreateAndAddPrmDsc (AccessRangeFld, AccessRangeFld->getType ());
717-
718- FieldDecl *MemRangeFld =
719- getFieldDeclByName (RecordDecl, {" impl" , " MemRange" });
720- assert (MemRangeFld &&
721- " The accessor.impl must contain the MemRange field" );
722- CreateAndAddPrmDsc (MemRangeFld, MemRangeFld->getType ());
723-
724- FieldDecl *OffsetFld =
725- getFieldDeclByName (RecordDecl, {" impl" , " Offset" });
726- assert (OffsetFld && " The accessor.impl must contain the Offset field" );
727- CreateAndAddPrmDsc (OffsetFld, OffsetFld->getType ());
796+ createAccessorParamDesc (Fld, ArgTy);
728797 } else if (Util::isSyclSamplerType (ArgTy)) {
729798 // the parameter is a SYCL sampler object
730799 const auto *RecordDecl = ArgTy->getAsCXXRecordDecl ();
@@ -747,6 +816,8 @@ static void buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
747816 }
748817 // structure or class typed parameter - the same handling as a scalar
749818 CreateAndAddPrmDsc (Fld, ArgTy);
819+ // create descriptors for each accessor field in the class or struct
820+ createParamDescForWrappedAccessors (Fld, ArgTy);
750821 } else if (ArgTy->isScalarType ()) {
751822 // scalar typed parameter
752823 CreateAndAddPrmDsc (Fld, ArgTy);
@@ -770,14 +841,7 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
770841 const ASTRecordLayout &Layout = Ctx.getASTRecordLayout (KernelObjTy);
771842 H.startKernel (Name, NameType);
772843
773- for (const auto Fld : KernelObjTy->fields ()) {
774- QualType ActualArgType;
775- QualType ArgTy = Fld->getType ();
776-
777- // Get offset in bytes
778- uint64_t Offset = Layout.getFieldOffset (Fld->getFieldIndex ()) / 8 ;
779-
780- if (Util::isSyclAccessorType (ArgTy)) {
844+ auto populateHeaderForAccessor = [&](const QualType &ArgTy, uint64_t Offset) {
781845 // The parameter is a SYCL accessor object.
782846 // The Info field of the parameter descriptor for accessor contains
783847 // two template parameters packed into thid integer field:
@@ -790,6 +854,43 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
790854 AccTmplTy->getTemplateArgs ()[1 ].getAsIntegral ().getExtValue ());
791855 int Info = getAccessTarget (AccTmplTy) | (Dims << 11 );
792856 H.addParamDesc (SYCLIntegrationHeader::kind_accessor, Info, Offset);
857+ };
858+
859+ std::function<void (const QualType &, uint64_t Offset)>
860+ populateHeaderForWrappedAccessors = [&](const QualType &ArgTy,
861+ uint64_t Offset) {
862+ const auto *Wrapper = ArgTy->getAsCXXRecordDecl ();
863+ for (const auto *WrapperFld : Wrapper->fields ()) {
864+ QualType FldType = WrapperFld->getType ();
865+ if (FldType->isStructureOrClassType ()) {
866+ ASTContext &WrapperCtx = Wrapper->getASTContext ();
867+ const ASTRecordLayout &WrapperLayout =
868+ WrapperCtx.getASTRecordLayout (Wrapper);
869+ // Get offset (in bytes) of the field in wrapper class or struct
870+ uint64_t OffsetInWrapper =
871+ WrapperLayout.getFieldOffset (WrapperFld->getFieldIndex ()) / 8 ;
872+ if (Util::isSyclAccessorType (FldType)) {
873+ // This is an accesor - populate the header appropriately
874+ populateHeaderForAccessor (FldType, Offset + OffsetInWrapper);
875+ } else {
876+ // This is an other class or struct - recursively search for an
877+ // accessor field
878+ populateHeaderForWrappedAccessors (FldType,
879+ Offset + OffsetInWrapper);
880+ }
881+ }
882+ }
883+ };
884+
885+ for (const auto Fld : KernelObjTy->fields ()) {
886+ QualType ActualArgType;
887+ QualType ArgTy = Fld->getType ();
888+
889+ // Get offset in bytes
890+ uint64_t Offset = Layout.getFieldOffset (Fld->getFieldIndex ()) / 8 ;
891+
892+ if (Util::isSyclAccessorType (ArgTy)) {
893+ populateHeaderForAccessor (ArgTy, Offset);
793894 } else if (Util::isSyclSamplerType (ArgTy)) {
794895 // The parameter is a SYCL sampler object
795896 // It has only one descriptor, "m_Sampler"
@@ -810,6 +911,12 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
810911 uint64_t Sz = Ctx.getTypeSizeInChars (Fld->getType ()).getQuantity ();
811912 H.addParamDesc (SYCLIntegrationHeader::kind_std_layout,
812913 static_cast <unsigned >(Sz), static_cast <unsigned >(Offset));
914+
915+ // check for accessor fields in structure or class and populate the
916+ // integration header appropriately
917+ if (ArgTy->isStructureOrClassType ()) {
918+ populateHeaderForWrappedAccessors (ArgTy, Offset);
919+ }
813920 } else {
814921 llvm_unreachable (" unsupported kernel parameter type" );
815922 }
0 commit comments