diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index a5786d1..62daa56 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -368,6 +368,7 @@ def get_type(data_type, field_size): "long": types.BigInteger, "float": types.Float, "double": types.Numeric, + "big_decimal": types.Numeric, # BOOLEAN, is added after release 0.7.1. # In release 0.7.1 and older releases, BOOLEAN is equivalent to STRING. "boolean": types.Boolean, diff --git a/tests/unit/test_sqlalchemy.py b/tests/unit/test_sqlalchemy.py index de346fa..0752f22 100644 --- a/tests/unit/test_sqlalchemy.py +++ b/tests/unit/test_sqlalchemy.py @@ -11,11 +11,12 @@ """ from unittest import TestCase +import json import responses from sqlalchemy import ( BigInteger, Column, Integer, MetaData, String, Table, - column, select, + column, select, types, ) from sqlalchemy.engine import make_url @@ -215,6 +216,31 @@ def test_gets_columns_with_different_default_values(self): }, ]) + @responses.activate + def test_gets_columns_with_big_decimal_type(self): + table_name = 'some-table' + url = f'{self.dialect._controller}/tables/{table_name}/schema' + responses.get(url, json={ + 'tables': [table_name], + 'timeFieldSpec': {}, + 'dimensionFieldSpecs': [{ + 'name': 'price', + 'dataType': 'BIG_DECIMAL', + 'defaultNullValue': '0.00', + }], + }) + + columns = self.dialect.get_columns('conn', table_name) + + self.assertEqual(columns, [ + { + 'default': '0.00', + 'name': 'price', + 'nullable': True, + 'type': types.Numeric, + }, + ]) + @responses.activate def test_gets_columns_with_time_spec(self): table_name = 'some-table' @@ -269,18 +295,29 @@ def test_gets_unique_constraints(self): self.assertEqual(result, []) - def test_gets_view_definition(self): - self.assertIsNone(self.dialect.get_view_definition('conn', 'table')) - - def test_cannot_rollback(self): - self.assertIsNone(self.dialect.do_rollback('conn')) - def test_checks_unicode_returns(self): self.assertTrue(self.dialect._check_unicode_returns('conn')) def test_checks_unicode_description(self): self.assertTrue(self.dialect._check_unicode_description('conn')) + def test_json_deserializer(self): + # Test with string input + self.assertEqual(self.dialect._json_deserializer('{"key": "value"}'), {"key": "value"}) + # Test with bytes input + self.assertEqual(self.dialect._json_deserializer(b'{"key": "value"}'), {"key": "value"}) + # Test with already parsed JSON + self.assertEqual(self.dialect._json_deserializer({"key": "value"}), {"key": "value"}) + # Test with non-JSON string - should raise JSONDecodeError + with self.assertRaises(json.JSONDecodeError): + self.dialect._json_deserializer("not json") + + def test_get_view_definition(self): + self.assertIsNone(self.dialect.get_view_definition('conn', 'view_name')) + + def test_do_rollback(self): + self.assertIsNone(self.dialect.do_rollback('conn')) + class PinotMultiStageDialectTest(PinotTestCase): def setUp(self) -> None: