@@ -54,9 +54,10 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
54
54
55
55
auto [Prefix, Tag] = splitMetadataName (MetadataElementName);
56
56
57
- if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
58
- // If metadata is reqd_work_group_size, record it for the corresponding
59
- // kernel name.
57
+ if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE ||
58
+ Tag == __SYCL_UR_PROGRAM_METADATA_TAG_MAX_WORK_GROUP_SIZE) {
59
+ // If metadata is reqd_work_group_size/max_work_group_size, record it for
60
+ // the corresponding kernel name.
60
61
size_t MDElemsSize = MetadataElement.size - sizeof (std::uint64_t );
61
62
62
63
// Expect between 1 and 3 32-bit integer values.
@@ -69,18 +70,23 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
69
70
reinterpret_cast <const char *>(MetadataElement.value .pData ) +
70
71
sizeof (std::uint64_t );
71
72
// Read values and pad with 1's for values not present.
72
- std::uint32_t ReqdWorkGroupElements[] = {1 , 1 , 1 };
73
- std::memcpy (ReqdWorkGroupElements, ValuePtr, MDElemsSize);
74
- KernelReqdWorkGroupSizeMD[Prefix] =
75
- std::make_tuple (ReqdWorkGroupElements[0 ], ReqdWorkGroupElements[1 ],
76
- ReqdWorkGroupElements[2 ]);
73
+ std::array<uint32_t , 3 > WorkGroupElements = {1 , 1 , 1 };
74
+ std::memcpy (WorkGroupElements.data (), ValuePtr, MDElemsSize);
75
+ (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE
76
+ ? KernelReqdWorkGroupSizeMD
77
+ : KernelMaxWorkGroupSizeMD)[Prefix] =
78
+ std::make_tuple (WorkGroupElements[0 ], WorkGroupElements[1 ],
79
+ WorkGroupElements[2 ]);
77
80
} else if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) {
78
81
const char *MetadataValPtr =
79
82
reinterpret_cast <const char *>(MetadataElement.value .pData ) +
80
83
sizeof (std::uint64_t );
81
84
const char *MetadataValPtrEnd =
82
85
MetadataValPtr + MetadataElement.size - sizeof (std::uint64_t );
83
86
GlobalIDMD[Prefix] = std::string{MetadataValPtr, MetadataValPtrEnd};
87
+ } else if (Tag ==
88
+ __SYCL_UR_PROGRAM_METADATA_TAG_MAX_LINEAR_WORK_GROUP_SIZE) {
89
+ KernelMaxLinearWorkGroupSizeMD[Prefix] = MetadataElement.value .data64 ;
84
90
}
85
91
}
86
92
return UR_RESULT_SUCCESS;
0 commit comments