diff --git a/docs/examples.md b/docs/examples.md index 1ee072e..8f1fad0 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -8,7 +8,7 @@ The tests show a lot of different use cases that are not all covered here. ```python from starlette.applications import Starlette - from slowapi import Limiter, _rate_limit_exceeded_handler + from slowapi import Limiter, rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.middleware import SlowAPIMiddleware from slowapi.errors import RateLimitExceeded @@ -16,7 +16,7 @@ The tests show a lot of different use cases that are not all covered here. limiter = Limiter(key_func=get_remote_address, default_limits=["1/minute"]) app = Starlette() app.state.limiter = limiter - app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler) app.add_middleware(SlowAPIMiddleware) # this will be limited by the default_limits diff --git a/docs/index.md b/docs/index.md index 5ec4219..0e92c5b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,14 +20,14 @@ $ pip install slowapi from starlette.applications import Starlette from starlette.responses import PlainTextResponse from starlette.requests import Request - from slowapi import Limiter, _rate_limit_exceeded_handler + from slowapi import Limiter, rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded limiter = Limiter(key_func=get_remote_address) app = Starlette() app.state.limiter = limiter - app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler) @limiter.limit("5/minute") async def homepage(request: Request): @@ -42,14 +42,14 @@ The above app will have a route `t1` that will accept up to 5 requests per minut ```python from fastapi import FastAPI, Request, Response - from slowapi import Limiter, _rate_limit_exceeded_handler + from slowapi import Limiter, rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded limiter = Limiter(key_func=get_remote_address) app = FastAPI() app.state.limiter = limiter - app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler) # Note: the route decorator must be above the limit decorator, not below it @app.get("/home") diff --git a/slowapi/__init__.py b/slowapi/__init__.py index cfa284e..2a69f0b 100644 --- a/slowapi/__init__.py +++ b/slowapi/__init__.py @@ -1,3 +1,7 @@ -from .extension import Limiter, _rate_limit_exceeded_handler +from .extension import ( + Limiter, + rate_limit_exceeded_handler, + _rate_limit_exceeded_handler, +) -__all__ = ["Limiter", "_rate_limit_exceeded_handler"] +__all__ = ["Limiter", "rate_limit_exceeded_handler", "_rate_limit_exceeded_handler"] diff --git a/slowapi/extension.py b/slowapi/extension.py index 050f882..f856e50 100644 --- a/slowapi/extension.py +++ b/slowapi/extension.py @@ -87,6 +87,15 @@ def _rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> Re return response +def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response: + """ + Handle rate limit exceeded exceptions. + """ + if isinstance(exc, RateLimitExceeded): + return _rate_limit_exceeded_handler(request, exc) + raise exc + + class Limiter: """ Initializes the slowapi rate limiter.