diff --git a/test/test_client.py b/test/test_client.py index 0db37f2c7..451455cd2 100755 --- a/test/test_client.py +++ b/test/test_client.py @@ -509,6 +509,24 @@ def test_client_tcp_connect(): assert not client.connect() +def test_client_tcp_reuse(): + """Test the tcp client connection method""" + with mock.patch.object(socket, "create_connection") as mock_method: + _socket = mock.MagicMock() + mock_method.return_value = _socket + client = lib_client.ModbusTcpClient("127.0.0.1") + _socket.getsockname.return_value = ("dmmy", 1234) + assert client.connect() + client.close() + with mock.patch.object(socket, "create_connection") as mock_method: + _socket = mock.MagicMock() + mock_method.return_value = _socket + client = lib_client.ModbusTcpClient("127.0.0.1") + _socket.getsockname.return_value = ("dmmy", 1234) + assert client.connect() + client.close() + + def test_client_tls_connect(): """Test the tls client connection method""" with mock.patch.object(ssl.SSLSocket, "connect") as mock_method: diff --git a/test/test_server_task.py b/test/test_server_task.py index c1511e12c..633757b8b 100755 --- a/test/test_server_task.py +++ b/test/test_server_task.py @@ -179,6 +179,40 @@ async def test_async_task_ok(comm): await task +@pytest.mark.xdist_group(name="server_serialize") +@pytest.mark.parametrize("comm", TEST_TYPES) +async def test_async_task_reuse(comm): + """Test normal client/server handling.""" + run_server, server_args, run_client, client_args = helper_config(comm, "async") + + task = asyncio.create_task(run_server(**server_args)) + await asyncio.sleep(0.1) + client = run_client(**client_args) + await client.connect() + await asyncio.sleep(0.1) + assert client._connected # pylint: disable=protected-access + rr = await client.read_coils(1, 1, slave=0x01) + assert len(rr.bits) == 8 + + await client.close() + await asyncio.sleep(0.1) + assert not client._connected # pylint: disable=protected-access + + await client.connect() + await asyncio.sleep(0.1) + assert client._connected # pylint: disable=protected-access + rr = await client.read_coils(1, 1, slave=0x01) + assert len(rr.bits) == 8 + + await client.close() + await asyncio.sleep(0.1) + assert not client._connected # pylint: disable=protected-access + + await server.ServerAsyncStop() + task.cancel() + await task + + @pytest.mark.xdist_group(name="server_serialize") @pytest.mark.parametrize("comm", TEST_TYPES) async def test_async_task_server_stop(comm):