Skip to content

Commit 66db70e

Browse files
feat: Add thread-safe caching to InferredSchemaLoader
- Add internal memoization with threading.Lock to prevent duplicate schema inference - Cache schema after first call to avoid re-reading records on subsequent calls - This addresses the issue where get_json_schema() is called during read operations (in DeclarativePartition.read()), not just during discover - Add unit test to verify caching behavior (schema inference happens only once) Fixes performance issue identified by @maxi297 where InferredSchemaLoader would read up to record_sample_size records for every partition/slice during a sync. Co-Authored-By: AJ Steers <[email protected]>
1 parent e118a20 commit 66db70e

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed

airbyte_cdk/sources/declarative/schema/inferred_schema_loader.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
33
#
44

5+
import threading
56
from collections.abc import Mapping as ABCMapping
67
from collections.abc import Sequence
78
from dataclasses import InitVar, dataclass
@@ -69,44 +70,54 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
6970
raise ValueError(
7071
"stream_name must be provided either directly or via the 'name' parameter"
7172
)
73+
self._cached_schema: Mapping[str, Any] | None = None
74+
self._lock = threading.Lock()
7275

7376
def get_json_schema(self) -> Mapping[str, Any]:
7477
"""
7578
Infers and returns a JSON schema by reading a sample of records from the stream.
7679
7780
This method reads up to `record_sample_size` records from the stream and uses
78-
the SchemaInferrer to generate a JSON schema. If no records are available,
79-
it returns an empty schema.
81+
the SchemaInferrer to generate a JSON schema. The schema is cached after the first
82+
call to avoid re-reading records on subsequent calls (e.g., during partition reads).
8083
8184
Returns:
8285
A mapping representing the inferred JSON schema for the stream
8386
"""
84-
schema_inferrer = SchemaInferrer()
87+
if self._cached_schema is not None:
88+
return self._cached_schema
8589

86-
record_count = 0
87-
for stream_slice in self.retriever.stream_slices():
88-
for record in self.retriever.read_records(records_schema={}, stream_slice=stream_slice):
89-
if record_count >= self.record_sample_size:
90-
break
90+
with self._lock:
91+
if self._cached_schema is not None:
92+
return self._cached_schema
9193

92-
# Convert all Mapping-like and Sequence-like objects to plain Python types
93-
# This is necessary because genson doesn't handle custom implementations properly
94-
record = _to_builtin_types(record)
94+
schema_inferrer = SchemaInferrer()
9595

96-
airbyte_record = AirbyteRecordMessage(
97-
stream=self.stream_name,
98-
data=record, # type: ignore[arg-type]
99-
emitted_at=0,
100-
)
96+
record_count = 0
97+
for stream_slice in self.retriever.stream_slices():
98+
for record in self.retriever.read_records(records_schema={}, stream_slice=stream_slice):
99+
if record_count >= self.record_sample_size:
100+
break
101101

102-
schema_inferrer.accumulate(airbyte_record)
103-
record_count += 1
102+
# Convert all Mapping-like and Sequence-like objects to plain Python types
103+
# This is necessary because genson doesn't handle custom implementations properly
104+
record = _to_builtin_types(record)
104105

105-
if record_count >= self.record_sample_size:
106-
break
106+
airbyte_record = AirbyteRecordMessage(
107+
stream=self.stream_name,
108+
data=record, # type: ignore[arg-type]
109+
emitted_at=0,
110+
)
107111

108-
inferred_schema: Mapping[str, Any] | None = schema_inferrer.get_stream_schema(
109-
self.stream_name
110-
)
112+
schema_inferrer.accumulate(airbyte_record)
113+
record_count += 1
114+
115+
if record_count >= self.record_sample_size:
116+
break
117+
118+
inferred_schema: Mapping[str, Any] | None = schema_inferrer.get_stream_schema(
119+
self.stream_name
120+
)
111121

112-
return inferred_schema if inferred_schema else {}
122+
self._cached_schema = inferred_schema if inferred_schema else {}
123+
return self._cached_schema

unit_tests/sources/declarative/schema/test_inferred_schema_loader.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,33 @@ def test_inferred_schema_loader_with_arrays():
184184
assert "properties" in schema
185185
assert "tags" in schema["properties"]
186186
assert "array" in schema["properties"]["tags"]["type"]
187+
188+
189+
def test_inferred_schema_loader_caches_schema():
190+
"""Test that InferredSchemaLoader caches the schema and doesn't re-read records on subsequent calls."""
191+
retriever = MagicMock()
192+
retriever.stream_slices.return_value = iter([None])
193+
retriever.read_records.return_value = iter(
194+
[
195+
{"id": 1, "name": "Alice"},
196+
{"id": 2, "name": "Bob"},
197+
]
198+
)
199+
200+
config = MagicMock()
201+
parameters = {"name": "users"}
202+
loader = InferredSchemaLoader(
203+
retriever=retriever,
204+
config=config,
205+
parameters=parameters,
206+
record_sample_size=2,
207+
stream_name="users",
208+
)
209+
210+
schema1 = loader.get_json_schema()
211+
schema2 = loader.get_json_schema()
212+
schema3 = loader.get_json_schema()
213+
214+
assert schema1 == schema2 == schema3
215+
assert retriever.stream_slices.call_count == 1
216+
assert retriever.read_records.call_count == 1

0 commit comments

Comments
 (0)