Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def run_async_client(client, modbus_calls=None):
assert client.connected
if modbus_calls:
await modbus_calls(client)
await client.close()
client.close()
_logger.info("### End of Program")


Expand Down
88 changes: 39 additions & 49 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transaction import DictTransactionManager
from pymodbus.transport import BaseTransport
from pymodbus.utilities import ModbusTransactionState


class ModbusBaseClient(ModbusClientMixin):
class ModbusBaseClient(ModbusClientMixin, BaseTransport):
"""**ModbusBaseClient**

**Parameters common to all clients**:
Expand Down Expand Up @@ -63,8 +64,6 @@ class _params: # pylint: disable=too-many-instance-attributes
broadcast_enable: bool = None
kwargs: dict = None
reconnect_delay: int = None
reconnect_delay_max: int = None
on_reconnect_callback: Callable[[], None] | None = None

baudrate: int = None
bytesize: int = None
Expand Down Expand Up @@ -95,6 +94,7 @@ def __init__( # pylint: disable=too-many-arguments
**kwargs: Any,
) -> None:
"""Initialize a client instance."""
BaseTransport.__init__(self)
self.params = self._params()
self.params.framer = framer
self.params.timeout = float(timeout)
Expand All @@ -104,8 +104,8 @@ def __init__( # pylint: disable=too-many-arguments
self.params.strict = bool(strict)
self.params.broadcast_enable = bool(broadcast_enable)
self.params.reconnect_delay = int(reconnect_delay)
self.params.reconnect_delay_max = int(reconnect_delay_max)
self.params.on_reconnect_callback = on_reconnect_callback
self.reconnect_delay_max = int(reconnect_delay_max)
self.on_reconnect_callback = on_reconnect_callback
self.params.kwargs = kwargs

# Common variables.
Expand All @@ -115,15 +115,14 @@ def __init__( # pylint: disable=too-many-arguments
)
self.delay_ms = self.params.reconnect_delay
self.use_protocol = False
self._connected = False
self.use_udp = False
self.state = ModbusTransactionState.IDLE
self.last_frame_end: float = 0
self.silent_interval: float = 0
self.transport = None
self._reconnect_task = None

# Initialize mixin
super().__init__()
ModbusClientMixin.__init__(self)

# ----------------------------------------------------------------------- #
# Client external interface
Expand Down Expand Up @@ -174,30 +173,16 @@ def execute(self, request: ModbusRequest = None) -> ModbusResponse:
:raises ConnectionException: Check exception text.
"""
if self.use_protocol:
if not self._connected:
if not self.transport:
raise ConnectionException(f"Not connected[{str(self)}]")
return self.async_execute(request)
if not self.connect():
raise ConnectionException(f"Failed to connect[{str(self)}]")
return self.transaction.execute(request)

def close(self) -> None:
"""Close the underlying socket connection (call **sync/async**)."""
raise NotImplementedException

# ----------------------------------------------------------------------- #
# Merged client methods
# ----------------------------------------------------------------------- #
def client_made_connection(self, protocol):
"""Run transport specific connection."""

def client_lost_connection(self, protocol):
"""Run transport specific connection lost."""

def datagram_received(self, data, _addr):
"""Receive datagram."""
self.data_received(data)

async def async_execute(self, request=None):
"""Execute requests asynchronously."""
request.transaction_id = self.transaction.getNextTID()
Expand All @@ -218,29 +203,13 @@ async def async_execute(self, request=None):
raise
return resp

def connection_made(self, transport):
"""Call when a connection is made.

The transport argument is the transport representing the connection.
"""
self.transport = transport
Log.debug("Client connected to modbus server")
self._connected = True
self.client_made_connection(self)

def connection_lost(self, reason):
"""Call when the connection is lost or closed.

The argument is either an exception object or None
"""
if self.transport:
self.transport.abort()
if hasattr(self.transport, "_sock"):
self.transport._sock.close() # pylint: disable=protected-access
self.transport = None
self.client_lost_connection(self)
Log.debug("Client disconnected from modbus server: {}", reason)
self._connected = False
self.close(reconnect=True)
for tid in list(self.transaction):
self.raise_future(
self.transaction.getTransaction(tid),
Expand Down Expand Up @@ -277,22 +246,43 @@ def _handle_response(self, reply, **_kwargs):
def _build_response(self, tid):
"""Return a deferred response for the current request."""
my_future = self.create_future()
if not self._connected:
if not self.transport:
self.raise_future(my_future, ConnectionException("Client is not connected"))
else:
self.transaction.addTransaction(my_future, tid)
return my_future

@property
def async_connected(self):
"""Return connection status."""
return self._connected
def close(self, reconnect: bool = False) -> None:
"""Close connection.

async def async_close(self):
"""Close connection."""
:param reconnect: (default false), try to reconnect
"""
if self.transport:
if hasattr(self.transport, "_sock"):
self.transport._sock.close() # pylint: disable=protected-access
self.transport.abort()
self.transport.close()
self._connected = False
self.transport = None
if self._reconnect_task:
self._reconnect_task.cancel()
self._reconnect_task = None

if not reconnect or not self.delay_ms:
self.delay_ms = 0
return

self._reconnect_task = asyncio.create_task(self._reconnect())

async def _reconnect(self):
"""Reconnect."""
Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms)
await asyncio.sleep(self.delay_ms / 1000)
self.delay_ms = min(2 * self.delay_ms, self.reconnect_delay_max)

self._reconnect_task = None
if self.on_reconnect_callback:
self.on_reconnect_callback()
return await self.connect()

# ----------------------------------------------------------------------- #
# Internal methods
Expand Down Expand Up @@ -353,7 +343,7 @@ def __exit__(self, klass, value, traceback):

async def __aexit__(self, klass, value, traceback):
"""Implement the client with exit block."""
await self.close()
self.close()

def __str__(self):
"""Build a string representation of the connection.
Expand Down
92 changes: 16 additions & 76 deletions pymodbus/client/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def run():

await client.connect()
...
await client.close()
client.close()
"""

transport = None
Expand All @@ -68,25 +68,6 @@ def __init__(
self.params.parity = parity
self.params.stopbits = stopbits
self.params.handle_local_echo = handle_local_echo
self.loop = None
self._connected_event = asyncio.Event()
self._reconnect_task = None

async def close(self): # pylint: disable=invalid-overridden-method
"""Stop connection."""

# prevent reconnect:
self.delay_ms = 0
if self.connected:
if self.transport:
self.transport.close()
await self.async_close()
await asyncio.sleep(0.1)

# if there is an unfinished delayed reconnection attempt pending, cancel it
if self._reconnect_task:
self._reconnect_task.cancel()
self._reconnect_task = None

def _create_protocol(self):
"""Create a protocol instance."""
Expand All @@ -95,74 +76,33 @@ def _create_protocol(self):
@property
def connected(self):
"""Connect internal."""
return self._connected_event.is_set()
return self.transport is not None

async def connect(self): # pylint: disable=invalid-overridden-method
"""Connect Async client."""
# get current loop, if there are no loop a RuntimeError will be raised
self.loop = asyncio.get_running_loop()

Log.debug("Starting serial connection")
try:
await create_serial_connection(
self.loop,
self._create_protocol,
self.params.port,
baudrate=self.params.baudrate,
bytesize=self.params.bytesize,
stopbits=self.params.stopbits,
parity=self.params.parity,
await asyncio.wait_for(
create_serial_connection(
self.loop,
self._create_protocol,
self.params.port,
baudrate=self.params.baudrate,
bytesize=self.params.bytesize,
stopbits=self.params.stopbits,
parity=self.params.parity,
timeout=self.params.timeout,
**self.params.kwargs,
),
timeout=self.params.timeout,
**self.params.kwargs,
)
await self._connected_event.wait()
Log.info("Connected to {}", self.params.port)
except Exception as exc: # pylint: disable=broad-except
Log.warning("Failed to connect: {}", exc)
if self.delay_ms > 0:
self._launch_reconnect()
self.close(reconnect=True)
return self.connected

def client_made_connection(self, protocol):
"""Notify successful connection."""
Log.info("Serial connected.")
if not self.connected:
self._connected_event.set()
else:
Log.error("Factory protocol connect callback called while connected.")

def client_lost_connection(self, protocol):
"""Notify lost connection."""
Log.info("Serial lost connection.")
if protocol is not self:
Log.error("Serial: protocol is not self.")

self._connected_event.clear()
if self.delay_ms:
self._launch_reconnect()

def _launch_reconnect(self):
"""Launch delayed reconnection coroutine"""
if self._reconnect_task:
Log.warning(
"Ignoring launch of delayed reconnection, another is in progress"
)
else:
# store the future in a member variable so we know we have a pending reconnection attempt
# also prevents its garbage collection
self._reconnect_task = asyncio.create_task(self._reconnect())

async def _reconnect(self):
"""Reconnect."""
Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms)
await asyncio.sleep(self.delay_ms / 1000)
self.delay_ms = min(2 * self.delay_ms, self.params.reconnect_delay_max)

self._reconnect_task = None
if self.params.on_reconnect_callback:
self.params.on_reconnect_callback()
return await self.connect()


class ModbusSerialClient(ModbusBaseClient):
"""**ModbusSerialClient**.
Expand Down Expand Up @@ -267,7 +207,7 @@ def connect(self):
self.close()
return self.socket is not None

def close(self):
def close(self): # pylint: disable=arguments-differ
"""Close the underlying socket connection."""
if self.socket:
self.socket.close()
Expand Down
Loading