Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ df.tensorflow.to_tfr(output_dir='gs://my/bucket')

##### Running on Cloud Dataflow

Google Cloud Platform Dataflow workers need to be supplied with the tfrecorder
package that you would like to run remotely. To do so first download or build
the package (a python wheel file) and then specify the path the the file when
tfrecorder is called.

Step 1: Download or create the wheel file.

To download the wheel from pip:
`pip download tfrecorder --no-deps`

To build from source/git:
`python setup.py sdist`

Step 2:
Specify the project, region, and path to the tfrecorder wheel for remote execution.

```python
import pandas as pd
import tfrecorder
Expand All @@ -44,9 +60,11 @@ df.tensorflow.to_tfr(
output_dir='gs://my/bucket',
runner='DataFlowRunner',
project='my-project',
region='us-central1')
region='us-central1'
tfrecorder_wheel='/path/to/my/tfrecorder.whl')
```


#### From CSV

Using Python interpreter:
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ pylint >= 2.5.3
fire >= 0.3.1
jupyter >= 1.0.0
tensorflow >= 2.2.0
pyarrow < 0.17
frozendict >= 1.2
pyarrow >= 0.17
frozendict >= 1.2
5 changes: 5 additions & 0 deletions tfrecorder/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def to_tfr(
runner: str = 'DirectRunner',
project: Optional[str] = None,
region: Optional[str] = None,
tfrecorder_wheel: Optional[str] = None,
dataflow_options: Union[Dict[str, Any], None] = None,
job_label: str = 'to-tfr',
compression: Optional[str] = 'gzip',
Expand All @@ -66,6 +67,9 @@ def to_tfr(
runner: Beam runner. Can be DirectRunner or DataFlowRunner.
project: GCP project name (Required if DataFlowRunner).
region: GCP region name (Required if DataFlowRunner).
tfrecorder_wheel: Path to the tfrecorder wheel DataFlow will run.
(create with 'python setup.py sdist' or
'pip download tfrecorder --no-deps')
dataflow_options: Optional dictionary containing DataFlow options.
job_label: User supplied description for the beam job name.
compression: Can be 'gzip' or None for no compression.
Expand All @@ -84,6 +88,7 @@ def to_tfr(
runner=runner,
project=project,
region=region,
tfrecorder_wheel=tfrecorder_wheel,
dataflow_options=dataflow_options,
job_label=job_label,
compression=compression,
Expand Down
4 changes: 2 additions & 2 deletions tfrecorder/beam_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def load(image_uri):
try:
with tf.io.gfile.GFile(image_uri, 'rb') as f:
return Image.open(f)
except tf.python.framework.errors_impl.NotFoundError:
raise OSError('File {} was not found.'.format(image_uri))
except tf.python.framework.errors_impl.NotFoundError as e:
raise OSError('File {} was not found.'.format(image_uri)) from e


# pylint: disable=abstract-method
Expand Down
16 changes: 5 additions & 11 deletions tfrecorder/beam_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@
from tfrecorder import constants


def _get_setup_py_filepath() -> str:
"""Returns the file path to the setup.py file.

The location of the setup.py file is needed to run Dataflow jobs.
"""

return os.path.join(
os.path.dirname(os.path.abspath(__file__)), '..', 'setup.py')


def _get_job_name(job_label: str = None) -> str:
"""Returns Beam runner job name.

Expand Down Expand Up @@ -76,6 +66,7 @@ def _get_pipeline_options(
job_dir: str,
project: str,
region: str,
tfrecorder_wheel: str,
dataflow_options: Union[Dict[str, Any], None]
) -> beam.pipeline.PipelineOptions:
"""Returns Beam pipeline options."""
Expand All @@ -95,7 +86,7 @@ def _get_pipeline_options(
if region:
options_dict['region'] = region
if runner == 'DataflowRunner':
options_dict['setup_file'] = _get_setup_py_filepath()
options_dict['extra_packages'] = tfrecorder_wheel
if dataflow_options:
options_dict.update(dataflow_options)

Expand Down Expand Up @@ -199,6 +190,7 @@ def build_pipeline(
output_dir: str,
compression: str,
num_shards: int,
tfrecorder_wheel: str,
dataflow_options: dict,
integer_label: bool) -> beam.Pipeline:
"""Runs TFRecorder Beam Pipeline.
Expand All @@ -212,6 +204,7 @@ def build_pipeline(
output_dir: GCS or Local Path for output.
compression: gzip or None.
num_shards: Number of shards.
tfrecorder_wheel: Path to TFRecorder wheel for DataFlow
dataflow_options: Dataflow Runner Options (optional)
integer_label: Flags if label is already an integer.

Expand All @@ -229,6 +222,7 @@ def build_pipeline(
job_dir,
project,
region,
tfrecorder_wheel,
dataflow_options)

#with beam.Pipeline(runner, options=options) as p:
Expand Down
7 changes: 0 additions & 7 deletions tfrecorder/beam_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

"""Tests for beam_pipeline."""

import os
import unittest
from unittest import mock

Expand Down Expand Up @@ -78,12 +77,6 @@ def test_partition_fn(self):
index, i,
'{} should be index {} but was index {}'.format(part, i, index))

def test_get_setup_py_filepath(self):
"""Tests `_get_setup_py_filepath`."""
filepath = beam_pipeline._get_setup_py_filepath()
self.assertTrue(os.path.isfile(filepath))
self.assertTrue(os.path.isabs(filepath))


if __name__ == '__main__':
unittest.main()
11 changes: 9 additions & 2 deletions tfrecorder/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def _validate_runner(
df: pd.DataFrame,
runner: str,
project: str,
region: str):
region: str,
tfrecorder_wheel: str):
"""Validates an appropriate beam runner is chosen."""
if runner not in ['DataflowRunner', 'DirectRunner']:
raise AttributeError('Runner {} is not supported.'.format(runner))
Expand All @@ -70,6 +71,9 @@ def _validate_runner(
'DataflowRunner requires valid `project` and `region` to be specified.'
'The `project` is {} and `region` is {}'.format(project, region))

if (runner == 'DataflowRunner') & (not tfrecorder_wheel):
raise AttributeError(
'DataflowRunner requires a tfrecorder whl file for remote execution.')
# def read_image_directory(dirpath) -> pd.DataFrame:
# """Reads image data from a directory into a Pandas DataFrame."""
#
Expand Down Expand Up @@ -164,6 +168,7 @@ def create_tfrecords(
runner: str = 'DirectRunner',
project: Optional[str] = None,
region: Optional[str] = None,
tfrecorder_wheel: Optional[str] = None,
dataflow_options: Optional[Dict[str, Any]] = None,
job_label: str = 'create-tfrecords',
compression: Optional[str] = 'gzip',
Expand All @@ -190,6 +195,7 @@ def create_tfrecords(
runner: Beam runner. Can be 'DirectRunner' or 'DataFlowRunner'
project: GCP project name (Required if DataflowRunner)
region: GCP region name (Required if DataflowRunner)
tfrecorder_wheel: Required for GCP Runs, path to the tfrecorder whl.
dataflow_options: Options dict for DataflowRunner
job_label: User supplied description for the Beam job name.
compression: Can be 'gzip' or None for no compression.
Expand All @@ -206,7 +212,7 @@ def create_tfrecords(
df = to_dataframe(input_data, header, names)

_validate_data(df)
_validate_runner(df, runner, project, region)
_validate_runner(df, runner, project, region, tfrecorder_wheel)

logfile = os.path.join('/tmp', constants.LOGFILE)
_configure_logging(logfile)
Expand All @@ -222,6 +228,7 @@ def create_tfrecords(
output_dir=output_dir,
compression=compression,
num_shards=num_shards,
tfrecorder_wheel=tfrecorder_wheel,
dataflow_options=dataflow_options,
integer_label=integer_label)

Expand Down
35 changes: 29 additions & 6 deletions tfrecorder/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def setUp(self):
self.test_df = test_utils.get_test_df()
self.test_region = 'us-central1'
self.test_project = 'foo'
self.test_wheel = '/my/path/wheel.whl'

@mock.patch('tfrecorder.client.beam_pipeline')
def test_create_tfrecords_direct_runner(self, mock_beam):
Expand Down Expand Up @@ -71,7 +72,8 @@ def test_create_tfrecords_dataflow_runner(self, mock_beam):
runner='DataflowRunner',
output_dir=outdir,
region=self.test_region,
project=self.test_project)
project=self.test_project,
tfrecorder_wheel=self.test_wheel)
self.assertEqual(r, expected)


Expand All @@ -84,6 +86,7 @@ def setUp(self):
self.test_df = test_utils.get_test_df()
self.test_region = 'us-central1'
self.test_project = 'foo'
self.test_wheel = '/my/path/wheel.whl'

def test_valid_dataframe(self):
"""Tests valid DataFrame input."""
Expand Down Expand Up @@ -126,7 +129,8 @@ def test_valid_runner(self):
self.test_df,
runner='DirectRunner',
project=self.test_project,
region=self.test_region))
region=self.test_region,
tfrecorder_wheel=None))

def test_invalid_runner(self):
"""Tests invalid runner."""
Expand All @@ -135,7 +139,8 @@ def test_invalid_runner(self):
self.test_df,
runner='FooRunner',
project=self.test_project,
region=self.test_region)
region=self.test_region,
tfrecorder_wheel=None)

def test_local_path_with_dataflow_runner(self):
"""Tests DataflowRunner conflict with local path."""
Expand All @@ -144,7 +149,8 @@ def test_local_path_with_dataflow_runner(self):
self.df_test,
runner='DataflowRunner',
project=self.test_project,
region=self.test_region)
region=self.test_region,
tfrecorder_wheel=self.test_wheel)

def test_gcs_path_with_dataflow_runner(self):
"""Tests DataflowRunner with GCS path."""
Expand All @@ -155,7 +161,8 @@ def test_gcs_path_with_dataflow_runner(self):
df2,
runner='DataflowRunner',
project=self.test_project,
region=self.test_region))
region=self.test_region,
tfrecorder_wheel=self.test_wheel))

def test_gcs_path_with_dataflow_runner_missing_param(self):
"""Tests DataflowRunner with missing required parameter."""
Expand All @@ -168,11 +175,27 @@ def test_gcs_path_with_dataflow_runner_missing_param(self):
df2,
runner='DataflowRunner',
project=p,
region=r)
region=r,
tfrecorder_wheel=self.test_wheel)
self.assertTrue('DataflowRunner requires valid `project` and `region`'
in repr(context.exception))


def test_gcs_path_with_dataflow_runner_missing_wheel(self):
"""Tests DataflowRunner with missing required whl path."""
df2 = self.test_df.copy()
df2[constants.IMAGE_URI_KEY] = 'gs://' + df2[constants.IMAGE_URI_KEY]
with self.assertRaises(AttributeError) as context:
client._validate_runner(
df2,
runner='DataflowRunner',
project=self.test_project,
region=self.test_region,
tfrecorder_wheel=None)
self.assertTrue('requires a tfrecorder whl file for remote execution.'
in repr(context.exception))


def _make_csv_tempfile(data: List[List[str]]) -> tempfile.NamedTemporaryFile:
"""Returns `NamedTemporaryFile` representing an image CSV."""

Expand Down
2 changes: 1 addition & 1 deletion tfrecorder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def copy_logfile_to_gcs(logfile: str, output_dir: str):
gcs_logfile.write(log)
except FileNotFoundError as e:
raise FileNotFoundError("Unable to copy log file {} to gcs.".format(
e.filename))
e.filename)) from e