diff --git a/speech/google/cloud/speech/_gax.py b/speech/google/cloud/speech/_gax.py index 05636a22806c..3c343f2227b5 100644 --- a/speech/google/cloud/speech/_gax.py +++ b/speech/google/cloud/speech/_gax.py @@ -182,9 +182,7 @@ def streaming_recognize(self, sample, language_code=None, .cloud_speech_pb2.StreamingRecognizeResponse` :returns: ``StreamingRecognizeResponse`` instances. """ - if getattr(sample.content, 'closed', None) is None: - raise ValueError('Please use file-like object for data stream.') - if sample.content.closed: + if sample.stream.closed: raise ValueError('Stream is closed.') requests = _stream_requests(sample, language_code=language_code, @@ -252,7 +250,6 @@ def sync_recognize(self, sample, language_code=None, max_alternatives=None, language_code=language_code, max_alternatives=max_alternatives, profanity_filter=profanity_filter, speech_context=SpeechContext(phrases=speech_context)) - audio = RecognitionAudio(content=sample.content, uri=sample.source_uri) api = self._gapic_api @@ -337,7 +334,7 @@ def _stream_requests(sample, language_code=None, max_alternatives=None, yield config_request while True: - data = sample.content.read(sample.chunk_size) + data = sample.stream.read(sample.chunk_size) if not data: break yield StreamingRecognizeRequest(audio_content=data) diff --git a/speech/google/cloud/speech/client.py b/speech/google/cloud/speech/client.py index 828e74c119b1..f6bcdf50b26a 100644 --- a/speech/google/cloud/speech/client.py +++ b/speech/google/cloud/speech/client.py @@ -66,12 +66,12 @@ def __init__(self, credentials=None, http=None, use_gax=None): else: self._use_gax = use_gax - def sample(self, content=None, source_uri=None, encoding=None, + def sample(self, content=None, source_uri=None, stream=None, encoding=None, sample_rate=None): """Factory: construct Sample to use when making recognize requests. :type content: bytes - :param content: (Optional) Byte stream of audio. + :param content: (Optional) Bytes containing audio data. :type source_uri: str :param source_uri: (Optional) URI that points to a file that contains @@ -80,6 +80,9 @@ def sample(self, content=None, source_uri=None, encoding=None, supported, which must be specified in the following format: ``gs://bucket_name/object_name``. + :type stream: file + :param stream: (Optional) File like object to stream. + :type encoding: str :param encoding: encoding of audio data sent in all RecognitionAudio messages, can be one of: :attr:`~.Encoding.LINEAR16`, @@ -97,7 +100,7 @@ def sample(self, content=None, source_uri=None, encoding=None, :rtype: :class:`~google.cloud.speech.sample.Sample` :returns: Instance of ``Sample``. """ - return Sample(content=content, source_uri=source_uri, + return Sample(content=content, source_uri=source_uri, stream=stream, encoding=encoding, sample_rate=sample_rate, client=self) @property diff --git a/speech/google/cloud/speech/sample.py b/speech/google/cloud/speech/sample.py index f1820e148729..fb0ab6ae69de 100644 --- a/speech/google/cloud/speech/sample.py +++ b/speech/google/cloud/speech/sample.py @@ -14,6 +14,7 @@ """Sample class to handle content for Google Cloud Speech API.""" + from google.cloud.speech.encoding import Encoding from google.cloud.speech.result import StreamingSpeechResult @@ -22,7 +23,7 @@ class Sample(object): """Representation of an audio sample to be used with Google Speech API. :type content: bytes - :param content: (Optional) Byte stream of audio. + :param content: (Optional) Bytes containing audio data. :type source_uri: str :param source_uri: (Optional) URI that points to a file that contains @@ -31,6 +32,9 @@ class Sample(object): supported, which must be specified in the following format: ``gs://bucket_name/object_name``. + :type stream: file + :param stream: (Optional) File like object to stream. + :type encoding: str :param encoding: encoding of audio data sent in all RecognitionAudio messages, can be one of: :attr:`~.Encoding.LINEAR16`, @@ -51,17 +55,19 @@ class Sample(object): default_encoding = Encoding.FLAC default_sample_rate = 16000 - def __init__(self, content=None, source_uri=None, + def __init__(self, content=None, source_uri=None, stream=None, encoding=None, sample_rate=None, client=None): self._client = client - no_source = content is None and source_uri is None - both_source = content is not None and source_uri is not None - if no_source or both_source: - raise ValueError('Supply one of \'content\' or \'source_uri\'') + sources = [content is not None, source_uri is not None, + stream is not None] + if sources.count(True) != 1: + raise ValueError('Supply exactly one of ' + '\'content\', \'source_uri\', \'stream\'') self._content = content self._source_uri = source_uri + self._stream = stream if sample_rate is not None and not 8000 <= sample_rate <= 48000: raise ValueError('The value of sample_rate must be between 8000' @@ -109,6 +115,15 @@ def sample_rate(self): """ return self._sample_rate + @property + def stream(self): + """Stream the content when it is a file-like object. + + :rtype: file + :returns: File like object to stream. + """ + return self._stream + @property def encoding(self): """Audio encoding type diff --git a/speech/unit_tests/test__gax.py b/speech/unit_tests/test__gax.py index 2c34b93abd77..043e271d9525 100644 --- a/speech/unit_tests/test__gax.py +++ b/speech/unit_tests/test__gax.py @@ -14,43 +14,6 @@ import unittest -import mock - - -def _make_credentials(): - import google.auth.credentials - return mock.Mock(spec=google.auth.credentials.Credentials) - - -class TestGAPICSpeechAPI(unittest.TestCase): - SAMPLE_RATE = 16000 - - @staticmethod - def _get_target_class(): - from google.cloud.speech._gax import GAPICSpeechAPI - - return GAPICSpeechAPI - - def _make_one(self, *args, **kw): - return self._get_target_class()(*args, **kw) - - def test_use_bytes_instead_of_file_like_object(self): - from google.cloud import speech - from google.cloud.speech.sample import Sample - - credentials = _make_credentials() - client = speech.Client(credentials=credentials, use_gax=True) - client.connection = _Connection() - client.connection.credentials = credentials - - sample = Sample(content=b'', encoding=speech.Encoding.FLAC, - sample_rate=self.SAMPLE_RATE) - - api = self._make_one(client) - with self.assertRaises(ValueError): - api.streaming_recognize(sample) - self.assertEqual(client.connection._requested, []) - class TestSpeechGAXMakeRequests(unittest.TestCase): SAMPLE_RATE = 16000 @@ -143,7 +106,7 @@ def test_stream_requests(self): from google.cloud.grpc.speech.v1beta1.cloud_speech_pb2 import ( StreamingRecognizeRequest) - sample = Sample(content=BytesIO(self.AUDIO_CONTENT), + sample = Sample(stream=BytesIO(self.AUDIO_CONTENT), encoding=speech.Encoding.FLAC, sample_rate=self.SAMPLE_RATE) language_code = 'US-en' @@ -172,10 +135,3 @@ def test_stream_requests(self): self.assertEqual(streaming_request.audio_content, self.AUDIO_CONTENT) self.assertIsInstance(config_request.streaming_config, StreamingRecognitionConfig) - - -class _Connection(object): - - def __init__(self, *responses): - self._responses = responses - self._requested = [] diff --git a/speech/unit_tests/test_client.py b/speech/unit_tests/test_client.py index bf11eb4f7c65..8f5262cc8b32 100644 --- a/speech/unit_tests/test_client.py +++ b/speech/unit_tests/test_client.py @@ -72,7 +72,7 @@ class TestClient(unittest.TestCase): SAMPLE_RATE = 16000 HINTS = ['hi'] AUDIO_SOURCE_URI = 'gs://sample-bucket/sample-recording.flac' - AUDIO_CONTENT = '/9j/4QNURXhpZgAASUkq' + AUDIO_CONTENT = b'testing 1 2 3' @staticmethod def _get_target_class(): @@ -125,14 +125,12 @@ def test_sync_recognize_content_with_optional_params_no_gax(self): from base64 import b64encode from google.cloud._helpers import _bytes_to_unicode - from google.cloud._helpers import _to_bytes from google.cloud import speech from google.cloud.speech.alternative import Alternative from unit_tests._fixtures import SYNC_RECOGNIZE_RESPONSE - _AUDIO_CONTENT = _to_bytes(self.AUDIO_CONTENT) - _B64_AUDIO_CONTENT = _bytes_to_unicode(b64encode(_AUDIO_CONTENT)) + _B64_AUDIO_CONTENT = _bytes_to_unicode(b64encode(self.AUDIO_CONTENT)) RETURNED = SYNC_RECOGNIZE_RESPONSE REQUEST = { 'config': { @@ -325,8 +323,7 @@ def speech_api(channel=None): self.assertIsInstance(low_level, _MockGAPICSpeechAPI) self.assertIs(low_level._channel, channel_obj) self.assertEqual( - channel_args, - [(creds, _gax.DEFAULT_USER_AGENT, host)]) + channel_args, [(creds, _gax.DEFAULT_USER_AGENT, host)]) results = sample.sync_recognize() @@ -462,8 +459,9 @@ def speech_api(channel=None): speech_api.SERVICE_ADDRESS = host stream.close() + self.assertTrue(stream.closed) - sample = client.sample(content=stream, + sample = client.sample(stream=stream, encoding=Encoding.LINEAR16, sample_rate=self.SAMPLE_RATE) @@ -523,7 +521,7 @@ def speech_api(channel=None): make_secure_channel=make_channel): client._speech_api = _gax.GAPICSpeechAPI(client) - sample = client.sample(content=stream, + sample = client.sample(stream=stream, encoding=Encoding.LINEAR16, sample_rate=self.SAMPLE_RATE) @@ -596,7 +594,7 @@ def speech_api(channel=None): make_secure_channel=make_channel): client._speech_api = _gax.GAPICSpeechAPI(client) - sample = client.sample(content=stream, + sample = client.sample(stream=stream, encoding=Encoding.LINEAR16, sample_rate=self.SAMPLE_RATE) @@ -640,7 +638,7 @@ def speech_api(channel=None): make_secure_channel=make_channel): client._speech_api = _gax.GAPICSpeechAPI(client) - sample = client.sample(content=stream, + sample = client.sample(stream=stream, encoding=Encoding.LINEAR16, sample_rate=self.SAMPLE_RATE) diff --git a/speech/unit_tests/test_sample.py b/speech/unit_tests/test_sample.py index 15f3604dfffd..b6a6b094435b 100644 --- a/speech/unit_tests/test_sample.py +++ b/speech/unit_tests/test_sample.py @@ -38,10 +38,44 @@ def test_initialize_sample(self): self.assertEqual(sample.sample_rate, self.SAMPLE_RATE) def test_content_and_source_uri(self): + from io import BytesIO + with self.assertRaises(ValueError): self._make_one(content='awefawagaeragere', source_uri=self.AUDIO_SOURCE_URI) + with self.assertRaises(ValueError): + self._make_one(stream=BytesIO(b'awefawagaeragere'), + source_uri=self.AUDIO_SOURCE_URI) + + with self.assertRaises(ValueError): + self._make_one(content='awefawagaeragere', + stream=BytesIO(b'awefawagaeragere'), + source_uri=self.AUDIO_SOURCE_URI) + + def test_stream_property(self): + from io import BytesIO + from google.cloud.speech.encoding import Encoding + + data = b'abc 1 2 3 4' + stream = BytesIO(data) + sample = self._make_one(stream=stream, encoding=Encoding.FLAC, + sample_rate=self.SAMPLE_RATE) + self.assertEqual(sample.stream, stream) + self.assertEqual(sample.stream.read(), data) + + def test_bytes_converts_to_file_like_object(self): + from google.cloud import speech + from google.cloud.speech.sample import Sample + + test_bytes = b'testing 1 2 3' + + sample = Sample(content=test_bytes, encoding=speech.Encoding.FLAC, + sample_rate=self.SAMPLE_RATE) + self.assertEqual(sample.content, test_bytes) + self.assertEqual(sample.encoding, speech.Encoding.FLAC) + self.assertEqual(sample.sample_rate, self.SAMPLE_RATE) + def test_sample_rates(self): from google.cloud.speech.encoding import Encoding diff --git a/system_tests/speech.py b/system_tests/speech.py index e3685502f77c..832db30ebd4d 100644 --- a/system_tests/speech.py +++ b/system_tests/speech.py @@ -118,7 +118,7 @@ def _make_async_request(self, content=None, source_uri=None, def _make_streaming_request(self, file_obj, single_utterance=True, interim_results=False): client = Config.CLIENT - sample = client.sample(content=file_obj, + sample = client.sample(stream=file_obj, encoding=speech.Encoding.LINEAR16, sample_rate=16000) return sample.streaming_recognize(single_utterance=single_utterance,