diff --git a/dagster_sqlmesh/asset.py b/dagster_sqlmesh/asset.py index 8f614bc..992d99f 100644 --- a/dagster_sqlmesh/asset.py +++ b/dagster_sqlmesh/asset.py @@ -2,8 +2,13 @@ import typing as t from dagster import AssetsDefinition, RetryPolicy, multi_asset +from sqlmesh import Context -from dagster_sqlmesh.controller import DagsterSQLMeshController +from dagster_sqlmesh.controller import ( + ContextCls, + ContextFactory, + DagsterSQLMeshController, +) from dagster_sqlmesh.translator import SQLMeshDagsterTranslator from .config import SQLMeshContextConfig @@ -16,6 +21,7 @@ def sqlmesh_assets( *, environment: str, config: SQLMeshContextConfig, + context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs), name: str | None = None, dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None, compute_kind: str = "sqlmesh", @@ -25,7 +31,7 @@ def sqlmesh_assets( # For now we don't set this by default enabled_subsetting: bool = False, ) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]: - controller = DagsterSQLMeshController.setup_with_config(config) + controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory) if not dagster_sqlmesh_translator: dagster_sqlmesh_translator = SQLMeshDagsterTranslator() conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator) diff --git a/dagster_sqlmesh/controller/__init__.py b/dagster_sqlmesh/controller/__init__.py index 150e8f0..4b66be4 100644 --- a/dagster_sqlmesh/controller/__init__.py +++ b/dagster_sqlmesh/controller/__init__.py @@ -1,3 +1,11 @@ # ruff: noqa: F403 F401 -from .base import PlanOptions, RunOptions, SQLMeshController, SQLMeshInstance +from .base import ( + DEFAULT_CONTEXT_FACTORY, + ContextCls, + ContextFactory, + PlanOptions, + RunOptions, + SQLMeshController, + SQLMeshInstance, +) from .dagster import DagsterSQLMeshController diff --git a/dagster_sqlmesh/controller/base.py b/dagster_sqlmesh/controller/base.py index 25db81a..2a36554 100644 --- a/dagster_sqlmesh/controller/base.py +++ b/dagster_sqlmesh/controller/base.py @@ -4,7 +4,6 @@ from contextlib import contextmanager from dataclasses import dataclass from types import MappingProxyType -from typing import TypeVar from sqlmesh.core.config import CategorizerConfig from sqlmesh.core.console import set_console @@ -27,8 +26,14 @@ logger = logging.getLogger(__name__) -T = TypeVar("T", bound="SQLMeshController") +T = t.TypeVar("T", bound="SQLMeshController") +ContextCls = t.TypeVar("ContextCls", bound=Context) +ContextFactory = t.Callable[..., ContextCls] +def default_context_factory(**kwargs: t.Any) -> Context: + return Context(**kwargs) + +DEFAULT_CONTEXT_FACTORY: ContextFactory[Context] = default_context_factory class PlanOptions(t.TypedDict): start: t.NotRequired[TimeLike] @@ -88,7 +93,7 @@ def parse_fqn(self) -> SQLMeshParsedFQN: return parse_fqn(self.fqn) -class SQLMeshInstance: +class SQLMeshInstance(t.Generic[ContextCls]): """ A class that manages sqlmesh operations and context within a specific environment. This class will run sqlmesh in a separate thread. @@ -110,7 +115,7 @@ class SQLMeshInstance: config: SQLMeshContextConfig console: EventConsole logger: logging.Logger - context: Context + context: ContextCls environment: str def __init__( @@ -118,7 +123,7 @@ def __init__( environment: str, console: EventConsole, config: SQLMeshContextConfig, - context: Context, + context: ContextCls, logger: logging.Logger, ): self.environment = environment @@ -167,7 +172,7 @@ def plan( def run_sqlmesh_thread( logger: logging.Logger, context: Context, - controller: SQLMeshController, + controller: SQLMeshController[ContextCls], environment: str, plan_options: PlanOptions, default_catalog: str, @@ -251,7 +256,7 @@ def run(self, **run_options: t.Unpack[RunOptions]) -> t.Iterator[ConsoleEvent]: def run_sqlmesh_thread( logger: logging.Logger, context: Context, - controller: SQLMeshController, + controller: SQLMeshController[ContextCls], environment: str, run_options: RunOptions, ) -> None: @@ -362,8 +367,7 @@ def non_external_models_dag(self) -> t.Iterable[tuple[Model, set[str]]]: continue yield (model, deps) - -class SQLMeshController: +class SQLMeshController(t.Generic[ContextCls]): """Allows control of sqlmesh via a python interface. It is not suggested to use the constructor of this class directly, but instead use the provided `setup` or `setup_with_config` class methods. @@ -403,25 +407,31 @@ class SQLMeshController: def setup( cls, path: str, + *, + context_factory: ContextFactory[ContextCls], gateway: str = "local", log_override: logging.Logger | None = None, - ) -> "SQLMeshController": + ) -> t.Self: return cls.setup_with_config( config=SQLMeshContextConfig(path=path, gateway=gateway), log_override=log_override, + context_factory=context_factory, ) @classmethod def setup_with_config( - cls: type[T], + cls, + *, config: SQLMeshContextConfig, + context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY, log_override: logging.Logger | None = None, - ) -> T: + ) -> t.Self: console = EventConsole(log_override=log_override) # type: ignore controller = cls( console=console, config=config, log_override=log_override, + context_factory=context_factory, ) return controller @@ -429,11 +439,13 @@ def __init__( self, config: SQLMeshContextConfig, console: EventConsole, + context_factory: ContextFactory[ContextCls], log_override: logging.Logger | None = None, ) -> None: self.config = config self.console = console self.logger = log_override or logger + self._context_factory = context_factory self._context_open = False def set_logger(self, logger: logging.Logger) -> None: @@ -446,7 +458,7 @@ def add_event_handler(self, handler: ConsoleEventHandler) -> str: def remove_event_handler(self, handler_id: str) -> None: self.console.remove_handler(handler_id) - def _create_context(self) -> Context: + def _create_context(self) -> ContextCls: options: dict[str, t.Any] = dict( paths=self.config.path, gateway=self.config.gateway, @@ -454,12 +466,12 @@ def _create_context(self) -> Context: if self.config.sqlmesh_config: options["config"] = self.config.sqlmesh_config set_console(self.console) - return Context(**options) + return self._context_factory(**options) @contextmanager def instance( self, environment: str, component: str = "unknown" - ) -> t.Iterator[SQLMeshInstance]: + ) -> t.Iterator[SQLMeshInstance[ContextCls]]: self.logger.info( f"Opening sqlmesh instance for env={environment} component={component}" ) diff --git a/dagster_sqlmesh/controller/dagster.py b/dagster_sqlmesh/controller/dagster.py index 87cd3b0..6978be8 100644 --- a/dagster_sqlmesh/controller/dagster.py +++ b/dagster_sqlmesh/controller/dagster.py @@ -7,12 +7,12 @@ from ..translator import SQLMeshDagsterTranslator from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions from ..utils import sqlmesh_model_name_to_key -from .base import SQLMeshController +from .base import ContextCls, SQLMeshController logger = logging.getLogger(__name__) -class DagsterSQLMeshController(SQLMeshController): +class DagsterSQLMeshController(SQLMeshController[ContextCls]): """An extension of the sqlmesh controller specifically for dagster use""" def to_asset_outs( diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index f192849..a5223ee 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -12,6 +12,12 @@ from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike +from dagster_sqlmesh.controller.base import ( + DEFAULT_CONTEXT_FACTORY, + ContextCls, + ContextFactory, +) + from . import console from .config import SQLMeshContextConfig from .controller import PlanOptions, RunOptions @@ -169,7 +175,9 @@ def report_event(self, event: console.ConsoleEvent) -> None: case console.StopPlanEvaluation: log_context.info("Plan evaluation completed") case console.StartEvaluationProgress( - batched_intervals=batches, environment_naming_info=environment_naming_info, default_catalog=default_catalog + batched_intervals=batches, + environment_naming_info=environment_naming_info, + default_catalog=default_catalog, ): self.update_stage("run") log_context.info( @@ -263,6 +271,7 @@ def run( self, context: AssetExecutionContext, *, + context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY, environment: str = "dev", start: TimeLike | None = None, end: TimeLike | None = None, @@ -277,7 +286,7 @@ def run( logger = context.log - controller = self.get_controller(logger) + controller = self.get_controller(context_factory, logger) with controller.instance(environment) as mesh: dag = mesh.models_dag() @@ -322,13 +331,18 @@ def run( plan_options=plan_options, run_options=run_options, ): + logger.debug(f"sqlmesh event: {event}") event_handler.process_events(event) yield from event_handler.notify_success(mesh.context) def get_controller( - self, log_override: logging.Logger | None = None - ) -> DagsterSQLMeshController: + self, + context_factory: ContextFactory[ContextCls], + log_override: logging.Logger | None = None, + ) -> DagsterSQLMeshController[ContextCls]: return DagsterSQLMeshController.setup_with_config( - self.config, log_override=log_override + config=self.config, + context_factory=context_factory, + log_override=log_override, ) diff --git a/dagster_sqlmesh/testing/context.py b/dagster_sqlmesh/testing/context.py index 5ea6363..9f42ac9 100644 --- a/dagster_sqlmesh/testing/context.py +++ b/dagster_sqlmesh/testing/context.py @@ -4,6 +4,7 @@ import duckdb import polars +from sqlmesh import Context from sqlmesh.utils.date import TimeLike from dagster_sqlmesh.config import SQLMeshContextConfig @@ -21,9 +22,9 @@ class SQLMeshTestContext: db_path: str context_config: SQLMeshContextConfig - def create_controller(self): + def create_controller(self) -> DagsterSQLMeshController[Context]: return DagsterSQLMeshController.setup_with_config( - self.context_config, + config=self.context_config, ) def query(self, *args: t.Any, **kwargs: t.Any) -> list[t.Any]: