File tree Expand file tree Collapse file tree 2 files changed +12
-9
lines changed Expand file tree Collapse file tree 2 files changed +12
-9
lines changed Original file line number Diff line number Diff line change 66import numpy as np
77from torch import Tensor
88
9- np .random .seed (0 )
10-
119
1210@dataclass
1311class InMemoryDataLoader :
@@ -36,24 +34,25 @@ def __post_init__(self):
3634
3735 def __iter__ (self ) -> Iterator [tuple [Tensor , ...]]:
3836 self .indices = np .random .permutation (self .dataset_len ) if self .shuffle else None
39- self .idx = 0
37+ self .current_idx = 0
4038 return self
4139
4240 def __next__ (self ) -> tuple [Tensor , ...]:
43- if self .idx >= self .dataset_len :
41+ start_idx = self .current_idx
42+ if start_idx >= self .dataset_len :
4443 raise StopIteration
4544
46- end_idx = self . idx + self .batch_size
45+ end_idx = start_idx + self .batch_size
4746
4847 if self .indices is None : # shuffle=False
49- slices = (t [self . idx : end_idx ] for t in self .tensors )
48+ slices = (t [start_idx : end_idx ] for t in self .tensors )
5049 else :
51- idx = self .indices [self . idx : end_idx ]
50+ idx = self .indices [start_idx : end_idx ]
5251 slices = (t [idx ] for t in self .tensors )
5352
5453 batch = self .collate_fn (* slices )
5554
56- self .idx += self .batch_size
55+ self .current_idx += self .batch_size
5756 return batch
5857
5958 def __len__ (self ) -> int :
Original file line number Diff line number Diff line change 66import numpy as np
77import pandas as pd
88import torch
9- import wandb
109from torch import nn
1110from torch .optim .swa_utils import SWALR , AveragedModel
1211
2322from aviary .wrenformer .model import Wrenformer
2423from aviary .wrenformer .utils import print_walltime
2524
25+ try :
26+ import wandb
27+ except ImportError :
28+ pass
29+
2630__author__ = "Janosh Riebesell"
2731__date__ = "2022-06-12"
2832
You can’t perform that action at this time.
0 commit comments