Skip to content

Commit 85a4487

Browse files
committed
streaming support in fetch methods for varbinarymax
1 parent 9ad334d commit 85a4487

File tree

2 files changed

+244
-74
lines changed

2 files changed

+244
-74
lines changed

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 186 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,96 @@ SQLRETURN SQLFetch_wrap(SqlHandlePtr StatementHandle) {
17511751
return SQLFetch_ptr(StatementHandle->get());
17521752
}
17531753

1754+
static py::object FetchLobColumnData(SQLHSTMT hStmt,
1755+
SQLUSMALLINT colIndex,
1756+
SQLSMALLINT cType,
1757+
bool isWideChar,
1758+
bool isBinary)
1759+
{
1760+
std::vector<char> buffer;
1761+
SQLLEN indicator = 0;
1762+
SQLRETURN ret;
1763+
int loopCount = 0;
1764+
1765+
while (true) {
1766+
++loopCount;
1767+
std::vector<char> chunk(DAE_CHUNK_SIZE);
1768+
ret = SQLGetData_ptr(
1769+
hStmt,
1770+
colIndex,
1771+
cType,
1772+
chunk.data(),
1773+
DAE_CHUNK_SIZE,
1774+
&indicator
1775+
);
1776+
if (indicator == SQL_NULL_DATA) {
1777+
LOG("Loop {}: Column {} is NULL", loopCount, colIndex);
1778+
return py::none();
1779+
}
1780+
if (!SQL_SUCCEEDED(ret) && ret != SQL_SUCCESS_WITH_INFO) {
1781+
LOG("Loop {}: Error fetching col={} with cType={} ret={}", loopCount, colIndex, cType, ret);
1782+
return py::none();
1783+
}
1784+
SQLLEN copyCount = 0;
1785+
if (indicator > 0 && indicator != SQL_NO_TOTAL) {
1786+
copyCount = std::min<SQLLEN>(indicator, DAE_CHUNK_SIZE);
1787+
} else {
1788+
copyCount = DAE_CHUNK_SIZE;
1789+
}
1790+
1791+
// Check if last byte(s) is a null terminator
1792+
if (copyCount > 0) {
1793+
if (!isWideChar && chunk[copyCount - 1] == '\0') {
1794+
--copyCount;
1795+
LOG("Loop {}: Trimmed null terminator (narrow)", loopCount);
1796+
} else if (copyCount >= sizeof(wchar_t)) {
1797+
auto wcharBuf = reinterpret_cast<const wchar_t*>(chunk.data());
1798+
if (wcharBuf[(copyCount / sizeof(wchar_t)) - 1] == L'\0') {
1799+
copyCount -= sizeof(wchar_t);
1800+
LOG("Loop {}: Trimmed null terminator (wide)", loopCount);
1801+
}
1802+
}
1803+
}
1804+
if (copyCount > 0 && indicator != 0) {
1805+
buffer.insert(buffer.end(), chunk.begin(), chunk.begin() + copyCount);
1806+
LOG("Loop {}: Appended {} bytes", loopCount, copyCount);
1807+
}
1808+
if (ret == SQL_SUCCESS) {
1809+
LOG("Loop {}: SQL_SUCCESS → no more data", loopCount);
1810+
break;
1811+
}
1812+
}
1813+
LOG("FetchLobColumnData: Total bytes collected = {}", buffer.size());
1814+
1815+
// Handle zero-length buffers correctly
1816+
if (buffer.empty()) {
1817+
if (isBinary) {
1818+
LOG("FetchLobColumnData: Returning empty bytes for binary column {}", colIndex);
1819+
return py::bytes(nullptr, 0);
1820+
} else if (isWideChar) {
1821+
LOG("FetchLobColumnData: Returning empty string for wide text column {}", colIndex);
1822+
return py::str("");
1823+
} else {
1824+
LOG("FetchLobColumnData: Returning empty string for narrow text column {}", colIndex);
1825+
return py::str("");
1826+
}
1827+
}
1828+
1829+
if (isWideChar) {
1830+
std::wstring wstr(reinterpret_cast<const wchar_t*>(buffer.data()),
1831+
buffer.size() / sizeof(wchar_t));
1832+
LOG("FetchLobColumnData: Returning wide string of length {}", wstr.length());
1833+
return py::cast(wstr);
1834+
}
1835+
if (isBinary) {
1836+
LOG("FetchLobColumnData: Returning binary of {} bytes", buffer.size());
1837+
return py::bytes(buffer.data(), buffer.size());
1838+
}
1839+
std::string str(buffer.data(), buffer.size());
1840+
LOG("FetchLobColumnData: Returning narrow string of length {}", str.length());
1841+
return py::str(str);
1842+
}
1843+
17541844
// Helper function to retrieve column data
17551845
// TODO: Handle variable length data correctly
17561846
SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, py::list& row) {
@@ -2059,45 +2149,39 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
20592149
case SQL_BINARY:
20602150
case SQL_VARBINARY:
20612151
case SQL_LONGVARBINARY: {
2062-
// TODO: revisit
2063-
HandleZeroColumnSizeAtFetch(columnSize);
2064-
std::unique_ptr<SQLCHAR[]> dataBuffer(new SQLCHAR[columnSize]);
2065-
SQLLEN dataLen;
2066-
ret = SQLGetData_ptr(hStmt, i, SQL_C_BINARY, dataBuffer.get(), columnSize, &dataLen);
2067-
2068-
if (SQL_SUCCEEDED(ret)) {
2069-
// TODO: Refactor these if's across other switches to avoid code duplication
2070-
if (dataLen > 0) {
2071-
if (static_cast<size_t>(dataLen) <= columnSize) {
2072-
row.append(py::bytes(reinterpret_cast<const char*>(
2073-
dataBuffer.get()), dataLen));
2074-
} else {
2075-
// In this case, buffer size is smaller, and data to be retrieved is longer
2076-
// TODO: Revisit
2077-
std::ostringstream oss;
2078-
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
2079-
<< "to be retrieved is longer (" << dataLen << "). ColumnID - "
2080-
<< i << ", datatype - " << dataType;
2081-
ThrowStdException(oss.str());
2152+
// Use streaming for large VARBINARY (columnSize unknown or > 8000)
2153+
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 8000) {
2154+
LOG("Streaming LOB for column {} (VARBINARY)", i);
2155+
row.append(FetchLobColumnData(hStmt, i, SQL_C_BINARY, false, true));
2156+
} else {
2157+
// Small VARBINARY, fetch directly
2158+
std::vector<SQLCHAR> dataBuffer(columnSize);
2159+
SQLLEN dataLen;
2160+
ret = SQLGetData_ptr(hStmt, i, SQL_C_BINARY, dataBuffer.data(), columnSize, &dataLen);
2161+
2162+
if (SQL_SUCCEEDED(ret)) {
2163+
if (dataLen > 0) {
2164+
if (static_cast<size_t>(dataLen) <= columnSize) {
2165+
row.append(py::bytes(reinterpret_cast<const char*>(dataBuffer.data()), dataLen));
2166+
} else {
2167+
std::ostringstream oss;
2168+
oss << "Buffer length for fetch (" << columnSize << ") is smaller than actual data ("
2169+
<< dataLen << "). ColumnID - " << i << ", datatype - " << dataType;
2170+
ThrowStdException(oss.str());
2171+
}
2172+
} else if (dataLen == SQL_NULL_DATA) {
2173+
row.append(py::none());
2174+
} else if (dataLen == 0) {
2175+
row.append(py::bytes(""));
2176+
} else {
2177+
LOG("SQLGetData returned unexpected negative length: {}. Column ID - {}", dataLen, i);
2178+
ThrowStdException("Unexpected negative SQLGetData length");
20822179
}
2083-
} else if (dataLen == SQL_NULL_DATA) {
2084-
row.append(py::none());
2085-
} else if (dataLen == 0) {
2086-
// Empty bytes
2087-
row.append(py::bytes(""));
2088-
} else if (dataLen < 0) {
2089-
// This is unexpected
2090-
LOG("SQLGetData returned an unexpected negative data length. "
2091-
"Raising exception. Column ID - {}, Data Type - {}, Data Length - {}",
2092-
i, dataType, dataLen);
2093-
ThrowStdException("SQLGetData returned an unexpected negative data length");
2180+
} else {
2181+
LOG("Error retrieving VARBINARY data for column {}. SQLGetData rc = {}", i, ret);
2182+
row.append(py::none());
20942183
}
2095-
} else {
2096-
LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return "
2097-
"code - {}. Returning NULL value instead",
2098-
i, dataType, ret);
2099-
row.append(py::none());
2100-
}
2184+
}
21012185
break;
21022186
}
21032187
case SQL_TINYINT: {
@@ -2342,7 +2426,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
23422426
// Fetch rows in batches
23432427
// TODO: Move to anonymous namespace, since it is not used outside this file
23442428
SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames,
2345-
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched) {
2429+
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched, const std::vector<SQLUSMALLINT>& lobColumns) {
23462430
LOG("Fetching data in batches");
23472431
SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0);
23482432
if (ret == SQL_NO_DATA) {
@@ -2539,21 +2623,12 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
25392623
case SQL_BINARY:
25402624
case SQL_VARBINARY:
25412625
case SQL_LONGVARBINARY: {
2542-
// TODO: variable length data needs special handling, this logic wont suffice
25432626
SQLULEN columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
2544-
HandleZeroColumnSizeAtFetch(columnSize);
2545-
if (static_cast<size_t>(dataLen) <= columnSize) {
2546-
row.append(py::bytes(reinterpret_cast<const char*>(
2547-
&buffers.charBuffers[col - 1][i * columnSize]),
2548-
dataLen));
2627+
bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end();
2628+
if (!isLob && static_cast<size_t>(dataLen) <= columnSize) {
2629+
row.append(py::bytes(reinterpret_cast<const char*>(&buffers.charBuffers[col - 1][i * columnSize]), dataLen));
25492630
} else {
2550-
// In this case, buffer size is smaller, and data to be retrieved is longer
2551-
// TODO: Revisit
2552-
std::ostringstream oss;
2553-
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
2554-
<< "to be retrieved is longer (" << dataLen << "). ColumnID - "
2555-
<< col << ", datatype - " << dataType;
2556-
ThrowStdException(oss.str());
2631+
row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true));
25572632
}
25582633
break;
25592634
}
@@ -2682,6 +2757,35 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
26822757
return ret;
26832758
}
26842759

2760+
std::vector<SQLUSMALLINT> lobColumns;
2761+
for (SQLSMALLINT i = 0; i < numCols; i++) {
2762+
auto colMeta = columnNames[i].cast<py::dict>();
2763+
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
2764+
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();
2765+
2766+
if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
2767+
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
2768+
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
2769+
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > 8000)) {
2770+
lobColumns.push_back(i + 1); // 1-based
2771+
}
2772+
}
2773+
2774+
// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
2775+
if (!lobColumns.empty()) {
2776+
LOG("LOB columns detected → using per-row SQLGetData path");
2777+
while (true) {
2778+
ret = SQLFetch_ptr(hStmt);
2779+
if (ret == SQL_NO_DATA) break;
2780+
if (!SQL_SUCCEEDED(ret)) return ret;
2781+
2782+
py::list row;
2783+
SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly
2784+
rows.append(row);
2785+
}
2786+
return SQL_SUCCESS;
2787+
}
2788+
26852789
// Initialize column buffers
26862790
ColumnBuffers buffers(numCols, fetchSize);
26872791

@@ -2696,7 +2800,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
26962800
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0);
26972801
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);
26982802

2699-
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
2803+
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
27002804
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
27012805
LOG("Error when fetching data");
27022806
return ret;
@@ -2775,6 +2879,35 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {
27752879
}
27762880
LOG("Fetching data in batch sizes of {}", fetchSize);
27772881

2882+
std::vector<SQLUSMALLINT> lobColumns;
2883+
for (SQLSMALLINT i = 0; i < numCols; i++) {
2884+
auto colMeta = columnNames[i].cast<py::dict>();
2885+
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
2886+
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();
2887+
2888+
if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
2889+
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
2890+
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
2891+
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > 8000)) {
2892+
lobColumns.push_back(i + 1); // 1-based
2893+
}
2894+
}
2895+
2896+
// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
2897+
if (!lobColumns.empty()) {
2898+
LOG("LOB columns detected → using per-row SQLGetData path");
2899+
while (true) {
2900+
ret = SQLFetch_ptr(hStmt);
2901+
if (ret == SQL_NO_DATA) break;
2902+
if (!SQL_SUCCEEDED(ret)) return ret;
2903+
2904+
py::list row;
2905+
SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly
2906+
rows.append(row);
2907+
}
2908+
return SQL_SUCCESS;
2909+
}
2910+
27782911
ColumnBuffers buffers(numCols, fetchSize);
27792912

27802913
// Bind columns
@@ -2789,7 +2922,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {
27892922
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);
27902923

27912924
while (ret != SQL_NO_DATA) {
2792-
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
2925+
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
27932926
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
27942927
LOG("Error when fetching data");
27952928
return ret;

tests/test_004_cursor.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6309,34 +6309,71 @@ def test_binary_data_over_8000_bytes(cursor, db_connection):
63096309
drop_table_if_exists(cursor, "#pytest_small_binary")
63106310
db_connection.commit()
63116311

6312-
def test_binary_data_large(cursor, db_connection):
6313-
"""Test insertion of binary data larger than 8000 bytes with streaming support."""
6312+
def test_varbinarymax_insert_fetch(cursor, db_connection):
6313+
"""Test for VARBINARY(MAX) insert and fetch (streaming support) using execute per row"""
63146314
try:
6315-
drop_table_if_exists(cursor, "#pytest_large_binary")
6315+
# Create test table
6316+
drop_table_if_exists(cursor, "#pytest_varbinarymax")
63166317
cursor.execute("""
6317-
CREATE TABLE #pytest_large_binary (
6318-
id INT PRIMARY KEY,
6319-
large_binary VARBINARY(MAX)
6318+
CREATE TABLE #pytest_varbinarymax (
6319+
id INT,
6320+
binary_data VARBINARY(MAX)
63206321
)
63216322
""")
6322-
6323-
# Large binary data > 8000 bytes
6324-
large_data = b'A' * 10000 # 10 KB
6325-
cursor.execute("INSERT INTO #pytest_large_binary (id, large_binary) VALUES (?, ?)", (1, large_data))
6323+
6324+
# Prepare test data
6325+
test_data = [
6326+
(2, b''), # Empty bytes
6327+
(3, b'1234567890'), # Small binary
6328+
(4, b'A' * 9000), # Large binary > 8000 (streaming)
6329+
(5, b'B' * 20000), # Large binary > 8000 (streaming)
6330+
(6, b'C' * 8000), # Edge case: exactly 8000 bytes
6331+
(7, b'D' * 8001), # Edge case: just over 8000 bytes
6332+
]
6333+
6334+
# Insert each row using execute
6335+
for row_id, binary in test_data:
6336+
cursor.execute("INSERT INTO #pytest_varbinarymax VALUES (?, ?)", (row_id, binary))
63266337
db_connection.commit()
6327-
print("Inserted large binary data (>8000 bytes) successfully.")
6328-
6329-
# commented out for now
6330-
# cursor.execute("SELECT large_binary FROM #pytest_large_binary WHERE id=1")
6331-
# result = cursor.fetchone()
6332-
# assert result[0] == large_data, f"Large binary data mismatch, got {len(result[0])} bytes"
6333-
6334-
# print("Large binary data (>8000 bytes) inserted and verified successfully.")
6335-
6338+
6339+
# ---------- FETCHONE TEST (multi-column) ----------
6340+
cursor.execute("SELECT id, binary_data FROM #pytest_varbinarymax ORDER BY id")
6341+
rows = []
6342+
while True:
6343+
row = cursor.fetchone()
6344+
if row is None:
6345+
break
6346+
rows.append(row)
6347+
6348+
assert len(rows) == len(test_data), f"Expected {len(test_data)} rows, got {len(rows)}"
6349+
6350+
# Validate each row
6351+
for i, (expected_id, expected_data) in enumerate(test_data):
6352+
fetched_id, fetched_data = rows[i]
6353+
assert fetched_id == expected_id, f"Row {i+1} ID mismatch: expected {expected_id}, got {fetched_id}"
6354+
assert isinstance(fetched_data, bytes), f"Row {i+1} expected bytes, got {type(fetched_data)}"
6355+
assert fetched_data == expected_data, f"Row {i+1} data mismatch"
6356+
6357+
# ---------- FETCHALL TEST ----------
6358+
cursor.execute("SELECT id, binary_data FROM #pytest_varbinarymax ORDER BY id")
6359+
all_rows = cursor.fetchall()
6360+
assert len(all_rows) == len(test_data)
6361+
6362+
# ---------- FETCHMANY TEST ----------
6363+
cursor.execute("SELECT id, binary_data FROM #pytest_varbinarymax ORDER BY id")
6364+
batch_size = 2
6365+
batches = []
6366+
while True:
6367+
batch = cursor.fetchmany(batch_size)
6368+
if not batch:
6369+
break
6370+
batches.extend(batch)
6371+
assert len(batches) == len(test_data)
6372+
63366373
except Exception as e:
6337-
pytest.fail(f"Large binary data insertion test failed: {e}")
6374+
pytest.fail(f"VARBINARY(MAX) insert/fetch test failed: {e}")
63386375
finally:
6339-
drop_table_if_exists(cursor, "#pytest_large_binary")
6376+
drop_table_if_exists(cursor, "#pytest_varbinarymax")
63406377
db_connection.commit()
63416378

63426379

0 commit comments

Comments
 (0)