Skip to content

Commit b5f03ea

Browse files
authored
Merge pull request googleapis#2835 from daspecster/vision-safe-search-system-test
Add Vision system tests for detect_safe_search().
2 parents a4671fd + 148a3a0 commit b5f03ea

File tree

1 file changed

+60
-8
lines changed

1 file changed

+60
-8
lines changed

system_tests/vision.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ def _assert_coordinate(self, coordinate):
6464
self.assertIsInstance(coordinate, (int, float))
6565
self.assertNotEqual(coordinate, 0.0)
6666

67+
def _assert_likelihood(self, likelihood):
68+
from google.cloud.vision.likelihood import Likelihood
69+
70+
levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY,
71+
Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY,
72+
Likelihood.VERY_UNLIKELY]
73+
self.assertIn(likelihood, levels)
74+
6775

6876
class TestVisionClientLogo(unittest.TestCase):
6977
def setUp(self):
@@ -130,14 +138,6 @@ def tearDown(self):
130138
for value in self.to_delete_by_case:
131139
value.delete()
132140

133-
def _assert_likelihood(self, likelihood):
134-
from google.cloud.vision.likelihood import Likelihood
135-
136-
levels = [Likelihood.UNKNOWN, Likelihood.VERY_LIKELY,
137-
Likelihood.UNLIKELY, Likelihood.POSSIBLE, Likelihood.LIKELY,
138-
Likelihood.VERY_UNLIKELY]
139-
self.assertIn(likelihood, levels)
140-
141141
def _assert_landmarks(self, landmarks):
142142
from google.cloud.vision.face import Landmark
143143
from google.cloud.vision.face import LandmarkTypes
@@ -340,3 +340,55 @@ def test_detect_landmark_filename(self):
340340
self.assertEqual(len(landmarks), 1)
341341
landmark = landmarks[0]
342342
self._assert_landmark(landmark)
343+
344+
345+
class TestVisionClientSafeSearch(BaseVisionTestCase):
346+
def setUp(self):
347+
self.to_delete_by_case = []
348+
349+
def tearDown(self):
350+
for value in self.to_delete_by_case:
351+
value.delete()
352+
353+
def _assert_safe_search(self, safe_search):
354+
from google.cloud.vision.safe import SafeSearchAnnotation
355+
356+
self.assertIsInstance(safe_search, SafeSearchAnnotation)
357+
self._assert_likelihood(safe_search.adult)
358+
self._assert_likelihood(safe_search.spoof)
359+
self._assert_likelihood(safe_search.medical)
360+
self._assert_likelihood(safe_search.violence)
361+
362+
def test_detect_safe_search_content(self):
363+
client = Config.CLIENT
364+
with open(FACE_FILE, 'rb') as image_file:
365+
image = client.image(content=image_file.read())
366+
safe_searches = image.detect_safe_search()
367+
self.assertEqual(len(safe_searches), 1)
368+
safe_search = safe_searches[0]
369+
self._assert_safe_search(safe_search)
370+
371+
def test_detect_safe_search_gcs(self):
372+
bucket_name = Config.TEST_BUCKET.name
373+
blob_name = 'faces.jpg'
374+
blob = Config.TEST_BUCKET.blob(blob_name)
375+
self.to_delete_by_case.append(blob) # Clean-up.
376+
with open(FACE_FILE, 'rb') as file_obj:
377+
blob.upload_from_file(file_obj)
378+
379+
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
380+
381+
client = Config.CLIENT
382+
image = client.image(source_uri=source_uri)
383+
safe_searches = image.detect_safe_search()
384+
self.assertEqual(len(safe_searches), 1)
385+
safe_search = safe_searches[0]
386+
self._assert_safe_search(safe_search)
387+
388+
def test_detect_safe_search_filename(self):
389+
client = Config.CLIENT
390+
image = client.image(filename=FACE_FILE)
391+
safe_searches = image.detect_safe_search()
392+
self.assertEqual(len(safe_searches), 1)
393+
safe_search = safe_searches[0]
394+
self._assert_safe_search(safe_search)

0 commit comments

Comments
 (0)