2424import pytest
2525import torch
2626from fsspec .implementations .local import LocalFileSystem
27- from omegaconf import Container , OmegaConf
28- from omegaconf .dictconfig import DictConfig
2927from torch .utils .data import DataLoader
3028
3129from pytorch_lightning import LightningModule , Trainer
3230from pytorch_lightning .callbacks import ModelCheckpoint
3331from pytorch_lightning .core .datamodule import LightningDataModule
3432from pytorch_lightning .core .saving import load_hparams_from_yaml , save_hparams_to_yaml
35- from pytorch_lightning .utilities import _HYDRA_EXPERIMENTAL_AVAILABLE , AttributeDict , is_picklable
33+ from pytorch_lightning .utilities import _HYDRA_EXPERIMENTAL_AVAILABLE , _OMEGACONF_AVAILABLE , AttributeDict , is_picklable
3634from pytorch_lightning .utilities .exceptions import MisconfigurationException
3735from tests .helpers import BoringModel , RandomDataset
36+ from tests .helpers .runif import RunIf
3837
3938if _HYDRA_EXPERIMENTAL_AVAILABLE :
4039 from hydra .experimental import compose , initialize
4140
41+ if _OMEGACONF_AVAILABLE :
42+ from omegaconf import Container , OmegaConf
43+ from omegaconf .dictconfig import DictConfig
44+
4245
4346class SaveHparamsModel (BoringModel ):
4447 """Tests that a model can take an object."""
@@ -117,6 +120,7 @@ def test_dict_hparams(tmpdir, cls):
117120 _run_standard_hparams_test (tmpdir , model , cls )
118121
119122
123+ @RunIf (omegaconf = True )
120124@pytest .mark .parametrize ("cls" , [SaveHparamsModel , SaveHparamsDecoratedModel ])
121125def test_omega_conf_hparams (tmpdir , cls ):
122126 # init model
@@ -275,10 +279,18 @@ def __init__(obj, *more_args, other_arg=300, **more_kwargs):
275279 obj .save_hyperparameters ()
276280
277281
278- class DictConfSubClassBoringModel (SubClassBoringModel ):
279- def __init__ (self , * args , dict_conf = OmegaConf .create (dict (my_param = "something" )), ** kwargs ):
280- super ().__init__ (* args , ** kwargs )
281- self .save_hyperparameters ()
282+ if _OMEGACONF_AVAILABLE :
283+
284+ class DictConfSubClassBoringModel (SubClassBoringModel ):
285+ def __init__ (self , * args , dict_conf = OmegaConf .create (dict (my_param = "something" )), ** kwargs ):
286+ super ().__init__ (* args , ** kwargs )
287+ self .save_hyperparameters ()
288+
289+
290+ else :
291+
292+ class DictConfSubClassBoringModel :
293+ ...
282294
283295
284296@pytest .mark .parametrize (
@@ -290,7 +302,7 @@ def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something"))
290302 SubSubClassBoringModel ,
291303 AggSubClassBoringModel ,
292304 UnconventionalArgsBoringModel ,
293- DictConfSubClassBoringModel ,
305+ pytest . param ( DictConfSubClassBoringModel , marks = RunIf ( omegaconf = True )) ,
294306 ],
295307)
296308def test_collect_init_arguments (tmpdir , cls ):
@@ -383,31 +395,6 @@ def test_collect_init_arguments_with_local_vars(cls):
383395 assert model .hparams ["arg2" ] == 2
384396
385397
386- # @pytest.mark.parametrize("cls,config", [
387- # (SaveHparamsModel, Namespace(my_arg=42)),
388- # (SaveHparamsModel, dict(my_arg=42)),
389- # (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))),
390- # (AssignHparamsModel, Namespace(my_arg=42)),
391- # (AssignHparamsModel, dict(my_arg=42)),
392- # (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))),
393- # ])
394- # def test_single_config_models(tmpdir, cls, config):
395- # """ Test that the model automatically saves the arguments passed into the constructor """
396- # model = cls(config)
397- #
398- # # no matter how you do it, it should be assigned
399- # assert model.hparams.my_arg == 42
400- #
401- # # verify that the checkpoint saved the correct values
402- # trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
403- # trainer.fit(model)
404- #
405- # # verify that model loads correctly
406- # raw_checkpoint_path = _raw_checkpoint_path(trainer)
407- # model = cls.load_from_checkpoint(raw_checkpoint_path)
408- # assert model.hparams.my_arg == 42
409-
410-
411398class AnotherArgModel (BoringModel ):
412399 def __init__ (self , arg1 ):
413400 super ().__init__ ()
@@ -511,8 +498,9 @@ def _compare_params(loaded_params, default_params: dict):
511498 save_hparams_to_yaml (path_yaml , AttributeDict (hparams ))
512499 _compare_params (load_hparams_from_yaml (path_yaml , use_omegaconf = False ), hparams )
513500
514- save_hparams_to_yaml (path_yaml , OmegaConf .create (hparams ))
515- _compare_params (load_hparams_from_yaml (path_yaml ), hparams )
501+ if _OMEGACONF_AVAILABLE :
502+ save_hparams_to_yaml (path_yaml , OmegaConf .create (hparams ))
503+ _compare_params (load_hparams_from_yaml (path_yaml ), hparams )
516504
517505
518506class NoArgsSubClassBoringModel (CustomBoringModel ):
0 commit comments