Skip to content

Commit 4201cef

Browse files
authored
set AliasAnalysisKind of embedding_bag and interaction to PURE_FUNCTION (#163)
1 parent 8d52321 commit 4201cef

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

torch_ipex/csrc/cpu/embeddingbag.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,13 @@ at::Tensor AtenIpexJITDev::dil_qembeddingbag(
494494
} // namespace torch_ipex
495495

496496
namespace {
497-
static auto dispatch =
498-
torch::RegisterOperators()
499-
.op("torch_ipex::embedding_bag", &torch_ipex::AtenIpexTypeExt::embedding_bag);
497+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
498+
m.def(torch::schema(
499+
"torch_ipex::embedding_bag(Tensor weight, Tensor indices, Tensor "
500+
"offsets, bool sparse, bool include_last_offset) -> Tensor",
501+
c10::AliasAnalysisKind::PURE_FUNCTION),
502+
torch_ipex::AtenIpexTypeExt::embedding_bag);
503+
}
500504
}
501505

502506
namespace torch_ipex {

torch_ipex/csrc/cpu/interaction.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,16 @@ at::Tensor AtenIpexJITDev::dil_qinteraction(const std::vector<at::Tensor> input,
422422
} // namespace torch_ipex
423423

424424
namespace {
425-
static auto dispatch =
426-
torch::RegisterOperators()
427-
.op("torch_ipex::interaction_forward",
428-
&torch_ipex::AtenIpexTypeExt::interaction_forward)
429-
.op("torch_ipex::interaction_backward",
430-
&torch_ipex::AtenIpexTypeExt::interaction_backward);
425+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
426+
m.def(
427+
torch::schema("torch_ipex::interaction_forward(Tensor[] input) -> Tensor",
428+
c10::AliasAnalysisKind::PURE_FUNCTION),
429+
torch_ipex::AtenIpexTypeExt::interaction_forward);
430+
m.def(torch::schema("torch_ipex::interaction_backward(Tensor grad_out, "
431+
"Tensor[] input) -> Tensor[]",
432+
c10::AliasAnalysisKind::PURE_FUNCTION),
433+
torch_ipex::AtenIpexTypeExt::interaction_backward);
434+
}
431435
}
432436

433437
namespace torch_ipex {

0 commit comments

Comments
 (0)