|
9 | 9 | #include "torch/csrc/jit/ir/ir.h" |
10 | 10 | #include "torch/torch.h" |
11 | 11 |
|
| 12 | +#include "core/conversion/converters/converter_util.h" |
12 | 13 | #include "core/conversion/evaluators/eval_macros.h" |
13 | 14 | #include "core/conversion/evaluators/eval_util.h" |
14 | 15 | #include "core/conversion/evaluators/evaluators.h" |
@@ -298,20 +299,22 @@ auto aten_registrations TORCHTRT_UNUSED = |
298 | 299 | } else { |
299 | 300 | auto dim = args.at(n->input(1)).unwrapToInt(); |
300 | 301 | if (tensor_var.isITensor()) { |
301 | | - if (ctx->input_is_dynamic) { |
| 302 | + auto tensor = tensor_var.ITensor(); |
| 303 | + auto dims = util::toVec(tensor->getDimensions()); |
| 304 | + auto nbDims = tensor->getDimensions().nbDims; |
| 305 | + if (dim < 0) { |
| 306 | + dim += nbDims; |
| 307 | + } |
| 308 | + // Check if selected dimension size is -1 else return static size |
| 309 | + if (ctx->input_is_dynamic && dims[dim] == -1) { |
302 | 310 | if (ctx->settings.allow_shape_tensors) { |
303 | 311 | return dynamic_size_layer(ctx, n, args); |
304 | 312 | } else { |
305 | 313 | LOG_WARNING( |
306 | 314 | "There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors"); |
307 | 315 | } |
308 | 316 | } |
309 | | - auto tensor = tensor_var.ITensor(); |
310 | | - auto dims = util::toVec(tensor->getDimensions()); |
311 | | - auto nbDims = tensor->getDimensions().nbDims; |
312 | | - if (dim < 0) { |
313 | | - dim += nbDims; |
314 | | - } |
| 317 | + |
315 | 318 | return dims[dim]; |
316 | 319 | } else if (tensor_var.IValue()->isTensor()) { |
317 | 320 | auto tensor = tensor_var.unwrapToTensor(); |
@@ -677,6 +680,25 @@ auto aten_registrations TORCHTRT_UNUSED = |
677 | 680 | .evaluator( |
678 | 681 | {c10::Symbol::fromQualString("aten::floordiv"), |
679 | 682 | [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { |
| 683 | + // Dynamic version of aten::floordiv |
| 684 | + if (args.at(n->input(0)).isITensor()) { |
| 685 | + if (args.at(n->input(1)).IValue()->isInt()) { |
| 686 | + auto int_tensor = scalar_to_tensor(args.at(n->input(1)).IValue()->toInt()); |
| 687 | + auto int_itensor = converters::tensor_to_const(ctx, int_tensor, util::node_info(n) + "_constant"); |
| 688 | + auto elementwise_layer = converters::add_elementwise( |
| 689 | + ctx, |
| 690 | + nvinfer1::ElementWiseOperation::kFLOOR_DIV, |
| 691 | + args.at(n->input(0)).ITensor(), |
| 692 | + int_itensor, |
| 693 | + util::node_info(n)); |
| 694 | + auto output_tensor = elementwise_layer->getOutput(0); |
| 695 | + auto tensor_holder = TensorContainer(); |
| 696 | + tensor_holder.hold_tensor(output_tensor); |
| 697 | + auto output_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder))); |
| 698 | + return output_ivalue; |
| 699 | + } |
| 700 | + } |
| 701 | + // Static version |
680 | 702 | if (args.at(n->input(0)).IValue()->isInt()) { |
681 | 703 | auto a = args.at(n->input(0)).unwrapToInt(); |
682 | 704 | auto b = args.at(n->input(1)).unwrapToInt(); |
|
0 commit comments