diff --git a/guidance/models/_model.py b/guidance/models/_model.py index 74f3c8af0..40b35ddc2 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -21,6 +21,17 @@ ipython_is_imported = True except ImportError: ipython_is_imported = False +try: + import torch + + torch_is_imported = True +except ImportError: + torch_is_imported = False +try: + import marimo + marimo_is_imported = True +except ImportError: + marimo_is_imported = False logger = logging.getLogger(__name__) @@ -357,6 +368,8 @@ def _update_display(self, throttle=True): if ipython_is_imported: clear_output(wait=True) display(HTML(self._html())) + elif marimo_is_imported: + marimo.output.replace(marimo.Html(self._html())) else: pprint(self._state) @@ -907,3 +920,4 @@ def __init__(self, *args, **kwargs): self.prompt = kwargs.pop("prompt", None) self.data = kwargs.pop("data", None) super().__init__(*args, **kwargs) +