Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,15 @@ def _map_sql_type(self, param, parameters_list, i):

# String mapping logic here
is_unicode = self._is_unicode_string(param)
if len(param) > MAX_INLINE_CHAR: # Long strings

# Computes UTF-16 code units (handles surrogate pairs)
utf16_len = sum(2 if ord(c) > 0xFFFF else 1 for c in param)
if utf16_len > MAX_INLINE_CHAR: # Long strings -> DAE
if is_unicode:
return (
ddbc_sql_const.SQL_WLONGVARCHAR.value,
ddbc_sql_const.SQL_C_WCHAR.value,
len(param),
utf16_len,
0,
True,
)
Expand All @@ -358,8 +361,9 @@ def _map_sql_type(self, param, parameters_list, i):
0,
True,
)
if is_unicode: # Short Unicode strings
utf16_len = len(param.encode("utf-16-le")) // 2

# Short strings
if is_unicode:
return (
ddbc_sql_const.SQL_WVARCHAR.value,
ddbc_sql_const.SQL_C_WCHAR.value,
Expand All @@ -374,7 +378,7 @@ def _map_sql_type(self, param, parameters_list, i):
0,
False,
)

if isinstance(param, bytes):
if len(param) > 8000: # Assuming VARBINARY(MAX) for long byte arrays
return (
Expand Down
76 changes: 62 additions & 14 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,27 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,

// TODO: Add more data types like money, guid, interval, TVPs etc.
switch (paramInfo.paramCType) {
case SQL_C_CHAR:
case SQL_C_CHAR: {
if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
!py::isinstance<py::bytes>(param)) {
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
}
if (paramInfo.isDAE) {
LOG("Parameter[{}] is marked for DAE streaming", paramIndex);
dataPtr = const_cast<void*>(reinterpret_cast<const void*>(&paramInfos[paramIndex]));
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = SQL_LEN_DATA_AT_EXEC(0);
bufferLength = 0;
} else {
std::string* strParam =
AllocateParamBuffer<std::string>(paramBuffers, param.cast<std::string>());
dataPtr = const_cast<void*>(static_cast<const void*>(strParam->c_str()));
bufferLength = strParam->size() + 1;
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = SQL_NTS;
}
break;
}
case SQL_C_BINARY: {
if (!py::isinstance<py::str>(param) && !py::isinstance<py::bytearray>(param) &&
!py::isinstance<py::bytes>(param)) {
Expand Down Expand Up @@ -1203,23 +1223,51 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle,
continue;
}
if (py::isinstance<py::str>(pyObj)) {
std::wstring wstr = pyObj.cast<std::wstring>();
if (matchedInfo->paramCType == SQL_C_WCHAR) {
std::wstring wstr = pyObj.cast<std::wstring>();
const SQLWCHAR* dataPtr = nullptr;
size_t totalChars = 0;
#if defined(__APPLE__) || defined(__linux__)
auto utf16Buf = WStringToSQLWCHAR(wstr);
const char* dataPtr = reinterpret_cast<const char*>(utf16Buf.data());
size_t totalBytes = (utf16Buf.size() - 1) * sizeof(SQLWCHAR);
std::vector<SQLWCHAR> sqlwStr = WStringToSQLWCHAR(wstr);
totalChars = sqlwStr.size() - 1;
dataPtr = sqlwStr.data();
#else
const char* dataPtr = reinterpret_cast<const char*>(wstr.data());
size_t totalBytes = wstr.size() * sizeof(wchar_t);
dataPtr = wstr.c_str();
totalChars = wstr.size();
#endif
const size_t chunkSize = DAE_CHUNK_SIZE;
for (size_t offset = 0; offset < totalBytes; offset += chunkSize) {
size_t len = std::min(chunkSize, totalBytes - offset);
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(len));
if (!SQL_SUCCEEDED(rc)) {
LOG("SQLPutData failed at offset {} of {}", offset, totalBytes);
return rc;
size_t offset = 0;
size_t chunkChars = DAE_CHUNK_SIZE / sizeof(SQLWCHAR);
while (offset < totalChars) {
size_t len = std::min(chunkChars, totalChars - offset);
size_t lenBytes = len * sizeof(SQLWCHAR);
if (lenBytes > static_cast<size_t>(std::numeric_limits<SQLLEN>::max())) {
ThrowStdException("Chunk size exceeds maximum allowed by SQLLEN");
}
rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(lenBytes));
if (!SQL_SUCCEEDED(rc)) {
LOG("SQLPutData failed at offset {} of {}", offset, totalChars);
return rc;
}
offset += len;
}
} else if (matchedInfo->paramCType == SQL_C_CHAR) {
std::string s = pyObj.cast<std::string>();
size_t totalBytes = s.size();
const char* dataPtr = s.data();
size_t offset = 0;
size_t chunkBytes = DAE_CHUNK_SIZE;
while (offset < totalBytes) {
size_t len = std::min(chunkBytes, totalBytes - offset);

rc = SQLPutData_ptr(hStmt, (SQLPOINTER)(dataPtr + offset), static_cast<SQLLEN>(len));
if (!SQL_SUCCEEDED(rc)) {
LOG("SQLPutData failed at offset {} of {}", offset, totalBytes);
return rc;
}
offset += len;
}
} else {
ThrowStdException("Unsupported C type for str in DAE");
}
} else {
ThrowStdException("DAE only supported for str or bytes");
Expand Down
182 changes: 181 additions & 1 deletion tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datetime import datetime, date, time
import decimal
from contextlib import closing
from mssql_python import Connection
from mssql_python import Connection, row

# Setup test table
TEST_TABLE = """
Expand Down Expand Up @@ -5124,6 +5124,186 @@ def test_emoji_round_trip(cursor, db_connection):
except Exception as e:
pytest.fail(f"Error for input {repr(text)}: {e}")

def test_varchar_max_insert_non_lob(cursor, db_connection):
"""Test small VARCHAR(MAX) insert (non-LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_varchar_nonlob (col VARCHAR(MAX))")
db_connection.commit()

small_str = "Hello, world!" # small, non-LOB
cursor.execute(
"INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)",
[small_str]
)
db_connection.commit()

empty_str = ""
cursor.execute(
"INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)",
[empty_str]
)
db_connection.commit()

# None value
cursor.execute(
"INSERT INTO #pytest_varchar_nonlob (col) VALUES (?)",
[None]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_varchar_nonlob")
# rows = cursor.fetchall()
# assert rows == [[small_str], [empty_str], [None]]

finally:
pass


def test_varchar_max_insert_lob(cursor, db_connection):
"""Test large VARCHAR(MAX) insert (LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_varchar_lob (col VARCHAR(MAX))")
db_connection.commit()

large_str = "A" * 100_000 # > 8k to trigger LOB
cursor.execute(
"INSERT INTO #pytest_varchar_lob (col) VALUES (?)",
[large_str]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_varchar_lob")
# rows = cursor.fetchall()
# assert rows == [[large_str]]

finally:
pass


def test_nvarchar_max_insert_non_lob(cursor, db_connection):
"""Test small NVARCHAR(MAX) insert (non-LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_nvarchar_nonlob (col NVARCHAR(MAX))")
db_connection.commit()

small_str = "Unicode ✨ test"
cursor.execute(
"INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)",
[small_str]
)
db_connection.commit()

empty_str = ""
cursor.execute(
"INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)",
[empty_str]
)
db_connection.commit()

cursor.execute(
"INSERT INTO #pytest_nvarchar_nonlob (col) VALUES (?)",
[None]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_nonlob")
# rows = cursor.fetchall()
# assert rows == [[small_str], [empty_str], [None]]

finally:
pass


def test_nvarchar_max_insert_lob(cursor, db_connection):
"""Test large NVARCHAR(MAX) insert (LOB path)."""
try:
cursor.execute("CREATE TABLE #pytest_nvarchar_lob (col NVARCHAR(MAX))")
db_connection.commit()

large_str = "📝" * 50_000 # each emoji = 2 UTF-16 code units, total > 100k bytes
cursor.execute(
"INSERT INTO #pytest_nvarchar_lob (col) VALUES (?)",
[large_str]
)
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_lob")
# rows = cursor.fetchall()
# assert rows == [[large_str]]

finally:
pass

def test_nvarchar_max_boundary(cursor, db_connection):
"""Test NVARCHAR(MAX) at LOB boundary sizes."""
try:
cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_boundary")
cursor.execute("CREATE TABLE #pytest_nvarchar_boundary (col NVARCHAR(MAX))")
db_connection.commit()

# 4k BMP chars = 8k bytes
cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["A" * 4096])
# 4k emojis = 8k UTF-16 code units (16k bytes)
cursor.execute("INSERT INTO #pytest_nvarchar_boundary (col) VALUES (?)", ["📝" * 4096])
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_boundary")
# rows = cursor.fetchall()
# assert rows == [["A" * 4096], ["📝" * 4096]]
finally:
pass


def test_nvarchar_max_chunk_edge(cursor, db_connection):
"""Test NVARCHAR(MAX) insert slightly larger than a chunk."""
try:
cursor.execute("DROP TABLE IF EXISTS #pytest_nvarchar_chunk")
cursor.execute("CREATE TABLE #pytest_nvarchar_chunk (col NVARCHAR(MAX))")
db_connection.commit()

chunk_size = 8192 # bytes
test_str = "📝" * ((chunk_size // 4) + 3) # slightly > 1 chunk
cursor.execute("INSERT INTO #pytest_nvarchar_chunk (col) VALUES (?)", [test_str])
db_connection.commit()

# Fetch commented for now
# cursor.execute("SELECT col FROM #pytest_nvarchar_chunk")
# row = cursor.fetchone()
# assert row[0] == test_str
finally:
pass

def test_empty_string_chunk(cursor, db_connection):
"""Test inserting empty strings into VARCHAR(MAX) and NVARCHAR(MAX)."""
try:
cursor.execute("DROP TABLE IF EXISTS #pytest_empty_string")
cursor.execute("""
CREATE TABLE #pytest_empty_string (
varchar_col VARCHAR(MAX),
nvarchar_col NVARCHAR(MAX)
)
""")
db_connection.commit()

empty_varchar = ""
empty_nvarchar = ""
cursor.execute(
"INSERT INTO #pytest_empty_string (varchar_col, nvarchar_col) VALUES (?, ?)",
[empty_varchar, empty_nvarchar]
)
db_connection.commit()

cursor.execute("SELECT LEN(varchar_col), LEN(nvarchar_col) FROM #pytest_empty_string")
row = tuple(int(x) for x in cursor.fetchone())
assert row == (0, 0), f"Expected lengths (0,0), got {row}"
finally:
cursor.execute("DROP TABLE IF EXISTS #pytest_empty_string")
db_connection.commit()

def test_close(db_connection):
"""Test closing the cursor"""
Expand Down