@@ -218,10 +218,10 @@ const std::unordered_set<c10::Symbol> AtenIntReplacementNodeKinds = {
218218 torch::jit::aten::floor_divide,
219219};
220220
221- torch::jit::Value* Validate0DTensor (torch::jit::Value* value) {
221+ c10::optional< torch::jit::Value*> Validate0DTensor (torch::jit::Value* value) {
222222 // Validates that the input Value* is a 0D Tensor (or int/float)
223223 // Return the stored int/float Value* if so, otherwise null
224- torch::jit::Value* enclosed_scalar_value = nullptr ;
224+ c10::optional< torch::jit::Value*> enclosed_scalar_value = {} ;
225225
226226 // Regular Int/Float case
227227 if (value->type ()->isSubtypeOf (c10::IntType::get ()) || value->type ()->isSubtypeOf (c10::FloatType::get ())) {
@@ -234,7 +234,7 @@ torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
234234 // Retrieve the Tensor stored in constant
235235 at::Tensor t = *torch::jit::constant_as<at::Tensor>(value);
236236 // Validate the shape of the Tensor is 0D (single-element) and integral
237- if (t.sizes () == std::vector<int64_t >({}) && t.item ().isIntegral ()) {
237+ if (t.sizes () == std::vector<int64_t >({}) && t.item ().isIntegral (false )) {
238238 // Extract the stored value, add it to the graph as a constant
239239 torch::jit::WithInsertPoint guard (value->node ());
240240 auto new_const_val = value->owningGraph ()->insertConstant (t.item (), c10::nullopt , value->node ()->scope ());
@@ -257,7 +257,7 @@ torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
257257 return enclosed_scalar_value;
258258}
259259
260- torch::jit::Value* TracebackAndEliminate0DTensors (torch::jit::Node* node) {
260+ c10::optional< torch::jit::Value*> TracebackAndEliminate0DTensors (torch::jit::Node* node) {
261261 // Trace back through a node and all parents to eliminate 0D Tensors
262262 // and update schemas to their scalar alternatives, returning final
263263 // Value* to user
@@ -268,30 +268,30 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
268268 LOG_DEBUG (
269269 " Encountered node " << node->kind ().toQualString ()
270270 << " which is unsupported in the aten::Int.Tensor replacement lowering pass." );
271- return nullptr ;
271+ return {} ;
272272 }
273273
274274 // Validate the first and second function inputs are 0D tensors or scalars
275- torch::jit::Value* first_input_scalar_value = Validate0DTensor (node->inputs ()[0 ]);
276- torch::jit::Value* second_input_scalar_value = Validate0DTensor (node->inputs ()[1 ]);
275+ c10::optional< torch::jit::Value*> first_input_scalar_value = Validate0DTensor (node->inputs ()[0 ]);
276+ c10::optional< torch::jit::Value*> second_input_scalar_value = Validate0DTensor (node->inputs ()[1 ]);
277277
278278 // If the first input is not a scalar, recursively traceback on parent nodes
279- if (!first_input_scalar_value) {
279+ if (!first_input_scalar_value. has_value () ) {
280280 LOG_DEBUG (" In aten::Int.Tensor lowering, now tracing " << node->inputs ()[0 ]->node ()->kind ().toQualString ());
281281 first_input_scalar_value = TracebackAndEliminate0DTensors (node->inputs ()[0 ]->node ());
282282 }
283283
284284 // If the second input is not a scalar, recursively traceback on parent nodes
285- if (!second_input_scalar_value) {
285+ if (!second_input_scalar_value. has_value () ) {
286286 LOG_DEBUG (" In aten::Int.Tensor lowering, now tracing " << node->inputs ()[0 ]->node ()->kind ().toQualString ());
287287 second_input_scalar_value = TracebackAndEliminate0DTensors (node->inputs ()[1 ]->node ());
288288 }
289289
290- if (!first_input_scalar_value || !second_input_scalar_value) {
290+ if (!first_input_scalar_value. has_value () || !second_input_scalar_value. has_value () ) {
291291 LOG_DEBUG (
292292 " In aten::Int.Tensor lowering, recursive trace through node input "
293293 << " parents failed to return a Scalar value for at least one parent node." );
294- return nullptr ;
294+ return {} ;
295295 }
296296
297297 // Set default insert point at node
@@ -303,15 +303,16 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
303303 // must be inserted
304304 case torch::jit::aten::floor_divide:
305305 new_node = node->owningGraph ()->create (
306- torch::jit::aten::floordiv, {first_input_scalar_value, second_input_scalar_value}, 1 );
306+ torch::jit::aten::floordiv, {first_input_scalar_value. value () , second_input_scalar_value. value () }, 1 );
307307 new_node->insertAfter (node);
308308 new_node->output ()->setType (c10::IntType::get ());
309309 return new_node->output ();
310310
311311 // In the aten::mul case, the schema syntax is the same, so we can use the existing schema
312312 // with new inputs
313313 default :
314- new_node = node->owningGraph ()->create (node->kind (), {first_input_scalar_value, second_input_scalar_value}, 1 );
314+ new_node = node->owningGraph ()->create (
315+ node->kind (), {first_input_scalar_value.value (), second_input_scalar_value.value ()}, 1 );
315316 new_node->insertAfter (node);
316317 new_node->output ()->setType (c10::IntType::get ());
317318 return new_node->output ();
@@ -336,8 +337,8 @@ void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
336337 " Tracing parent node " << it->input ()->node ()->kind ().toQualString ()
337338 << " to eliminate 0D Tensors for aten::Int.Tensor case." );
338339 auto scalar_input_value = TracebackAndEliminate0DTensors (it->input ()->node ());
339- if (scalar_input_value) {
340- it->output ()->replaceAllUsesWith (scalar_input_value);
340+ if (scalar_input_value. has_value () ) {
341+ it->output ()->replaceAllUsesWith (scalar_input_value. value () );
341342 LOG_DEBUG (" Tracing parent nodes for aten::Int.Tensor case succeeded." );
342343 } else {
343344 LOG_DEBUG (" Tracing parent nodes for aten::Int.Tensor case failed." );
0 commit comments