Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dagster_sqlmesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@

from .asset import *
from .config import *
from .controller import *
from .resource import *
from .translator import *
76 changes: 64 additions & 12 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,61 @@
DagsterSQLMeshController,
)
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.types import SQLMeshMultiAssetOptions

logger = logging.getLogger(__name__)

def sqlmesh_to_multi_asset_options(
*,
environment: str,
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
) -> SQLMeshMultiAssetOptions:
"""Converts sqlmesh project into a SQLMeshMultiAssetOptions object which is
an intermediate representation of the SQLMesh project that can be used to
create a dagster multi_asset definition."""
controller = DagsterSQLMeshController.setup_with_config(
config=config, context_factory=context_factory
)
if not dagster_sqlmesh_translator:
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()

conversion = controller.to_asset_outs(
environment,
translator=dagster_sqlmesh_translator,
)
return conversion

def sqlmesh_asset_from_multi_asset_options(
*,
sqlmesh_multi_asset_options: SQLMeshMultiAssetOptions,
name: str | None = None,
compute_kind: str = "sqlmesh",
op_tags: t.Mapping[str, t.Any] | None = None,
required_resource_keys: set[str] | None = None,
retry_policy: RetryPolicy | None = None,
enabled_subsetting: bool = False,
) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]:
"""Creates a dagster multi_asset definition from a SQLMeshMultiAssetOptions object."""
kwargs: dict[str, t.Any] = {}
if enabled_subsetting:
kwargs["can_subset"] = True

#asset_deps = sqlmesh_multi_asset_options.to_asset_deps()
#print("Asset deps boop:", asset_deps) # Debugging line

return multi_asset(
outs=sqlmesh_multi_asset_options.to_asset_outs(),
deps=sqlmesh_multi_asset_options.to_asset_deps(),
internal_asset_deps=sqlmesh_multi_asset_options.to_internal_asset_deps(),
name=name,
compute_kind=compute_kind,
op_tags=op_tags,
required_resource_keys=required_resource_keys,
retry_policy=retry_policy,
**kwargs,
)

# Define a SQLMesh Asset
def sqlmesh_assets(
Expand All @@ -30,19 +82,19 @@ def sqlmesh_assets(
# For now we don't set this by default
enabled_subsetting: bool = False,
) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]:
controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory)
if not dagster_sqlmesh_translator:
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator)

return multi_asset(
conversion = sqlmesh_to_multi_asset_options(
environment=environment,
config=config,
context_factory=context_factory,
dagster_sqlmesh_translator=dagster_sqlmesh_translator,
)

return sqlmesh_asset_from_multi_asset_options(
sqlmesh_multi_asset_options=conversion,
name=name,
outs=conversion.outs,
deps=conversion.deps,
internal_asset_deps=conversion.internal_asset_deps,
op_tags=op_tags,
compute_kind=compute_kind,
retry_policy=retry_policy,
can_subset=enabled_subsetting,
op_tags=op_tags,
required_resource_keys=required_resource_keys,
retry_policy=retry_policy,
enabled_subsetting=enabled_subsetting,
)
10 changes: 8 additions & 2 deletions dagster_sqlmesh/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from dagster import Config
from pydantic import Field
from sqlmesh.core.config import Config as MeshConfig
from sqlmesh.core.config.loader import load_configs


@dataclass
Expand All @@ -27,7 +29,11 @@ class SQLMeshContextConfig(Config):
config_override: dict[str, Any] | None = Field(default_factory=lambda: None)

@property
def sqlmesh_config(self) -> MeshConfig | None:
def sqlmesh_config(self) -> MeshConfig:
if self.config_override:
return MeshConfig.parse_obj(self.config_override)
return None
sqlmesh_path = Path(self.path)
configs = load_configs(None, MeshConfig, [sqlmesh_path])
if sqlmesh_path not in configs:
raise ValueError(f"SQLMesh configuration not found at {sqlmesh_path}")
return configs[sqlmesh_path]
84 changes: 53 additions & 31 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# pyright: reportPrivateImportUsage=false
import logging
from inspect import signature

from dagster import AssetDep, AssetKey, AssetOut
from dagster._core.definitions.asset_dep import CoercibleToAssetDep

from dagster_sqlmesh.controller.base import ContextCls, SQLMeshController
from dagster_sqlmesh.controller.base import (
ContextCls,
SQLMeshController,
)
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.types import SQLMeshModelDep, SQLMeshMultiAssetOptions
from dagster_sqlmesh.types import (
ConvertibleToAssetDep,
ConvertibleToAssetOut,
SQLMeshModelDep,
SQLMeshMultiAssetOptions,
)
from dagster_sqlmesh.utils import get_asset_key_str

logger = logging.getLogger(__name__)
Expand All @@ -17,47 +21,65 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]):
"""An extension of the sqlmesh controller specifically for dagster use"""

def to_asset_outs(
self, environment: str, translator: SQLMeshDagsterTranslator,
self,
environment: str,
translator: SQLMeshDagsterTranslator,
) -> SQLMeshMultiAssetOptions:
"""Loads all the asset outs of the current sqlmesh environment. If a
cache is provided, it will be tried first to load the asset outs."""

internal_asset_deps_map: dict[str, set[str]] = {}
deps_map: dict[str, ConvertibleToAssetDep] = {}
asset_outs: dict[str, ConvertibleToAssetOut] = {}

with self.instance(environment, "to_asset_outs") as instance:
context = instance.context
output = SQLMeshMultiAssetOptions()
depsMap: dict[str, CoercibleToAssetDep] = {}

for model, deps in instance.non_external_models_dag():
asset_key = translator.get_asset_key(context=context, fqn=model.fqn)
asset_key_str = asset_key.to_user_string()
model_deps = [
SQLMeshModelDep(fqn=dep, model=context.get_model(dep))
for dep in deps
]
internal_asset_deps: set[AssetKey] = set()
internal_asset_deps: set[str] = set()
asset_tags = translator.get_tags(context, model)

for dep in model_deps:
if dep.model:
internal_asset_deps.add(
translator.get_asset_key(context, dep.model.fqn)
)
dep_asset_key_str = translator.get_asset_key(
context, dep.model.fqn
).to_user_string()

internal_asset_deps.add(dep_asset_key_str)
else:
table = get_asset_key_str(dep.fqn)
key = translator.get_asset_key(context, dep.fqn)
key = translator.get_asset_key(
context, dep.fqn
).to_user_string()
internal_asset_deps.add(key)

# create an external dep
depsMap[table] = AssetDep(key)
deps_map[table] = translator.create_asset_dep(key=key)

model_key = get_asset_key_str(model.fqn)
# If current Dagster supports "kinds", add labels for Dagster UI
if "kinds" in signature(AssetOut).parameters:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model),
kinds={"sqlmesh", translator._get_context_dialect(context).lower()}
)
else:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model)
)
output.internal_asset_deps[model_key] = internal_asset_deps

output.deps = list(depsMap.values())
return output
asset_outs[model_key] = translator.create_asset_out(
model_key=model_key,
asset_key=asset_key_str,
tags=asset_tags,
is_required=False,
group_name=translator.get_group_name(context, model),
kinds={
"sqlmesh",
translator.get_context_dialect(context).lower(),
},
)
internal_asset_deps_map[model_key] = internal_asset_deps

deps = list(deps_map.values())

return SQLMeshMultiAssetOptions(
outs=asset_outs,
deps=deps,
internal_asset_deps=internal_asset_deps_map,
)
70 changes: 67 additions & 3 deletions dagster_sqlmesh/translator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,48 @@
import typing as t
from collections.abc import Sequence
from inspect import signature

from dagster import AssetKey
from dagster import AssetDep, AssetKey, AssetOut
from pydantic import BaseModel, Field
from sqlglot import exp
from sqlmesh.core.context import Context
from sqlmesh.core.model import Model

from .types import ConvertibleToAssetDep, ConvertibleToAssetOut


class IntermediateAssetOut(BaseModel):
model_key: str
asset_key: str
tags: t.Mapping[str, str] | None = None
is_required: bool = True
group_name: str | None = None
kinds: set[str] | None = None
kwargs: dict[str, t.Any] = Field(default_factory=dict)

def to_asset_out(self) -> AssetOut:
asset_key = AssetKey.from_user_string(self.asset_key)

if "kinds" not in signature(AssetOut).parameters:
self.kinds = None

return AssetOut(
key=asset_key,
tags=self.tags,
is_required=self.is_required,
group_name=self.group_name,
kinds=self.kinds,
**self.kwargs,
)


class IntermediateAssetDep(BaseModel):
key: str
kwargs: dict[str, t.Any] = Field(default_factory=dict)

def to_asset_dep(self) -> AssetDep:
return AssetDep(AssetKey.from_user_string(self.key))


class SQLMeshDagsterTranslator:
"""Translates sqlmesh objects for dagster"""
Expand All @@ -19,14 +57,40 @@ def get_asset_key_name(self, fqn: str) -> Sequence[str]:
asset_key_name = [table.catalog, table.db, table.name]

return asset_key_name

def get_group_name(self, context: Context, model: Model) -> str:
path = self.get_asset_key_name(model.fqn)
return path[-2]

def _get_context_dialect(self, context: Context) -> str:
def get_context_dialect(self, context: Context) -> str:
return context.engine_adapter.dialect

def create_asset_dep(self, *, key: str, **kwargs: t.Any) -> ConvertibleToAssetDep:
"""Create an object that resolves to an AssetDep

Most users of this library will not need to use this method, it is
primarily the way we enable cacheable assets from dagster-sqlmesh.
"""
return IntermediateAssetDep(key=key, kwargs=kwargs)

def create_asset_out(
self, *, model_key: str, asset_key: str, **kwargs: t.Any
) -> ConvertibleToAssetOut:
"""Create an object that resolves to an AssetOut

Most users of this library will not need to use this method, it is
primarily the way we enable cacheable assets from dagster-sqlmesh.
"""
return IntermediateAssetOut(
model_key=model_key,
asset_key=asset_key,
kinds=kwargs.pop("kinds", None),
tags=kwargs.pop("tags", None),
group_name=kwargs.pop("group_name", None),
is_required=kwargs.pop("is_required", False),
kwargs=kwargs,
)

def get_tags(self, context: Context, model: Model) -> dict[str, str]:
"""Given the sqlmesh context and a model return the tags for that model"""
return {k: "true" for k in model.tags}
43 changes: 38 additions & 5 deletions dagster_sqlmesh/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import typing as t
from dataclasses import dataclass, field

from dagster import AssetCheckResult, AssetKey, AssetMaterialization, AssetOut
from dagster._core.definitions.asset_dep import CoercibleToAssetDep
from dagster import AssetCheckResult, AssetDep, AssetKey, AssetMaterialization, AssetOut
from sqlmesh.core.model import Model

MultiAssetResponse = t.Iterable[AssetCheckResult | AssetMaterialization]
Expand Down Expand Up @@ -30,10 +29,44 @@ class SQLMeshModelDep:

def parse_fqn(self) -> SQLMeshParsedFQN:
return SQLMeshParsedFQN.parse(self.fqn)

class ConvertibleToAssetOut(t.Protocol):
def to_asset_out(self) -> AssetOut:
"""Convert to an AssetOut object."""
...

class ConvertibleToAssetDep(t.Protocol):
def to_asset_dep(self) -> AssetDep:
"""Convert to an AssetDep object."""
...

class ConvertibleToAssetKey(t.Protocol):
def to_asset_key(self) -> AssetKey:
...

@dataclass(kw_only=True)
class SQLMeshMultiAssetOptions:
outs: dict[str, AssetOut] = field(default_factory=lambda: {})
deps: t.Iterable[CoercibleToAssetDep] = field(default_factory=lambda: {})
internal_asset_deps: dict[str, set[AssetKey]] = field(default_factory=lambda: {})
"""Generic class for returning dagster multi asset options from SQLMesh, the
types used are intentionally generic so to allow for potentially using an
intermediate representation of the dagster asset objects. This is most
useful in caching purposes and is done to allow for users of this library to
manipulate the dagster asset creation process as they see fit."""

outs: t.Mapping[str, ConvertibleToAssetOut] = field(default_factory=lambda: {})
deps: t.Iterable[ConvertibleToAssetDep] = field(default_factory=lambda: [])
internal_asset_deps: t.Mapping[str, set[str]] = field(default_factory=lambda: {})

def to_asset_outs(self) -> t.Mapping[str, AssetOut]:
"""Convert to an iterable of AssetOut objects."""
return {key: out.to_asset_out() for key, out in self.outs.items()}

def to_asset_deps(self) -> t.Iterable[AssetDep]:
"""Convert to an iterable of AssetDep objects."""
return [dep.to_asset_dep() for dep in self.deps]

def to_internal_asset_deps(self) -> dict[str, set[AssetKey]]:
"""Convert to a dictionary of internal asset dependencies."""
return {
key: {AssetKey.from_user_string(dep) for dep in deps}
for key, deps in self.internal_asset_deps.items()
}
Loading