Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,65 @@ decoded_msg = sdk.parse_jwt_token(access_token) # or sdk.parse_jwt_token(access_

`decoded_msg` is the JSON data decoded from the `access_token`, which contains user info and other useful stuff.

## CSRF Protection with State Parameter

For enhanced security, you should use the `state` parameter to protect against CSRF attacks. The SDK provides helper methods for state generation and validation:

### Using Custom State Parameter

```python
from casdoor import CasdoorSDK

# Initialize SDK
sdk = CasdoorSDK(
endpoint,
client_id,
client_secret,
certificate,
org_name,
application_name,
)

# Step 1: Generate a state token before redirecting to Casdoor
state = sdk.generate_state_token()

# Store the state in your session (implementation depends on your framework)
# For example, using Flask:
# session['oauth_state'] = state

# Step 2: Generate auth URL with custom state
auth_url = sdk.get_auth_link(redirect_uri="http://localhost:8080/callback", state=state)

# Redirect user to auth_url
# ...

# Step 3: In your callback handler, verify the state
def callback_handler():
# Get the state from the callback parameters
received_state = request.args.get('state')

# Get the expected state from session
expected_state = session.get('oauth_state')

# Verify state to prevent CSRF attacks
if not sdk.verify_state_token(received_state, expected_state):
# State validation failed - possible CSRF attack
raise ValueError("Invalid state parameter")

# State is valid, proceed with token exchange
code = request.args.get('code')
token = sdk.get_oauth_token(code=code)
access_token = token.get("access_token")
decoded_msg = sdk.parse_jwt_token(access_token)

# Clear the state from session
session.pop('oauth_state', None)

return decoded_msg
```

**Note:** If you don't provide a custom `state` parameter to `get_auth_link()`, it will default to the `application_name` for backward compatibility. However, for production applications, it's strongly recommended to use a randomly generated state token for CSRF protection.

## Step4. Interact with the users

casdoor-python-sdk support basic user operations, like:
Expand Down
38 changes: 37 additions & 1 deletion src/casdoor/async_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import base64
import json
import secrets
from typing import Dict, List, Optional

import aiohttp
Expand Down Expand Up @@ -138,14 +139,25 @@ async def get_auth_link(
redirect_uri: str,
response_type: str = "code",
scope: str = "read",
state: Optional[str] = None,
) -> str:
"""
Get authorization link for OAuth flow.

:param redirect_uri: The redirect URI for the OAuth callback
:param response_type: OAuth response type (default: "code")
:param scope: OAuth scope (default: "read")
:param state: Custom state parameter for CSRF protection.
If not provided, defaults to application_name for backward compatibility.
:return: Authorization URL
"""
url = self.front_endpoint + "/login/oauth/authorize"
params = {
"client_id": self.client_id,
"response_type": response_type,
"redirect_uri": redirect_uri,
"scope": scope,
"state": self.application_name,
"state": state if state is not None else self.application_name,
}
return str(URL(url).with_query(params))

Expand Down Expand Up @@ -566,3 +578,27 @@ async def get_user_roles(self, username: str) -> List[Dict]:
user_roles.append(role)

return user_roles

@staticmethod
def generate_state_token(length: int = 32) -> str:
"""
Generate a cryptographically secure random state token for CSRF protection.

:param length: Length of the state token in bytes (default: 32)
:return: A random state token as a hex string
"""
return secrets.token_hex(length)

@staticmethod
def verify_state_token(received_state: Optional[str], expected_state: Optional[str]) -> bool:
"""
Verify that the received state token matches the expected state token.
Uses constant-time comparison to prevent timing attacks.

:param received_state: The state parameter received from OAuth callback
:param expected_state: The expected state token (stored in session)
:return: True if states match, False otherwise
"""
if not received_state or not expected_state:
return False
return secrets.compare_digest(received_state, expected_state)
51 changes: 47 additions & 4 deletions src/casdoor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import json
import secrets
from typing import Dict, List, Optional
from urllib.parse import urlencode, urlparse, urlunparse

import jwt
import requests
Expand Down Expand Up @@ -133,17 +135,34 @@ def certification(self) -> bytes:
raise TypeError("certificate field must be str type")
return self.certificate.encode("utf-8")

def get_auth_link(self, redirect_uri: str, response_type: str = "code", scope: str = "read"):
def get_auth_link(
self,
redirect_uri: str,
response_type: str = "code",
scope: str = "read",
state: Optional[str] = None,
):
"""
Get authorization link for OAuth flow.

:param redirect_uri: The redirect URI for the OAuth callback
:param response_type: OAuth response type (default: "code")
:param scope: OAuth scope (default: "read")
:param state: Custom state parameter for CSRF protection.
If not provided, defaults to application_name for backward compatibility.
:return: Authorization URL
"""
url = self.front_endpoint + "/login/oauth/authorize"
params = {
"client_id": self.client_id,
"response_type": response_type,
"redirect_uri": redirect_uri,
"scope": scope,
"state": self.application_name,
"state": state if state is not None else self.application_name,
}
r = requests.request("", url, params=params)
return r.url
# Parse the URL and add query parameters
parsed = urlparse(url)
return urlunparse(parsed._replace(query=urlencode(params)))

def get_oauth_token(
self, code: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None
Expand Down Expand Up @@ -402,3 +421,27 @@ def batch_enforce(
raise ValueError(error_str)

return enforce_results

@staticmethod
def generate_state_token(length: int = 32) -> str:
"""
Generate a cryptographically secure random state token for CSRF protection.

:param length: Length of the state token in bytes (default: 32)
:return: A random state token as a hex string
"""
return secrets.token_hex(length)

@staticmethod
def verify_state_token(received_state: Optional[str], expected_state: Optional[str]) -> bool:
"""
Verify that the received state token matches the expected state token.
Uses constant-time comparison to prevent timing attacks.

:param received_state: The state parameter received from OAuth callback
:param expected_state: The expected state token (stored in session)
:return: True if states match, False otherwise
"""
if not received_state or not expected_state:
return False
return secrets.compare_digest(received_state, expected_state)
53 changes: 53 additions & 0 deletions src/tests/test_async_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,56 @@ async def test_auth_link(self):
response,
f"{sdk.front_endpoint}/login/oauth/authorize?client_id={sdk.client_id}&response_type=code&redirect_uri={redirect_uri}&scope=read&state={sdk.application_name}",
)

def test_generate_state_token(self):
sdk = self.get_sdk()
state1 = sdk.generate_state_token()
state2 = sdk.generate_state_token()

# Check that state tokens are strings
self.assertIsInstance(state1, str)
self.assertIsInstance(state2, str)

# Check that state tokens are not empty
self.assertGreater(len(state1), 0)
self.assertGreater(len(state2), 0)

# Check that each generated token is unique
self.assertNotEqual(state1, state2)

# Default length is 32 bytes = 64 hex characters
self.assertEqual(len(state1), 64)

# Test custom length
state_custom = sdk.generate_state_token(length=16)
self.assertEqual(len(state_custom), 32) # 16 bytes = 32 hex chars

def test_verify_state_token(self):
sdk = self.get_sdk()
state = sdk.generate_state_token()

# Valid state should match
self.assertTrue(sdk.verify_state_token(state, state))

# Different states should not match
state2 = sdk.generate_state_token()
self.assertFalse(sdk.verify_state_token(state, state2))

# Empty or None states should not match
self.assertFalse(sdk.verify_state_token("", state))
self.assertFalse(sdk.verify_state_token(state, ""))
self.assertFalse(sdk.verify_state_token(None, state))
self.assertFalse(sdk.verify_state_token(state, None))

async def test_get_auth_link_with_custom_state(self):
sdk = self.get_sdk()
custom_state = sdk.generate_state_token()
redirect_uri = "http://localhost:9000/callback"

# Test with custom state
auth_url = await sdk.get_auth_link(redirect_uri=redirect_uri, state=custom_state)
self.assertIn("state=" + custom_state, auth_url)

# Test with default state (application_name)
auth_url_default = await sdk.get_auth_link(redirect_uri=redirect_uri)
self.assertIn("state=" + sdk.application_name, auth_url_default)
53 changes: 53 additions & 0 deletions src/tests/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,56 @@ def test_modify_user(self):

self.assertIn("status", response)
self.assertIsInstance(response, dict)

def test_generate_state_token(self):
sdk = self.get_sdk()
state1 = sdk.generate_state_token()
state2 = sdk.generate_state_token()

# Check that state tokens are strings
self.assertIsInstance(state1, str)
self.assertIsInstance(state2, str)

# Check that state tokens are not empty
self.assertGreater(len(state1), 0)
self.assertGreater(len(state2), 0)

# Check that each generated token is unique
self.assertNotEqual(state1, state2)

# Default length is 32 bytes = 64 hex characters
self.assertEqual(len(state1), 64)

# Test custom length
state_custom = sdk.generate_state_token(length=16)
self.assertEqual(len(state_custom), 32) # 16 bytes = 32 hex chars

def test_verify_state_token(self):
sdk = self.get_sdk()
state = sdk.generate_state_token()

# Valid state should match
self.assertTrue(sdk.verify_state_token(state, state))

# Different states should not match
state2 = sdk.generate_state_token()
self.assertFalse(sdk.verify_state_token(state, state2))

# Empty or None states should not match
self.assertFalse(sdk.verify_state_token("", state))
self.assertFalse(sdk.verify_state_token(state, ""))
self.assertFalse(sdk.verify_state_token(None, state))
self.assertFalse(sdk.verify_state_token(state, None))

def test_get_auth_link_with_custom_state(self):
sdk = self.get_sdk()
custom_state = sdk.generate_state_token()
redirect_uri = "http://localhost:8080/callback"

# Test with custom state
auth_url = sdk.get_auth_link(redirect_uri=redirect_uri, state=custom_state)
self.assertIn("state=" + custom_state, auth_url)

# Test with default state (application_name)
auth_url_default = sdk.get_auth_link(redirect_uri=redirect_uri)
self.assertIn("state=" + sdk.application_name, auth_url_default)
Loading