Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,16 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
parameters_list[i].scale,
False,
)

if isinstance(param, uuid.UUID):
parameters_list[i] = param.bytes_le
return (
ddbc_sql_const.SQL_GUID.value,
ddbc_sql_const.SQL_C_GUID.value,
16,
0,
False,
)

if isinstance(param, str):
if (
Expand All @@ -352,6 +362,20 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
0,
False,
)

try:
val = uuid.UUID(param)
parameters_list[i] = val.bytes_le
return (
ddbc_sql_const.SQL_GUID.value,
ddbc_sql_const.SQL_C_GUID.value,
16,
0,
False
)
except ValueError:
pass


# Attempt to parse as date, datetime, datetime2, timestamp, smalldatetime or time
if self._parse_date(param):
Expand Down
83 changes: 65 additions & 18 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,33 @@
break;
}
case SQL_C_GUID: {
// TODO
if (!py::isinstance<py::bytes>(param)) {
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
}
py::bytes uuid_bytes = param.cast<py::bytes>();
const unsigned char* uuid_data = reinterpret_cast<const unsigned char*>(PyBytes_AS_STRING(uuid_bytes.ptr()));
if (PyBytes_GET_SIZE(uuid_bytes.ptr()) != 16) {
LOG("Invalid UUID parameter at index {}: expected 16 bytes, got {} bytes, type {}", paramIndex, PyBytes_GET_SIZE(uuid_bytes.ptr()), paramInfo.paramCType);
ThrowStdException("UUID binary data must be exactly 16 bytes long.");
}
SQLGUID* guid_data_ptr = AllocateParamBuffer<SQLGUID>(paramBuffers);
guid_data_ptr->Data1 =
(static_cast<uint32_t>(uuid_data[3]) << 24) |
(static_cast<uint32_t>(uuid_data[2]) << 16) |
(static_cast<uint32_t>(uuid_data[1]) << 8) |
(static_cast<uint32_t>(uuid_data[0]));
guid_data_ptr->Data2 =
(static_cast<uint16_t>(uuid_data[5]) << 8) |
(static_cast<uint16_t>(uuid_data[4]));
guid_data_ptr->Data3 =
(static_cast<uint16_t>(uuid_data[7]) << 8) |
(static_cast<uint16_t>(uuid_data[6]));
std::memcpy(guid_data_ptr->Data4, &uuid_data[8], 8);
dataPtr = static_cast<void*>(guid_data_ptr);
bufferLength = sizeof(SQLGUID);
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = sizeof(SQLGUID);
break;
}
default: {
std::ostringstream errorString;
Expand Down Expand Up @@ -2553,20 +2579,27 @@
#if (ODBCVER >= 0x0350)
case SQL_GUID: {
SQLGUID guidValue;
ret = SQLGetData_ptr(hStmt, i, SQL_C_GUID, &guidValue, sizeof(guidValue), NULL);
if (SQL_SUCCEEDED(ret)) {
std::ostringstream oss;
oss << std::hex << std::setfill('0') << std::setw(8) << guidValue.Data1 << '-'
<< std::setw(4) << guidValue.Data2 << '-' << std::setw(4) << guidValue.Data3
<< '-' << std::setw(2) << static_cast<int>(guidValue.Data4[0])
<< std::setw(2) << static_cast<int>(guidValue.Data4[1]) << '-' << std::hex
<< std::setw(2) << static_cast<int>(guidValue.Data4[2]) << std::setw(2)
<< static_cast<int>(guidValue.Data4[3]) << std::setw(2)
<< static_cast<int>(guidValue.Data4[4]) << std::setw(2)
<< static_cast<int>(guidValue.Data4[5]) << std::setw(2)
<< static_cast<int>(guidValue.Data4[6]) << std::setw(2)
<< static_cast<int>(guidValue.Data4[7]);
row.append(oss.str()); // Append GUID as a string
SQLLEN indicator;
ret = SQLGetData_ptr(hStmt, i, SQL_C_GUID, &guidValue, sizeof(guidValue), &indicator);

if (SQL_SUCCEEDED(ret) && indicator != SQL_NULL_DATA) {
std::vector<char> guid_bytes(16);
guid_bytes[0] = ((char*)&guidValue.Data1)[3];
guid_bytes[1] = ((char*)&guidValue.Data1)[2];
guid_bytes[2] = ((char*)&guidValue.Data1)[1];
guid_bytes[3] = ((char*)&guidValue.Data1)[0];
guid_bytes[4] = ((char*)&guidValue.Data2)[1];
guid_bytes[5] = ((char*)&guidValue.Data2)[0];
guid_bytes[6] = ((char*)&guidValue.Data3)[1];
guid_bytes[7] = ((char*)&guidValue.Data3)[0];
std::memcpy(&guid_bytes[8], guidValue.Data4, sizeof(guidValue.Data4));

py::bytes py_guid_bytes(guid_bytes.data(), guid_bytes.size());
py::object uuid_module = py::module_::import("uuid");
py::object uuid_obj = uuid_module.attr("UUID")(py::arg("bytes")=py_guid_bytes);
row.append(uuid_obj);
} else if (indicator == SQL_NULL_DATA) {
row.append(py::none());
} else {
LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return "
"code - {}. Returning NULL value instead",
Expand Down Expand Up @@ -2957,9 +2990,23 @@
break;
}
case SQL_GUID: {
row.append(
py::bytes(reinterpret_cast<const char*>(&buffers.guidBuffers[col - 1][i]),
sizeof(SQLGUID)));
SQLGUID* guidValue = &buffers.guidBuffers[col - 1][i];
uint8_t reordered[16];
reordered[0] = ((char*)&guidValue->Data1)[3];
reordered[1] = ((char*)&guidValue->Data1)[2];
reordered[2] = ((char*)&guidValue->Data1)[1];
reordered[3] = ((char*)&guidValue->Data1)[0];
reordered[4] = ((char*)&guidValue->Data2)[1];
reordered[5] = ((char*)&guidValue->Data2)[0];
reordered[6] = ((char*)&guidValue->Data3)[1];
reordered[7] = ((char*)&guidValue->Data3)[0];
std::memcpy(reordered + 8, guidValue->Data4, 8);

py::bytes py_guid_bytes(reinterpret_cast<char*>(reordered), 16);
py::dict kwargs;
kwargs["bytes"] = py_guid_bytes;
py::object uuid_obj = py::module_::import("uuid").attr("UUID")(**kwargs);
row.append(uuid_obj);
break;
}
case SQL_BINARY:
Expand Down
205 changes: 204 additions & 1 deletion tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import decimal
from contextlib import closing
import mssql_python
import uuid


# Setup test table
TEST_TABLE = """
Expand Down Expand Up @@ -6942,6 +6944,208 @@ def test_money_smallmoney_invalid_values(cursor, db_connection):
drop_table_if_exists(cursor, "dbo.money_test")
db_connection.commit()

def test_uuid_insert_and_select_none(cursor, db_connection):
"""Test inserting and retrieving None in a nullable UUID column."""
table_name = "#pytest_uuid_nullable"
try:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(f"""
CREATE TABLE {table_name} (
id UNIQUEIDENTIFIER,
name NVARCHAR(50)
)
""")
db_connection.commit()

# Insert a row with None for the UUID
cursor.execute(f"INSERT INTO {table_name} (id, name) VALUES (?, ?)", [None, "Bob"])
db_connection.commit()

# Fetch the row
cursor.execute(f"SELECT id, name FROM {table_name}")
retrieved_uuid, retrieved_name = cursor.fetchone()

# Assert correct results
assert retrieved_uuid is None, f"Expected None, got {retrieved_uuid}"
assert retrieved_name == "Bob"
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
db_connection.commit()


def test_insert_multiple_uuids(cursor, db_connection):
"""Test inserting multiple UUIDs and verifying retrieval."""
table_name = "#pytest_uuid_multiple"
try:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(f"""
CREATE TABLE {table_name} (
id UNIQUEIDENTIFIER PRIMARY KEY,
description NVARCHAR(50)
)
""")
db_connection.commit()

# Prepare test data
uuids_to_insert = {f"Item {i}": uuid.uuid4() for i in range(5)}

# Insert UUIDs and descriptions
for desc, uid in uuids_to_insert.items():
cursor.execute(f"INSERT INTO {table_name} (id, description) VALUES (?, ?)", [uid, desc])
db_connection.commit()

# Fetch all rows
cursor.execute(f"SELECT id, description FROM {table_name}")
rows = cursor.fetchall()

# Verify each fetched row
assert len(rows) == len(uuids_to_insert), "Fetched row count mismatch"

for retrieved_uuid, retrieved_desc in rows:
assert isinstance(retrieved_uuid, uuid.UUID), f"Expected uuid.UUID, got {type(retrieved_uuid)}"
expected_uuid = uuids_to_insert[retrieved_desc]
assert retrieved_uuid == expected_uuid, f"UUID mismatch for '{retrieved_desc}': expected {expected_uuid}, got {retrieved_uuid}"
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
db_connection.commit()


def test_fetchmany_uuids(cursor, db_connection):
"""Test fetching multiple UUID rows with fetchmany()."""
table_name = "#pytest_uuid_fetchmany"
try:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(f"""
CREATE TABLE {table_name} (
id UNIQUEIDENTIFIER PRIMARY KEY,
description NVARCHAR(50)
)
""")
db_connection.commit()

uuids_to_insert = {f"Item {i}": uuid.uuid4() for i in range(10)}

for desc, uid in uuids_to_insert.items():
cursor.execute(f"INSERT INTO {table_name} (id, description) VALUES (?, ?)", [uid, desc])
db_connection.commit()

cursor.execute(f"SELECT id, description FROM {table_name}")

# Fetch in batches of 3
batch_size = 3
fetched_rows = []
while True:
batch = cursor.fetchmany(batch_size)
if not batch:
break
fetched_rows.extend(batch)

# Verify all rows
assert len(fetched_rows) == len(uuids_to_insert), "Fetched row count mismatch"
for retrieved_uuid, retrieved_desc in fetched_rows:
assert isinstance(retrieved_uuid, uuid.UUID)
expected_uuid = uuids_to_insert[retrieved_desc]
assert retrieved_uuid == expected_uuid
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
db_connection.commit()


def test_uuid_insert_with_none(cursor, db_connection):
"""Test inserting None into a UUID column results in a NULL value."""
table_name = "#pytest_uuid_none"
try:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(f"""
CREATE TABLE {table_name} (
id UNIQUEIDENTIFIER,
name NVARCHAR(50)
)
""")
db_connection.commit()

cursor.execute(f"INSERT INTO {table_name} (id, name) VALUES (?, ?)", [None, "Alice"])
db_connection.commit()

cursor.execute(f"SELECT id, name FROM {table_name}")
retrieved_uuid, retrieved_name = cursor.fetchone()

assert retrieved_uuid is None, f"Expected NULL UUID, got {retrieved_uuid}"
assert retrieved_name == "Alice"
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
db_connection.commit()

def test_invalid_uuid_inserts(cursor, db_connection):
"""Test inserting invalid UUID values raises appropriate errors."""
table_name = "#pytest_uuid_invalid"
try:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(f"CREATE TABLE {table_name} (id UNIQUEIDENTIFIER)")
db_connection.commit()

invalid_values = [
"12345", # Too short
"not-a-uuid", # Not a UUID string
123456789, # Integer
12.34, # Float
object() # Arbitrary object
]

for val in invalid_values:
with pytest.raises(Exception):
cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [val])
db_connection.commit()
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
db_connection.commit()

def test_duplicate_uuid_inserts(cursor, db_connection):
"""Test that inserting duplicate UUIDs into a PK column raises an error."""
table_name = "#pytest_uuid_duplicate"
try:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(f"CREATE TABLE {table_name} (id UNIQUEIDENTIFIER PRIMARY KEY)")
db_connection.commit()

uid = uuid.uuid4()
cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [uid])
db_connection.commit()

with pytest.raises(Exception):
cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [uid])
db_connection.commit()
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
db_connection.commit()

def test_extreme_uuids(cursor, db_connection):
"""Test inserting extreme but valid UUIDs."""
table_name = "#pytest_uuid_extreme"
try:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
cursor.execute(f"CREATE TABLE {table_name} (id UNIQUEIDENTIFIER)")
db_connection.commit()

extreme_uuids = [
uuid.UUID(int=0), # All zeros
uuid.UUID(int=(1 << 128) - 1), # All ones
]

for uid in extreme_uuids:
cursor.execute(f"INSERT INTO {table_name} (id) VALUES (?)", [uid])
db_connection.commit()

cursor.execute(f"SELECT id FROM {table_name}")
rows = cursor.fetchall()
fetched_uuids = [row[0] for row in rows]

for uid in extreme_uuids:
assert uid in fetched_uuids, f"Extreme UUID {uid} not retrieved correctly"
finally:
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
db_connection.commit()

def test_decimal_separator_with_multiple_values(cursor, db_connection):
"""Test decimal separator with multiple different decimal values"""
original_separator = mssql_python.getDecimalSeparator()
Expand Down Expand Up @@ -10193,7 +10397,6 @@ def test_decimal_separator_calculations(cursor, db_connection):

# Cleanup
cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test")
db_connection.commit()

def test_close(db_connection):
"""Test closing the cursor"""
Expand Down
Loading