Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions sycl/include/sycl/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1544,8 +1544,37 @@ template <typename VecT, typename OperationLeftT, typename OperationRightT,
template <typename> class OperationCurrentT, int... Indexes>
class SwizzleOp {
using DataT = typename VecT::element_type;
using CommonDataT = std::common_type_t<typename OperationLeftT::DataT,
typename OperationRightT::DataT>;
// Certain operators return a vector with a different element type. Also, the
// left and right operand types may differ. CommonDataT selects a result type
// based on these types to ensure that the result value can be represented.
//
// Example 1:
// sycl::vec<unsigned char, 4> vec{...};
// auto result = 300u + vec.x();
//
// CommonDataT is std::common_type_t<OperationLeftT, OperationRightT> since
// it's larger than unsigned char.
//
// Example 2:
// sycl::vec<bool, 1> vec{...};
// auto result = vec.template swizzle<sycl::elem::s0>() && vec;
//
// CommonDataT is DataT since operator&& returns a vector with element type
// int8_t, which is larger than bool.
//
// Example 3:
// sycl::vec<std::byte, 4> vec{...}; auto swlo = vec.lo();
// auto result = swlo == swlo;
//
// CommonDataT is DataT since operator== returns a vector with element type
// int8_t, which is the same size as std::byte. std::common_type_t<DataT, ...>
// can't be used here since there's no type that int8_t and std::byte can both
// be implicitly converted to.
using OpLeftDataT = typename OperationLeftT::DataT;
using OpRightDataT = typename OperationRightT::DataT;
using CommonDataT = std::conditional_t<
sizeof(DataT) >= sizeof(std::common_type_t<OpLeftDataT, OpRightDataT>),
DataT, std::common_type_t<OpLeftDataT, OpRightDataT>>;
static constexpr int getNumElements() { return sizeof...(Indexes); }

using rel_t = detail::rel_t<DataT>;
Expand Down
68 changes: 68 additions & 0 deletions sycl/test-e2e/Regression/vec_rel_swizzle_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// RUN: %if preview-breaking-changes-supported %{ %clangxx -fsycl -fpreview-breaking-changes %s -o %t2.out %}
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}

#include <cstdlib>
#include <sycl/sycl.hpp>

template <typename T, typename ResultT>
bool testAndOperator(const std::string &typeName) {
constexpr int N = 5;
std::array<ResultT, N> results{};

sycl::queue q;
sycl::buffer<ResultT, 1> buffer{results.data(), N};
q.submit([&](sycl::handler &cgh) {
sycl::accessor acc{buffer, cgh, sycl::write_only};
cgh.parallel_for(sycl::range<1>{1}, [=](sycl::id<1> id) {
auto testVec1 = sycl::vec<T, 1>(static_cast<T>(1));
auto testVec2 = sycl::vec<T, 1>(static_cast<T>(2));
sycl::vec<ResultT, 1> resVec;

ResultT expected = static_cast<ResultT>(
-(static_cast<ResultT>(1) && static_cast<ResultT>(2)));
acc[0] = expected;

// LHS swizzle
resVec = testVec1.template swizzle<sycl::elem::s0>() && testVec2;
acc[1] = resVec[0];

// RHS swizzle
resVec = testVec1 && testVec2.template swizzle<sycl::elem::s0>();
acc[2] = resVec[0];

// No swizzle
resVec = testVec1 && testVec2;
acc[3] = resVec[0];

// Both swizzle
resVec = testVec1.template swizzle<sycl::elem::s0>() &&
testVec2.template swizzle<sycl::elem::s0>();
acc[4] = resVec[0];
});
}).wait();

bool passed = true;
ResultT expected = results[0];

std::cout << "Testing with T = " << typeName << std::endl;
std::cout << "Expected: " << (int)expected << std::endl;
for (int i = 1; i < N; i++) {
std::cout << "Test " << (i - 1) << ": " << ((int)results[i]) << std::endl;
passed &= expected == results[i];
}
std::cout << std::endl;
return passed;
}

int main() {
bool passed = true;
passed &= testAndOperator<bool, std::int8_t>("bool");
passed &= testAndOperator<std::int8_t, std::int8_t>("std::int8_t");
passed &= testAndOperator<float, std::int32_t>("float");
passed &= testAndOperator<int, std::int32_t>("int");
std::cout << (passed ? "Pass" : "Fail") << std::endl;
return (passed ? EXIT_SUCCESS : EXIT_FAILURE);
}