diff --git a/system_tests/data/text.jpg b/system_tests/data/text.jpg new file mode 100644 index 000000000000..39c601c8d315 Binary files /dev/null and b/system_tests/data/text.jpg differ diff --git a/system_tests/vision.py b/system_tests/vision.py index 62950c12ddda..277bcd9d657e 100644 --- a/system_tests/vision.py +++ b/system_tests/vision.py @@ -32,6 +32,7 @@ FACE_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'faces.jpg') LABEL_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'car.jpg') LANDMARK_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'landmark.jpg') +TEXT_FILE = os.path.join(_SYS_TESTS_DIR, 'data', 'text.jpg') class Config(object): @@ -285,7 +286,7 @@ def test_detect_labels_filename(self): class TestVisionClientLandmark(BaseVisionTestCase): - DESCRIPTIONS = ('Mount Rushmore') + DESCRIPTIONS = ('Mount Rushmore',) def setUp(self): self.to_delete_by_case = [] @@ -394,6 +395,67 @@ def test_detect_safe_search_filename(self): self._assert_safe_search(safe_search) +class TestVisionClientText(unittest.TestCase): + DESCRIPTIONS = ( + 'Do', + 'what', + 'is', + 'right,', + 'not', + 'what', + 'is', + 'easy', + 'Do what is\nright, not\nwhat is easy\n', + ) + + def setUp(self): + self.to_delete_by_case = [] + + def tearDown(self): + for value in self.to_delete_by_case: + value.delete() + + def _assert_text(self, text): + self.assertIsInstance(text, EntityAnnotation) + self.assertIn(text.description, self.DESCRIPTIONS) + self.assertIn(text.locale, (None, 'en')) + self.assertNotEqual(text.score, 0.0) + + def test_detect_text_content(self): + client = Config.CLIENT + with open(TEXT_FILE, 'rb') as image_file: + image = client.image(content=image_file.read()) + texts = image.detect_text() + self.assertEqual(len(texts), 9) + for text in texts: + self._assert_text(text) + + def test_detect_text_gcs(self): + bucket_name = Config.TEST_BUCKET.name + blob_name = 'text.jpg' + blob = Config.TEST_BUCKET.blob(blob_name) + self.to_delete_by_case.append(blob) # Clean-up. + with open(TEXT_FILE, 'rb') as file_obj: + blob.upload_from_file(file_obj) + + source_uri = 'gs://%s/%s' % (bucket_name, blob_name) + + client = Config.CLIENT + image = client.image(source_uri=source_uri) + texts = image.detect_text() + self.assertEqual(len(texts), 9) + for text in texts: + self._assert_text(text) + + def test_detect_text_filename(self): + client = Config.CLIENT + image = client.image(filename=TEXT_FILE) + texts = image.detect_text() + self.assertEqual(len(texts), 9) + for text in texts: + self._assert_text(text) + + class TestVisionClientImageProperties(BaseVisionTestCase): def setUp(self): self.to_delete_by_case = []