|
4 | 4 | import sys |
5 | 5 | import tempfile |
6 | 6 | import typing as t |
7 | | -from dataclasses import dataclass |
8 | 7 |
|
9 | | -import duckdb |
10 | | -import polars |
11 | 8 | import pytest |
12 | 9 | from sqlmesh.core.config import ( |
13 | 10 | Config as SQLMeshConfig, |
14 | 11 | DuckDBConnectionConfig, |
15 | 12 | GatewayConfig, |
16 | 13 | ModelDefaultsConfig, |
17 | 14 | ) |
18 | | -from sqlmesh.core.console import get_console |
19 | | -from sqlmesh.utils.date import TimeLike |
20 | 15 |
|
21 | 16 | from dagster_sqlmesh.config import SQLMeshContextConfig |
22 | | -from dagster_sqlmesh.console import ConsoleEvent |
23 | | -from dagster_sqlmesh.controller.base import PlanOptions, RunOptions |
24 | | -from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController |
25 | | -from dagster_sqlmesh.events import ConsoleRecorder |
| 17 | +from dagster_sqlmesh.testing import SQLMeshTestContext |
26 | 18 |
|
27 | 19 | logger = logging.getLogger(__name__) |
28 | 20 |
|
@@ -50,110 +42,6 @@ def sample_sqlmesh_project() -> t.Iterator[str]: |
50 | 42 | yield str(project_dir) |
51 | 43 |
|
52 | 44 |
|
53 | | -@dataclass |
54 | | -class SQLMeshTestContext: |
55 | | - """A test context for running SQLMesh""" |
56 | | - |
57 | | - db_path: str |
58 | | - context_config: SQLMeshContextConfig |
59 | | - |
60 | | - def create_controller( |
61 | | - self, enable_debug_console: bool = False |
62 | | - ) -> DagsterSQLMeshController: |
63 | | - console = None |
64 | | - if enable_debug_console: |
65 | | - console = get_console() |
66 | | - return DagsterSQLMeshController.setup_with_config( |
67 | | - self.context_config, debug_console=console |
68 | | - ) |
69 | | - |
70 | | - def query(self, *args: t.Any, **kwargs: t.Any) -> t.Any: |
71 | | - conn = duckdb.connect(self.db_path) |
72 | | - return conn.sql(*args, **kwargs).fetchall() |
73 | | - |
74 | | - def initialize_test_source(self) -> None: |
75 | | - conn = duckdb.connect(self.db_path) |
76 | | - conn.sql( |
77 | | - """ |
78 | | - CREATE SCHEMA sources; |
79 | | - """ |
80 | | - ) |
81 | | - conn.sql( |
82 | | - """ |
83 | | - CREATE TABLE sources.test_source (id INTEGER, name VARCHAR); |
84 | | - """ |
85 | | - ) |
86 | | - conn.sql( |
87 | | - """ |
88 | | - INSERT INTO sources.test_source (id, name) |
89 | | - VALUES (1, 'abc'), (2, 'def'); |
90 | | - """ |
91 | | - ) |
92 | | - conn.close() |
93 | | - |
94 | | - def append_to_test_source(self, df: polars.DataFrame): |
95 | | - logger.debug("appending data to the test source") |
96 | | - conn = duckdb.connect(self.db_path) |
97 | | - conn.sql( |
98 | | - """ |
99 | | - INSERT INTO sources.test_source |
100 | | - SELECT * FROM df |
101 | | - """ |
102 | | - ) |
103 | | - |
104 | | - def plan_and_run( |
105 | | - self, |
106 | | - *, |
107 | | - environment: str, |
108 | | - execution_time: TimeLike | None = None, |
109 | | - enable_debug_console: bool = False, |
110 | | - start: TimeLike | None = None, |
111 | | - end: TimeLike | None = None, |
112 | | - select_models: list[str] | None = None, |
113 | | - restate_selected: bool = False, |
114 | | - skip_run: bool = False, |
115 | | - ) -> t.Iterator[ConsoleEvent] | None: |
116 | | - """Runs plan and run on SQLMesh with the given configuration and record all of the generated events. |
117 | | -
|
118 | | - Args: |
119 | | - environment (str): The environment to run SQLMesh in. |
120 | | - execution_time (TimeLike, optional): The execution timestamp for the run. Defaults to None. |
121 | | - enable_debug_console (bool, optional): Flag to enable debug console. Defaults to False. |
122 | | - start (TimeLike, optional): Start time for the run interval. Defaults to None. |
123 | | - end (TimeLike, optional): End time for the run interval. Defaults to None. |
124 | | - restate_models (List[str], optional): List of models to restate. Defaults to None. |
125 | | -
|
126 | | - Returns: |
127 | | - None: The function records events to a debug console but doesn't return anything. |
128 | | -
|
129 | | - Note: |
130 | | - TimeLike can be any time-like object that SQLMesh accepts (datetime, str, etc.). |
131 | | - The function creates a controller and recorder to capture all SQLMesh events during execution. |
132 | | - """ |
133 | | - controller = self.create_controller(enable_debug_console=enable_debug_console) |
134 | | - recorder = ConsoleRecorder() |
135 | | - # controller.add_event_handler(ConsoleRecorder()) |
136 | | - plan_options = PlanOptions( |
137 | | - enable_preview=True, |
138 | | - ) |
139 | | - run_options = RunOptions() |
140 | | - if execution_time: |
141 | | - plan_options["execution_time"] = execution_time |
142 | | - run_options["execution_time"] = execution_time |
143 | | - |
144 | | - for event in controller.plan_and_run( |
145 | | - environment, |
146 | | - start=start, |
147 | | - end=end, |
148 | | - select_models=select_models, |
149 | | - restate_selected=restate_selected, |
150 | | - plan_options=plan_options, |
151 | | - run_options=run_options, |
152 | | - skip_run=skip_run, |
153 | | - ): |
154 | | - recorder(event) |
155 | | - |
156 | | - |
157 | 45 | @pytest.fixture |
158 | 46 | def sample_sqlmesh_test_context( |
159 | 47 | sample_sqlmesh_project: str, |
|
0 commit comments