Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class ConstantsDDBC(Enum):
SQL_FETCH_ABSOLUTE = 5
SQL_FETCH_RELATIVE = 6
SQL_FETCH_BOOKMARK = 8
SQL_DATETIMEOFFSET = -155
SQL_C_SS_TIMESTAMPOFFSET = 0x4001
SQL_SCOPE_CURROW = 0
SQL_BEST_ROWID = 1
SQL_ROWVER = 2
Expand Down
25 changes: 18 additions & 7 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,24 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
)

if isinstance(param, datetime.datetime):
return (
ddbc_sql_const.SQL_TIMESTAMP.value,
ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value,
26,
6,
False,
)
if param.tzinfo is not None:
# Timezone-aware datetime -> DATETIMEOFFSET
return (
ddbc_sql_const.SQL_DATETIMEOFFSET.value,
ddbc_sql_const.SQL_C_SS_TIMESTAMPOFFSET.value,
34,
7,
False,
)
else:
# Naive datetime -> TIMESTAMP
return (
ddbc_sql_const.SQL_TIMESTAMP.value,
ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value,
26,
6,
False,
)

if isinstance(param, datetime.date):
return (
Expand Down
111 changes: 108 additions & 3 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
#include <iostream>
#include <utility> // std::forward
#include <filesystem>

//-------------------------------------------------------------------------------------------------
// Macro definitions
//-------------------------------------------------------------------------------------------------

// This constant is not exposed via sql.h, hence define it here
#define SQL_SS_TIME2 (-154)

#define SQL_SS_TIMESTAMPOFFSET (-155)
#define SQL_C_SS_TIMESTAMPOFFSET (0x4001)
#define MAX_DIGITS_IN_NUMERIC 64

#define STRINGIFY_FOR_CASE(x) \
Expand Down Expand Up @@ -94,6 +94,20 @@ struct ColumnBuffers {
indicators(numCols, std::vector<SQLLEN>(fetchSize)) {}
};

// Struct to hold the DateTimeOffset structure
struct DateTimeOffset
{
SQLSMALLINT year;
SQLUSMALLINT month;
SQLUSMALLINT day;
SQLUSMALLINT hour;
SQLUSMALLINT minute;
SQLUSMALLINT second;
SQLUINTEGER fraction; // Nanoseconds
SQLSMALLINT timezone_hour; // Offset hours from UTC
SQLSMALLINT timezone_minute; // Offset minutes from UTC
};

//-------------------------------------------------------------------------------------------------
// Function pointer initialization
//-------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -463,6 +477,49 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
dataPtr = static_cast<void*>(sqlTimePtr);
break;
}
case SQL_C_SS_TIMESTAMPOFFSET: {
py::object datetimeType = py::module_::import("datetime").attr("datetime");
if (!py::isinstance(param, datetimeType)) {
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
}
// Checking if the object has a timezone
py::object tzinfo = param.attr("tzinfo");
if (tzinfo.is_none()) {
ThrowStdException("Datetime object must have tzinfo for SQL_C_SS_TIMESTAMPOFFSET at paramIndex " + std::to_string(paramIndex));
}

DateTimeOffset* dtoPtr = AllocateParamBuffer<DateTimeOffset>(paramBuffers);

dtoPtr->year = static_cast<SQLSMALLINT>(param.attr("year").cast<int>());
dtoPtr->month = static_cast<SQLUSMALLINT>(param.attr("month").cast<int>());
dtoPtr->day = static_cast<SQLUSMALLINT>(param.attr("day").cast<int>());
dtoPtr->hour = static_cast<SQLUSMALLINT>(param.attr("hour").cast<int>());
dtoPtr->minute = static_cast<SQLUSMALLINT>(param.attr("minute").cast<int>());
dtoPtr->second = static_cast<SQLUSMALLINT>(param.attr("second").cast<int>());
dtoPtr->fraction = static_cast<SQLUINTEGER>(param.attr("microsecond").cast<int>() * 1000);

py::object utcoffset = tzinfo.attr("utcoffset")(param);
if (utcoffset.is_none()) {
ThrowStdException("Datetime object's tzinfo.utcoffset() returned None at paramIndex " + std::to_string(paramIndex));
}

int total_seconds = static_cast<int>(utcoffset.attr("total_seconds")().cast<double>());
const int MAX_OFFSET = 14 * 3600;
const int MIN_OFFSET = -14 * 3600;

if (total_seconds > MAX_OFFSET || total_seconds < MIN_OFFSET) {
ThrowStdException("Datetimeoffset tz offset out of SQL Server range (-14h to +14h) at paramIndex " + std::to_string(paramIndex));
}
std::div_t div_result = std::div(total_seconds, 3600);
dtoPtr->timezone_hour = static_cast<SQLSMALLINT>(div_result.quot);
dtoPtr->timezone_minute = static_cast<SQLSMALLINT>(div(div_result.rem, 60).quot);

dataPtr = static_cast<void*>(dtoPtr);
bufferLength = sizeof(DateTimeOffset);
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
*strLenOrIndPtr = bufferLength;
break;
}
case SQL_C_TYPE_TIMESTAMP: {
py::object datetimeType = py::module_::import("datetime").attr("datetime");
if (!py::isinstance(param, datetimeType)) {
Expand Down Expand Up @@ -514,7 +571,6 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
}
}
assert(SQLBindParameter_ptr && SQLGetStmtAttr_ptr && SQLSetDescField_ptr);

RETCODE rc = SQLBindParameter_ptr(
hStmt,
static_cast<SQLUSMALLINT>(paramIndex + 1), /* 1-based indexing */
Expand Down Expand Up @@ -2485,6 +2541,55 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
}
break;
}
case SQL_SS_TIMESTAMPOFFSET: {
DateTimeOffset dtoValue;
SQLLEN indicator;
ret = SQLGetData_ptr(
hStmt,
i, SQL_C_SS_TIMESTAMPOFFSET,
&dtoValue,
sizeof(dtoValue),
&indicator
);
if (SQL_SUCCEEDED(ret) && indicator != SQL_NULL_DATA) {
LOG("[Fetch] Retrieved DTO: {}-{}-{} {}:{}:{}, fraction(ns)={}, tz_hour={}, tz_minute={}",
dtoValue.year, dtoValue.month, dtoValue.day,
dtoValue.hour, dtoValue.minute, dtoValue.second,
dtoValue.fraction,
dtoValue.timezone_hour, dtoValue.timezone_minute
);

int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute;
// Validating offset
if (totalMinutes < -24 * 60 || totalMinutes > 24 * 60) {
std::ostringstream oss;
oss << "Invalid timezone offset from SQL_SS_TIMESTAMPOFFSET_STRUCT: "
<< totalMinutes << " minutes for column " << i;
ThrowStdException(oss.str());
}
// Convert fraction from ns to µs
int microseconds = dtoValue.fraction / 1000;
py::object datetime = py::module_::import("datetime");
py::object tzinfo = datetime.attr("timezone")(
datetime.attr("timedelta")(py::arg("minutes") = totalMinutes)
);
py::object py_dt = datetime.attr("datetime")(
dtoValue.year,
dtoValue.month,
dtoValue.day,
dtoValue.hour,
dtoValue.minute,
dtoValue.second,
microseconds,
tzinfo
);
row.append(py_dt);
} else {
LOG("Error fetching DATETIMEOFFSET for column {}, ret={}", i, ret);
row.append(py::none());
}
break;
}
case SQL_BINARY:
case SQL_VARBINARY:
case SQL_LONGVARBINARY: {
Expand Down
170 changes: 168 additions & 2 deletions tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import pytest
from datetime import datetime, date, time
from datetime import datetime, date, time, timedelta, timezone
import time as time_module
import decimal
from contextlib import closing
Expand Down Expand Up @@ -6470,7 +6470,7 @@ def test_only_null_and_empty_binary(cursor, db_connection):
finally:
drop_table_if_exists(cursor, "#pytest_null_empty_binary")
db_connection.commit()

# ---------------------- VARCHAR(MAX) ----------------------

def test_varcharmax_short_fetch(cursor, db_connection):
Expand Down Expand Up @@ -7356,6 +7356,172 @@ def test_decimal_separator_calculations(cursor, db_connection):
cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test")
db_connection.commit()

def test_datetimeoffset_read_write(cursor, db_connection):
"""Test reading and writing timezone-aware DATETIMEOFFSET values."""
try:
test_cases = [
# Valid timezone-aware datetimes
datetime(2023, 10, 26, 10, 30, 0, tzinfo=timezone(timedelta(hours=5, minutes=30))),
datetime(2023, 10, 27, 15, 45, 10, 123456, tzinfo=timezone(timedelta(hours=-8))),
datetime(2023, 10, 28, 20, 0, 5, 987654, tzinfo=timezone.utc)
]

cursor.execute("IF OBJECT_ID('tempdb..#dto_test', 'U') IS NOT NULL DROP TABLE #dto_test;")
cursor.execute("CREATE TABLE #dto_test (id INT PRIMARY KEY, dto_column DATETIMEOFFSET);")
db_connection.commit()

insert_stmt = "INSERT INTO #dto_test (id, dto_column) VALUES (?, ?);"
for i, dt in enumerate(test_cases):
cursor.execute(insert_stmt, i, dt)
db_connection.commit()

cursor.execute("SELECT id, dto_column FROM #dto_test ORDER BY id;")
for i, dt in enumerate(test_cases):
row = cursor.fetchone()
assert row is not None
fetched_id, fetched_dt = row
assert fetched_dt.tzinfo is not None
expected_utc = dt.astimezone(timezone.utc)
fetched_utc = fetched_dt.astimezone(timezone.utc)
# Ignore sub-microsecond differences
expected_utc = expected_utc.replace(microsecond=int(expected_utc.microsecond / 1000) * 1000)
fetched_utc = fetched_utc.replace(microsecond=int(fetched_utc.microsecond / 1000) * 1000)
assert fetched_utc == expected_utc
finally:
cursor.execute("DROP TABLE IF EXISTS #dto_test;")
db_connection.commit()

def test_datetimeoffset_max_min_offsets(cursor, db_connection):
"""
Test inserting and retrieving DATETIMEOFFSET with maximum and minimum allowed offsets (+14:00 and -14:00).
Uses fetchone() for retrieval.
"""
try:
cursor.execute("IF OBJECT_ID('tempdb..#dto_offsets', 'U') IS NOT NULL DROP TABLE #dto_offsets;")
cursor.execute("CREATE TABLE #dto_offsets (id INT PRIMARY KEY, dto_column DATETIMEOFFSET);")
db_connection.commit()

test_cases = [
(1, datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone(timedelta(hours=14)))), # max offset
(2, datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone(timedelta(hours=-14)))), # min offset
]

insert_stmt = "INSERT INTO #dto_offsets (id, dto_column) VALUES (?, ?);"
for row_id, dt in test_cases:
cursor.execute(insert_stmt, row_id, dt)
db_connection.commit()

cursor.execute("SELECT id, dto_column FROM #dto_offsets ORDER BY id;")

for expected_id, expected_dt in test_cases:
row = cursor.fetchone()
assert row is not None, f"No row fetched for id {expected_id}."
fetched_id, fetched_dt = row

assert fetched_id == expected_id, f"ID mismatch: expected {expected_id}, got {fetched_id}"
assert fetched_dt.tzinfo is not None, f"Fetched datetime object is naive for id {fetched_id}"

# Compare in UTC to avoid offset differences
expected_utc = expected_dt.astimezone(timezone.utc).replace(tzinfo=None)
fetched_utc = fetched_dt.astimezone(timezone.utc).replace(tzinfo=None)
assert fetched_utc == expected_utc, (
f"Value mismatch for id {expected_id}: expected UTC {expected_utc}, got {fetched_utc}"
)

finally:
cursor.execute("IF OBJECT_ID('tempdb..#dto_offsets', 'U') IS NOT NULL DROP TABLE #dto_offsets;")
db_connection.commit()

def test_datetimeoffset_invalid_offsets(cursor, db_connection):
"""Verify driver rejects offsets beyond ±14 hours."""
try:
cursor.execute("CREATE TABLE #dto_invalid (id INT PRIMARY KEY, dto_column DATETIMEOFFSET);")
db_connection.commit()

with pytest.raises(Exception):
cursor.execute("INSERT INTO #dto_invalid (id, dto_column) VALUES (?, ?);",
1, datetime(2025, 1, 1, 12, 0, tzinfo=timezone(timedelta(hours=15))))

with pytest.raises(Exception):
cursor.execute("INSERT INTO #dto_invalid (id, dto_column) VALUES (?, ?);",
2, datetime(2025, 1, 1, 12, 0, tzinfo=timezone(timedelta(hours=-15))))
finally:
cursor.execute("DROP TABLE IF EXISTS #dto_invalid;")
db_connection.commit()

def test_datetimeoffset_dst_transitions(cursor, db_connection):
"""
Test inserting and retrieving DATETIMEOFFSET values around DST transitions.
Ensures that driver handles DST correctly and does not crash.
"""
try:
cursor.execute("IF OBJECT_ID('tempdb..#dto_dst', 'U') IS NOT NULL DROP TABLE #dto_dst;")
cursor.execute("CREATE TABLE #dto_dst (id INT PRIMARY KEY, dto_column DATETIMEOFFSET);")
db_connection.commit()

# Example DST transition dates (replace with actual region offset if needed)
dst_test_cases = [
(1, datetime(2025, 3, 9, 1, 59, 59, tzinfo=timezone(timedelta(hours=-5)))), # Just before spring forward
(2, datetime(2025, 3, 9, 3, 0, 0, tzinfo=timezone(timedelta(hours=-4)))), # Just after spring forward
(3, datetime(2025, 11, 2, 1, 59, 59, tzinfo=timezone(timedelta(hours=-4)))), # Just before fall back
(4, datetime(2025, 11, 2, 1, 0, 0, tzinfo=timezone(timedelta(hours=-5)))), # Just after fall back
]

insert_stmt = "INSERT INTO #dto_dst (id, dto_column) VALUES (?, ?);"
for row_id, dt in dst_test_cases:
cursor.execute(insert_stmt, row_id, dt)
db_connection.commit()

cursor.execute("SELECT id, dto_column FROM #dto_dst ORDER BY id;")

for expected_id, expected_dt in dst_test_cases:
row = cursor.fetchone()
assert row is not None, f"No row fetched for id {expected_id}."
fetched_id, fetched_dt = row

assert fetched_id == expected_id, f"ID mismatch: expected {expected_id}, got {fetched_id}"
assert fetched_dt.tzinfo is not None, f"Fetched datetime object is naive for id {fetched_id}"

# Compare UTC time to avoid issues due to offsets changing in DST
expected_utc = expected_dt.astimezone(timezone.utc).replace(tzinfo=None)
fetched_utc = fetched_dt.astimezone(timezone.utc).replace(tzinfo=None)
assert fetched_utc == expected_utc, (
f"Value mismatch for id {expected_id}: expected UTC {expected_utc}, got {fetched_utc}"
)

finally:
cursor.execute("IF OBJECT_ID('tempdb..#dto_dst', 'U') IS NOT NULL DROP TABLE #dto_dst;")
db_connection.commit()

def test_datetimeoffset_leap_second(cursor, db_connection):
"""Ensure driver handles leap-second-like microsecond edge cases without crashing."""
try:
cursor.execute("CREATE TABLE #dto_leap (id INT PRIMARY KEY, dto_column DATETIMEOFFSET);")
db_connection.commit()

leap_second_sim = datetime(2023, 12, 31, 23, 59, 59, 999999, tzinfo=timezone.utc)
cursor.execute("INSERT INTO #dto_leap (id, dto_column) VALUES (?, ?);", 1, leap_second_sim)
db_connection.commit()

row = cursor.execute("SELECT dto_column FROM #dto_leap;").fetchone()
assert row[0].tzinfo is not None
finally:
cursor.execute("DROP TABLE IF EXISTS #dto_leap;")
db_connection.commit()

def test_datetimeoffset_malformed_input(cursor, db_connection):
"""Verify driver raises error for invalid datetimeoffset strings."""
try:
cursor.execute("CREATE TABLE #dto_malformed (id INT PRIMARY KEY, dto_column DATETIMEOFFSET);")
db_connection.commit()

with pytest.raises(Exception):
cursor.execute("INSERT INTO #dto_malformed (id, dto_column) VALUES (?, ?);",
1, "2023-13-45 25:61:00 +99:99") # invalid string
finally:
cursor.execute("DROP TABLE IF EXISTS #dto_malformed;")
db_connection.commit()

def test_lowercase_attribute(cursor, db_connection):
"""Test that the lowercase attribute properly converts column names to lowercase"""

Expand Down