Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [Unreleased] - 2025-08-15
### Fixed
- Fix public routes being protected when passing `url_base_pathname` or `routes_pathname_prefix` to app
- Fix OIDC redirects after login and logout when passing `url_base_pathname` or `routes_pathname_prefix` to app

## [2.3.0] - 2024-03-18
### Added
- OIDCAuth allows to authenticate via OIDC
Expand Down
21 changes: 13 additions & 8 deletions dash_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
from flask import request

from .public_routes import (
add_public_routes, get_public_callbacks, get_public_routes
add_public_routes,
get_public_callbacks,
get_public_routes,
)


class Auth(ABC):
def __init__(
self,
app: Dash,
public_routes: Optional[list] = None,
**obsolete
self, app: Dash, public_routes: Optional[list] = None, **obsolete
):
"""Auth base class for authentication in Dash.

Expand Down Expand Up @@ -47,14 +46,19 @@ def _protect(self):

@server.before_request
def before_request_auth():

public_routes = get_public_routes(self.app)
public_callbacks = get_public_callbacks(self.app)
url_base = (
self.app.config.get("url_base_pathname", "")
or self.app.config.get("requests_pathname_prefix", "")
or self.app.config.get("routes_pathname_prefix", "")
)
# Handle Dash's callback route:
# * Check whether the callback is marked as public
# * Check whether the callback is performed on route change in
# which case the path should be checked against the public routes
if request.path == "/_dash-update-component":
callback_path = f"{url_base.rstrip('/')}/_dash-update-component"
if request.path == callback_path:
body = request.get_json()

# Check whether the callback is marked as public
Expand All @@ -66,7 +70,8 @@ def before_request_auth():
# should be checked against the public routes
pathname = next(
(
inp.get("value") for inp in body["inputs"]
inp.get("value")
for inp in body["inputs"]
if isinstance(inp, dict)
and inp.get("property") == "pathname"
),
Expand Down
36 changes: 22 additions & 14 deletions dash_auth/oidc_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

if TYPE_CHECKING:
from authlib.integrations.flask_client.apps import (
FlaskOAuth1App, FlaskOAuth2App
FlaskOAuth1App,
FlaskOAuth2App,
)


Expand Down Expand Up @@ -175,28 +176,24 @@ def register_provider(self, idp_name: str, **kwargs):
)
client_kwargs = kwargs.pop("client_kwargs", {})
client_kwargs.setdefault("scope", "openid email")
self.oauth.register(
idp_name, client_kwargs=client_kwargs, **kwargs
)
self.oauth.register(idp_name, client_kwargs=client_kwargs, **kwargs)

def get_oauth_client(self, idp: str):
"""Get the OAuth client."""
if idp not in self.oauth._registry:
raise ValueError(f"'{idp}' is not a valid registered idp")

client: Union[FlaskOAuth1App, FlaskOAuth2App] = (
self.oauth.create_client(idp)
)
client: Union[
FlaskOAuth1App, FlaskOAuth2App
] = self.oauth.create_client(idp)
return client

def get_oauth_kwargs(self, idp: str):
"""Get the OAuth kwargs."""
if idp not in self.oauth._registry:
raise ValueError(f"'{idp}' is not a valid registered idp")

kwargs: dict = (
self.oauth._registry[idp][1]
)
kwargs: dict = self.oauth._registry[idp][1]
return kwargs

def _create_redirect_uri(self, idp: str):
Expand Down Expand Up @@ -242,14 +239,21 @@ def login_request(self, idp: str = None):
def logout(self): # pylint: disable=C0116
"""Logout the user."""
session.clear()
base_url = self.app.config.get("url_base_pathname") or "/"
page = self.logout_page or f"""
base_url = (
self.app.config.get("url_base_pathname")
or self.app.config.get("routes_pathname_prefix")
or "/"
)
page = (
self.logout_page
or f"""
<div style="display: flex; flex-direction: column;
gap: 0.75rem; padding: 3rem 5rem;">
<div>Logged out successfully</div>
<div><a href="{base_url}">Go back</a></div>
</div>
"""
)
return page

def callback(self, idp: str): # pylint: disable=C0116
Expand All @@ -269,7 +273,7 @@ def callback(self, idp: str): # pylint: disable=C0116
user = token.get("userinfo")
return self.after_logged_in(user, idp, token)

def after_logged_in(self, user: Optional[dict], idp: str, token: dict):
def after_logged_in(self, user: Optional[dict], idp: str, token: dict):
"""
Post-login actions after successful OIDC authentication.
For example, allows to pass custom attributes to the user session:
Expand All @@ -288,7 +292,11 @@ def after_logged_in(self, user, idp, token):
if self.log_signins:
logging.info("User %s is logging in.", user.get("email"))

return redirect(self.app.config.get("url_base_pathname") or "/")
return redirect(
self.app.config.get("url_base_pathname")
or self.app.config.get("routes_pathname_prefix")
or "/"
)

def is_authorized(self): # pylint: disable=C0116
"""Check whether ther user is authenticated."""
Expand Down
7 changes: 7 additions & 0 deletions dash_auth/public_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,18 @@ def add_public_routes(app: Dash, routes: list):
"""

public_routes = get_public_routes(app)
url_base = (
app.config.get("url_base_pathname", "")
or app.config.get("requests_pathname_prefix", "")
or app.config.get("routes_pathname_prefix", "")
)

if not public_routes.map._rules:
routes = BASE_PUBLIC_ROUTES + routes

for route in routes:
if url_base and not route.startswith(url_base):
route = url_base.rstrip("/") + route
public_routes.map.add(Rule(route))

app.server.config[PUBLIC_ROUTES] = public_routes
Expand Down
42 changes: 35 additions & 7 deletions tests/test_basic_auth_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dash import Dash, Input, Output, dcc, html
import requests

import pytest
from dash_auth import BasicAuth, add_public_routes, protected


Expand All @@ -15,8 +15,17 @@
}


def test_ba001_basic_auth_login_flow(dash_br, dash_thread_server):
app = Dash(__name__)
@pytest.mark.parametrize(
"kwargs",
[
{},
{"url_base_pathname": "/app/"},
{"routes_pathname_prefix": "/app/"},
{"routes_pathname_prefix": "/app/", "requests_pathname_prefix": "/app/"},
],
)
def test_ba001_basic_auth_login_flow(dash_br, dash_thread_server, kwargs):
app = Dash(__name__, **kwargs)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output")
Expand All @@ -30,7 +39,12 @@ def update_output(new_value):
add_public_routes(app, ["/user/<user_id>/public"])

dash_thread_server(app)
base_url = dash_thread_server.url
path_prefix = (
app.config.get("url_base_pathname", "")
or app.config.get("requests_pathname_prefix", "")
or app.config.get("routes_pathname_prefix", "")
)
base_url = dash_thread_server.url + path_prefix

def test_failed_views(url):
assert requests.get(url).status_code == 401
Expand Down Expand Up @@ -60,8 +74,17 @@ def test_successful_views(url):
dash_br.wait_for_text_to_equal("#output", "initial value")


def test_ba002_basic_auth_groups(dash_br, dash_thread_server):
app = Dash(__name__)
@pytest.mark.parametrize(
"kwargs",
[
{},
{"url_base_pathname": "/app/"},
{"routes_pathname_prefix": "/app/"},
{"routes_pathname_prefix": "/app/", "requests_pathname_prefix": "/app/"},
],
)
def test_ba002_basic_auth_groups(dash_br, dash_thread_server, kwargs):
app = Dash(__name__, **kwargs)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output")
Expand Down Expand Up @@ -89,7 +112,12 @@ def update_output(new_value):
)

dash_thread_server(app)
base_url = dash_thread_server.url
path_prefix = (
app.config.get("url_base_pathname", "")
or app.config.get("requests_pathname_prefix", "")
or app.config.get("routes_pathname_prefix", "")
)
base_url = dash_thread_server.url + path_prefix

for user, password in TEST_USERS["valid"]:
# login using the URL instead of the alert popup
Expand Down
74 changes: 58 additions & 16 deletions tests/test_oidc_auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from unittest.mock import patch

import requests
Expand All @@ -9,6 +8,7 @@
protected_callback,
OIDCAuth,
)
import pytest


def valid_authorize_redirect(_, redirect_uri, *args, **kwargs):
Expand All @@ -27,10 +27,19 @@ def valid_authorize_access_token(*args, **kwargs):
}


@pytest.mark.parametrize(
"kwargs",
[
{},
{"url_base_pathname": "/app/"},
{"routes_pathname_prefix": "/app/"},
{"routes_pathname_prefix": "/app/", "requests_pathname_prefix": "/app/"},
],
)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server):
app = Dash(__name__)
def test_oa001_oidc_auth_login_flow_success(dash_br, dash_thread_server, kwargs):
app = Dash(__name__, **kwargs)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
Expand Down Expand Up @@ -89,7 +98,12 @@ def update_output5(new_value):
server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
)
dash_thread_server(app)
base_url = dash_thread_server.url
path_prefix = (
app.config.get("url_base_pathname", "")
or app.config.get("requests_pathname_prefix", "")
or app.config.get("routes_pathname_prefix", "")
)
base_url = dash_thread_server.url + path_prefix

assert requests.get(base_url).status_code == 200

Expand All @@ -101,9 +115,18 @@ def update_output5(new_value):
dash_br.wait_for_text_to_equal("#output5", "initial value")


@pytest.mark.parametrize(
"kwargs",
[
{},
{"url_base_pathname": "/app/"},
{"routes_pathname_prefix": "/app/"},
{"routes_pathname_prefix": "/app/", "requests_pathname_prefix": "/app/"},
],
)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", invalid_authorize_redirect)
def test_oa002_oidc_auth_login_fail(dash_thread_server):
app = Dash(__name__)
def test_oa002_oidc_auth_login_fail(dash_thread_server, kwargs):
app = Dash(__name__, **kwargs)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output")
Expand All @@ -122,7 +145,12 @@ def update_output(new_value):
server_metadata_url="https://idp.com/oidc/2/.well-known/openid-configuration",
)
dash_thread_server(app)
base_url = dash_thread_server.url
path_prefix = (
app.config.get("url_base_pathname", "")
or app.config.get("requests_pathname_prefix", "")
or app.config.get("routes_pathname_prefix", "")
)
base_url = dash_thread_server.url + path_prefix

def test_unauthorized(url):
r = requests.get(url)
Expand All @@ -133,13 +161,22 @@ def test_authorized(url):
assert requests.get(url).status_code == 200

test_unauthorized(base_url)
test_authorized(os.path.join(base_url, "public"))
test_authorized("/".join([base_url, "public"]))


@pytest.mark.parametrize(
"kwargs",
[
{},
{"url_base_pathname": "/app/"},
{"routes_pathname_prefix": "/app/"},
{"routes_pathname_prefix": "/app/", "requests_pathname_prefix": "/app/"},
],
)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_redirect", valid_authorize_redirect)
@patch("authlib.integrations.flask_client.apps.FlaskOAuth2App.authorize_access_token", valid_authorize_access_token)
def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server):
app = Dash(__name__)
def test_oa003_oidc_auth_login_several_idp(dash_br, dash_thread_server, kwargs):
app = Dash(__name__, **kwargs)
app.layout = html.Div([
dcc.Input(id="input", value="initial value"),
html.Div(id="output1"),
Expand Down Expand Up @@ -168,21 +205,26 @@ def update_output1(new_value):
)

dash_thread_server(app)
path_prefix = (
app.config.get("url_base_pathname", "")
or app.config.get("requests_pathname_prefix", "")
or app.config.get("routes_pathname_prefix", "")
)
base_url = dash_thread_server.url

base_url_prefix = (base_url + path_prefix).strip("/")
assert requests.get(base_url).status_code == 400

# Login with IDP1
assert requests.get(os.path.join(base_url, "oidc/idp1/login")).status_code == 200
assert requests.get(base_url + "/oidc/idp1/login").status_code == 200

# Logout
assert requests.get(os.path.join(base_url, "oidc/logout")).status_code == 200
assert requests.get(base_url + "/oidc/logout").status_code == 200

assert requests.get(base_url).status_code == 400

# Login with IDP2
assert requests.get(os.path.join(base_url, "oidc/idp2/login")).status_code == 200
assert requests.get(base_url + "/oidc/idp2/login").status_code == 200

dash_br.driver.get(os.path.join(base_url, "oidc/idp2/login"))
dash_br.driver.get(base_url)
dash_br.driver.get(base_url + "/oidc/idp2/login")
dash_br.driver.get(base_url_prefix)
dash_br.wait_for_text_to_equal("#output1", "initial value")