Skip to content

Commit a7284e2

Browse files
yuvalkirstainpre-commit-ci[bot]Bordamergify[bot]
authored
enable specifying weights path for fid (#2867)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka B <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 3ff199c commit a7284e2

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717

1818
### Changed
1919

20-
-
20+
- Enabled specifying weights path for FID ([#2867](https://github.com/PyTorchLightning/metrics/pull/2867))
2121

2222

2323
### Removed

src/torchmetrics/image/fid.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def __init__(
300300
reset_real_features: bool = True,
301301
normalize: bool = False,
302302
input_img_size: tuple[int, int, int] = (3, 299, 299),
303+
feature_extractor_weights_path: Optional[str] = None,
303304
**kwargs: Any,
304305
) -> None:
305306
super().__init__(**kwargs)
@@ -322,7 +323,11 @@ def __init__(
322323
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
323324
)
324325

325-
self.inception = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
326+
self.inception = NoTrainInceptionV3(
327+
name="inception-v3-compat",
328+
features_list=[str(feature)],
329+
feature_extractor_weights_path=feature_extractor_weights_path,
330+
)
326331

327332
elif isinstance(feature, Module):
328333
self.inception = feature

0 commit comments

Comments
 (0)