diff --git a/linebot/webhook.py b/linebot/webhook.py index 035b3ef6a..eae6d23c1 100644 --- a/linebot/webhook.py +++ b/linebot/webhook.py @@ -253,26 +253,30 @@ def handle(self, body, signature): if func is None: LOGGER.info('No handler of ' + key + ' and no default handler') else: - args_count = self.__get_args_count(func) - if args_count == 0: - func() - elif args_count == 1: - func(event) - else: - func(event, payload.destination) + self.__invoke_func(func, event, payload) def __add_handler(self, func, event, message=None): key = self.__get_handler_key(event, message=message) self._handlers[key] = func + @classmethod + def __invoke_func(cls, func, event, payload): + (has_varargs, args_count) = cls.__get_args_count(func) + if has_varargs or args_count == 2: + func(event, payload.destination) + elif args_count == 1: + func(event) + else: + func() + @staticmethod def __get_args_count(func): if PY3: arg_spec = inspect.getfullargspec(func) - return len(arg_spec.args) + return (arg_spec.varargs is not None, len(arg_spec.args)) else: arg_spec = inspect.getargspec(func) - return len(arg_spec.args) + return (arg_spec.varargs is not None, len(arg_spec.args)) @staticmethod def __get_handler_key(event, message=None): diff --git a/tests/test_webhook.py b/tests/test_webhook.py index 51b884fd4..9839a65c4 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -17,6 +17,7 @@ import os import unittest from builtins import open +import inspect from linebot import ( SignatureValidator, WebhookParser, WebhookHandler @@ -29,6 +30,7 @@ LocationMessage, StickerMessage, FileMessage, SourceUser, SourceRoom, SourceGroup, DeviceLink, DeviceUnlink, ScenarioResult, ActionResult) +from linebot.utils import PY3 class TestSignatureValidator(unittest.TestCase): @@ -527,5 +529,88 @@ def test_handler(self): self.handler.handle(body, 'signature') +class TestInvokeWebhookHandler(unittest.TestCase): + def setUp(self): + def wrap(func): + def wrapper(*args): + if PY3: + arg_spec = inspect.getfullargspec(func) + else: + arg_spec = inspect.getargspec(func) + return func(*args[0:len(arg_spec.args)]) + return wrapper + + def func_with_0_args(): + assert True + + def func_with_1_arg(arg): + assert arg + + def func_with_2_args(arg1, arg2): + assert arg1 and arg2 + + def func_with_1_arg_with_default(arg=False): + assert arg + + def func_with_2_args_with_default(arg1=False, arg2=False): + assert arg1 and arg2 + + def func_with_1_arg_and_1_arg_with_default(arg1, arg2=False): + assert arg1 and arg2 + + @wrap + def wrapped_func_with_0_args(): + assert True + + @wrap + def wrapped_func_with_1_arg(arg): + assert arg + + @wrap + def wrapped_func_with_2_args(arg1, arg2): + assert arg1 and arg2 + + @wrap + def wrapped_func_with_1_arg_with_default(arg=False): + assert arg + + @wrap + def wrapped_func_with_2_args_with_default(arg1=False, arg2=False): + assert arg1 and arg2 + + @wrap + def wrapped_func_with_1_arg_and_1_arg_with_default( + arg1, arg2=False): + assert arg1 and arg2 + + self.functions = [ + func_with_0_args, + func_with_1_arg, + func_with_2_args, + func_with_1_arg_with_default, + func_with_2_args_with_default, + func_with_1_arg_and_1_arg_with_default, + wrapped_func_with_0_args, + wrapped_func_with_1_arg, + wrapped_func_with_2_args, + wrapped_func_with_1_arg_with_default, + wrapped_func_with_2_args_with_default, + wrapped_func_with_1_arg_and_1_arg_with_default, + ] + + def test_invoke_func(self): + class PayloadMock(object): + def __init__(self): + self.destination = True + + event = True + payload = PayloadMock() + + for func in self.functions: + WebhookHandler._WebhookHandler__invoke_func( + func, event, payload + ) + + if __name__ == '__main__': unittest.main()