Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 48 additions & 47 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _get_numeric_data(self, param):
numeric_data.val = val
return numeric_data

def _map_sql_type(self, param, parameters_list, i):
def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
"""
Map a Python data type to the corresponding SQL type,
C type, Column size, and Decimal digits.
Expand All @@ -242,23 +242,27 @@ def _map_sql_type(self, param, parameters_list, i):
return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0, False

if isinstance(param, int):
if 0 <= param <= 255:
# Use min_val/max_val if available
value_to_check = max_val if max_val is not None else param
min_to_check = min_val if min_val is not None else param

if 0 <= min_to_check and value_to_check <= 255:
return (
ddbc_sql_const.SQL_TINYINT.value,
ddbc_sql_const.SQL_C_TINYINT.value,
3,
0,
False,
)
if -32768 <= param <= 32767:
if -32768 <= min_to_check and value_to_check <= 32767:
return (
ddbc_sql_const.SQL_SMALLINT.value,
ddbc_sql_const.SQL_C_SHORT.value,
5,
0,
False,
)
if -2147483648 <= param <= 2147483647:
if -2147483648 <= min_to_check and value_to_check <= 2147483647:
return (
ddbc_sql_const.SQL_INTEGER.value,
ddbc_sql_const.SQL_C_LONG.value,
Expand Down Expand Up @@ -412,7 +416,7 @@ def _map_sql_type(self, param, parameters_list, i):
0,
False,
)

if isinstance(param, datetime.datetime):
return (
ddbc_sql_const.SQL_TIMESTAMP.value,
Expand Down Expand Up @@ -505,20 +509,18 @@ def _check_closed(self):
driver_error="Operation cannot be performed: the cursor is closed.",
ddbc_error="Operation cannot be performed: the cursor is closed."
)

def _create_parameter_types_list(self, parameter, param_info, parameters_list, i):
def _create_parameter_types_list(self, parameter, param_info, parameters_list, i, min_val=None, max_val=None):
"""
Maps parameter types for the given parameter.

Args:
parameter: parameter to bind.

Returns:
paraminfo.
"""
paraminfo = param_info()
sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type(
parameter, parameters_list, i
parameter, parameters_list, i, min_val=min_val, max_val=max_val
)
paraminfo.paramCType = c_type
paraminfo.paramSQLType = sql_type
Expand Down Expand Up @@ -824,35 +826,6 @@ def execute(
# Return self for method chaining
return self

@staticmethod
def _select_best_sample_value(column):
"""
Selects the most representative non-null value from a column for type inference.

This is used during executemany() to infer SQL/C types based on actual data,
preferring a non-null value that is not the first row to avoid bias from placeholder defaults.

Args:
column: List of values in the column.
"""
non_nulls = [v for v in column if v is not None]
if not non_nulls:
return None
if all(isinstance(v, int) for v in non_nulls):
# Pick the value with the widest range (min/max)
return max(non_nulls, key=lambda v: abs(v))
if all(isinstance(v, float) for v in non_nulls):
return 0.0
if all(isinstance(v, decimal.Decimal) for v in non_nulls):
return max(non_nulls, key=lambda d: len(d.as_tuple().digits))
if all(isinstance(v, str) for v in non_nulls):
return max(non_nulls, key=lambda s: len(str(s)))
if all(isinstance(v, datetime.datetime) for v in non_nulls):
return datetime.datetime.now()
if all(isinstance(v, datetime.date) for v in non_nulls):
return datetime.date.today()
return non_nulls[0] # fallback

def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
"""
Convert list of rows (row-wise) into list of columns (column-wise),
Expand All @@ -871,6 +844,32 @@ def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
for i, val in enumerate(row):
columnwise[i].append(val)
return columnwise

def _compute_column_type(self, column):
"""
Determine representative value and integer min/max for a column.

Returns:
sample_value: Representative value for type inference and modified_row.
min_val: Minimum for integers (None otherwise).
max_val: Maximum for integers (None otherwise).
"""
non_nulls = [v for v in column if v is not None]
if not non_nulls:
return None, None, None

int_values = [v for v in non_nulls if isinstance(v, int)]
if int_values:
min_val, max_val = min(int_values), max(int_values)
sample_value = max(int_values, key=abs)
return sample_value, min_val, max_val

sample_value = None
for v in non_nulls:
if not sample_value or (hasattr(v, '__len__') and len(v) > len(sample_value)):
sample_value = v

return sample_value, None, None

def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Expand All @@ -885,10 +884,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
self._check_closed()
self._reset_cursor()

# Clear any previous messages
self.messages = []

if not seq_of_parameters:
self.rowcount = 0
return
Expand All @@ -899,11 +898,13 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:

for col_index in range(param_count):
column = [row[col_index] for row in seq_of_parameters]
sample_value = self._select_best_sample_value(column)
dummy_row = list(seq_of_parameters[0])
parameters_type.append(
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index)
sample_value, min_val, max_val = self._compute_column_type(column)
modified_row = list(seq_of_parameters[0])
modified_row[col_index] = sample_value
paraminfo = self._create_parameter_types_list(
sample_value, param_info, modified_row, col_index, min_val=min_val, max_val=max_val
)
parameters_type.append(paraminfo)

columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters)
log('info', "Executing batch query with %d parameter sets:\n%s",
Expand All @@ -923,11 +924,11 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
# Capture any diagnostic messages after execution
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()

if self.description:
self.rowcount = -1
self._reset_rownumber()
Expand Down