Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Agents

## Linting and type checking

Before committing any changes, always run:

```
poetry run ruff format .
poetry run ruff check .
poetry run mypy --show-error-codes .
```

Fix all errors before committing. Do not commit code with unused imports, formatting issues, or type errors.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "upstash-workflow"
version = "0.1.4"
version = "0.1.5"
description = "Python SDK for Upstash Workflow"
license = "MIT"
authors = ["Upstash <support@upstash.com>"]
Expand Down Expand Up @@ -29,7 +29,7 @@ packages = [{ include = "upstash_workflow" }]

[tool.poetry.dependencies]
python = "^3.8"
qstash = "^2.0.3"
qstash = "^3.4.0"

[tool.poetry.group.fastapi.dependencies]
fastapi = "^0.115.0"
Expand Down
37 changes: 37 additions & 0 deletions tests/asyncio/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from qstash import AsyncQStash
from upstash_workflow import AsyncWorkflowContext
from upstash_workflow.error import WorkflowAbort
from upstash_workflow.types import Redact
from upstash_workflow.asyncio.workflow_requests import _trigger_first_invocation
from tests.utils import (
RequestFields,
ResponseFields,
Expand Down Expand Up @@ -103,3 +105,38 @@ async def execute() -> None:
],
),
)


@pytest.mark.asyncio
async def test_trigger_workflow_with_redact(qstash_client: AsyncQStash) -> None:
redact: Redact = {"body": True, "header": ["Authorization"]}

context = AsyncWorkflowContext(
qstash_client=qstash_client,
workflow_run_id="wfr-id",
headers={},
steps=[],
url=WORKFLOW_ENDPOINT,
initial_payload="my-payload",
env=None,
retries=3,
failure_url=None,
redact=redact,
)

async def execute() -> None:
await _trigger_first_invocation(context, retries=3, redact=redact)

await mock_qstash_server(
execute=execute,
response_fields=ResponseFields(status=200, body="msgId"),
receives_request=RequestFields(
method="POST",
url=f"{MOCK_QSTASH_SERVER_URL}/v2/publish/{WORKFLOW_ENDPOINT}",
token="mock-token",
body="my-payload",
headers={
"Upstash-Redact-Fields": "body,header[Authorization]",
},
),
)
8 changes: 7 additions & 1 deletion tests/asyncio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,14 @@ async def handler(request: web.Request) -> web.Response:
text=f"assertion in mock QStash failed: {str(error)}", status=400
)

data: Any
if "/v2/batch" in str(request.url):
data = [{"messageId": response_fields.body, "deduplicated": False}]
else:
data = {"messageId": response_fields.body, "deduplicated": False}

return web.json_response(
data=[{"messageId": response_fields.body, "deduplicated": False}],
data=data,
status=response_fields.status,
)

Expand Down
36 changes: 36 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from qstash import QStash
from upstash_workflow import WorkflowContext
from upstash_workflow.error import WorkflowAbort
from upstash_workflow.types import Redact
from upstash_workflow.workflow_requests import _trigger_first_invocation
from tests.utils import (
mock_qstash_server,
RequestFields,
Expand Down Expand Up @@ -102,3 +104,37 @@ def execute() -> None:
],
),
)


def test_trigger_workflow_with_redact(qstash_client: QStash) -> None:
redact: Redact = {"body": True, "header": ["Authorization"]}

context = WorkflowContext(
qstash_client=qstash_client,
workflow_run_id="wfr-id",
headers={},
steps=[],
url=WORKFLOW_ENDPOINT,
initial_payload="my-payload",
env=None,
retries=3,
failure_url=None,
redact=redact,
)

def execute() -> None:
_trigger_first_invocation(context, retries=3, redact=redact)

mock_qstash_server(
execute=execute,
response_fields=ResponseFields(status=200, body="msgId"),
receives_request=RequestFields(
method="POST",
url=f"{MOCK_QSTASH_SERVER_URL}/v2/publish/{WORKFLOW_ENDPOINT}",
token="mock-token",
body="my-payload",
headers={
"Upstash-Redact-Fields": "body,header[Authorization]",
},
),
)
82 changes: 82 additions & 0 deletions tests/test_redact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests for redact parameter being passed to qstash client."""

from unittest.mock import MagicMock
from upstash_workflow.workflow_requests import _trigger_first_invocation
from upstash_workflow.types import Redact


def _make_workflow_context():
ctx = MagicMock()
ctx.workflow_run_id = "wfr-test-id"
ctx.url = "https://example.com"
ctx.headers = {}
ctx.request_payload = '{"test": true}'
ctx.qstash_client.message.publish_json = MagicMock()
return ctx


def test_trigger_passes_redact_body() -> None:
"""Test that redact with body is passed to publish_json."""
ctx = _make_workflow_context()
redact: Redact = {"body": True}

_trigger_first_invocation(ctx, retries=3, redact=redact)

ctx.qstash_client.message.publish_json.assert_called_once()
call_kwargs = ctx.qstash_client.message.publish_json.call_args
assert call_kwargs.kwargs["redact"] == {"body": True}


def test_trigger_passes_redact_header_all() -> None:
"""Test that redact with all headers is passed to publish_json."""
ctx = _make_workflow_context()
redact: Redact = {"header": True}

_trigger_first_invocation(ctx, retries=3, redact=redact)

call_kwargs = ctx.qstash_client.message.publish_json.call_args
assert call_kwargs.kwargs["redact"] == {"header": True}


def test_trigger_passes_redact_specific_headers() -> None:
"""Test that redact with specific headers is passed to publish_json."""
ctx = _make_workflow_context()
redact: Redact = {"header": ["Authorization", "X-API-Key"]}

_trigger_first_invocation(ctx, retries=3, redact=redact)

call_kwargs = ctx.qstash_client.message.publish_json.call_args
assert call_kwargs.kwargs["redact"] == {"header": ["Authorization", "X-API-Key"]}


def test_trigger_passes_redact_body_and_headers() -> None:
"""Test that redact with body and specific headers is passed to publish_json."""
ctx = _make_workflow_context()
redact: Redact = {"body": True, "header": ["Authorization"]}

_trigger_first_invocation(ctx, retries=3, redact=redact)

call_kwargs = ctx.qstash_client.message.publish_json.call_args
assert call_kwargs.kwargs["redact"] == {"body": True, "header": ["Authorization"]}


def test_trigger_passes_no_redact() -> None:
"""Test that redact=None is passed when no redact specified."""
ctx = _make_workflow_context()

_trigger_first_invocation(ctx, retries=3, redact=None)

call_kwargs = ctx.qstash_client.message.publish_json.call_args
assert call_kwargs.kwargs["redact"] is None


def test_trigger_no_redact_headers_in_headers() -> None:
"""Test that Upstash-Redact-Fields is NOT in the headers (qstash client handles it)."""
ctx = _make_workflow_context()
redact: Redact = {"body": True, "header": ["Authorization"]}

_trigger_first_invocation(ctx, retries=3, redact=redact)

call_kwargs = ctx.qstash_client.message.publish_json.call_args
headers = call_kwargs.kwargs["headers"]
assert "Upstash-Redact-Fields" not in headers
11 changes: 8 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,14 @@ def handle_request(self) -> None:
)
return

response_data = json.dumps(
[{"messageId": response_fields.body, "deduplicated": False}]
)
if "/v2/batch" in self.path:
response_data = json.dumps(
[{"messageId": response_fields.body, "deduplicated": False}]
)
else:
response_data = json.dumps(
{"messageId": response_fields.body, "deduplicated": False}
)

self.send_response(response_fields.status)
self.send_header("Content-type", "application/json")
Expand Down
2 changes: 1 addition & 1 deletion upstash_workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.4"
__version__ = "0.1.5"

from upstash_workflow.context.context import WorkflowContext
from upstash_workflow.serve.serve import serve
Expand Down
95 changes: 60 additions & 35 deletions upstash_workflow/asyncio/context/auto_executor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Union, Literal, cast, Any, TypeVar
import json
from qstash.message import BatchJsonRequest
from upstash_workflow.constants import NO_CONCURRENCY
from upstash_workflow.error import WorkflowError, WorkflowAbort
from upstash_workflow.workflow_requests import _get_headers
from upstash_workflow.types import DefaultStep, HTTPMethods
from upstash_workflow.types import DefaultStep
from upstash_workflow.asyncio.context.steps import _BaseLazyStep, _LazyCallStep

if TYPE_CHECKING:
Expand Down Expand Up @@ -72,7 +71,7 @@ async def submit_steps_to_qstash(
f"Unable to submit steps to QStash. Provided list is empty. Current step: {self.step_count}"
)

batch_requests = []
batch_body = []
for index, single_step in enumerate(steps):
lazy_step = lazy_steps[index]
headers = _get_headers(
Expand All @@ -93,41 +92,67 @@ async def submit_steps_to_qstash(

single_step.out = json.dumps(single_step.out)

batch_requests.append(
BatchJsonRequest(
headers=headers,
method=cast(HTTPMethods, single_step.call_method),
body=single_step.call_body,
url=single_step.call_url,
)
if single_step.call_url
else (
BatchJsonRequest(
headers=headers,
body={
"method": "POST",
"stepId": single_step.step_id,
"stepName": single_step.step_name,
"stepType": single_step.step_type,
"out": single_step.out,
"sleepFor": single_step.sleep_for,
"sleepUntil": single_step.sleep_until,
"concurrent": single_step.concurrent,
"targetStep": single_step.target_step,
"callUrl": single_step.call_url,
"callMethod": single_step.call_method,
"callBody": single_step.call_body,
"callHeaders": single_step.call_headers,
if single_step.call_url:
batch_body.append(
{
"destination": single_step.call_url,
"headers": {
"Content-Type": "application/json",
"Upstash-Method": single_step.call_method,
**headers,
},
url=self.context.url,
not_before=cast( # TODO: Change not_before type in BatchJsonRequest
Any, single_step.sleep_until if will_wait else None
"body": json.dumps(single_step.call_body),
"queue": None,
}
)
else:
step_headers = {
"Content-Type": "application/json",
**headers,
}

sleep_until = single_step.sleep_until if will_wait else None
sleep_for = single_step.sleep_for if will_wait else None

if sleep_until is not None:
step_headers["Upstash-Not-Before"] = str(sleep_until)
if sleep_for is not None:
if isinstance(sleep_for, int):
step_headers["Upstash-Delay"] = f"{sleep_for}s"
else:
step_headers["Upstash-Delay"] = str(sleep_for)

batch_body.append(
{
"destination": self.context.url,
"headers": step_headers,
"body": json.dumps(
{
"method": "POST",
"stepId": single_step.step_id,
"stepName": single_step.step_name,
"stepType": single_step.step_type,
"out": single_step.out,
"sleepFor": single_step.sleep_for,
"sleepUntil": single_step.sleep_until,
"concurrent": single_step.concurrent,
"targetStep": single_step.target_step,
"callUrl": single_step.call_url,
"callMethod": single_step.call_method,
"callBody": single_step.call_body,
"callHeaders": single_step.call_headers,
}
),
delay=cast(Any, single_step.sleep_for if will_wait else None),
)
"queue": None,
}
)
)
await self.context.qstash_client.message.batch_json(batch_requests)

await self.context.qstash_client.http.request(
path="/v2/batch",
body=json.dumps(batch_body),
headers={"Content-Type": "application/json"},
method="POST",
)
raise WorkflowAbort(steps[0].step_name, steps[0])


Expand Down
3 changes: 3 additions & 0 deletions upstash_workflow/asyncio/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
HTTPMethods,
CallResponse,
CallResponseDict,
Redact,
)

TInitialPayload = TypeVar("TInitialPayload")
Expand All @@ -51,6 +52,7 @@ def __init__(
initial_payload: TInitialPayload,
env: Optional[Dict[str, Optional[str]]] = None,
retries: Optional[int] = None,
redact: Optional[Redact] = None,
):
self.qstash_client: AsyncQStash = qstash_client
self.workflow_run_id: str = workflow_run_id
Expand All @@ -61,6 +63,7 @@ def __init__(
self.request_payload: TInitialPayload = initial_payload
self.env: Dict[str, Optional[str]] = env or {}
self.retries: int = DEFAULT_RETRIES if retries is None else retries
self.redact: Optional[Redact] = redact
self._executor: _AutoExecutor = _AutoExecutor(self, self._steps)

async def run(
Expand Down
Loading
Loading