Skip to content

Datamodule not calling load_state_dict() when loading from checkpoint #14842

@dconathan

Description

@dconathan

First check

  • I'm sure this is a bug.
  • I've added a descriptive title to this bug.
  • I've provided clear instructions on how to reproduce the bug.
  • I've added a code sample.
  • I've provided any other important info that is required.

Bug description

Sorry if this is part of #14841 but wanted to make sure it gets fixed as part of that if not!

In short, when you call MyDataModule.load_from_checkpoint(checkpoint_path), it doesn't seem like the MyDataModule.load_state_dict() is being called.

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer, LightningDataModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringDataModule(LightningDataModule):
    def state_dict(self):
        print("state_dict()")
        return dict()

    def load_state_dict(self, state_dict):
        print("load_state_dict()")
        raise RuntimeError("this should be raised!")

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():

    datamodule = BoringDataModule()
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, datamodule)
    trainer.test(model, datamodule)
    checkpoint_path = os.path.join(trainer.log_dir, "checkpoints", "epoch=0-step=1.ckpt")
    # should raise the RunTime error from BoringDataModule.load_state_dict() ?
    loaded_datamodule = BoringDataModule.load_from_checkpoint(checkpoint_path)
    assert isinstance(loaded_datamodule, BoringDataModule)


if __name__ == "__main__":
    run()

Error messages and logs

No response

Important info


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningDataModule
#- PyTorch Lightning Version (e.g., 1.5.0): 1.7.6
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10): 1.12.1
#- Python version (e.g., 3.9): 3.10
#- OS (e.g., Linux): MacOS
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): `pip`
#- Running environment of LightningApp (e.g. local, cloud): `local`

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions