diff --git a/basicsr/data/ffhq_blind_dataset.py b/basicsr/data/ffhq_blind_dataset.py index 9f90060..e4bc6fa 100755 --- a/basicsr/data/ffhq_blind_dataset.py +++ b/basicsr/data/ffhq_blind_dataset.py @@ -46,11 +46,7 @@ def __init__(self, opt): else: self.crop_components = False - if self.latent_gt_path is not None: - self.load_latent_gt = True - self.latent_gt_dict = torch.load(self.latent_gt_path) - else: - self.load_latent_gt = False + self.latent_gt_dict = None if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = self.gt_folder @@ -177,6 +173,14 @@ def get_component_locations(self, name, status): def __getitem__(self, index): + + if self.latent_gt_path is not None: + self.load_latent_gt = True + if self.latent_gt_dict is None: + self.latent_gt_dict = torch.load(self.latent_gt_path) + else: + self.load_latent_gt = False + if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) diff --git a/basicsr/data/ffhq_blind_joint_dataset.py b/basicsr/data/ffhq_blind_joint_dataset.py index 0dc845f..0bd00be 100755 --- a/basicsr/data/ffhq_blind_joint_dataset.py +++ b/basicsr/data/ffhq_blind_joint_dataset.py @@ -45,11 +45,7 @@ def __init__(self, opt): else: self.crop_components = False - if self.latent_gt_path is not None: - self.load_latent_gt = True - self.latent_gt_dict = torch.load(self.latent_gt_path) - else: - self.load_latent_gt = False + self.latent_gt_dict = None if self.io_backend_opt['type'] == 'lmdb': self.io_backend_opt['db_paths'] = self.gt_folder @@ -169,6 +165,14 @@ def get_component_locations(self, name, status): def __getitem__(self, index): + + if self.latent_gt_path is not None: + self.load_latent_gt = True + if self.latent_gt_dict is None: + self.latent_gt_dict = torch.load(self.latent_gt_path) + else: + self.load_latent_gt = False + if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)