Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ async def open(self) -> None:

async def close(self) -> None:
"""Closes the bidi-gRPC connection."""
raise NotImplementedError(
"close() is not implemented yet in _AsyncWriteObjectStream"
)
if not self._is_stream_open:
raise ValueError("Stream is not open")
await self.socket_like_rpc.close()
self._is_stream_open = False

async def send(
self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest
Expand Down
69 changes: 63 additions & 6 deletions tests/unit/asyncio/test_async_write_object_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest
from unittest import mock

from unittest.mock import AsyncMock
from google.cloud.storage._experimental.asyncio.async_write_object_stream import (
_AsyncWriteObjectStream,
)
Expand Down Expand Up @@ -43,6 +44,27 @@ def mock_client():
return client


async def instantiate_write_obj_stream(mock_client, mock_cls_async_bidi_rpc, open=True):
"""Helper to create an instance of _AsyncWriteObjectStream and open it by default."""
socket_like_rpc = AsyncMock()
mock_cls_async_bidi_rpc.return_value = socket_like_rpc
socket_like_rpc.open = AsyncMock()
socket_like_rpc.close = AsyncMock()

mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse)
mock_response.resource = mock.MagicMock(spec=_storage_v2.Object)
mock_response.resource.generation = GENERATION
mock_response.write_handle = WRITE_HANDLE
socket_like_rpc.recv = AsyncMock(return_value=mock_response)

write_obj_stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)

if open:
await write_obj_stream.open()

return write_obj_stream


def test_async_write_object_stream_init(mock_client):
"""Test the constructor of _AsyncWriteObjectStream."""
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
Expand Down Expand Up @@ -228,7 +250,6 @@ async def test_open_raises_error_on_missing_generation(
ValueError, match="Failed to obtain object generation after opening the stream"
):
await stream.open()
# assert stream.generation_number is None


@pytest.mark.asyncio
Expand All @@ -252,13 +273,49 @@ async def test_open_raises_error_on_missing_write_handle(


@pytest.mark.asyncio
async def test_unimplemented_methods_raise_error(mock_client):
"""Test that unimplemented methods raise NotImplementedError."""
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_close(mock_cls_async_bidi_rpc, mock_client):
"""Test that close successfully closes the stream."""
# Arrange
write_obj_stream = await instantiate_write_obj_stream(
mock_client, mock_cls_async_bidi_rpc, open=True
)

with pytest.raises(NotImplementedError):
await stream.close()
# Act
await write_obj_stream.close()

# Assert
write_obj_stream.socket_like_rpc.close.assert_called_once()
assert not write_obj_stream.is_stream_open


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_close_without_open_should_raise_error(
mock_cls_async_bidi_rpc, mock_client
):
"""Test that closing a stream that is not open raises a ValueError."""
# Arrange
write_obj_stream = await instantiate_write_obj_stream(
mock_client, mock_cls_async_bidi_rpc, open=False
)

# Act & Assert
with pytest.raises(ValueError, match="Stream is not open"):
await write_obj_stream.close()


@pytest.mark.asyncio
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc"
)
async def test_unimplemented_methods_raise_error(mock_async_bidi_rpc, mock_client):
"""Test that unimplemented methods (send, recv) raise NotImplementedError."""
stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT)
with pytest.raises(NotImplementedError):
await stream.send(_storage_v2.BidiWriteObjectRequest())

Expand Down