Skip to content

Commit fcbdd71

Browse files
committed
FEAT: Adding implementation for Cursor.message
1 parent 1b580bc commit fcbdd71

File tree

3 files changed

+334
-18
lines changed

3 files changed

+334
-18
lines changed

mssql_python/cursor.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def __init__(self, connection) -> None:
8080
self._next_row_index = 0 # internal: index of the next row the driver will return (0-based)
8181
self._has_result_set = False # Track if we have an active result set
8282

83+
self.messages = [] # Store diagnostic messages
84+
8385
def _is_unicode_string(self, param):
8486
"""
8587
Check if a string contains non-ASCII characters.
@@ -452,6 +454,9 @@ def close(self) -> None:
452454
if self.closed:
453455
raise Exception("Cursor is already closed.")
454456

457+
# Clear messages per DBAPI
458+
self.messages = []
459+
455460
if self.hstmt:
456461
self.hstmt.free()
457462
self.hstmt = None
@@ -695,6 +700,9 @@ def execute(
695700
if reset_cursor:
696701
self._reset_cursor()
697702

703+
# Clear any previous messages
704+
self.messages = []
705+
698706
param_info = ddbc_bindings.ParamInfo
699707
parameters_type = []
700708

@@ -742,7 +750,14 @@ def execute(
742750
self.is_stmt_prepared,
743751
use_prepare,
744752
)
753+
754+
# Check for errors but don't raise exceptions for info/warning messages
745755
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
756+
757+
# Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.)
758+
if self.hstmt:
759+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
760+
746761
self.last_executed_stmt = operation
747762

748763
# Update rowcount after execution
@@ -822,7 +837,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
822837
"""
823838
self._check_closed()
824839
self._reset_cursor()
825-
840+
841+
# Clear any previous messages
842+
self.messages = []
843+
826844
if not seq_of_parameters:
827845
self.rowcount = 0
828846
return
@@ -854,6 +872,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
854872
)
855873
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
856874

875+
# Capture any diagnostic messages after execution
876+
if self.hstmt:
877+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
878+
857879
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
858880
self.last_executed_stmt = operation
859881
self._initialize_description()
@@ -877,6 +899,9 @@ def fetchone(self) -> Union[None, Row]:
877899
try:
878900
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)
879901

902+
if self.hstmt:
903+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
904+
880905
if ret == ddbc_sql_const.SQL_NO_DATA.value:
881906
return None
882907

@@ -911,6 +936,10 @@ def fetchmany(self, size: int = None) -> List[Row]:
911936
rows_data = []
912937
try:
913938
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)
939+
940+
if self.hstmt:
941+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
942+
914943

915944
# Update rownumber for the number of rows actually fetched
916945
if rows_data and self._has_result_set:
@@ -937,6 +966,10 @@ def fetchall(self) -> List[Row]:
937966
rows_data = []
938967
try:
939968
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
969+
970+
if self.hstmt:
971+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
972+
940973

941974
# Update rownumber for the number of rows actually fetched
942975
if rows_data and self._has_result_set:
@@ -961,6 +994,9 @@ def nextset(self) -> Union[bool, None]:
961994
"""
962995
self._check_closed() # Check if the cursor is closed
963996

997+
# Clear messages per DBAPI
998+
self.messages = []
999+
9641000
# Skip to the next result set
9651001
ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt)
9661002
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
@@ -1041,6 +1077,9 @@ def commit(self):
10411077
"""
10421078
self._check_closed() # Check if the cursor is closed
10431079

1080+
# Clear messages per DBAPI
1081+
self.messages = []
1082+
10441083
# Delegate to the connection's commit method
10451084
self._connection.commit()
10461085

@@ -1067,6 +1106,9 @@ def rollback(self):
10671106
"""
10681107
self._check_closed() # Check if the cursor is closed
10691108

1109+
# Clear messages per DBAPI
1110+
self.messages = []
1111+
10701112
# Delegate to the connection's rollback method
10711113
self._connection.rollback()
10721114

@@ -1090,6 +1132,10 @@ def scroll(self, value: int, mode: str = 'relative') -> None:
10901132
This implementation emulates scrolling for forward-only cursors by consuming rows.
10911133
"""
10921134
self._check_closed()
1135+
1136+
# Clear messages per DBAPI
1137+
self.messages = []
1138+
10931139
if mode not in ('relative', 'absolute'):
10941140
raise ProgrammingError(
10951141
driver_error="Invalid scroll mode",
@@ -1195,29 +1241,36 @@ def _consume_rows_for_scroll(self, rows_to_consume: int) -> None:
11951241

11961242
def skip(self, count: int) -> None:
11971243
"""
1198-
Skip the next 'count' records in the query result set.
1199-
1200-
This is a convenience method that advances the cursor by 'count'
1201-
positions without returning the skipped rows.
1244+
Skip the next count records in the query result set.
12021245
12031246
Args:
1204-
count: Number of records to skip. Must be non-negative.
1205-
1206-
Returns:
1207-
None
1247+
count: Number of records to skip.
12081248
12091249
Raises:
1210-
ProgrammingError: If the cursor is closed or no result set is available.
1211-
NotSupportedError: If count is negative (backward scrolling not supported).
12121250
IndexError: If attempting to skip past the end of the result set.
1213-
1214-
Note:
1215-
For convenience, skip(0) is accepted and will do nothing.
1251+
ProgrammingError: If count is not an integer.
1252+
NotSupportedError: If attempting to skip backwards.
12161253
"""
1254+
from mssql_python.exceptions import ProgrammingError, NotSupportedError
1255+
12171256
self._check_closed()
12181257

1219-
if count == 0: # Skip 0 is a no-op
1258+
# Clear messages
1259+
self.messages = []
1260+
1261+
# Validate arguments
1262+
if not isinstance(count, int):
1263+
raise ProgrammingError("Count must be an integer", "Invalid argument type")
1264+
1265+
if count < 0:
1266+
raise NotSupportedError("Negative skip values are not supported", "Backward scrolling not supported")
1267+
1268+
# Skip zero is a no-op
1269+
if count == 0:
12201270
return
1221-
1222-
# Use existing scroll method with relative mode
1223-
self.scroll(count, 'relative')
1271+
1272+
# Skip the rows by fetching and discarding
1273+
for _ in range(count):
1274+
row = self.fetchone()
1275+
if row is None:
1276+
raise IndexError("Cannot skip beyond the end of the result set")

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,65 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET
901901
return errorInfo;
902902
}
903903

904+
py::list SQLGetAllDiagRecords(SqlHandlePtr handle) {
905+
LOG("Retrieving all diagnostic records");
906+
if (!SQLGetDiagRec_ptr) {
907+
LOG("Function pointer not initialized. Loading the driver.");
908+
DriverLoader::getInstance().loadDriver();
909+
}
910+
911+
py::list records;
912+
SQLHANDLE rawHandle = handle->get();
913+
SQLSMALLINT handleType = handle->type();
914+
915+
// Iterate through all available diagnostic records
916+
for (SQLSMALLINT recNumber = 1; ; recNumber++) {
917+
SQLWCHAR sqlState[6] = {0};
918+
SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0};
919+
SQLINTEGER nativeError = 0;
920+
SQLSMALLINT messageLen = 0;
921+
922+
SQLRETURN diagReturn = SQLGetDiagRec_ptr(
923+
handleType, rawHandle, recNumber, sqlState, &nativeError,
924+
message, SQL_MAX_MESSAGE_LENGTH, &messageLen);
925+
926+
if (diagReturn == SQL_NO_DATA || !SQL_SUCCEEDED(diagReturn))
927+
break;
928+
929+
#if defined(_WIN32)
930+
// On Windows, create a formatted UTF-8 string for state+error
931+
char stateWithError[50];
932+
sprintf(stateWithError, "[%ls] (%d)", sqlState, nativeError);
933+
934+
// Convert wide string message to UTF-8
935+
int msgSize = WideCharToMultiByte(CP_UTF8, 0, message, -1, NULL, 0, NULL, NULL);
936+
std::vector<char> msgBuffer(msgSize);
937+
WideCharToMultiByte(CP_UTF8, 0, message, -1, msgBuffer.data(), msgSize, NULL, NULL);
938+
939+
// Create the tuple with converted strings
940+
records.append(py::make_tuple(
941+
py::str(stateWithError),
942+
py::str(msgBuffer.data())
943+
));
944+
#else
945+
// On Unix, use the SQLWCHARToWString utility and then convert to UTF-8
946+
std::string stateStr = WideToUTF8(SQLWCHARToWString(sqlState));
947+
std::string msgStr = WideToUTF8(SQLWCHARToWString(message, messageLen));
948+
949+
// Format the state string
950+
std::string stateWithError = "[" + stateStr + "] (" + std::to_string(nativeError) + ")";
951+
952+
// Create the tuple with converted strings
953+
records.append(py::make_tuple(
954+
py::str(stateWithError),
955+
py::str(msgStr)
956+
));
957+
#endif
958+
}
959+
960+
return records;
961+
}
962+
904963
// Wrap SQLExecDirect
905964
SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) {
906965
LOG("Execute SQL query directly - {}", Query.c_str());
@@ -2553,6 +2612,10 @@ PYBIND11_MODULE(ddbc_bindings, m) {
25532612
m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set");
25542613
m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle");
25552614
m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors");
2615+
// Add this to your PYBIND11_MODULE section
2616+
m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords,
2617+
"Get all diagnostic records for a handle",
2618+
py::arg("handle"));
25562619

25572620
// Add a version attribute
25582621
m.attr("__version__") = "1.0.0";

0 commit comments

Comments
 (0)