Skip to content

Commit 6b990f3

Browse files
authored
Add artifcact_location arg to MLFlow logger (#6677)
* Add artifcact_location arg to MLFlow logger * Add CHANGELOG URL * Update test
1 parent 0ea8f39 commit 6b990f3

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6767
- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))
6868

6969

70+
- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))
71+
72+
7073
### Changed
7174

7275
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))

pytorch_lightning/loggers/mlflow.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def any_lightning_module_function_or_hook(self):
7878
Defaults to `./mlflow` if `tracking_uri` is not provided.
7979
Has no effect if `tracking_uri` is provided.
8080
prefix: A string to put at the beginning of metric keys.
81+
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
82+
default.
8183
8284
Raises:
8385
ImportError:
@@ -93,6 +95,7 @@ def __init__(
9395
tags: Optional[Dict[str, Any]] = None,
9496
save_dir: Optional[str] = './mlruns',
9597
prefix: str = '',
98+
artifact_location: Optional[str] = None,
9699
):
97100
if mlflow is None:
98101
raise ImportError(
@@ -109,6 +112,8 @@ def __init__(
109112
self._run_id = None
110113
self.tags = tags
111114
self._prefix = prefix
115+
self._artifact_location = artifact_location
116+
112117
self._mlflow_client = MlflowClient(tracking_uri)
113118

114119
@property
@@ -129,7 +134,10 @@ def experiment(self) -> MlflowClient:
129134
self._experiment_id = expt.experiment_id
130135
else:
131136
log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.')
132-
self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name)
137+
self._experiment_id = self._mlflow_client.create_experiment(
138+
name=self._experiment_name,
139+
artifact_location=self._artifact_location,
140+
)
133141

134142
if self._run_id is None:
135143
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags)

tests/loggers/test_mlflow.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,31 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir):
199199

200200
with pytest.warns(RuntimeWarning, match=f'Discard {key}={value}'):
201201
logger.log_hyperparams(params)
202+
203+
204+
@mock.patch('pytorch_lightning.loggers.mlflow.time')
205+
@mock.patch('pytorch_lightning.loggers.mlflow.mlflow')
206+
@mock.patch('pytorch_lightning.loggers.mlflow.MlflowClient')
207+
def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir):
208+
"""
209+
Test that the logger calls methods on the mlflow experiment correctly.
210+
"""
211+
time.return_value = 1
212+
213+
logger = MLFlowLogger('test', save_dir=tmpdir, artifact_location='my_artifact_location')
214+
logger._mlflow_client.get_experiment_by_name.return_value = None
215+
216+
params = {'test': 'test_param'}
217+
logger.log_hyperparams(params)
218+
219+
logger.experiment.log_param.assert_called_once_with(logger.run_id, 'test', 'test_param')
220+
221+
metrics = {'some_metric': 10}
222+
logger.log_metrics(metrics)
223+
224+
logger.experiment.log_metric.assert_called_once_with(logger.run_id, 'some_metric', 10, 1000, None)
225+
226+
logger._mlflow_client.create_experiment.assert_called_once_with(
227+
name='test',
228+
artifact_location='my_artifact_location',
229+
)

0 commit comments

Comments
 (0)