Skip to content

Commit c409c16

Browse files
fix: address DuckDB store critical issues
- Fix import error: use public duckdb module instead of private _duckdb - Fix naive timestamp handling: use replace(tzinfo=UTC) for naive datetimes - Add SQL injection protection: validate table names with regex - Change JSON storage to use dict directly instead of json.dumps() Fixes all critical issues identified in code review: - Resolves ModuleNotFoundError that was causing all tests to fail - Prevents ValueError when DuckDB returns naive timestamps - Protects against SQL injection via malicious table names - Improves JSON queryability in DuckDB Co-authored-by: William Easton <[email protected]>
1 parent 8acf655 commit c409c16

File tree

4 files changed

+56
-22
lines changed

4 files changed

+56
-22
lines changed

key-value/key-value-aio/src/key_value/aio/stores/duckdb/store.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from datetime import datetime, timezone
34
from pathlib import Path
45
from typing import Any, cast, overload
@@ -52,9 +53,9 @@ def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
5253
if collection_name is not None:
5354
json_document["collection"] = collection_name
5455

55-
# Store as JSON string for DuckDB's JSON column
56-
# DuckDB will parse it and store it as native JSON
57-
data["value_dict"] = json.dumps(json_document)
56+
# Store the dict directly for DuckDB's JSON column
57+
# DuckDB's Python API accepts dict objects and stores them as queryable JSON
58+
data["value_dict"] = json_document
5859

5960
return data
6061

@@ -103,12 +104,18 @@ def _parse_json_column(self, value_dict: Any) -> dict[str, Any]:
103104
def _convert_timestamps_to_utc(self, data: dict[str, Any]) -> None:
104105
"""Convert naive timestamps to UTC timezone-aware timestamps."""
105106
created_at = data.get("created_at")
106-
if created_at is not None and isinstance(created_at, datetime) and created_at.tzinfo is None:
107-
data["created_at"] = created_at.astimezone(tz=timezone.utc)
107+
if created_at is not None and isinstance(created_at, datetime):
108+
if created_at.tzinfo is None:
109+
data["created_at"] = created_at.replace(tzinfo=timezone.utc)
110+
else:
111+
data["created_at"] = created_at.astimezone(tz=timezone.utc)
108112

109113
expires_at = data.get("expires_at")
110-
if expires_at is not None and isinstance(expires_at, datetime) and expires_at.tzinfo is None:
111-
data["expires_at"] = expires_at.astimezone(tz=timezone.utc)
114+
if expires_at is not None and isinstance(expires_at, datetime):
115+
if expires_at.tzinfo is None:
116+
data["expires_at"] = expires_at.replace(tzinfo=timezone.utc)
117+
else:
118+
data["expires_at"] = expires_at.astimezone(tz=timezone.utc)
112119

113120

114121
class DuckDBStore(BaseContextManagerStore, BaseStore):
@@ -214,6 +221,11 @@ def __init__(
214221

215222
self._is_closed = False
216223
self._adapter = DuckDBSerializationAdapter()
224+
225+
# Validate table name to prevent SQL injection
226+
if not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", table_name):
227+
msg = "Table name must start with a letter or underscore and contain only letters, digits, or underscores"
228+
raise ValueError(msg)
217229
self._table_name = table_name
218230
self._stable_api = False
219231

@@ -345,9 +357,15 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
345357
}
346358

347359
if created_at is not None and isinstance(created_at, datetime):
348-
document["created_at"] = created_at.astimezone(tz=timezone.utc)
360+
if created_at.tzinfo is None:
361+
document["created_at"] = created_at.replace(tzinfo=timezone.utc)
362+
else:
363+
document["created_at"] = created_at.astimezone(tz=timezone.utc)
349364
if expires_at is not None and isinstance(expires_at, datetime):
350-
document["expires_at"] = expires_at.astimezone(tz=timezone.utc)
365+
if expires_at.tzinfo is None:
366+
document["expires_at"] = expires_at.replace(tzinfo=timezone.utc)
367+
else:
368+
document["expires_at"] = expires_at.astimezone(tz=timezone.utc)
351369

352370
try:
353371
return self._adapter.load_dict(data=document)

key-value/key-value-aio/tests/stores/duckdb/test_duckdb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from tempfile import TemporaryDirectory
44

55
import pytest
6-
from _duckdb import DuckDBPyConnection
7-
from duckdb import CatalogException
6+
from duckdb import CatalogException, DuckDBPyConnection
87
from inline_snapshot import snapshot
98
from typing_extensions import override
109

key-value/key-value-sync/src/key_value/sync/code_gen/stores/duckdb/store.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# from the original file 'store.py'
33
# DO NOT CHANGE! Change the original file instead.
44
import json
5+
import re
56
from datetime import datetime, timezone
67
from pathlib import Path
78
from typing import Any, cast, overload
@@ -55,9 +56,9 @@ def prepare_dump(self, data: dict[str, Any]) -> dict[str, Any]:
5556
if collection_name is not None:
5657
json_document["collection"] = collection_name
5758

58-
# Store as JSON string for DuckDB's JSON column
59-
# DuckDB will parse it and store it as native JSON
60-
data["value_dict"] = json.dumps(json_document)
59+
# Store the dict directly for DuckDB's JSON column
60+
# DuckDB's Python API accepts dict objects and stores them as queryable JSON
61+
data["value_dict"] = json_document
6162

6263
return data
6364

@@ -106,12 +107,18 @@ def _parse_json_column(self, value_dict: Any) -> dict[str, Any]:
106107
def _convert_timestamps_to_utc(self, data: dict[str, Any]) -> None:
107108
"""Convert naive timestamps to UTC timezone-aware timestamps."""
108109
created_at = data.get("created_at")
109-
if created_at is not None and isinstance(created_at, datetime) and (created_at.tzinfo is None):
110-
data["created_at"] = created_at.astimezone(tz=timezone.utc)
110+
if created_at is not None and isinstance(created_at, datetime):
111+
if created_at.tzinfo is None:
112+
data["created_at"] = created_at.replace(tzinfo=timezone.utc)
113+
else:
114+
data["created_at"] = created_at.astimezone(tz=timezone.utc)
111115

112116
expires_at = data.get("expires_at")
113-
if expires_at is not None and isinstance(expires_at, datetime) and (expires_at.tzinfo is None):
114-
data["expires_at"] = expires_at.astimezone(tz=timezone.utc)
117+
if expires_at is not None and isinstance(expires_at, datetime):
118+
if expires_at.tzinfo is None:
119+
data["expires_at"] = expires_at.replace(tzinfo=timezone.utc)
120+
else:
121+
data["expires_at"] = expires_at.astimezone(tz=timezone.utc)
115122

116123

117124
class DuckDBStore(BaseContextManagerStore, BaseStore):
@@ -216,6 +223,11 @@ def __init__(
216223

217224
self._is_closed = False
218225
self._adapter = DuckDBSerializationAdapter()
226+
227+
# Validate table name to prevent SQL injection
228+
if not re.fullmatch("[A-Za-z_][A-Za-z0-9_]*", table_name):
229+
msg = "Table name must start with a letter or underscore and contain only letters, digits, or underscores"
230+
raise ValueError(msg)
219231
self._table_name = table_name
220232
self._stable_api = False
221233

@@ -315,9 +327,15 @@ def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry | Non
315327
document: dict[str, Any] = {"value_dict": value_dict}
316328

317329
if created_at is not None and isinstance(created_at, datetime):
318-
document["created_at"] = created_at.astimezone(tz=timezone.utc)
330+
if created_at.tzinfo is None:
331+
document["created_at"] = created_at.replace(tzinfo=timezone.utc)
332+
else:
333+
document["created_at"] = created_at.astimezone(tz=timezone.utc)
319334
if expires_at is not None and isinstance(expires_at, datetime):
320-
document["expires_at"] = expires_at.astimezone(tz=timezone.utc)
335+
if expires_at.tzinfo is None:
336+
document["expires_at"] = expires_at.replace(tzinfo=timezone.utc)
337+
else:
338+
document["expires_at"] = expires_at.astimezone(tz=timezone.utc)
321339

322340
try:
323341
return self._adapter.load_dict(data=document)

key-value/key-value-sync/tests/code_gen/stores/duckdb/test_duckdb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from tempfile import TemporaryDirectory
77

88
import pytest
9-
from _duckdb import DuckDBPyConnection
10-
from duckdb import CatalogException
9+
from duckdb import CatalogException, DuckDBPyConnection
1110
from inline_snapshot import snapshot
1211
from typing_extensions import override
1312

0 commit comments

Comments
 (0)