diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 048ea43a..970eff1a 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -222,7 +222,7 @@ def _get_numeric_data(self, param): numeric_data.val = val return numeric_data - def _map_sql_type(self, param, parameters_list, i): + def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None): """ Map a Python data type to the corresponding SQL type, C type, Column size, and Decimal digits. @@ -246,7 +246,11 @@ def _map_sql_type(self, param, parameters_list, i): return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0, False if isinstance(param, int): - if 0 <= param <= 255: + # Use min_val/max_val if available + value_to_check = max_val if max_val is not None else param + min_to_check = min_val if min_val is not None else param + + if 0 <= min_to_check and value_to_check <= 255: return ( ddbc_sql_const.SQL_TINYINT.value, ddbc_sql_const.SQL_C_TINYINT.value, @@ -254,7 +258,7 @@ def _map_sql_type(self, param, parameters_list, i): 0, False, ) - if -32768 <= param <= 32767: + if -32768 <= min_to_check and value_to_check <= 32767: return ( ddbc_sql_const.SQL_SMALLINT.value, ddbc_sql_const.SQL_C_SHORT.value, @@ -262,7 +266,7 @@ def _map_sql_type(self, param, parameters_list, i): 0, False, ) - if -2147483648 <= param <= 2147483647: + if -2147483648 <= min_to_check and value_to_check <= 2147483647: return ( ddbc_sql_const.SQL_INTEGER.value, ddbc_sql_const.SQL_C_LONG.value, @@ -514,20 +518,18 @@ def _check_closed(self): driver_error="Operation cannot be performed: The cursor is closed.", ddbc_error="" ) - - def _create_parameter_types_list(self, parameter, param_info, parameters_list, i): + + def _create_parameter_types_list(self, parameter, param_info, parameters_list, i, min_val=None, max_val=None): """ Maps parameter types for the given parameter. - Args: parameter: parameter to bind. - Returns: paraminfo. """ paraminfo = param_info() sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type( - parameter, parameters_list, i + parameter, parameters_list, i, min_val=min_val, max_val=max_val ) paraminfo.paramCType = c_type paraminfo.paramSQLType = sql_type @@ -833,37 +835,6 @@ def execute( # Return self for method chaining return self - @staticmethod - def _select_best_sample_value(column): - """ - Selects the most representative non-null value from a column for type inference. - - This is used during executemany() to infer SQL/C types based on actual data, - preferring a non-null value that is not the first row to avoid bias from placeholder defaults. - - Args: - column: List of values in the column. - """ - non_nulls = [v for v in column if v is not None] - if not non_nulls: - return None - if all(isinstance(v, int) for v in non_nulls): - # Pick the value with the widest range (min/max) - return max(non_nulls, key=lambda v: abs(v)) - if all(isinstance(v, float) for v in non_nulls): - return 0.0 - if all(isinstance(v, decimal.Decimal) for v in non_nulls): - return max(non_nulls, key=lambda d: len(d.as_tuple().digits)) - if all(isinstance(v, str) for v in non_nulls): - return max(non_nulls, key=lambda s: len(str(s))) - if all(isinstance(v, datetime.datetime) for v in non_nulls): - return datetime.datetime.now() - if all(isinstance(v, (bytes, bytearray)) for v in non_nulls): - return max(non_nulls, key=lambda b: len(b)) - if all(isinstance(v, datetime.date) for v in non_nulls): - return datetime.date.today() - return non_nulls[0] # fallback - def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list: """ Convert list of rows (row-wise) into list of columns (column-wise), @@ -882,6 +853,32 @@ def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list: for i, val in enumerate(row): columnwise[i].append(val) return columnwise + + def _compute_column_type(self, column): + """ + Determine representative value and integer min/max for a column. + + Returns: + sample_value: Representative value for type inference and modified_row. + min_val: Minimum for integers (None otherwise). + max_val: Maximum for integers (None otherwise). + """ + non_nulls = [v for v in column if v is not None] + if not non_nulls: + return None, None, None + + int_values = [v for v in non_nulls if isinstance(v, int)] + if int_values: + min_val, max_val = min(int_values), max(int_values) + sample_value = max(int_values, key=abs) + return sample_value, min_val, max_val + + sample_value = None + for v in non_nulls: + if not sample_value or (hasattr(v, '__len__') and len(v) > len(sample_value)): + sample_value = v + + return sample_value, None, None def executemany(self, operation: str, seq_of_parameters: list) -> None: """ @@ -896,10 +893,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: """ self._check_closed() self._reset_cursor() - + # Clear any previous messages self.messages = [] - + if not seq_of_parameters: self.rowcount = 0 return @@ -910,14 +907,17 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: for col_index in range(param_count): column = [row[col_index] for row in seq_of_parameters] - sample_value = self._select_best_sample_value(column) - dummy_row = list(seq_of_parameters[0]) - parameters_type.append( - self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index) + sample_value, min_val, max_val = self._compute_column_type(column) + modified_row = list(seq_of_parameters[0]) + modified_row[col_index] = sample_value + # sending original values for all rows here, we may change this if any inconsistent behavior is observed + paraminfo = self._create_parameter_types_list( + sample_value, param_info, modified_row, col_index, min_val=min_val, max_val=max_val ) + parameters_type.append(paraminfo) columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters) - log('info', "Executing batch query with %d parameter sets:\n%s", + log('debug', "Executing batch query with %d parameter sets:\n%s", len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters)) ) @@ -934,11 +934,11 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: # Capture any diagnostic messages after execution if self.hstmt: self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) - + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self.last_executed_stmt = operation self._initialize_description() - + if self.description: self.rowcount = -1 self._reset_rownumber()