From 2c4a52408ae2f2ccbaad75cd1a5ff8afa0febe8d Mon Sep 17 00:00:00 2001 From: David Yu Yang Date: Wed, 6 Oct 2021 13:13:55 +0800 Subject: [PATCH 1/2] Type hinting automaton.py Used ansmachine.py as initial reference, grepped into some other files along the way. Some notes taken while adding the hints: add_breakpoints returns None Unsure of object type that could have fileno() on 609 645 assumed from rd and wr from the above class --- .config/mypy/mypy_enabled.txt | 1 + scapy/automaton.py | 424 +++++++++++++++++++++++++--------- scapy/compat.py | 6 + 3 files changed, 318 insertions(+), 113 deletions(-) diff --git a/.config/mypy/mypy_enabled.txt b/.config/mypy/mypy_enabled.txt index 107dd82d797..291e0962f1a 100644 --- a/.config/mypy/mypy_enabled.txt +++ b/.config/mypy/mypy_enabled.txt @@ -18,6 +18,7 @@ scapy/asn1/ber.py scapy/asn1/mib.py scapy/asn1fields.py scapy/asn1packet.py +scapy/automaton.py scapy/base_classes.py scapy/compat.py scapy/config.py diff --git a/scapy/automaton.py b/scapy/automaton.py index 91465dae426..35563fbc54f 100644 --- a/scapy/automaton.py +++ b/scapy/automaton.py @@ -24,6 +24,7 @@ import select from collections import deque +from typing import Dict from scapy.config import conf from scapy.utils import do_graph @@ -31,11 +32,28 @@ from scapy.plist import PacketList from scapy.data import MTU from scapy.supersocket import SuperSocket +from scapy.packet import Packet from scapy.consts import WINDOWS import scapy.modules.six as six +from scapy.compat import ( + Union, + List, + Optional, + Any, + Type, + Callable, + Protocol, + Iterator, + Tuple, + Set, + TypeVar, + Deque, +) + def select_objects(inputs, remain): + # type: (List[Any], Union[float, int, None]) -> List[Any] """ Select objects. Same than: ``select.select(inputs, [], [], remain)`` @@ -58,6 +76,8 @@ def select_objects(inputs, remain): :param inputs: objects to process :param remain: timeout. If 0, return []. """ + if not remain: + return [] if not WINDOWS: return select.select(inputs, [], [], remain)[0] natives = [] @@ -75,23 +95,25 @@ def select_objects(inputs, remain): if events: remainms = int((remain or 0) * 1000) if len(events) == 1: - res = ctypes.windll.kernel32.WaitForSingleObject( - ctypes.c_void_p(events[0].fileno()), - remainms - ) + if sys.platform == "win32": + res = ctypes.windll.kernel32.WaitForSingleObject( + ctypes.c_void_p(events[0].fileno()), + remainms + ) else: # Sadly, the only way to emulate select() is to first check # if any object is available using WaitForMultipleObjects # then poll the others. - res = ctypes.windll.kernel32.WaitForMultipleObjects( - len(events), - (ctypes.c_void_p * len(events))( - *[x.fileno() for x in events] - ), - False, - remainms - ) - if res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout + if sys.platform == "win32": + res = ctypes.windll.kernel32.WaitForMultipleObjects( + len(events), + (ctypes.c_void_p * len(events))( + *[x.fileno() for x in events] + ), + False, + remainms + ) + if sys.platform == "win32" and res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout # noqa: E501 results.add(events[res]) if len(events) > 1: # Now poll the others, if any @@ -107,69 +129,90 @@ def select_objects(inputs, remain): class ObjectPipe: def __init__(self, name=None): + # type: (Optional[str]) -> None self.name = name or "ObjectPipe" self._closed = False self.__rd, self.__wr = os.pipe() - self.__queue = deque() + self.__queue = deque() # type: Deque[Union[str, bytes, Message]] if WINDOWS: self._wincreate() def _wincreate(self): - self._fd = ctypes.windll.kernel32.CreateEventA( - None, True, False, - ctypes.create_string_buffer(b"ObjectPipe %f" % random.random()) - ) + # type: () -> None + if sys.platform == "win32": + self._fd = ctypes.windll.kernel32.CreateEventA( + None, True, False, + ctypes.create_string_buffer(b"ObjectPipe %f" % random.random()) + ) def _winset(self): - if ctypes.windll.kernel32.SetEvent( - ctypes.c_void_p(self._fd)) == 0: - warning(ctypes.FormatError()) + # type: () -> None + if sys.platform == "win32": + if ctypes.windll.kernel32.SetEvent( + ctypes.c_void_p(self._fd)) == 0: + warning(ctypes.FormatError()) def _winreset(self): - if ctypes.windll.kernel32.ResetEvent( - ctypes.c_void_p(self._fd)) == 0: - warning(ctypes.FormatError()) + # type: () -> None + if sys.platform == "win32": + if ctypes.windll.kernel32.ResetEvent( + ctypes.c_void_p(self._fd)) == 0: + warning(ctypes.FormatError()) def _winclose(self): - if self._fd and ctypes.windll.kernel32.CloseHandle( - ctypes.c_void_p(self._fd)) == 0: - warning(ctypes.FormatError()) - self._fd = None + # type: () -> None + if sys.platform == "win32": + if self._fd and ctypes.windll.kernel32.CloseHandle( + ctypes.c_void_p(self._fd)) == 0: + warning(ctypes.FormatError()) + self._fd = None def fileno(self): + # type: () -> int if WINDOWS: - return self._fd - else: - return self.__rd + if sys.platform == "win32": + return self._fd + return self.__rd def send(self, obj): + # type: (Union[str, bytes, Message]) -> int self.__queue.append(obj) if WINDOWS: self._winset() os.write(self.__wr, b"X") + return 0 def write(self, obj): + # type: (str) -> None self.send(obj) def empty(self): + # type: () -> bool return not bool(self.__queue) def flush(self): + # type: () -> None pass def recv(self, n=0): + # type: (Optional[int]) -> Optional[Message] if self._closed: return None os.read(self.__rd, 1) elt = self.__queue.popleft() if WINDOWS and not self.__queue: self._winreset() - return elt + if isinstance(elt, Message): + return elt + else: + return Message(elt) def read(self, n=0): + # type: (Optional[int]) -> Any return self.recv(n) def close(self): + # type: () -> None if not self._closed: self._closed = True os.close(self.__rd) @@ -179,14 +222,16 @@ def close(self): self._winclose() def __repr__(self): + # type: () -> str return "<%s at %s>" % (self.name, id(self)) def __del__(self): + # type: () -> None self.close() @staticmethod def select(sockets, remain=conf.recv_poll_rate): - # Only handle ObjectPipes + # type: (List[SuperSocket], float) -> List[SuperSocket] results = [] for s in sockets: if s.closed: @@ -196,11 +241,19 @@ def select(sockets, remain=conf.recv_poll_rate): return select_objects(sockets, remain) -class Message: - def __init__(self, **args): - self.__dict__.update(args) +class Message(str): + type = None # type: str + pkt = None # type: Packet + result = None # type: str + state = None # type: Message + exc_info = None # type: Union[Tuple[None, None, None], Tuple[BaseException, Exception, types.TracebackType]] # noqa: E501 + + def __new__(cls, *args): + # type: (Any) -> Any + return str.__new__(cls, *args) def __repr__(self): + # type: () -> str return "" % " ".join("%s=%r" % (k, v) for (k, v) in six.iteritems(self.__dict__) # noqa: E501 if not k.startswith("_")) @@ -208,26 +261,33 @@ def __repr__(self): class _instance_state: def __init__(self, instance): + # type: (Any) -> None self.__self__ = instance.__self__ self.__func__ = instance.__func__ self.__self__.__class__ = instance.__self__.__class__ def __getattr__(self, attr): + # type: (str) -> Any return getattr(self.__func__, attr) def __call__(self, *args, **kargs): + # type: (Any, Any) -> Any return self.__func__(self.__self__, *args, **kargs) def breaks(self): + # type: () -> Any return self.__self__.add_breakpoints(self.__func__) def intercepts(self): + # type: () -> Any return self.__self__.add_interception_points(self.__func__) def unbreaks(self): + # type: () -> Any return self.__self__.remove_breakpoints(self.__func__) def unintercepts(self): + # type: () -> Any return self.__self__.remove_interception_points(self.__func__) @@ -235,6 +295,29 @@ def unintercepts(self): # Automata # ############## +class StateWrapper(Protocol[TypeVar("F", bound=Callable[..., object])]): + __name__ = None # type: str + atmt_type = None # type: str + atmt_state = None # type: str + atmt_initial = None # type: int + atmt_final = None # type: int + atmt_stop = None # type: int + atmt_error = None # type: int + atmt_origfunc = None # type: StateWrapper + atmt_prio = None # type: int + atmt_as_supersocket = None # type: str + atmt_condname = None # type: str + atmt_ioname = None # type: str + atmt_timeout = None # type: int + atmt_cond = None # type: Dict[str, int] + __call__ = None # type: TypeVar("F", bound=Callable[..., object]) # noqa: E501 + + +def state_wrapper_decorator(f): + # type: (Callable) -> StateWrapper + return f + + class ATMT: STATE = "State" ACTION = "Action" @@ -245,6 +328,7 @@ class ATMT: class NewStateRequested(Exception): def __init__(self, state_func, automaton, *args, **kargs): + # type: (Any, ATMT, Any, Any) -> None self.func = state_func self.state = state_func.atmt_state self.initial = state_func.atmt_initial @@ -258,19 +342,28 @@ def __init__(self, state_func, automaton, *args, **kargs): self.action_parameters() # init action parameters def action_parameters(self, *args, **kargs): + # type: (Any, Any) -> NewStateRequested self.action_args = args self.action_kargs = kargs return self def run(self): + # type: () -> Any return self.func(self.automaton, *self.args, **self.kargs) def __repr__(self): + # type: () -> str return "NewStateRequested(%s)" % self.state - @staticmethod - def state(initial=0, final=0, stop=0, error=0): + def state(self, + initial=0, # type: int + final=0, # type: int + stop=0, # type: int + error=0 # type: int + ): + # type: (...) -> Callable[[StateWrapper, int, int], Callable[[Any, Any], NewStateRequested]] # noqa: E501 def deco(f, initial=initial, final=final): + # type: (StateWrapper, int, int) -> Callable[[Any, Any], NewStateRequested] # noqa: E501 f.atmt_type = ATMT.STATE f.atmt_state = f.__name__ f.atmt_initial = initial @@ -278,7 +371,9 @@ def deco(f, initial=initial, final=final): f.atmt_stop = stop f.atmt_error = error + @state_wrapper_decorator def state_wrapper(self, *args, **kargs): + # type: (ATMT, Any, Any) -> NewStateRequested return ATMT.NewStateRequested(f, self, *args, **kargs) state_wrapper.__name__ = "%s_wrapper" % f.__name__ @@ -294,7 +389,9 @@ def state_wrapper(self, *args, **kargs): @staticmethod def action(cond, prio=0): + # type: (Any, int) -> Callable[[StateWrapper, StateWrapper], StateWrapper] # noqa: E501 def deco(f, cond=cond): + # type: (StateWrapper, StateWrapper) -> StateWrapper if not hasattr(f, "atmt_type"): f.atmt_cond = {} f.atmt_type = ATMT.ACTION @@ -304,7 +401,9 @@ def deco(f, cond=cond): @staticmethod def condition(state, prio=0): + # type: (Any, int) -> Callable[[StateWrapper, StateWrapper], StateWrapper] # noqa: E501 def deco(f, state=state): + # type: (StateWrapper, StateWrapper) -> Any f.atmt_type = ATMT.CONDITION f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ @@ -314,7 +413,9 @@ def deco(f, state=state): @staticmethod def receive_condition(state, prio=0): + # type: (StateWrapper, int) -> Callable[[StateWrapper, StateWrapper], StateWrapper] # noqa: E501 def deco(f, state=state): + # type: (StateWrapper, StateWrapper) -> StateWrapper f.atmt_type = ATMT.RECV f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ @@ -323,20 +424,28 @@ def deco(f, state=state): return deco @staticmethod - def ioevent(state, name, prio=0, as_supersocket=None): + def ioevent(state, # type: StateWrapper + name, # type: str + prio=0, # type: int + as_supersocket=None # type: Optional[str] + ): + # type: (...) -> Callable[[StateWrapper, StateWrapper], StateWrapper] def deco(f, state=state): + # type: (StateWrapper, StateWrapper) -> StateWrapper f.atmt_type = ATMT.IOEVENT f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ f.atmt_ioname = name f.atmt_prio = prio - f.atmt_as_supersocket = as_supersocket + f.atmt_as_supersocket = as_supersocket if as_supersocket else "" return f return deco @staticmethod def timeout(state, timeout): + # type: (StateWrapper, int) -> Callable[[StateWrapper, StateWrapper, int], StateWrapper] # noqa: E501 def deco(f, state=state, timeout=timeout): + # type: (StateWrapper, StateWrapper, int) -> StateWrapper f.atmt_type = ATMT.TIMEOUT f.atmt_state = state.atmt_state f.atmt_timeout = timeout @@ -362,7 +471,15 @@ class _ATMT_Command: class _ATMT_supersocket(SuperSocket): - def __init__(self, name, ioevent, automaton, proto, *args, **kargs): + def __init__(self, + name, # type: str + ioevent, # type: str + automaton, # type: Type[Automaton] + proto, # type: Callable[[Message], Any] + *args, # type: Any + **kargs # type: Any + ): + # type: (...) -> None self.name = name self.ioevent = ioevent self.proto = proto @@ -374,20 +491,24 @@ def __init__(self, name, ioevent, automaton, proto, *args, **kargs): self.atmt.runbg() def send(self, s): + # type: (bytes) -> int if not isinstance(s, bytes): s = bytes(s) - self.spa.send(s) + return self.spa.send(s) def fileno(self): + # type: () -> int return self.spb.fileno() def recv(self, n=MTU): + # type: (Optional[int]) -> Any r = self.spb.recv(n) - if self.proto is not None: + if self.proto is not None and r is not None: r = self.proto(r) return r def close(self): + # type: () -> None if not self.closed: self.atmt.stop() self.spa.close() @@ -396,16 +517,19 @@ def close(self): @staticmethod def select(sockets, remain=conf.recv_poll_rate): + # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] return select_objects(sockets, remain) class _ATMT_to_supersocket: def __init__(self, name, ioevent, automaton): + # type: (str, str, Type[Automaton]) -> None self.name = name self.ioevent = ioevent self.automaton = automaton def __call__(self, proto, *args, **kargs): + # type: (Callable[[Message], Any], Any, Any) -> _ATMT_supersocket return _ATMT_supersocket( self.name, self.ioevent, self.automaton, proto, *args, **kargs @@ -414,16 +538,17 @@ def __call__(self, proto, *args, **kargs): class Automaton_metaclass(type): def __new__(cls, name, bases, dct): + # type: (str, Tuple[Any], Dict[str, Any]) -> Automaton_metaclass cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct) cls.states = {} - cls.state = None - cls.recv_conditions = {} - cls.conditions = {} - cls.ioevents = {} - cls.timeout = {} - cls.actions = {} - cls.initial_states = [] - cls.stop_states = [] + cls.state = StateWrapper() # type: StateWrapper + cls.recv_conditions = {} # type: Dict[str, List[StateWrapper]] + cls.conditions = {} # type: Dict[str, List[StateWrapper]] + cls.ioevents = {} # type: Dict[str, List[StateWrapper]] + cls.timeout = {} # type: Dict[str, List[Tuple[int, StateWrapper]]] # noqa: E501 + cls.actions = {} # type: Dict[str, List[StateWrapper]] + cls.initial_states = [] # type: List[StateWrapper] + cls.stop_states = [] # type: List[StateWrapper] cls.ionames = [] cls.iosupersockets = [] @@ -437,7 +562,7 @@ def __new__(cls, name, bases, dct): members[k] = v decorated = [v for v in six.itervalues(members) - if isinstance(v, types.FunctionType) and hasattr(v, "atmt_type")] # noqa: E501 + if isinstance(v, StateWrapper) and hasattr(v, "atmt_type")] # noqa: E501 for m in decorated: if m.atmt_type == ATMT.STATE: @@ -467,8 +592,8 @@ def __new__(cls, name, bases, dct): elif m.atmt_type == ATMT.TIMEOUT: cls.timeout[m.atmt_state].append((m.atmt_timeout, m)) elif m.atmt_type == ATMT.ACTION: - for c in m.atmt_cond: - cls.actions[c].append(m) + for co in m.atmt_cond: + cls.actions[co].append(m) for v in six.itervalues(cls.timeout): v.sort(key=lambda x: x[0]) @@ -486,6 +611,7 @@ def __new__(cls, name, bases, dct): return cls def build_graph(self): + # type: () -> str s = 'digraph "%s" {\n' % self.__class__.__name__ se = "" # Keep initial nodes at the beginning for better rendering @@ -529,12 +655,64 @@ def build_graph(self): return s def graph(self, **kargs): + # type: (Any) -> Optional[str] s = self.build_graph() return do_graph(s, **kargs) class Automaton(six.with_metaclass(Automaton_metaclass)): + # Internals + def __init__(self, *args, **kargs): + # type: (Any, Any) -> None + external_fd = kargs.pop("external_fd", {}) + self.send_sock_class = kargs.pop("ll", conf.L3socket) + self.recv_sock_class = kargs.pop("recvsock", conf.L2listen) + self.is_atmt_socket = kargs.pop("is_atmt_socket", False) + self.started = threading.Lock() + self.threadid = None # type: Optional[int] + self.breakpointed = None + self.breakpoints = set() # type: Set[StateWrapper] + self.interception_points = set() # type: Set[StateWrapper] + self.intercepted_packet = None # type: Union[None, Packet] + self.debug_level = 0 + self.init_args = args + self.init_kargs = kargs + self.io = type.__new__(type, "IOnamespace", (), {}) + self.oi = type.__new__(type, "IOnamespace", (), {}) + self.cmdin = ObjectPipe("cmdin") + self.cmdout = ObjectPipe("cmdout") + self.ioin = {} + self.ioout = {} + self.packets = PacketList() # type: PacketList + for n in self.ionames: + extfd = external_fd.get(n) + if not isinstance(extfd, tuple): + extfd = (extfd, extfd) + ioin, ioout = extfd + if ioin is None: + ioin = ObjectPipe("ioin") + else: + ioin = self._IO_fdwrapper(ioin, None) + if ioout is None: + ioout = ObjectPipe("ioout") + else: + ioout = self._IO_fdwrapper(None, ioout) + + self.ioin[n] = ioin + self.ioout[n] = ioout + ioin.ioname = n + ioout.ioname = n + setattr(self.io, n, self._IO_mixer(ioout, ioin)) + setattr(self.oi, n, self._IO_mixer(ioin, ioout)) + + for stname in self.states: + setattr(self, stname, + _instance_state(getattr(self, stname))) + + self.start() + def parse_args(self, debug=0, store=1, **kargs): + # type: (int, int, Any) -> None self.debug_level = debug if debug: conf.logLevel = logging.DEBUG @@ -542,14 +720,17 @@ def parse_args(self, debug=0, store=1, **kargs): self.store_packets = store def master_filter(self, pkt): + # type: (Packet) -> bool return True def my_send(self, pkt): + # type: (Packet) -> None self.send_sock.send(pkt) # Utility classes and exceptions class _IO_fdwrapper: def __init__(self, rd, wr): + # type: (Union[int, ObjectPipe, None], Union[int, ObjectPipe, None]) -> None # noqa: E501 if rd is not None and not isinstance(rd, (int, ObjectPipe)): rd = rd.fileno() if wr is not None and not isinstance(wr, (int, ObjectPipe)): @@ -558,50 +739,74 @@ def __init__(self, rd, wr): self.wr = wr def fileno(self): + # type: () -> int if isinstance(self.rd, ObjectPipe): return self.rd.fileno() - return self.rd + elif isinstance(self.rd, int): + return self.rd + return 0 def read(self, n=65535): + # type: (int) -> Union[bytes, Message, None] if isinstance(self.rd, ObjectPipe): return self.rd.recv(n) - return os.read(self.rd, n) + elif isinstance(self.rd, int): + return os.read(self.rd, n) + return b'' def write(self, msg): + # type: (bytes) -> int if isinstance(self.wr, ObjectPipe): return self.wr.send(msg) - return os.write(self.wr, msg) + elif isinstance(self.wr, int): + return os.write(self.wr, msg) + return 0 def recv(self, n=65535): + # type: (int) -> Union[bytes, Message, None] return self.read(n) def send(self, msg): + # type: (bytes) -> int return self.write(msg) class _IO_mixer: def __init__(self, rd, wr): + # type: (Union[int, ObjectPipe], Union[int, ObjectPipe]) -> None self.rd = rd self.wr = wr def fileno(self): + # type: () -> Any if isinstance(self.rd, ObjectPipe): return self.rd.fileno() return self.rd def recv(self, n=None): - return self.rd.recv(n) + # type: (Optional[int]) -> Any + if isinstance(self.rd, ObjectPipe): + return self.rd.recv(n) + else: + return None def read(self, n=None): + # type: (Optional[int]) -> Any return self.recv(n) def send(self, msg): - return self.wr.send(msg) + # type: (str) -> int + if isinstance(self.wr, ObjectPipe): + return self.wr.send(msg) + else: + return 0 def write(self, msg): + # type: (str) -> int return self.send(msg) class AutomatonException(Exception): def __init__(self, msg, state=None, result=None): + # type: (str, Optional[Message], Optional[str]) -> None Exception.__init__(self, msg) self.state = state self.result = result @@ -626,6 +831,7 @@ class Singlestep(AutomatonStopped): class InterceptionPoint(AutomatonStopped): def __init__(self, msg, state=None, result=None, packet=None): + # type: (str, Optional[Message], Optional[str], Optional[str]) -> None # noqa: E501 Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result) # noqa: E501 self.packet = packet @@ -634,16 +840,20 @@ class CommandMessage(AutomatonException): # Services def debug(self, lvl, msg): + # type: (int, str) -> None if self.debug_level >= lvl: log_runtime.debug(msg) def send(self, pkt): + # type: (Packet) -> None if self.state.state in self.interception_points: self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary()) self.intercepted_packet = pkt cmd = Message(type=_ATMT_Command.INTERCEPT, state=self.state, pkt=pkt) # noqa: E501 self.cmdout.send(cmd) - cmd = self.cmdin.recv() + temp = self.cmdin.recv() + if temp: + cmd = temp self.intercepted_packet = None if cmd.type == _ATMT_Command.REJECT: self.debug(3, "INTERCEPT: packet rejected") @@ -661,61 +871,16 @@ def send(self, pkt): if self.store_packets: self.packets.append(pkt.copy()) - # Internals - def __init__(self, *args, **kargs): - external_fd = kargs.pop("external_fd", {}) - self.send_sock_class = kargs.pop("ll", conf.L3socket) - self.recv_sock_class = kargs.pop("recvsock", conf.L2listen) - self.is_atmt_socket = kargs.pop("is_atmt_socket", False) - self.started = threading.Lock() - self.threadid = None - self.breakpointed = None - self.breakpoints = set() - self.interception_points = set() - self.intercepted_packet = None - self.debug_level = 0 - self.init_args = args - self.init_kargs = kargs - self.io = type.__new__(type, "IOnamespace", (), {}) - self.oi = type.__new__(type, "IOnamespace", (), {}) - self.cmdin = ObjectPipe("cmdin") - self.cmdout = ObjectPipe("cmdout") - self.ioin = {} - self.ioout = {} - for n in self.ionames: - extfd = external_fd.get(n) - if not isinstance(extfd, tuple): - extfd = (extfd, extfd) - ioin, ioout = extfd - if ioin is None: - ioin = ObjectPipe("ioin") - else: - ioin = self._IO_fdwrapper(ioin, None) - if ioout is None: - ioout = ObjectPipe("ioout") - else: - ioout = self._IO_fdwrapper(None, ioout) - - self.ioin[n] = ioin - self.ioout[n] = ioout - ioin.ioname = n - ioout.ioname = n - setattr(self.io, n, self._IO_mixer(ioout, ioin)) - setattr(self.oi, n, self._IO_mixer(ioin, ioout)) - - for stname in self.states: - setattr(self, stname, - _instance_state(getattr(self, stname))) - - self.start() - def __iter__(self): + # type: () -> Automaton return self def __del__(self): + # type: () -> None self.stop() def _run_condition(self, cond, *args, **kargs): + # type: (StateWrapper, Any, Any) -> None try: self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 cond(self, *args, **kargs) @@ -735,6 +900,7 @@ def _run_condition(self, cond, *args, **kargs): self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 def _do_start(self, *args, **kargs): + # type: (Any, Any) -> None ready = threading.Event() _t = threading.Thread( target=self._do_control, @@ -747,8 +913,11 @@ def _do_start(self, *args, **kargs): ready.wait() def _do_control(self, ready, *args, **kargs): + # type: (threading.Event, Any, Any) -> None with self.started: self.threadid = threading.current_thread().ident + if self.threadid is None: + self.threadid = 0 # Update default parameters a = args + self.init_args[len(args):] @@ -770,6 +939,8 @@ def _do_control(self, ready, *args, **kargs): try: while True: c = self.cmdin.recv() + if c is None: + return None self.debug(5, "Received command %s" % c.type) if c.type == _ATMT_Command.RUN: singlestep = False @@ -812,6 +983,7 @@ def _do_control(self, ready, *args, **kargs): self.threadid = None def _do_iter(self): + # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, NewStateRequested, None]] # noqa: E501 while True: try: self.debug(1, "## state=[%s]" % self.state.state) @@ -867,7 +1039,7 @@ def _do_iter(self): self._run_condition(timeout_func, *state_output) next_timeout, timeout_func = next(expirations) if next_timeout is None: - remain = None + remain = 0 else: remain = next_timeout - t @@ -899,6 +1071,7 @@ def _do_iter(self): yield state_req def __repr__(self): + # type: () -> str return "" % ( self.__class__.__name__, ["HALTED", "RUNNING"][self.started.locked()] @@ -906,43 +1079,54 @@ def __repr__(self): # Public API def add_interception_points(self, *ipts): + # type: (Any) -> None for ipt in ipts: if hasattr(ipt, "atmt_state"): ipt = ipt.atmt_state self.interception_points.add(ipt) def remove_interception_points(self, *ipts): + # type: (Any) -> None for ipt in ipts: if hasattr(ipt, "atmt_state"): ipt = ipt.atmt_state self.interception_points.discard(ipt) def add_breakpoints(self, *bps): + # type: (Any) -> None for bp in bps: if hasattr(bp, "atmt_state"): bp = bp.atmt_state self.breakpoints.add(bp) def remove_breakpoints(self, *bps): + # type: (Any) -> None for bp in bps: if hasattr(bp, "atmt_state"): bp = bp.atmt_state self.breakpoints.discard(bp) def start(self, *args, **kargs): + # type: (Any, Any) -> None if not self.started.locked(): self._do_start(*args, **kargs) - def run(self, resume=None, wait=True): + def run(self, + resume=None, # type: Optional[Message] + wait=True # type: Optional[bool] + ): + # type: (...) -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 if resume is None: resume = Message(type=_ATMT_Command.RUN) self.cmdin.send(resume) if wait: try: c = self.cmdout.recv() + if c is None: + return None except KeyboardInterrupt: self.cmdin.send(Message(type=_ATMT_Command.FREEZE)) - return + return None if c.type == _ATMT_Command.END: return c.result elif c.type == _ATMT_Command.INTERCEPT: @@ -953,15 +1137,19 @@ def run(self, resume=None, wait=True): raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state) # noqa: E501 elif c.type == _ATMT_Command.EXCEPTION: six.reraise(c.exc_info[0], c.exc_info[1], c.exc_info[2]) + return None def runbg(self, resume=None, wait=False): + # type: (Optional[Message], Optional[bool]) -> None self.run(resume, wait) def next(self): + # type: () -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 return self.run(resume=Message(type=_ATMT_Command.NEXT)) __next__ = next def _flush_inout(self): + # type: () -> None with self.started: # Flush command pipes while True: @@ -972,18 +1160,25 @@ def _flush_inout(self): fd.recv() def stop(self): + # type: () -> None self.cmdin.send(Message(type=_ATMT_Command.STOP)) self._flush_inout() def forcestop(self): + # type: () -> None self.cmdin.send(Message(type=_ATMT_Command.FORCESTOP)) self._flush_inout() def restart(self, *args, **kargs): + # type: (Any, Any) -> None self.stop() self.start(*args, **kargs) - def accept_packet(self, pkt=None, wait=False): + def accept_packet(self, + pkt=None, # type: Optional[Packet] + wait=False # type: Optional[bool] + ): + # type: (...) -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 rsm = Message() if pkt is None: rsm.type = _ATMT_Command.ACCEPT @@ -992,6 +1187,9 @@ def accept_packet(self, pkt=None, wait=False): rsm.pkt = pkt return self.run(resume=rsm, wait=wait) - def reject_packet(self, wait=False): + def reject_packet(self, + wait=False # type: Optional[bool] + ): + # type: (...) -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 rsm = Message(type=_ATMT_Command.REJECT) return self.run(resume=rsm, wait=wait) diff --git a/scapy/compat.py b/scapy/compat.py index 5b16aaa6849..a9da28fa8b6 100644 --- a/scapy/compat.py +++ b/scapy/compat.py @@ -25,6 +25,7 @@ 'AnyStr', 'Callable', 'DefaultDict', + 'Deque', 'Dict', 'Generic', 'IO', @@ -38,6 +39,7 @@ 'NoReturn', 'Optional', 'Pattern', + 'Protocol', 'Sequence', 'Set', 'Sized', @@ -122,6 +124,7 @@ def __repr__(self): AnyStr, Callable, DefaultDict, + Deque, Dict, Generic, IO, @@ -132,6 +135,7 @@ def __repr__(self): NoReturn, Optional, Pattern, + Protocol, Sequence, Set, Sized, @@ -152,6 +156,7 @@ def cast(_type, obj): # type: ignore Callable = _FakeType("Callable") DefaultDict = _FakeType("DefaultDict", # type: ignore collections.defaultdict) + Deque = _FakeType("Deque") # type: ignore Dict = _FakeType("Dict", dict) # type: ignore Generic = _FakeType("Generic") IO = _FakeType("IO") # type: ignore @@ -162,6 +167,7 @@ def cast(_type, obj): # type: ignore NoReturn = _FakeType("NoReturn") # type: ignore Optional = _FakeType("Optional") Pattern = _FakeType("Pattern") # type: ignore + Protocol = _FakeType("Protocol") Sequence = _FakeType("Sequence") # type: ignore Sequence = _FakeType("Sequence", list) # type: ignore Set = _FakeType("Set", set) # type: ignore From 5e0cf48348dc7a2b4f8d2016975883e2c8ba2235 Mon Sep 17 00:00:00 2001 From: gpotter2 Date: Mon, 1 Nov 2021 21:38:06 +0100 Subject: [PATCH 2/2] Various typing fixes --- .config/mypy/mypy_check.py | 22 ++ scapy/automaton.py | 288 ++++++++++++----------- scapy/compat.py | 3 - scapy/contrib/isotp/isotp_soft_socket.py | 2 +- scapy/pipetool.py | 4 +- scapy/sendrecv.py | 10 +- scapy/supersocket.py | 5 +- scapy/utils.py | 2 +- test/testsocket.py | 32 +-- 9 files changed, 189 insertions(+), 179 deletions(-) diff --git a/.config/mypy/mypy_check.py b/.config/mypy/mypy_check.py index 543bb1a2fe2..4d4bdf27644 100644 --- a/.config/mypy/mypy_check.py +++ b/.config/mypy/mypy_check.py @@ -62,6 +62,28 @@ "--show-traceback", ] + [os.path.abspath(f) for f in FILES] +if sys.platform.startswith("linux"): + ARGS.extend([ + "--always-true=LINUX", + "--always-false=OPENBSD", + "--always-false=FREEBSD", + "--always-false=NETBSD", + "--always-false=DARWIN", + "--always-false=WINDOWS", + "--always-false=BSD", + ]) +if sys.platform.startswith("win32"): + ARGS.extend([ + "--always-false=LINUX", + "--always-false=OPENBSD", + "--always-false=FREEBSD", + "--always-false=NETBSD", + "--always-false=DARWIN", + "--always-true=WINDOWS", + "--always-false=WINDOWS_XP", + "--always-false=BSD", + ]) + # Run mypy over the files mypy_main(None, sys.stdout, sys.stderr, ARGS) diff --git a/scapy/automaton.py b/scapy/automaton.py index 35563fbc54f..fb3bacd2b21 100644 --- a/scapy/automaton.py +++ b/scapy/automaton.py @@ -24,7 +24,6 @@ import select from collections import deque -from typing import Dict from scapy.config import conf from scapy.utils import do_graph @@ -37,18 +36,22 @@ import scapy.modules.six as six from scapy.compat import ( - Union, - List, - Optional, Any, - Type, Callable, - Protocol, + DecoratorCallable, + Deque, + Dict, + Generic, Iterator, - Tuple, + List, + Optional, Set, + Tuple, + Type, TypeVar, - Deque, + Union, + _Generic_metaclass, + cast, ) @@ -76,8 +79,6 @@ def select_objects(inputs, remain): :param inputs: objects to process :param remain: timeout. If 0, return []. """ - if not remain: - return [] if not WINDOWS: return select.select(inputs, [], [], remain)[0] natives = [] @@ -95,25 +96,23 @@ def select_objects(inputs, remain): if events: remainms = int((remain or 0) * 1000) if len(events) == 1: - if sys.platform == "win32": - res = ctypes.windll.kernel32.WaitForSingleObject( - ctypes.c_void_p(events[0].fileno()), - remainms - ) + res = ctypes.windll.kernel32.WaitForSingleObject( + ctypes.c_void_p(events[0].fileno()), + remainms + ) else: # Sadly, the only way to emulate select() is to first check # if any object is available using WaitForMultipleObjects # then poll the others. - if sys.platform == "win32": - res = ctypes.windll.kernel32.WaitForMultipleObjects( - len(events), - (ctypes.c_void_p * len(events))( - *[x.fileno() for x in events] - ), - False, - remainms - ) - if sys.platform == "win32" and res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout # noqa: E501 + res = ctypes.windll.kernel32.WaitForMultipleObjects( + len(events), + (ctypes.c_void_p * len(events))( + *[x.fileno() for x in events] + ), + False, + remainms + ) + if res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout results.add(events[res]) if len(events) > 1: # Now poll the others, if any @@ -127,41 +126,42 @@ def select_objects(inputs, remain): return list(results) -class ObjectPipe: +_T = TypeVar("_T") + + +@six.add_metaclass(_Generic_metaclass) +class ObjectPipe(Generic[_T]): def __init__(self, name=None): # type: (Optional[str]) -> None self.name = name or "ObjectPipe" self._closed = False self.__rd, self.__wr = os.pipe() - self.__queue = deque() # type: Deque[Union[str, bytes, Message]] + self.__queue = deque() # type: Deque[_T] if WINDOWS: self._wincreate() - def _wincreate(self): - # type: () -> None - if sys.platform == "win32": + if WINDOWS: + def _wincreate(self): + # type: () -> None self._fd = ctypes.windll.kernel32.CreateEventA( None, True, False, ctypes.create_string_buffer(b"ObjectPipe %f" % random.random()) ) - def _winset(self): - # type: () -> None - if sys.platform == "win32": + def _winset(self): + # type: () -> None if ctypes.windll.kernel32.SetEvent( ctypes.c_void_p(self._fd)) == 0: warning(ctypes.FormatError()) - def _winreset(self): - # type: () -> None - if sys.platform == "win32": + def _winreset(self): + # type: () -> None if ctypes.windll.kernel32.ResetEvent( ctypes.c_void_p(self._fd)) == 0: warning(ctypes.FormatError()) - def _winclose(self): - # type: () -> None - if sys.platform == "win32": + def _winclose(self): + # type: () -> None if self._fd and ctypes.windll.kernel32.CloseHandle( ctypes.c_void_p(self._fd)) == 0: warning(ctypes.FormatError()) @@ -170,20 +170,19 @@ def _winclose(self): def fileno(self): # type: () -> int if WINDOWS: - if sys.platform == "win32": - return self._fd + return self._fd return self.__rd def send(self, obj): - # type: (Union[str, bytes, Message]) -> int + # type: (Union[_T]) -> int self.__queue.append(obj) if WINDOWS: self._winset() os.write(self.__wr, b"X") - return 0 + return 1 def write(self, obj): - # type: (str) -> None + # type: (_T) -> None self.send(obj) def empty(self): @@ -195,20 +194,17 @@ def flush(self): pass def recv(self, n=0): - # type: (Optional[int]) -> Optional[Message] + # type: (Optional[int]) -> Optional[_T] if self._closed: return None os.read(self.__rd, 1) elt = self.__queue.popleft() if WINDOWS and not self.__queue: self._winreset() - if isinstance(elt, Message): - return elt - else: - return Message(elt) + return elt def read(self, n=0): - # type: (Optional[int]) -> Any + # type: (Optional[int]) -> Optional[_T] return self.recv(n) def close(self): @@ -231,7 +227,8 @@ def __del__(self): @staticmethod def select(sockets, remain=conf.recv_poll_rate): - # type: (List[SuperSocket], float) -> List[SuperSocket] + # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] + # Only handle ObjectPipes results = [] for s in sockets: if s.closed: @@ -241,16 +238,16 @@ def select(sockets, remain=conf.recv_poll_rate): return select_objects(sockets, remain) -class Message(str): +class Message: type = None # type: str pkt = None # type: Packet result = None # type: str state = None # type: Message exc_info = None # type: Union[Tuple[None, None, None], Tuple[BaseException, Exception, types.TracebackType]] # noqa: E501 - def __new__(cls, *args): - # type: (Any) -> Any - return str.__new__(cls, *args) + def __init__(self, **args): + # type: (Any) -> None + self.__dict__.update(args) def __repr__(self): # type: () -> str @@ -295,7 +292,7 @@ def unintercepts(self): # Automata # ############## -class StateWrapper(Protocol[TypeVar("F", bound=Callable[..., object])]): +class _StateWrapper: __name__ = None # type: str atmt_type = None # type: str atmt_state = None # type: str @@ -303,19 +300,15 @@ class StateWrapper(Protocol[TypeVar("F", bound=Callable[..., object])]): atmt_final = None # type: int atmt_stop = None # type: int atmt_error = None # type: int - atmt_origfunc = None # type: StateWrapper + atmt_origfunc = None # type: _StateWrapper atmt_prio = None # type: int - atmt_as_supersocket = None # type: str + atmt_as_supersocket = None # type: Optional[str] atmt_condname = None # type: str atmt_ioname = None # type: str atmt_timeout = None # type: int atmt_cond = None # type: Dict[str, int] - __call__ = None # type: TypeVar("F", bound=Callable[..., object]) # noqa: E501 - - -def state_wrapper_decorator(f): - # type: (Callable) -> StateWrapper - return f + __code__ = None # type: types.CodeType + __call__ = None # type: Callable[..., ATMT.NewStateRequested] class ATMT: @@ -342,7 +335,7 @@ def __init__(self, state_func, automaton, *args, **kargs): self.action_parameters() # init action parameters def action_parameters(self, *args, **kargs): - # type: (Any, Any) -> NewStateRequested + # type: (Any, Any) -> ATMT.NewStateRequested self.action_args = args self.action_kargs = kargs return self @@ -355,15 +348,15 @@ def __repr__(self): # type: () -> str return "NewStateRequested(%s)" % self.state - def state(self, - initial=0, # type: int + @staticmethod + def state(initial=0, # type: int final=0, # type: int stop=0, # type: int error=0 # type: int ): - # type: (...) -> Callable[[StateWrapper, int, int], Callable[[Any, Any], NewStateRequested]] # noqa: E501 + # type: (...) -> Callable[[DecoratorCallable], DecoratorCallable] def deco(f, initial=initial, final=final): - # type: (StateWrapper, int, int) -> Callable[[Any, Any], NewStateRequested] # noqa: E501 + # type: (_StateWrapper, int, int) -> _StateWrapper f.atmt_type = ATMT.STATE f.atmt_state = f.__name__ f.atmt_initial = initial @@ -371,11 +364,11 @@ def deco(f, initial=initial, final=final): f.atmt_stop = stop f.atmt_error = error - @state_wrapper_decorator - def state_wrapper(self, *args, **kargs): - # type: (ATMT, Any, Any) -> NewStateRequested + def _state_wrapper(self, *args, **kargs): + # type: (ATMT, Any, Any) -> ATMT.NewStateRequested return ATMT.NewStateRequested(f, self, *args, **kargs) + state_wrapper = cast(_StateWrapper, _state_wrapper) state_wrapper.__name__ = "%s_wrapper" % f.__name__ state_wrapper.atmt_type = ATMT.STATE state_wrapper.atmt_state = f.__name__ @@ -385,13 +378,13 @@ def state_wrapper(self, *args, **kargs): state_wrapper.atmt_error = error state_wrapper.atmt_origfunc = f return state_wrapper - return deco + return deco # type: ignore @staticmethod def action(cond, prio=0): - # type: (Any, int) -> Callable[[StateWrapper, StateWrapper], StateWrapper] # noqa: E501 + # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 def deco(f, cond=cond): - # type: (StateWrapper, StateWrapper) -> StateWrapper + # type: (_StateWrapper, _StateWrapper) -> _StateWrapper if not hasattr(f, "atmt_type"): f.atmt_cond = {} f.atmt_type = ATMT.ACTION @@ -401,9 +394,9 @@ def deco(f, cond=cond): @staticmethod def condition(state, prio=0): - # type: (Any, int) -> Callable[[StateWrapper, StateWrapper], StateWrapper] # noqa: E501 + # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 def deco(f, state=state): - # type: (StateWrapper, StateWrapper) -> Any + # type: (_StateWrapper, _StateWrapper) -> Any f.atmt_type = ATMT.CONDITION f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ @@ -413,9 +406,9 @@ def deco(f, state=state): @staticmethod def receive_condition(state, prio=0): - # type: (StateWrapper, int) -> Callable[[StateWrapper, StateWrapper], StateWrapper] # noqa: E501 + # type: (_StateWrapper, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 def deco(f, state=state): - # type: (StateWrapper, StateWrapper) -> StateWrapper + # type: (_StateWrapper, _StateWrapper) -> _StateWrapper f.atmt_type = ATMT.RECV f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ @@ -424,28 +417,28 @@ def deco(f, state=state): return deco @staticmethod - def ioevent(state, # type: StateWrapper + def ioevent(state, # type: _StateWrapper name, # type: str prio=0, # type: int as_supersocket=None # type: Optional[str] ): - # type: (...) -> Callable[[StateWrapper, StateWrapper], StateWrapper] + # type: (...) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 def deco(f, state=state): - # type: (StateWrapper, StateWrapper) -> StateWrapper + # type: (_StateWrapper, _StateWrapper) -> _StateWrapper f.atmt_type = ATMT.IOEVENT f.atmt_state = state.atmt_state f.atmt_condname = f.__name__ f.atmt_ioname = name f.atmt_prio = prio - f.atmt_as_supersocket = as_supersocket if as_supersocket else "" + f.atmt_as_supersocket = as_supersocket return f return deco @staticmethod def timeout(state, timeout): - # type: (StateWrapper, int) -> Callable[[StateWrapper, StateWrapper, int], StateWrapper] # noqa: E501 + # type: (_StateWrapper, int) -> Callable[[_StateWrapper, _StateWrapper, int], _StateWrapper] # noqa: E501 def deco(f, state=state, timeout=timeout): - # type: (StateWrapper, StateWrapper, int) -> StateWrapper + # type: (_StateWrapper, _StateWrapper, int) -> _StateWrapper f.atmt_type = ATMT.TIMEOUT f.atmt_state = state.atmt_state f.atmt_timeout = timeout @@ -475,7 +468,7 @@ def __init__(self, name, # type: str ioevent, # type: str automaton, # type: Type[Automaton] - proto, # type: Callable[[Message], Any] + proto, # type: Callable[[bytes], Any] *args, # type: Any **kargs # type: Any ): @@ -484,7 +477,8 @@ def __init__(self, self.ioevent = ioevent self.proto = proto # write, read - self.spa, self.spb = ObjectPipe("spa"), ObjectPipe("spb") + self.spa, self.spb = ObjectPipe[bytes]("spa"), \ + ObjectPipe[bytes]("spb") kargs["external_fd"] = {ioevent: (self.spa, self.spb)} kargs["is_atmt_socket"] = True self.atmt = automaton(*args, **kargs) @@ -529,7 +523,7 @@ def __init__(self, name, ioevent, automaton): self.automaton = automaton def __call__(self, proto, *args, **kargs): - # type: (Callable[[Message], Any], Any, Any) -> _ATMT_supersocket + # type: (Callable[[bytes], Any], Any, Any) -> _ATMT_supersocket return _ATMT_supersocket( self.name, self.ioevent, self.automaton, proto, *args, **kargs @@ -537,18 +531,17 @@ def __call__(self, proto, *args, **kargs): class Automaton_metaclass(type): - def __new__(cls, name, bases, dct): - # type: (str, Tuple[Any], Dict[str, Any]) -> Automaton_metaclass + def __new__(cls, name, bases, dct): # type: ignore + # type: (str, Tuple[Any], Dict[str, Any]) -> Type[Automaton] cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct) cls.states = {} - cls.state = StateWrapper() # type: StateWrapper - cls.recv_conditions = {} # type: Dict[str, List[StateWrapper]] - cls.conditions = {} # type: Dict[str, List[StateWrapper]] - cls.ioevents = {} # type: Dict[str, List[StateWrapper]] - cls.timeout = {} # type: Dict[str, List[Tuple[int, StateWrapper]]] # noqa: E501 - cls.actions = {} # type: Dict[str, List[StateWrapper]] - cls.initial_states = [] # type: List[StateWrapper] - cls.stop_states = [] # type: List[StateWrapper] + cls.recv_conditions = {} # type: Dict[str, List[_StateWrapper]] + cls.conditions = {} # type: Dict[str, List[_StateWrapper]] + cls.ioevents = {} # type: Dict[str, List[_StateWrapper]] + cls.timeout = {} # type: Dict[str, List[Tuple[int, _StateWrapper]]] # noqa: E501 + cls.actions = {} # type: Dict[str, List[_StateWrapper]] + cls.initial_states = [] # type: List[_StateWrapper] + cls.stop_states = [] # type: List[_StateWrapper] cls.ionames = [] cls.iosupersockets = [] @@ -562,7 +555,7 @@ def __new__(cls, name, bases, dct): members[k] = v decorated = [v for v in six.itervalues(members) - if isinstance(v, StateWrapper) and hasattr(v, "atmt_type")] # noqa: E501 + if hasattr(v, "atmt_type")] for m in decorated: if m.atmt_type == ATMT.STATE: @@ -606,9 +599,12 @@ def __new__(cls, name, bases, dct): actlst.sort(key=lambda x: x.atmt_cond[condname]) for ioev in cls.iosupersockets: - setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls)) # noqa: E501 - - return cls + setattr(cls, ioev.atmt_as_supersocket, + _ATMT_to_supersocket( + ioev.atmt_as_supersocket, + ioev.atmt_ioname, + cast(Type["Automaton"], cls))) + return cast(Type["Automaton"], cls) def build_graph(self): # type: () -> str @@ -641,8 +637,8 @@ def build_graph(self): for x in self.actions[f.atmt_condname]: line += "\\l>[%s]" % x.__name__ s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k, n, line, c) # noqa: E501 - for k, v in six.iteritems(self.timeout): - for t, f in v: + for k, v2 in six.iteritems(self.timeout): + for t, f in v2: if f is None: continue for n in f.__code__.co_names + f.__code__.co_consts: @@ -660,7 +656,20 @@ def graph(self, **kargs): return do_graph(s, **kargs) -class Automaton(six.with_metaclass(Automaton_metaclass)): +@six.add_metaclass(Automaton_metaclass) +class Automaton: + states = {} # type: Dict[str, _StateWrapper] + state = None # type: ATMT.NewStateRequested + recv_conditions = {} # type: Dict[str, List[_StateWrapper]] + conditions = {} # type: Dict[str, List[_StateWrapper]] + ioevents = {} # type: Dict[str, List[_StateWrapper]] + timeout = {} # type: Dict[str, List[Tuple[int, _StateWrapper]]] # noqa: E501 + actions = {} # type: Dict[str, List[_StateWrapper]] + initial_states = [] # type: List[_StateWrapper] + stop_states = [] # type: List[_StateWrapper] + ionames = [] # type: List[str] + iosupersockets = [] # type: List[SuperSocket] + # Internals def __init__(self, *args, **kargs): # type: (Any, Any) -> None @@ -671,20 +680,20 @@ def __init__(self, *args, **kargs): self.started = threading.Lock() self.threadid = None # type: Optional[int] self.breakpointed = None - self.breakpoints = set() # type: Set[StateWrapper] - self.interception_points = set() # type: Set[StateWrapper] + self.breakpoints = set() # type: Set[_StateWrapper] + self.interception_points = set() # type: Set[_StateWrapper] self.intercepted_packet = None # type: Union[None, Packet] self.debug_level = 0 self.init_args = args self.init_kargs = kargs self.io = type.__new__(type, "IOnamespace", (), {}) self.oi = type.__new__(type, "IOnamespace", (), {}) - self.cmdin = ObjectPipe("cmdin") - self.cmdout = ObjectPipe("cmdout") + self.cmdin = ObjectPipe[Message]("cmdin") + self.cmdout = ObjectPipe[Message]("cmdout") self.ioin = {} self.ioout = {} self.packets = PacketList() # type: PacketList - for n in self.ionames: + for n in self.__class__.ionames: extfd = external_fd.get(n) if not isinstance(extfd, tuple): extfd = (extfd, extfd) @@ -729,12 +738,15 @@ def my_send(self, pkt): # Utility classes and exceptions class _IO_fdwrapper: - def __init__(self, rd, wr): - # type: (Union[int, ObjectPipe, None], Union[int, ObjectPipe, None]) -> None # noqa: E501 + def __init__(self, + rd, # type: Union[int, ObjectPipe[bytes], None] + wr # type: Union[int, ObjectPipe[bytes], None] + ): + # type: (...) -> None if rd is not None and not isinstance(rd, (int, ObjectPipe)): - rd = rd.fileno() + rd = rd.fileno() # type: ignore if wr is not None and not isinstance(wr, (int, ObjectPipe)): - wr = wr.fileno() + wr = wr.fileno() # type: ignore self.rd = rd self.wr = wr @@ -747,12 +759,12 @@ def fileno(self): return 0 def read(self, n=65535): - # type: (int) -> Union[bytes, Message, None] + # type: (int) -> Optional[bytes] if isinstance(self.rd, ObjectPipe): return self.rd.recv(n) elif isinstance(self.rd, int): return os.read(self.rd, n) - return b'' + return None def write(self, msg): # type: (bytes) -> int @@ -763,7 +775,7 @@ def write(self, msg): return 0 def recv(self, n=65535): - # type: (int) -> Union[bytes, Message, None] + # type: (int) -> Optional[bytes] return self.read(n) def send(self, msg): @@ -771,8 +783,11 @@ def send(self, msg): return self.write(msg) class _IO_mixer: - def __init__(self, rd, wr): - # type: (Union[int, ObjectPipe], Union[int, ObjectPipe]) -> None + def __init__(self, + rd, # type: ObjectPipe[Any] + wr, # type: ObjectPipe[Any] + ): + # type: (...) -> None self.rd = rd self.wr = wr @@ -784,10 +799,7 @@ def fileno(self): def recv(self, n=None): # type: (Optional[int]) -> Any - if isinstance(self.rd, ObjectPipe): - return self.rd.recv(n) - else: - return None + return self.rd.recv(n) def read(self, n=None): # type: (Optional[int]) -> Any @@ -795,10 +807,7 @@ def read(self, n=None): def send(self, msg): # type: (str) -> int - if isinstance(self.wr, ObjectPipe): - return self.wr.send(msg) - else: - return 0 + return self.wr.send(msg) def write(self, msg): # type: (str) -> int @@ -849,11 +858,14 @@ def send(self, pkt): if self.state.state in self.interception_points: self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary()) self.intercepted_packet = pkt - cmd = Message(type=_ATMT_Command.INTERCEPT, state=self.state, pkt=pkt) # noqa: E501 - self.cmdout.send(cmd) - temp = self.cmdin.recv() - if temp: - cmd = temp + self.cmdout.send( + Message(type=_ATMT_Command.INTERCEPT, + state=self.state, pkt=pkt) + ) + cmd = self.cmdin.recv() + if not cmd: + self.debug(3, "CANCELLED") + return self.intercepted_packet = None if cmd.type == _ATMT_Command.REJECT: self.debug(3, "INTERCEPT: packet rejected") @@ -880,7 +892,7 @@ def __del__(self): self.stop() def _run_condition(self, cond, *args, **kargs): - # type: (StateWrapper, Any, Any) -> None + # type: (_StateWrapper, Any, Any) -> None try: self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 cond(self, *args, **kargs) @@ -983,7 +995,7 @@ def _do_control(self, ready, *args, **kargs): self.threadid = None def _do_iter(self): - # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, NewStateRequested, None]] # noqa: E501 + # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, ATMT.NewStateRequested, None]] # noqa: E501 while True: try: self.debug(1, "## state=[%s]" % self.state.state) @@ -1115,7 +1127,7 @@ def run(self, resume=None, # type: Optional[Message] wait=True # type: Optional[bool] ): - # type: (...) -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 + # type: (...) -> Any if resume is None: resume = Message(type=_ATMT_Command.RUN) self.cmdin.send(resume) @@ -1144,7 +1156,7 @@ def runbg(self, resume=None, wait=False): self.run(resume, wait) def next(self): - # type: () -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 + # type: () -> Any return self.run(resume=Message(type=_ATMT_Command.NEXT)) __next__ = next @@ -1178,7 +1190,7 @@ def accept_packet(self, pkt=None, # type: Optional[Packet] wait=False # type: Optional[bool] ): - # type: (...) -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 + # type: (...) -> Any rsm = Message() if pkt is None: rsm.type = _ATMT_Command.ACCEPT @@ -1190,6 +1202,6 @@ def accept_packet(self, def reject_packet(self, wait=False # type: Optional[bool] ): - # type: (...) -> Union[Automaton.AutomatonException, Automaton.AutomatonStopped, str, None] # noqa: E501 + # type: (...) -> Any rsm = Message(type=_ATMT_Command.REJECT) return self.run(resume=rsm, wait=wait) diff --git a/scapy/compat.py b/scapy/compat.py index a9da28fa8b6..12a4fd68379 100644 --- a/scapy/compat.py +++ b/scapy/compat.py @@ -39,7 +39,6 @@ 'NoReturn', 'Optional', 'Pattern', - 'Protocol', 'Sequence', 'Set', 'Sized', @@ -135,7 +134,6 @@ def __repr__(self): NoReturn, Optional, Pattern, - Protocol, Sequence, Set, Sized, @@ -167,7 +165,6 @@ def cast(_type, obj): # type: ignore NoReturn = _FakeType("NoReturn") # type: ignore Optional = _FakeType("Optional") Pattern = _FakeType("Pattern") # type: ignore - Protocol = _FakeType("Protocol") Sequence = _FakeType("Sequence") # type: ignore Sequence = _FakeType("Sequence", list) # type: ignore Set = _FakeType("Set", set) # type: ignore diff --git a/scapy/contrib/isotp/isotp_soft_socket.py b/scapy/contrib/isotp/isotp_soft_socket.py index 7d55738d969..ddcd56e8ba6 100644 --- a/scapy/contrib/isotp/isotp_soft_socket.py +++ b/scapy/contrib/isotp/isotp_soft_socket.py @@ -523,7 +523,7 @@ def __init__(self, self.rxfc_bs = rx_block_size self.rxfc_stmin = stmin - self.rx_queue = ObjectPipe() + self.rx_queue = ObjectPipe[Tuple[bytes, Union[float, EDecimal]]]() self.rx_len = -1 self.rx_buf = None # type: Optional[bytes] self.rx_sn = 0 diff --git a/scapy/pipetool.py b/scapy/pipetool.py index 4097c0c4eca..da02ede8a12 100644 --- a/scapy/pipetool.py +++ b/scapy/pipetool.py @@ -20,6 +20,8 @@ from scapy.config import conf from scapy.utils import get_temp_file, do_graph +from scapy.compat import _Generic_metaclass + class PipeEngine(ObjectPipe): pipes = {} @@ -250,7 +252,7 @@ def __hash__(self): return object.__hash__(self) -class _PipeMeta(type): +class _PipeMeta(_Generic_metaclass): def __new__(cls, name, bases, dct): c = type.__new__(cls, name, bases, dct) PipeEngine.pipes[name] = c diff --git a/scapy/sendrecv.py b/scapy/sendrecv.py index c3e53b21770..0a14eeaf9e8 100644 --- a/scapy/sendrecv.py +++ b/scapy/sendrecv.py @@ -1174,15 +1174,16 @@ def _run(self, "The used select function " "will be the one of the first socket") + close_pipe = None # type: Optional[ObjectPipe[None]] if not nonblocking_socket: # select is blocking: Add special control socket from scapy.automaton import ObjectPipe - close_pipe = ObjectPipe() - sniff_sockets[close_pipe] = "control_socket" + close_pipe = ObjectPipe[None]() + sniff_sockets[close_pipe] = "control_socket" # type: ignore def stop_cb(): # type: () -> None - if self.running: + if self.running and close_pipe: close_pipe.send(None) self.continue_sniff = False self.stop_cb = stop_cb @@ -1192,7 +1193,6 @@ def stop_cb(): # type: () -> None self.continue_sniff = False self.stop_cb = stop_cb - close_pipe = None try: if started_callback: @@ -1212,7 +1212,7 @@ def stop_cb(): sockets = select_func(list(sniff_sockets.keys()), remain) dead_sockets = [] for s in sockets: - if s is close_pipe: + if s is close_pipe: # type: ignore break try: p = s.recv() diff --git a/scapy/supersocket.py b/scapy/supersocket.py index 591cc142f8a..fa0229758e2 100644 --- a/scapy/supersocket.py +++ b/scapy/supersocket.py @@ -40,13 +40,14 @@ Optional, Tuple, Type, - cast + cast, + _Generic_metaclass ) # Utils -class _SuperSocket_metaclass(type): +class _SuperSocket_metaclass(_Generic_metaclass): desc = None # type: Optional[str] def __repr__(self): diff --git a/scapy/utils.py b/scapy/utils.py index 6ed622bd3d3..3cca5779293 100644 --- a/scapy/utils.py +++ b/scapy/utils.py @@ -2362,7 +2362,7 @@ def get_terminal_width(): return sizex # Backups / Python 2.7 if WINDOWS: - from ctypes import windll, create_string_buffer # type: ignore + from ctypes import windll, create_string_buffer # http://code.activestate.com/recipes/440694-determine-size-of-console-window-on-windows/ h = windll.kernel32.GetStdHandle(-12) csbi = create_string_buffer(22) diff --git a/test/testsocket.py b/test/testsocket.py index 2d06e4023b1..29fdbc7b4a7 100644 --- a/test/testsocket.py +++ b/test/testsocket.py @@ -21,7 +21,7 @@ open_test_sockets = list() # type: List[TestSocket] -class TestSocket(ObjectPipe, object): +class TestSocket(ObjectPipe[Packet], SuperSocket): nonblocking_socket = False # type: bool def __init__(self, basecls=None): @@ -56,7 +56,7 @@ def send(self, x): # type: (Packet) -> int sx = bytes(x) for r in self.paired_sockets: - super(TestSocket, r).send(sx) + super(TestSocket, r).send(sx) # type: ignore try: x.sent_time = time.time() except AttributeError: @@ -70,33 +70,9 @@ def recv_raw(self, x=MTU): super(TestSocket, self).recv(), \ time.time() - def recv(self, x=MTU): + def recv(self, x=MTU): # type: ignore # type: (int) -> Optional[Packet] - if six.PY3: - return SuperSocket.recv(self, x) - else: - return SuperSocket.recv.im_func(self, x) - - def sr1(self, *args, **kargs): - # type: (Any, Any) -> Optional[Packet] - if six.PY3: - return SuperSocket.sr1(self, *args, **kargs) - else: - return SuperSocket.sr1.im_func(self, *args, **kargs) - - def sr(self, *args, **kargs): - # type: (Any, Any) -> Tuple[SndRcvList, PacketList] - if six.PY3: - return SuperSocket.sr(self, *args, **kargs) - else: - return SuperSocket.sr.im_func(self, *args, **kargs) - - def sniff(self, *args, **kargs): - # type: (Any, Any) -> PacketList - if six.PY3: - return SuperSocket.sniff(self, *args, **kargs) - else: - return SuperSocket.sniff.im_func(self, *args, **kargs) + return SuperSocket.recv(self, x=x) @staticmethod def select(sockets, remain=conf.recv_poll_rate):