@@ -90,16 +90,19 @@ def forward(self,
9090 chklowtinv = torch .linalg .inv (chklowt )
9191 data [self .h_out_field ] = (chklowtinv @ data [self .h_out_field ] @ torch .transpose (chklowtinv ,dim0 = 1 ,dim1 = 2 ).conj ())
9292 elif eig_solver == 'numpy' :
93- chklowt = np .linalg .cholesky (data [self .s_out_field ].detach ().numpy ())
93+ chklowt = np .linalg .cholesky (data [self .s_out_field ].detach ().cpu (). numpy ())
9494 chklowtinv = np .linalg .inv (chklowt )
95- data [self .h_out_field ] = (chklowtinv @ data [self .h_out_field ].detach ().numpy () @ np .transpose (chklowtinv ,(0 ,2 ,1 )).conj ())
96- else :
97- data [self .h_out_field ] = data [self .h_out_field ]
98-
95+ data [self .h_out_field ] = (chklowtinv @ data [self .h_out_field ].detach ().cpu ().numpy () @ np .transpose (chklowtinv ,(0 ,2 ,1 )).conj ())
96+ elif eig_solver == 'numpy' :
97+ # Convert to numpy when using numpy solver without overlap
98+ data [self .h_out_field ] = data [self .h_out_field ].detach ().cpu ().numpy ()
99+
99100 if eig_solver == 'torch' :
100101 eigvals .append (torch .linalg .eigvalsh (data [self .h_out_field ]))
101102 elif eig_solver == 'numpy' :
102- eigvals .append (torch .from_numpy (np .linalg .eigvalsh (a = data [self .h_out_field ])))
103+ eigvals_np = np .linalg .eigvalsh (a = data [self .h_out_field ])
104+ # Preserve dtype by converting to the Hamiltonian's original dtype
105+ eigvals .append (torch .from_numpy (eigvals_np ).to (dtype = self .h2k .dtype ))
103106
104107 data [self .out_field ] = torch .nested .as_nested_tensor ([torch .cat (eigvals , dim = 0 )])
105108 if nested :
0 commit comments