Skip to content

Commit 82119f7

Browse files
authored
FIX: Improvement to parameter type inference and handling (#215)
### Work Item / Issue Reference <!-- IMPORTANT: Please follow the PR template guidelines below. For mssql-python maintainers: Insert your ADO Work Item ID below (e.g. AB#37452) For external contributors: Insert Github Issue number below (e.g. #149) Only one reference is required - either GitHub issue OR ADO Work Item. --> <!-- mssql-python maintainers: ADO Work Item --> > [AB#38383](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/38383) <!-- External contributors: GitHub Issue --> > GitHub Issue: #<ISSUE_NUMBER> ------------------------------------------------------------------- ### Summary <!-- Insert your summary of changes below. Minimum 10 characters required. --> This pull request refactors the logic for inferring SQL types from Python parameters in the `mssql_python/cursor.py` module, especially for batch operations using `executemany`. The main improvements are more accurate type detection for integer columns by considering the minimum and maximum values in the data, and a cleaner separation of concerns in the codebase. **Improvements to type inference and parameter handling:** * Refactored the type mapping logic in `_map_sql_type` to use `min_val` and `max_val` for integer columns, allowing for more accurate type selection based on the actual range of values in the data. * Updated `_create_parameter_types_list` to accept and forward `min_val` and `max_val`, supporting the improved type inference for batch operations. * Replaced the static method `_select_best_sample_value` with a new `_compute_column_type` method, which determines a representative sample value and computes min/max for integer columns, enhancing how types are inferred for each parameter column in `executemany`. [[1]](diffhunk://#diff-deceea46ae01082ce8400e14fa02f4b7585afb7b5ed9885338b66494f5f38280L827-L855) [[2]](diffhunk://#diff-deceea46ae01082ce8400e14fa02f4b7585afb7b5ed9885338b66494f5f38280R848-R873) * Modified the `executemany` method to use `_compute_column_type` for each parameter column, passing the computed min/max values to `_create_parameter_types_list` for better type assignment. **Code cleanup:** * Removed the now-unnecessary `_select_best_sample_value` static method, consolidating logic and reducing code duplication. <!-- ### PR Title Guide > For feature requests FEAT: (short-description) > For non-feature requests like test case updates, config updates , dependency updates etc CHORE: (short-description) > For Fix requests FIX: (short-description) > For doc update requests DOC: (short-description) > For Formatting, indentation, or styling update STYLE: (short-description) > For Refactor, without any feature changes REFACTOR: (short-description) > For release related changes, without any feature changes RELEASE: #<RELEASE_VERSION> (short-description) ### Contribution Guidelines External contributors: - Create a GitHub issue first: https://github.com/microsoft/mssql-python/issues/new - Link the GitHub issue in the "GitHub Issue" section above - Follow the PR title format and provide a meaningful summary mssql-python maintainers: - Create an ADO Work Item following internal processes - Link the ADO Work Item in the "ADO Work Item" section above - Follow the PR title format and provide a meaningful summary -->
1 parent 7d06c96 commit 82119f7

File tree

1 file changed

+49
-49
lines changed

1 file changed

+49
-49
lines changed

mssql_python/cursor.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _get_numeric_data(self, param):
222222
numeric_data.val = val
223223
return numeric_data
224224

225-
def _map_sql_type(self, param, parameters_list, i):
225+
def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
226226
"""
227227
Map a Python data type to the corresponding SQL type,
228228
C type, Column size, and Decimal digits.
@@ -246,23 +246,27 @@ def _map_sql_type(self, param, parameters_list, i):
246246
return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0, False
247247

248248
if isinstance(param, int):
249-
if 0 <= param <= 255:
249+
# Use min_val/max_val if available
250+
value_to_check = max_val if max_val is not None else param
251+
min_to_check = min_val if min_val is not None else param
252+
253+
if 0 <= min_to_check and value_to_check <= 255:
250254
return (
251255
ddbc_sql_const.SQL_TINYINT.value,
252256
ddbc_sql_const.SQL_C_TINYINT.value,
253257
3,
254258
0,
255259
False,
256260
)
257-
if -32768 <= param <= 32767:
261+
if -32768 <= min_to_check and value_to_check <= 32767:
258262
return (
259263
ddbc_sql_const.SQL_SMALLINT.value,
260264
ddbc_sql_const.SQL_C_SHORT.value,
261265
5,
262266
0,
263267
False,
264268
)
265-
if -2147483648 <= param <= 2147483647:
269+
if -2147483648 <= min_to_check and value_to_check <= 2147483647:
266270
return (
267271
ddbc_sql_const.SQL_INTEGER.value,
268272
ddbc_sql_const.SQL_C_LONG.value,
@@ -514,20 +518,18 @@ def _check_closed(self):
514518
driver_error="Operation cannot be performed: The cursor is closed.",
515519
ddbc_error=""
516520
)
517-
518-
def _create_parameter_types_list(self, parameter, param_info, parameters_list, i):
521+
522+
def _create_parameter_types_list(self, parameter, param_info, parameters_list, i, min_val=None, max_val=None):
519523
"""
520524
Maps parameter types for the given parameter.
521-
522525
Args:
523526
parameter: parameter to bind.
524-
525527
Returns:
526528
paraminfo.
527529
"""
528530
paraminfo = param_info()
529531
sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type(
530-
parameter, parameters_list, i
532+
parameter, parameters_list, i, min_val=min_val, max_val=max_val
531533
)
532534
paraminfo.paramCType = c_type
533535
paraminfo.paramSQLType = sql_type
@@ -833,37 +835,6 @@ def execute(
833835
# Return self for method chaining
834836
return self
835837

836-
@staticmethod
837-
def _select_best_sample_value(column):
838-
"""
839-
Selects the most representative non-null value from a column for type inference.
840-
841-
This is used during executemany() to infer SQL/C types based on actual data,
842-
preferring a non-null value that is not the first row to avoid bias from placeholder defaults.
843-
844-
Args:
845-
column: List of values in the column.
846-
"""
847-
non_nulls = [v for v in column if v is not None]
848-
if not non_nulls:
849-
return None
850-
if all(isinstance(v, int) for v in non_nulls):
851-
# Pick the value with the widest range (min/max)
852-
return max(non_nulls, key=lambda v: abs(v))
853-
if all(isinstance(v, float) for v in non_nulls):
854-
return 0.0
855-
if all(isinstance(v, decimal.Decimal) for v in non_nulls):
856-
return max(non_nulls, key=lambda d: len(d.as_tuple().digits))
857-
if all(isinstance(v, str) for v in non_nulls):
858-
return max(non_nulls, key=lambda s: len(str(s)))
859-
if all(isinstance(v, datetime.datetime) for v in non_nulls):
860-
return datetime.datetime.now()
861-
if all(isinstance(v, (bytes, bytearray)) for v in non_nulls):
862-
return max(non_nulls, key=lambda b: len(b))
863-
if all(isinstance(v, datetime.date) for v in non_nulls):
864-
return datetime.date.today()
865-
return non_nulls[0] # fallback
866-
867838
def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
868839
"""
869840
Convert list of rows (row-wise) into list of columns (column-wise),
@@ -882,6 +853,32 @@ def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
882853
for i, val in enumerate(row):
883854
columnwise[i].append(val)
884855
return columnwise
856+
857+
def _compute_column_type(self, column):
858+
"""
859+
Determine representative value and integer min/max for a column.
860+
861+
Returns:
862+
sample_value: Representative value for type inference and modified_row.
863+
min_val: Minimum for integers (None otherwise).
864+
max_val: Maximum for integers (None otherwise).
865+
"""
866+
non_nulls = [v for v in column if v is not None]
867+
if not non_nulls:
868+
return None, None, None
869+
870+
int_values = [v for v in non_nulls if isinstance(v, int)]
871+
if int_values:
872+
min_val, max_val = min(int_values), max(int_values)
873+
sample_value = max(int_values, key=abs)
874+
return sample_value, min_val, max_val
875+
876+
sample_value = None
877+
for v in non_nulls:
878+
if not sample_value or (hasattr(v, '__len__') and len(v) > len(sample_value)):
879+
sample_value = v
880+
881+
return sample_value, None, None
885882

886883
def executemany(self, operation: str, seq_of_parameters: list) -> None:
887884
"""
@@ -896,10 +893,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
896893
"""
897894
self._check_closed()
898895
self._reset_cursor()
899-
896+
900897
# Clear any previous messages
901898
self.messages = []
902-
899+
903900
if not seq_of_parameters:
904901
self.rowcount = 0
905902
return
@@ -910,14 +907,17 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
910907

911908
for col_index in range(param_count):
912909
column = [row[col_index] for row in seq_of_parameters]
913-
sample_value = self._select_best_sample_value(column)
914-
dummy_row = list(seq_of_parameters[0])
915-
parameters_type.append(
916-
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index)
910+
sample_value, min_val, max_val = self._compute_column_type(column)
911+
modified_row = list(seq_of_parameters[0])
912+
modified_row[col_index] = sample_value
913+
# sending original values for all rows here, we may change this if any inconsistent behavior is observed
914+
paraminfo = self._create_parameter_types_list(
915+
sample_value, param_info, modified_row, col_index, min_val=min_val, max_val=max_val
917916
)
917+
parameters_type.append(paraminfo)
918918

919919
columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters)
920-
log('info', "Executing batch query with %d parameter sets:\n%s",
920+
log('debug', "Executing batch query with %d parameter sets:\n%s",
921921
len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters))
922922
)
923923

@@ -934,11 +934,11 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
934934
# Capture any diagnostic messages after execution
935935
if self.hstmt:
936936
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
937-
937+
938938
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
939939
self.last_executed_stmt = operation
940940
self._initialize_description()
941-
941+
942942
if self.description:
943943
self.rowcount = -1
944944
self._reset_rownumber()

0 commit comments

Comments
 (0)