Skip to content

Commit a1db275

Browse files
authored
FEAT: adding cursor.tables (#185)
### Work Item / Issue Reference <!-- IMPORTANT: Please follow the PR template guidelines below. For mssql-python maintainers: Insert your ADO Work Item ID below (e.g. AB#37452) For external contributors: Insert Github Issue number below (e.g. #149) Only one reference is required - either GitHub issue OR ADO Work Item. --> <!-- mssql-python maintainers: ADO Work Item --> > [AB#34926](https://sqlclientdrivers.visualstudio.com/c6d89619-62de-46a0-8b46-70b92a84d85e/_workitems/edit/34926) ------------------------------------------------------------------- ### Summary This pull request adds a new `tables()` method to the `Cursor` class in `mssql_python/cursor.py`, providing a way to query metadata about tables in the database, including support for filtering by name, schema, catalog, and table type. It also introduces comprehensive test coverage for this new method in `tests/test_004_cursor.py`. Additionally, the `skip()` method in the cursor is simplified by delegating to the existing `scroll()` method. **New feature: Table metadata querying** - Added a `tables()` method to the `Cursor` class, enabling users to retrieve information about tables with support for filtering by table name (including temporary tables), schema, catalog, and table type (supports both string and list input). The method returns the cursor itself for easy chaining and iteration. **Testing improvements** - Introduced a suite of tests for the new `tables()` method, covering basic usage, filtering by name, schema, and type, wildcard support, combined filters, empty results, iteration, method chaining, and existence checks. These tests ensure the method works as intended and handles edge cases. **Code simplification** - Refactored the `skip()` method in the cursor to delegate to the `scroll()` method in 'relative' mode, removing redundant validation and manual row skipping logic. --------- Co-authored-by: Jahnvi Thakkar <[email protected]>
1 parent fcbdd71 commit a1db275

File tree

5 files changed

+631
-35
lines changed

5 files changed

+631
-35
lines changed

mssql_python/cursor.py

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,9 @@ def fetchone(self) -> Union[None, Row]:
908908
# Update internal position after successful fetch
909909
self._increment_rownumber()
910910

911-
# Create and return a Row object
912-
return Row(row_data, self.description)
911+
# Create and return a Row object, passing column name map if available
912+
column_map = getattr(self, '_column_name_map', None)
913+
return Row(row_data, self.description, column_map)
913914
except Exception as e:
914915
# On error, don't increment rownumber - rethrow the error
915916
raise e
@@ -948,7 +949,8 @@ def fetchmany(self, size: int = None) -> List[Row]:
948949
self._rownumber = self._next_row_index - 1
949950

950951
# Convert raw data to Row objects
951-
return [Row(row_data, self.description) for row_data in rows_data]
952+
column_map = getattr(self, '_column_name_map', None)
953+
return [Row(row_data, self.description, column_map) for row_data in rows_data]
952954
except Exception as e:
953955
# On error, don't increment rownumber - rethrow the error
954956
raise e
@@ -977,7 +979,8 @@ def fetchall(self) -> List[Row]:
977979
self._rownumber = self._next_row_index - 1
978980

979981
# Convert raw data to Row objects
980-
return [Row(row_data, self.description) for row_data in rows_data]
982+
column_map = getattr(self, '_column_name_map', None)
983+
return [Row(row_data, self.description, column_map) for row_data in rows_data]
981984
except Exception as e:
982985
# On error, don't increment rownumber - rethrow the error
983986
raise e
@@ -1258,19 +1261,139 @@ def skip(self, count: int) -> None:
12581261
# Clear messages
12591262
self.messages = []
12601263

1261-
# Validate arguments
1262-
if not isinstance(count, int):
1263-
raise ProgrammingError("Count must be an integer", "Invalid argument type")
1264+
# Simply delegate to the scroll method with 'relative' mode
1265+
self.scroll(count, 'relative')
1266+
1267+
def _execute_tables(self, stmt_handle, catalog_name=None, schema_name=None, table_name=None,
1268+
table_type=None, search_escape=None):
1269+
"""
1270+
Execute SQLTables ODBC function to retrieve table metadata.
12641271
1265-
if count < 0:
1266-
raise NotSupportedError("Negative skip values are not supported", "Backward scrolling not supported")
1272+
Args:
1273+
stmt_handle: ODBC statement handle
1274+
catalog_name: The catalog name pattern
1275+
schema_name: The schema name pattern
1276+
table_name: The table name pattern
1277+
table_type: The table type filter
1278+
search_escape: The escape character for pattern matching
1279+
"""
1280+
# Convert None values to empty strings for ODBC
1281+
catalog = "" if catalog_name is None else catalog_name
1282+
schema = "" if schema_name is None else schema_name
1283+
table = "" if table_name is None else table_name
1284+
types = "" if table_type is None else table_type
12671285

1268-
# Skip zero is a no-op
1269-
if count == 0:
1270-
return
1286+
# Call the ODBC SQLTables function
1287+
retcode = ddbc_bindings.DDBCSQLTables(
1288+
stmt_handle,
1289+
catalog,
1290+
schema,
1291+
table,
1292+
types
1293+
)
1294+
1295+
# Check return code and handle errors
1296+
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, stmt_handle, retcode)
1297+
1298+
# Capture any diagnostic messages
1299+
if stmt_handle:
1300+
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(stmt_handle))
1301+
1302+
def tables(self, table=None, catalog=None, schema=None, tableType=None):
1303+
"""
1304+
Returns information about tables in the database that match the given criteria using
1305+
the SQLTables ODBC function.
1306+
1307+
Args:
1308+
table (str, optional): The table name pattern. Default is None (all tables).
1309+
catalog (str, optional): The catalog name. Default is None.
1310+
schema (str, optional): The schema name pattern. Default is None.
1311+
tableType (str or list, optional): The table type filter. Default is None.
1312+
Example: "TABLE" or ["TABLE", "VIEW"]
1313+
1314+
Returns:
1315+
list: A list of Row objects containing table information with these columns:
1316+
- table_cat: Catalog name
1317+
- table_schem: Schema name
1318+
- table_name: Table name
1319+
- table_type: Table type (e.g., "TABLE", "VIEW")
1320+
- remarks: Comments about the table
1321+
1322+
Notes:
1323+
This method only processes the standard five columns as defined in the ODBC
1324+
specification. Any additional columns that might be returned by specific ODBC
1325+
drivers are not included in the result set.
1326+
1327+
Example:
1328+
# Get all tables in the database
1329+
tables = cursor.tables()
1330+
1331+
# Get all tables in schema 'dbo'
1332+
tables = cursor.tables(schema='dbo')
1333+
1334+
# Get table named 'Customers'
1335+
tables = cursor.tables(table='Customers')
1336+
1337+
# Get all views
1338+
tables = cursor.tables(tableType='VIEW')
1339+
"""
1340+
self._check_closed()
1341+
1342+
# Clear messages
1343+
self.messages = []
1344+
1345+
# Always reset the cursor first to ensure clean state
1346+
self._reset_cursor()
1347+
1348+
# Format table_type parameter - SQLTables expects comma-separated string
1349+
table_type_str = None
1350+
if tableType is not None:
1351+
if isinstance(tableType, (list, tuple)):
1352+
table_type_str = ",".join(tableType)
1353+
else:
1354+
table_type_str = str(tableType)
1355+
1356+
# Call SQLTables via the helper method
1357+
self._execute_tables(
1358+
self.hstmt,
1359+
catalog_name=catalog,
1360+
schema_name=schema,
1361+
table_name=table,
1362+
table_type=table_type_str
1363+
)
1364+
1365+
# Initialize description from column metadata
1366+
column_metadata = []
1367+
try:
1368+
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
1369+
self._initialize_description(column_metadata)
1370+
except Exception:
1371+
# If describe fails, create a manual description for the standard columns
1372+
column_types = [str, str, str, str, str]
1373+
self.description = [
1374+
("table_cat", column_types[0], None, 128, 128, 0, True),
1375+
("table_schem", column_types[1], None, 128, 128, 0, True),
1376+
("table_name", column_types[2], None, 128, 128, 0, False),
1377+
("table_type", column_types[3], None, 128, 128, 0, False),
1378+
("remarks", column_types[4], None, 254, 254, 0, True)
1379+
]
1380+
1381+
# Define column names in ODBC standard order
1382+
column_names = [
1383+
"table_cat", "table_schem", "table_name", "table_type", "remarks"
1384+
]
1385+
1386+
# Fetch all rows
1387+
rows_data = []
1388+
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
1389+
1390+
# Create a column map for attribute access
1391+
column_map = {name: i for i, name in enumerate(column_names)}
1392+
1393+
# Create Row objects with the column map
1394+
result_rows = []
1395+
for row_data in rows_data:
1396+
row = Row(row_data, self.description, column_map)
1397+
result_rows.append(row)
12711398

1272-
# Skip the rows by fetching and discarding
1273-
for _ in range(count):
1274-
row = self.fetchone()
1275-
if row is None:
1276-
raise IndexError("Cannot skip beyond the end of the result set")
1399+
return result_rows

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr;
134134

135135
// Diagnostic APIs
136136
SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr;
137+
SQLTablesFunc SQLTables_ptr = nullptr;
137138

138139
namespace {
139140

@@ -786,6 +787,7 @@ DriverHandle LoadDriverOrThrowException() {
786787
SQLFreeStmt_ptr = GetFunctionPointer<SQLFreeStmtFunc>(handle, "SQLFreeStmt");
787788

788789
SQLGetDiagRec_ptr = GetFunctionPointer<SQLGetDiagRecFunc>(handle, "SQLGetDiagRecW");
790+
SQLTables_ptr = GetFunctionPointer<SQLTablesFunc>(handle, "SQLTablesW");
789791

790792
bool success =
791793
SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr &&
@@ -796,7 +798,7 @@ DriverHandle LoadDriverOrThrowException() {
796798
SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr &&
797799
SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr &&
798800
SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
799-
SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
801+
SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLTables_ptr;
800802

801803
if (!success) {
802804
ThrowStdException("Failed to load required function pointers from driver.");
@@ -982,6 +984,91 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q
982984
return ret;
983985
}
984986

987+
// Wrapper for SQLTables
988+
SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle,
989+
const std::wstring& catalog,
990+
const std::wstring& schema,
991+
const std::wstring& table,
992+
const std::wstring& tableType) {
993+
994+
if (!SQLTables_ptr) {
995+
LOG("Function pointer not initialized. Loading the driver.");
996+
DriverLoader::getInstance().loadDriver();
997+
}
998+
999+
SQLWCHAR* catalogPtr = nullptr;
1000+
SQLWCHAR* schemaPtr = nullptr;
1001+
SQLWCHAR* tablePtr = nullptr;
1002+
SQLWCHAR* tableTypePtr = nullptr;
1003+
SQLSMALLINT catalogLen = 0;
1004+
SQLSMALLINT schemaLen = 0;
1005+
SQLSMALLINT tableLen = 0;
1006+
SQLSMALLINT tableTypeLen = 0;
1007+
1008+
std::vector<SQLWCHAR> catalogBuffer;
1009+
std::vector<SQLWCHAR> schemaBuffer;
1010+
std::vector<SQLWCHAR> tableBuffer;
1011+
std::vector<SQLWCHAR> tableTypeBuffer;
1012+
1013+
#if defined(__APPLE__) || defined(__linux__)
1014+
// On Unix platforms, convert wstring to SQLWCHAR array
1015+
if (!catalog.empty()) {
1016+
catalogBuffer = WStringToSQLWCHAR(catalog);
1017+
catalogPtr = catalogBuffer.data();
1018+
catalogLen = SQL_NTS;
1019+
}
1020+
if (!schema.empty()) {
1021+
schemaBuffer = WStringToSQLWCHAR(schema);
1022+
schemaPtr = schemaBuffer.data();
1023+
schemaLen = SQL_NTS;
1024+
}
1025+
if (!table.empty()) {
1026+
tableBuffer = WStringToSQLWCHAR(table);
1027+
tablePtr = tableBuffer.data();
1028+
tableLen = SQL_NTS;
1029+
}
1030+
if (!tableType.empty()) {
1031+
tableTypeBuffer = WStringToSQLWCHAR(tableType);
1032+
tableTypePtr = tableTypeBuffer.data();
1033+
tableTypeLen = SQL_NTS;
1034+
}
1035+
#else
1036+
// On Windows, direct assignment works
1037+
if (!catalog.empty()) {
1038+
catalogPtr = const_cast<SQLWCHAR*>(catalog.c_str());
1039+
catalogLen = SQL_NTS;
1040+
}
1041+
if (!schema.empty()) {
1042+
schemaPtr = const_cast<SQLWCHAR*>(schema.c_str());
1043+
schemaLen = SQL_NTS;
1044+
}
1045+
if (!table.empty()) {
1046+
tablePtr = const_cast<SQLWCHAR*>(table.c_str());
1047+
tableLen = SQL_NTS;
1048+
}
1049+
if (!tableType.empty()) {
1050+
tableTypePtr = const_cast<SQLWCHAR*>(tableType.c_str());
1051+
tableTypeLen = SQL_NTS;
1052+
}
1053+
#endif
1054+
1055+
SQLRETURN ret = SQLTables_ptr(
1056+
StatementHandle->get(),
1057+
catalogPtr, catalogLen,
1058+
schemaPtr, schemaLen,
1059+
tablePtr, tableLen,
1060+
tableTypePtr, tableTypeLen
1061+
);
1062+
1063+
if (!SQL_SUCCEEDED(ret)) {
1064+
LOG("SQLTables failed with return code: {}", ret);
1065+
} else {
1066+
LOG("SQLTables succeeded");
1067+
}
1068+
1069+
return ret;
1070+
}
1071+
9851072
// Executes the provided query. If the query is parametrized, it prepares the statement and
9861073
// binds the parameters. Otherwise, it executes the query directly.
9871074
// 'usePrepare' parameter can be used to disable the prepare step for queries that might already
@@ -2616,6 +2703,12 @@ PYBIND11_MODULE(ddbc_bindings, m) {
26162703
m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords,
26172704
"Get all diagnostic records for a handle",
26182705
py::arg("handle"));
2706+
// Add to PYBIND11_MODULE section
2707+
m.def("DDBCSQLTables", &SQLTables_wrap,
2708+
"Get table information using ODBC SQLTables",
2709+
py::arg("StatementHandle"), py::arg("catalog") = std::wstring(),
2710+
py::arg("schema") = std::wstring(), py::arg("table") = std::wstring(),
2711+
py::arg("tableType") = std::wstring());
26192712

26202713
// Add a version attribute
26212714
m.attr("__version__") = "1.0.0";

mssql_python/pybind/ddbc_bindings.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,18 @@ typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR
105105
typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT);
106106
typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER,
107107
SQLSMALLINT, SQLSMALLINT*, SQLPOINTER);
108-
108+
typedef SQLRETURN (*SQLTablesFunc)(
109+
SQLHSTMT StatementHandle,
110+
SQLWCHAR* CatalogName,
111+
SQLSMALLINT NameLength1,
112+
SQLWCHAR* SchemaName,
113+
SQLSMALLINT NameLength2,
114+
SQLWCHAR* TableName,
115+
SQLSMALLINT NameLength3,
116+
SQLWCHAR* TableType,
117+
SQLSMALLINT NameLength4
118+
);
119+
109120
// Transaction APIs
110121
typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT);
111122

@@ -148,6 +159,7 @@ extern SQLBindColFunc SQLBindCol_ptr;
148159
extern SQLDescribeColFunc SQLDescribeCol_ptr;
149160
extern SQLMoreResultsFunc SQLMoreResults_ptr;
150161
extern SQLColAttributeFunc SQLColAttribute_ptr;
162+
extern SQLTablesFunc SQLTables_ptr;
151163

152164
// Transaction APIs
153165
extern SQLEndTranFunc SQLEndTran_ptr;

mssql_python/row.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,27 @@ class Row:
99
print(row.column_name) # Access by column name
1010
"""
1111

12-
def __init__(self, values, cursor_description):
12+
def __init__(self, values, description, column_map=None):
1313
"""
14-
Initialize a Row object with values and cursor description.
14+
Initialize a Row object with values and description.
1515
1616
Args:
17-
values: List of values for this row
18-
cursor_description: The cursor description containing column metadata
17+
values: List of values for this row.
18+
description: Description of the columns (from cursor.description).
19+
column_map: Optional mapping of column names to indices.
1920
"""
2021
self._values = values
22+
self._description = description
2123

22-
# TODO: ADO task - Optimize memory usage by sharing column map across rows
23-
# Instead of storing the full cursor_description in each Row object:
24-
# 1. Build the column map once at the cursor level after setting description
25-
# 2. Pass only this map to each Row instance
26-
# 3. Remove cursor_description from Row objects entirely
27-
28-
# Create mapping of column names to indices
29-
self._column_map = {}
30-
for i, desc in enumerate(cursor_description):
31-
if desc and desc[0]: # Ensure column name exists
32-
self._column_map[desc[0]] = i
24+
# Build column map if not provided
25+
if column_map is None:
26+
self._column_map = {}
27+
for i, desc in enumerate(description):
28+
col_name = desc[0]
29+
self._column_map[col_name] = i
30+
self._column_map[col_name.lower()] = i # Add lowercase for case-insensitivity
31+
else:
32+
self._column_map = column_map
3333

3434
def __getitem__(self, index):
3535
"""Allow accessing by numeric index: row[0]"""

0 commit comments

Comments
 (0)