diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 78a011421..5bb191ca2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -200,6 +200,12 @@ def read(self) -> Optional[OAuthToken]: STRUCT is returned as Dict[str, Any] ARRAY is returned as numpy.ndarray When False, complex types are returned as a strings. These are generally deserializable as JSON. + :param enable_metric_view_metadata: `bool`, optional (default is False) + When True, enables metric view metadata support by setting the + spark.sql.thriftserver.metadata.metricview.enabled session configuration. + This allows + 1. cursor.tables() to return METRIC_VIEW table type + 2. cursor.columns() to return "measure" column type """ # Internal arguments in **kwargs: @@ -248,6 +254,14 @@ def read(self) -> Optional[OAuthToken]: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} + enable_metric_view_metadata = kwargs.get("enable_metric_view_metadata", False) + if enable_metric_view_metadata: + if session_configuration is None: + session_configuration = {} + session_configuration[ + "spark.sql.thriftserver.metadata.metricview.enabled" + ] = "true" + self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index c135a846b..1d70ec4c4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -163,6 +163,17 @@ def test_configuration_passthrough(self, mock_client_class): call_kwargs = mock_client_class.return_value.open_session.call_args[1] assert call_kwargs["session_configuration"] == mock_session_config + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_enable_metric_view_metadata_parameter(self, mock_client_class): + """Test that enable_metric_view_metadata parameter sets the correct session configuration.""" + databricks.sql.connect( + enable_metric_view_metadata=True, **self.DUMMY_CONNECTION_ARGS + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + expected_config = {"spark.sql.thriftserver.metadata.metricview.enabled": "true"} + assert call_kwargs["session_configuration"] == expected_config + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock()