Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 73 additions & 45 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _get_numeric_data(self, param):
numeric_data.val = val
return numeric_data

def _map_sql_type(self, param, parameters_list, i):
def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
"""
Map a Python data type to the corresponding SQL type,
C type, Column size, and Decimal digits.
Expand All @@ -242,23 +242,27 @@ def _map_sql_type(self, param, parameters_list, i):
return ddbc_sql_const.SQL_BIT.value, ddbc_sql_const.SQL_C_BIT.value, 1, 0, False

if isinstance(param, int):
if 0 <= param <= 255:
# Use min_val/max_val if available
value_to_check = max_val if max_val is not None else param
min_to_check = min_val if min_val is not None else param

if 0 <= min_to_check and value_to_check <= 255:
return (
ddbc_sql_const.SQL_TINYINT.value,
ddbc_sql_const.SQL_C_TINYINT.value,
3,
0,
False,
)
if -32768 <= param <= 32767:
if -32768 <= min_to_check and value_to_check <= 32767:
return (
ddbc_sql_const.SQL_SMALLINT.value,
ddbc_sql_const.SQL_C_SHORT.value,
5,
0,
False,
)
if -2147483648 <= param <= 2147483647:
if -2147483648 <= min_to_check and value_to_check <= 2147483647:
return (
ddbc_sql_const.SQL_INTEGER.value,
ddbc_sql_const.SQL_C_LONG.value,
Expand Down Expand Up @@ -412,7 +416,7 @@ def _map_sql_type(self, param, parameters_list, i):
0,
False,
)

if isinstance(param, datetime.datetime):
return (
ddbc_sql_const.SQL_TIMESTAMP.value,
Expand Down Expand Up @@ -505,20 +509,18 @@ def _check_closed(self):
driver_error="Operation cannot be performed: the cursor is closed.",
ddbc_error="Operation cannot be performed: the cursor is closed."
)

def _create_parameter_types_list(self, parameter, param_info, parameters_list, i):
def _create_parameter_types_list(self, parameter, param_info, parameters_list, i, min_val=None, max_val=None):
"""
Maps parameter types for the given parameter.

Args:
parameter: parameter to bind.

Returns:
paraminfo.
"""
paraminfo = param_info()
sql_type, c_type, column_size, decimal_digits, is_dae = self._map_sql_type(
parameter, parameters_list, i
parameter, parameters_list, i, min_val=min_val, max_val=max_val
)
paraminfo.paramCType = c_type
paraminfo.paramSQLType = sql_type
Expand Down Expand Up @@ -824,35 +826,6 @@ def execute(
# Return self for method chaining
return self

@staticmethod
def _select_best_sample_value(column):
"""
Selects the most representative non-null value from a column for type inference.

This is used during executemany() to infer SQL/C types based on actual data,
preferring a non-null value that is not the first row to avoid bias from placeholder defaults.

Args:
column: List of values in the column.
"""
non_nulls = [v for v in column if v is not None]
if not non_nulls:
return None
if all(isinstance(v, int) for v in non_nulls):
# Pick the value with the widest range (min/max)
return max(non_nulls, key=lambda v: abs(v))
if all(isinstance(v, float) for v in non_nulls):
return 0.0
if all(isinstance(v, decimal.Decimal) for v in non_nulls):
return max(non_nulls, key=lambda d: len(d.as_tuple().digits))
if all(isinstance(v, str) for v in non_nulls):
return max(non_nulls, key=lambda s: len(str(s)))
if all(isinstance(v, datetime.datetime) for v in non_nulls):
return datetime.datetime.now()
if all(isinstance(v, datetime.date) for v in non_nulls):
return datetime.date.today()
return non_nulls[0] # fallback

def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
"""
Convert list of rows (row-wise) into list of columns (column-wise),
Expand All @@ -872,6 +845,58 @@ def _transpose_rowwise_to_columnwise(self, seq_of_parameters: list) -> list:
columnwise[i].append(val)
return columnwise

def _compute_column_type(self, column):
"""
Scan all rows in a column to determine:
- representative value (sample_value)
- is_dae flag
- final value to use in dummy row
- min_val/max_val for integers
"""
non_nulls = [v for v in column if v is not None]
if not non_nulls:
return None, False, None, None, None

is_dae = False
sample_value = None
min_val = max_val = None

# Handle integers separately to determine min/max
int_values = [v for v in non_nulls if isinstance(v, int)]
if int_values:
min_val, max_val = min(int_values), max(int_values)
sample_value = max(int_values, key=lambda x: abs(x))
return sample_value, False, sample_value, min_val, max_val

# Handle other types (strings, bytes, float, decimal, datetime)
for v in non_nulls:
if isinstance(v, str):
utf16_len = sum(2 if ord(c) > 0xFFFF else 1 for c in v)
if utf16_len > MAX_INLINE_CHAR:
is_dae = True
if not sample_value or len(v) > len(sample_value):
sample_value = v
elif isinstance(v, (bytes, bytearray)):
if len(v) > 8000:
is_dae = True
if not sample_value or len(v) > len(sample_value):
sample_value = v
elif isinstance(v, float):
if sample_value is None:
sample_value = 0.0
elif isinstance(v, decimal.Decimal):
if sample_value is None or len(v.as_tuple().digits) > len(sample_value.as_tuple().digits):
sample_value = v
elif isinstance(v, (datetime.datetime, datetime.date, datetime.time)):
if sample_value is None:
sample_value = v
else:
if sample_value is None:
sample_value = v

return sample_value, is_dae, sample_value, None, None


def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Prepare a database operation and execute it against all parameter sequences.
Expand All @@ -885,10 +910,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
self._check_closed()
self._reset_cursor()

# Clear any previous messages
self.messages = []

if not seq_of_parameters:
self.rowcount = 0
return
Expand All @@ -899,11 +924,14 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:

for col_index in range(param_count):
column = [row[col_index] for row in seq_of_parameters]
sample_value = self._select_best_sample_value(column)
sample_value, is_dae, final_value, min_val, max_val = self._compute_column_type(column)
dummy_row = list(seq_of_parameters[0])
parameters_type.append(
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index)
dummy_row[col_index] = final_value
paraminfo = self._create_parameter_types_list(
sample_value, param_info, dummy_row, col_index, min_val=min_val, max_val=max_val
)
paraminfo.isDAE = is_dae
parameters_type.append(paraminfo)

columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters)
log('info', "Executing batch query with %d parameter sets:\n%s",
Expand All @@ -923,7 +951,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
# Capture any diagnostic messages after execution
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()
Expand Down