55#include " torch/csrc/autograd/anomaly_mode.h"
66#include " torch/csrc/autograd/profiler.h"
77#include " torch/csrc/autograd/saved_variable.h"
8- #include " torch/csrc/autograd/type_and_shape .h"
8+ #include " torch/csrc/autograd/input_metadata .h"
99#include " torch/csrc/autograd/variable.h"
1010#include " torch/csrc/utils/python_stub.h"
1111#include " torch/csrc/utils/variadic.h"
@@ -128,9 +128,18 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
128128
129129 // / Adds the type and shape metadata for a new input. Returns the index of
130130 // / of the new input.
131- uint32_t add_input_metadata (const at::Type& type, at::IntList shape) noexcept {
131+ uint32_t add_input_metadata (
132+ const at::Type& type
133+ , at::IntList shape
134+ , const int64_t device) noexcept {
132135 uint32_t input_nr = input_metadata_.size ();
133- input_metadata_.emplace_back (type, shape);
136+ input_metadata_.emplace_back (type, shape, device);
137+ return input_nr;
138+ }
139+
140+ uint32_t add_input_metadata (const at::Tensor& t) noexcept {
141+ uint32_t input_nr = input_metadata_.size ();
142+ input_metadata_.emplace_back (t);
134143 return input_nr;
135144 }
136145
@@ -145,7 +154,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
145154 return input_metadata_.size ();
146155 }
147156
148- const TypeAndShape & input_metadata (size_t index) const {
157+ const InputMetadata & input_metadata (size_t index) const {
149158 return input_metadata_[index];
150159 }
151160
@@ -322,7 +331,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
322331 std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr ;
323332 std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
324333 std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
325- at::SmallVector<TypeAndShape , 2 > input_metadata_;
334+ at::SmallVector<InputMetadata , 2 > input_metadata_;
326335};
327336
328337// / See Function::is_traceable() for definition.
@@ -367,7 +376,7 @@ inline void create_gradient_edge(
367376 Variable& variable,
368377 std::shared_ptr<Function> function) {
369378 // Copy before move.
370- const auto input_nr = function->add_input_metadata (variable. type (), variable. sizes () );
379+ const auto input_nr = function->add_input_metadata (variable);
371380 variable.set_gradient_edge ({std::move (function), input_nr});
372381}
373382
0 commit comments