Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions dagster_sqlmesh/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@
import typing as t

import pytest
from sqlmesh.core.config import (
Config as SQLMeshConfig,
DuckDBConnectionConfig,
GatewayConfig,
ModelDefaultsConfig,
)

from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.testing import SQLMeshTestContext
from dagster_sqlmesh.testing import (
SQLMeshTestContext,
setup_testing_sqlmesh_context_config,
)

logger = logging.getLogger(__name__)

Expand All @@ -41,23 +38,19 @@ def sample_sqlmesh_project() -> t.Iterator[str]:
# Initialize the "source" data
yield str(project_dir)

@pytest.fixture
def sample_sqlmesh_db_path(sample_sqlmesh_project: str) -> t.Iterator[str]:
db_path = os.path.join(sample_sqlmesh_project, "db.db")
yield db_path

@pytest.fixture
def sample_sqlmesh_test_context_config(sample_sqlmesh_project: str, sample_sqlmesh_db_path: str) -> t.Iterator[SQLMeshContextConfig]:
yield setup_testing_sqlmesh_context_config(db_path=sample_sqlmesh_db_path, project_path=sample_sqlmesh_project)

@pytest.fixture
def sample_sqlmesh_test_context(
sample_sqlmesh_project: str,
sample_sqlmesh_project: str, sample_sqlmesh_test_context_config: SQLMeshContextConfig, sample_sqlmesh_db_path: str
) -> t.Iterator[SQLMeshTestContext]:
db_path = os.path.join(sample_sqlmesh_project, "db.db")
config = SQLMeshConfig(
gateways={
"local": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path)),
},
default_gateway="local",
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
)
config_as_dict = config.dict()
context_config = SQLMeshContextConfig(
path=sample_sqlmesh_project, gateway="local", config_override=config_as_dict
)
test_context = SQLMeshTestContext(db_path=db_path, context_config=context_config)
test_context = SQLMeshTestContext(db_path=sample_sqlmesh_db_path, context_config=sample_sqlmesh_test_context_config)
test_context.initialize_test_source()
yield test_context
4 changes: 2 additions & 2 deletions dagster_sqlmesh/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,14 @@ def publish_known_event(self, event_name: str, **kwargs: t.Any) -> None:

def publish(self, event: ConsoleEvent) -> None:
self.logger.debug(
f"EventConsole[{self.id}]: sending event to {len(self._handlers)}"
f"EventConsole[{self.id}]: sending event {event.__class__.__name__} to {len(self._handlers)}"
)
for handler in self._handlers.values():
handler(event)

def publish_unknown_event(self, event_name: str, **kwargs: t.Any) -> None:
self.logger.debug(
f"EventConsole[{self.id}]: sending unknown event to {len(self._handlers)}"
f"EventConsole[{self.id}]: sending unknown '{event_name}' event to {len(self._handlers)} handlers"
)
self.logger.debug(f"EventConsole[{self.id}]: unknown event {event_name} {kwargs}")

Expand Down
34 changes: 23 additions & 11 deletions dagster_sqlmesh/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,17 @@ def run_sqlmesh_thread(
default_catalog: str,
) -> None:
logger.debug("dagster-sqlmesh: thread started")

def auto_execute_plan(event: ConsoleEvent):
if isinstance(event, Plan):
try:
event.plan_builder.apply()
except Exception as e:
controller.console.exception(e)
return None

try:
controller.console.add_handler(auto_execute_plan)
builder = t.cast(
PlanBuilder,
context.plan_builder(
Expand All @@ -191,7 +201,7 @@ def run_sqlmesh_thread(
builder,
auto_apply=True,
default_catalog=default_catalog,
)
)
except Exception as e:
controller.console.exception(e)
except: # noqa: E722
Expand All @@ -218,16 +228,18 @@ def run_sqlmesh_thread(
thread.start()

self.logger.debug("waiting for events")
for event in generator.events(thread):
match event:
case ConsoleException(exception=e):
raise e
case Plan(plan_builder=plan_builder, auto_apply=auto_apply):
if auto_apply:
plan_builder.apply()
yield event
case _:
yield event
try:
for event in generator.events(thread):
match event:
case ConsoleException(exception=e):
raise e
case _:
yield event
except Exception as e:
import traceback
print("An exception occurred:")
print(traceback.format_exc())
raise

thread.join()

Expand Down
151 changes: 109 additions & 42 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging
import typing as t
from types import MappingProxyType

from dagster import (
AssetExecutionContext,
ConfigurableResource,
MaterializeResult,
)
from dagster._core.errors import DagsterInvalidPropertyError
from sqlmesh import Model
from sqlmesh.core.context import Context as SQLMeshContext
from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike
from sqlmesh.utils.errors import SQLMeshError

from dagster_sqlmesh.controller.base import (
DEFAULT_CONTEXT_FACTORY,
Expand Down Expand Up @@ -113,20 +116,41 @@ def event_name(self):
return self._event.__class__.__name__


class GenericSQLMeshError(Exception):
pass


class FailedModelError(Exception):
def __init__(self, model_name: str, message: str | None) -> None:
super().__init__(message)
self.model_name = model_name
self.message = message


class PlanOrRunFailedError(Exception):
def __init__(self, stage: str, message: str, errors: list[Exception]) -> None:
super().__init__(message)
self.stage = stage
self.errors = errors


class DagsterSQLMeshEventHandler:
def __init__(
self,
context: AssetExecutionContext,
models_map: dict[str, Model],
dag: DAG[t.Any],
prefix: str,
is_testing: bool = False,
) -> None:
self._models_map = models_map
self._prefix = prefix
self._context = context
self._logger = context.log
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
self._stage = "plan"
self._errors: list[Exception] = []
self._is_testing = is_testing

def process_events(self, event: console.ConsoleEvent) -> None:
self.report_event(event)
Expand All @@ -150,14 +174,17 @@ def notify_success(
# If the model is not in models_map, we can skip any notification
if model:
output_key = sqlmesh_model_name_to_key(model.name)
asset_key = self._context.asset_key_for_output(output_key)
yield MaterializeResult(
asset_key=asset_key,
metadata={
"updated": update_status,
"duration_ms": 0,
},
)
if not self._is_testing:
# Stupidly dagster when testing cannot use the following
# method so we must specifically skip this when testing
asset_key = self._context.asset_key_for_output(output_key)
yield MaterializeResult(
asset_key=asset_key,
metadata={
"updated": update_status,
"duration_ms": 0,
},
)
notify = self._tracker.notify_queue_next()

def report_event(self, event: console.ConsoleEvent) -> None:
Expand Down Expand Up @@ -210,19 +237,22 @@ def report_event(self, event: console.ConsoleEvent) -> None:
if success:
log_context.info("sqlmesh ran successfully")
else:
log_context.error("sqlmesh failed")
raise Exception("sqlmesh failed during run")
log_context.error("sqlmesh failed. check collected errors")
case console.LogError(message=message):
log_context.error(
f"sqlmesh reported an error: {message}",
)
case console.LogFailedModels(models=models):
if len(models) != 0:
self._errors.append(GenericSQLMeshError(message))
case console.LogFailedModels(errors=errors):
if len(errors) != 0:
failed_models = "\n".join(
[f"{model!s}\n{model.__cause__!s}" for model in models]
[f"{error.node!s}\n{error.__cause__!s}" for error in errors]
)
log_context.error(f"sqlmesh failed models: {failed_models}")
raise Exception("sqlmesh has failed models")
for error in errors:
self._errors.append(
FailedModelError(error.node, str(error.__cause__))
)
case console.UpdatePromotionProgress(snapshot=snapshot, promoted=promoted):
log_context.info(
"Promotion progress update",
Expand Down Expand Up @@ -263,9 +293,18 @@ def log(
def update_stage(self, stage: str):
self._stage = stage

@property
def stage(self) -> str:
return self._stage

@property
def errors(self) -> list[Exception]:
return self._errors[:]


class SQLMeshResource(ConfigurableResource):
config: SQLMeshContextConfig
is_testing: bool = False

def run(
self,
Expand Down Expand Up @@ -293,25 +332,16 @@ def run(
with controller.instance(environment) as mesh:
dag = mesh.models_dag()

select_models = []

models = mesh.models()
models_map = models.copy()
all_available_models = set(
[model.fqn for model, _ in mesh.non_external_models_dag()]
)
if context.selected_output_names:
models_map = {}
for key, model in models.items():
if (
sqlmesh_model_name_to_key(model.name)
in context.selected_output_names
):
models_map[key] = model
select_models.append(model.name)
selected_models_set = set(models_map.keys())

if all_available_models == selected_models_set:
selected_models_set, models_map, select_models = (
self._get_selected_models_from_context(context, models)
)

if all_available_models == selected_models_set or select_models is None:
logger.info("all models selected")

# Setting this to none to allow sqlmesh to select all models and
Expand All @@ -321,24 +351,61 @@ def run(
logger.info(f"selected models: {select_models}")

event_handler = DagsterSQLMeshEventHandler(
context, models_map, dag, "sqlmesh: "
context, models_map, dag, "sqlmesh: ", is_testing=self.is_testing
)

for event in mesh.plan_and_run(
start=start,
end=end,
select_models=select_models,
restate_models=restate_models,
restate_selected=restate_selected,
skip_run=skip_run,
plan_options=plan_options,
run_options=run_options,
):
logger.debug(f"sqlmesh event: {event}")
event_handler.process_events(event)

try:
for event in mesh.plan_and_run(
start=start,
end=end,
select_models=select_models,
restate_models=restate_models,
restate_selected=restate_selected,
skip_run=skip_run,
plan_options=plan_options,
run_options=run_options,
):
logger.debug(f"sqlmesh event: {event}")
event_handler.process_events(event)
except SQLMeshError as e:
logger.error(f"sqlmesh error: {e}")
errors = event_handler.errors
for error in errors:
logger.error(f"sqlmesh encountered the following error during sqlmesh {event_handler.stage}: {error}")
raise PlanOrRunFailedError(
event_handler.stage,
f"sqlmesh failed during {event_handler.stage} with {len(event_handler.errors) + 1} errors",
[e, *event_handler.errors],
)
yield from event_handler.notify_success(mesh.context)

def _get_selected_models_from_context(
self, context: AssetExecutionContext, models: MappingProxyType[str, Model]
) -> tuple[set[str], dict[str, Model], list[str] | None]:
models_map = models.copy()
try:
selected_output_names = set(context.selected_output_names)
except (DagsterInvalidPropertyError, AttributeError) as e:
# Special case for direct execution context when testing. This is related to:
# https://github.com/dagster-io/dagster/issues/23633
if "DirectOpExecutionContext" in str(e):
context.log.warning("Caught an error that is likely a direct execution")
return (set(models_map.keys()), models_map, None)
else:
raise e

select_models: list[str] = []
models_map = {}
for key, model in models.items():
if sqlmesh_model_name_to_key(model.name) in selected_output_names:
models_map[key] = model
select_models.append(model.name)
return (
set(models_map.keys()),
models_map,
select_models,
)

def get_controller(
self,
context_factory: ContextFactory[ContextCls],
Expand Down
2 changes: 1 addition & 1 deletion dagster_sqlmesh/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestC
translator = SQLMeshDagsterTranslator()
outs = controller.to_asset_outs("dev", translator)
assert len(list(outs.deps)) == 1
assert len(outs.outs) == 9
assert len(outs.outs) == 10
Loading