Skip to content

Restoring SequentialLR has undocumented side-effects on Optimizer #119168

@ceisenach

Description

@ceisenach

🐛 Describe the bug

When saving and restoring optimizer and LRScheduler states, the order in which the state_dicts are restored determines whether or not the restored optimizer behaves correctly.

Consider the following example

import torch

model = torch.nn.Linear(10, 10)
optim = torch.optim.SGD(model.parameters(), lr=3e-5)
lr = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=cosine_decay_linear_warmup)
lrs = []

for i in range(100):
    optim.step()
    lr.step()
    lrs.append(lr.get_last_lr())

model2 = torch.nn.Linear(10, 10)
optim2 = torch.optim.SGD(model2.parameters(), lr=3e-4)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, T_max=80, eta_min=3e-5)
scheduler1 = torch.optim.lr_scheduler.LinearLR(optim2, start_factor=0.1, end_factor=1, total_iters=20)
lr2 = torch.optim.lr_scheduler.SequentialLR(optim2, schedulers=[scheduler1, scheduler2], milestones=[20])
lrs2 = []
lrs3 = []

for i in range(25):
    optim2.step()
    lr2.step()
    lrs2.append(lr2.get_last_lr())
    lrs3.append(lr2.get_last_lr())

torch.save(lr2.state_dict(), '/home/ubuntu/save_seq2.pt')
torch.save(optim2.state_dict(), '/home/ubuntu/save_optim2.pt')
    
# Correct Behavior
model2 = torch.nn.Linear(10, 10)
optim2 = torch.optim.SGD(model2.parameters(), lr=3e-4)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, T_max=80, eta_min=3e-5)
scheduler1 = torch.optim.lr_scheduler.LinearLR(optim2, start_factor=0.1, end_factor=1, total_iters=20)
lr2 = torch.optim.lr_scheduler.SequentialLR(optim2, schedulers=[scheduler1, scheduler2], milestones=[20])
lr2.load_state_dict(torch.load('/home/ubuntu/save_seq2.pt'))
optim2.load_state_dict(torch.load('/home/ubuntu/save_optim2.pt'))

for i in range(25, 100):
    lr2.step()
    lrs2.append(lr2.get_last_lr())
    
# Incorrect Behavior
model2 = torch.nn.Linear(10, 10)
optim2 = torch.optim.SGD(model2.parameters(), lr=3e-4)
optim2.load_state_dict(torch.load('/home/ubuntu/save_optim2.pt'))
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, T_max=80, eta_min=3e-5)
scheduler1 = torch.optim.lr_scheduler.LinearLR(optim2, start_factor=0.1, end_factor=1, total_iters=20)
lr2 = torch.optim.lr_scheduler.SequentialLR(optim2, schedulers=[scheduler1, scheduler2], milestones=[20])
lr2.load_state_dict(torch.load('/home/ubuntu/save_seq2.pt'))

for i in range(25, 100):
    lr2.step()
    lrs3.append(lr2.get_last_lr())

The first example (with no restore) produces the following learning rate
image

The second example where the optimizer is restored last, the behavior is also correct
image

The third example, where the optimizer is restored first, the behavior is incorrect.
image

This is caused because the SequentialLR has side effects on the optimizer when it is initialized. Other LRSchedulers do not cause the same side-effects (ie order of restoring objects does not matter).

Versions

PyTorch version: 2.2.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.18 | packaged by conda-forge | (default, Dec 23 2023, 17:21:28) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1026-aws-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 3000.000
BogoMIPS: 6000.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] torch==2.2.0+cu118
[pip3] torchaudio==2.2.0+cu118
[pip3] torchvision==0.17.0+cu118
[pip3] triton==2.2.0
[conda] numpy 1.24.1 pypi_0 pypi
[conda] torch 2.2.0+cu118 pypi_0 pypi
[conda] torchaudio 2.2.0+cu118 pypi_0 pypi
[conda] torchvision 0.17.0+cu118 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: LrSchedulermodule: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions