Skip to content

Commit 143b0a1

Browse files
authored
Merge pull request #2237 from daspecster/vision-label-detection
Add vision label detection
2 parents 1f7ed9d + 408f929 commit 143b0a1

File tree

5 files changed

+67
-7
lines changed

5 files changed

+67
-7
lines changed

google/cloud/vision/entity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def from_api_repr(cls, response):
5555
:rtype: :class:`~google.cloud.vision.entiy.EntityAnnotation`
5656
:returns: Instance of ``EntityAnnotation``.
5757
"""
58-
bounds = Bounds.from_api_repr(response['boundingPoly'])
58+
bounds = Bounds.from_api_repr(response.get('boundingPoly'))
5959
description = response['description']
6060
locations = [LocationInformation.from_api_repr(location)
6161
for location in response.get('locations', [])]

google/cloud/vision/geometry.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ def from_api_repr(cls, response_vertices):
3131
:type response_vertices: dict
3232
:param response_vertices: List of vertices.
3333
34-
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase`
35-
:returns: Instance of BoundsBase with populated verticies.
34+
:rtype: :class:`~google.cloud.vision.geometry.BoundsBase` or None
35+
:returns: Instance of BoundsBase with populated verticies or None.
3636
"""
37-
vertices = []
38-
for vertex in response_vertices['vertices']:
39-
vertices.append(Vertex(vertex.get('x', None),
40-
vertex.get('y', None)))
37+
if not response_vertices:
38+
return None
39+
40+
vertices = [Vertex(vertex.get('x', None), vertex.get('y', None)) for
41+
vertex in response_vertices.get('vertices', [])]
4142
return cls(vertices)
4243

4344
@property

google/cloud/vision/image.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _detect_annotation(self, feature):
9494
:class:`~google.cloud.vision.entity.EntityAnnotation`.
9595
"""
9696
reverse_types = {
97+
'LABEL_DETECTION': 'labelAnnotations',
9798
'LANDMARK_DETECTION': 'landmarkAnnotations',
9899
'LOGO_DETECTION': 'logoAnnotations',
99100
}
@@ -122,6 +123,18 @@ def detect_faces(self, limit=10):
122123

123124
return faces
124125

126+
def detect_labels(self, limit=10):
127+
"""Detect labels that describe objects in an image.
128+
129+
:type limit: int
130+
:param limit: The maximum number of labels to try and detect.
131+
132+
:rtype: list
133+
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
134+
"""
135+
feature = Feature(FeatureTypes.LABEL_DETECTION, limit)
136+
return self._detect_annotation(feature)
137+
125138
def detect_landmarks(self, limit=10):
126139
"""Detect landmarks in an image.
127140

unit_tests/vision/_fixtures.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,28 @@
1+
LABEL_DETECTION_RESPONSE = {
2+
'responses': [
3+
{
4+
'labelAnnotations': [
5+
{
6+
'mid': '/m/0k4j',
7+
'description': 'automobile',
8+
'score': 0.9776855
9+
},
10+
{
11+
'mid': '/m/07yv9',
12+
'description': 'vehicle',
13+
'score': 0.947987
14+
},
15+
{
16+
'mid': '/m/07r04',
17+
'description': 'truck',
18+
'score': 0.88429511
19+
}
20+
]
21+
}
22+
]
23+
}
24+
25+
126
LANDMARK_DETECTION_RESPONSE = {
227
'responses': [
328
{

unit_tests/vision/test_client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,27 @@ def test_face_detection_from_content(self):
114114
image_request['image']['content'])
115115
self.assertEqual(5, image_request['features'][0]['maxResults'])
116116

117+
def test_label_detection_from_source(self):
118+
from google.cloud.vision.entity import EntityAnnotation
119+
from unit_tests.vision._fixtures import (LABEL_DETECTION_RESPONSE as
120+
RETURNED)
121+
credentials = _Credentials()
122+
client = self._makeOne(project=self.PROJECT, credentials=credentials)
123+
client.connection = _Connection(RETURNED)
124+
125+
image = client.image(source_uri=_IMAGE_SOURCE)
126+
labels = image.detect_labels(limit=3)
127+
self.assertEqual(3, len(labels))
128+
self.assertTrue(isinstance(labels[0], EntityAnnotation))
129+
image_request = client.connection._requested[0]['data']['requests'][0]
130+
self.assertEqual(_IMAGE_SOURCE,
131+
image_request['image']['source']['gcs_image_uri'])
132+
self.assertEqual(3, image_request['features'][0]['maxResults'])
133+
self.assertEqual('automobile', labels[0].description)
134+
self.assertEqual('vehicle', labels[1].description)
135+
self.assertEqual('/m/0k4j', labels[0].mid)
136+
self.assertEqual('/m/07yv9', labels[1].mid)
137+
117138
def test_landmark_detection_from_source(self):
118139
from google.cloud.vision.entity import EntityAnnotation
119140
from unit_tests.vision._fixtures import (LANDMARK_DETECTION_RESPONSE as

0 commit comments

Comments
 (0)