@@ -58,19 +58,22 @@ void dumpTensorCout(const Tensor& tensor) {
5858 std::cout << std::endl;
5959}
6060
61- c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr (const Tensor& tensor, int64_t level, bool should_be_alive ) {
61+ c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr (const Tensor& tensor, int64_t level, const std::shared_ptr< bool >& life_handle ) {
6262 auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet ({
6363 DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
6464 auto key_set = getKeysToPropagateToWrapper (tensor, keys_to_propagate);
6565 key_set = key_set.add (DispatchKey::FuncTorchGradWrapper);
66- if (should_be_alive) {
67- return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, getLifeHandleForLevel (level));
68- } else {
69- return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, std::make_shared<bool >(false ));
70- }
66+ return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
7167}
7268
73- Tensor makeTensorWrapper (const Tensor& tensor, int64_t level, bool is_immutable) {
69+ // use makeTensorWrapper instead to avoid potential footguns:
70+ // unsafeMakeTensorWrapper doesn't check that level and life_handle
71+ // refer to the same interpreter
72+ static Tensor unsafeMakeTensorWrapper (
73+ const Tensor& tensor,
74+ int64_t level,
75+ bool is_immutable,
76+ const std::shared_ptr<bool >& life_handle) {
7477 auto wrapped = maybeGetTensorWrapper (tensor);
7578 if (wrapped) {
7679 TORCH_INTERNAL_ASSERT (wrapped->level () < level);
@@ -80,20 +83,38 @@ Tensor makeTensorWrapper(const Tensor& tensor, int64_t level, bool is_immutable)
8083 DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
8184 auto key_set = getKeysToPropagateToWrapper (tensor, keys_to_propagate);
8285 key_set = key_set.add (DispatchKey::FuncTorchGradWrapper);
83- auto life_handle = getLifeHandleForLevel (level);
84- auto result = at::detail::make_tensor<TensorWrapper>( key_set, tensor, level, std::move ( life_handle) , is_immutable);
86+ auto result = at::detail::make_tensor<TensorWrapper>(
87+ key_set, tensor, level, life_handle, is_immutable);
8588 TORCH_INTERNAL_ASSERT (result.key_set ().has (DispatchKey::FuncTorchGradWrapper));
8689 return result;
8790}
8891
92+ Tensor makeTensorWrapper (const Tensor& tensor, int64_t level, bool is_immutable) {
93+ auto life_handle = getLifeHandleForLevel (level);
94+ return unsafeMakeTensorWrapper (
95+ tensor,
96+ level,
97+ is_immutable,
98+ getLifeHandleForLevel (level));
99+ }
100+
101+ Tensor makeTensorWrapper (const Tensor& tensor, const Interpreter& interpreter, bool is_immutable) {
102+ return unsafeMakeTensorWrapper (
103+ tensor,
104+ interpreter.level (),
105+ is_immutable,
106+ interpreter.is_alive_ptr ());
107+ }
108+
109+
89110bool TensorWrapper::is_alive () const {
90111 return *is_alive_;
91112}
92113
93114c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach (
94115 const c10::VariableVersion& version_counter,
95116 bool allow_tensor_metadata_change) const {
96- auto dest_impl = makeTensorWrapperPtr (value (), level_, is_alive () );
117+ auto dest_impl = makeTensorWrapperPtr (value (), level_, is_alive_ );
97118 dest_impl->set_version_counter (version_counter);
98119
99120 // TODO: is this even right?
@@ -104,7 +125,7 @@ c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
104125c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach (
105126 c10::VariableVersion&& version_counter,
106127 bool allow_tensor_metadata_change) const {
107- auto dest_impl = makeTensorWrapperPtr (value (), level_, is_alive () );
128+ auto dest_impl = makeTensorWrapperPtr (value (), level_, is_alive_ );
108129 dest_impl->set_version_counter (version_counter);
109130
110131 // TODO: is this even right?
0 commit comments