-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinglightningdatamodulepl.LightningDataModulepl.LightningDataModule
Description
First check
- I'm sure this is a bug.
- I've added a descriptive title to this bug.
- I've provided clear instructions on how to reproduce the bug.
- I've added a code sample.
- I've provided any other important info that is required.
Bug description
Sorry if this is part of #14841 but wanted to make sure it gets fixed as part of that if not!
In short, when you call MyDataModule.load_from_checkpoint(checkpoint_path)
, it doesn't seem like the MyDataModule.load_state_dict()
is being called.
How to reproduce the bug
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringDataModule(LightningDataModule):
def state_dict(self):
print("state_dict()")
return dict()
def load_state_dict(self, state_dict):
print("load_state_dict()")
raise RuntimeError("this should be raised!")
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def test_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
datamodule = BoringDataModule()
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model, datamodule)
trainer.test(model, datamodule)
checkpoint_path = os.path.join(trainer.log_dir, "checkpoints", "epoch=0-step=1.ckpt")
# should raise the RunTime error from BoringDataModule.load_state_dict() ?
loaded_datamodule = BoringDataModule.load_from_checkpoint(checkpoint_path)
assert isinstance(loaded_datamodule, BoringDataModule)
if __name__ == "__main__":
run()
Error messages and logs
No response
Important info
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningDataModule
#- PyTorch Lightning Version (e.g., 1.5.0): 1.7.6
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10): 1.12.1
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): MacOS
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): `pip`
#- Running environment of LightningApp (e.g. local, cloud): `local`
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointinglightningdatamodulepl.LightningDataModulepl.LightningDataModule