diff --git a/system_tests/vision.py b/system_tests/vision.py index 81cfa1f71257..120db0325e6d 100644 --- a/system_tests/vision.py +++ b/system_tests/vision.py @@ -64,6 +64,14 @@ def _assert_coordinate(self, coordinate): self.assertIsInstance(coordinate, (int, float)) self.assertNotEqual(coordinate, 0.0) + def _assert_likelihood(self, likelihood): + from google.cloud.vision.likelihood import Likelihood + + levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY, + Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY, + Likelihood.VERY_UNLIKELY] + self.assertIn(likelihood, levels) + class TestVisionClientLogo(unittest.TestCase): def setUp(self): @@ -130,14 +138,6 @@ def tearDown(self): for value in self.to_delete_by_case: value.delete() - def _assert_likelihood(self, likelihood): - from google.cloud.vision.likelihood import Likelihood - - levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY, - Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY, - Likelihood.VERY_UNLIKELY] - self.assertIn(likelihood, levels) - def _assert_landmarks(self, landmarks): from google.cloud.vision.face import Landmark from google.cloud.vision.face import LandmarkTypes @@ -340,3 +340,55 @@ def test_detect_landmark_filename(self): self.assertEqual(len(landmarks), 1) landmark = landmarks[0] self._assert_landmark(landmark) + + +class TestVisionClientSafeSearch(BaseVisionTestCase): + def setUp(self): + self.to_delete_by_case = [] + + def tearDown(self): + for value in self.to_delete_by_case: + value.delete() + + def _assert_safe_search(self, safe_search): + from google.cloud.vision.safe import SafeSearchAnnotation + + self.assertIsInstance(safe_search, SafeSearchAnnotation) + self._assert_likelihood(safe_search.adult) + self._assert_likelihood(safe_search.spoof) + self._assert_likelihood(safe_search.medical) + self._assert_likelihood(safe_search.violence) + + def test_detect_safe_search_content(self): + client = Config.CLIENT + with open(FACE_FILE, 'rb') as image_file: + image = client.image(content=image_file.read()) + safe_searches = image.detect_safe_search() + self.assertEqual(len(safe_searches), 1) + safe_search = safe_searches[0] + self._assert_safe_search(safe_search) + + def test_detect_safe_search_gcs(self): + bucket_name = Config.TEST_BUCKET.name + blob_name = 'faces.jpg' + blob = Config.TEST_BUCKET.blob(blob_name) + self.to_delete_by_case.append(blob) # Clean-up. + with open(FACE_FILE, 'rb') as file_obj: + blob.upload_from_file(file_obj) + + source_uri = 'gs://%s/%s' % (bucket_name, blob_name) + + client = Config.CLIENT + image = client.image(source_uri=source_uri) + safe_searches = image.detect_safe_search() + self.assertEqual(len(safe_searches), 1) + safe_search = safe_searches[0] + self._assert_safe_search(safe_search) + + def test_detect_safe_search_filename(self): + client = Config.CLIENT + image = client.image(filename=FACE_FILE) + safe_searches = image.detect_safe_search() + self.assertEqual(len(safe_searches), 1) + safe_search = safe_searches[0] + self._assert_safe_search(safe_search)