Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __str__(self):
)


class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse], ModbusProtocol):
class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse]):
"""**ModbusBaseClient**.

Fixed parameters:
Expand Down Expand Up @@ -308,26 +308,22 @@ def __init__(
) -> None:
"""Initialize a client instance."""
ModbusClientMixin.__init__(self) # type: ignore[arg-type]
ModbusProtocol.__init__(
self,
CommParams(
comm_type=kwargs.get("CommType"),
comm_name="comm",
source_address=kwargs.get("source_address", None),
reconnect_delay=reconnect_delay,
reconnect_delay_max=reconnect_delay_max,
timeout_connect=timeout,
host=kwargs.get("host", None),
port=kwargs.get("port", 0),
sslctx=kwargs.get("sslctx", None),
baudrate=kwargs.get("baudrate", None),
bytesize=kwargs.get("bytesize", None),
parity=kwargs.get("parity", None),
stopbits=kwargs.get("stopbits", None),
handle_local_echo=kwargs.get("handle_local_echo", False),
on_reconnect_callback=on_reconnect_callback,
),
False,
self.comm_params = CommParams(
comm_type=kwargs.get("CommType"),
comm_name="comm",
source_address=kwargs.get("source_address", None),
reconnect_delay=reconnect_delay,
reconnect_delay_max=reconnect_delay_max,
timeout_connect=timeout,
host=kwargs.get("host", None),
port=kwargs.get("port", 0),
sslctx=kwargs.get("sslctx", None),
baudrate=kwargs.get("baudrate", None),
bytesize=kwargs.get("bytesize", None),
parity=kwargs.get("parity", None),
stopbits=kwargs.get("stopbits", None),
handle_local_echo=kwargs.get("handle_local_echo", False),
on_reconnect_callback=on_reconnect_callback,
)
self.params = self._params()
self.params.retries = int(retries)
Expand All @@ -349,6 +345,7 @@ def __init__(
self.state = ModbusTransactionState.IDLE
self.last_frame_end: float | None = 0
self.silent_interval: float = 0
self.transport = None

# ----------------------------------------------------------------------- #
# Client external interface
Expand Down
6 changes: 1 addition & 5 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
self.is_closing = False

self.transport: asyncio.BaseTransport = None # type: ignore[assignment]
self.loop: asyncio.AbstractEventLoop = None # type: ignore[assignment]
self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self.recv_buffer: bytes = b""
self.call_create: Callable[[], Coroutine[Any, Any, Any]] = None # type: ignore[assignment]
if self.is_server:
Expand Down Expand Up @@ -237,8 +237,6 @@ def init_setup_connect_listen(self, host: str, port: int) -> None:
async def transport_connect(self) -> bool:
"""Handle generic connect and call on to specific transport connect."""
Log.debug("Connecting {}", self.comm_params.comm_name)
if not self.loop:
self.loop = asyncio.get_running_loop()
self.is_closing = False
try:
self.transport, _protocol = await asyncio.wait_for(
Expand All @@ -253,8 +251,6 @@ async def transport_connect(self) -> bool:
async def transport_listen(self) -> bool:
"""Handle generic listen and call on to specific transport listen."""
Log.debug("Awaiting connections {}", self.comm_params.comm_name)
if not self.loop:
self.loop = asyncio.get_running_loop()
self.is_closing = False
try:
self.transport = await self.call_create()
Expand Down
2 changes: 1 addition & 1 deletion test/test_framers.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_rtu_incoming_packet(rtu_framer, data):
assert mock_reset.call_count == (1 if reset_called else 0)


def test_send_packet(rtu_framer):
async def test_send_packet(rtu_framer):
"""Test send packet."""
message = TEST_MESSAGE
client = ModbusBaseClient(
Expand Down
6 changes: 3 additions & 3 deletions test/transport/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def data_received(self, data):


@pytest.fixture(name="dummy_protocol")
def prepare_dummy_protocol():
async def prepare_dummy_protocol():
"""Return transport object."""
return DummyProtocol


@pytest.fixture(name="client")
def prepare_protocol(use_clc):
async def prepare_protocol(use_clc):
"""Prepare transport object."""
transport = ModbusProtocol(use_clc, False)
transport.callback_connected = mock.Mock()
Expand All @@ -66,7 +66,7 @@ def prepare_protocol(use_clc):


@pytest.fixture(name="server")
def prepare_transport_server(use_cls):
async def prepare_transport_server(use_cls):
"""Prepare transport object."""
transport = ModbusProtocol(use_cls, True)
transport.callback_connected = mock.Mock()
Expand Down
6 changes: 2 additions & 4 deletions test/transport/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ async def test_init_source_addr_none(self, use_clc):

async def test_loop_connect(self, client, dummy_protocol):
"""Test properties."""
client.loop = None
client.call_create = mock.AsyncMock(return_value=(dummy_protocol(), None))
assert await client.transport_connect()
assert client.loop

async def test_loop_listen(self, server, dummy_protocol):
"""Test properties."""
Expand Down Expand Up @@ -342,7 +340,7 @@ def test_generate_ssl_no_file(self, use_clc):

@pytest.mark.parametrize("use_host", ["socket://localhost:5005", "/dev/tty"])
@pytest.mark.parametrize("use_comm_type", [CommType.SERIAL])
def test_init_serial(self, use_cls):
async def test_init_serial(self, use_cls):
"""Test server serial with socket."""
ModbusProtocol(use_cls, True)

Expand All @@ -356,7 +354,7 @@ async def test_init_create_serial(self, use_cls):
@pytest.mark.parametrize("use_host", ["localhost"])
@pytest.mark.parametrize("use_comm_type", [CommType.UDP])
@pytest.mark.parametrize("is_server", [True, False])
def test_init_udp(self, is_server, use_cls, use_clc):
async def test_init_udp(self, is_server, use_cls, use_clc):
"""Test server/client udp."""
if is_server:
ModbusProtocol(use_cls, True)
Expand Down