-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
featureIs an improvement or enhancementIs an improvement or enhancementloggingRelated to the `LoggerConnector` and `log()`Related to the `LoggerConnector` and `log()`
Milestone
Description
🐛 Bug
The way I understand it, it is recommended to log a TorchMetric instance, instead of its .compute() output (only if you do this, DDP synchronization is possible, should you want it):
def __init__(self):
self.val_acc = Accuracy()
...
def validation_step_end(self, outputs):
self.log("val/acc", self.val_acc)However I was suprised to learn that this also quietly seems to reset the metric, even if the metric should not be reset:
def __init__(self):
self.val_acc = Accuracy()
self.val_acc_best = MaxMetric()
...
def validation_epoch_end(self, outputs):
self.val_acc_best.update(self.val_acc.compute())
self.log("val/acc_best", self.val_acc_best)
# after this line, self.val_acc_best.value = -InfImagine my surprise that my "best accuracy" was decreasing every few epochs.
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Accuracy, MaxMetric
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.val_acc = Accuracy()
self.val_acc_best = MaxMetric()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
val_loss = self(batch).sum()
self.log("valid_loss", val_loss)
return {"loss": val_loss}
def validation_step_end(self, outputs):
print(self.val_acc_best.value)
self.val_acc_best.update(outputs['loss']) # get val accuracy from current epoch
print(self.val_acc_best.value)
self.log("valid_loss_best", self.val_acc_best)
print(self.val_acc_best.value)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=2,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
if __name__ == "__main__":
run()As you will see in the output, self.val_acc_best.value will be -Inf in the second epoch, and therefore fails to track the best loss.
Expected behavior
Environment
* CUDA:
- GPU:
- NVIDIA GeForce GTX 1060 6GB
- available: True
- version: 10.2
* Packages:
- numpy: 1.20.3
- pyTorch_debug: False
- pyTorch_version: 1.10.0+cu102
- pytorch-lightning: 1.6.0dev
- tqdm: 4.62.3
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor:
- python: 3.8.12
- version: #1 ZEN SMP PREEMPT Wed, 22 Dec 2021 09:23:53 +0000
ananthsub and edwardpwtsoi
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementloggingRelated to the `LoggerConnector` and `log()`Related to the `LoggerConnector` and `log()`