-
Notifications
You must be signed in to change notification settings - Fork 9
[Perf] Streams 3: Add qd.stream_parallel() context manager #409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: hp/streams-quadrantsic-2-amdgpu-cpu
Are you sure you want to change the base?
Changes from all commits
a40ed4c
aa2fa2a
be7ad92
ce83281
880abc7
065a3b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,10 +21,12 @@ | |
| from quadrants.lang.ast.ast_transformer_utils import ( | ||
| ASTTransformerFuncContext, | ||
| ) | ||
| from quadrants.lang.ast.symbol_resolver import ASTResolver | ||
| from quadrants.lang.exception import ( | ||
| QuadrantsSyntaxError, | ||
| ) | ||
| from quadrants.lang.matrix import MatrixType | ||
| from quadrants.lang.stream import stream_parallel | ||
| from quadrants.lang.struct import StructType | ||
| from quadrants.lang.util import to_quadrants_type | ||
| from quadrants.types import annotations, ndarray_type, primitive_types | ||
|
|
@@ -295,7 +297,34 @@ def build_FunctionDef( | |
| else: | ||
| FunctionDefTransformer._transform_as_func(ctx, node, args) | ||
|
|
||
| if ctx.is_kernel: | ||
| FunctionDefTransformer._validate_stream_parallel_exclusivity(node.body, ctx.global_vars) | ||
|
|
||
| with ctx.variable_scope_guard(): | ||
| build_stmts(ctx, node.body) | ||
|
|
||
| return None | ||
|
|
||
| @staticmethod | ||
| def _is_stream_parallel_with(stmt: ast.stmt, global_vars: dict[str, Any]) -> bool: | ||
| if not isinstance(stmt, ast.With): | ||
| return False | ||
| if len(stmt.items) != 1: | ||
| return False | ||
| item = stmt.items[0] | ||
|
Comment on lines
+312
to
+314
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is |
||
| if not isinstance(item.context_expr, ast.Call): | ||
| return False | ||
| return ASTResolver.resolve_to(item.context_expr.func, stream_parallel, global_vars) | ||
|
|
||
| @staticmethod | ||
| def _validate_stream_parallel_exclusivity(body: list[ast.stmt], global_vars: dict[str, Any]) -> None: | ||
| has_sp = any(FunctionDefTransformer._is_stream_parallel_with(s, global_vars) for s in body) | ||
| if not has_sp: | ||
|
Comment on lines
+321
to
+322
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather do # <Insert fancy comment explaining what this check is doing>
if not any(FunctionDefTransformer._is_stream_parallel_with(s, global_vars) for s in body):
return |
||
| return | ||
| for stmt in body: | ||
| if not FunctionDefTransformer._is_stream_parallel_with(stmt, global_vars): | ||
| raise QuadrantsSyntaxError( | ||
| "When using qd.stream_parallel(), all top-level statements " | ||
| "in the kernel must be 'with qd.stream_parallel():' blocks. " | ||
| "Move non-parallel code to a separate kernel." | ||
| ) | ||
|
Comment on lines
+327
to
+330
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't understand why you are moving to the next line before you have to. This is weird to me. But I don't care much. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same. Not clear what
itemsis.