diff --git a/dagster_sqlmesh/console.py b/dagster_sqlmesh/console.py index e334ca1..e1ca596 100644 --- a/dagster_sqlmesh/console.py +++ b/dagster_sqlmesh/console.py @@ -10,7 +10,7 @@ from sqlmesh.core.console import Console from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.plan import EvaluatablePlan, PlanBuilder +from sqlmesh.core.plan import EvaluatablePlan, Plan as SQLMeshPlan, PlanBuilder from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory, SnapshotInfoLike from sqlmesh.core.table_diff import RowDiff, SchemaDiff, TableDiff from sqlmesh.utils.concurrency import NodeExecutionFailedError @@ -219,6 +219,10 @@ class PrintEnvironments(BaseConsoleEvent): class ShowTableDiffSummary(BaseConsoleEvent): table_diff: TableDiff +@dataclass(kw_only=True) +class PlanBuilt(BaseConsoleEvent): + plan: SQLMeshPlan + ConsoleEvent = ( StartPlanEvaluation | StopPlanEvaluation @@ -263,6 +267,7 @@ class ShowTableDiffSummary(BaseConsoleEvent): | ConsoleException | PrintEnvironments | ShowTableDiffSummary + | PlanBuilt ) ConsoleEventHandler = t.Callable[[ConsoleEvent], None] @@ -424,7 +429,22 @@ def add_handler(self, handler: ConsoleEventHandler) -> str: def remove_handler(self, handler_id: str) -> None: del self._handlers[handler_id] - + + def plan(self, plan_builder: PlanBuilder, auto_apply: bool, default_catalog: str | None, no_diff: bool = False, no_prompts: bool = False) -> None: + """Plan is not a console event. This triggers building of a plan and + applying said plan + + This method is called by SQLMesh to start the plan process (when you + call Context#plan) + + This overriden method ignores the options passed in at this time + """ + + plan_builder.apply() + + def capture_built_plan(self, plan: SQLMeshPlan) -> None: + """Capture the built plan and publish a PlanBuilt event.""" + self.publish(PlanBuilt(plan=plan)) class EventConsole(IntrospectingConsole): """ diff --git a/dagster_sqlmesh/controller/base.py b/dagster_sqlmesh/controller/base.py index 51e671b..02635a2 100644 --- a/dagster_sqlmesh/controller/base.py +++ b/dagster_sqlmesh/controller/base.py @@ -9,7 +9,7 @@ from sqlmesh.core.console import set_console from sqlmesh.core.context import Context from sqlmesh.core.model import Model -from sqlmesh.core.plan import PlanBuilder +from sqlmesh.core.plan import Plan as SQLMeshPlan, PlanBuilder from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike @@ -19,7 +19,6 @@ ConsoleEventHandler, ConsoleException, EventConsole, - Plan, SnapshotCategorizer, ) from dagster_sqlmesh.events import ConsoleGenerator @@ -179,16 +178,7 @@ def run_sqlmesh_thread( ) -> None: logger.debug("dagster-sqlmesh: thread started") - def auto_execute_plan(event: ConsoleEvent): - if isinstance(event, Plan): - try: - event.plan_builder.apply() - except Exception as e: - controller.console.exception(e) - return None - try: - controller.console.add_handler(auto_execute_plan) builder = t.cast( PlanBuilder, context.plan_builder( @@ -381,6 +371,11 @@ def non_external_models_dag(self) -> t.Iterable[tuple[Model, set[str]]]: continue yield (model, deps) + +class ContextApplyFunction(t.Protocol): + def __call__(self, plan: SQLMeshPlan, *args: t.Any, **kwargs: t.Any) -> None: + ... + class SQLMeshController(t.Generic[ContextCls]): """Allows control of sqlmesh via a python interface. It is not suggested to use the constructor of this class directly, but instead use the provided @@ -480,7 +475,21 @@ def _create_context(self) -> ContextCls: if self.config.sqlmesh_config: options["config"] = self.config.sqlmesh_config set_console(self.console) - return self._context_factory(**options) + context = self._context_factory(**options) + + # As part of the context, it specifies a method "apply" that we would + # like to introspect. To do so we replace "apply" with a wrapped + # function that generates a special console event unique to + # dagster-sqlmesh + def wrap_apply_event(f: ContextApplyFunction) -> ContextApplyFunction: + def wrapped_apply(plan: SQLMeshPlan, *args: t.Any, **kwargs: t.Any): + self.logger.debug("capturing plan as event") + self.console.capture_built_plan(plan) + result = f(plan, *args, **kwargs) + return result + return wrapped_apply + context.apply = wrap_apply_event(context.apply) + return context @contextmanager def instance( diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index fdd3f35..bc5f037 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -1,16 +1,17 @@ import logging import typing as t +from datetime import UTC, datetime from types import MappingProxyType -from dagster import ( - AssetExecutionContext, - ConfigurableResource, - MaterializeResult, -) +import dagster as dg +import sqlglot from dagster._core.errors import DagsterInvalidPropertyError +from pydantic import BaseModel, Field +from sqlglot import exp from sqlmesh import Model from sqlmesh.core.context import Context as SQLMeshContext -from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo +from sqlmesh.core.plan import Plan as SQLMeshPlan +from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import SQLMeshError @@ -26,6 +27,123 @@ from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController from dagster_sqlmesh.utils import get_asset_key_str +logger = logging.getLogger(__name__) + + +def _START_OF_UNIX_TIME(): + dt = datetime.strptime("1970-01-01T00:00:00Z", "%Y-%m-%dT%H:%M:%SZ") + return dt.astimezone(UTC) + + +class ModelMaterializationStatus(BaseModel): + model_fqn: str + + # The last time this model was updated or restated + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + snapshot_id: str + + last_updated_or_restated: datetime = Field(default_factory=_START_OF_UNIX_TIME) + last_promoted: datetime = Field(default_factory=_START_OF_UNIX_TIME) + last_backfill: datetime = Field(default_factory=_START_OF_UNIX_TIME) + + def update_or_restate_now(self): + """Shortcut function to set last_updated_or_restated time to now""" + self.last_updated_or_restated = datetime.now(UTC) + + def promote_now(self): + """Shortcut function to set last_promoted time to now""" + self.last_promoted = datetime.now(UTC) + + def backfill_now(self): + """Shortcut function to set last_backfill time to now""" + self.last_backfill = datetime.now(UTC) + + def as_dagster_metadata( + self, previous: "ModelMaterializationStatus | None" + ) -> dict[str, dg.MetadataValue]: + if previous: + # if the previous materialization status exists then we compare all + # of the dates and take the _largest_ for all dates except + # `created_at` + last_updated_or_restated = dg.MetadataValue.timestamp( + max(previous.last_updated_or_restated, self.last_updated_or_restated) + ) + last_promoted = dg.MetadataValue.timestamp( + max(previous.last_promoted, self.last_promoted) + ) + last_backfill = dg.MetadataValue.timestamp( + max(previous.last_backfill, self.last_backfill) + ) + created_at = dg.MetadataValue.timestamp(previous.created_at) + else: + # If there is no previous materialization status all dates can use + # the created_at timestamp + created_at = dg.MetadataValue.timestamp(self.created_at) + last_updated_or_restated = dg.MetadataValue.timestamp(self.created_at) + last_promoted = dg.MetadataValue.timestamp(self.created_at) + last_backfill = dg.MetadataValue.timestamp(self.created_at) + + return { + "snapshot_id": dg.MetadataValue.text(self.snapshot_id), + "model_fqn": dg.MetadataValue.text(self.model_fqn), + "created_at": created_at, + "last_updated_or_restated": last_updated_or_restated, + "last_promoted": last_promoted, + "last_backfill": last_backfill, + } + + @classmethod + def from_dagster_metadata( + cls, metadata: dict[str, t.Any] + ) -> "ModelMaterializationStatus": + # convert metadata values + converted: dict[str, dg.MetadataValue] = {} + for key, value in metadata.items(): + assert isinstance( + value, dg.MetadataValue + ), f"Expected MetadataValue for {key}, got {type(value)}" + converted[key] = value + + return cls.model_validate( + dict( + model_fqn=converted["model_fqn"].value, + snapshot_id=converted["snapshot_id"].value, + created_at=converted["created_at"].value, + last_updated_or_restated=converted["last_updated_or_restated"].value, + last_promoted=converted["last_promoted"].value, + last_backfill=converted["last_backfill"].value, + ) + ) + + def as_glot_table(self) -> exp.Table: + return sqlglot.to_table(self.model_fqn) + + def is_match(self, input: str, ignore_catalog: bool = False) -> bool: + """Tests if the passed in string matches this model's table + + Args: + input (str): The input string to match against the model's table. + ignore_catalog (bool): Whether to use to only match table db and + name (default: False) + + Returns: + bool - True if the input string matches the model's table, False + otherwise. + """ + table = self.as_glot_table() + + input_as_table = sqlglot.to_table(input) + + if input_as_table.name != table.name: + return False + if input_as_table.db != table.db: + return False + + if not ignore_catalog: + if input_as_table.catalog != table.catalog: + return False + return True + class MaterializationTracker: """Tracks sqlmesh materializations and notifies dagster in the correct @@ -36,24 +154,68 @@ def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None: self.logger = logger self._batches: dict[Snapshot, int] = {} self._count: dict[Snapshot, int] = {} - self._complete_update_status: dict[str, bool] = {} + self._model_metadata: dict[str, ModelMaterializationStatus] = {} + self._non_model_names: set[str] = set() self._sorted_dag = sorted_dag self._current_index = 0 self.finished_promotion = False - def init_complete_update_status(self, snapshots: list[SnapshotTableInfo]) -> None: - planned_model_names = set() - for snapshot in snapshots: - planned_model_names.add(snapshot.name) - - # Anything not in the plan should be listed as completed and queued for - # notification - self._complete_update_status = { - name: False for name in (set(self._sorted_dag) - planned_model_names) + def initialize_from_plan(self, plan: SQLMeshPlan): + # Initialize all of the model materialization statuses + # Existing snapshots + snapshots_by_name = { + snapshot.name: snapshot for snapshot in plan.snapshots.values() } + created_at = datetime.now(UTC) + + # Include new snapshots + for snapshot in plan.context_diff.new_snapshots.values(): + snapshots_by_name[snapshot.name] = snapshot + + for model_fqn in self._sorted_dag: + snapshot = snapshots_by_name.get(model_fqn) + + if not snapshot: + self._non_model_names.add(model_fqn) + continue + + if snapshot.is_external: + self._non_model_names.add(model_fqn) + continue + + self._model_metadata[snapshot.name] = ModelMaterializationStatus( + model_fqn=snapshot.model.fqn, + snapshot_id=snapshot.identifier, + created_at=created_at, + ) + + # Update all of the model status that are to be updated or restated in this plan + # This condition was taken from a condition found in sqlmesh's `Context` + # object. It's used to determine if there are any changes in the plan + if ( + not plan.context_diff.has_changes + and not plan.requires_backfill + and not plan.has_unmodified_unpromoted + ): + self.logger.info("No changes detected, adding all models to the queue") + else: + context_diff = plan.context_diff + for snapshot in context_diff.new_snapshots.values(): + self._model_metadata[snapshot.name].update_or_restate_now() + for snapshots in context_diff.modified_snapshots.values(): + self._model_metadata[snapshots[0].name].update_or_restate_now() + for snapshot_id in plan.restatements.keys(): + self._model_metadata[ + plan.snapshots[snapshot_id].name + ].update_or_restate_now() + def update_promotion(self, snapshot: SnapshotInfoLike, promoted: bool) -> None: - self._complete_update_status[snapshot.name] = promoted + if promoted: + self._model_metadata[snapshot.name].promote_now() + + def update_run(self, snapshot: SnapshotInfoLike) -> None: + self._model_metadata[snapshot.name].backfill_now() def stop_promotion(self) -> None: self.finished_promotion = True @@ -64,6 +226,7 @@ def plan(self, batches: dict[Snapshot, int]) -> None: for snapshot, _ in self._batches.items(): self._count[snapshot] = 0 + self._model_metadata[snapshot.name].backfill_now() def update_plan(self, snapshot: Snapshot, _batch_idx: int) -> tuple[int, int]: self._count[snapshot] += 1 @@ -71,14 +234,40 @@ def update_plan(self, snapshot: Snapshot, _batch_idx: int) -> tuple[int, int]: expected_count = self._batches[snapshot] return (current_count, expected_count) - def notify_queue_next(self) -> tuple[str, bool] | None: + def notify_queue_next(self) -> tuple[str, ModelMaterializationStatus] | None: + """Notifies about the next materialization in the queue. At the end of a + sqlmesh run the `all_up_to_date` flag should be set to True. + + Returns: + A tuple containing the name of the materialization and its status, + or None if there are no more model statuses left in the queue + """ if self._current_index >= len(self._sorted_dag): return None - check_name = self._sorted_dag[self._current_index] - if check_name in self._complete_update_status: - self._current_index += 1 - return (check_name, self._complete_update_status[check_name]) - return None + self.logger.debug( + f"MaterializationTracker index {self._current_index}", + extra=dict( + current_index=self._current_index, + ), + ) + + while True: + model_name_for_notification = self._sorted_dag[self._current_index] + + if model_name_for_notification in self._non_model_names: + self._current_index += 1 + self.logger.debug( + f"skipping non-model snapshot {model_name_for_notification}" + ) + continue + + if model_name_for_notification in self._model_metadata: + self._current_index += 1 + return ( + model_name_for_notification, + self._model_metadata[model_name_for_notification], + ) + return None class SQLMeshEventLogContext: @@ -136,12 +325,28 @@ def __init__(self, stage: str, message: str, errors: list[Exception]) -> None: class DagsterSQLMeshEventHandler: def __init__( self, - context: AssetExecutionContext, + context: dg.AssetExecutionContext, models_map: dict[str, Model], dag: DAG[t.Any], prefix: str, is_testing: bool = False, + materializations_enabled: bool = True, ) -> None: + """Dagster event handler for SQLMesh models. + + The handler is responsible for reporting events from sqlmesh to dagster. + + Args: + context: The Dagster asset execution context. + models_map: A mapping of model names to their SQLMesh model instances. + dag: The directed acyclic graph representing the SQLMesh models. + prefix: A prefix to use for all asset keys generated by this handler. + is_testing: Whether the handler is being used in a testing context. + materializations_enabled: Whether the handler is to generate + materializations, this should be disabled if you with to run a + sqlmesh plan or run in an environment different from the normal + target environment. + """ self._models_map = models_map self._prefix = prefix self._context = context @@ -152,16 +357,18 @@ def __init__( self._stage = "plan" self._errors: list[Exception] = [] self._is_testing = is_testing + self._materializations_enabled = materializations_enabled def process_events(self, event: console.ConsoleEvent) -> None: self.report_event(event) def notify_success( self, sqlmesh_context: SQLMeshContext - ) -> t.Iterator[MaterializeResult]: + ) -> t.Iterator[dg.MaterializeResult]: notify = self._tracker.notify_queue_next() + while notify is not None: - completed_name, update_status = notify + completed_name, materialization_status = notify # If the model is not in the context, we can skip any notification # This will happen for external models @@ -176,25 +383,79 @@ def notify_success( if model: # Passing model.fqn to get internal unique asset key output_key = get_asset_key_str(model.fqn) - if not self._is_testing: - # Stupidly dagster when testing cannot use the following - # method so we must specifically skip this when testing + if self._is_testing: + asset_key = dg.AssetKey(["testing", output_key]) + self._logger.warning( + f"Generated fake asset key for testing: {asset_key.to_user_string()}" + ) + else: asset_key = self._context.asset_key_for_output(output_key) - yield MaterializeResult( - asset_key=asset_key, - metadata={ - "updated": update_status, - "duration_ms": 0, - }, + if self._materializations_enabled: + yield self.create_materialize_result( + self._context, asset_key, materialization_status + ) + else: + self._logger.debug( + f"Materializations disabled. Would have materialized for {asset_key.to_user_string()}" ) notify = self._tracker.notify_queue_next() + else: + self._logger.debug("No more materializations to process") + + def create_materialize_result( + self, + context: dg.AssetExecutionContext, + asset_key: dg.AssetKey, + current_materialization_status: ModelMaterializationStatus, + ) -> dg.MaterializeResult: + last_materialization = context.instance.get_latest_materialization_event( + asset_key + ) + + if not last_materialization: + self._logger.debug( + f"No materialization found for {asset_key.to_user_string()}, all dates will be set to now." + ) + last_materialization_status = None + else: + assert ( + last_materialization.asset_materialization is not None + ), "Expected asset materialization to be present." + try: + last_materialization_status = ( + ModelMaterializationStatus.from_dagster_metadata( + dict(last_materialization.asset_materialization.metadata) + ) + ) + except Exception as e: + self._logger.warning( + f"Failed to validate last materialization for {asset_key.to_user_string()}: {e}. Ignoring and using the current status" + ) + last_materialization_status = None + + return dg.MaterializeResult( + asset_key=asset_key, + metadata=current_materialization_status.as_dagster_metadata( + last_materialization_status + ), + ) def report_event(self, event: console.ConsoleEvent) -> None: log_context = self.log_context(event) match event: + case console.PlanBuilt(plan=plan): + log_context.info( + "Plan built", + { + "snapshots": [s.name for s in plan.environment.snapshots], + "models_to_backfill": plan.models_to_backfill, + "empty_backfill": plan.empty_backfill, + "requires_backfill": plan.requires_backfill, + }, + ) + self._tracker.initialize_from_plan(plan) case console.StartPlanEvaluation(plan=plan): - self._tracker.init_complete_update_status(plan.environment.snapshots) log_context.info( "Starting Plan Evaluation", { @@ -226,14 +487,23 @@ def report_event(self, event: console.ConsoleEvent) -> None: ): done, expected = self._tracker.update_plan(snapshot, batch_idx) - log_context.info( - "Snapshot progress update", - { - "asset_key": get_asset_key_str(snapshot.model.name), - "progress": f"{done}/{expected}", - "duration_ms": duration_ms, - }, - ) + if done == expected: + log_context.info( + "Snapshot progress complete", + { + "asset_key": get_asset_key_str(snapshot.model.name), + }, + ) + self._tracker.update_run(snapshot) + else: + log_context.info( + "Snapshot progress update", + { + "asset_key": get_asset_key_str(snapshot.model.name), + "progress": f"{done}/{expected}", + "duration_ms": duration_ms, + }, + ) case console.LogSuccess(success=success): self.update_stage("done") if success: @@ -304,13 +574,13 @@ def errors(self) -> list[Exception]: return self._errors[:] -class SQLMeshResource(ConfigurableResource): +class SQLMeshResource(dg.ConfigurableResource): config: SQLMeshContextConfig is_testing: bool = False def run( self, - context: AssetExecutionContext, + context: dg.AssetExecutionContext, *, context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY, environment: str = "dev", @@ -322,7 +592,8 @@ def run( skip_run: bool = False, plan_options: PlanOptions | None = None, run_options: RunOptions | None = None, - ) -> t.Iterable[MaterializeResult]: + materializations_enabled: bool = True, + ) -> t.Iterable[dg.MaterializeResult]: """Execute SQLMesh based on the configuration given""" plan_options = plan_options or {} run_options = run_options or {} @@ -360,6 +631,7 @@ def run( dag=dag, prefix="sqlmesh: ", is_testing=self.is_testing, + materializations_enabled=materializations_enabled, ) def raise_for_sqlmesh_errors( @@ -396,20 +668,24 @@ def raise_for_sqlmesh_errors( except SQLMeshError as e: logger.error(f"sqlmesh error: {e}") raise_for_sqlmesh_errors(event_handler, [GenericSQLMeshError(str(e))]) + logger.info(f"sqlmesh run completed for {len(models_map)} models") # Some errors do not raise exceptions immediately, so we need to check # the event handler for any errors that may have been collected. raise_for_sqlmesh_errors(event_handler) yield from event_handler.notify_success(mesh.context) + logger.debug("sqlmesh selected all models notified of completion") + def create_event_handler( self, *, - context: AssetExecutionContext, + context: dg.AssetExecutionContext, dag: DAG[str], models_map: dict[str, Model], prefix: str, is_testing: bool, + materializations_enabled: bool, ) -> DagsterSQLMeshEventHandler: return DagsterSQLMeshEventHandler( context=context, @@ -417,10 +693,11 @@ def create_event_handler( models_map=models_map, prefix=prefix, is_testing=is_testing, + materializations_enabled=materializations_enabled, ) def _get_selected_models_from_context( - self, context: AssetExecutionContext, models: MappingProxyType[str, Model] + self, context: dg.AssetExecutionContext, models: MappingProxyType[str, Model] ) -> tuple[set[str], dict[str, Model], list[str] | None]: models_map = models.copy() try: diff --git a/dagster_sqlmesh/test_resource.py b/dagster_sqlmesh/test_resource.py index 38a3924..3563978 100644 --- a/dagster_sqlmesh/test_resource.py +++ b/dagster_sqlmesh/test_resource.py @@ -1,22 +1,64 @@ import typing as t +from dataclasses import dataclass import dagster as dg +import pytest -from dagster_sqlmesh.resource import DagsterSQLMeshEventHandler, PlanOrRunFailedError +from dagster_sqlmesh.resource import ( + DagsterSQLMeshEventHandler, + ModelMaterializationStatus, + PlanOrRunFailedError, +) from dagster_sqlmesh.testing import setup_testing_sqlmesh_test_context +from dagster_sqlmesh.testing.context import SQLMeshTestContext, TestSQLMeshResource -def test_sqlmesh_resource_should_report_no_errors( - sample_sqlmesh_project: str, sample_sqlmesh_db_path: str +@dataclass(kw_only=True) +class SQLMeshResourceInitialization: + dagster_instance: dg.DagsterInstance + dagster_context: dg.AssetExecutionContext + test_context: SQLMeshTestContext + resource: TestSQLMeshResource + + +def setup_sqlmesh_resource( + sample_sqlmesh_project: str, sample_sqlmesh_db_path: str, enable_model_failure: bool ): - dg_context = dg.build_asset_context() + dg_instance = dg.DagsterInstance.local_temp() + dg_context = dg.build_asset_context( + instance=dg_instance, + ) test_context = setup_testing_sqlmesh_test_context( db_path=sample_sqlmesh_db_path, project_path=sample_sqlmesh_project, - variables={"enable_model_failure": False} + variables={"enable_model_failure": enable_model_failure}, ) test_context.initialize_test_source() - resource = test_context.create_resource() + resource = test_context.create_resource() + return SQLMeshResourceInitialization( + dagster_context=dg_context, + test_context=test_context, + resource=resource, + dagster_instance=dg_instance, + ) + + +@pytest.fixture +def sample_sqlmesh_resource_initialization( + sample_sqlmesh_project: str, sample_sqlmesh_db_path: str +): + return setup_sqlmesh_resource( + sample_sqlmesh_project=sample_sqlmesh_project, + sample_sqlmesh_db_path=sample_sqlmesh_db_path, + enable_model_failure=False, + ) + + +def test_sqlmesh_resource_should_report_no_errors( + sample_sqlmesh_resource_initialization: SQLMeshResourceInitialization, +): + resource = sample_sqlmesh_resource_initialization.resource + dg_context = sample_sqlmesh_resource_initialization.dagster_context success = True try: @@ -34,14 +76,13 @@ def test_sqlmesh_resource_should_report_no_errors( def test_sqlmesh_resource_properly_reports_errors( sample_sqlmesh_project: str, sample_sqlmesh_db_path: str ): - dg_context = dg.build_asset_context() - test_context = setup_testing_sqlmesh_test_context( - db_path=sample_sqlmesh_db_path, - project_path=sample_sqlmesh_project, - variables={"enable_model_failure": True} + sqlmesh_resource_initialization = setup_sqlmesh_resource( + sample_sqlmesh_project=sample_sqlmesh_project, + sample_sqlmesh_db_path=sample_sqlmesh_db_path, + enable_model_failure=True, ) - test_context.initialize_test_source() - resource = test_context.create_resource() + resource = sqlmesh_resource_initialization.resource + dg_context = sqlmesh_resource_initialization.dagster_context caught_failure = False try: @@ -56,27 +97,25 @@ def test_sqlmesh_resource_properly_reports_errors( expected_error_found = True break assert expected_error_found, "Expected error not found in the error list." - + assert caught_failure, "Expected an error to be raised, but it was not." def test_sqlmesh_resource_properly_reports_errors_not_thrown( - sample_sqlmesh_project: str, sample_sqlmesh_db_path: str + sample_sqlmesh_resource_initialization: SQLMeshResourceInitialization, ): - dg_context = dg.build_asset_context() - test_context = setup_testing_sqlmesh_test_context( - db_path=sample_sqlmesh_db_path, - project_path=sample_sqlmesh_project, - variables={"enable_model_failure": False} - ) - test_context.initialize_test_source() - resource = test_context.create_resource() - def event_handler_factory(*args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshEventHandler: + dg_context = sample_sqlmesh_resource_initialization.dagster_context + resource = sample_sqlmesh_resource_initialization.resource + + def event_handler_factory( + *args: t.Any, **kwargs: t.Any + ) -> DagsterSQLMeshEventHandler: """Custom event handler factory for the SQLMesh resource.""" handler = DagsterSQLMeshEventHandler(*args, **kwargs) # Load it with an error handler._errors = [Exception("testerror")] return handler + resource.set_event_handler_factory(event_handler_factory) caught_failure = False @@ -92,7 +131,115 @@ def event_handler_factory(*args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshEventH if "testerror" in str(err): expected_error_found = True break - assert expected_error_found, "Expected error 'testerror' not found in the error list." + assert ( + expected_error_found + ), "Expected error 'testerror' not found in the error list." assert caught_failure, "Expected an error to be raised, but it was not." + +def test_sqlmesh_resource_should_properly_materialize_results_when_no_plan_is_run( + sample_sqlmesh_resource_initialization: SQLMeshResourceInitialization, +): + """We had an issue with sqlmesh not properly materializing sqlmesh models if + the plan ended up not having any changes and all models were already up to + date. + + This test is to ensure that doesn't regress. + """ + + resource = sample_sqlmesh_resource_initialization.resource + dg_context = sample_sqlmesh_resource_initialization.dagster_context + dg_instance = sample_sqlmesh_resource_initialization.dagster_instance + + # First run should materialize all models + initial_results: list[dg.MaterializeResult] = [] + for result in resource.run(dg_context): + initial_results.append(result) + assert result.asset_key is not None, "Expected asset key to be present." + dg_instance.report_runless_asset_event(dg.AssetMaterialization( + asset_key=result.asset_key, + metadata=result.metadata, + )) + + + # All metadata times should be set to the same time + initial_times: set[float] = set() + for result in initial_results: + assert result.metadata is not None, "Expected metadata to be present." + status = ModelMaterializationStatus.from_dagster_metadata(dict(result.metadata)) + initial_times.add(status.created_at.timestamp()) + initial_times.add(status.last_backfill.timestamp()) + initial_times.add(status.last_updated_or_restated.timestamp()) + initial_times.add(status.last_promoted.timestamp()) + assert len(initial_times) == 1, "Expected all metadata times to be the same." + + assert len(initial_results) > 0, "Expected initial results to be non-empty." + + original_created_at = next(iter(initial_times)) + + # Second run should also materialize all models + second_results: list[dg.MaterializeResult] = [] + for result in resource.run(dg_context): + second_results.append(result) + + assert len(second_results) > 0, "Expected second results to be non-empty." + assert len(initial_results) == len( + second_results + ), "Expected initial and second results to have the same number of materialized assets" + + # Assert that all models were not updated + second_times: set[float] = set() + for result in second_results: + assert result.metadata is not None, "Expected metadata to be present." + status = ModelMaterializationStatus.from_dagster_metadata(dict(result.metadata)) + second_times.add(status.created_at.timestamp()) + second_times.add(status.last_backfill.timestamp()) + second_times.add(status.last_updated_or_restated.timestamp()) + second_times.add(status.last_promoted.timestamp()) + assert ( + len(second_times) == 1 + ), "Expected all metadata times to be the same in the second run." + + # Third run will restate the full model + third_results: list[dg.MaterializeResult] = [] + for result in resource.run( + dg_context, restate_models=["sqlmesh_example.full_model"] + ): + third_results.append(result) + + assert len(third_results) > 0, "Expected third results to be non-empty." + assert len(third_results) == len( + initial_results + ), "Expected third results to have the same number of materialized assets as initial results" + + # The full model's metadata should indicate it was updated all others should + # be promoted + for result in third_results: + assert result.metadata is not None, "Expected metadata to be present." + status = ModelMaterializationStatus.from_dagster_metadata(dict(result.metadata)) + assert ( + status.created_at.timestamp() == original_created_at + ), "Expected created_at to be unchanged for all models" + + # Oddly, sqlmesh promotes everything during a restatement + assert ( + status.last_promoted.timestamp() > status.created_at.timestamp() + ), "Expected last_promoted to be updated for all models" + + if status.is_match("sqlmesh_example.full_model", ignore_catalog=True): + assert ( + status.last_updated_or_restated.timestamp() + > status.created_at.timestamp() + ), "Expected full model to be updated." + assert ( + status.last_backfill.timestamp() > status.created_at.timestamp() + ), "Expected only full model to be backfilled" + else: + assert ( + status.last_updated_or_restated.timestamp() + == status.created_at.timestamp() + ), f"{status.model_fqn} updated. Expected only full model to be updated" + assert ( + status.last_backfill.timestamp() == status.created_at.timestamp() + ), f"{status.model_fqn} run. Expected only full model to be run" diff --git a/pyproject.toml b/pyproject.toml index 1526385..f523205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "sqlmesh<0.188", "pytest>=8.3.2", "pyarrow>=18.0.0", + "pydantic>=2.11.5", ] [dependency-groups] diff --git a/uv.lock b/uv.lock index 3b69fc2..66101f8 100644 --- a/uv.lock +++ b/uv.lock @@ -291,11 +291,12 @@ wheels = [ [[package]] name = "dagster-sqlmesh" -version = "0.18.0" +version = "0.19.0" source = { editable = "." } dependencies = [ { name = "dagster" }, { name = "pyarrow" }, + { name = "pydantic" }, { name = "pytest" }, { name = "sqlmesh" }, ] @@ -316,6 +317,7 @@ dev = [ requires-dist = [ { name = "dagster", specifier = ">=1.7.8" }, { name = "pyarrow", specifier = ">=18.0.0" }, + { name = "pydantic", specifier = ">=2.11.5" }, { name = "pytest", specifier = ">=8.3.2" }, { name = "sqlmesh", specifier = "<0.188" }, ]