Skip to content

Commit c46ad06

Browse files
feat: Support write api as loading option (#1617)
1 parent 0895ef8 commit c46ad06

File tree

10 files changed

+205
-78
lines changed

10 files changed

+205
-78
lines changed

bigframes/core/local_data.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,27 +97,46 @@ def from_pyarrow(self, table: pa.Table) -> ManagedArrowTable:
9797
mat.validate()
9898
return mat
9999

100-
def to_pyarrow_table(
100+
def to_arrow(
101101
self,
102102
*,
103103
offsets_col: Optional[str] = None,
104104
geo_format: Literal["wkb", "wkt"] = "wkt",
105105
duration_type: Literal["int", "duration"] = "duration",
106106
json_type: Literal["string"] = "string",
107-
) -> pa.Table:
108-
pa_table = self.data
109-
if offsets_col is not None:
110-
pa_table = pa_table.append_column(
111-
offsets_col, pa.array(range(pa_table.num_rows), type=pa.int64())
112-
)
107+
) -> tuple[pa.Schema, Iterable[pa.RecordBatch]]:
113108
if geo_format != "wkt":
114109
raise NotImplementedError(f"geo format {geo_format} not yet implemented")
115-
if duration_type != "duration":
116-
raise NotImplementedError(
117-
f"duration as {duration_type} not yet implemented"
118-
)
119110
assert json_type == "string"
120-
return pa_table
111+
112+
batches = self.data.to_batches()
113+
schema = self.data.schema
114+
if duration_type == "int":
115+
schema = _schema_durations_to_ints(schema)
116+
batches = map(functools.partial(_cast_pa_batch, schema=schema), batches)
117+
118+
if offsets_col is not None:
119+
return schema.append(pa.field(offsets_col, pa.int64())), _append_offsets(
120+
batches, offsets_col
121+
)
122+
else:
123+
return schema, batches
124+
125+
def to_pyarrow_table(
126+
self,
127+
*,
128+
offsets_col: Optional[str] = None,
129+
geo_format: Literal["wkb", "wkt"] = "wkt",
130+
duration_type: Literal["int", "duration"] = "duration",
131+
json_type: Literal["string"] = "string",
132+
) -> pa.Table:
133+
schema, batches = self.to_arrow(
134+
offsets_col=offsets_col,
135+
geo_format=geo_format,
136+
duration_type=duration_type,
137+
json_type=json_type,
138+
)
139+
return pa.Table.from_batches(batches, schema)
121140

122141
def to_parquet(
123142
self,
@@ -391,6 +410,41 @@ def _physical_type_replacements(dtype: pa.DataType) -> pa.DataType:
391410
return dtype
392411

393412

413+
def _append_offsets(
414+
batches: Iterable[pa.RecordBatch], offsets_col_name: str
415+
) -> Iterable[pa.RecordBatch]:
416+
offset = 0
417+
for batch in batches:
418+
offsets = pa.array(range(offset, offset + batch.num_rows), type=pa.int64())
419+
batch_w_offsets = pa.record_batch(
420+
[*batch.columns, offsets],
421+
schema=batch.schema.append(pa.field(offsets_col_name, pa.int64())),
422+
)
423+
offset += batch.num_rows
424+
yield batch_w_offsets
425+
426+
427+
@_recursive_map_types
428+
def _durations_to_ints(type: pa.DataType) -> pa.DataType:
429+
if pa.types.is_duration(type):
430+
return pa.int64()
431+
return type
432+
433+
434+
def _schema_durations_to_ints(schema: pa.Schema) -> pa.Schema:
435+
return pa.schema(
436+
pa.field(field.name, _durations_to_ints(field.type)) for field in schema
437+
)
438+
439+
440+
# TODO: Use RecordBatch.cast once min pyarrow>=16.0
441+
def _cast_pa_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch:
442+
return pa.record_batch(
443+
[arr.cast(type) for arr, type in zip(batch.columns, schema.types)],
444+
schema=schema,
445+
)
446+
447+
394448
def _pairwise(iterable):
395449
do_yield = False
396450
a = None

bigframes/core/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def label_to_identifier(label: typing.Hashable, strict: bool = False) -> str:
142142
identifier = re.sub(r"[^a-zA-Z0-9_]", "", identifier)
143143
if not identifier:
144144
identifier = "id"
145+
elif identifier[0].isdigit():
146+
# first character must be letter or underscore
147+
identifier = "_" + identifier
145148
return identifier
146149

147150

bigframes/session/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def __init__(
255255
session=self,
256256
bqclient=self._clients_provider.bqclient,
257257
storage_manager=self._temp_storage_manager,
258+
write_client=self._clients_provider.bqstoragewriteclient,
258259
default_index_type=self._default_index_type,
259260
scan_index_uniqueness=self._strictly_ordered,
260261
force_total_order=self._strictly_ordered,
@@ -731,7 +732,9 @@ def read_pandas(
731732
workload is such that you exhaust the BigQuery load job
732733
quota and your data cannot be embedded in SQL due to size or
733734
data type limitations.
734-
735+
* "bigquery_write":
736+
[Preview] Use the BigQuery Storage Write API. This feature
737+
is in public preview.
735738
Returns:
736739
An equivalent bigframes.pandas.(DataFrame/Series/Index) object
737740
@@ -805,6 +808,10 @@ def _read_pandas(
805808
return self._loader.read_pandas(
806809
pandas_dataframe, method="stream", api_name=api_name
807810
)
811+
elif write_engine == "bigquery_write":
812+
return self._loader.read_pandas(
813+
pandas_dataframe, method="write", api_name=api_name
814+
)
808815
else:
809816
raise ValueError(f"Got unexpected write_engine '{write_engine}'")
810817

bigframes/session/clients.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ def __init__(
134134
self._bqstoragereadclient: Optional[
135135
google.cloud.bigquery_storage_v1.BigQueryReadClient
136136
] = None
137+
self._bqstoragewriteclient: Optional[
138+
google.cloud.bigquery_storage_v1.BigQueryWriteClient
139+
] = None
137140
self._cloudfunctionsclient: Optional[
138141
google.cloud.functions_v2.FunctionServiceClient
139142
] = None
@@ -238,6 +241,34 @@ def bqstoragereadclient(self):
238241

239242
return self._bqstoragereadclient
240243

244+
@property
245+
def bqstoragewriteclient(self):
246+
if not self._bqstoragewriteclient:
247+
bqstorage_options = None
248+
if "bqstoragewriteclient" in self._client_endpoints_override:
249+
bqstorage_options = google.api_core.client_options.ClientOptions(
250+
api_endpoint=self._client_endpoints_override["bqstoragewriteclient"]
251+
)
252+
elif self._use_regional_endpoints:
253+
bqstorage_options = google.api_core.client_options.ClientOptions(
254+
api_endpoint=_BIGQUERYSTORAGE_REGIONAL_ENDPOINT.format(
255+
location=self._location
256+
)
257+
)
258+
259+
bqstorage_info = google.api_core.gapic_v1.client_info.ClientInfo(
260+
user_agent=self._application_name
261+
)
262+
self._bqstoragewriteclient = (
263+
google.cloud.bigquery_storage_v1.BigQueryWriteClient(
264+
client_info=bqstorage_info,
265+
client_options=bqstorage_options,
266+
credentials=self._credentials,
267+
)
268+
)
269+
270+
return self._bqstoragewriteclient
271+
241272
@property
242273
def cloudfunctionsclient(self):
243274
if not self._cloudfunctionsclient:

bigframes/session/loader.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import typing
2424
from typing import (
2525
Dict,
26+
Generator,
2627
Hashable,
2728
IO,
2829
Iterable,
@@ -36,12 +37,13 @@
3637
import bigframes_vendored.constants as constants
3738
import bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq
3839
import google.api_core.exceptions
40+
from google.cloud import bigquery_storage_v1
3941
import google.cloud.bigquery as bigquery
40-
import google.cloud.bigquery.table
42+
from google.cloud.bigquery_storage_v1 import types as bq_storage_types
4143
import pandas
4244
import pyarrow as pa
4345

44-
from bigframes.core import local_data, utils
46+
from bigframes.core import guid, local_data, utils
4547
import bigframes.core as core
4648
import bigframes.core.blocks as blocks
4749
import bigframes.core.schema as schemata
@@ -142,13 +144,15 @@ def __init__(
142144
self,
143145
session: bigframes.session.Session,
144146
bqclient: bigquery.Client,
147+
write_client: bigquery_storage_v1.BigQueryWriteClient,
145148
storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager,
146149
default_index_type: bigframes.enums.DefaultIndexKind,
147150
scan_index_uniqueness: bool,
148151
force_total_order: bool,
149152
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
150153
):
151154
self._bqclient = bqclient
155+
self._write_client = write_client
152156
self._storage_manager = storage_manager
153157
self._default_index_type = default_index_type
154158
self._scan_index_uniqueness = scan_index_uniqueness
@@ -165,7 +169,7 @@ def __init__(
165169
def read_pandas(
166170
self,
167171
pandas_dataframe: pandas.DataFrame,
168-
method: Literal["load", "stream"],
172+
method: Literal["load", "stream", "write"],
169173
api_name: str,
170174
) -> dataframe.DataFrame:
171175
# TODO: Push this into from_pandas, along with index flag
@@ -183,6 +187,8 @@ def read_pandas(
183187
array_value = self.load_data(managed_data, api_name=api_name)
184188
elif method == "stream":
185189
array_value = self.stream_data(managed_data)
190+
elif method == "write":
191+
array_value = self.write_data(managed_data)
186192
else:
187193
raise ValueError(f"Unsupported read method {method}")
188194

@@ -198,7 +204,7 @@ def load_data(
198204
self, data: local_data.ManagedArrowTable, api_name: Optional[str] = None
199205
) -> core.ArrayValue:
200206
"""Load managed data into bigquery"""
201-
ordering_col = "bf_load_job_offsets"
207+
ordering_col = guid.generate_guid("load_offsets_")
202208

203209
# JSON support incomplete
204210
for item in data.schema.items:
@@ -244,7 +250,7 @@ def load_data(
244250

245251
def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
246252
"""Load managed data into bigquery"""
247-
ordering_col = "bf_stream_job_offsets"
253+
ordering_col = guid.generate_guid("stream_offsets_")
248254
schema_w_offsets = data.schema.append(
249255
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
250256
)
@@ -277,6 +283,61 @@ def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
277283
n_rows=data.data.num_rows,
278284
).drop_columns([ordering_col])
279285

286+
def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
287+
"""Load managed data into bigquery"""
288+
ordering_col = guid.generate_guid("stream_offsets_")
289+
schema_w_offsets = data.schema.append(
290+
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
291+
)
292+
bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES)
293+
bq_table_ref = self._storage_manager.create_temp_table(
294+
bq_schema, [ordering_col]
295+
)
296+
297+
requested_stream = bq_storage_types.stream.WriteStream()
298+
requested_stream.type_ = bq_storage_types.stream.WriteStream.Type.COMMITTED # type: ignore
299+
300+
stream_request = bq_storage_types.CreateWriteStreamRequest(
301+
parent=bq_table_ref.to_bqstorage(), write_stream=requested_stream
302+
)
303+
stream = self._write_client.create_write_stream(request=stream_request)
304+
305+
def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]:
306+
schema, batches = data.to_arrow(
307+
offsets_col=ordering_col, duration_type="int"
308+
)
309+
offset = 0
310+
for batch in batches:
311+
request = bq_storage_types.AppendRowsRequest(
312+
write_stream=stream.name, offset=offset
313+
)
314+
request.arrow_rows.writer_schema.serialized_schema = (
315+
schema.serialize().to_pybytes()
316+
)
317+
request.arrow_rows.rows.serialized_record_batch = (
318+
batch.serialize().to_pybytes()
319+
)
320+
offset += batch.num_rows
321+
yield request
322+
323+
for response in self._write_client.append_rows(requests=request_gen()):
324+
if response.row_errors:
325+
raise ValueError(
326+
f"Problem loading at least one row from DataFrame: {response.row_errors}. {constants.FEEDBACK_LINK}"
327+
)
328+
# This step isn't strictly necessary in COMMITTED mode, but avoids max active stream limits
329+
response = self._write_client.finalize_write_stream(name=stream.name)
330+
assert response.row_count == data.data.num_rows
331+
332+
destination_table = self._bqclient.get_table(bq_table_ref)
333+
return core.ArrayValue.from_table(
334+
table=destination_table,
335+
schema=schema_w_offsets,
336+
session=self._session,
337+
offsets_col=ordering_col,
338+
n_rows=data.data.num_rows,
339+
).drop_columns([ordering_col])
340+
280341
def _start_generic_job(self, job: formatting_helpers.GenericJob):
281342
if bigframes.options.display.progress_bar is not None:
282343
formatting_helpers.wait_for_job(

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
"google-cloud-bigtable >=2.24.0",
4343
"google-cloud-pubsub >=2.21.4",
4444
"google-cloud-bigquery[bqstorage,pandas] >=3.31.0",
45+
# 2.30 needed for arrow support.
46+
"google-cloud-bigquery-storage >= 2.30.0, < 3.0.0",
4547
"google-cloud-functions >=1.12.0",
4648
"google-cloud-bigquery-connection >=1.12.0",
4749
"google-cloud-iam >=2.12.1",

tests/system/small/test_dataframe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test_df_construct_pandas_default(scalars_dfs):
8383
("bigquery_inline"),
8484
("bigquery_load"),
8585
("bigquery_streaming"),
86+
("bigquery_write"),
8687
],
8788
)
8889
def test_read_pandas_all_nice_types(
@@ -1772,7 +1773,7 @@ def test_len(scalars_dfs):
17721773
)
17731774
@pytest.mark.parametrize(
17741775
"write_engine",
1775-
["bigquery_load", "bigquery_streaming"],
1776+
["bigquery_load", "bigquery_streaming", "bigquery_write"],
17761777
)
17771778
def test_df_len_local(session, n_rows, write_engine):
17781779
assert (

0 commit comments

Comments
 (0)