diff --git a/check_ci.sh b/check_ci.sh index 291622910..109b947ba 100755 --- a/check_ci.sh +++ b/check_ci.sh @@ -9,5 +9,5 @@ codespell ruff check --fix --exit-non-zero-on-fix . pylint --recursive=y examples pymodbus test mypy pymodbus -pytest --numprocesses auto +pytest --cov --numprocesses auto echo "Ready to push" diff --git a/pymodbus/message/ascii.py b/pymodbus/message/ascii.py index 400d9ad31..322aaa508 100644 --- a/pymodbus/message/ascii.py +++ b/pymodbus/message/ascii.py @@ -22,10 +22,10 @@ class MessageAscii(MessageBase): def reset(self) -> None: """Clear internal handling.""" - def decode(self, _data: bytes) -> tuple[int, int, bytes]: + def decode(self, _data: bytes) -> tuple[int, int, int, bytes]: """Decode message.""" - return 0, 0, b'' + return 0, 0, 0, b'' - def encode(self, data: bytes, device_id: int, tid: int) -> bytes: + def encode(self, _data: bytes, _device_id: int, _tid: int) -> bytes: """Decode message.""" return b'' diff --git a/pymodbus/message/base.py b/pymodbus/message/base.py index c0b9315a0..a5a386023 100644 --- a/pymodbus/message/base.py +++ b/pymodbus/message/base.py @@ -29,12 +29,13 @@ def reset(self) -> None: """Clear internal handling.""" @abstractmethod - def decode(self, _data: bytes) -> tuple[int, int, bytes]: + def decode(self, _data: bytes) -> tuple[int, int, int, bytes]: """Decode message. return: - used_len (int) + used_len (int) or 0 to read more transaction_id (int) or 0 + device_id (int) or 0 modbus request/response (bytes) """ diff --git a/pymodbus/message/message.py b/pymodbus/message/message.py index 7297ba004..3c971e641 100644 --- a/pymodbus/message/message.py +++ b/pymodbus/message/message.py @@ -74,22 +74,18 @@ def __init__(self, MessageType.TLS: MessageTLS(device_ids, is_server), }[message_type] - def callback_disconnected(self, exc: Exception | None) -> None: - """Call when connection is lost.""" - self.msg_handle.reset() - def callback_data(self, data: bytes, addr: tuple | None = None) -> int: """Handle received data.""" - used_len, tid, modbus = self.msg_handle.decode(data) - if modbus: - self.callback_request_response(modbus, tid) + used_len, tid, device_id, data = self.msg_handle.decode(data) + if data: + self.callback_request_response(data, device_id, tid) return used_len # --------------------- # # callbacks and helpers # # --------------------- # @abstractmethod - def callback_request_response(self, data: bytes, tid: int) -> None: + def callback_request_response(self, data: bytes, device_id: int, tid: int) -> None: """Handle received modbus request/response.""" def build_send(self, data: bytes, device_id: int, tid: int, addr: tuple | None = None) -> None: @@ -101,4 +97,8 @@ def build_send(self, data: bytes, device_id: int, tid: int, addr: tuple | None = :param addr: optional addr, only used for UDP server. """ send_data = self.msg_handle.encode(data, device_id, tid) - super().send(send_data, addr) + self.send(send_data, addr) + + def reset(self) -> None: + """Reset handling.""" + self.msg_handle.reset() diff --git a/pymodbus/message/raw.py b/pymodbus/message/raw.py index dbb951bbe..a198edd89 100644 --- a/pymodbus/message/raw.py +++ b/pymodbus/message/raw.py @@ -9,8 +9,8 @@ class MessageRaw(MessageBase): HEADER: byte[0] = device_id - byte[1-2] = length of request/response, NOT converted - byte[3..] = request/response + byte[1] = transaction_id + byte[2..] = request/response This is mainly for test purposes. """ @@ -18,10 +18,12 @@ class MessageRaw(MessageBase): def reset(self) -> None: """Clear internal handling.""" - def decode(self, _data: bytes) -> tuple[int, int, bytes]: + def decode(self, data: bytes) -> tuple[int, int, int, bytes]: """Decode message.""" - return 0, 0, b'' + if len(data) < 3: + return 0, 0, 0, b'' + return len(data), int(data[0]), int(data[1]), data[2:] def encode(self, data: bytes, device_id: int, tid: int) -> bytes: """Decode message.""" - return b'' + return device_id.to_bytes(1, 'big') + tid.to_bytes(1, 'big') + data diff --git a/pymodbus/message/rtu.py b/pymodbus/message/rtu.py index aec9990b3..4e3cf2188 100644 --- a/pymodbus/message/rtu.py +++ b/pymodbus/message/rtu.py @@ -43,10 +43,10 @@ class MessageRTU(MessageBase): def reset(self) -> None: """Clear internal handling.""" - def decode(self, _data: bytes) -> tuple[int, int, bytes]: + def decode(self, _data: bytes) -> tuple[int, int, int, bytes]: """Decode message.""" - return 0, 0, b'' + return 0, 0, 0, b'' - def encode(self, data: bytes, device_id: int, tid: int) -> bytes: + def encode(self, _data: bytes, _device_id: int, _tid: int) -> bytes: """Decode message.""" return b'' diff --git a/pymodbus/message/socket.py b/pymodbus/message/socket.py index 07ad7a198..c057e5622 100644 --- a/pymodbus/message/socket.py +++ b/pymodbus/message/socket.py @@ -22,10 +22,10 @@ class MessageSocket(MessageBase): def reset(self) -> None: """Clear internal handling.""" - def decode(self, _data: bytes) -> tuple[int, int, bytes]: + def decode(self, _data: bytes) -> tuple[int, int, int, bytes]: """Decode message.""" - return 0, 0, b'' + return 0, 0, 0, b'' - def encode(self, data: bytes, device_id: int, tid: int) -> bytes: + def encode(self, _data: bytes, _device_id: int, _tid: int) -> bytes: """Decode message.""" return b'' diff --git a/pymodbus/message/tls.py b/pymodbus/message/tls.py index 28f958d8f..a299b1be6 100644 --- a/pymodbus/message/tls.py +++ b/pymodbus/message/tls.py @@ -19,10 +19,10 @@ class MessageTLS(MessageBase): def reset(self) -> None: """Clear internal handling.""" - def decode(self, _data: bytes) -> tuple[int, int, bytes]: + def decode(self, _data: bytes) -> tuple[int, int, int, bytes]: """Decode message.""" - return 0, 0, b'' + return 0, 0, 0, b'' - def encode(self, data: bytes, device_id: int, tid: int) -> bytes: + def encode(self, _data: bytes, _device_id: int, _tid: int) -> bytes: """Decode message.""" return b'' diff --git a/test/message/conftest.py b/test/message/conftest.py index 570acbe21..cd5345d01 100644 --- a/test/message/conftest.py +++ b/test/message/conftest.py @@ -1,28 +1,38 @@ """Configure pytest.""" from __future__ import annotations +from unittest import mock + import pytest -from pymodbus.logging import Log -from pymodbus.message import Message -from pymodbus.transport import ModbusProtocol +from pymodbus.message import Message, MessageType +from pymodbus.transport import CommParams, ModbusProtocol class DummyMessage(Message): """Implement use of ModbusProtocol.""" + def __init__(self, + message_type: MessageType, + params: CommParams, + is_server: bool, + device_ids: list[int] | None, + ): + """Initialize a message instance.""" + super().__init__(message_type, params, is_server, device_ids) + self.send = mock.Mock() + def callback_new_connection(self) -> ModbusProtocol: """Call when listener receive new connection request.""" - return DummyMessage(self.message_type, self.comm_params, self.is_server, self.device_ids) + return DummyMessage(self.message_type, self.comm_params, self.is_server, self.device_ids) # pragma: no cover def callback_connected(self) -> None: """Call when connection is succcesfull.""" def callback_disconnected(self, exc: Exception | None) -> None: """Call when connection is lost.""" - Log.debug("callback_disconnected called: {}", exc) - def callback_request_response(self, data: bytes, tid: int) -> None: + def callback_request_response(self, data: bytes, device_id: int, tid: int) -> None: """Handle received modbus request/response.""" diff --git a/test/message/test_message.py b/test/message/test_message.py index a7659c66f..638147e9c 100644 --- a/test/message/test_message.py +++ b/test/message/test_message.py @@ -1,19 +1,98 @@ """Test transport.""" +from unittest import mock + import pytest from pymodbus.message import MessageType from pymodbus.transport import CommParams -class TestMessage: # pylint: disable=too-few-public-methods +class TestMessage: """Test message module.""" @pytest.mark.parametrize(("entry"), list(MessageType)) - async def test_message_type(self, entry, dummy_message): + async def test_message_init(self, entry, dummy_message): + """Test message type.""" + msg = dummy_message(entry.value, + CommParams(), + False, + [1], + ) + assert msg.msg_handle + + async def test_message_callback_data(self, dummy_message): + """Test message type.""" + msg = dummy_message(MessageType.RAW, + CommParams(), + False, + [1], + ) + msg.msg_handle.decode = mock.MagicMock(return_value=(5,0,0,b'')) + assert msg.callback_data(b'') == 5 + + async def test_message_callback_data_decode(self, dummy_message): """Test message type.""" - dummy_message(entry.value, + msg = dummy_message(MessageType.RAW, + CommParams(), + False, + [1], + ) + msg.msg_handle.decode = mock.MagicMock(return_value=(17,0,1,b'decode')) + assert msg.callback_data(b'') == 17 + + async def test_message_build_send(self, dummy_message): + """Test message type.""" + msg = dummy_message(MessageType.RAW, + CommParams(), + False, + [1], + ) + msg.msg_handle.encode = mock.MagicMock(return_value=(b'decode')) + msg.build_send(b'decode', 1, 0) + msg.msg_handle.encode.assert_called_once() + msg.send.assert_called_once() + + async def test_message_reset(self, dummy_message): + """Test message type.""" + msg = dummy_message(MessageType.RAW, + CommParams(), + False, + [1], + ) + msg.msg_handle.reset = mock.Mock() + msg.reset() + + @pytest.mark.parametrize( + ("msg_type", "data", "res_len", "res_id", "res_tid", "res_data"), [ + (MessageType.RAW, b'\x00\x01', 0, 0, 0, b''), + (MessageType.RAW, b'\x01\x02\x03', 3, 1, 2, b'\x03'), + (MessageType.RAW, b'\x04\x05\x06\x07\x08\x09\x00\x01\x02\x03', 10, 4, 5, b'\x06\x07\x08\x09\x00\x01\x02\x03'), + ]) + async def test_decode(self, dummy_message, msg_type, data, res_id, res_tid, res_len, res_data): + """Test decode method in all types.""" + msg = dummy_message(msg_type, + CommParams(), + False, + [1], + ) + t_len, t_id, t_tid, t_data = msg.msg_handle.decode(data) + assert res_len == t_len + assert res_id == t_id + assert res_tid == t_tid + assert res_data == t_data + + @pytest.mark.parametrize( + ("msg_type", "data", "dev_id", "tid", "res_data"), [ + (MessageType.RAW, b'\x01\x02', 5, 6, b'\x05\x06\x01\x02'), + (MessageType.RAW, b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09', 17, 25, b'\x11\x19\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09'), + ]) + async def test_encode(self, dummy_message, msg_type, data, dev_id, tid, res_data): + """Test decode method in all types.""" + msg = dummy_message(msg_type, CommParams(), False, [1], ) + t_data = msg.msg_handle.encode(data, dev_id, tid) + assert res_data == t_data