From dc40601975173816e3b9c28bd97fd7259116401f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 23 Jan 2026 13:32:16 +0000 Subject: [PATCH 1/5] Initial plan From 51d881f60da639b1f7dc65cef0a5ece1f8c2a3cc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 23 Jan 2026 13:36:03 +0000 Subject: [PATCH 2/5] Add custom state parameter support for CSRF protection Co-authored-by: hsluoyz <3787410+hsluoyz@users.noreply.github.com> --- README.md | 59 +++++++++++++++++++++++++++++++++++ src/casdoor/async_main.py | 38 +++++++++++++++++++++- src/casdoor/main.py | 45 ++++++++++++++++++++++++-- src/tests/test_async_oauth.py | 53 +++++++++++++++++++++++++++++++ src/tests/test_oauth.py | 53 +++++++++++++++++++++++++++++++ 5 files changed, 245 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fb9840e..abe0315 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,65 @@ decoded_msg = sdk.parse_jwt_token(access_token) # or sdk.parse_jwt_token(access_ `decoded_msg` is the JSON data decoded from the `access_token`, which contains user info and other useful stuff. +## CSRF Protection with State Parameter + +For enhanced security, you should use the `state` parameter to protect against CSRF attacks. The SDK provides helper methods for state generation and validation: + +### Using Custom State Parameter + +```python +from casdoor import CasdoorSDK + +# Initialize SDK +sdk = CasdoorSDK( + endpoint, + client_id, + client_secret, + certificate, + org_name, + application_name, +) + +# Step 1: Generate a state token before redirecting to Casdoor +state = sdk.generate_state_token() + +# Store the state in your session (implementation depends on your framework) +# For example, using Flask: +# session['oauth_state'] = state + +# Step 2: Generate auth URL with custom state +auth_url = sdk.get_auth_link(redirect_uri="http://localhost:8080/callback", state=state) + +# Redirect user to auth_url +# ... + +# Step 3: In your callback handler, verify the state +def callback_handler(): + # Get the state from the callback parameters + received_state = request.args.get('state') + + # Get the expected state from session + expected_state = session.get('oauth_state') + + # Verify state to prevent CSRF attacks + if not sdk.verify_state_token(received_state, expected_state): + # State validation failed - possible CSRF attack + raise ValueError("Invalid state parameter") + + # State is valid, proceed with token exchange + code = request.args.get('code') + token = sdk.get_oauth_token(code=code) + access_token = token.get("access_token") + decoded_msg = sdk.parse_jwt_token(access_token) + + # Clear the state from session + session.pop('oauth_state', None) + + return decoded_msg +``` + +**Note:** If you don't provide a custom `state` parameter to `get_auth_link()`, it will default to the `application_name` for backward compatibility. However, for production applications, it's strongly recommended to use a randomly generated state token for CSRF protection. + ## Step4. Interact with the users casdoor-python-sdk support basic user operations, like: diff --git a/src/casdoor/async_main.py b/src/casdoor/async_main.py index 27c070c..f1fc7bb 100644 --- a/src/casdoor/async_main.py +++ b/src/casdoor/async_main.py @@ -14,6 +14,7 @@ import base64 import json +import secrets from typing import Dict, List, Optional import aiohttp @@ -138,14 +139,25 @@ async def get_auth_link( redirect_uri: str, response_type: str = "code", scope: str = "read", + state: Optional[str] = None, ) -> str: + """ + Get authorization link for OAuth flow. + + :param redirect_uri: The redirect URI for the OAuth callback + :param response_type: OAuth response type (default: "code") + :param scope: OAuth scope (default: "read") + :param state: Custom state parameter for CSRF protection. + If not provided, defaults to application_name for backward compatibility. + :return: Authorization URL + """ url = self.front_endpoint + "/login/oauth/authorize" params = { "client_id": self.client_id, "response_type": response_type, "redirect_uri": redirect_uri, "scope": scope, - "state": self.application_name, + "state": state if state is not None else self.application_name, } return str(URL(url).with_query(params)) @@ -566,3 +578,27 @@ async def get_user_roles(self, username: str) -> List[Dict]: user_roles.append(role) return user_roles + + @staticmethod + def generate_state_token(length: int = 32) -> str: + """ + Generate a cryptographically secure random state token for CSRF protection. + + :param length: Length of the state token in bytes (default: 32) + :return: A random state token as a hex string + """ + return secrets.token_hex(length) + + @staticmethod + def verify_state_token(received_state: str, expected_state: str) -> bool: + """ + Verify that the received state token matches the expected state token. + Uses constant-time comparison to prevent timing attacks. + + :param received_state: The state parameter received from OAuth callback + :param expected_state: The expected state token (stored in session) + :return: True if states match, False otherwise + """ + if not received_state or not expected_state: + return False + return secrets.compare_digest(received_state, expected_state) diff --git a/src/casdoor/main.py b/src/casdoor/main.py index a87c5c5..7c41baa 100644 --- a/src/casdoor/main.py +++ b/src/casdoor/main.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import secrets from typing import Dict, List, Optional import jwt @@ -133,14 +134,30 @@ def certification(self) -> bytes: raise TypeError("certificate field must be str type") return self.certificate.encode("utf-8") - def get_auth_link(self, redirect_uri: str, response_type: str = "code", scope: str = "read"): + def get_auth_link( + self, + redirect_uri: str, + response_type: str = "code", + scope: str = "read", + state: Optional[str] = None, + ): + """ + Get authorization link for OAuth flow. + + :param redirect_uri: The redirect URI for the OAuth callback + :param response_type: OAuth response type (default: "code") + :param scope: OAuth scope (default: "read") + :param state: Custom state parameter for CSRF protection. + If not provided, defaults to application_name for backward compatibility. + :return: Authorization URL + """ url = self.front_endpoint + "/login/oauth/authorize" params = { "client_id": self.client_id, "response_type": response_type, "redirect_uri": redirect_uri, "scope": scope, - "state": self.application_name, + "state": state if state is not None else self.application_name, } r = requests.request("", url, params=params) return r.url @@ -402,3 +419,27 @@ def batch_enforce( raise ValueError(error_str) return enforce_results + + @staticmethod + def generate_state_token(length: int = 32) -> str: + """ + Generate a cryptographically secure random state token for CSRF protection. + + :param length: Length of the state token in bytes (default: 32) + :return: A random state token as a hex string + """ + return secrets.token_hex(length) + + @staticmethod + def verify_state_token(received_state: str, expected_state: str) -> bool: + """ + Verify that the received state token matches the expected state token. + Uses constant-time comparison to prevent timing attacks. + + :param received_state: The state parameter received from OAuth callback + :param expected_state: The expected state token (stored in session) + :return: True if states match, False otherwise + """ + if not received_state or not expected_state: + return False + return secrets.compare_digest(received_state, expected_state) diff --git a/src/tests/test_async_oauth.py b/src/tests/test_async_oauth.py index ec5f93d..3e01558 100644 --- a/src/tests/test_async_oauth.py +++ b/src/tests/test_async_oauth.py @@ -214,3 +214,56 @@ async def test_auth_link(self): response, f"{sdk.front_endpoint}/login/oauth/authorize?client_id={sdk.client_id}&response_type=code&redirect_uri={redirect_uri}&scope=read&state={sdk.application_name}", ) + + def test_generate_state_token(self): + sdk = self.get_sdk() + state1 = sdk.generate_state_token() + state2 = sdk.generate_state_token() + + # Check that state tokens are strings + self.assertIsInstance(state1, str) + self.assertIsInstance(state2, str) + + # Check that state tokens are not empty + self.assertGreater(len(state1), 0) + self.assertGreater(len(state2), 0) + + # Check that each generated token is unique + self.assertNotEqual(state1, state2) + + # Default length is 32 bytes = 64 hex characters + self.assertEqual(len(state1), 64) + + # Test custom length + state_custom = sdk.generate_state_token(length=16) + self.assertEqual(len(state_custom), 32) # 16 bytes = 32 hex chars + + def test_verify_state_token(self): + sdk = self.get_sdk() + state = sdk.generate_state_token() + + # Valid state should match + self.assertTrue(sdk.verify_state_token(state, state)) + + # Different states should not match + state2 = sdk.generate_state_token() + self.assertFalse(sdk.verify_state_token(state, state2)) + + # Empty or None states should not match + self.assertFalse(sdk.verify_state_token("", state)) + self.assertFalse(sdk.verify_state_token(state, "")) + self.assertFalse(sdk.verify_state_token(None, state)) + self.assertFalse(sdk.verify_state_token(state, None)) + + async def test_get_auth_link_with_custom_state(self): + sdk = self.get_sdk() + custom_state = sdk.generate_state_token() + redirect_uri = "http://localhost:9000/callback" + + # Test with custom state + auth_url = await sdk.get_auth_link(redirect_uri=redirect_uri, state=custom_state) + self.assertIn("state=" + custom_state, auth_url) + + # Test with default state (application_name) + auth_url_default = await sdk.get_auth_link(redirect_uri=redirect_uri) + self.assertIn("state=" + sdk.application_name, auth_url_default) diff --git a/src/tests/test_oauth.py b/src/tests/test_oauth.py index 1c54fe3..2ea0c5a 100644 --- a/src/tests/test_oauth.py +++ b/src/tests/test_oauth.py @@ -224,3 +224,56 @@ def test_modify_user(self): self.assertIn("status", response) self.assertIsInstance(response, dict) + + def test_generate_state_token(self): + sdk = self.get_sdk() + state1 = sdk.generate_state_token() + state2 = sdk.generate_state_token() + + # Check that state tokens are strings + self.assertIsInstance(state1, str) + self.assertIsInstance(state2, str) + + # Check that state tokens are not empty + self.assertGreater(len(state1), 0) + self.assertGreater(len(state2), 0) + + # Check that each generated token is unique + self.assertNotEqual(state1, state2) + + # Default length is 32 bytes = 64 hex characters + self.assertEqual(len(state1), 64) + + # Test custom length + state_custom = sdk.generate_state_token(length=16) + self.assertEqual(len(state_custom), 32) # 16 bytes = 32 hex chars + + def test_verify_state_token(self): + sdk = self.get_sdk() + state = sdk.generate_state_token() + + # Valid state should match + self.assertTrue(sdk.verify_state_token(state, state)) + + # Different states should not match + state2 = sdk.generate_state_token() + self.assertFalse(sdk.verify_state_token(state, state2)) + + # Empty or None states should not match + self.assertFalse(sdk.verify_state_token("", state)) + self.assertFalse(sdk.verify_state_token(state, "")) + self.assertFalse(sdk.verify_state_token(None, state)) + self.assertFalse(sdk.verify_state_token(state, None)) + + def test_get_auth_link_with_custom_state(self): + sdk = self.get_sdk() + custom_state = sdk.generate_state_token() + redirect_uri = "http://localhost:8080/callback" + + # Test with custom state + auth_url = sdk.get_auth_link(redirect_uri=redirect_uri, state=custom_state) + self.assertIn("state=" + custom_state, auth_url) + + # Test with default state (application_name) + auth_url_default = sdk.get_auth_link(redirect_uri=redirect_uri) + self.assertIn("state=" + sdk.application_name, auth_url_default) From cb09fcb34026e986cec884a4245803d934a35a00 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 23 Jan 2026 13:38:04 +0000 Subject: [PATCH 3/5] Format code with black Co-authored-by: hsluoyz <3787410+hsluoyz@users.noreply.github.com> --- src/casdoor/main.py | 4 ++-- src/tests/test_async_oauth.py | 20 ++++++++++---------- src/tests/test_oauth.py | 20 ++++++++++---------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/casdoor/main.py b/src/casdoor/main.py index 7c41baa..351cb20 100644 --- a/src/casdoor/main.py +++ b/src/casdoor/main.py @@ -15,6 +15,7 @@ import json import secrets from typing import Dict, List, Optional +from urllib.parse import urlencode import jwt import requests @@ -159,8 +160,7 @@ def get_auth_link( "scope": scope, "state": state if state is not None else self.application_name, } - r = requests.request("", url, params=params) - return r.url + return url + "?" + urlencode(params) def get_oauth_token( self, code: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None diff --git a/src/tests/test_async_oauth.py b/src/tests/test_async_oauth.py index 3e01558..3614c87 100644 --- a/src/tests/test_async_oauth.py +++ b/src/tests/test_async_oauth.py @@ -219,21 +219,21 @@ def test_generate_state_token(self): sdk = self.get_sdk() state1 = sdk.generate_state_token() state2 = sdk.generate_state_token() - + # Check that state tokens are strings self.assertIsInstance(state1, str) self.assertIsInstance(state2, str) - + # Check that state tokens are not empty self.assertGreater(len(state1), 0) self.assertGreater(len(state2), 0) - + # Check that each generated token is unique self.assertNotEqual(state1, state2) - + # Default length is 32 bytes = 64 hex characters self.assertEqual(len(state1), 64) - + # Test custom length state_custom = sdk.generate_state_token(length=16) self.assertEqual(len(state_custom), 32) # 16 bytes = 32 hex chars @@ -241,14 +241,14 @@ def test_generate_state_token(self): def test_verify_state_token(self): sdk = self.get_sdk() state = sdk.generate_state_token() - + # Valid state should match self.assertTrue(sdk.verify_state_token(state, state)) - + # Different states should not match state2 = sdk.generate_state_token() self.assertFalse(sdk.verify_state_token(state, state2)) - + # Empty or None states should not match self.assertFalse(sdk.verify_state_token("", state)) self.assertFalse(sdk.verify_state_token(state, "")) @@ -259,11 +259,11 @@ async def test_get_auth_link_with_custom_state(self): sdk = self.get_sdk() custom_state = sdk.generate_state_token() redirect_uri = "http://localhost:9000/callback" - + # Test with custom state auth_url = await sdk.get_auth_link(redirect_uri=redirect_uri, state=custom_state) self.assertIn("state=" + custom_state, auth_url) - + # Test with default state (application_name) auth_url_default = await sdk.get_auth_link(redirect_uri=redirect_uri) self.assertIn("state=" + sdk.application_name, auth_url_default) diff --git a/src/tests/test_oauth.py b/src/tests/test_oauth.py index 2ea0c5a..2a91f9c 100644 --- a/src/tests/test_oauth.py +++ b/src/tests/test_oauth.py @@ -229,21 +229,21 @@ def test_generate_state_token(self): sdk = self.get_sdk() state1 = sdk.generate_state_token() state2 = sdk.generate_state_token() - + # Check that state tokens are strings self.assertIsInstance(state1, str) self.assertIsInstance(state2, str) - + # Check that state tokens are not empty self.assertGreater(len(state1), 0) self.assertGreater(len(state2), 0) - + # Check that each generated token is unique self.assertNotEqual(state1, state2) - + # Default length is 32 bytes = 64 hex characters self.assertEqual(len(state1), 64) - + # Test custom length state_custom = sdk.generate_state_token(length=16) self.assertEqual(len(state_custom), 32) # 16 bytes = 32 hex chars @@ -251,14 +251,14 @@ def test_generate_state_token(self): def test_verify_state_token(self): sdk = self.get_sdk() state = sdk.generate_state_token() - + # Valid state should match self.assertTrue(sdk.verify_state_token(state, state)) - + # Different states should not match state2 = sdk.generate_state_token() self.assertFalse(sdk.verify_state_token(state, state2)) - + # Empty or None states should not match self.assertFalse(sdk.verify_state_token("", state)) self.assertFalse(sdk.verify_state_token(state, "")) @@ -269,11 +269,11 @@ def test_get_auth_link_with_custom_state(self): sdk = self.get_sdk() custom_state = sdk.generate_state_token() redirect_uri = "http://localhost:8080/callback" - + # Test with custom state auth_url = sdk.get_auth_link(redirect_uri=redirect_uri, state=custom_state) self.assertIn("state=" + custom_state, auth_url) - + # Test with default state (application_name) auth_url_default = sdk.get_auth_link(redirect_uri=redirect_uri) self.assertIn("state=" + sdk.application_name, auth_url_default) From 1d9767c7cd81148e5fa6a3331602e35790917968 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 23 Jan 2026 13:39:45 +0000 Subject: [PATCH 4/5] Improve URL building to use urlparse for robustness Co-authored-by: hsluoyz <3787410+hsluoyz@users.noreply.github.com> --- src/casdoor/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/casdoor/main.py b/src/casdoor/main.py index 351cb20..6cd37e5 100644 --- a/src/casdoor/main.py +++ b/src/casdoor/main.py @@ -15,7 +15,7 @@ import json import secrets from typing import Dict, List, Optional -from urllib.parse import urlencode +from urllib.parse import urlencode, urlparse, urlunparse import jwt import requests @@ -160,7 +160,9 @@ def get_auth_link( "scope": scope, "state": state if state is not None else self.application_name, } - return url + "?" + urlencode(params) + # Parse the URL and add query parameters + parsed = urlparse(url) + return urlunparse(parsed._replace(query=urlencode(params))) def get_oauth_token( self, code: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None From cef6597efacb7c1bafcae3555852371777a78b8b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 23 Jan 2026 13:42:05 +0000 Subject: [PATCH 5/5] Fix type annotations for verify_state_token to use Optional[str] Co-authored-by: hsluoyz <3787410+hsluoyz@users.noreply.github.com> --- src/casdoor/async_main.py | 2 +- src/casdoor/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/casdoor/async_main.py b/src/casdoor/async_main.py index f1fc7bb..f26ab5c 100644 --- a/src/casdoor/async_main.py +++ b/src/casdoor/async_main.py @@ -590,7 +590,7 @@ def generate_state_token(length: int = 32) -> str: return secrets.token_hex(length) @staticmethod - def verify_state_token(received_state: str, expected_state: str) -> bool: + def verify_state_token(received_state: Optional[str], expected_state: Optional[str]) -> bool: """ Verify that the received state token matches the expected state token. Uses constant-time comparison to prevent timing attacks. diff --git a/src/casdoor/main.py b/src/casdoor/main.py index 6cd37e5..ce98375 100644 --- a/src/casdoor/main.py +++ b/src/casdoor/main.py @@ -433,7 +433,7 @@ def generate_state_token(length: int = 32) -> str: return secrets.token_hex(length) @staticmethod - def verify_state_token(received_state: str, expected_state: str) -> bool: + def verify_state_token(received_state: Optional[str], expected_state: Optional[str]) -> bool: """ Verify that the received state token matches the expected state token. Uses constant-time comparison to prevent timing attacks.