Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
732 changes: 0 additions & 732 deletions dpctl/tensor/libtensor/include/kernels/boolean_reductions.hpp

This file was deleted.

253 changes: 127 additions & 126 deletions dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp

Large diffs are not rendered by default.

45 changes: 10 additions & 35 deletions dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2584,11 +2584,7 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q,
(reduction_nelems + preferred_reductions_per_wi * wg - 1) /
(preferred_reductions_per_wi * wg);

// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg,
dev.get_info<sycl::info::device::max_work_group_size>() / 2);
size_t max_wg = reduction_detail::get_work_group_size(dev);

if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
resTy *tmp = sycl::malloc_device<resTy>(
Expand Down Expand Up @@ -2666,17 +2662,20 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q,
using OuterInnerDimsIndexerT =
dpctl::tensor::offset_utils::StridedIndexer;
using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;

const OuterInnerDimsIndexerT lhs_indexer(
inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides);
const OuterInnerDimsIndexerT rhs_indexer(
inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides);
constexpr TmpIndexerT res_indexer{};

using dpctl::tensor::offset_utils::Strided1DIndexer;
using dpctl::tensor::offset_utils::StridedIndexer;
using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer;
using dpctl::tensor::offset_utils::UnpackedStridedIndexer;
using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer<
StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>;

const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset,
batch_shape_strides);
const UnpackedStridedIndexer rhs_batch_indexer(
Expand Down Expand Up @@ -2969,11 +2968,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q,
(reduction_nelems + preferred_reductions_per_wi * wg - 1) /
(preferred_reductions_per_wi * wg);

// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg,
dev.get_info<sycl::info::device::max_work_group_size>() / 2);
size_t max_wg = reduction_detail::get_work_group_size(dev);

if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
resTy *tmp = sycl::malloc_device<resTy>(
Expand Down Expand Up @@ -3172,11 +3167,7 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q,
(reduction_nelems + preferred_reductions_per_wi * wg - 1) /
(preferred_reductions_per_wi * wg);

// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg,
dev.get_info<sycl::info::device::max_work_group_size>() / 2);
size_t max_wg = reduction_detail::get_work_group_size(dev);

if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
resTy *tmp = sycl::malloc_device<resTy>(
Expand Down Expand Up @@ -3558,11 +3549,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q,
(reduction_nelems + preferred_reductions_per_wi * wg - 1) /
(preferred_reductions_per_wi * wg);

// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg,
dev.get_info<sycl::info::device::max_work_group_size>() / 2);
size_t max_wg = reduction_detail::get_work_group_size(dev);

if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
resTy *tmp = sycl::malloc_device<resTy>(
Expand Down Expand Up @@ -3728,11 +3715,7 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q,
(reduction_nelems + preferred_reductions_per_wi * wg - 1) /
(preferred_reductions_per_wi * wg);

// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg,
dev.get_info<sycl::info::device::max_work_group_size>() / 2);
size_t max_wg = reduction_detail::get_work_group_size(dev);

if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
resTy *tmp = sycl::malloc_device<resTy>(
Expand Down Expand Up @@ -3982,11 +3965,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q,
(reduction_nelems + preferred_reductions_per_wi * wg - 1) /
(preferred_reductions_per_wi * wg);

// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg,
dev.get_info<sycl::info::device::max_work_group_size>() / 2);
size_t max_wg = reduction_detail::get_work_group_size(dev);

if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
resTy *tmp = sycl::malloc_device<resTy>(
Expand Down Expand Up @@ -4139,11 +4118,7 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q,
(reduction_nelems + preferred_reductions_per_wi * wg - 1) /
(preferred_reductions_per_wi * wg);

// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg,
dev.get_info<sycl::info::device::max_work_group_size>() / 2);
size_t max_wg = reduction_detail::get_work_group_size(dev);

if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
resTy *tmp = sycl::malloc_device<resTy>(
Expand Down
Loading