Skip to content

Commit 03ac94a

Browse files
nabenabe0928ravinkohli
authored andcommitted
[fix] Address Ravin's comments and fix range issues in row cut
1 parent 2248047 commit 03ac94a

File tree

4 files changed

+21
-18
lines changed

4 files changed

+21
-18
lines changed

autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]:
6464
underlying model and returns the transformed array.
6565
6666
Args:
67-
X (Dict[str, Any])): 'X' dictionary
67+
X (Dict[str, Any])): fit dictionary
6868
6969
Returns:
70-
(Dict[str, Any]): the updated 'X' dictionary
70+
(Dict[str, Any]): the updated fit dictionary
7171
"""
7272
X.update({'train_data_loader': self.train_data_loader,
7373
'val_data_loader': self.val_data_loader,

autoPyTorch/pipeline/components/training/trainer/GridCutMixTrainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3838
lam = self.random_state.beta(alpha, beta)
3939
batch_size, _, W, H = X.shape
4040
device = torch.device('cuda' if X.is_cuda else 'cpu')
41-
batch_indices = torch.randperm(batch_size).to(device)
41+
permed_indices = torch.randperm(batch_size).to(device)
4242

4343
r = self.random_state.rand(1)
4444
if beta <= 0 or r > self.alpha:
45-
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
45+
return X, {'y_a': y, 'y_b': y[permed_indices], 'lam': 1}
4646

4747
# Draw parameters of a random bounding box
4848
# Where to cut basically
@@ -56,13 +56,13 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
5656
bbx2 = np.clip(cx + cut_w // 2, 0, W)
5757
bby2 = np.clip(cy + cut_h // 2, 0, H)
5858

59-
X[:, :, bbx1:bbx2, bby1:bby2] = X[batch_indices, :, bbx1:bbx2, bby1:bby2]
59+
X[:, :, bbx1:bbx2, bby1:bby2] = X[permed_indices, :, bbx1:bbx2, bby1:bby2]
6060

6161
# Adjust lam
6262
pixel_size = W * H
6363
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / pixel_size)
6464

65-
y_a, y_b = y, y[batch_indices]
65+
y_a, y_b = y, y[permed_indices]
6666

6767
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
6868

autoPyTorch/pipeline/components/training/trainer/RowCutMixTrainer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,29 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
3030
lam = self.random_state.beta(alpha, beta)
3131
batch_size = X.shape[0]
3232
device = torch.device('cuda' if X.is_cuda else 'cpu')
33-
batch_indices = torch.randperm(batch_size).to(device)
33+
permed_indices = torch.randperm(batch_size).to(device)
3434

3535
r = self.random_state.rand(1)
3636
if beta <= 0 or r > self.alpha:
37-
return X, {'y_a': y, 'y_b': y[batch_indices], 'lam': 1}
37+
return X, {'y_a': y, 'y_b': y[permed_indices], 'lam': 1}
3838

39-
row_size = X.shape[1]
40-
row_indices = torch.tensor(
39+
# batch_size (permutation of rows), col_size = X.shape
40+
col_size = X.shape[1]
41+
col_indices = torch.tensor(
4142
self.random_state.choice(
42-
range(1, row_size),
43-
max(1, int(row_size * lam)),
43+
range(col_size),
44+
max(1, int(col_size * lam)),
4445
replace=False
4546
)
4647
)
4748

48-
X[:, row_indices] = X[batch_indices, :][:, row_indices]
49+
# Replace selected columns with columns from another data point
50+
X[:, col_indices] = X[permed_indices, :][:, col_indices]
4951

5052
# Adjust lam
51-
lam = 1 - len(row_indices) / X.shape[1]
53+
lam = 1 - len(col_indices) / X.shape[1]
5254

53-
y_a, y_b = y, y[batch_indices]
55+
y_a, y_b = y, y[permed_indices]
5456

5557
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
5658

autoPyTorch/pipeline/components/training/trainer/RowCutOutTrainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
4646
lam = 1
4747
return X, {'y_a': y_a, 'y_b': y_b, 'lam': lam}
4848

49-
row_size = X.shape[1]
50-
row_indices = self.random_state.choice(range(1, row_size), max(1, int(row_size * self.patch_ratio)),
49+
# (batch_size (permutation of rows), col_size) = X.shape
50+
col_size = X.shape[1]
51+
col_indices = self.random_state.choice(range(col_size), max(1, int(col_size * self.patch_ratio)),
5152
replace=False)
5253

5354
if not isinstance(self.numerical_columns, typing.Iterable):
@@ -56,7 +57,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray,
5657
self.numerical_columns))
5758

5859
numerical_indices = torch.tensor(self.numerical_columns)
59-
categorical_indices = torch.tensor([idx for idx in row_indices if idx not in self.numerical_columns])
60+
categorical_indices = torch.tensor([idx for idx in col_indices if idx not in self.numerical_columns])
6061

6162
X[:, categorical_indices.long()] = self.CATEGORICAL_VALUE
6263
X[:, numerical_indices.long()] = self.NUMERICAL_VALUE

0 commit comments

Comments
 (0)