Skip to content

Commit 2294981

Browse files
add pytree support
1 parent 88cef1d commit 2294981

File tree

7 files changed

+140
-26
lines changed

7 files changed

+140
-26
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@
66

77
- Add more type auto conversion for `tc.gates.Gate` as inputs
88

9+
- Add `tree_flatten` and `tree_unflatten` method on backends
10+
11+
- Add torch optimizer to the backend agnostic optimizer abstraction
12+
13+
### Changed
14+
15+
- Refactor the tree utils, add native torch support for pytree utils
16+
17+
### Fixed
18+
19+
- grad in torch backend now support pytrees
20+
921
## 0.1.2
1022

1123
### Added

docs/source/textbook/chap4.ipynb

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,9 @@
743743
"\n",
744744
"\n",
745745
"\n",
746-
"* > 关于Grover search的进一步说明:我们从以上的推导可以发现,Grover Search 应用于搜索某个问题的解。在经典计算中,对某些问题,最好的方法也只能通过暴力尝试所有$O(N)$个输入并计算$f(x)$来找到解,而Grover search 可以将尝试的次数减少到$O(\\sqrt{N})$。如果每一次经典的尝试和量子的尝试相差的时间不大的话,那么量子计算就会比经典计算快得多。值得注意的是,这种搜索与数据库的搜索并不一致,是无序无结构的搜索。Grover search原则上也可以进行数据库搜索,但其相对于经典算法的优越性则要看数据存储的结构。"
746+
"* > 关于Grover search的进一步说明:我们从以上的推导可以发现,Grover Search 应用于搜索某个问题的解。在经典计算中,对某些问题,最好的方法也只能通过暴力尝试所有$O(N)$个输入并计算$f(x)$来找到解,而Grover search 可以将尝试的次数减少到$O(\\sqrt{N})$。如果每一次经典的尝试和量子的尝试相差的时间不大的话,那么量子计算就会比经典计算快得多。值得注意的是,这种搜索与数据库的搜索并不一致,是无序无结构的搜索。Grover search原则上也可以进行数据库搜索,但其相对于经典算法的优越性则要看数据存储的结构。\n",
747+
"\n",
748+
"更多关于基础 Grover 搜索算法的扩展和证明,可以参考【1】。"
747749
]
748750
},
749751
{
@@ -838,6 +840,15 @@
838840
"source": [
839841
"c.sample()"
840842
]
843+
},
844+
{
845+
"cell_type": "markdown",
846+
"metadata": {},
847+
"source": [
848+
"## 参考文献\n",
849+
"\n",
850+
"【1】https://arxiv.org/pdf/quant-ph/9605034.pdf"
851+
]
841852
}
842853
],
843854
"metadata": {

tensorcircuit/backends/abstract_backend.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -672,26 +672,52 @@ def tree_map(self: Any, f: Callable[..., Any], *pytrees: Any) -> Any:
672672
:rtype: Any
673673
"""
674674
try:
675-
import jax as libjax
675+
import tensorflow as tf
676676

677-
has_jax = True
678677
except ImportError:
679-
has_jax = False
680-
try:
681-
import tensorflow as tf
678+
raise NotImplementedError("No installed ML backend for `tree_map`")
682679

683-
has_tf = True
684-
except ImportError:
685-
has_tf = False
680+
return tf.nest.map_structure(f, *pytrees)
686681

687-
if has_jax:
688-
r = libjax.tree_map(f, *pytrees)
689-
elif has_tf:
690-
r = tf.nest.map_structure(f, *pytrees)
691-
else:
692-
raise NotImplementedError("Only tensorflow and jax support `tree_map`")
682+
def tree_flatten(self: Any, pytree: Any) -> Tuple[Any, Any]:
683+
"""
684+
Flatten python structure to 1D list
693685
694-
return r
686+
:param pytree: python structure to be flattened
687+
:type pytree: Any
688+
:return: The 1D list of flattened structure and treedef
689+
which can be used for later unflatten
690+
:rtype: Tuple[Any, Any]
691+
"""
692+
try:
693+
import tensorflow as tf
694+
695+
except ImportError:
696+
raise NotImplementedError("No installed ML backend for `tree_flatten`")
697+
698+
leaves = tf.nest.flatten(pytree)
699+
treedef = pytree
700+
701+
return leaves, treedef
702+
703+
def tree_unflatten(self: Any, treedef: Any, leaves: Any) -> Any:
704+
"""
705+
Pack 1D list to pytree defined via ``treedef``
706+
707+
:param treedef: Def of pytree structure, the second return from ``tree_flatten``
708+
:type treedef: Any
709+
:param leaves: the 1D list of flattened data structure
710+
:type leaves: Any
711+
:return: Packed pytree
712+
:rtype: Any
713+
"""
714+
try:
715+
import tensorflow as tf
716+
717+
except ImportError:
718+
raise NotImplementedError("No installed ML backend for `tree_unflatten`")
719+
720+
return tf.nest.pack_sequence_as(treedef, leaves)
695721

696722
def set_random_state(
697723
self: Any, seed: Optional[int] = None, get_only: bool = False

tensorcircuit/backends/jax_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,15 @@ def is_tensor(self, a: Any) -> bool:
378378
def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor:
379379
return jsp.linalg.solve(A, b, assume_a)
380380

381+
def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
382+
return libjax.tree_map(f, *pytrees)
383+
384+
def tree_flatten(self: Any, pytree: Any) -> Tuple[Any, Any]:
385+
return libjax.tree_flatten(pytree) # type: ignore
386+
387+
def tree_unflatten(self: Any, treedef: Any, leaves: Any) -> Any:
388+
return libjax.tree_unflatten(treedef, leaves)
389+
381390
def set_random_state(
382391
self, seed: Optional[Union[int, PRNGKeyArray]] = None, get_only: bool = False
383392
) -> Any:

tensorcircuit/backends/pytorch_backend.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
dtypestr: str
2525
Tensor = Any
26+
pytree = Any
2627

2728
torchlib: Any
2829

@@ -34,6 +35,28 @@
3435
# To be added once pytorch backend is ready
3536

3637

38+
class torch_optimizer:
39+
def __init__(self, optimizer: Any) -> None:
40+
self.optimizer = optimizer
41+
self.is_init = False
42+
43+
def update(self, grads: pytree, params: pytree) -> pytree:
44+
# flatten grad and param
45+
params, treedef = PyTorchBackend.tree_flatten(None, params)
46+
grads, _ = PyTorchBackend.tree_flatten(None, grads)
47+
if self.is_init is False:
48+
self.optimizer = self.optimizer(params)
49+
self.is_init = True
50+
with torchlib.no_grad():
51+
for g, p in zip(grads, params):
52+
p.grad = g
53+
self.optimizer.step()
54+
self.optimizer.zero_grad()
55+
# reorg the param
56+
params = PyTorchBackend.tree_unflatten(None, treedef, params)
57+
return params
58+
59+
3760
def _conj_torch(self: Any, tensor: Tensor) -> Tensor:
3861
t = torchlib.conj(tensor)
3962
return t.resolve_conj() # any side effect?
@@ -355,6 +378,16 @@ def cast(self, a: Tensor, dtype: str) -> Tensor:
355378
def solve(self, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
356379
return torchlib.linalg.solve(A, b)
357380

381+
def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
382+
# TODO(@refraction-ray): torch not support multiple pytree args
383+
return torchlib.utils._pytree.tree_map(f, *pytrees)
384+
385+
def tree_flatten(self: Any, pytree: Any) -> Tuple[Any, Any]:
386+
return torchlib.utils._pytree.tree_flatten(pytree) # type: ignore
387+
388+
def tree_unflatten(self: Any, treedef: Any, leaves: Any) -> Any:
389+
return torchlib.utils._pytree.tree_unflatten(leaves, treedef)
390+
358391
def cond(
359392
self,
360393
pred: bool,
@@ -413,6 +446,13 @@ def value_and_grad(
413446
argnums: Union[int, Sequence[int]] = 0,
414447
has_aux: bool = False,
415448
) -> Callable[..., Tuple[Any, Any]]:
449+
def ask_require(t: Tensor) -> Any:
450+
t.requires_grad_(True)
451+
return t
452+
453+
def get_grad(t: Tensor) -> Tensor:
454+
return t.grad
455+
416456
def wrapper(*args: Any, **kws: Any) -> Any:
417457
x = []
418458
if isinstance(argnums, int):
@@ -423,15 +463,15 @@ def wrapper(*args: Any, **kws: Any) -> Any:
423463
argnumsl = argnums # type: ignore
424464
for i, arg in enumerate(args):
425465
if i in argnumsl:
426-
x.append(arg.requires_grad_(True))
466+
x.append(self.tree_map(ask_require, arg))
427467
else:
428468
x.append(arg)
429469
y = f(*x, **kws)
430470
if has_aux:
431471
y[0].backward()
432472
else:
433473
y.backward()
434-
gs = [x[i].grad for i in argnumsl]
474+
gs = [self.tree_map(get_grad, x[i]) for i in argnumsl]
435475
if len(gs) == 1:
436476
gs = gs[0]
437477
return y, gs
@@ -532,3 +572,5 @@ def vectorized_value_and_grad(
532572
return f
533573

534574
vvag = vectorized_value_and_grad
575+
576+
optimizer = torch_optimizer

tests/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,11 @@ def jaxb():
4040
def torchb():
4141
try:
4242
tc.set_backend("pytorch")
43-
tc.set_dtype("float64")
4443
yield
4544
tc.set_backend("numpy")
46-
tc.set_dtype("complex64")
4745
except ImportError as e:
4846
print(e)
4947
tc.set_backend("numpy")
50-
tc.set_dtype("complex64")
5148
pytest.skip("****** No torch backend found, skipping test suit *******")
5249

5350

tests/test_backends.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -779,14 +779,29 @@ def test_solve(backend):
779779
np.testing.assert_allclose(xp, x[:, 0], atol=1e-5)
780780

781781

782-
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
782+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
783+
def test_treeutils(backend):
784+
d0 = {"a": np.ones([2]), "b": [tc.backend.zeros([]), tc.backend.ones([1, 1])]}
785+
leaves, treedef = tc.backend.tree_flatten(d0)
786+
d1 = tc.backend.tree_unflatten(treedef, leaves)
787+
d2 = tc.backend.tree_map(lambda x: 2 * x, d1)
788+
np.testing.assert_allclose(2 * np.ones([1, 1]), d2["b"][1])
789+
790+
791+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
783792
def test_optimizers(backend):
784793
if tc.backend.name == "jax":
785794
try:
786795
import optax
787796
except ImportError:
788797
pytest.skip("optax is not installed")
789798

799+
if tc.backend.name == "pytorch":
800+
try:
801+
import torch
802+
except ImportError:
803+
pytest.skip("torch is not installed")
804+
790805
def f(params, n):
791806
c = tc.Circuit(n)
792807
c = tc.templates.blocks.example_block(c, params["a"])
@@ -802,6 +817,9 @@ def get_opt():
802817
elif tc.backend.name == "jax":
803818
optimizer2 = optax.adam(5e-2)
804819
opt = tc.backend.optimizer(optimizer2)
820+
elif tc.backend.name == "pytorch":
821+
optimizer3 = partial(torch.optim.Adam, lr=5e-2)
822+
opt = tc.backend.optimizer(optimizer3)
805823
else:
806824
raise ValueError("%s doesn't support optimizer interface" % tc.backend.name)
807825
return opt
@@ -810,8 +828,8 @@ def get_opt():
810828
opt = get_opt()
811829

812830
params = {
813-
"a": tc.backend.implicit_randn([4, n]),
814-
"b": tc.backend.implicit_randn([4, n]),
831+
"a": tc.backend.ones([4, n], dtype="float32"),
832+
"b": tc.backend.ones([4, n], dtype="float32"),
815833
}
816834

817835
for _ in range(20):
@@ -828,12 +846,11 @@ def f2(params, n):
828846

829847
vgs2 = tc.backend.jit(tc.backend.value_and_grad(f2, argnums=0), static_argnums=1)
830848

831-
params = tc.backend.implicit_randn([4, n])
849+
params = tc.backend.ones([4, n], dtype="float32")
832850
opt = get_opt()
833851

834852
for _ in range(20):
835853
loss, grads = vgs2(params, n)
836-
print(grads, params)
837854
params = opt.update(grads, params)
838855
print(loss)
839856

0 commit comments

Comments
 (0)