Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions msal/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def _main():
logging.error("Invalid input: %s", e)
except KeyboardInterrupt: # Useful for bailing out a stuck interactive flow
print("Aborted")
except Exception as e:
logging.error("Error: %s", e)

if __name__ == "__main__":
_main()
Expand Down
10 changes: 9 additions & 1 deletion msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ def initiate_auth_code_flow(

:param str response_mode:
OPTIONAL. Specifies the method with which response parameters should be returned.
The default value is equivalent to ``query``, which is still secure enough in MSAL Python
The default value is equivalent to ``query``, which was still secure enough in MSAL Python
(because MSAL Python does not transfer tokens via query parameter in the first place).
For even better security, we recommend using the value ``form_post``.
In "form_post" mode, response parameters
Expand All @@ -973,6 +973,11 @@ def initiate_auth_code_flow(
`here <https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#ResponseModes>`
and `here <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html#FormPostResponseMode>`

.. note::
You should configure your web framework to accept form_post responses instead of query responses.
While this parameter still works, it will be removed in a future version.
Using query-based response modes is less secure and should be avoided.

:return:
The auth code flow. It is a dict in this form::

Expand All @@ -991,6 +996,9 @@ def initiate_auth_code_flow(
3. and then relay this dict and subsequent auth response to
:func:`~acquire_token_by_auth_code_flow()`.
"""
# Note to maintainers: Do not emit warning for the use of response_mode here,
# because response_mode=form_post is still the recommended usage for MSAL Python 1.x.
# App developers making the right call shall not be disturbed by unactionable warnings.
client = _ClientWithCcsRoutingInfo(
{"authorization_endpoint": self.authority.authorization_endpoint},
self.client_id,
Expand Down
81 changes: 58 additions & 23 deletions msal/oauth2cli/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
It optionally opens a browser window to guide a human user to manually login.
After obtaining an auth code, the web server will automatically shut down.
"""
from collections import defaultdict
import logging
import os
import socket
Expand Down Expand Up @@ -109,29 +110,49 @@ def _printify(text):

class _AuthCodeHandler(BaseHTTPRequestHandler):
def do_GET(self):
# For flexibility, we choose to not check self.path matching redirect_uri
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
qs = parse_qs(urlparse(self.path).query)
if qs.get('code') or qs.get("error"): # So, it is an auth response
auth_response = _qs2kv(qs)
logger.debug("Got auth response: %s", auth_response)
if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
# OAuth2 successful and error responses contain state when it was used
# https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
self._send_full_response("State mismatch") # Possibly an attack
else:
template = (self.server.success_template
if "code" in qs else self.server.error_template)
if _is_html(template.template):
safe_data = _escape(auth_response) # Foiling an XSS attack
else:
safe_data = auth_response
self._send_full_response(template.safe_substitute(**safe_data))
self.server.auth_response = auth_response # Set it now, after the response is likely sent
if qs:
# GET request with auth code or error - reject for security (form_post only)
self._send_full_response(
"response_mode=query is not supported for authentication responses. "
"This application operates in response_mode=form_post mode only.",
is_ok=False)
else:
# Other GET requests - show welcome page
self._send_full_response(self.server.welcome_page)
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.

def do_POST(self): # Handle form_post response where auth code is in body
# For flexibility, we choose to not check self.path matching redirect_uri
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
content_length = int(self.headers.get('Content-Length', 0))
post_data = self.rfile.read(content_length).decode('utf-8')
qs = parse_qs(post_data)
if qs.get('code') or qs.get('error'): # So, it is an auth response
self._process_auth_response(_qs2kv(qs))
else:
self._send_full_response("Invalid POST request", is_ok=False)
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.

def _process_auth_response(self, auth_response):
"""Process the auth response from either GET or POST request."""
logger.debug("Got auth response: %s", auth_response)
if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
# OAuth2 successful and error responses contain state when it was used
# https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
self._send_full_response( # Possibly an attack
"State mismatch. Waiting for next response... or you may abort.", is_ok=False)
else:
template = (self.server.success_template
if "code" in auth_response else self.server.error_template)
if _is_html(template.template):
safe_data = _escape(auth_response) # Foiling an XSS attack
else:
safe_data = auth_response
filled_data = defaultdict(str, safe_data) # So that missing keys will be empty string
self._send_full_response(template.safe_substitute(**filled_data))
self.server.auth_response = auth_response # Set it now, after the response is likely sent

def _send_full_response(self, body, is_ok=True):
self.send_response(200 if is_ok else 400)
content_type = 'text/html' if _is_html(body) else 'text/plain'
Expand Down Expand Up @@ -215,6 +236,7 @@ def get_auth_response(self, timeout=None, **kwargs):

:param str auth_uri:
If provided, this function will try to open a local browser.
Starting from 2026, the built-in http server will require response_mode=form_post.
:param int timeout: In seconds. None means wait indefinitely.
:param str state:
You may provide the state you used in auth_uri,
Expand Down Expand Up @@ -287,8 +309,20 @@ def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None,
welcome_uri = "http://localhost:{p}".format(p=self.get_port())
abort_uri = "{loc}?error=abort".format(loc=welcome_uri)
logger.debug("Abort by visit %s", abort_uri)
self._server.welcome_page = Template(welcome_template or "").safe_substitute(
auth_uri=auth_uri, abort_uri=abort_uri)

if auth_uri:
# Note to maintainers:
# Do not enforce response_mode=form_post by secretly hardcoding it here.
# Just validate it here, so we won't surprise caller by changing their auth_uri behind the scene.
params = parse_qs(urlparse(auth_uri).query)
assert params.get('response_mode', [None])[0] == 'form_post', (
"The built-in http server supports HTTP POST only. "
"The auth_uri must be built with response_mode=form_post")

self._server.welcome_page = Template(
welcome_template or
"<a href='$auth_uri'>Sign In</a>, or <a href='$abort_uri'>Abort</a>"
).safe_substitute(auth_uri=auth_uri, abort_uri=abort_uri)
if auth_uri: # Now attempt to open a local browser to visit it
_uri = welcome_uri if welcome_template else auth_uri
logger.info("Open a browser on this device to visit: %s" % _uri)
Expand Down Expand Up @@ -317,8 +351,11 @@ def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None,
auth_uri_callback(_uri)

self._server.success_template = Template(success_template or
"Authentication completed. You can close this window now.")
"Authentication complete. You can return to the application. Please close this browser tab.")
self._server.error_template = Template(error_template or
# Do NOT invent new placeholders in this template. Just use standard keys defined in OAuth2 RFC.
# Otherwise there is no obvious canonical way for caller to know what placeholders are supported.
# Besides, we have been using these standard keys for years. Changing now would break backward compatibility.
"Authentication failed. $error: $error_description. ($error_uri)")

self._server.timeout = timeout # Otherwise its handle_timeout() won't work
Expand Down Expand Up @@ -370,8 +407,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
)
print(json.dumps(receiver.get_auth_response(
auth_uri=flow["auth_uri"],
welcome_template=
"<a href='$auth_uri'>Sign In</a>, or <a href='$abort_uri'>Abort</a",
error_template="<html>Oh no. $error</html>",
success_template="Oh yeah. Got $code",
timeout=args.timeout,
Expand Down
12 changes: 11 additions & 1 deletion msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,21 @@ def obtain_token_by_device_flow(self,

def _build_auth_request_uri(
self,
response_type, redirect_uri=None, scope=None, state=None, **kwargs):
response_type,
*,
redirect_uri=None, scope=None, state=None, response_mode=None,
**kwargs):
if "authorization_endpoint" not in self.configuration:
raise ValueError("authorization_endpoint not found in configuration")
authorization_endpoint = self.configuration["authorization_endpoint"]
if response_mode != 'form_post':
warnings.warn(
"response_mode='form_post' is recommended for better security. "
"See https://www.rfc-editor.org/rfc/rfc9700.html#section-4.3.1"
)
params = self._build_auth_request_params(
response_type, redirect_uri=redirect_uri, scope=scope, state=state,
response_mode=response_mode,
**kwargs)
sep = '&' if '?' in authorization_endpoint else '?'
return "%s%s%s" % (authorization_endpoint, sep, urlencode(params))
Expand Down Expand Up @@ -669,6 +678,7 @@ def _obtain_token_by_browser(
flow = self.initiate_auth_code_flow(
redirect_uri=redirect_uri,
scope=_scope_set(scope) | _scope_set(extra_scope_to_consent),
response_mode='form_post', # The auth_code_receiver has been changed to require it
**(auth_params or {}))
auth_response = auth_code_receiver.get_auth_response(
auth_uri=flow["auth_uri"],
Expand Down
72 changes: 68 additions & 4 deletions tests/test_authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,81 @@ def test_no_two_concurrent_receivers_can_listen_on_same_port(self):
pass

def test_template_should_escape_input(self):
"""Test that HTML in error response is properly escaped"""
with AuthCodeReceiver() as receiver:
receiver._scheduled_actions = [( # Injection happens here when the port is known
1, # Delay it until the receiver is activated by get_auth_response()
lambda: self.assertEqual(
"<html>&lt;tag&gt;foo&lt;/tag&gt;</html>",
requests.get("http://localhost:{}?error=<tag>foo</tag>".format(
receiver.get_port())).text,
"Unsafe data in HTML should be escaped",
"<html>&lt;script&gt;alert(&#x27;xss&#x27;);&lt;/script&gt;</html>",
requests.post(
"http://localhost:{}".format(receiver.get_port()),
data={"error": "<script>alert('xss');</script>"},
).text,
))]
receiver.get_auth_response( # Starts server and hang until timeout
timeout=3,
error_template="<html>$error</html>",
)

def test_get_request_with_auth_code_is_rejected(self):
"""Test that GET request with auth code is rejected for security"""
with AuthCodeReceiver() as receiver:
test_state = "test_state_67890"
receiver._scheduled_actions = [(
1,
lambda: self.assertEqual(400, requests.get(
"http://localhost:{}".format(receiver.get_port()), params={
"code": "test_auth_code_12345",
"state": test_state
}
).status_code)
)]
result = receiver.get_auth_response(timeout=3, state=test_state)
self.assertIsNone(result, "Should not receive auth response via GET")

def test_post_request_with_auth_code(self):
"""Test that POST request with auth code is handled correctly (form_post response mode)"""
with AuthCodeReceiver() as receiver:
test_code = "test_auth_code_12345"
test_state = "test_state_67890"
receiver._scheduled_actions = [(
1,
lambda: requests.post(
"http://localhost:{}".format(receiver.get_port()),
data={"code": test_code, "state": test_state},
)
)]
result = receiver.get_auth_response(timeout=3, state=test_state)
self.assertIsNotNone(result, "Should receive auth response via POST")
self.assertEqual(result.get("code"), test_code)
self.assertEqual(result.get("state"), test_state)

def test_post_request_with_error(self):
"""Test that POST request with error is handled correctly"""
with AuthCodeReceiver() as receiver:
test_error = "access_denied"
test_error_description = "User denied access"
receiver._scheduled_actions = [(
1,
lambda: requests.post(
"http://localhost:{}".format(receiver.get_port()),
data={"error": test_error, "error_description": test_error_description},
)
)]
result = receiver.get_auth_response(timeout=3)
self.assertIsNotNone(result, "Should receive auth response via POST")
self.assertEqual(result.get("error"), test_error)
self.assertEqual(result.get("error_description"), test_error_description)

def test_post_request_state_mismatch(self):
"""Test that POST request with mismatched state is rejected"""
with AuthCodeReceiver() as receiver:
receiver._scheduled_actions = [(
1,
lambda: requests.post(
"http://localhost:{}".format(receiver.get_port()),
data={"code": "test_code", "state": "wrong_state"},
)
)]
result = receiver.get_auth_response(timeout=3, state="expected_state")
self.assertIsNone(result, "Should not receive auth response due to state mismatch")
Loading
Loading