2323
2424dtypestr : str
2525Tensor = Any
26+ pytree = Any
2627
2728torchlib : Any
2829
3435# To be added once pytorch backend is ready
3536
3637
38+ class torch_optimizer :
39+ def __init__ (self , optimizer : Any ) -> None :
40+ self .optimizer = optimizer
41+ self .is_init = False
42+
43+ def update (self , grads : pytree , params : pytree ) -> pytree :
44+ # flatten grad and param
45+ params , treedef = PyTorchBackend .tree_flatten (None , params )
46+ grads , _ = PyTorchBackend .tree_flatten (None , grads )
47+ if self .is_init is False :
48+ self .optimizer = self .optimizer (params )
49+ self .is_init = True
50+ with torchlib .no_grad ():
51+ for g , p in zip (grads , params ):
52+ p .grad = g
53+ self .optimizer .step ()
54+ self .optimizer .zero_grad ()
55+ # reorg the param
56+ params = PyTorchBackend .tree_unflatten (None , treedef , params )
57+ return params
58+
59+
3760def _conj_torch (self : Any , tensor : Tensor ) -> Tensor :
3861 t = torchlib .conj (tensor )
3962 return t .resolve_conj () # any side effect?
@@ -355,6 +378,16 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
355378 def solve (self , A : Tensor , b : Tensor , ** kws : Any ) -> Tensor :
356379 return torchlib .linalg .solve (A , b )
357380
381+ def tree_map (self , f : Callable [..., Any ], * pytrees : Any ) -> Any :
382+ # TODO(@refraction-ray): torch not support multiple pytree args
383+ return torchlib .utils ._pytree .tree_map (f , * pytrees )
384+
385+ def tree_flatten (self : Any , pytree : Any ) -> Tuple [Any , Any ]:
386+ return torchlib .utils ._pytree .tree_flatten (pytree ) # type: ignore
387+
388+ def tree_unflatten (self : Any , treedef : Any , leaves : Any ) -> Any :
389+ return torchlib .utils ._pytree .tree_unflatten (leaves , treedef )
390+
358391 def cond (
359392 self ,
360393 pred : bool ,
@@ -413,6 +446,13 @@ def value_and_grad(
413446 argnums : Union [int , Sequence [int ]] = 0 ,
414447 has_aux : bool = False ,
415448 ) -> Callable [..., Tuple [Any , Any ]]:
449+ def ask_require (t : Tensor ) -> Any :
450+ t .requires_grad_ (True )
451+ return t
452+
453+ def get_grad (t : Tensor ) -> Tensor :
454+ return t .grad
455+
416456 def wrapper (* args : Any , ** kws : Any ) -> Any :
417457 x = []
418458 if isinstance (argnums , int ):
@@ -423,15 +463,15 @@ def wrapper(*args: Any, **kws: Any) -> Any:
423463 argnumsl = argnums # type: ignore
424464 for i , arg in enumerate (args ):
425465 if i in argnumsl :
426- x .append (arg . requires_grad_ ( True ))
466+ x .append (self . tree_map ( ask_require , arg ))
427467 else :
428468 x .append (arg )
429469 y = f (* x , ** kws )
430470 if has_aux :
431471 y [0 ].backward ()
432472 else :
433473 y .backward ()
434- gs = [x [i ]. grad for i in argnumsl ]
474+ gs = [self . tree_map ( get_grad , x [i ]) for i in argnumsl ]
435475 if len (gs ) == 1 :
436476 gs = gs [0 ]
437477 return y , gs
@@ -532,3 +572,5 @@ def vectorized_value_and_grad(
532572 return f
533573
534574 vvag = vectorized_value_and_grad
575+
576+ optimizer = torch_optimizer
0 commit comments