Skip to content

Commit b9eaba9

Browse files
committed
Refactor(examples/mnist): Clarify loss reduction and correct dataset length
This commit addresses two potential points of confusion and error in the MNIST example, as detailed in issue #623. 1. Explicit Loss Reduction: The train() function's use of F.nll_loss implicitly defaults to reduction='mean', whereas the test() function uses reduction='sum'. This change makes the `reduction='mean'` explicit in the train() function. This improves code clarity. 2. Correct Dataset Size with Samplers: Using `len(loader.dataset)` to get the number of samples is incorrect when a Sampler (e.g., SubsetRandomSampler for a validation split) is used. It incorrectly reports the full dataset size, not the subset size. The logic is updated to first check `len(loader.sampler)`. If a sampler exists, its length is used. Otherwise, it falls back to `len(loader.dataset)`. This ensures the correct number of samples is used for logging and calculations. Fixes #623
1 parent acc295d commit b9eaba9

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

mnist/main.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,20 @@ def forward(self, x):
3535

3636
def train(args, model, device, train_loader, optimizer, epoch):
3737
model.train()
38+
# Get the correct number of samples for logging.
39+
# Use len(train_loader.sampler) if a sampler is provided (e.g., SubsetRandomSampler),
40+
# otherwise, use the full dataset length.
41+
data_len = len(train_loader.sampler) if train_loader.sampler is not None else len(train_loader.dataset)
3842
for batch_idx, (data, target) in enumerate(train_loader):
3943
data, target = data.to(device), target.to(device)
4044
optimizer.zero_grad()
4145
output = model(data)
42-
loss = F.nll_loss(output, target)
46+
loss = F.nll_loss(output, target,reduction='mean') # get batch average loss
4347
loss.backward()
4448
optimizer.step()
4549
if batch_idx % args.log_interval == 0:
4650
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
47-
epoch, batch_idx * len(data), len(train_loader.dataset),
51+
epoch, batch_idx * len(data), data_len,
4852
100. * batch_idx / len(train_loader), loss.item()))
4953
if args.dry_run:
5054
break
@@ -54,6 +58,7 @@ def test(model, device, test_loader):
5458
model.eval()
5559
test_loss = 0
5660
correct = 0
61+
data_len = len(test_loader.sampler) if test_loader.sampler is not None else len(test_loader.dataset)
5762
with torch.no_grad():
5863
for data, target in test_loader:
5964
data, target = data.to(device), target.to(device)
@@ -62,11 +67,11 @@ def test(model, device, test_loader):
6267
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
6368
correct += pred.eq(target.view_as(pred)).sum().item()
6469

65-
test_loss /= len(test_loader.dataset)
70+
test_loss /= data_len # get average loss in test_set
6671

6772
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
68-
test_loss, correct, len(test_loader.dataset),
69-
100. * correct / len(test_loader.dataset)))
73+
test_loss, correct, data_len,
74+
100. * correct / data_len))
7075

7176

7277
def main():

0 commit comments

Comments
 (0)