Skip to content

Commit b72d328

Browse files
committed
guard against wandb not installed in examples/wrenformer.py
remove global numpy random seed in aviary/data.py
1 parent bd08f56 commit b72d328

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

aviary/data.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import numpy as np
77
from torch import Tensor
88

9-
np.random.seed(0)
10-
119

1210
@dataclass
1311
class InMemoryDataLoader:
@@ -36,24 +34,25 @@ def __post_init__(self):
3634

3735
def __iter__(self) -> Iterator[tuple[Tensor, ...]]:
3836
self.indices = np.random.permutation(self.dataset_len) if self.shuffle else None
39-
self.idx = 0
37+
self.current_idx = 0
4038
return self
4139

4240
def __next__(self) -> tuple[Tensor, ...]:
43-
if self.idx >= self.dataset_len:
41+
start_idx = self.current_idx
42+
if start_idx >= self.dataset_len:
4443
raise StopIteration
4544

46-
end_idx = self.idx + self.batch_size
45+
end_idx = start_idx + self.batch_size
4746

4847
if self.indices is None: # shuffle=False
49-
slices = (t[self.idx : end_idx] for t in self.tensors)
48+
slices = (t[start_idx:end_idx] for t in self.tensors)
5049
else:
51-
idx = self.indices[self.idx : end_idx]
50+
idx = self.indices[start_idx:end_idx]
5251
slices = (t[idx] for t in self.tensors)
5352

5453
batch = self.collate_fn(*slices)
5554

56-
self.idx += self.batch_size
55+
self.current_idx += self.batch_size
5756
return batch
5857

5958
def __len__(self) -> int:

examples/wrenformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import numpy as np
77
import pandas as pd
88
import torch
9-
import wandb
109
from torch import nn
1110
from torch.optim.swa_utils import SWALR, AveragedModel
1211

@@ -23,6 +22,11 @@
2322
from aviary.wrenformer.model import Wrenformer
2423
from aviary.wrenformer.utils import print_walltime
2524

25+
try:
26+
import wandb
27+
except ImportError:
28+
pass
29+
2630
__author__ = "Janosh Riebesell"
2731
__date__ = "2022-06-12"
2832

0 commit comments

Comments
 (0)