diff --git a/linebot/webhook.py b/linebot/webhook.py index 035b3ef6a..0009e8b51 100644 --- a/linebot/webhook.py +++ b/linebot/webhook.py @@ -267,11 +267,16 @@ def __add_handler(self, func, event, message=None): @staticmethod def __get_args_count(func): + HANDLER_ARGSIZE_MAXIMAM = 2 if PY3: arg_spec = inspect.getfullargspec(func) + if arg_spec.varargs is not None: + return HANDLER_ARGSIZE_MAXIMAM return len(arg_spec.args) else: arg_spec = inspect.getargspec(func) + if arg_spec.varargs is not None: + return HANDLER_ARGSIZE_MAXIMAM return len(arg_spec.args) @staticmethod diff --git a/tests/test_webhook.py b/tests/test_webhook.py index 51b884fd4..c216b79cf 100644 --- a/tests/test_webhook.py +++ b/tests/test_webhook.py @@ -17,6 +17,7 @@ import os import unittest from builtins import open +import inspect from linebot import ( SignatureValidator, WebhookParser, WebhookHandler @@ -29,6 +30,7 @@ LocationMessage, StickerMessage, FileMessage, SourceUser, SourceRoom, SourceGroup, DeviceLink, DeviceUnlink, ScenarioResult, ActionResult) +from linebot.utils import PY3 class TestSignatureValidator(unittest.TestCase): @@ -527,5 +529,96 @@ def test_handler(self): self.handler.handle(body, 'signature') +def wrap(func): + def wrapper(*args): + if PY3: + arg_spec = inspect.getfullargspec(func) + else: + arg_spec = inspect.getargspec(func) + return func(*args[0:len(arg_spec.args)]) + return wrapper + + +class TestWebhookHandlerWithWrappedFunction(unittest.TestCase): + def setUp(self): + self.handler = WebhookHandler('channel_secret') + + @self.handler.add(MessageEvent, message=TextMessage) + @wrap + def message_text(event, destination): + self.assertEqual('message', event.type) + self.assertEqual('text', event.message.type) + self.assertEqual('U123', destination) + + @self.handler.add(MessageEvent, + message=(ImageMessage, VideoMessage, AudioMessage)) + @wrap + def message_content(event): + self.assertEqual('message', event.type) + self.assertIn( + event.message.type, + ['image', 'video', 'audio'] + ) + + @self.handler.add(MessageEvent, message=StickerMessage) + @wrap + def message_sticker(event): + self.assertEqual('message', event.type) + self.assertEqual('sticker', event.message.type) + + @self.handler.add(MessageEvent) + @wrap + def message(event): + self.assertEqual('message', event.type) + self.assertNotIn( + event.message.type, + ['text', 'image', 'video', 'audio', 'sticker'] + ) + + @self.handler.add(FollowEvent) + @wrap + def follow(event, destination): + self.assertEqual('follow', event.type) + self.assertEqual('U123', destination) + + @self.handler.add(JoinEvent) + @wrap + def join(event): + self.assertEqual('join', event.type) + + @self.handler.add(PostbackEvent) + @wrap + def postback(event): + self.assertEqual('postback', event.type) + + @self.handler.add(BeaconEvent) + @wrap + def beacon(event): + self.assertEqual('beacon', event.type) + + @self.handler.add(AccountLinkEvent) + @wrap + def account_link(event): + self.assertEqual('accountLink', event.type) + + @self.handler.default() + def default(event): + self.assertNotIn( + event.type, + ['message', 'follow', 'join', 'postback', 'beacon', 'accountLink'] + ) + + def test_handler_with_wrapped_function(self): + file_dir = os.path.dirname(__file__) + webhook_sample_json_path = os.path.join(file_dir, 'text', 'webhook.json') + with open(webhook_sample_json_path) as fp: + body = fp.read() + + # mock + self.handler.parser.signature_validator.validate = lambda a, b: True + + self.handler.handle(body, 'signature') + + if __name__ == '__main__': unittest.main()