You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Refactor(examples/mnist): Clarify loss reduction and correct dataset length
This commit addresses two potential points of confusion and error in the MNIST example, as detailed in issue #623.
1. Explicit Loss Reduction:
The train() function's use of F.nll_loss implicitly defaults to reduction='mean', whereas the test() function uses reduction='sum'.
This change makes the `reduction='mean'` explicit in the train() function. This improves code clarity.
2. Correct Dataset Size with Samplers:
Using `len(loader.dataset)` to get the number of samples is incorrect when a Sampler (e.g., SubsetRandomSampler for a validation split) is used. It incorrectly reports the full dataset size, not the subset size.
The logic is updated to first check `len(loader.sampler)`. If a sampler exists, its length is used. Otherwise, it falls back to `len(loader.dataset)`. This ensures the correct number of samples is used for logging and calculations.
Fixes#623
0 commit comments