diff --git a/docs/guides/model_selection.md b/docs/guides/model_selection.md index 1d01c280da..db098a1538 100644 --- a/docs/guides/model_selection.md +++ b/docs/guides/model_selection.md @@ -2,7 +2,7 @@ This guide describes how to select specific models to include in a SQLMesh plan, which can be useful when modifying a subset of the models in a SQLMesh project. -Note: the selector syntax described below is also used for the SQLMesh `plan` [`--allow-destructive-model` selector](../concepts/plans.md#destructive-changes). +Note: the selector syntax described below is also used for the SQLMesh `plan` [`--allow-destructive-model` selector](../concepts/plans.md#destructive-changes) and for the `table_diff` command to [diff a selection of models](./tablediff.md#diffing-multiple-models-across-environments). ## Background @@ -221,6 +221,81 @@ Models: └── sushi.customer_revenue_lifetime ``` +#### Select with tags + +If we specify the `--select-model` option with a tag selector like `"tag:reporting"`, all models with the "reporting" tag will be selected. Tags are case-insensitive and support wildcards: + +```bash +❯ sqlmesh plan dev --select-model "tag:reporting*" +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models: +├── Directly Modified: +│ ├── sushi.daily_revenue +│ └── sushi.monthly_revenue +└── Indirectly Modified: + └── sushi.revenue_dashboard +``` + +#### Select with git changes + +The git-based selector allows you to select models whose files have changed compared to a target branch (default: main). This includes: +- Untracked files (new files not in git) +- Uncommitted changes in working directory +- Committed changes different from the target branch + +For example: + +```bash +❯ sqlmesh plan dev --select-model "git:feature" +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models: +├── Directly Modified: +│ └── sushi.items # Changed in feature branch +└── Indirectly Modified: + ├── sushi.order_items + └── sushi.daily_revenue +``` + +You can also combine git selection with upstream/downstream indicators: + +```bash +❯ sqlmesh plan dev --select-model "git:feature+" +# Selects changed models and their downstream dependencies + +❯ sqlmesh plan dev --select-model "+git:feature" +# Selects changed models and their upstream dependencies +``` + +#### Complex selections with logical operators + +The model selector supports combining multiple conditions using logical operators: + +- `&` (AND): Both conditions must be true +- `|` (OR): Either condition must be true +- `^` (NOT): Negates a condition + +For example: + +```bash +❯ sqlmesh plan dev --select-model "(tag:finance & ^tag:deprecated)" +# Selects models with finance tag that don't have deprecated tag + +❯ sqlmesh plan dev --select-model "(+model_a | model_b+)" +# Selects model_a and its upstream deps OR model_b and its downstream deps + +❯ sqlmesh plan dev --select-model "(tag:finance & git:main)" +# Selects changed models that also have the finance tag + +❯ sqlmesh plan dev --select-model "^(tag:test) & metrics.*" +# Selects models in metrics schema that don't have the test tag +``` + ### Backfill examples #### No backfill selection diff --git a/docs/guides/tablediff.md b/docs/guides/tablediff.md index 193c253a35..5c048462c2 100644 --- a/docs/guides/tablediff.md +++ b/docs/guides/tablediff.md @@ -122,6 +122,55 @@ Under the hood, SQLMesh stores temporary data in the database to perform the com The default schema for these temporary tables is `sqlmesh_temp` but can be changed with the `--temp-schema` option. The schema can be specified as a `CATALOG.SCHEMA` or `SCHEMA`. + +## Diffing multiple models across environments + +SQLMesh allows you to compare multiple models across environments at once using model selection expressions. This is useful when you want to validate changes across a set of related models or the entire project. + +To diff multiple models, use the `--select-model` (or `-m` for short) option with the table diff command: + +```bash +sqlmesh table_diff prod:dev --select-model "sqlmesh_example.*" +``` + +When diffing multiple models, SQLMesh will: + +1. Show the models returned by the selector that exist in both environments and have differences +2. Compare these models and display the data diff of each model + +> Note: Models will only be data diffed if there's a breaking change that impacts them. + +The `--select-model` option supports a powerful selection syntax that lets you choose models using patterns, tags, dependencies and git status. For complete details, see the [model selection guide](./model_selection.md). + +> Note: Surround your selection pattern in single or double quotes. Ex: `'*'`, `"sqlmesh_example.*"` + +Here are some common examples: + +```bash +# Select all models in a schema +sqlmesh table_diff prod:dev -m "sqlmesh_example.*" + +# Select a model and its dependencies +sqlmesh table_diff prod:dev -m "+model_name" # include upstream deps +sqlmesh table_diff prod:dev -m "model_name+" # include downstream deps + +# Select models by tag +sqlmesh table_diff prod:dev -m "tag:finance" + +# Select models with git changes +sqlmesh table_diff prod:dev -m "git:feature" + +# Use logical operators for complex selections +sqlmesh table_diff prod:dev -m "(metrics.* & ^tag:deprecated)" # models in the metrics schema that aren't deprecated + +# Combine multiple selectors +sqlmesh table_diff prod:dev -m "tag:finance" -m "metrics.*_daily" +``` + +When multiple selectors are provided, they are combined with OR logic, meaning a model matching any of the selectors will be included. + +> Note: All models being compared must have their `grain` defined that is unique and not null, as this is used to perform the join between the tables in the two environments. + ## Diffing tables or views Compare specific tables or views with the SQLMesh CLI interface by using the command `sqlmesh table_diff [source table]:[target table]`. diff --git a/docs/reference/cli.md b/docs/reference/cli.md index cd2c852b1b..d7b578116c 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -529,7 +529,7 @@ Options: ``` Usage: sqlmesh table_diff [OPTIONS] SOURCE:TARGET [MODEL] - Show the diff between two tables. + Show the diff between two tables or multiple models across two environments. Options: -o, --on TEXT The column to join on. Can be specified multiple @@ -548,6 +548,7 @@ Options: --temp-schema TEXT Schema used for temporary tables. It can be `CATALOG.SCHEMA` or `SCHEMA`. Default: `sqlmesh_temp` + -m, --select-model TEXT Select specific models to table diff. --help Show this message and exit. ``` diff --git a/docs/reference/notebook.md b/docs/reference/notebook.md index 60a7e16e7a..313b7295ee 100644 --- a/docs/reference/notebook.md +++ b/docs/reference/notebook.md @@ -293,7 +293,7 @@ Create a schema file containing external model schemas. %table_diff [--on [ON ...]] [--skip-columns [SKIP_COLUMNS ...]] [--model MODEL] [--where WHERE] [--limit LIMIT] [--show-sample] [--decimals DECIMALS] [--skip-grain-check] - [--temp-schema SCHEMA] + [--temp-schema SCHEMA] [--select-model [SELECT_MODEL ...]] SOURCE:TARGET Show the diff between two tables. @@ -320,6 +320,8 @@ options: --skip-grain-check Disable the check for a primary key (grain) that is missing or is not unique. --temp-schema SCHEMA The schema to use for temporary tables. + --select-model <[SELECT_MODEL ...]> + Select specific models to diff using a pattern. ``` #### model diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index cfec00c9b1..a06f9719a2 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -892,18 +892,26 @@ def create_external_models(obj: Context, **kwargs: t.Any) -> None: type=str, help="Schema used for temporary tables. It can be `CATALOG.SCHEMA` or `SCHEMA`. Default: `sqlmesh_temp`", ) +@click.option( + "--select-model", + "-m", + type=str, + multiple=True, + help="Specify one or more models to data diff. Use wildcards to diff multiple models. Ex: '*' (all models with applied plan diffs), 'demo.model+' (this and downstream models), 'git:feature_branch' (models with direct modifications in this branch only)", +) @click.pass_obj @error_handler @cli_analytics def table_diff( obj: Context, source_to_target: str, model: t.Optional[str], **kwargs: t.Any ) -> None: - """Show the diff between two tables.""" + """Show the diff between two tables or a selection of models when they are specified.""" source, target = source_to_target.split(":") + select_models = {model} if model else kwargs.pop("select_model", None) obj.table_diff( source=source, target=target, - model_or_snapshot=model, + select_models=select_models, **kwargs, ) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 98a49a1e6c..027760d10f 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -218,6 +218,57 @@ def show_model_difference_summary( """Displays a summary of differences for the given models.""" +class TableDiffConsole(abc.ABC): + """Console for displaying table differences""" + + @abc.abstractmethod + def show_table_diff( + self, + table_diffs: t.List[TableDiff], + show_sample: bool = True, + skip_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + ) -> None: + """Display the table diff between two or multiple tables.""" + + @abc.abstractmethod + def update_table_diff_progress(self, model: str) -> None: + """Update table diff progress bar""" + + @abc.abstractmethod + def start_table_diff_progress(self, models_to_diff: int) -> None: + """Start table diff progress bar""" + + @abc.abstractmethod + def start_table_diff_model_progress(self, model: str) -> None: + """Start table diff model progress""" + + @abc.abstractmethod + def stop_table_diff_progress(self, success: bool) -> None: + """Stop table diff progress bar""" + + @abc.abstractmethod + def show_table_diff_details( + self, + models_to_diff: t.List[str], + ) -> None: + """Display information about which tables are going to be diffed""" + + @abc.abstractmethod + def show_table_diff_summary(self, table_diff: TableDiff) -> None: + """Display information about the tables being diffed and how they are being joined""" + + @abc.abstractmethod + def show_schema_diff(self, schema_diff: SchemaDiff) -> None: + """Show table schema diff.""" + + @abc.abstractmethod + def show_row_diff( + self, row_diff: RowDiff, show_sample: bool = True, skip_grain_check: bool = False + ) -> None: + """Show table summary diff.""" + + class BaseConsole(abc.ABC): @abc.abstractmethod def log_error(self, message: str) -> None: @@ -258,6 +309,7 @@ class Console( JanitorConsole, EnvironmentsConsole, DifferenceConsole, + TableDiffConsole, BaseConsole, abc.ABC, ): @@ -424,20 +476,6 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: def loading_stop(self, id: uuid.UUID) -> None: """Stop loading for the given id.""" - @abc.abstractmethod - def show_table_diff_summary(self, table_diff: TableDiff) -> None: - """Display information about the tables being diffed and how they are being joined""" - - @abc.abstractmethod - def show_schema_diff(self, schema_diff: SchemaDiff) -> None: - """Show table schema diff.""" - - @abc.abstractmethod - def show_row_diff( - self, row_diff: RowDiff, show_sample: bool = True, skip_grain_check: bool = False - ) -> None: - """Show table summary diff.""" - class NoopConsole(Console): def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: @@ -648,6 +686,40 @@ def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: def loading_stop(self, id: uuid.UUID) -> None: pass + def show_table_diff( + self, + table_diffs: t.List[TableDiff], + show_sample: bool = True, + skip_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + ) -> None: + for table_diff in table_diffs: + self.show_table_diff_summary(table_diff) + self.show_schema_diff(table_diff.schema_diff()) + self.show_row_diff( + table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check), + show_sample=show_sample, + skip_grain_check=skip_grain_check, + ) + + def update_table_diff_progress(self, model: str) -> None: + pass + + def start_table_diff_progress(self, models_to_diff: int) -> None: + pass + + def start_table_diff_model_progress(self, model: str) -> None: + pass + + def stop_table_diff_progress(self, success: bool) -> None: + pass + + def show_table_diff_details( + self, + models_to_diff: t.List[str], + ) -> None: + pass + def show_table_diff_summary(self, table_diff: TableDiff) -> None: pass @@ -697,6 +769,7 @@ class TerminalConsole(Console): """A rich based implementation of the console.""" TABLE_DIFF_SOURCE_BLUE = "#0248ff" + TABLE_DIFF_TARGET_GREEN = "green" def __init__( self, @@ -746,6 +819,11 @@ def __init__( self.state_import_snapshot_task: t.Optional[TaskID] = None self.state_import_environment_task: t.Optional[TaskID] = None + self.table_diff_progress: t.Optional[Progress] = None + self.table_diff_model_progress: t.Optional[Progress] = None + self.table_diff_model_tasks: t.Dict[str, TaskID] = {} + self.table_diff_progress_live: t.Optional[Live] = None + self.verbosity = verbosity self.dialect = dialect self.ignore_warnings = ignore_warnings @@ -1544,39 +1622,58 @@ def _show_summary_tree_for( ) tree.add(self._limit_model_names(removed_tree, self.verbosity)) if modified_snapshot_ids: - direct = Tree("[bold][direct]Directly Modified:") - indirect = Tree("[bold][indirect]Indirectly Modified:") - metadata = Tree("[bold][metadata]Metadata Updated:") - for s_id in modified_snapshot_ids: - name = s_id.name - display_name = context_diff.snapshots[s_id].display_name( - environment_naming_info, - default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, - dialect=self.dialect, - ) - if context_diff.directly_modified(name): - direct.add( - f"[direct]{display_name}" - if no_diff - else Syntax(f"{display_name}\n{context_diff.text_diff(name)}", "sql") - ) - elif context_diff.indirectly_modified(name): - indirect.add(f"[indirect]{display_name}") - elif context_diff.metadata_updated(name): - metadata.add( - f"[metadata]{display_name}" - if no_diff - else Syntax(f"{display_name}\n{context_diff.text_diff(name)}", "sql") - ) + tree = self._add_modified_models( + context_diff, + modified_snapshot_ids, + tree, + environment_naming_info, + default_catalog, + no_diff, + ) - if direct.children: - tree.add(direct) - if indirect.children: - tree.add(self._limit_model_names(indirect, self.verbosity)) - if metadata.children: - tree.add(metadata) self._print(tree) + def _add_modified_models( + self, + context_diff: ContextDiff, + modified_snapshot_ids: t.Set[SnapshotId], + tree: Tree, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + no_diff: bool = True, + ) -> Tree: + direct = Tree("[bold][direct]Directly Modified:") + indirect = Tree("[bold][indirect]Indirectly Modified:") + metadata = Tree("[bold][metadata]Metadata Updated:") + for s_id in modified_snapshot_ids: + name = s_id.name + display_name = context_diff.snapshots[s_id].display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + if context_diff.directly_modified(name): + direct.add( + f"[direct]{display_name}" + if no_diff + else Syntax(f"{display_name}\n{context_diff.text_diff(name)}", "sql") + ) + elif context_diff.indirectly_modified(name): + indirect.add(f"[indirect]{display_name}") + elif context_diff.metadata_updated(name): + metadata.add( + f"[metadata]{display_name}" + if no_diff + else Syntax(f"{display_name}\n{context_diff.text_diff(name)}", "sql") + ) + if direct.children: + tree.add(direct) + if indirect.children: + tree.add(self._limit_model_names(indirect, self.verbosity)) + if metadata.children: + tree.add(metadata) + return tree + def _show_options_after_categorization( self, plan_builder: PlanBuilder, @@ -1882,6 +1979,71 @@ def loading_stop(self, id: uuid.UUID) -> None: self.loading_status[id].stop() del self.loading_status[id] + def show_table_diff_details( + self, + models_to_diff: t.List[str], + ) -> None: + """Display information about which tables are going to be diffed""" + + if models_to_diff: + m_tree = Tree("\n[b]Models to compare:") + for m in models_to_diff: + m_tree.add(f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m}[/{self.TABLE_DIFF_SOURCE_BLUE}]") + self._print(m_tree) + self._print("") + + def start_table_diff_progress(self, models_to_diff: int) -> None: + if not self.table_diff_progress: + self.table_diff_progress = make_progress_bar( + "Calculating model differences", self.console + ) + self.table_diff_model_progress = Progress( + TextColumn("{task.fields[view_name]}", justify="right"), + SpinnerColumn(spinner_name="simpleDots"), + console=self.console, + ) + + progress_table = Table.grid() + progress_table.add_row(self.table_diff_progress) + progress_table.add_row(self.table_diff_model_progress) + + self.table_diff_progress_live = Live(progress_table, refresh_per_second=10) + self.table_diff_progress_live.start() + + self.table_diff_model_task = self.table_diff_progress.add_task( + "Diffing", total=models_to_diff + ) + + def start_table_diff_model_progress(self, model: str) -> None: + if self.table_diff_model_progress and model not in self.table_diff_model_tasks: + self.table_diff_model_tasks[model] = self.table_diff_model_progress.add_task( + f"Diffing {model}...", + view_name=model, + total=1, + ) + + def update_table_diff_progress(self, model: str) -> None: + if self.table_diff_progress: + self.table_diff_progress.update(self.table_diff_model_task, refresh=True, advance=1) + if self.table_diff_model_progress and model in self.table_diff_model_tasks: + model_task_id = self.table_diff_model_tasks[model] + self.table_diff_model_progress.remove_task(model_task_id) + + def stop_table_diff_progress(self, success: bool) -> None: + if self.table_diff_progress_live: + self.table_diff_progress_live.stop() + self.table_diff_progress_live = None + self.log_status_update("") + + if success: + self.log_success(f"Table diff completed successfully!") + else: + self.log_error("Table diff failed!") + + self.table_diff_progress = None + self.table_diff_model_progress = None + self.table_diff_model_tasks = {} + def show_table_diff_summary(self, table_diff: TableDiff) -> None: tree = Tree("\n[b]Table Diff") @@ -1897,7 +2059,9 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None: ) envs.add(source) - target = Tree(f"Target: [green]{table_diff.target_alias}[/green]") + target = Tree( + f"Target: [{self.TABLE_DIFF_TARGET_GREEN}]{table_diff.target_alias}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) envs.add(target) tree.add(envs) @@ -1907,7 +2071,9 @@ def show_table_diff_summary(self, table_diff: TableDiff) -> None: tables.add( f"Source: [{self.TABLE_DIFF_SOURCE_BLUE}]{table_diff.source}[/{self.TABLE_DIFF_SOURCE_BLUE}]" ) - tables.add(f"Target: [green]{table_diff.target}[/green]") + tables.add( + f"Target: [{self.TABLE_DIFF_TARGET_GREEN}]{table_diff.target}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) tree.add(tables) @@ -1928,7 +2094,7 @@ def show_schema_diff(self, schema_diff: SchemaDiff) -> None: if schema_diff.target_alias: target_name = schema_diff.target_alias.upper() - first_line = f"\n[b]Schema Diff Between '[{self.TABLE_DIFF_SOURCE_BLUE}]{source_name}[/{self.TABLE_DIFF_SOURCE_BLUE}]' and '[green]{target_name}[/green]'" + first_line = f"\n[b]Schema Diff Between '[{self.TABLE_DIFF_SOURCE_BLUE}]{source_name}[/{self.TABLE_DIFF_SOURCE_BLUE}]' and '[{self.TABLE_DIFF_TARGET_GREEN}]{target_name}[/{self.TABLE_DIFF_TARGET_GREEN}]'" if schema_diff.model_name: first_line = ( first_line + f" environments for model '[blue]{schema_diff.model_name}[/blue]'" @@ -2032,7 +2198,7 @@ def show_row_diff( column_styles = { source_name: self.TABLE_DIFF_SOURCE_BLUE, - target_name: "green", + target_name: self.TABLE_DIFF_TARGET_GREEN, } for column, [source_column, target_column] in columns.items(): @@ -2089,6 +2255,57 @@ def show_row_diff( self.console.print(f"\n[b][green]{target_name} ONLY[/green] sample rows:[/b]") self.console.print(row_diff.t_sample.to_string(index=False), end="\n\n") + def show_table_diff( + self, + table_diffs: t.List[TableDiff], + show_sample: bool = True, + skip_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + ) -> None: + """ + Display the table diff between all mismatched tables. + """ + if len(table_diffs) > 1: + mismatched_tables = [] + fully_matched = [] + for table_diff in table_diffs: + if ( + table_diff.schema_diff().source_schema == table_diff.schema_diff().target_schema + ) and ( + table_diff.row_diff( + temp_schema=temp_schema, skip_grain_check=skip_grain_check + ).full_match_pct + == 100 + ): + fully_matched.append(table_diff) + else: + mismatched_tables.append(table_diff) + table_diffs = mismatched_tables if mismatched_tables else [] + if fully_matched: + m_tree = Tree("\n[b]Identical Tables") + for m in fully_matched: + m_tree.add( + f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m.source}[/{self.TABLE_DIFF_SOURCE_BLUE}] - [{self.TABLE_DIFF_TARGET_GREEN}]{m.target}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) + self._print(m_tree) + + if mismatched_tables: + m_tree = Tree("\n[b]Mismatched Tables") + for m in mismatched_tables: + m_tree.add( + f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m.source}[/{self.TABLE_DIFF_SOURCE_BLUE}] - [{self.TABLE_DIFF_TARGET_GREEN}]{m.target}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) + self._print(m_tree) + + for table_diff in table_diffs: + self.show_table_diff_summary(table_diff) + self.show_schema_diff(table_diff.schema_diff()) + self.show_row_diff( + table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check), + show_sample=show_sample, + skip_grain_check=skip_grain_check, + ) + def print_environments(self, environments_summary: t.List[EnvironmentSummary]) -> None: """Prints all environment names along with expiry datetime.""" output = [ @@ -2643,23 +2860,11 @@ def show_model_difference_summary( no_diff: Hide the actual SQL differences. """ added_snapshots = {context_diff.snapshots[s_id] for s_id in context_diff.added} - added_snapshot_models = {s for s in added_snapshots if s.is_model} - if added_snapshot_models: + if added_snapshots: self._print("\n**Added Models:**") - added_models = sorted(added_snapshot_models) - list_length = len(added_models) - if ( - self.verbosity < Verbosity.VERY_VERBOSE - and list_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD - ): - self._print(added_models[0]) - self._print(f"- `.... {list_length - 2} more ....`\n") - self._print(added_models[-1]) - else: - for snapshot in added_models: - self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" - ) + self._print_models_with_threshold( + environment_naming_info, {s for s in added_snapshots if s.is_model}, default_catalog + ) added_snapshot_audits = {s for s in added_snapshots if s.is_audit} if added_snapshot_audits: @@ -2670,23 +2875,13 @@ def show_model_difference_summary( ) removed_snapshot_table_infos = set(context_diff.removed_snapshots.values()) - removed_model_snapshot_table_infos = {s for s in removed_snapshot_table_infos if s.is_model} - if removed_model_snapshot_table_infos: + if removed_snapshot_table_infos: self._print("\n**Removed Models:**") - removed_models = sorted(removed_model_snapshot_table_infos) - list_length = len(removed_models) - if ( - self.verbosity < Verbosity.VERY_VERBOSE - and list_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD - ): - self._print(removed_models[0]) - self._print(f"- `.... {list_length - 2} more ....`\n") - self._print(removed_models[-1]) - else: - for snapshot_table_info in removed_models: - self._print( - f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" - ) + self._print_models_with_threshold( + environment_naming_info, + {s for s in removed_snapshot_table_infos if s.is_model}, + default_catalog, + ) removed_audit_snapshot_table_infos = {s for s in removed_snapshot_table_infos if s.is_audit} if removed_audit_snapshot_table_infos: @@ -2700,48 +2895,72 @@ def show_model_difference_summary( current_snapshot for current_snapshot, _ in context_diff.modified_snapshots.values() } if modified_snapshots: - directly_modified = [] - indirectly_modified = [] - metadata_modified = [] - for snapshot in modified_snapshots: - if context_diff.directly_modified(snapshot.name): - directly_modified.append(snapshot) - elif context_diff.indirectly_modified(snapshot.name): - indirectly_modified.append(snapshot) - elif context_diff.metadata_updated(snapshot.name): - metadata_modified.append(snapshot) - if directly_modified: - self._print("\n**Directly Modified:**") - for snapshot in sorted(directly_modified): - self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" - ) - if not no_diff: - self._print(f"```diff\n{context_diff.text_diff(snapshot.name)}\n```") - if indirectly_modified: - self._print("\n**Indirectly Modified:**") - indirectly_modified = sorted(indirectly_modified) - modified_length = len(indirectly_modified) - if ( - self.verbosity < Verbosity.VERY_VERBOSE - and modified_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD - ): - self._print( - f"- `{indirectly_modified[0].display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`\n" - f"- `.... {modified_length - 2} more ....`\n" - f"- `{indirectly_modified[-1].display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" - ) - else: - for snapshot in indirectly_modified: - self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" - ) - if metadata_modified: - self._print("\n**Metadata Updated:**") - for snapshot in sorted(metadata_modified): - self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" - ) + self._print_modified_models( + context_diff, modified_snapshots, environment_naming_info, default_catalog, no_diff + ) + + def _print_models_with_threshold( + self, + environment_naming_info: EnvironmentNamingInfo, + snapshot_table_infos: t.Set[SnapshotInfoLike], + default_catalog: t.Optional[str] = None, + ) -> None: + models = sorted(snapshot_table_infos) + list_length = len(models) + if ( + self.verbosity < Verbosity.VERY_VERBOSE + and list_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD + ): + self._print( + f"- `{models[0].display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" + ) + self._print(f"- `.... {list_length - 2} more ....`\n") + self._print( + f"- `{models[-1].display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" + ) + else: + for snapshot_table_info in models: + self._print( + f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" + ) + + def _print_modified_models( + self, + context_diff: ContextDiff, + modified_snapshots: t.Set[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + no_diff: bool = True, + ) -> None: + directly_modified = [] + indirectly_modified = [] + metadata_modified = [] + for snapshot in modified_snapshots: + if context_diff.directly_modified(snapshot.name): + directly_modified.append(snapshot) + elif context_diff.indirectly_modified(snapshot.name): + indirectly_modified.append(snapshot) + elif context_diff.metadata_updated(snapshot.name): + metadata_modified.append(snapshot) + if directly_modified: + self._print("\n**Directly Modified:**") + for snapshot in sorted(directly_modified): + self._print( + f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" + ) + if not no_diff: + self._print(f"```diff\n{context_diff.text_diff(snapshot.name)}\n```") + if indirectly_modified: + self._print("\n**Indirectly Modified:**") + self._print_models_with_threshold( + environment_naming_info, set(indirectly_modified), default_catalog + ) + if metadata_modified: + self._print("\n**Metadata Updated:**") + for snapshot in sorted(metadata_modified): + self._print( + f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" + ) def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> None: """Displays the models with missing dates.""" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 18aadde20f..6e1ff1eca1 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -116,6 +116,7 @@ ) from sqlmesh.core.user import User from sqlmesh.utils import UniqueKeyDict, Verbosity +from sqlmesh.utils.concurrency import concurrent_apply_to_values from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike, now_ds, to_timestamp, format_tz_datetime, now_timestamp from sqlmesh.utils.errors import ( @@ -1568,9 +1569,9 @@ def table_diff( self, source: str, target: str, - on: t.List[str] | exp.Condition | None = None, - skip_columns: t.List[str] | None = None, - model_or_snapshot: t.Optional[ModelOrSnapshot] = None, + on: t.Optional[t.List[str] | exp.Condition] = None, + skip_columns: t.Optional[t.List[str]] = None, + select_models: t.Optional[t.Collection[str]] = None, where: t.Optional[str | exp.Condition] = None, limit: int = 20, show: bool = True, @@ -1578,7 +1579,7 @@ def table_diff( decimals: int = 3, skip_grain_check: bool = False, temp_schema: t.Optional[str] = None, - ) -> TableDiff: + ) -> t.List[TableDiff]: """Show a diff between two tables. Args: @@ -1587,7 +1588,7 @@ def table_diff( on: The join condition, table aliases must be "s" and "t" for source and target. If omitted, the table's grain will be used. skip_columns: The columns to skip when computing the table diff. - model_or_snapshot: The model or snapshot to use when environments are passed in. + select_models: The models or snapshots to use when environments are passed in. where: An optional where statement to filter results. limit: The limit of the sample dataframe. show: Show the table diff output in the console. @@ -1597,53 +1598,191 @@ def table_diff( temp_schema: The schema to use for temporary tables. Returns: - The TableDiff object containing schema and summary differences. + The list of TableDiff objects containing schema and summary differences. """ - source_alias, target_alias = source, target - adapter = self.engine_adapter + table_diffs: t.List[TableDiff] = [] - if model_or_snapshot: - model = self.get_model(model_or_snapshot, raise_if_missing=True) - adapter = self._get_engine_adapter(model.gateway) + # Diffs multiple or a single model across two environments + if select_models: source_env = self.state_reader.get_environment(source) target_env = self.state_reader.get_environment(target) - if not source_env: raise SQLMeshError(f"Could not find environment '{source}'") if not target_env: - raise SQLMeshError(f"Could not find environment '{target}')") - - # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point - # to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews - source = next( - snapshot for snapshot in source_env.snapshots if snapshot.name == model.fqn - ).qualified_view_name.for_environment(source_env.naming_info, adapter.dialect) - - target = next( - snapshot for snapshot in target_env.snapshots if snapshot.name == model.fqn - ).qualified_view_name.for_environment(target_env.naming_info, adapter.dialect) - - source_alias = source_env.name - target_alias = target_env.name - - if not on: - on = [] - for expr in [ref.expression for ref in model.all_references if ref.unique]: - if isinstance(expr, exp.Tuple): - on.extend( - [key.this.sql(dialect=adapter.dialect) for key in expr.expressions] + raise SQLMeshError(f"Could not find environment '{target}'") + criteria = ", ".join(f"'{c}'" for c in select_models) + try: + selected_models = self._new_selector().expand_model_selections(select_models) + if not selected_models: + self.console.log_status_update( + f"No models matched the selection criteria: {criteria}" + ) + except Exception as e: + raise SQLMeshError(e) + + models_to_diff: t.List[ + t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Condition]] + ] = [] + models_without_grain: t.List[Model] = [] + source_snapshots_to_name = { + snapshot.name: snapshot for snapshot in source_env.snapshots + } + target_snapshots_to_name = { + snapshot.name: snapshot for snapshot in target_env.snapshots + } + + for model_fqn in selected_models: + model = self._models[model_fqn] + adapter = self._get_engine_adapter(model.gateway) + source_snapshot = source_snapshots_to_name.get(model.fqn) + target_snapshot = target_snapshots_to_name.get(model.fqn) + + if target_snapshot and source_snapshot: + if ( + source_snapshot.fingerprint.data_hash + != target_snapshot.fingerprint.data_hash + ): + # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point + # to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews + source = source_snapshot.qualified_view_name.for_environment( + source_env.naming_info, adapter.dialect ) - else: - # Handle a single Column or Paren expression - on.append(expr.this.sql(dialect=adapter.dialect)) + target = target_snapshot.qualified_view_name.for_environment( + target_env.naming_info, adapter.dialect + ) + model_on = on or model.on + models_to_diff.append((model, adapter, source, target, model_on)) + if not model_on: + models_without_grain.append(model) + + if models_to_diff: + self.console.show_table_diff_details( + [model[0].name for model in models_to_diff], + ) + if models_without_grain: + model_names = "\n".join( + f"─ {model.name} \n at '{model._path}'" for model in models_without_grain + ) + raise SQLMeshError( + f"SQLMesh doesn't know how to join the tables for the following models:\n{model_names}\n" + "\nPlease specify the `grain` in each model definition. Must be unique and not null." + ) + + self.console.start_table_diff_progress(len(models_to_diff)) + try: + tasks_num = min(len(models_to_diff), self.concurrent_tasks) + table_diffs = concurrent_apply_to_values( + list(models_to_diff), + lambda model_info: self._model_diff( + model=model_info[0], + adapter=model_info[1], + source=model_info[2], + target=model_info[3], + on=model_info[4], + source_alias=source_env.name, + target_alias=target_env.name, + limit=limit, + decimals=decimals, + skip_columns=skip_columns, + where=where, + show=show, + temp_schema=temp_schema, + skip_grain_check=skip_grain_check, + ), + tasks_num=tasks_num, + ) + self.console.stop_table_diff_progress(success=True) + except: + self.console.stop_table_diff_progress(success=False) + raise + elif selected_models: + self.console.log_status_update( + f"No models contain differences with the selection criteria: {criteria}" + ) + + else: + table_diffs = [ + self._table_diff( + source=source, + target=target, + source_alias=source, + target_alias=target, + limit=limit, + decimals=decimals, + adapter=self.engine_adapter, + on=on, + skip_columns=skip_columns, + where=where, + ) + ] + if show: + self.console.show_table_diff(table_diffs, show_sample, skip_grain_check, temp_schema) + + return table_diffs + + def _model_diff( + self, + model: Model, + adapter: EngineAdapter, + source: str, + target: str, + source_alias: str, + target_alias: str, + limit: int, + decimals: int, + on: t.Optional[t.List[str] | exp.Condition] = None, + skip_columns: t.Optional[t.List[str]] = None, + where: t.Optional[str | exp.Condition] = None, + show: bool = True, + temp_schema: t.Optional[str] = None, + skip_grain_check: bool = False, + ) -> TableDiff: + self.console.start_table_diff_model_progress(model.name) + + table_diff = self._table_diff( + on=on, + skip_columns=skip_columns, + where=where, + limit=limit, + decimals=decimals, + model=model, + adapter=adapter, + source=source, + target=target, + source_alias=source_alias, + target_alias=target_alias, + ) + + if show: + # Trigger row_diff in parallel execution so it's available for ordered display later + table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check) + + self.console.update_table_diff_progress(model.name) + + return table_diff + + def _table_diff( + self, + source: str, + target: str, + source_alias: str, + target_alias: str, + limit: int, + decimals: int, + adapter: EngineAdapter, + on: t.Optional[t.List[str] | exp.Condition] = None, + model: t.Optional[Model] = None, + skip_columns: t.Optional[t.List[str]] = None, + where: t.Optional[str | exp.Condition] = None, + ) -> TableDiff: if not on: raise SQLMeshError( "SQLMesh doesn't know how to join the two tables. Specify the `grains` in each model definition or pass join column names in separate `-o` flags." ) - table_diff = TableDiff( + return TableDiff( adapter=adapter.with_log_level(logger.getEffectiveLevel()), source=source, target=target, @@ -1652,20 +1791,11 @@ def table_diff( where=where, source_alias=source_alias, target_alias=target_alias, - model_name=model.name if model_or_snapshot else None, - model_dialect=model.dialect if model_or_snapshot else None, limit=limit, decimals=decimals, + model_name=model.name if model else None, + model_dialect=model.dialect if model else None, ) - if show: - self.console.show_table_diff_summary(table_diff) - self.console.show_schema_diff(table_diff.schema_diff()) - self.console.show_row_diff( - table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check), - show_sample=show_sample, - skip_grain_check=skip_grain_check, - ) - return table_diff @python_api_analytics def get_dag( diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 29c82bc33f..85d99992fc 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -450,6 +450,20 @@ def all_references(self) -> t.List[Reference]: Reference(model_name=self.name, expression=e, unique=True) for e in self.references ] + @property + def on(self) -> t.List[str]: + """The grains to be used as join condition in table_diff.""" + + on: t.List[str] = [] + for expr in [ref.expression for ref in self.all_references if ref.unique]: + if isinstance(expr, exp.Tuple): + on.extend([key.this.sql(dialect=self.dialect) for key in expr.expressions]) + else: + # Handle a single Column or Paren expression + on.append(expr.this.sql(dialect=self.dialect)) + + return on + @property def managed_columns(self) -> t.Dict[str, exp.DataType]: return getattr(self.kind, "managed_columns", {}) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 80554a60a2..4b4ec7f23e 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -686,6 +686,12 @@ def create_external_models(self, context: Context, line: str) -> None: default=3, help="The number of decimal places to keep when comparing floating point columns. Default: 3", ) + @argument( + "--select-model", + type=str, + nargs="*", + help="Specify one or more models to data diff. Use wildcards to diff multiple models. Ex: '*' (all models with applied plan diffs), 'demo.model+' (this and downstream models), 'git:feature_branch' (models with direct modifications in this branch only)", + ) @argument( "--skip-grain-check", action="store_true", @@ -700,12 +706,13 @@ def table_diff(self, context: Context, line: str) -> None: """ args = parse_argstring(self.table_diff, line) source, target = args.source_to_target.split(":") + select_models = {args.model} if args.model else args.select_model or None context.table_diff( source=source, target=target, on=args.on, skip_columns=args.skip_columns, - model_or_snapshot=args.model, + select_models=select_models, where=args.where, limit=args.limit, show_sample=args.show_sample, diff --git a/sqlmesh/utils/git.py b/sqlmesh/utils/git.py index 9a558dec9a..00410e776c 100644 --- a/sqlmesh/utils/git.py +++ b/sqlmesh/utils/git.py @@ -27,7 +27,23 @@ def _execute_list_output(self, commands: t.List[str], base_path: Path) -> t.List return [(base_path / o).absolute() for o in self._execute(commands).split("\n") if o] def _execute(self, commands: t.List[str]) -> str: - result = subprocess.run(["git"] + commands, cwd=self._work_dir, stdout=subprocess.PIPE) + result = subprocess.run( + ["git"] + commands, + cwd=self._work_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # If the Git command failed, extract and raise the error message in the console + if result.returncode != 0: + stderr_output = result.stderr.decode("utf-8").strip() + error_message = next( + (line for line in stderr_output.splitlines() if line.lower().startswith("fatal:")), + stderr_output, + ) + raise RuntimeError(f"Git error: {error_message}") + return result.stdout.decode("utf-8").strip() @cached_property diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index f01ad4a6d7..11a98524e8 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -110,8 +110,8 @@ def test_data_diff(sushi_context_fixed_date, capsys, caplog): source="source_dev", target="target_dev", on=exp.condition("s.customer_id = t.customer_id AND s.event_date = t.event_date"), - model_or_snapshot="sushi.customer_revenue_by_day", - ) + select_models={"sushi.customer_revenue_by_day"}, + )[0] # verify queries were actually logged to the log file, this helps immensely with debugging console_output = capsys.readouterr() @@ -169,7 +169,7 @@ def test_data_diff_decimals(sushi_context_fixed_date): source="table_diff_source", target="table_diff_target", on=["key"], - ) + )[0] assert diff.row_diff().full_match_count == 3 assert diff.row_diff().partial_match_count == 0 @@ -178,7 +178,7 @@ def test_data_diff_decimals(sushi_context_fixed_date): target="table_diff_target", on=["key"], decimals=4, - ) + )[0] row_diff = diff.row_diff() joined_sample_columns = row_diff.joined_sample.columns @@ -291,9 +291,9 @@ def test_grain_check(sushi_context_fixed_date): source="source_dev", target="target_dev", on=["'key_1'", "key_2"], - model_or_snapshot="SUSHI.GRAIN_ITEMS", + select_models={"memory.sushi*"}, skip_grain_check=False, - ) + )[0] row_diff = diff.row_diff() assert row_diff.full_match_count == 7 @@ -420,9 +420,9 @@ def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context) sushi_context_fixed_date.plan(environment="unit_test", auto_apply=True, include_unmodified=True) table_diff = sushi_context_fixed_date.table_diff( - source="unit_test", target="prod", model_or_snapshot="sushi.waiter_revenue_by_day" - ) - + source="unit_test", target="prod", select_models={"sushi.waiter_revenue_by_day"} + )[0] + assert isinstance(table_diff, TableDiff) assert table_diff.source == "memory.sushi__unit_test.waiter_revenue_by_day" assert table_diff.target == "memory.sushi.waiter_revenue_by_day" @@ -562,3 +562,136 @@ def test_data_diff_nullable_booleans(): """ assert strip_ansi_codes(output) == expected_output.strip() + + +@pytest.mark.slow +def test_data_diff_multiple_models(sushi_context_fixed_date, capsys, caplog): + # Create first analytics model + expressions = d.parse( + """ + MODEL (name memory.sushi.analytics_1, kind full, grain(key), tags (finance),); + SELECT + key, + value, + FROM + (VALUES + (1, 3), + (2, 4), + ) AS t (key, value) + """ + ) + model_s = load_sql_based_model(expressions, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s) + + # Create second analytics model from analytics_1 + expressions_2 = d.parse( + """ + MODEL (name memory.sushi.analytics_2, kind full, grain(key), tags (finance),); + SELECT + key, + value as amount, + FROM + memory.sushi.analytics_1 + """ + ) + model_s2 = load_sql_based_model(expressions_2, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s2) + + sushi_context_fixed_date.plan( + "source_dev", + no_prompts=True, + auto_apply=True, + skip_tests=True, + start="2023-01-31", + end="2023-01-31", + ) + + # Modify first model + model = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."ANALYTICS_1"'] + modified_model = model.dict() + modified_model["query"] = ( + exp.select("*") + .from_(model.query.subquery()) + .union("SELECT key, value FROM (VALUES (1, 6),(2,3),) AS t (key, value)") + ) + modified_sqlmodel = SqlModel(**modified_model) + sushi_context_fixed_date.upsert_model(modified_sqlmodel) + + # Modify second model + model2 = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."ANALYTICS_2"'] + modified_model2 = model2.dict() + modified_model2["query"] = ( + exp.select("*") + .from_(model2.query.subquery()) + .union("SELECT key, amount FROM (VALUES (5, 150.2),(6,250.2),) AS t (key, amount)") + ) + modified_sqlmodel2 = SqlModel(**modified_model2) + sushi_context_fixed_date.upsert_model(modified_sqlmodel2) + + sushi_context_fixed_date.auto_categorize_changes = CategorizerConfig( + sql=AutoCategorizationMode.FULL + ) + sushi_context_fixed_date.plan( + "target_dev", + create_from="source_dev", + no_prompts=True, + auto_apply=True, + skip_tests=True, + start="2023-01-31", + end="2023-01-31", + ) + + # Get diffs for both models + selector = {"tag:finance & memory.sushi.analytics*"} + diffs = sushi_context_fixed_date.table_diff( + source="source_dev", + target="target_dev", + on=["key"], + select_models=selector, + skip_grain_check=False, + ) + + assert len(diffs) == 2 + + # Check analytics_1 diff + diff1 = next(d for d in diffs if "ANALYTICS_1" in d.source) + row_diff1 = diff1.row_diff() + assert row_diff1.full_match_count == 2 + assert row_diff1.full_match_pct == 50.0 + assert row_diff1.s_only_count == 0 + assert row_diff1.t_only_count == 0 + assert row_diff1.stats["join_count"] == 4 + assert row_diff1.stats["null_grain_count"] == 0 + assert row_diff1.stats["s_count"] == 4 + assert row_diff1.stats["distinct_count_s"] == 2 + assert row_diff1.stats["t_count"] == 4 + assert row_diff1.stats["distinct_count_t"] == 2 + assert row_diff1.s_sample.shape == (0, 2) + assert row_diff1.t_sample.shape == (0, 2) + + # Check analytics_2 diff + diff2 = next(d for d in diffs if "ANALYTICS_2" in d.source) + row_diff2 = diff2.row_diff() + assert row_diff2.full_match_count == 2 + assert row_diff2.full_match_pct == 40.0 + assert row_diff2.s_only_count == 0 + assert row_diff2.t_only_count == 2 + assert row_diff2.stats["join_count"] == 4 + assert row_diff2.stats["null_grain_count"] == 0 + assert row_diff2.stats["s_count"] == 4 + assert row_diff2.stats["distinct_count_s"] == 2 + assert row_diff2.stats["t_count"] == 6 + assert row_diff2.stats["distinct_count_t"] == 4 + assert row_diff2.s_sample.shape == (0, 2) + assert row_diff2.t_sample.shape == (2, 2) + + # This selector shouldn't return any diffs since both models have this tag + selector = {"^tag:finance"} + diffs = sushi_context_fixed_date.table_diff( + source="source_dev", + target="target_dev", + on=["key"], + select_models=selector, + skip_grain_check=False, + ) + assert len(diffs) == 0 diff --git a/tests/integrations/jupyter/test_magics.py b/tests/integrations/jupyter/test_magics.py index 9519668ab3..6bfc4b8df3 100644 --- a/tests/integrations/jupyter/test_magics.py +++ b/tests/integrations/jupyter/test_magics.py @@ -641,26 +641,10 @@ def test_table_diff(notebook, loaded_sushi_context, convert_all_html_output_to_t assert not output.stdout assert not output.stderr - assert len(output.outputs) == 5 + + assert len(output.outputs) == 1 assert convert_all_html_output_to_text(output) == [ - """Table Diff -├── Model: -│ └── sushi.top_waiters -├── Environment: -│ ├── Source: dev -│ └── Target: prod -├── Tables: -│ ├── Source: memory.sushi__dev.top_waiters -│ └── Target: memory.sushi.top_waiters -└── Join On: - └── waiter_id""", - """Schema Diff Between 'DEV' and 'PROD' environments for model 'sushi.top_waiters': -└── Schemas match""", - """Row Counts: -└── FULL MATCH: 8 rows (100.0%)""", - """COMMON ROWS column comparison stats:""", - """pct_match -revenue 100.0""", + "No models contain differences with the selection criteria: 'sushi.top_waiters'" ] diff --git a/tests/web/test_main.py b/tests/web/test_main.py index 5ecbba8d1f..99b268cf0d 100644 --- a/tests/web/test_main.py +++ b/tests/web/test_main.py @@ -520,8 +520,7 @@ def test_table_diff(client: TestClient, web_sushi_context: Context) -> None: }, ) assert response.status_code == 200 - assert "schema_diff" in response.json() - assert "row_diff" in response.json() + assert response.json() == None def test_test(client: TestClient, web_sushi_context: Context) -> None: diff --git a/web/server/api/endpoints/table_diff.py b/web/server/api/endpoints/table_diff.py index 7d0da2cc01..3439327102 100644 --- a/web/server/api/endpoints/table_diff.py +++ b/web/server/api/endpoints/table_diff.py @@ -23,17 +23,22 @@ def get_table_diff( temp_schema: t.Optional[str] = None, limit: int = 20, context: Context = Depends(get_loaded_context), -) -> TableDiff: +) -> t.Optional[TableDiff]: """Calculate differences between tables, taking into account schema and row level differences.""" - diff = context.table_diff( + table_diffs = context.table_diff( source=source, target=target, on=exp.condition(on) if on else None, - model_or_snapshot=model_or_snapshot, + select_models={model_or_snapshot} if model_or_snapshot else None, where=where, limit=limit, show=False, ) + + if not table_diffs: + return None + diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs + _schema_diff = diff.schema_diff() _row_diff = diff.row_diff(temp_schema=temp_schema) schema_diff = SchemaDiff(