@@ -41,44 +41,45 @@ torch.jit.save(m,'PyTorchModel.pt')\n";
4141
4242void TMVA_SOFIE_PyTorch (){
4343
44- // Running the Python script to generate PyTorch .pt file
44+ // Running the Python script to generate PyTorch .pt file
4545
46- TMacro m ;
47- m .AddLine (pythonSrc );
48- m .SaveSource ("make_pytorch_model.py" );
49- gSystem -> Exec ("python3 make_pytorch_model.py" );
46+ TMacro m ;
47+ m .AddLine (pythonSrc );
48+ m .SaveSource ("make_pytorch_model.py" );
49+ gSystem -> Exec ("python3 make_pytorch_model.py" );
5050
51- // Parsing a PyTorch model requires the shape and data-type of input tensor
52- // Data-type of input tensor defaults to Float if not specified
53- std ::vector < size_t > inputTensorShapeSequential {2 ,32 };
54- std ::vector < std ::vector < size_t >> inputShapesSequential {inputTensorShapeSequential };
51+ // Parsing a PyTorch model requires the shape and data-type of input tensor
52+ // Data-type of input tensor defaults to Float if not specified
53+ std ::vector < size_t > inputTensorShapeSequential {2 , 32 };
54+ std ::vector < std ::vector < size_t >> inputShapesSequential {inputTensorShapeSequential };
5555
56- // Parsing the saved PyTorch .pt file into RModel object
57- SOFIE ::RModel model = SOFIE ::PyTorch ::Parse ("PyTorchModel.pt" ,inputShapesSequential );
56+ // Parsing the saved PyTorch .pt file into RModel object
57+ SOFIE ::RModel model = SOFIE ::PyTorch ::Parse ("PyTorchModel.pt" , inputShapesSequential );
5858
59- // Generating inference code
60- model .Generate ();
61- model .OutputGenerated ("PyTorchModel.hxx" );
59+ // Generating inference code
60+ model .Generate ();
61+ model .OutputGenerated ("PyTorchModel.hxx" );
6262
63- // Printing required input tensors
64- std ::cout << "\n\n" ;
65- model .PrintRequiredInputTensors ();
63+ // Printing required input tensors
64+ std ::cout << "\n\n" ;
65+ model .PrintRequiredInputTensors ();
6666
67- // Printing initialized tensors (weights)
68- std ::cout << "\n\n" ;
69- model .PrintInitializedTensors ();
67+ // Printing initialized tensors (weights)
68+ std ::cout << "\n\n" ;
69+ model .PrintInitializedTensors ();
7070
71- // Printing intermediate tensors
72- std ::cout << "\n\n" ;
73- model .PrintIntermediateTensors ();
71+ // Printing intermediate tensors
72+ std ::cout << "\n\n" ;
73+ model .PrintIntermediateTensors ();
7474
75- //Checking if tensor already exist in model
76- std ::cout <<"\n\nTensor \"0weight\" already exist: " <<std ::boolalpha <<model .CheckIfTensorAlreadyExist ("0weight" )<<"\n\n" ;
77- std ::vector < size_t > tensorShape = model .GetTensorShape ("0weight" );
78- std ::cout <<"Shape of tensor \"0weight\": " ;
79- for (auto& it :tensorShape ){
80- std ::cout <<it <<",";
81- }
75+ // Checking if tensor already exist in model
76+ std ::cout << "\n\nTensor \"0weight\" already exist: " << std ::boolalpha << model .CheckIfTensorAlreadyExist ("0weight" )
77+ << "\n\n" ;
78+ std ::vector < size_t > tensorShape = model .GetTensorShape ("0weight" );
79+ std ::cout << "Shape of tensor \"0weight\": " ;
80+ for (auto & it : tensorShape ) {
81+ std ::cout << it << ",";
82+ }
8283 std ::cout <<"\n\nData type of tensor \"0weight\": " ;
8384 SOFIE ::ETensorType tensorType = model .GetTensorType ("0weight" );
8485 std ::cout <<SOFIE ::ConvertTypeToString (tensorType );
0 commit comments