Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit f03063c

Browse files
authored
Merge pull request #13 from google/remove-uri-path
Change image_uri to image_name (file basename) in TFRecord output.
2 parents 33d6c96 + c8e8467 commit f03063c

File tree

13 files changed

+23
-18
lines changed

13 files changed

+23
-18
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
all: init pylint coverage test
1+
all: init test pylint
22

33
init:
44
pip install -r requirements.txt

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ pylint >= 2.5.3
99
fire >= 0.3.1
1010
jupyter >= 1.0.0
1111
tensorflow >= 2.2.0
12-
pyarrow < 0.17
12+
pyarrow < 0.17
13+
frozendict >= 1.2

tfrecorder/beam_image.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import base64
1919
import logging
20+
import os
2021
from typing import Any, Dict, Generator, Tuple
2122

2223
import apache_beam as beam
@@ -73,10 +74,10 @@ def load(image_uri):
7374
class ExtractImagesDoFn(beam.DoFn):
7475
"""Adds image to PCollection."""
7576

76-
def __init__(self, image_key: str):
77+
def __init__(self, image_uri_key: str):
7778
"""Constructor."""
7879
super().__init__()
79-
self.image_key = image_key
80+
self.image_uri_key = image_uri_key
8081
self.image_good_counter = Metrics.counter(self.__class__, 'image_good')
8182
self.image_bad_counter = Metrics.counter(self.__class__, 'image_bad')
8283

@@ -95,9 +96,9 @@ def process(
9596
d = {}
9697

9798
try:
98-
image_uri = element[self.image_key]
99+
image_uri = element.pop(self.image_uri_key)
99100
image = load(image_uri)
100-
# TODO(cezequiel): Remove path from image_uri -> image_name
101+
element['image_name'] = os.path.split(image_uri)[-1]
101102
d['image'] = encode(image)
102103
d['image_width'], d['image_height'] = image.size
103104
d['image_channels'] = mode_to_channel(image.mode)

tfrecorder/beam_image_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,6 @@ def _equal(actual):
112112
.format(actual_keys, expected_keys_))
113113
return _equal
114114

115-
expected_keys = ['image_uri', 'label', 'split', 'image',
115+
expected_keys = ['image_name', 'label', 'split', 'image',
116116
'image_height', 'image_width', 'image_channels']
117117
util.assert_that(data, key_matcher(expected_keys))

tfrecorder/check.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def check_tfrecords(
100100
writer.writerow(row)
101101

102102
# Save image data to a file
103-
if 'image_uri' in r:
104-
_, image_filename = os.path.split(_stringify(r['image_uri']))
103+
if 'image_name' in r:
104+
_, image_filename = os.path.split(_stringify(r['image_name']))
105105
image_path = os.path.join(data_dir, image_filename)
106106
_save_image_from_record(r, image_path)
107107

tfrecorder/check_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def setUp(self):
8686

8787
data = test_utils.get_test_data()
8888
num_records = len(data[constants.IMAGE_URI_KEY])
89+
image_uris = data.pop(constants.IMAGE_URI_KEY)
90+
data['image_name'] = [os.path.split(uri)[-1] for uri in image_uris]
8991
data.update({
9092
'image': [beam_image.encode(image_fn())
9193
for _ in range(num_records)],
@@ -123,8 +125,7 @@ def test_valid_records(self, mock_fn):
123125
# Check output images
124126
actual_image_files = [
125127
f for f in os.listdir(actual_dir) if f.endswith('.jpg')]
126-
expected_image_files = [
127-
os.path.split(f)[-1] for f in self.data['image_uri']]
128+
expected_image_files = self.data['image_name']
128129
self.assertCountEqual(actual_image_files, expected_image_files)
129130

130131

tfrecorder/constants.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
"""Global constants."""
1818

19-
import collections
2019
import logging
2120

21+
import frozendict
2222
import tensorflow as tf
2323
from tensorflow_transform.tf_metadata import dataset_metadata
2424
from tensorflow_transform.tf_metadata import schema_utils
@@ -31,14 +31,16 @@
3131
LABEL_KEY = 'label'
3232
IMAGE_CSV_COLUMNS = [SPLIT_KEY, IMAGE_URI_KEY, LABEL_KEY]
3333

34-
IMAGE_CSV_FEATURE_SPEC = {
34+
IMAGE_CSV_FEATURE_SPEC = frozendict.FrozenOrderedDict({
3535
SPLIT_KEY: tf.io.FixedLenFeature([], tf.string),
3636
IMAGE_URI_KEY: tf.io.FixedLenFeature([], tf.string),
3737
LABEL_KEY: tf.io.FixedLenFeature([], tf.string),
38-
}
38+
})
3939

40-
RAW_FEATURE_SPEC = collections.OrderedDict(IMAGE_CSV_FEATURE_SPEC)
41-
RAW_FEATURE_SPEC.update({
40+
RAW_FEATURE_SPEC = frozendict.FrozenOrderedDict({
41+
SPLIT_KEY: tf.io.FixedLenFeature([], tf.string),
42+
LABEL_KEY: tf.io.FixedLenFeature([], tf.string),
43+
'image_name': tf.io.FixedLenFeature([], tf.string),
4244
'image': tf.io.FixedLenFeature([], tf.string),
4345
'image_height': tf.io.FixedLenFeature([], tf.int64),
4446
'image_width': tf.io.FixedLenFeature([], tf.int64),

tfrecorder/test_data/sample_tfrecords/schema.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ feature {
2626
}
2727
}
2828
feature {
29-
name: "image_uri"
29+
name: "image_name"
3030
type: BYTES
3131
presence {
3232
min_fraction: 1.0
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)