Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
import typing as t

from dagster import AssetsDefinition, RetryPolicy, multi_asset
from sqlmesh import Context

from dagster_sqlmesh.controller import DagsterSQLMeshController
from dagster_sqlmesh.controller import (
ContextCls,
ContextFactory,
DagsterSQLMeshController,
)
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

from .config import SQLMeshContextConfig
Expand All @@ -16,6 +21,7 @@ def sqlmesh_assets(
*,
environment: str,
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
name: str | None = None,
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
compute_kind: str = "sqlmesh",
Expand All @@ -25,7 +31,7 @@ def sqlmesh_assets(
# For now we don't set this by default
enabled_subsetting: bool = False,
) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]:
controller = DagsterSQLMeshController.setup_with_config(config)
controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory)
if not dagster_sqlmesh_translator:
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator)
Expand Down
10 changes: 9 additions & 1 deletion dagster_sqlmesh/controller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# ruff: noqa: F403 F401
from .base import PlanOptions, RunOptions, SQLMeshController, SQLMeshInstance
from .base import (
DEFAULT_CONTEXT_FACTORY,
ContextCls,
ContextFactory,
PlanOptions,
RunOptions,
SQLMeshController,
SQLMeshInstance,
)
from .dagster import DagsterSQLMeshController
42 changes: 27 additions & 15 deletions dagster_sqlmesh/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from contextlib import contextmanager
from dataclasses import dataclass
from types import MappingProxyType
from typing import TypeVar

from sqlmesh.core.config import CategorizerConfig
from sqlmesh.core.console import set_console
Expand All @@ -27,8 +26,14 @@

logger = logging.getLogger(__name__)

T = TypeVar("T", bound="SQLMeshController")
T = t.TypeVar("T", bound="SQLMeshController")
ContextCls = t.TypeVar("ContextCls", bound=Context)
ContextFactory = t.Callable[..., ContextCls]

def default_context_factory(**kwargs: t.Any) -> Context:
return Context(**kwargs)

DEFAULT_CONTEXT_FACTORY: ContextFactory[Context] = default_context_factory

class PlanOptions(t.TypedDict):
start: t.NotRequired[TimeLike]
Expand Down Expand Up @@ -88,7 +93,7 @@ def parse_fqn(self) -> SQLMeshParsedFQN:
return parse_fqn(self.fqn)


class SQLMeshInstance:
class SQLMeshInstance(t.Generic[ContextCls]):
"""
A class that manages sqlmesh operations and context within a specific
environment. This class will run sqlmesh in a separate thread.
Expand All @@ -110,15 +115,15 @@ class SQLMeshInstance:
config: SQLMeshContextConfig
console: EventConsole
logger: logging.Logger
context: Context
context: ContextCls
environment: str

def __init__(
self,
environment: str,
console: EventConsole,
config: SQLMeshContextConfig,
context: Context,
context: ContextCls,
logger: logging.Logger,
):
self.environment = environment
Expand Down Expand Up @@ -167,7 +172,7 @@ def plan(
def run_sqlmesh_thread(
logger: logging.Logger,
context: Context,
controller: SQLMeshController,
controller: SQLMeshController[ContextCls],
environment: str,
plan_options: PlanOptions,
default_catalog: str,
Expand Down Expand Up @@ -251,7 +256,7 @@ def run(self, **run_options: t.Unpack[RunOptions]) -> t.Iterator[ConsoleEvent]:
def run_sqlmesh_thread(
logger: logging.Logger,
context: Context,
controller: SQLMeshController,
controller: SQLMeshController[ContextCls],
environment: str,
run_options: RunOptions,
) -> None:
Expand Down Expand Up @@ -362,8 +367,7 @@ def non_external_models_dag(self) -> t.Iterable[tuple[Model, set[str]]]:
continue
yield (model, deps)


class SQLMeshController:
class SQLMeshController(t.Generic[ContextCls]):
"""Allows control of sqlmesh via a python interface. It is not suggested to
use the constructor of this class directly, but instead use the provided
`setup` or `setup_with_config` class methods.
Expand Down Expand Up @@ -403,37 +407,45 @@ class SQLMeshController:
def setup(
cls,
path: str,
*,
context_factory: ContextFactory[ContextCls],
gateway: str = "local",
log_override: logging.Logger | None = None,
) -> "SQLMeshController":
) -> t.Self:
return cls.setup_with_config(
config=SQLMeshContextConfig(path=path, gateway=gateway),
log_override=log_override,
context_factory=context_factory,
)

@classmethod
def setup_with_config(
cls: type[T],
cls,
*,
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
log_override: logging.Logger | None = None,
) -> T:
) -> t.Self:
console = EventConsole(log_override=log_override) # type: ignore
controller = cls(
console=console,
config=config,
log_override=log_override,
context_factory=context_factory,
)
return controller

def __init__(
self,
config: SQLMeshContextConfig,
console: EventConsole,
context_factory: ContextFactory[ContextCls],
log_override: logging.Logger | None = None,
) -> None:
self.config = config
self.console = console
self.logger = log_override or logger
self._context_factory = context_factory
self._context_open = False

def set_logger(self, logger: logging.Logger) -> None:
Expand All @@ -446,20 +458,20 @@ def add_event_handler(self, handler: ConsoleEventHandler) -> str:
def remove_event_handler(self, handler_id: str) -> None:
self.console.remove_handler(handler_id)

def _create_context(self) -> Context:
def _create_context(self) -> ContextCls:
options: dict[str, t.Any] = dict(
paths=self.config.path,
gateway=self.config.gateway,
)
if self.config.sqlmesh_config:
options["config"] = self.config.sqlmesh_config
set_console(self.console)
return Context(**options)
return self._context_factory(**options)

@contextmanager
def instance(
self, environment: str, component: str = "unknown"
) -> t.Iterator[SQLMeshInstance]:
) -> t.Iterator[SQLMeshInstance[ContextCls]]:
self.logger.info(
f"Opening sqlmesh instance for env={environment} component={component}"
)
Expand Down
4 changes: 2 additions & 2 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from ..translator import SQLMeshDagsterTranslator
from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions
from ..utils import sqlmesh_model_name_to_key
from .base import SQLMeshController
from .base import ContextCls, SQLMeshController

logger = logging.getLogger(__name__)


class DagsterSQLMeshController(SQLMeshController):
class DagsterSQLMeshController(SQLMeshController[ContextCls]):
"""An extension of the sqlmesh controller specifically for dagster use"""

def to_asset_outs(
Expand Down
24 changes: 19 additions & 5 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike

from dagster_sqlmesh.controller.base import (
DEFAULT_CONTEXT_FACTORY,
ContextCls,
ContextFactory,
)

from . import console
from .config import SQLMeshContextConfig
from .controller import PlanOptions, RunOptions
Expand Down Expand Up @@ -169,7 +175,9 @@ def report_event(self, event: console.ConsoleEvent) -> None:
case console.StopPlanEvaluation:
log_context.info("Plan evaluation completed")
case console.StartEvaluationProgress(
batched_intervals=batches, environment_naming_info=environment_naming_info, default_catalog=default_catalog
batched_intervals=batches,
environment_naming_info=environment_naming_info,
default_catalog=default_catalog,
):
self.update_stage("run")
log_context.info(
Expand Down Expand Up @@ -263,6 +271,7 @@ def run(
self,
context: AssetExecutionContext,
*,
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
environment: str = "dev",
start: TimeLike | None = None,
end: TimeLike | None = None,
Expand All @@ -277,7 +286,7 @@ def run(

logger = context.log

controller = self.get_controller(logger)
controller = self.get_controller(context_factory, logger)

with controller.instance(environment) as mesh:
dag = mesh.models_dag()
Expand Down Expand Up @@ -322,13 +331,18 @@ def run(
plan_options=plan_options,
run_options=run_options,
):
logger.debug(f"sqlmesh event: {event}")
event_handler.process_events(event)

yield from event_handler.notify_success(mesh.context)

def get_controller(
self, log_override: logging.Logger | None = None
) -> DagsterSQLMeshController:
self,
context_factory: ContextFactory[ContextCls],
log_override: logging.Logger | None = None,
) -> DagsterSQLMeshController[ContextCls]:
return DagsterSQLMeshController.setup_with_config(
self.config, log_override=log_override
config=self.config,
context_factory=context_factory,
log_override=log_override,
)
5 changes: 3 additions & 2 deletions dagster_sqlmesh/testing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import duckdb
import polars
from sqlmesh import Context
from sqlmesh.utils.date import TimeLike

from dagster_sqlmesh.config import SQLMeshContextConfig
Expand All @@ -21,9 +22,9 @@ class SQLMeshTestContext:
db_path: str
context_config: SQLMeshContextConfig

def create_controller(self):
def create_controller(self) -> DagsterSQLMeshController[Context]:
return DagsterSQLMeshController.setup_with_config(
self.context_config,
config=self.context_config,
)

def query(self, *args: t.Any, **kwargs: t.Any) -> list[t.Any]:
Expand Down