diff --git a/docs/book/component-guide/experiment-trackers/README.md b/docs/book/component-guide/experiment-trackers/README.md index 5c8d5e878c0..bcb9ead5384 100644 --- a/docs/book/component-guide/experiment-trackers/README.md +++ b/docs/book/component-guide/experiment-trackers/README.md @@ -37,6 +37,7 @@ Experiment Trackers are optional stack components provided by integrations: | [Comet](comet.md) | `comet` | `comet` | Add Comet experiment tracking and visualization capabilities to your ZenML pipelines | | [MLflow](mlflow.md) | `mlflow` | `mlflow` | Add MLflow experiment tracking and visualization capabilities to your ZenML pipelines | | [Neptune](neptune.md) | `neptune` | `neptune` | Add Neptune experiment tracking and visualization capabilities to your ZenML pipelines | +| [Trackio](trackio.md) | `trackio` | `trackio` | Add Trackio experiment tracking and visualization capabilities to your ZenML pipelines | | [Weights & Biases](wandb.md) | `wandb` | `wandb` | Add Weights & Biases experiment tracking and visualization capabilities to your ZenML pipelines | | [Custom Implementation](custom.md) | _custom_ | | _custom_ | diff --git a/docs/book/component-guide/experiment-trackers/trackio.md b/docs/book/component-guide/experiment-trackers/trackio.md new file mode 100644 index 00000000000..81fcc92a0be --- /dev/null +++ b/docs/book/component-guide/experiment-trackers/trackio.md @@ -0,0 +1,154 @@ +--- +description: Logging and visualizing experiments with Trackio. +--- + +# Trackio + +The Trackio Experiment Tracker is an [Experiment Tracker](./) flavor provided by the Trackio-ZenML integration. It connects your ZenML pipelines to [Trackio](https://github.com/gradio-app/trackio), a lightweight experiment tracking solution maintained by the Gradio team, so you can visualize and compare the artifacts and metrics generated by your runs. + +### When would you want to use it? + +Trackio focuses on offering a simple workflow for collecting model metrics and artifacts while staying close to the developer experience that Gradio users are familiar with. You should consider the Trackio Experiment Tracker if: + +* you already use Trackio (or plan to) to keep a history of training runs and evaluations and would like to surface the results of your automated ZenML pipelines in the same place +* you prefer a managed experiment tracking option that integrates nicely with the Gradio/Spaces ecosystem without the complexity of heavier experiment tracking suites +* you want to experiment locally with an easy to self-host service before committing to a larger-scale tracking solution + +You may want to use one of the other [Experiment Tracker flavors](./#experiment-tracker-flavors) if you rely on advanced governance features (e.g. model registry, lineage dashboards) that are better covered by tools like MLflow, Neptune, or Weights & Biases. + +### How do you deploy it? + +The Trackio Experiment Tracker flavor is provided through the Trackio ZenML integration. Install the integration to make the component available for registration: + +```shell +zenml integration install trackio -y +``` + +Once installed, you can register the experiment tracker and add it to a stack. The Trackio tracker accepts the following configuration attributes: + +* `workspace` (optional): Trackio workspace/organization slug. Leave empty to use the default workspace associated with your API key. +* `project` (optional): Project identifier inside Trackio where new runs should be stored. +* `api_key` (optional): API token used to authenticate against Trackio. When omitted, the library falls back to the environment or CLI authentication. +* `base_url` (optional): Override the default Trackio endpoint, e.g. when talking to a self-hosted Trackio server. + +### Authentication methods + +You can authenticate with Trackio using the same mechanisms recommended for other experiment trackers. Storing credentials inside a [ZenML Secret](https://docs.zenml.io/how-to/project-setup-and-management/interact-with-secrets) keeps your configuration secure and auditable. + +{% tabs %} +{% tab title="ZenML Secret (Recommended)" %} + +```shell +zenml secret create trackio_secret --api_key= + +zenml experiment-tracker register trackio_tracker \ + --flavor=trackio \ + --workspace= \ + --project= \ + --api_key={{trackio_secret.api_key}} + +zenml stack register trackio_stack -e trackio_tracker ... --set +``` + +{% endtab %} +{% tab title="Basic authentication" %} + +You can also provide the credentials directly on the command line (useful for quick experiments): + +```shell +zenml experiment-tracker register trackio_tracker \ + --flavor=trackio \ + --workspace= \ + --project= \ + --api_key= + +zenml stack register trackio_stack -e trackio_tracker ... --set +``` + +{% hint style="warning" %} +Providing credentials inline is not recommended for production use because they will be stored in plain text in the stack configuration. +{% endhint %} + +{% endtab %} +{% endtabs %} + +For more information about the configuration parameters supported by the Trackio experiment tracker, refer to the [SDK docs](https://sdkdocs.zenml.io/latest/integration_code_docs/integrations-trackio.html). + +### How do you use it? + +Once the experiment tracker is part of the active stack, enable it for the steps that should log to Trackio by using the `experiment_tracker` argument of the `@step` decorator. Inside the step you can fetch the Trackio run object and use the regular Trackio API to log metrics, files or metadata: + +```python +from zenml import step +from zenml.integrations.trackio.experiment_trackers import get_trackio_run + +@step(experiment_tracker="trackio_tracker") +def evaluate_model(accuracy: float, params: dict) -> None: + run = get_trackio_run() + + # Log high level metrics / metadata using the Trackio client API + if hasattr(run, "log"): + run.log({"accuracy": accuracy}) + elif hasattr(run, "log_metrics"): + run.log_metrics({"accuracy": accuracy}) + + if hasattr(run, "log_params"): + run.log_params(params) + elif hasattr(run, "__setitem__"): + run["parameters"] = params +``` + +Any metadata logged during the step is persisted by Trackio and the resulting dashboard URL is attached to the ZenML run metadata. You can retrieve it programmatically: + +```python +from zenml.client import Client + +pipeline_run = Client().get_pipeline_run("") +step = pipeline_run.steps["evaluate_model"] +experiment_tracker_url = step.run_metadata["experiment_tracker_url"].value +``` + +{% hint style="info" %} +Instead of hardcoding the experiment tracker name, you can obtain it dynamically from the active stack: + +```python +from zenml.client import Client + +tracker_name = Client().active_stack.experiment_tracker.name + +@step(experiment_tracker=tracker_name) +def my_step(...): + ... +``` +{% endhint %} + +### Additional settings + +The Trackio experiment tracker exposes additional runtime settings that can be provided via the `@step` decorator: + +* `run_name`: Override the auto-generated run name that ZenML derives from the pipeline and step names. +* `tags`: Attach custom tags to the Trackio run. ZenML always appends the pipeline name, pipeline run name, and step name to help with filtering. +* `metadata`: Provide a dictionary of values that should be logged immediately after the run is created (e.g. static parameters, environment metadata). + +These settings can be declared when defining the step: + +```python +from zenml import step +from zenml.integrations.trackio.flavors import TrackioExperimentTrackerSettings + +@step( + experiment_tracker="trackio_tracker", + settings={ + "experiment_tracker": TrackioExperimentTrackerSettings( + run_name="nightly-eval", + tags=["nightly", "evaluation"], + metadata={"dataset": "imagenet", "split": "validation"}, + ) + }, +) +def nightly_evaluation() -> None: + run = get_trackio_run() + run.log({"f1": 0.81}) +``` + +As with other experiment trackers, Trackio runs are automatically marked as failed when the corresponding ZenML pipeline step fails. diff --git a/docs/book/component-guide/toc.md b/docs/book/component-guide/toc.md index 53a867ed184..a6d54e04766 100644 --- a/docs/book/component-guide/toc.md +++ b/docs/book/component-guide/toc.md @@ -52,6 +52,7 @@ * [Comet](experiment-trackers/comet.md) * [MLflow](experiment-trackers/mlflow.md) * [Neptune](experiment-trackers/neptune.md) + * [Trackio](experiment-trackers/trackio.md) * [Weights & Biases](experiment-trackers/wandb.md) * [Google Cloud VertexAI Experiment Tracker](experiment-trackers/vertexai.md) * [Develop a custom experiment tracker](experiment-trackers/custom.md) diff --git a/src/zenml/integrations/constants.py b/src/zenml/integrations/constants.py index fc079665ebb..b264eac8834 100644 --- a/src/zenml/integrations/constants.py +++ b/src/zenml/integrations/constants.py @@ -70,6 +70,7 @@ SKYPILOT_KUBERNETES = "skypilot_kubernetes" SLACK = "slack" SPARK = "spark" +TRACKIO = "trackio" TEKTON = "tekton" TENSORBOARD = "tensorboard" TENSORFLOW = "tensorflow" diff --git a/src/zenml/integrations/trackio/__init__.py b/src/zenml/integrations/trackio/__init__.py new file mode 100644 index 00000000000..dd2985f8a2d --- /dev/null +++ b/src/zenml/integrations/trackio/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Module containing the Trackio integration.""" + +from typing import List, Type + +from zenml.integrations.constants import TRACKIO +from zenml.integrations.integration import Integration +from zenml.stack import Flavor + +TRACKIO_EXPERIMENT_TRACKER_FLAVOR = "trackio" + + +class TrackioIntegration(Integration): + """Definition of the Trackio integration for ZenML.""" + + NAME = TRACKIO + REQUIREMENTS = ["trackio>=0.1.0"] + + @classmethod + def flavors(cls) -> List[Type[Flavor]]: + """Declare the stack component flavors for the Trackio integration. + + Returns: + List of stack component flavors for this integration. + """ + from zenml.integrations.trackio.flavors import ( + TrackioExperimentTrackerFlavor, + ) + + return [TrackioExperimentTrackerFlavor] diff --git a/src/zenml/integrations/trackio/experiment_trackers/__init__.py b/src/zenml/integrations/trackio/experiment_trackers/__init__.py new file mode 100644 index 00000000000..a163ba14b09 --- /dev/null +++ b/src/zenml/integrations/trackio/experiment_trackers/__init__.py @@ -0,0 +1,10 @@ +"""Trackio experiment tracker implementation.""" + +from zenml.integrations.trackio.experiment_trackers.run_state import ( + get_trackio_run, +) +from zenml.integrations.trackio.experiment_trackers.trackio_experiment_tracker import ( + TrackioExperimentTracker, +) + +__all__ = ["TrackioExperimentTracker", "get_trackio_run"] diff --git a/src/zenml/integrations/trackio/experiment_trackers/run_state.py b/src/zenml/integrations/trackio/experiment_trackers/run_state.py new file mode 100644 index 00000000000..a6f2d7514c1 --- /dev/null +++ b/src/zenml/integrations/trackio/experiment_trackers/run_state.py @@ -0,0 +1,57 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utilities to access the active Trackio run from pipeline steps.""" + +from typing import Any + +from zenml.client import Client + + +def get_trackio_run() -> Any: + """Return the Trackio run associated with the currently executing step. + + Returns: + The Trackio run instance created for the active step. + + Raises: + RuntimeError: If the active stack does not contain a Trackio experiment + tracker or the tracker has not yet been initialized. + """ + + from zenml.integrations.trackio.experiment_trackers import ( + TrackioExperimentTracker, + ) + + experiment_tracker = Client().active_stack.experiment_tracker + + if experiment_tracker is None: + raise RuntimeError( + "Unable to access a Trackio run: No experiment tracker is " + "configured for the active stack." + ) + + if not isinstance(experiment_tracker, TrackioExperimentTracker): + raise RuntimeError( + "Unable to access a Trackio run: The experiment tracker in the " + "active stack is not the Trackio integration." + ) + + try: + return experiment_tracker.active_run + except RuntimeError as exc: + raise RuntimeError( + "Unable to access a Trackio run: The tracker was not initialized. " + "Ensure that the step enabling the experiment tracker is running " + "when calling this helper." + ) from exc diff --git a/src/zenml/integrations/trackio/experiment_trackers/trackio_experiment_tracker.py b/src/zenml/integrations/trackio/experiment_trackers/trackio_experiment_tracker.py new file mode 100644 index 00000000000..49789fe5cbc --- /dev/null +++ b/src/zenml/integrations/trackio/experiment_trackers/trackio_experiment_tracker.py @@ -0,0 +1,496 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Trackio experiment tracker.""" + +from __future__ import annotations + +import inspect +import os +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + cast, +) +from urllib.parse import urljoin + +from zenml.constants import METADATA_EXPERIMENT_TRACKER_URL +from zenml.experiment_trackers.base_experiment_tracker import ( + BaseExperimentTracker, +) +from zenml.integrations.trackio.flavors import ( + TrackioExperimentTrackerConfig, + TrackioExperimentTrackerSettings, +) +from zenml.logger import get_logger +from zenml.metadata.metadata_types import Uri + +if TYPE_CHECKING: + from zenml.config.base_settings import BaseSettings + from zenml.config.step_run_info import StepRunInfo + from zenml.metadata.metadata_types import MetadataType + +logger = get_logger(__name__) + + +class TrackioExperimentTracker(BaseExperimentTracker): + """Experiment tracker implementation that delegates to Trackio.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the Trackio experiment tracker.""" + + super().__init__(*args, **kwargs) + self._active_run: Optional[Any] = None + self._run_initialized = False + self._run_url: Optional[str] = None + self._run_identifier: Optional[str] = None + + @property + def config(self) -> TrackioExperimentTrackerConfig: + """Return the configuration of the experiment tracker.""" + + return cast(TrackioExperimentTrackerConfig, self._config) + + @property + def settings_class(self) -> Optional[Type["BaseSettings"]]: + """Return the settings class for the experiment tracker.""" + + return TrackioExperimentTrackerSettings + + @property + def active_run(self) -> Any: + """Return the active Trackio run. + + Returns: + The Trackio run instance. + + Raises: + RuntimeError: If the run is not initialized. + """ + + if not self._run_initialized or self._active_run is None: + raise RuntimeError( + "Trackio run is not initialized. Make sure to access the run " + "from within a step that uses the Trackio experiment tracker." + ) + return self._active_run + + def prepare_step_run(self, info: "StepRunInfo") -> None: + """Create and configure a Trackio run for the upcoming step.""" + + settings = cast( + TrackioExperimentTrackerSettings, self.get_settings(info) + ) + run_name = ( + settings.run_name or f"{info.run_name}_{info.pipeline_step_name}" + ) + tags = self._build_tags(settings.tags, info) + metadata = settings.metadata + + run, run_url, identifier = self._start_trackio_run( + run_name=run_name, + tags=tags, + metadata=metadata, + ) + + self._active_run = run + self._run_initialized = True + self._run_url = run_url + self._run_identifier = identifier or run_name + + def get_step_run_metadata( + self, info: "StepRunInfo" + ) -> Dict[str, "MetadataType"]: + """Return metadata generated during the step run.""" + + metadata: Dict[str, "MetadataType"] = {} + + run_url = self._run_url or self._build_fallback_run_url() + if run_url: + metadata[METADATA_EXPERIMENT_TRACKER_URL] = Uri(run_url) + + if self._run_identifier: + metadata["trackio_run_id"] = self._run_identifier + + return metadata + + def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None: + """Finalize the Trackio run associated with the step.""" + + try: + self._finish_run(step_failed) + finally: + self._active_run = None + self._run_initialized = False + self._run_url = None + self._run_identifier = None + + # --------------------------------------------------------------------- + # Helper utilities + # --------------------------------------------------------------------- + def _build_tags( + self, configured_tags: Iterable[str], info: "StepRunInfo" + ) -> List[str]: + """Merge configured tags with contextual information.""" + + tag_set = set(configured_tags) + tag_set.update(filter(None, [info.run_name, info.pipeline.name])) + tag_set.add(info.pipeline_step_name) + return sorted(tag_set) + + def _start_trackio_run( + self, + run_name: str, + tags: List[str], + metadata: Dict[str, Any], + ) -> Tuple[Any, Optional[str], Optional[str]]: + """Start a Trackio run via the available Python API.""" + + trackio_module = self._import_trackio() + self._authenticate(trackio_module) + + init_kwargs = self._build_init_kwargs(run_name, tags, metadata) + + run = self._try_initialize_run(trackio_module, init_kwargs) + if run is None: + raise RuntimeError( + "Unable to create a Trackio run. Please ensure that the " + "Trackio package is up to date and compatible with the " + "ZenML Trackio integration." + ) + + run_url = self._extract_run_url(run) + identifier = self._extract_run_identifier(run) + + self._attach_tags(run, tags) + self._log_initial_metadata(run, metadata) + + return run, run_url, identifier + + def _import_trackio(self) -> Any: + """Import the Trackio Python package.""" + + try: + import trackio # type: ignore[import-not-found] + except ImportError as exc: + raise RuntimeError( + "The Trackio integration requires the `trackio` package to be " + "installed. Run `zenml integration install trackio` to " + "install the dependency." + ) from exc + + return trackio + + def _authenticate(self, trackio_module: Any) -> None: + """Authenticate with Trackio if an API key has been provided.""" + + api_key = self.config.api_key + if not api_key: + return + + for attr in ("login", "authenticate", "set_token"): + maybe_callable = getattr(trackio_module, attr, None) + if callable(maybe_callable): + success, _ = self._call_with_supported_kwargs( + maybe_callable, + { + "api_key": api_key, + "token": api_key, + "key": api_key, + }, + ) + if success: + return + + os.environ.setdefault("TRACKIO_API_KEY", api_key) + + def _build_init_kwargs( + self, + run_name: str, + tags: List[str], + metadata: Dict[str, Any], + ) -> Dict[str, Any]: + """Collect potential keyword arguments for Trackio run creation.""" + + kwargs: Dict[str, Any] = { + "workspace": self.config.workspace, + "organization": self.config.workspace, + "project": self.config.project, + "project_name": self.config.project, + "base_url": self.config.base_url, + "api_url": self.config.base_url, + "url": self.config.base_url, + "run_name": run_name, + "name": run_name, + "display_name": run_name, + "title": run_name, + "tags": tags, + "tag_names": tags, + "labels": tags, + } + + if metadata: + kwargs.update( + { + "metadata": metadata, + "properties": metadata, + "params": metadata, + } + ) + + return {key: value for key, value in kwargs.items() if value} + + def _try_initialize_run( + self, trackio_module: Any, kwargs: Dict[str, Any] + ) -> Any: + """Try several possible Trackio APIs to initialize a run.""" + + for candidate in ("init", "start_run", "start", "create_run"): + maybe_callable = getattr(trackio_module, candidate, None) + if callable(maybe_callable): + success, result = self._call_with_supported_kwargs( + maybe_callable, kwargs + ) + if success: + if result is not None: + return result + current = self._fetch_current_run(trackio_module) + if current is not None: + return current + + run_class = getattr(trackio_module, "Run", None) + if run_class is not None: + try: + success, run_instance = self._call_with_supported_kwargs( + run_class, kwargs + ) + except Exception: # noqa: BLE001 + logger.debug( + "Failed to instantiate Trackio Run via constructor.", + exc_info=True, + ) + else: + if success and run_instance is not None: + starter = getattr(run_instance, "start", None) + if callable(starter): + started_success, started = ( + self._call_with_supported_kwargs( + starter, + { + "run_name": kwargs.get("run_name"), + "name": kwargs.get("name"), + }, + ) + ) + if started_success: + if started is not None: + return started + current = self._fetch_current_run(trackio_module) + if current is not None: + return current + return run_instance + + return None + + def _fetch_current_run(self, trackio_module: Any) -> Any: + """Try to retrieve the currently active Trackio run from the module.""" + + for attribute in ("get_current_run", "current_run", "active_run"): + value = getattr(trackio_module, attribute, None) + if callable(value): + success, run = self._call_with_supported_kwargs(value, {}) + if success and run is not None: + return run + elif value is not None: + return value + + return None + + def _call_with_supported_kwargs( + self, callable_obj: Any, potential_kwargs: Dict[str, Any] + ) -> Tuple[bool, Any]: + """Invoke a callable with the subset of supported keyword arguments.""" + + try: + signature = inspect.signature(callable_obj) + except (TypeError, ValueError): + signature = None + + filtered_kwargs: Dict[str, Any] = {} + if signature is None: + filtered_kwargs = potential_kwargs + else: + for name, value in potential_kwargs.items(): + if name in signature.parameters: + filtered_kwargs[name] = value + + try: + result = callable_obj(**filtered_kwargs) + except Exception as exc: # noqa: BLE001 + logger.debug( + "Failed to call %s with kwargs %s: %s", + callable_obj, + filtered_kwargs, + exc, + ) + return False, None + + return True, result + + def _extract_run_url(self, run: Any) -> Optional[str]: + """Try to extract a URL pointing to the Trackio run.""" + + for attribute in ("url", "ui_url", "dashboard_url", "app_url"): + value = getattr(run, attribute, None) + if isinstance(value, str) and value: + return value + + get_url = getattr(run, "get_url", None) + if callable(get_url): + try: + value = get_url() + except TypeError: + value = None + if isinstance(value, str) and value: + return value + + return None + + def _extract_run_identifier(self, run: Any) -> Optional[str]: + """Extract a stable identifier for the run.""" + + for attribute in ("id", "run_id", "name", "identifier", "slug"): + value = getattr(run, attribute, None) + if isinstance(value, str) and value: + return value + + return None + + def _attach_tags(self, run: Any, tags: List[str]) -> None: + """Attach tags to the Trackio run if the API supports it.""" + + if not tags: + return + + for attribute in ("set_tags", "add_tags", "with_tags"): + maybe_callable = getattr(run, attribute, None) + if callable(maybe_callable): + success, _ = self._call_with_supported_kwargs( + maybe_callable, {"tags": tags, "tag_names": tags} + ) + if success: + return + + add_tag = getattr(run, "add_tag", None) + if callable(add_tag): + for tag in tags: + self._call_with_supported_kwargs(add_tag, {"tag": tag}) + + def _log_initial_metadata( + self, run: Any, metadata: Dict[str, Any] + ) -> None: + """Log static metadata right after the run is created.""" + + if not metadata: + return + + if hasattr(run, "__setitem__"): + try: + for key, value in metadata.items(): + run[key] = value + return + except Exception: # noqa: BLE001 + logger.debug( + "Falling back to explicit logging for metadata.", + exc_info=True, + ) + + for attribute in ( + "log", + "log_metrics", + "set_params", + "update_metadata", + ): + maybe_callable = getattr(run, attribute, None) + if callable(maybe_callable): + success, _ = self._call_with_supported_kwargs( + maybe_callable, + { + "metrics": metadata, + "params": metadata, + "metadata": metadata, + }, + ) + if success: + return + + logger.debug("Unable to automatically log metadata to Trackio run.") + + def _build_fallback_run_url(self) -> Optional[str]: + """Build a fallback Trackio run URL based on the configuration.""" + + if not self.config.base_url or not self._run_identifier: + return None + + path_parts = [ + part + for part in ( + self.config.workspace, + self.config.project, + "runs", + self._run_identifier, + ) + if part + ] + + if not path_parts: + return None + + base = self.config.base_url.rstrip("/") + "/" + return urljoin(base, "/".join(path_parts)) + + def _finish_run(self, step_failed: bool) -> None: + """Finish the active Trackio run.""" + + if self._active_run is None: + return + + success = not step_failed + status_kwargs = { + "status": "failed" if step_failed else "success", + "state": "failed" if step_failed else "completed", + "success": success, + "exit_code": 1 if step_failed else 0, + } + + for attribute in ("finish", "end", "stop", "close", "complete"): + maybe_callable = getattr(self._active_run, attribute, None) + if callable(maybe_callable): + success, _ = self._call_with_supported_kwargs( + maybe_callable, status_kwargs + ) + if success: + break + else: + context_exit = getattr(self._active_run, "__exit__", None) + if callable(context_exit): + try: + context_exit(None, None, None) + except TypeError: + context_exit() diff --git a/src/zenml/integrations/trackio/flavors/__init__.py b/src/zenml/integrations/trackio/flavors/__init__.py new file mode 100644 index 00000000000..3ad7a5c2d29 --- /dev/null +++ b/src/zenml/integrations/trackio/flavors/__init__.py @@ -0,0 +1,13 @@ +"""Trackio experiment tracker flavors.""" + +from zenml.integrations.trackio.flavors.trackio_experiment_tracker_flavor import ( + TrackioExperimentTrackerConfig, + TrackioExperimentTrackerFlavor, + TrackioExperimentTrackerSettings, +) + +__all__ = [ + "TrackioExperimentTrackerConfig", + "TrackioExperimentTrackerFlavor", + "TrackioExperimentTrackerSettings", +] diff --git a/src/zenml/integrations/trackio/flavors/trackio_experiment_tracker_flavor.py b/src/zenml/integrations/trackio/flavors/trackio_experiment_tracker_flavor.py new file mode 100644 index 00000000000..d3b8e9ee39c --- /dev/null +++ b/src/zenml/integrations/trackio/flavors/trackio_experiment_tracker_flavor.py @@ -0,0 +1,133 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Trackio experiment tracker flavor.""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type + +from pydantic import Field + +from zenml.config.base_settings import BaseSettings +from zenml.experiment_trackers.base_experiment_tracker import ( + BaseExperimentTrackerConfig, + BaseExperimentTrackerFlavor, +) +from zenml.integrations.trackio import TRACKIO_EXPERIMENT_TRACKER_FLAVOR +from zenml.utils.secret_utils import SecretField + +if TYPE_CHECKING: + from zenml.integrations.trackio.experiment_trackers import ( + TrackioExperimentTracker, + ) + +__all__ = [ + "TrackioExperimentTrackerConfig", + "TrackioExperimentTrackerFlavor", + "TrackioExperimentTrackerSettings", +] + + +class TrackioExperimentTrackerConfig(BaseExperimentTrackerConfig): + """Configuration for the Trackio experiment tracker.""" + + workspace: Optional[str] = Field( + default=None, + description=( + "Optional workspace or organization to use for the Trackio run." + ), + ) + project: Optional[str] = Field( + default=None, + description="Trackio project slug where new runs should be created.", + ) + api_key: Optional[str] = SecretField( + default=None, + description="API key used to authenticate with the Trackio service.", + ) + base_url: Optional[str] = Field( + default=None, + description=( + "Override the default Trackio service base URL (e.g. for a " + "self-hosted deployment)." + ), + ) + + +class TrackioExperimentTrackerSettings(BaseSettings): + """Runtime settings for the Trackio experiment tracker.""" + + run_name: Optional[str] = Field( + default=None, + description="Explicit name to assign to the Trackio run.", + ) + tags: List[str] = Field( + default_factory=list, + description="List of tags to attach to the Trackio run.", + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, + description=( + "Static metadata to log to Trackio as soon as the run is created." + ), + ) + + +class TrackioExperimentTrackerFlavor(BaseExperimentTrackerFlavor): + """Flavor for the Trackio experiment tracker component.""" + + @property + def name(self) -> str: + """Name of the flavor. + + Returns: + The flavor name. + """ + + return TRACKIO_EXPERIMENT_TRACKER_FLAVOR + + @property + def docs_url(self) -> Optional[str]: + """URL pointing to the user documentation for this flavor.""" + + return self.generate_default_docs_url() + + @property + def sdk_docs_url(self) -> Optional[str]: + """URL pointing to the SDK documentation for this flavor.""" + + return self.generate_default_sdk_docs_url() + + @property + def logo_url(self) -> str: + """URL of an image that represents the flavor in the dashboard.""" + + return ( + "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/" + "experiment_tracker/trackio.png" + ) + + @property + def config_class(self) -> Type[TrackioExperimentTrackerConfig]: + """Return the configuration class associated with this flavor.""" + + return TrackioExperimentTrackerConfig + + @property + def implementation_class(self) -> Type["TrackioExperimentTracker"]: + """Implementation class for the Trackio experiment tracker flavor.""" + + from zenml.integrations.trackio.experiment_trackers import ( + TrackioExperimentTracker, + ) + + return TrackioExperimentTracker diff --git a/tests/unit/stack/test_trackio_experiment_tracker_flavor.py b/tests/unit/stack/test_trackio_experiment_tracker_flavor.py new file mode 100644 index 00000000000..1c184de28d8 --- /dev/null +++ b/tests/unit/stack/test_trackio_experiment_tracker_flavor.py @@ -0,0 +1,44 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Unit tests for the Trackio experiment tracker flavor.""" + +from zenml.enums import StackComponentType +from zenml.integrations.trackio import TRACKIO_EXPERIMENT_TRACKER_FLAVOR +from zenml.integrations.trackio.experiment_trackers import ( + TrackioExperimentTracker, +) +from zenml.integrations.trackio.flavors import ( + TrackioExperimentTrackerConfig, + TrackioExperimentTrackerFlavor, +) + + +def test_trackio_flavor_metadata() -> None: + """Trackio experiment tracker flavor advertises the expected metadata.""" + + flavor = TrackioExperimentTrackerFlavor() + + assert flavor.name == TRACKIO_EXPERIMENT_TRACKER_FLAVOR + assert flavor.type == StackComponentType.EXPERIMENT_TRACKER + assert flavor.config_class is TrackioExperimentTrackerConfig + assert flavor.implementation_class is TrackioExperimentTracker + + +def test_trackio_config_defaults() -> None: + """Trackio configuration defaults to optional fields.""" + + config = TrackioExperimentTrackerConfig() + + assert config.workspace is None + assert config.project is None + assert config.api_key is None + assert config.base_url is None