Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import weakref
import re
import codecs
from typing import Any
from mssql_python.cursor import Cursor
from mssql_python.helpers import add_driver_to_connection_str, sanitize_connection_string, sanitize_user_input, log
from mssql_python import ddbc_bindings
Expand Down Expand Up @@ -531,6 +532,174 @@ def cursor(self) -> Cursor:
self._cursors.add(cursor) # Track the cursor
return cursor

def execute(self, sql: str, *args: Any) -> Cursor:
"""
Creates a new Cursor object, calls its execute method, and returns the new cursor.

This is a convenience method that is not part of the DB API. Since a new Cursor
is allocated by each call, this should not be used if more than one SQL statement
needs to be executed on the connection.

Note on cursor lifecycle management:
- Each call creates a new cursor that is tracked by the connection's internal WeakSet
- Cursors are automatically dereferenced/closed when they go out of scope
- For long-running applications or loops, explicitly call cursor.close() when done
to release resources immediately rather than waiting for garbage collection

Args:
sql (str): The SQL query to execute.
*args: Parameters to be passed to the query.

Returns:
Cursor: A new cursor with the executed query.

Raises:
DatabaseError: If there is an error executing the query.
InterfaceError: If the connection is closed.

Example:
# Automatic cleanup (cursor goes out of scope after the operation)
row = connection.execute("SELECT name FROM users WHERE id = ?", 123).fetchone()

# Manual cleanup for more explicit resource management
cursor = connection.execute("SELECT * FROM large_table")
try:
# Use cursor...
rows = cursor.fetchall()
finally:
cursor.close() # Explicitly release resources
"""
cursor = self.cursor()
try:
# Add the cursor to our tracking set BEFORE execution
# This ensures it's tracked even if execution fails
self._cursors.add(cursor)

# Now execute the query
cursor.execute(sql, *args)
return cursor
except Exception:
# If execution fails, close the cursor to avoid leaking resources
cursor.close()
raise

def batch_execute(self, statements, params=None, reuse_cursor=None, auto_close=False):
"""
Execute multiple SQL statements efficiently using a single cursor.

This method allows executing multiple SQL statements in sequence using a single
cursor, which is more efficient than creating a new cursor for each statement.

Args:
statements (list): List of SQL statements to execute
params (list, optional): List of parameter sets corresponding to statements.
Each item can be None, a single parameter, or a sequence of parameters.
If None, no parameters will be used for any statement.
reuse_cursor (Cursor, optional): Existing cursor to reuse instead of creating a new one.
If None, a new cursor will be created.
auto_close (bool): Whether to close the cursor after execution if a new one was created.
Defaults to False. Has no effect if reuse_cursor is provided.

Returns:
tuple: (results, cursor) where:
- results is a list of execution results, one for each statement
- cursor is the cursor used for execution (useful if you want to keep using it)

Raises:
TypeError: If statements is not a list or if params is provided but not a list
ValueError: If params is provided but has different length than statements
DatabaseError: If there is an error executing any of the statements
InterfaceError: If the connection is closed

Example:
# Execute multiple statements with a single cursor
results, _ = conn.batch_execute([
"INSERT INTO users VALUES (?, ?)",
"UPDATE stats SET count = count + 1",
"SELECT * FROM users"
], [
(1, "user1"),
None,
None
])

# Last result contains the SELECT results
for row in results[-1]:
print(row)

# Reuse an existing cursor
my_cursor = conn.cursor()
results, _ = conn.batch_execute([
"SELECT * FROM table1",
"SELECT * FROM table2"
], reuse_cursor=my_cursor)

# Cursor remains open for further use
my_cursor.execute("SELECT * FROM table3")
"""
# Validate inputs
if not isinstance(statements, list):
raise TypeError("statements must be a list of SQL statements")

if params is not None:
if not isinstance(params, list):
raise TypeError("params must be a list of parameter sets")
if len(params) != len(statements):
raise ValueError("params list must have the same length as statements list")
else:
# Create a list of None values with the same length as statements
params = [None] * len(statements)

# Determine which cursor to use
is_new_cursor = reuse_cursor is None
cursor = self.cursor() if is_new_cursor else reuse_cursor

# Execute statements and collect results
results = []
try:
for i, (stmt, param) in enumerate(zip(statements, params)):
try:
# Execute the statement with parameters if provided
if param is not None:
cursor.execute(stmt, param)
else:
cursor.execute(stmt)

# For SELECT statements, fetch all rows
# For other statements, get the row count
if cursor.description is not None:
# This is a SELECT statement or similar that returns rows
results.append(cursor.fetchall())
else:
# This is an INSERT, UPDATE, DELETE or similar that doesn't return rows
results.append(cursor.rowcount)

log('debug', f"Executed batch statement {i+1}/{len(statements)}")

except Exception as e:
# If a statement fails, include statement context in the error
log('error', f"Error executing statement {i+1}/{len(statements)}: {e}")
raise

except Exception as e:
# If an error occurs and auto_close is True, close the cursor
if auto_close:
try:
# Close the cursor regardless of whether it's reused or new
cursor.close()
log('debug', "Automatically closed cursor after batch execution error")
except Exception as close_err:
log('warning', f"Error closing cursor after execution failure: {close_err}")
# Re-raise the original exception
raise

# Close the cursor if requested and we created a new one
if is_new_cursor and auto_close:
cursor.close()
log('debug', "Automatically closed cursor after batch execution")

return results, cursor

def commit(self) -> None:
"""
Commit the current transaction.
Expand All @@ -541,8 +710,16 @@ def commit(self) -> None:
that the changes are saved.

Raises:
InterfaceError: If the connection is closed.
DatabaseError: If there is an error while committing the transaction.
"""
# Check if connection is closed
if self._closed or self._conn is None:
raise InterfaceError(
driver_error="Cannot commit on a closed connection",
ddbc_error="Cannot commit on a closed connection",
)

# Commit the current transaction
self._conn.commit()
log('info', "Transaction committed successfully.")
Expand All @@ -556,8 +733,16 @@ def rollback(self) -> None:
transaction or if the changes should not be saved.

Raises:
InterfaceError: If the connection is closed.
DatabaseError: If there is an error while rolling back the transaction.
"""
# Check if connection is closed
if self._closed or self._conn is None:
raise InterfaceError(
driver_error="Cannot rollback on a closed connection",
ddbc_error="Cannot rollback on a closed connection",
)

# Roll back the current transaction
self._conn.rollback()
log('info', "Transaction rolled back successfully.")
Expand Down Expand Up @@ -623,6 +808,21 @@ def close(self) -> None:
self._closed = True

log('info', "Connection closed successfully.")

def _remove_cursor(self, cursor):
"""
Remove a cursor from the connection's tracking.

This method is called when a cursor is closed to ensure proper cleanup.

Args:
cursor: The cursor to remove from tracking.
"""
if hasattr(self, '_cursors'):
try:
self._cursors.discard(cursor)
except Exception:
pass # Ignore errors during cleanup

def __enter__(self) -> 'Connection':
"""
Expand Down
7 changes: 7 additions & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,13 @@ def close(self) -> None:
# Clear messages per DBAPI
self.messages = []

# Remove this cursor from the connection's tracking
if hasattr(self, 'connection') and self.connection and hasattr(self.connection, '_cursors'):
try:
self.connection._cursors.discard(self)
except Exception as e:
log('warning', "Error removing cursor from connection tracking: %s", e)

if self.hstmt:
self.hstmt.free()
self.hstmt = None
Expand Down
52 changes: 26 additions & 26 deletions tests/test_001_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,32 @@ def test_lowercase():
# Check if lowercase has the expected default value
assert lowercase is False, "lowercase should default to False"

def test_decimal_separator():
"""Test decimal separator functionality"""

# Check default value
assert getDecimalSeparator() == '.', "Default decimal separator should be '.'"

try:
# Test setting a new value
setDecimalSeparator(',')
assert getDecimalSeparator() == ',', "Decimal separator should be ',' after setting"

# Test invalid input
with pytest.raises(ValueError):
setDecimalSeparator('too long')

with pytest.raises(ValueError):
setDecimalSeparator('')

with pytest.raises(ValueError):
setDecimalSeparator(123) # Non-string input

finally:
# Restore default value
setDecimalSeparator('.')
assert getDecimalSeparator() == '.', "Decimal separator should be restored to '.'"

def test_lowercase_thread_safety_no_db():
"""
Tests concurrent modifications to mssql_python.lowercase without database interaction.
Expand Down Expand Up @@ -152,32 +178,6 @@ def test_lowercase():
# Check if lowercase has the expected default value
assert lowercase is False, "lowercase should default to False"

def test_decimal_separator():
"""Test decimal separator functionality"""

# Check default value
assert getDecimalSeparator() == '.', "Default decimal separator should be '.'"

try:
# Test setting a new value
setDecimalSeparator(',')
assert getDecimalSeparator() == ',', "Decimal separator should be ',' after setting"

# Test invalid input
with pytest.raises(ValueError):
setDecimalSeparator('too long')

with pytest.raises(ValueError):
setDecimalSeparator('')

with pytest.raises(ValueError):
setDecimalSeparator(123) # Non-string input

finally:
# Restore default value
setDecimalSeparator('.')
assert getDecimalSeparator() == '.', "Decimal separator should be restored to '.'"

def test_decimal_separator_edge_cases():
"""Test decimal separator edge cases and boundary conditions"""
import decimal
Expand Down
Loading