Skip to content

Commit c9dd51b

Browse files
committed
rate_limit: Stop wrapping rate limited functions.
This refactors `rate_limit` so that we no longer use it as a decorator. This is a workaround to python/mypy#12909 as `rate_limit` previous expects different parameters than its callers. Signed-off-by: Zixuan James Li <[email protected]>
1 parent c011657 commit c9dd51b

File tree

2 files changed

+25
-45
lines changed

2 files changed

+25
-45
lines changed

zerver/decorator.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,8 @@ def _wrapped_view_func(
508508
request: HttpRequest, /, *args: ParamT.args, **kwargs: ParamT.kwargs
509509
) -> HttpResponse:
510510
process_client(request, request.user, is_browser_view=True, query=view_func.__name__)
511-
return rate_limit(view_func)(request, *args, **kwargs)
511+
rate_limit(request)
512+
return view_func(request, *args, **kwargs)
512513

513514
return _wrapped_view_func
514515

@@ -681,10 +682,8 @@ def _wrapped_func_arguments(
681682
) -> HttpResponse:
682683
user_profile = validate_api_key(request, None, api_key, False)
683684
if not skip_rate_limiting:
684-
limited_func = rate_limit(view_func)
685-
else:
686-
limited_func = view_func
687-
return limited_func(request, user_profile, *args, **kwargs)
685+
rate_limit(request)
686+
return view_func(request, user_profile, *args, **kwargs)
688687

689688
return _wrapped_func_arguments
690689

@@ -745,10 +744,8 @@ def _wrapped_func_arguments(
745744
try:
746745
if not skip_rate_limiting:
747746
# Apply rate limiting
748-
target_view_func = rate_limit(view_func)
749-
else:
750-
target_view_func = view_func
751-
return target_view_func(request, profile, *args, **kwargs)
747+
rate_limit(request)
748+
return view_func(request, profile, *args, **kwargs)
752749
except Exception as err:
753750
if not webhook_client_name:
754751
raise err
@@ -822,9 +819,7 @@ def authenticate_log_and_execute_json(
822819
**kwargs: object,
823820
) -> HttpResponse:
824821
if not skip_rate_limiting:
825-
limited_view_func = rate_limit(view_func)
826-
else:
827-
limited_view_func = view_func
822+
rate_limit(request)
828823

829824
if not request.user.is_authenticated:
830825
if not allow_unauthenticated:
@@ -835,7 +830,7 @@ def authenticate_log_and_execute_json(
835830
is_browser_view=True,
836831
query=view_func.__name__,
837832
)
838-
return limited_view_func(request, request.user, *args, **kwargs)
833+
return view_func(request, request.user, *args, **kwargs)
839834

840835
user_profile = request.user
841836
validate_account_and_subdomain(request, user_profile)
@@ -844,7 +839,7 @@ def authenticate_log_and_execute_json(
844839
raise JsonableError(_("Webhook bots can only access webhooks"))
845840

846841
process_client(request, user_profile, is_browser_view=True, query=view_func.__name__)
847-
return limited_view_func(request, user_profile, *args, **kwargs)
842+
return view_func(request, user_profile, *args, **kwargs)
848843

849844

850845
# Checks if the user is logged in. If not, return an error (the
@@ -1027,36 +1022,22 @@ def rate_limit_remote_server(
10271022
raise e
10281023

10291024

1030-
def rate_limit(func: ViewFuncT) -> ViewFuncT:
1031-
"""Rate-limits a view."""
1032-
1033-
@wraps(func)
1034-
def wrapped_func(request: HttpRequest, *args: object, **kwargs: object) -> HttpResponse:
1035-
1036-
# It is really tempting to not even wrap our original function
1037-
# when settings.RATE_LIMITING is False, but it would make
1038-
# for awkward unit testing in some situations.
1039-
if not settings.RATE_LIMITING:
1040-
return func(request, *args, **kwargs)
1041-
1042-
if client_is_exempt_from_rate_limiting(request):
1043-
return func(request, *args, **kwargs)
1044-
1045-
user = request.user
1046-
remote_server = RequestNotes.get_notes(request).remote_server
1025+
def rate_limit(request: HttpRequest) -> None:
1026+
if not settings.RATE_LIMITING:
1027+
return
10471028

1048-
if settings.ZILENCER_ENABLED and remote_server is not None:
1049-
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
1050-
elif not user.is_authenticated:
1051-
rate_limit_request_by_ip(request, domain="api_by_ip")
1052-
return func(request, *args, **kwargs)
1053-
else:
1054-
assert isinstance(user, UserProfile)
1055-
rate_limit_user(request, user, domain="api_by_user")
1029+
if client_is_exempt_from_rate_limiting(request):
1030+
return
10561031

1057-
return func(request, *args, **kwargs)
1032+
remote_server = RequestNotes.get_notes(request).remote_server
10581033

1059-
return cast(ViewFuncT, wrapped_func) # https://github.com/python/mypy/issues/1927
1034+
if settings.ZILENCER_ENABLED and remote_server is not None:
1035+
rate_limit_remote_server(request, remote_server, domain="api_by_remote_server")
1036+
elif not request.user.is_authenticated:
1037+
rate_limit_request_by_ip(request, domain="api_by_ip")
1038+
else:
1039+
assert isinstance(request.user, UserProfile)
1040+
rate_limit_user(request, request.user, domain="api_by_user")
10601041

10611042

10621043
def return_success_on_head_request(

zerver/tests/test_decorators.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def my_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
519519
request.method = "POST"
520520
request.user = self.example_user("hamlet")
521521
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
522-
result = my_unlimited_view(request)
522+
result = my_unlimited_view(request, request.user)
523523

524524
self.assert_json_success(result)
525525
self.assertFalse(rate_limit_mock.called)
@@ -528,7 +528,7 @@ def my_view(request: HttpRequest, user_profile: UserProfile) -> HttpResponse:
528528
request.method = "POST"
529529
request.user = self.example_user("hamlet")
530530
with mock.patch("zerver.decorator.rate_limit") as rate_limit_mock:
531-
result = my_rate_limited_view(request)
531+
result = my_rate_limited_view(request, request.user)
532532

533533
# Don't assert json_success, since it'll be the rate_limit mock object
534534
self.assertTrue(rate_limit_mock.called)
@@ -630,10 +630,9 @@ def test_authenticated_rest_api_view_errors(self) -> None:
630630
class RateLimitTestCase(ZulipTestCase):
631631
def get_ratelimited_view(self) -> Callable[..., HttpResponse]:
632632
def f(req: Any) -> HttpResponse:
633+
rate_limit(req)
633634
return json_response(msg="some value")
634635

635-
f = rate_limit(f)
636-
637636
return f
638637

639638
def errors_disallowed(self) -> Any:

0 commit comments

Comments
 (0)