diff --git a/docs/pr_async_completion_and_notification.md b/docs/pr_async_completion_and_notification.md new file mode 100644 index 00000000..64d567f6 --- /dev/null +++ b/docs/pr_async_completion_and_notification.md @@ -0,0 +1,646 @@ +# PR: 异步任务完成与跨卡通知机制 — 实现与设计对齐 + +## 概述 + +本 PR 是 [`docs/runtime_async.md`](runtime_async.md) 设计文档的首轮落地实现,在 `tensormap_and_ringbuffer` 运行时中新增: + +1. **Deferred Completion(延迟完成)**— 对应设计文档 §2 `complete_in_future` 机制 +2. **Notification Counter Gating(通知计数器门控)**— 对应设计文档 §2.4.2 通知计数器协议 +3. 两个硬件双卡 demo — 对应设计文档 §3(SDMA 场景)和 §4(AllReduce 通知场景) + +## 设计对齐详细对照 + +### §2.1 `complete_in_future` 属性 + +| 设计文档 | 本 PR 实现 | +|---|---| +| Task descriptor 新增 `complete_in_future: bool` 字段 | `PTO2TaskPayload::complete_in_future`(`pto_runtime2_types.h`)| +| 默认 `false`,标记为 `true` 的任务延迟完成 | 编排层通过 `PTOParam::complete_in_future = true` 设置,orchestrator 提交任务时写入 payload | + +### §2.2 Modified Worker Return Behavior + +| 设计文档 | 本 PR 实现 | +|---|---| +| 函数返回时释放 core,但 `complete_in_future` 任务**不调用** `on_task_complete` | `aicpu_executor.cpp` Phase 1:`mixed_complete` 后先调 `async_wait_list.register_deferred()`,若注册成功则跳过 `on_mixed_task_complete`,core 照常释放给下一个 ready task | +| 任务保持 RUNNING 状态 | 已注册到 `PTO2AsyncWaitList` 的任务不进入完成流程,不触发 fanout 传播,不释放 buffer | +| `completed_this_turn` 不计入 | `if (mixed_complete && !slot_state.payload->complete_in_future)` 才递增 `completed_this_turn` | + +### §2.3 `waiting_completion_count` + +| 设计文档 | 本 PR 实现 | +|---|---| +| 每次注册预期完成事件时递增,每次匹配时递减,归零时调用 `on_task_complete` | `PTO2AsyncWaitEntry::waiting_completion_count`,由 `register_deferred` 从 CQ 读取条件数初始化 | +| 支持多个完成条件组合 | `PTO2AsyncWaitEntry::conditions[PTO2_MAX_COMPLETIONS_PER_TASK]` 数组,每个条件独立轮询、独立递减 | + +### §2.4.1 Request/Completion Queue 协议 + +| 设计文档 API | 本 PR 实现 | +|---|---| +| `pto2_send_request_entry(RQ_TYPE, RQ_ID, *descriptor) → tag` | `pto_rq_kernel_api.h` 中 `pto2_send_request_entry()` — Kernel 侧向硬件引擎提交请求 | +| `pto2_save_expected_completion(CQ_TYPE, CQ_ID, tag, task_id)` | `pto_cq_kernel_api.h` 中 `pto2_save_expected_completion()` — Kernel 侧将 {type, tag, expected_value} 写入 `PTO2CompletionQueue`(设备内存),调度器轮询读取 | + +**实现细化**:设计文档中 CQ 条目由调度器直接管理。实现中采用 Kernel→共享内存→调度器 的间接传递: +- Kernel 将完成条件写入 `PTO2CompletionQueue`(`pto_cq_types.h`,在 GM 上) +- 调度器在 `register_deferred` 中读取 CQ 内容,解析为 `PTO2CompletionCondition` 注册到 `PTO2AsyncWaitList` + +### §2.4.2 Notification Counter 协议 + +| 设计文档 API | 本 PR 实现 | +|---|---| +| `pto2_send_notification(REMOTE_ADDR, atomic_op)` | `pto_notify_kernel_api.h` 中 `pto2_send_notification()` — 封装 pto-comm-isa `TNOTIFY` 指令,向远端 RDMA 窗口原子操作 | +| `pto2_save_expected_notification_counter(LOCAL_ADDR, expected, task_id)` | **拆分为两个变体**(见下文) | + +**设计细化**:设计文档的 notification counter 是 deferred completion 语义(任务已运行,等计数器到达后完成)。实现中将其拆分为两个独立机制: + +| 变体 | 语义 | 实现 | +|---|---|---| +| **Deferred Completion** | 任务已执行,等 CQ 中的 COUNTER 条件满足后完成 | `pto2_save_expected_completion(CQ_TYPE=COUNTER, CQ_ID, counter_addr, expected)` → 通过 `PTO2AsyncWaitList` 轮询 | +| **Pre-launch Gating** | 任务尚未执行,等本地计数器到达后才放行启动 | `pto2_rt_expect_notification_counter(params, counter_addr, expected)` → 通过 `PTO2NotificationWaitList` 轮询 | + +Pre-launch gating 是对设计文档的增量扩展:设计文档假设任务先运行再等完成条件,而 `async_notify_demo` 的场景是消费者**在远端通知到达之前不应启动**(避免读到脏数据)。 + +### §2.5 Scheduler Polling and Completion Resolution + +| 设计文档 | 本 PR 实现 | +|---|---| +| 调度器维护两个 watch list,轮询 CQ 和计数器 | 调度循环新增 Phase 0 和 Phase 0b | +| CQ 匹配后 `waiting_completion_count--`,归零调 `on_task_complete` | `PTO2AsyncWaitList::poll_and_complete` — 逐条检查条件,全部满足后调 `on_mixed_task_complete` + deferred release | +| Counter 到达 `expected_value` 后同上 | **Phase 0**: counter 类型条件在 `poll_and_complete` 中一起处理(deferred completion 语义)| +| | **Phase 0b**: `PTO2NotificationWaitList::poll_and_enqueue` — counter 到达后直接入 ReadyQueue(pre-launch gating 语义)| + +调度循环整体结构: + +``` +每次迭代: + Phase 0: async_wait_list.poll_and_complete() ← 设计文档 §2.5 expected_completion_list + Phase 0b: notification_wait_list.poll_and_enqueue() ← 设计文档 §2.5 expected_notification_counter_list + Phase 1: check_running_cores_for_completion() ← 已有,+register_deferred 判断 + Phase 2: dispatch_ready_tasks_to_idle_cores() ← 已有,无修改 +``` + +### §3 Example: SDMA Completion → `async_completion_demo` + +| 设计文档 §3 场景 | `async_completion_demo` 实现 | +|---|---| +| Task A: SDMA prefetch,`complete_in_future=True` | Producer (func_id=2): `kernel_producer_async.cpp` — 通过 `TGET_ASYNC` 发起异步远程 RDMA 读取,将完成条件写入 CQ,立即返回 | +| Task B: 消费 tensor_X,等 A 完成后才就绪 | Consumer (func_id=1): `kernel_consumer.cpp` — 依赖 producer 输出,调度器在 SDMA 完成后才放行 | +| Core 释放后可执行其他任务 | Producer 返回即释放 AICore,调度器继续分派其他 ready task | +| 调度器轮询 CQ,匹配后完成 A | Phase 0 `poll_and_complete` 检测 SDMA event 完成 → `on_mixed_task_complete(producer)` → consumer 变 READY | + +### §4 Example: Notification Counter → `async_notify_demo` + +| 设计文档 §4 场景 | `async_notify_demo` 实现 | +|---|---| +| Task AR: 写本地数据到共享 GM,向所有 peer 发 `ATOMIC_INCREMENT` | Producer (func_id=0): `kernel_producer_notify.cpp` — 计算 `out = in * 2`,然后 `TNOTIFY(AtomicAdd)` 向对端窗口计数器 +1 | +| 等本地计数器达到 expected_value 后完成 | **变体**:本 PR 采用 pre-launch gating 而非 deferred completion — Consumer 提交时声明 `expect_notification_counter(addr, 1)`,调度器在计数器 ≥ 1 前不启动 Consumer | +| 调度器轮询计数器 | Phase 0b `poll_and_enqueue` 每次迭代 `cache_invalidate_range` + 读计数器,到达 1 后将 Consumer 入 ReadyQueue | +| Rank 1 故意延迟 2M 次循环再通知 | `kernel_producer_notify.cpp` line 96-98: `if (my_rank == 1) { for (volatile int i = 0; i < 2000000; ++i) {} }` — 使得时序差异可见 | + +## 实现文件清单 + +### Runtime 核心(对应设计文档 §2) + +| 文件 | 设计文档对应 | 职责 | +|---|---|---| +| `runtime/pto_async_wait.h` | §2.3 + §2.5 | `PTO2AsyncWaitList`(wait list + polling)、`PTO2CompletionCondition`、`poll_and_complete`、`register_deferred` | +| `runtime/pto_cq_types.h` | §2.4.1 | `PTO2CompletionQueue`、`PTO2CompletionEntry` — CQ 数据结构 | +| `runtime/pto_cq_kernel_api.h` | §2.4.1 | `pto2_save_expected_completion` — Kernel 侧 CQ 写入 API | +| `runtime/pto_rq_kernel_api.h` | §2.4.1 | `pto2_send_request_entry` — Kernel 侧 RQ 提交 API | +| `runtime/pto_notify_kernel_api.h` | §2.4.2 | `pto2_send_notification`、`pto2_save_expected_notification_counter` — Kernel 侧通知 API | +| `runtime/pto_scheduler.h` | §2.5 | `PTO2NotificationWaitList` — pre-launch gating watch list | +| `runtime/pto_shared_memory.h` | — | `pto2_record_scheduler_error` — 错误报告辅助 | +| `runtime/pto_types.h` | §2.3 | `PTO2AsyncEngine` 枚举、async context 地址 | +| `runtime/pto_runtime2_types.h` | §2.1 | `PTO2TaskPayload::complete_in_future` 字段 | +| `aicpu/aicpu_executor.cpp` | §2.2 + §2.5 | Phase 0/0b 轮询逻辑、`register_deferred` 调用点、core 释放行为修改 | + +### 编排层 API(对应设计文档 §2.4.2 扩展) + +| 文件 | 新增 API | 说明 | +|---|---|---| +| `orchestration/pto_orchestration_api.h` | `pto2_rt_expect_notification_counter(params, addr, expected)` | 编排侧 pre-launch gating 声明(设计文档的增量扩展)| + +### 平台层与 Python 层 + +| 文件 | 职责 | +|---|---| +| `platform/include/host/comm.h` | 后端无关通信 C API(5 函数)| +| `platform/onboard/host/comm_hccl.cpp` | HCCL 硬件后端(RDMA 窗口分配)| +| `platform/sim/host/comm_sim.cpp` | POSIX 共享内存仿真后端 | +| `platform/include/common/comm_context.h` | `CommDeviceContext` — 设备侧 RDMA 窗口上下文 | +| `examples/scripts/distributed_code_runner.py` | L3 编排器:compile → run → verify | +| `examples/scripts/distributed_worker.py` | per-rank 独立进程 | +| `python/bindings.py` | 分布式 launch 接口 + `RUNTIME_ENV` 注入 | + +### 示例 + +| Demo | 设计文档对应 | 说明 | +|---|---|---| +| `async_completion_demo` | §3 SDMA 场景 | 2 卡,Producer 异步 RDMA 读取 + CQ 延迟完成 | +| `async_notify_demo` | §4 通知计数器场景 | 2 卡,Producer TNOTIFY 跨卡通知 + Consumer pre-launch gating | + +## 与设计文档的差异总结 + +| 设计文档描述 | 实际实现 | 原因 | +|---|---|---| +| Notification counter 统一为 deferred completion(任务已运行,等计数器后完成)| 拆分为 deferred completion + pre-launch gating 两个独立机制 | `async_notify_demo` 场景中消费者应在远端通知到达**之前**就不启动,避免读脏数据。Pre-launch gating 比 deferred completion 更高效(任务无需先运行再等待)| +| CQ 轮询由调度器直接读硬件 CQ | Kernel 将条件写入 GM 上的 `PTO2CompletionQueue`,调度器读 GM | AICPU 调度器无法直接访问 AICore 的硬件 CQ 寄存器,需通过 GM 中转 | +| `pto2_save_expected_notification_counter` 在 Kernel 内调用 | Pre-launch gating 变体 `pto2_rt_expect_notification_counter` 在编排层调用 | 编排层提交任务时即可声明门控条件,无需在 Kernel 内处理 | +| 设计文档 §5.1 列出 4 个新 API | 全部实现,另外新增 `pto2_rt_expect_notification_counter` 编排层 API | 编排层 pre-launch gating 是对设计的补充 | + +## 面向后续重构的设计结论:移除 Pre-launch Gating,统一到 Deferred Completion + +本节不是对当前 PR 实现的描述,而是针对协作者评审的**后续重构设计**。目标是回答两个问题: + +1. 当前 `notification_wait_list` 方案为什么不够理想? +2. 如果要尽量少改 runtime,notification wait 应该建模成什么任务? + +### 需要解决的问题 + +当前实现为了支持 `async_notify_demo`,引入了与 deferred completion 并列的第二套机制: + +- 编排层通过 `pto2_rt_expect_notification_counter(params, addr, expected)` 给 task 打上 launch gating 元数据 +- orchestrator / scheduler 在 fanin 满足后,不是将 task 放入 ReadyQueue,而是放入 `PTO2NotificationWaitList` +- AICPU 调度循环新增 Phase 0b,轮询计数器并在满足条件后再将 task 入队 + +这个方案能工作,但有三个问题: + +1. **调度语义分裂** + - 当前 runtime 原本只有一种核心语义:task 是否 READY 由 fanin 决定;task 是否 COMPLETE 由 worker return 或 deferred completion 决定 + - pre-launch gating 额外引入了“fanin 已满足但仍不可调度”的隐藏状态,导致 READY 判定不再统一 + +2. **重复了一套 polling / wakeup 机制** + - `PTO2AsyncWaitList` 已经负责“task return 后等待外部条件满足再 complete” + - `PTO2NotificationWaitList` 又负责“task launch 前等待外部条件满足再 enqueue” + - 两者都在做“轮询某个外部条件,然后触发 task 状态推进” + +3. **侵入 scheduler 热路径** + - payload / params 里增加了 `has_launch_counter` / `launch_counter_addr` / `launch_counter_expected` + - orchestrator 和 scheduler 的 ready 判定路径都被改写 + - AICPU 主循环被迫引入 Phase 0b 和单独的 notification polling 锁 + +从架构分层上看,这不符合 `architecture.md` 中“orchestrator 负责建图、scheduler 负责执行已有图和完成已有 task”的简洁边界。 + +### 两种候选建模方式 + +#### 候选 A:保留当前方案,notification wait 作为 launch gating + +语义: + +- 真正的 consumer task 在编排时就提交 +- 但它在 notify 到达前不会被 AICPU dispatch +- notify 到达后,scheduler 再把该 task 放入 ReadyQueue + +执行流程: + +1. orchestrator 提交 producer 和 consumer +2. consumer 的 fanin 满足后,不进入 ReadyQueue,而进入 `notification_wait_list` +3. AICPU Phase 0b 轮询本地 counter +4. counter 达标后,AICPU 将 consumer 入队 +5. consumer 被 dispatch 并运行 + +优点: + +- consumer 自身不会在 notify 前被 dispatch +- 没有额外的 proxy task + +缺点: + +- scheduler 需要继续维护专用 launch-gating 机制 +- READY 判定不再只由 fanin 决定 +- notification 相关逻辑散落在 params、payload、orchestrator、scheduler、executor 多处 + +#### 候选 B:引入显式 notification-wait task,并统一使用 deferred completion + +语义: + +- orchestrator 显式提交一个 `wait_task` +- `wait_task` 是一个非常短的 task:它只负责注册“本地 counter 达到 expected_value 才 complete”的 deferred completion 条件,然后立即 return +- consumer 不再直接依赖 launch gating,而是依赖 `wait_task` 的真正完成 + +这里有一个关键约束:**当前 runtime 的依赖关系是通过 TensorMap 上的 tensor producer/consumer 边来发现的**,不是通过显式 `task_id -> task_id` API 建边。 + +因此,consumer 对 `wait_task` 的依赖,不能只在概念上写成“depends on wait_task”,而要在编排时通过一个**显式 dependency token tensor** 来表达: + +- `wait_task` 输出一个 token tensor(可以是 1-element `int32`,仅作为依赖 token 使用) +- `consumer` 把这个 token tensor 作为一个额外 input +- 这样 orchestrator 就会通过 TensorMap 自动建立 `wait_task -> consumer` 的 fanin/fanout 关系 + +推荐的 DAG 形态: + +```text +producer --------\ + \ + -> consumer + / +wait_task --------/ +``` + +其中: + +- `producer` 负责写本地输出并通过 `TNOTIFY` 更新对端 counter +- `wait_task` 负责“把 notify 条件转换成一个 deferred-completion task” +- `consumer` 同时依赖 `producer` 和 `wait_task`,其中对 `wait_task` 的依赖通过 token tensor 表达 + +执行流程: + +1. orchestrator 预先提交 `producer`、`wait_task`、`consumer` +2. `wait_task` 一旦 READY,就像普通 task 一样被 AICPU dispatch +3. `wait_task` 在 worker 内调用 `pto2_save_expected_notification_counter(...)` 或等价接口,把本地 counter 条件写入 CQ +4. `wait_task` 立即 return,但由于 `complete_in_future=true`,其逻辑状态仍未 COMPLETE +5. AICPU Phase 0 通过现有 `async_wait_list` 轮询该 counter 条件 +6. counter 达标后,AICPU 走现有 deferred-completion 路径将 `wait_task` complete +7. `consumer` 只有在 `producer` 和 `wait_task` 都完成后才会 READY + +优点: + +- 不需要单独的 pre-launch gating 语义 +- scheduler 继续只理解两件事:`READY` 和 `deferred completion` +- notification 被建模成 task graph 里的显式节点,依赖关系更清晰 +- 更符合“用已有机制表达新语义,而不是额外引入状态机分支”的原则 + +代价: + +- 图里会多一个 proxy task +- 需要定义 notification-wait task 的提交方式和示例写法 + +### 建议的新 API:用“提交 wait task”替代“给 consumer 打 launch gating 标记” + +当前 API: + +```cpp +pto2_rt_expect_notification_counter(params_consumer, counter_addr, expected); +pto2_rt_submit_aiv_task(CONSUMER_KERNEL_ID, params_consumer); +``` + +这个 API 的问题是:它把 notification wait 绑定在 consumer 本身上,因此 runtime 只能在 scheduler 内部增加一个“launch 前门控”的特殊分支。 + +推荐替换为两层 API: + +#### 1. Worker 侧 API:保留现有 counter-based deferred completion + +worker 侧不需要引入新的调度语义,只需要继续使用现有 API: + +```cpp +pto2_save_expected_notification_counter(cq, local_counter_addr, expected_value); +``` + +也就是说,worker 的职责只是: + +- 接收本地 counter 地址 +- 接收 expected value +- 接收 cq 地址 +- 将该条件写入 CQ +- 立即 return + +#### 2. Orchestration 侧 API:新增“提交 notification-wait task”的 helper + +推荐新增一个编排 helper,语义类似: + +```cpp +Tensor pto2_rt_submit_notification_wait_task( + int32_t kernel_id, + uint64_t local_counter_addr, + uint32_t expected_value); +``` + +返回值: + +- 一个 token tensor +- 该 token tensor 只用于建依赖边,不承载业务数据 + +helper 内部做的事情: + +1. 创建一个 1-element token tensor +2. 创建 `PTOParam params_wait` +3. `params_wait.add_output(token_tensor)` +4. `params_wait.add_scalar(local_counter_addr)` +5. `params_wait.add_scalar(expected_value)` +6. 分配 `cq_addr = pto2_rt_alloc_cq()` +7. 调用 `pto2_rt_submit_aiv_task_deferred(kernel_id, params_wait, cq_addr)` +8. 返回 token tensor 给 orchestrator + +这样,编排侧就从: + +```cpp +pto2_rt_expect_notification_counter(params_consumer, counter_addr, expected); +``` + +变成: + +```cpp +Tensor notify_token = pto2_rt_submit_notification_wait_task( + WAIT_KERNEL_ID, counter_addr, expected); +params_consumer.add_input(notify_token); +``` + +这比给 consumer 打 launch-gating 标记更符合当前 runtime 的依赖表达方式,因为它显式产出了一个 producer tensor,并让 consumer 像依赖其他中间 tensor 一样依赖它。 + +### `async_notify_demo` 中推荐的新 API 写法 + +当前 `async_notify_demo` 的核心逻辑是: + +- `producer` 写 `out = in * 2` +- `producer` 对 peer 发 `TNOTIFY(AtomicAdd)` +- `consumer` 读取 `out` 和本地 `notify_counter` +- 现实现通过 `pto2_rt_expect_notification_counter(...)` 对 consumer 做 launch gating + +推荐重构后的编排代码形态: + +```cpp +uint32_t data_shapes[1] = {128 * 128}; +uint32_t token_shapes[1] = {1}; + +Tensor ext_in = make_tensor_external(in_ptr, data_shapes, 1, DataType::FLOAT32); +Tensor ext_out = make_tensor_external(out_ptr, data_shapes, 1, DataType::FLOAT32); +Tensor ext_result = make_tensor_external(result_ptr, data_shapes, 1, DataType::FLOAT32); + +PTOParam params_producer; +params_producer.add_input(ext_in); +params_producer.add_output(ext_out); +params_producer.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr); +params_producer.add_scalar((uint64_t)(uintptr_t)comm_ctx); +pto2_rt_submit_aiv_task(PRODUCER_KERNEL_ID, params_producer); + +Tensor notify_token = pto2_rt_submit_notification_wait_task( + WAIT_KERNEL_ID, + (uint64_t)(uintptr_t)notify_counter_ptr, + 1); + +PTOParam params_consumer; +params_consumer.add_input(ext_out); +params_consumer.add_input(notify_token); +params_consumer.add_output(ext_result); +params_consumer.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr); +pto2_rt_submit_aiv_task(CONSUMER_KERNEL_ID, params_consumer); +``` + +推荐的 wait kernel 参数形态: + +```cpp +args[0] = &Tensor(token_tensor) // 仅用于依赖建图,可不实际消费 +args[1] = local_counter_addr +args[2] = expected_value +args[3] = cq_addr +``` + +wait kernel 的逻辑: + +```cpp +extern "C" __aicore__ void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* token_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + auto* counter_addr = reinterpret_cast(args[1]); + uint32_t expected = static_cast(args[2]); + uint64_t cq_addr = static_cast(args[3]); + + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + pto2_save_expected_notification_counter(cq, counter_addr, expected); + pto2_cq_flush(cq); + + (void)token_tensor; +} +``` + +这里 token tensor 的作用只有一个:让 consumer 能在图上依赖 `wait_task`。它不是业务数据,不要求 consumer kernel 真正读取它。 + +### AllReduce / Barrier 场景中的 API 形态 + +在 distributed allreduce/barrier 场景中,这个 API 更容易看清价值。 + +假设每个 rank 都有: + +- `local_reduce_task`:完成本地 partial reduce +- `notify_peers_task`:对所有 peer 执行 `TNOTIFY(AtomicAdd)` +- `barrier_wait_task`:等待本地 counter 达到 `nranks` +- `post_reduce_task`:继续执行后续计算 + +推荐 DAG: + +```text +local_reduce_task --> notify_peers_task --------\ + \ + -> post_reduce_task + / +barrier_wait_task ------------------------------/ +``` + +其中: + +- `notify_peers_task` 负责向所有 peer 发通知 +- `barrier_wait_task` 输出一个 barrier token tensor +- `post_reduce_task` 通过 `add_input(barrier_token)` 依赖 barrier 完成 + +推荐的编排代码形态: + +```cpp +Tensor barrier_token = pto2_rt_submit_notification_wait_task( + BARRIER_WAIT_KERNEL_ID, + local_barrier_counter_addr, + nranks); + +PTOParam params_post; +params_post.add_input(reduced_tensor); +params_post.add_input(barrier_token); +params_post.add_output(next_tensor); +pto2_rt_submit_aiv_task(POST_KERNEL_ID, params_post); +``` + +这样 barrier 的语义就很直观: + +- `barrier_wait_task` 完成之前,`post_reduce_task` 永远不会 READY +- 但 scheduler 不需要理解“barrier task”这种特殊任务 +- 它只需要理解一个普通的 deferred-completion task 和一条普通的 tensor dependency + +### 为什么推荐 token tensor,而不是新建 task-to-task 依赖 API + +这一节的结论**只针对 `tensormap_and_ringbuffer` runtime**。 + +理论上也可以设计一个新的 orchestrator API,显式声明: + +```cpp +pto2_rt_add_dependency(consumer_task, wait_task); +``` + +但这会明显更侵入 runtime,因为当前实现: + +- 没有稳定暴露“提交后 task handle”的上层 API +- fanin/fanout 的发现依赖 TensorMap 查找 producer tensor +- task ring / slot state / dependency pool 都是围绕 tensor producer lookup 建的 + +相比之下,token tensor 方案: + +- 完全复用现有 TensorMap 依赖构建逻辑 +- 不需要引入新的 task handle API +- 不需要引入 scheduler 侧新的依赖边类型 + +因此,从“最小侵入”原则看,推荐先采用 **wait_task + token tensor**,而不是额外设计 task-to-task dependency 接口。 + +### `aicpu_build_graph` 的情况(简述) + +上面的 token tensor 方案主要是 `tensormap_and_ringbuffer` 的约束结果。`aicpu_build_graph` 已经有显式依赖 API: + +- submit 返回 `PTO2TaskId` +- orchestration 可直接调用 `pto2_rt_add_dependency(rt, producer, consumer)` + +因此在 `aicpu_build_graph` 中,不需要 token tensor。推荐直接写成: + +```cpp +PTO2TaskId t_wait = pto2_rt_submit_notification_wait_task( + rt, WAIT_KERNEL_ID, local_counter_addr, expected_value); +pto2_rt_add_dependency(rt, t_wait, t_consumer); +``` + +AllReduce / barrier 场景也同理:`wait_task` 作为普通 deferred-completion task 提交,然后通过 `pto2_rt_add_dependency(...)` 把它连到 post-barrier task 即可。 + +### 结论:两种 runtime 的推荐表达不同 + +因此,这一设计结论应该拆成两层: + +1. **共性结论** + - 都推荐把 notification wait 建模成显式 `wait_task + deferred completion` + - 都不推荐继续依赖 scheduler 专用的 pre-launch gating 语义 + +2. **runtime-specific 表达方式** + - 对 `tensormap_and_ringbuffer`:推荐 `wait_task + token tensor` + - 原因:依赖通过 TensorMap/tensor producer 自动发现 + - 对 `aicpu_build_graph`:推荐 `wait_task + explicit dependency` + - 原因:submit 返回 `PTO2TaskId`,且已有 `pto2_rt_add_dependency(...)` + +也就是说,**wait_task 本身是通用设计,token tensor 不是**。token tensor 只是 `tensormap_and_ringbuffer` 在当前机制下的最小侵入落地方式。 + +### 为什么推荐候选 B + +关键点在于:我们真正需要保证的是**consumer 不得早于 notify 条件满足而运行**,而不是“wait_task 本身绝不能在 notify 前被 dispatch”。 + +只要 `wait_task`: + +- 被标记为 `complete_in_future=true` +- 在 worker 内只做“注册等待条件”而不做实际数据消费 +- return 后不触发 `on_task_complete` + +那么它即使被提前 dispatch,也不会让 consumer 提前运行。因为 consumer 依赖的是 `wait_task` 的**真正完成**,而不是它的 dispatch 或 return。 + +换句话说: + +- `wait_task` 提前 dispatch:可以 +- `wait_task` 提前 return:可以 +- `wait_task` 提前 complete:不可以 +- `consumer` 提前 READY:不可以 + +这恰好就是 deferred completion 已经提供的语义。 + +### AICPU / Scheduler 在推荐方案中的职责 + +推荐方案下,AICPU **不负责动态创建 task**。它只做两件事: + +1. 像平常一样 dispatch 已经由 orchestrator 提交好的 `wait_task` +2. 像平常一样在 async wait path 中轮询它的 completion condition,并在满足时 complete 它 + +职责分工保持为: + +- **Orchestrator**:建图,预先提交 `producer` / `wait_task` / `consumer` +- **AICPU Scheduler**:dispatch ready task;轮询 deferred completions;推动 fanin/fanout +- **AICore Worker**:执行 `wait_task` 的短函数体,只注册 completion condition 后返回 + +不推荐的做法是: + +- AICPU 在 observe 到 notify 后再“新建一个 task” +- 或者 scheduler 维护一条“event → create task → attach dependency → complete task”的新链路 + +那会引入新的 task creation / dependency hookup / flow-control 复杂度,明显比复用已有 deferred-completion 机制更侵入 runtime。 + +### 推荐方案对 runtime 的实际改动范围 + +如果采用候选 B,目标应当是**删除 launch gating 机制,保留并小幅扩展 deferred completion**。 + +#### 可以删除的内容 + +- `PTOParam::has_launch_counter` +- `PTOParam::launch_counter_addr` +- `PTOParam::launch_counter_expected` +- `PTO2TaskPayload` 中对应字段 +- `pto2_rt_expect_notification_counter(...)` +- `PTO2NotificationWaitList` +- scheduler 中“fanin 满足但不入 ReadyQueue”的分支 +- executor 的 Phase 0b notification polling 逻辑 + +#### 需要保留并扩展的内容 + +- `complete_in_future` +- per-task CQ +- `PTO2AsyncWaitList` +- `PTO2CompletionType::COUNTER` +- `pto2_save_expected_notification_counter(...)` + +#### 需要补强的点 + +当前 deferred `COUNTER` 路径在语义上已经存在,但要用于 remote notify,还需要补一个关键点: + +- **在 deferred-completion 的 counter polling 路径中加入 cache invalidation** + +原因: + +- `TNOTIFY` / RDMA 对本地 counter 的更新可能绕过 AICPU cache +- 当前 `notification_wait_list` 路径在读 counter 前会显式 `cache_invalidate_range(...)` +- 但 `PTO2AsyncWaitList::test_notification()` 目前只是直接读 `*counter_addr` + +因此,如果要把 notify waiting 统一迁移到 deferred completion,必须把 cache visibility 逻辑迁到 `PTO2AsyncWaitList` 的 `COUNTER` polling 路径里。否则 remote notify 可能已经发生,但 AICPU 仍读到 stale counter。 + +### 推荐方案的示例执行时序 + +以当前 `async_notify_demo` 为例,推荐的重构后时序应为: + +1. Rank N orchestrator 提交: + - `producer` + - `wait_task(complete_in_future=true)` + - `consumer(depends on producer output + wait_task token)` +2. `wait_task` 被 AICPU 像普通 ready task 一样 dispatch +3. `wait_task` worker 将 `{type=COUNTER, addr=local_notify_counter, expected=1}` 写入 CQ +4. `wait_task` return,但保持 incomplete +5. `producer` 正常运行,完成本地输出,并通过 `TNOTIFY` 更新 peer 的 counter +6. Peer 侧 AICPU 在 Phase 0 中轮询 `async_wait_list` +7. 观察到本地 counter 已达标后,complete `wait_task` +8. `consumer` 的 fanin 全部满足,进入 READY +9. `consumer` 被 dispatch,读取的本地数据和 notify 条件都已有效 + +### 结论 + +面向后续重构,建议将 notification wait 统一收敛到**显式 wait_task + deferred completion** 模型: + +- 不再把 notify 建模成“launch 前的特殊门控状态” +- 而是把 notify 建模成“一个普通 task 的延迟完成条件” + +这样做的核心收益是: + +- 复用已有 async-completion 机制 +- 删除 scheduler 专用分支 +- 将复杂度从 scheduler 状态机迁回 task graph 建模 +- 更符合本项目现有 runtime 的职责划分 + +### 供协作者评审的重点问题 + +在正式改代码前,建议协作者重点确认以下几点: + +1. 是否接受 graph 中新增一个显式 `wait_task`,以换取 scheduler 机制简化? +2. `wait_task` 是否应当直接零 fanin 提前提交,还是仍依赖某个本地前置 task? +3. 是否接受使用 token tensor 来表达 `wait_task -> consumer` 的依赖,而不是新增 task-to-task dependency API? +4. 是否同意删除 `notification_wait_list`,将 notify waiting 全部并入 `PTO2AsyncWaitList::COUNTER`? +5. `COUNTER` polling 的 cache invalidation 应放在 `test_notification()` 内部,还是抽成统一 helper? +6. 对 `aicpu_build_graph`,是否应直接采用 `wait_task + pto2_rt_add_dependency(...)`,而不引入 token tensor? +7. `async_notify_demo` 是否应该在重构后作为“deferred counter completion”的主示例,而不是“launch gating”示例? + +## 测试 + +```bash +# 运行两个 async demo(需要 2 张 Ascend 卡) +CANN_ENV_SCRIPT=/path/to/set_env.sh \ + examples/scripts/run_async_tests.sh --devices 6,7 +``` + +## 变更统计 + +- **41 文件变更**,+4291 / -15 行 +- Runtime 核心:~1200 行(async wait list, scheduler, CQ/RQ/notify API, shared memory) +- 平台层:~850 行(comm.h + HCCL/sim 实现 + CommDeviceContext) +- Python 基础设施:~970 行(distributed runner + worker + bindings) +- 示例代码:~1100 行(两个完整 demo) diff --git a/docs/runtime_async.md b/docs/runtime_async.md new file mode 100644 index 00000000..8723d67b --- /dev/null +++ b/docs/runtime_async.md @@ -0,0 +1,344 @@ +# Runtime Extension for Asynchronous Hardware Engine Requests + +## 1. Background: Current Runtime Model + +In the current pypto runtime design, each level of the hierarchy has three roles: + +| Role | Responsibility | +|---|---| +| **Orchestrator** | Submits tasks to workers at the current level, or to lower-level orchestrators | +| **Scheduler** | Manages task readiness, dependency tracking, and buffer lifecycle | +| **Worker** | Executes worker functions (tasks), typically run-to-completion | + +### 1.1 Run-to-Completion Task Lifecycle + +A typical task is a **run-to-completion function**. When the worker function returns, the scheduler calls `pto2_scheduler_on_task_complete`, which performs two dependency-tracking operations: + +1. **Fanout propagation (consumer readiness)**: Walks the completing task's fanout list and increments each consumer task's `fanin_refcount`. When a consumer's `fanin_refcount` equals its `fanin_count`, that consumer transitions to READY and is placed in the ready queue. + +2. **Fanin retirement (producer release)**: Walks the completing task's fanin list and increments each producer's `fanout_refcount`. When a producer's `fanout_refcount` equals its `fanout_count`, the producer transitions to CONSUMED and its output buffers become eligible for release. + +Note that **fanin, fanout, and `ref_count` are tracked at the task level** (producers and consumers), not at the individual tensor level. + +### 1.2 The Asynchronous Hardware Engine Challenge + +This run-to-completion model assumes that **function return = task complete**. This creates a fundamental problem for worker functions that submit requests to asynchronous hardware engines: + +| Hardware Engine | Function | +|---|---| +| **SDMA** | System DMA — bulk data movement between memory regions | +| **RoCE** | RDMA over Converged Ethernet — inter-node network data transfer | +| **UMA** | Unified Memory Access — cross-die or cross-chip memory operations | +| **CCU** | Cache Coherence Unit — cache management and coherence operations | + +When a worker function submits a request to one of these engines and then returns, the hardware engine may still be: +- **Reading** from the task's IN parameters (the buffer must not be released yet), or +- **Writing** to the task's INOUT/OUT parameters (the data is not yet valid for consumers). + +If the scheduler calls `pto2_scheduler_on_task_complete` at function return time, it would prematurely release producer buffers and unblock consumer tasks before the hardware operation finishes — leading to data corruption or races. + +**Goal**: Keep the existing runtime mechanisms for task lifecycle management, buffer management, dependency resolution, and task scheduling intact, while adding the ability to defer task completion until asynchronous hardware operations finish. + +--- + +## 2. Design: `pl.complete_in_future` and Deferred Completion + +### 2.1 New Function Attribute: `pl.complete_in_future` + +A new optional attribute is added to the `pl.function` definition: + +```python +@pl.function(complete_in_future=True) +def sdma_prefetch(src_tensor, dst_tensor): + ... +``` + +By default, `complete_in_future` is `False` (standard run-to-completion). Functions that submit asynchronous hardware requests or rely on external completion signals should be marked with `complete_in_future=True`. + +### 2.2 Modified Worker Return Behavior + +When a worker function returns, the runtime performs the following: + +``` +on worker function return(task_id): + (a) Release the core / worker thread → available to execute the next ready task + + (b) if task.complete_in_future: + // Do NOT call pto2_scheduler_on_task_complete. + // Task remains in RUNNING state (logically incomplete). + else: + pto2_scheduler_on_task_complete(sched, task_id) // standard path +``` + +A `complete_in_future` task's function return **releases the core** but does **not complete the task**. The scheduler keeps the task in RUNNING state. Dependency propagation and buffer release are deferred. + +### 2.3 Task Descriptor Extensions + +Two fields are added to `PTO2TaskDescriptor`: + +| Field | Type | Default | Description | +|---|---|---|---| +| `complete_in_future` | `bool` | `false` | Whether this task defers completion beyond function return | +| `waiting_completion_count` | `int32_t` | `0` | Number of outstanding completion events before the task is truly complete | + +The `waiting_completion_count` is incremented each time the task registers an expected completion event (via the APIs below). When the count reaches zero, the runtime calls `pto2_scheduler_on_task_complete`. + +### 2.4 New Runtime APIs + +Four new APIs are introduced, called from within the worker function body: + +#### 2.4.1 Request/Completion Queue Protocol + +```c +tag = pto2_send_request_entry(RQ_TYPE, RQ_ID, *descriptor); +success = pto2_save_expected_completion(CQ_TYPE, CQ_ID, tag, task_id); +``` + +| Parameter | Description | +|---|---| +| `RQ_TYPE` / `CQ_TYPE` | Engine type: `SDMA`, `RoCE`, `UMA`, `CCU`, etc. | +| `RQ_ID` | Index of the request queue for the given engine type | +| `CQ_ID` | Index of the completion queue for the given engine type | +| `descriptor` | Engine-specific request descriptor (DMA address, length, etc.) | +| `tag` | Unique handle returned by `pto2_send_request_entry`, used to match the completion entry | +| `task_id` | The task that should be completed when this tag appears in the CQ | + +**Workflow**: The worker function calls `pto2_send_request_entry` to submit a request to a hardware engine. The returned `tag` uniquely identifies this request. The worker then calls `pto2_save_expected_completion` to register this tag in the scheduler's **expected completion list**. This also increments `waiting_completion_count` for the task. + +#### 2.4.2 Notification Counter Protocol + +```c +pto2_send_notification(REMOTE_NOTIFICATION_COUNTER_ADDRESS, atomic_op); +pto2_save_expected_notification_counter(LOCAL_NOTIFICATION_COUNTER_ADDRESS, expected_value, task_id); +``` + +| Parameter | Description | +|---|---| +| `REMOTE_NOTIFICATION_COUNTER_ADDRESS` | Memory address of a counter on a remote node (or local) | +| `atomic_op` | Atomic operation to perform (e.g., `ATOMIC_INCREMENT`) | +| `LOCAL_NOTIFICATION_COUNTER_ADDRESS` | Memory address of a counter on the local node | +| `expected_value` | The value at which the counter triggers completion | +| `task_id` | The task that should be completed when the counter reaches `expected_value` | + +**Workflow**: `pto2_send_notification` performs a remote atomic memory operation on the target counter address. `pto2_save_expected_notification_counter` registers the local counter in the scheduler's **expected notification counter list** and increments `waiting_completion_count` for the task. + +### 2.5 Scheduler Polling and Completion Resolution + +The runtime scheduler maintains two watch lists: + +1. **Expected completion list**: Entries of the form `{CQ_TYPE, CQ_ID, tag, task_id}` +2. **Expected notification counter list**: Entries of the form `{counter_address, expected_value, task_id}` + +When either list is non-empty, the scheduler **polls** the corresponding completion queues and counter addresses: + +``` +scheduler_poll_loop: + for each entry in expected_completion_list: + if CQ[entry.CQ_TYPE][entry.CQ_ID] contains entry.tag: + remove entry from list + task = get_task(entry.task_id) + task.waiting_completion_count-- + if task.waiting_completion_count == 0: + pto2_scheduler_on_task_complete(sched, entry.task_id) + + for each entry in expected_notification_counter_list: + if *entry.counter_address >= entry.expected_value: + remove entry from list + task = get_task(entry.task_id) + task.waiting_completion_count-- + if task.waiting_completion_count == 0: + pto2_scheduler_on_task_complete(sched, entry.task_id) +``` + +**Multiple completion conditions**: A single task may register multiple expected completions and/or notification counters. Each registration increments `waiting_completion_count`. Each match decrements it. The task completes only when `waiting_completion_count` reaches zero — i.e., **all** registered conditions are satisfied. + +--- + +## 3. Example: Task Waiting for SDMA Completion + +A data prefetch scenario: Task A uses SDMA to move a tensor from host memory to device memory. Task B is a compute task that consumes this tensor. + +### 3.1 Task DAG + +``` +Task A (SDMA prefetch, complete_in_future=True) + OUT: tensor_X ───fanout───▶ Task B (compute, run-to-completion) + IN: tensor_X +``` + +### 3.2 Timeline + +``` +Worker Core +──────────────────────────────────────────────────────────────────▶ time +│ Task A │ Task C (unrelated) │ +│ 1. Build SDMA descriptor │ (core reused) │ +│ 2. tag = pto2_send_request_entry( │ │ +│ SDMA, rq_id, &desc) │ │ +│ 3. pto2_save_expected_completion( │ │ +│ SDMA, cq_id, tag, A) │ │ +│ 4. return → core released │ │ +├────────────────────────────────────┼───────────────────────────┤ + Task A status: RUNNING (not COMPLETED) + +SDMA Engine +──────────────────────────────────────────────────────────────────▶ time +│ DMA transfer: host mem ──────────▶ device mem │ +│ completion entry posted ──┤ + +Scheduler +──────────────────────────────────────────────────────────────────▶ time +│ Watch list: [{SDMA, cq_id, tag, A}] │ +│ ...poll... CQ match found! │ +│ A.waiting_completion_count-- → 0 │ +│ → pto2_scheduler_on_task_complete(A) │ +│ fanout: B.fanin_refcount++ → B becomes READY │ + +Worker Core +──────────────────────────────────────────────────────────────────▶ time + │ Task B │ + │ compute on │ + │ tensor_X (valid)│ +``` + +### 3.3 Key Observations + +1. **Core reuse**: Task A returns at step 4 and the core immediately picks up Task C. No core time is wasted waiting for DMA. + +2. **Task A stays in RUNNING state**: The scheduler does not call `pto2_scheduler_on_task_complete` at function return. Task A's OUT buffer (`tensor_X`) is not yet marked valid. + +3. **Task B is safely blocked**: Task B's `fanin_refcount < fanin_count` because Task A has not completed. Task B remains PENDING — it cannot execute on stale or partially-written data. + +4. **Deferred completion**: Only when the scheduler's polling loop detects the SDMA completion entry matching `tag` does it trigger `pto2_scheduler_on_task_complete(A)`, which propagates readiness to Task B through the standard fanout mechanism. + +5. **Data integrity**: The DMA transfer is guaranteed complete before Task B reads `tensor_X`. The runtime's existing dependency tracking is fully preserved — only the **timing** of the completion call is changed. + +--- + +## 4. Example: Task Waiting for Notification Counter (AllReduce) + +An AllReduce operation across 4 nodes. Each node contributes a partial result and must wait for all peers to finish before the reduced result is valid. Each node sends a notification to all peers upon completing its local contribution, and waits for its local counter to reach the expected value (4 = one increment per node). + +### 4.1 Task DAG (Node 0) + +``` +Task P (local partial reduce, run-to-completion) + OUT: partial ───fanout───▶ Task AR (allreduce exchange, complete_in_future=True) + INOUT: partial + OUT: reduced ───fanout───▶ Task C (post-reduce compute) + IN: reduced +``` + +### 4.2 Timeline (Node 0) + +``` +Worker Core +──────────────────────────────────────────────────────────────────▶ time +│ Task P │ Task AR │ Task D (unrelated) +│ local partial reduce │ 1. Write local partial to shared GM │ (core reused) +│ → on_task_complete(P) │ 2. For each peer (including self): │ +│ → AR becomes READY │ pto2_send_notification( │ +│ │ peer.COUNTER_ADDR, │ +│ │ ATOMIC_INCREMENT) │ +│ │ 3. pto2_save_expected_notification_ │ +│ │ counter(MY_COUNTER_ADDR, 4, AR) │ +│ │ 4. return → core released │ +├───────────────────────┼────────────────────────────────────────┤ + Task AR status: RUNNING (not COMPLETED) + +Node 0's Notification Counter +──────────────────────────────────────────────────────────────────▶ time + value: 0 + +1 (self) → 1 + +1 (Node 1) → 2 + +1 (Node 2) → 3 + +1 (Node 3) → 4 ← matches expected_value + +Scheduler +──────────────────────────────────────────────────────────────────▶ time +│ Watch list: [{MY_COUNTER_ADDR, expected=4, AR}] │ +│ ...poll... counter == 4, match found! │ +│ AR.waiting_completion_count-- → 0 │ +│ → pto2_scheduler_on_task_complete(AR) │ +│ fanout: C.fanin_refcount++ → C becomes READY │ +│ fanin: P.fanout_refcount++ → P may become CONSUMED │ + +Worker Core +──────────────────────────────────────────────────────────────────▶ time + │ Task C │ + │ post-reduce │ + │ compute on │ + │ reduced (valid│ + │ from all 4 │ + │ nodes) │ +``` + +### 4.3 Key Observations + +1. **Distributed barrier without blocking**: The notification counter acts as a barrier across 4 nodes. Each node atomically increments counters on all peers when its local work is done. No core spins or blocks — the scheduler polls the counter asynchronously. + +2. **Core reuse**: Task AR returns immediately after sending notifications and registering the expected counter. The core proceeds to execute Task D. + +3. **Task C is safely gated**: `pto2_scheduler_on_task_complete(AR)` is not called until the counter reaches 4. Task C cannot become READY until all nodes have contributed. The `reduced` output is guaranteed to reflect the fully-reduced result. + +4. **Producer buffer lifetime**: Task P's `partial` buffer is used as INOUT by Task AR. The runtime does not release this buffer until Task AR completes (via fanin retirement in `pto2_scheduler_on_task_complete`). This keeps the buffer alive while remote nodes may still be reading it via RDMA. + +5. **Composable with CQ events**: If Task AR also submits an SDMA request (e.g., to move the reduced result to a different memory region), it can call both `pto2_save_expected_completion` and `pto2_save_expected_notification_counter`. The `waiting_completion_count` starts at 2. Each event independently decrements the count. Task AR completes only when **both** the notification counter reaches 4 **and** the SDMA CQ entry arrives. + +--- + +## 5. Conclusion + +### 5.1 Summary of Changes + +This extension introduces a minimal set of additions to the pypto system, spanning the frontend language and the runtime layer: + +**PyPTO Frontend (`pl.` DSL)**: + +| Addition | Description | +|---|---| +| `pl.function(complete_in_future=True)` | New optional attribute on `pl.function`. Marks a worker function whose task completion is deferred beyond function return. Default is `False` (standard run-to-completion). | + +**PTO Runtime — New APIs**: + +| API | Purpose | +|---|---| +| `pto2_send_request_entry(RQ_TYPE, RQ_ID, *descriptor) → tag` | Submit a request to a hardware engine's request queue. Returns a unique `tag` identifying the request. | +| `pto2_save_expected_completion(CQ_TYPE, CQ_ID, tag, task_id)` | Register an expected completion queue entry. The scheduler polls the CQ for the matching `tag` and defers task completion until it arrives. | +| `pto2_send_notification(REMOTE_COUNTER_ADDR, atomic_op)` | Perform a remote atomic operation on a notification counter (e.g., increment a counter on a peer node). | +| `pto2_save_expected_notification_counter(LOCAL_COUNTER_ADDR, expected_value, task_id)` | Register an expected notification counter value. The scheduler polls the local counter and defers task completion until it reaches the expected value. | + +**PTO Runtime — Task Descriptor Extension**: + +| Field | Type | Default | Description | +|---|---|---|---| +| `complete_in_future` | `bool` | `false` | If `true`, the task's function return releases the core but does not trigger `pto2_scheduler_on_task_complete`. | +| `waiting_completion_count` | `int32_t` | `0` | Number of outstanding completion conditions. Incremented by each `pto2_save_expected_completion` or `pto2_save_expected_notification_counter` call. Decremented by the scheduler when a match is detected. Task completes when this reaches zero. | + +### 5.2 Core Runtime Mechanisms Remain Intact + +This enhancement is designed to be **non-invasive** to the existing pypto runtime architecture. The following core mechanisms are completely unchanged: + +- **Task ring** — allocation, slot management, and ring pointer advancement +- **Dependency tracking** — fanin/fanout linked lists, `fanin_refcount` / `fanout_refcount` protocol +- **`pto2_scheduler_on_task_complete`** — the three-step completion propagation (fanout propagation → fanin retirement → consumed check) is identical; only the **trigger point** is changed from "function return" to "all completion conditions satisfied" +- **Buffer lifecycle** — heap ring allocation, scope-based lifetime, and fanout-driven release +- **Ready queue** — enqueue/dequeue protocol and worker dispatch +- **Orchestrator/Scheduler/Worker roles** — unchanged at every level + +The only behavioral change is: for `complete_in_future` tasks, the call to `pto2_scheduler_on_task_complete` is **deferred** from the worker thread (at function return) to the scheduler thread (when all registered completion conditions are met). The function itself is called with the same arguments and produces the same effects. + +### 5.3 Applicability Across All Hierarchy Levels (L2–L7) + +This scheme is not limited to L2 (hardware core level). It applies uniformly across the entire hierarchical runtime system: + +| Level | Typical Async Use Case | +|---|---| +| **L2** (InCore) | SDMA transfers, CCU cache operations, TPUSH/TPOP hardware flag waits | +| **L3** (Server) | UMA cross-die memory operations, local SDMA between NPU chips | +| **L4** (Pod) | RoCE RDMA transfers between servers within a pod | +| **L5** (Service Pool) | Cross-pod data movement, distributed AllReduce notification barriers | +| **L6** (Cluster) | Inter-cluster data synchronization, federated aggregation barriers | +| **L7** (Global) | Cross-datacenter transfers, global notification barriers | + +At every level, the runtime's `LevelRuntime::on_task_complete` (L3–L7) or `pto2_scheduler_on_task_complete` (L2) follows the same deferred-completion protocol: check `complete_in_future`, poll the watch lists, decrement `waiting_completion_count`, and trigger the standard completion path only when all conditions are satisfied. This provides a **single, unified mechanism** for managing asynchronous hardware operations at any scale — from a single DMA transfer on one chip to a global multi-node synchronization barrier. diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py new file mode 100644 index 00000000..4d6dad80 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py @@ -0,0 +1,73 @@ +""" +Golden script for async_completion_demo. + +Single-card / sim path keeps the original producer-consumer pipeline: + producer: out[i] = in[i] * 2.0 + consumer: result[i] = out[i] + 1.0 + +Hardware 2-card path validates `out` and `result`: + each rank TGET_ASYNCs the peer rank's `in` into local `out`, then the + normal consumer computes `result = out + 1`. +""" + +import ctypes +import torch + +__outputs__ = ["result", "out"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def generate_inputs(params: dict) -> list: + SIZE = 128 * 128 + + inp = torch.full((SIZE,), 3.0, dtype=torch.float32) + out = torch.zeros(SIZE, dtype=torch.float32) + result = torch.zeros(SIZE, dtype=torch.float32) + event_handle_output = torch.zeros(4, dtype=torch.int32) + + return [ + ("in", inp), + ("out", out), + ("result", result), + ("event_handle_output", event_handle_output), + ("size_in", ctypes.c_int64(inp.nbytes)), + ("size_out", ctypes.c_int64(out.nbytes)), + ("size_result", ctypes.c_int64(result.nbytes)), + ("size_event_handle_output", ctypes.c_int64(event_handle_output.nbytes)), + ("SIZE", ctypes.c_int64(SIZE)), + ] + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + del comm_ctx + del nranks + del root + + size = 128 * 128 + inp = [float(i % 251) / 10.0 for i in range(size)] + out = [0.0] * size + result = [0.0] * size + + return [ + ("in", inp), + ("out", out), + ("result", result), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + if "in" in tensors: + inp = torch.as_tensor(tensors["in"]) + tensors["result"][:] = inp * 2.0 + 1.0 + tensors["out"][:] = inp * 2.0 + return + + out = tensors["out"] + result = tensors["result"] + for i in range(len(out)): + value = float(i % 251) / 10.0 + out[i] = value + result[i] = value + 1.0 diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_consumer.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_consumer.cpp new file mode 100644 index 00000000..f206bf3a --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_consumer.cpp @@ -0,0 +1,66 @@ +/** + * Async Completion Demo - Consumer Kernel (func_id=1) + * + * Implements: result[i] = src[i] + 1.0 + * + * This kernel executes as a normal run-to-completion task. It depends on the + * producer's output tensor; the scheduler only dispatches it after the + * producer's deferred completion (event flag) is resolved. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(src) — input tensor struct pointer (producer's output) + * args[1] = &Tensor(result) — output tensor struct pointer + */ + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* src_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* result_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + + __gm__ float* src = reinterpret_cast<__gm__ float*>(src_tensor->buffer.addr) + src_tensor->start_offset; + __gm__ float* result = reinterpret_cast<__gm__ float*>(result_tensor->buffer.addr) + result_tensor->start_offset; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData srcTile(vRows, vCols); + TileData dstTile(vRows, vCols); + TASSIGN(srcTile, 0x0); + TASSIGN(dstTile, 0x10000); + + GlobalData srcGlobal(src); + GlobalData dstGlobal(result); + + TLOAD(srcTile, srcGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TADDS(dstTile, srcTile, 1.0f); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(dstGlobal, dstTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp new file mode 100644 index 00000000..ccbc41e9 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer.cpp @@ -0,0 +1,91 @@ +/** + * Async Completion Demo - Simulated Producer Kernel (func_id=0) + * + * Implements: out[i] = in[i] * 2.0 + * + * After storing the output, writes 1 to a GM completion flag, then registers + * the completion via the CQ. The scheduler reads the CQ after this kernel + * returns and polls the flag address. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(in) — input tensor struct pointer + * args[1] = &Tensor(out) — output tensor struct pointer + * args[2] = event_flag_gm_addr — GM flag addr (pre-allocated by golden.py) + * args[3] = cq_addr — completion queue (appended by submit_deferred) + */ + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "pto_cq_types.h" +#include "pto_cq_kernel_api.h" + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* in_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + uint64_t event_flag_addr = static_cast(args[2]); + uint64_t cq_addr = static_cast(args[3]); + + __gm__ float* in_data = reinterpret_cast<__gm__ float*>(in_tensor->buffer.addr) + in_tensor->start_offset; + __gm__ float* out_data = reinterpret_cast<__gm__ float*>(out_tensor->buffer.addr) + out_tensor->start_offset; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData inTile(vRows, vCols); + TileData outTile(vRows, vCols); + TASSIGN(inTile, 0x0); + TASSIGN(outTile, 0x10000); + + GlobalData inGlobal(in_data); + GlobalData outGlobal(out_data); + + TLOAD(inTile, inGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // out = in + in = in * 2.0 + TADD(outTile, inTile, inTile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(outGlobal, outTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + + // Signal async completion: write non-zero flag to GM + volatile __gm__ int32_t* flag = reinterpret_cast( + static_cast(event_flag_addr)); +#if defined(SINGLE_CACHE_LINE) && defined(DSB_DDR) + dcci((__gm__ int32_t*)flag, SINGLE_CACHE_LINE); + *flag = 1; + dcci((__gm__ int32_t*)flag, SINGLE_CACHE_LINE); + dsb(DSB_DDR); +#else + *flag = 1; +#endif + + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, event_flag_addr); + pto2_cq_flush(cq); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer_async.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer_async.cpp new file mode 100644 index 00000000..00bcab0c --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/aiv/kernel_producer_async.cpp @@ -0,0 +1,89 @@ +/** + * Async Completion Demo - Hardware 2P SDMA TGET Producer Kernel (func_id=2) + * + * Implements: + * 1. Read peer rank's input buffer via TGET_ASYNC into local out + * 2. Register the async event in the CQ + * 3. Return immediately so the runtime completes the task asynchronously + * + * This kernel is only compiled for real hardware (a2a3), not for simulation. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(in) — input tensor struct pointer + * args[1] = &Tensor(out) — output tensor struct pointer + * args[2] = CommDeviceContext* — distributed communication context + * args[3] = sdma_context_addr — SDMA async context + * args[4] = cq_addr — completion queue (appended by submit_deferred) + */ + +#include +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "pto/comm/pto_comm_inst.hpp" +#include "pto/npu/comm/async/sdma/sdma_types.hpp" +#include "pto/common/pto_tile.hpp" + +#include "common/comm_context.h" +#include "tensor.h" + +using namespace pto; + +#include "pto_sq_kernel_api.h" + +template +AICORE inline __gm__ T* CommRemotePtr(__gm__ CommDeviceContext* ctx, __gm__ T* local_ptr, + int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)local_ptr - local_base; + return (__gm__ T*)(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* in_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ CommDeviceContext* comm_ctx = + reinterpret_cast<__gm__ CommDeviceContext*>(args[2]); + uint64_t sdma_context = static_cast(args[3]); + uint64_t cq_addr = static_cast(args[4]); + + __gm__ float* in_data = reinterpret_cast<__gm__ float*>(in_tensor->buffer.addr) + in_tensor->start_offset; + __gm__ float* out_data = reinterpret_cast<__gm__ float*>(out_tensor->buffer.addr) + out_tensor->start_offset; + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + + int my_rank = static_cast(comm_ctx->rankId); + int nranks = static_cast(comm_ctx->rankNum); + if (nranks != 2) { + pipe_barrier(PIPE_ALL); + return; + } + int peer_rank = 1 - my_rank; + + constexpr int kTotalElems = 128 * 128; + + using FlatShape = Shape<1, 1, 1, 1, kTotalElems>; + using FlatStride = Stride; + using FlatGlobalData = GlobalTensor; + FlatGlobalData outGlobalFlat(out_data); + __gm__ float* remote_in_data = CommRemotePtr(comm_ctx, in_data, peer_rank); + FlatGlobalData remoteInGlobalFlat(remote_in_data); + + using ScratchTile = pto::Tile; + ScratchTile scratchTile; + TASSIGN(scratchTile, 0x20000); + + __gm__ uint8_t* context = reinterpret_cast<__gm__ uint8_t*>(static_cast(sdma_context)); + + auto desc = pto2_sdma_tget_descriptor(outGlobalFlat, remoteInGlobalFlat, scratchTile, context); + uint64_t tag = pto2_send_request_entry(PTO2_ENGINE_SDMA, PTO2_SQ_ID_AUTO, desc); + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, tag); + + pto2_cq_flush(cq); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py new file mode 100644 index 00000000..d56faa32 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/kernel_config.py @@ -0,0 +1,56 @@ +""" +Async Completion Demo - Kernel and Orchestration Configuration + +Two hardware cards use the existing deferred-completion producer API to +demonstrate a real 2P TGET_ASYNC remote read. The legacy single-card / sim +path stays available for local debugging. +""" + +import os +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "async_demo_orchestration.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +_platform = os.environ.get("PTO_PLATFORM", "a2a3sim") + +KERNELS = [ + {"func_id": 0, "source": str(_KERNELS_ROOT / "aiv" / "kernel_producer.cpp"), "core_type": "aiv"}, + {"func_id": 1, "source": str(_KERNELS_ROOT / "aiv" / "kernel_consumer.cpp"), "core_type": "aiv"}, +] + +if _platform == "a2a3": + KERNELS.append( + {"func_id": 2, "source": str(_KERNELS_ROOT / "aiv" / "kernel_producer_async.cpp"), "core_type": "aiv"}, + ) + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "orch_thread_num": 1, + "block_dim": 3, + "rounds": 1, +} + +if _platform == "a2a3": + RUNTIME_ENV = { + "PTO2_ENABLE_SDMA": "1", + } + + DISTRIBUTED_CONFIG = { + "nranks": 2, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + {"name": "in", "dtype": "float32", "count": 128 * 128, "placement": "window"}, + {"name": "out", "dtype": "float32", "count": 128 * 128, "placement": "window"}, + {"name": "result", "dtype": "float32", "count": 128 * 128, "placement": "device"}, + ], + "inputs": ["in"], + "outputs": ["out", "result"], + "args": ["in", "out", "result", "deviceCtx"], + } diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp new file mode 100644 index 00000000..2cf6ba4d --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels/orchestration/async_demo_orchestration.cpp @@ -0,0 +1,135 @@ +/** + * Async Completion Demo - Device-side orchestration (CQ model) + * + * Two execution modes share this file: + * + * 1. Single-card / sim mode (legacy demo): + * t0 (producer): out = in * 2.0 [deferred completion via CQ] + * t1 (consumer): result = out + 1.0 [run-to-completion] + * + * 2. Two-card hardware mode: + * both ranks submit one deferred producer task that TGET_ASYNCs the peer + * rank's input buffer into local out, then run the normal consumer on out. + * + * CQ model: + * Orchestration marks t0 as complete_in_future and passes a CQ address. + * The producer kernel decides at runtime what completions it needs and writes + * them into the completion queue. The scheduler reads the CQ after the kernel + * returns and registers completions dynamically. + */ + +#include +#include + +#include "common/comm_context.h" +#include "pto_orchestration_api.h" + +#define ARG_PTR_IN 0 +#define ARG_PTR_OUT 1 +#define ARG_PTR_RESULT 2 +#define ARG_PTR_EVENT_HANDLE_OUTPUT 3 + +#define ARG_SIZE_IN 4 +#define ARG_SIZE_OUT 5 +#define ARG_SIZE_RESULT 6 +#define ARG_SIZE_EVENT_HANDLE_OUTPUT 7 + +#define ARG_SIZE 8 + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + return PTO2OrchestrationConfig{ + .expected_arg_count = (arg_count >= 9) ? 9 : 4, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(uint64_t* args, int arg_count, + int orch_thread_num, int orch_thread_index) { + (void)arg_count; + (void)orch_thread_num; + if (orch_thread_index != 0) return; + + if (arg_count == 4) { + void* in_ptr = (void*)(uintptr_t)args[0]; + void* out_ptr = (void*)(uintptr_t)args[1]; + void* result_ptr = (void*)(uintptr_t)args[2]; + auto* comm_ctx = reinterpret_cast((uintptr_t)args[3]); + int my_rank = (int)comm_ctx->rankId; + + uint32_t shapes[1] = {128 * 128}; + Tensor ext_in = make_tensor_external(in_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32); + + uint64_t sdma_context = pto2_rt_get_sdma_context(); + uint64_t cq = pto2_rt_alloc_cq(); + if (sdma_context == 0 || cq == 0) { + LOG_ERROR("async_demo 2P: rank %d failed to get SDMA context or CQ (sdma=0x%lx, cq=0x%lx)", + my_rank, sdma_context, cq); + return; + } + + PTOParam params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar((uint64_t)(uintptr_t)comm_ctx); + params_producer.add_scalar(sdma_context); + pto2_rt_submit_aiv_task_deferred(2, params_producer, cq); + + PTOParam params_consumer; + params_consumer.add_input(ext_out); + params_consumer.add_output(ext_result); + pto2_rt_submit_aiv_task(1, params_consumer); + + LOG_INFO("async_demo 2P: rank %d submitted TGET_ASYNC producer with CQ", my_rank); + return; + } + + void* in_ptr = (void*)(uintptr_t)args[ARG_PTR_IN]; + void* out_ptr = (void*)(uintptr_t)args[ARG_PTR_OUT]; + void* result_ptr = (void*)(uintptr_t)args[ARG_PTR_RESULT]; + uint64_t event_handle_output_gm = args[ARG_PTR_EVENT_HANDLE_OUTPUT]; + int SIZE = (int)(args[ARG_SIZE] & 0x7FFFFFFF); + + uint64_t sdma_context = pto2_rt_get_sdma_context(); + uint64_t cq = pto2_rt_alloc_cq(); + + LOG_INFO("async_demo: SIZE=%d, event_handle_output=0x%lx, sdma_context=0x%lx, cq=0x%lx", + SIZE, event_handle_output_gm, sdma_context, cq); + + uint32_t shapes[1] = {(uint32_t)SIZE}; + Tensor ext_in = make_tensor_external(in_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32); + + if (sdma_context != 0) { + // HW mode: kernel issues async SDMA request and puts event.handle directly in CQ entry. + PTOParam params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar(sdma_context); + pto2_rt_submit_aiv_task_deferred(2, params_producer, cq); + + LOG_INFO("async_demo: HW mode - submitted async SDMA producer (func_id=2) with CQ"); + } else { + PTOParam params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar(event_handle_output_gm); + pto2_rt_submit_aiv_task_deferred(0, params_producer, cq); + + LOG_INFO("async_demo: Sim mode - submitted producer (func_id=0) with CQ"); + } + + // t1 (consumer): result = out + 1.0 — normal run-to-completion + PTOParam params_consumer; + params_consumer.add_input(ext_out); + params_consumer.add_output(ext_result); + pto2_rt_submit_aiv_task(1, params_consumer); +} + +} // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py new file mode 100644 index 00000000..1d8ffa87 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py @@ -0,0 +1,70 @@ +""" +Golden script for async_notify_demo. + +Two hardware ranks each produce `out = in * 2` and TNOTIFY the peer. +The consumer is launch-gated on the local notification counter >= 1. +When the consumer runs, it reads notify_counter (must be 1) and computes +`result = out + notify_counter = in*2 + 1`. +""" + +import torch + +__outputs__ = ["result", "out"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + del rank + del nranks + del root + del comm_ctx + + size = 128 * 128 + inp = [float(i % 251) / 10.0 for i in range(size)] + out = [0.0] * size + result = [0.0] * size + notify_counter = [0] + + return [ + ("in", inp), + ("out", out), + ("result", result), + ("notify_counter", notify_counter), + ] + + +def generate_inputs(params: dict) -> list: + del params + + size = 128 * 128 + inp = torch.tensor([float(i % 251) / 10.0 for i in range(size)], dtype=torch.float32) + out = torch.zeros(size, dtype=torch.float32) + result = torch.zeros(size, dtype=torch.float32) + notify_counter = torch.zeros(1, dtype=torch.int32) + + return [ + ("in", inp), + ("out", out), + ("result", result), + ("notify_counter", notify_counter), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + del params + + if "in" in tensors: + inp = torch.as_tensor(tensors["in"]) + tensors["out"][:] = inp * 2.0 + tensors["result"][:] = tensors["out"] + 1.0 + return + + out = tensors["out"] + result = tensors["result"] + for i in range(len(out)): + value = float(i % 251) / 10.0 + out[i] = value * 2.0 + result[i] = out[i] + 1.0 diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_consumer.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_consumer.cpp new file mode 100644 index 00000000..28380969 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_consumer.cpp @@ -0,0 +1,72 @@ +/** + * Async Notify Demo - Consumer Kernel (func_id=1) + * + * Implements: result[i] = src[i] + notify_counter[0] + * + * Depends on NotifyWait completing (via dummy tensor), guaranteeing + * the local notification counter >= 1 before this kernel runs. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(dummy_notify) — input (dependency token from NotifyWait) + * args[1] = &Tensor(src) — input tensor struct pointer (producer's output) + * args[2] = &Tensor(result) — output tensor struct pointer + * args[3] = notify_counter_addr — local notify counter (window memory) + */ + +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + // args[0] = dummy_notify tensor (dependency token, unused) + __gm__ Tensor* src_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ Tensor* result_tensor = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ int32_t* notify_counter = reinterpret_cast<__gm__ int32_t*>(args[3]); + + __gm__ float* src = + reinterpret_cast<__gm__ float*>(src_tensor->buffer.addr) + src_tensor->start_offset; + __gm__ float* result = + reinterpret_cast<__gm__ float*>(result_tensor->buffer.addr) + result_tensor->start_offset; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData srcTile(vRows, vCols); + TileData dstTile(vRows, vCols); + TASSIGN(srcTile, 0x0); + TASSIGN(dstTile, 0x10000); + + GlobalData srcGlobal(src); + GlobalData dstGlobal(result); + + TLOAD(srcTile, srcGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + float notify_value = static_cast(*notify_counter); + TADDS(dstTile, srcTile, notify_value); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(dstGlobal, dstTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_notify_wait.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_notify_wait.cpp new file mode 100644 index 00000000..36d68086 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_notify_wait.cpp @@ -0,0 +1,46 @@ +/** + * NotifyWait Kernel — register notification counter as CQ condition (func_id=2) + * + * Trivial deferred-completion kernel: registers a COUNTER wait condition + * for the notification counter, then returns immediately. The scheduler + * polls the counter via the CQ mechanism and completes this task once + * *notify_counter >= expected_value. + * + * Kernel args layout: + * args[0] = &Tensor(dummy_notify) — output (dependency token for downstream) + * args[1] = notify_counter_addr — scalar (GM int32* to poll) + * args[2] = expected_value — scalar (threshold) + * args[3] = cq_addr — scalar (auto-appended by deferred submit) + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "tensor.h" +#include "pto_cq_kernel_api.h" + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + uint64_t notify_counter_addr = static_cast(args[1]); + uint32_t expected_value = static_cast(args[2]); + uint64_t cq_addr = static_cast(args[3]); + + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, + notify_counter_addr, expected_value); + // Flush CQ writes from AICore data cache to GM so the AICPU scheduler + // can read them. pto2_cq_flush's #if-defined guards don't fire because + // the constants are C++ enums, not macros — call intrinsics directly. + dcci((__gm__ int32_t*)cq, cache_line_t::ENTIRE_DATA_CACHE, dcci_dst_t::CACHELINE_OUT); + dsb(DSB_DDR); + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_producer_notify.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_producer_notify.cpp new file mode 100644 index 00000000..a7547c33 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/aiv/kernel_producer_notify.cpp @@ -0,0 +1,103 @@ +/** + * Async Notify Demo - Hardware 2P Notify Producer Kernel (func_id=0) + * + * Implements: + * 1. Local compute: out[i] = in[i] * 2.0 + * 2. Notify peer rank via TNOTIFY(AtomicAdd) on the peer's window counter + * 3. Return normally (run-to-completion, no deferred completion) + * + * Rank 1 inserts a deliberate delay before notifying. This makes missing + * launch-gating on the consumer side visible in the example output. + * + * Kernel args layout (packed by scheduler): + * args[0] = &Tensor(in) — input tensor struct pointer + * args[1] = &Tensor(out) — output tensor struct pointer + * args[2] = notify_counter_addr — local notify counter (window memory) + * args[3] = CommDeviceContext* — distributed communication context + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "pto/common/pto_tile.hpp" + +#include "common/comm_context.h" +#include "tensor.h" + +using namespace pto; + +#include "pto_notify_kernel_api.h" + +template +AICORE inline __gm__ T* CommRemotePtr(__gm__ CommDeviceContext* ctx, __gm__ T* local_ptr, + int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = (uint64_t)local_ptr - local_base; + return (__gm__ T*)(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* in_tensor = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* out_tensor = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ int32_t* local_counter = reinterpret_cast<__gm__ int32_t*>(args[2]); + __gm__ CommDeviceContext* comm_ctx = reinterpret_cast<__gm__ CommDeviceContext*>(args[3]); + + __gm__ float* in_data = + reinterpret_cast<__gm__ float*>(in_tensor->buffer.addr) + in_tensor->start_offset; + __gm__ float* out_data = + reinterpret_cast<__gm__ float*>(out_tensor->buffer.addr) + out_tensor->start_offset; + + int my_rank = static_cast(comm_ctx->rankId); + int nranks = static_cast(comm_ctx->rankNum); + if (nranks != 2) { + pipe_barrier(PIPE_ALL); + return; + } + int peer_rank = 1 - my_rank; + + constexpr int kTRows_ = 128; + constexpr int kTCols_ = 128; + constexpr int vRows = 128; + constexpr int vCols = 128; + + using DynShapeDim5 = Shape<1, 1, 1, vRows, vCols>; + using DynStridDim5 = Stride<1, 1, 1, kTCols_, 1>; + using GlobalData = GlobalTensor; + using TileData = Tile; + + TileData inTile(vRows, vCols); + TileData outTile(vRows, vCols); + TASSIGN(inTile, 0x0); + TASSIGN(outTile, 0x10000); + + GlobalData inGlobal(in_data); + GlobalData outGlobal(out_data); + + TLOAD(inTile, inGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + TADD(outTile, inTile, inTile); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + TSTORE(outGlobal, outTile); + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID7); + + if (my_rank == 1) { + for (volatile int i = 0; i < 2000000; ++i) { + } + } + + __gm__ int32_t* remote_counter = CommRemotePtr(comm_ctx, local_counter, peer_rank); + pto2_send_notification(remote_counter, 1, PTO2NotifyOp::AtomicAdd); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/kernel_config.py new file mode 100644 index 00000000..d0a3eda0 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/kernel_config.py @@ -0,0 +1,50 @@ +""" +Async Notify Demo - Kernel and Orchestration Configuration + +Two hardware cards use TNOTIFY(AtomicAdd) for inter-rank notification. +The consumer depends on a deferred NotifyWait task that polls the +local notification counter >= 1 via the CQ mechanism. +""" + +import os +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_platform = os.environ.get("PTO_PLATFORM", "a2a3sim") + +if _platform != "a2a3": + raise RuntimeError("async_notify_demo currently requires PTO_PLATFORM=a2a3") + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "async_notify_orchestration.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "source": str(_KERNELS_ROOT / "aiv" / "kernel_producer_notify.cpp"), "core_type": "aiv"}, + {"func_id": 1, "source": str(_KERNELS_ROOT / "aiv" / "kernel_consumer.cpp"), "core_type": "aiv"}, + {"func_id": 2, "source": str(_KERNELS_ROOT / "aiv" / "kernel_notify_wait.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "orch_thread_num": 1, + "block_dim": 3, + "rounds": 1, +} + +DISTRIBUTED_CONFIG = { + "nranks": 2, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + {"name": "in", "dtype": "float32", "count": 128 * 128, "placement": "window"}, + {"name": "out", "dtype": "float32", "count": 128 * 128, "placement": "device"}, + {"name": "result", "dtype": "float32", "count": 128 * 128, "placement": "device"}, + {"name": "notify_counter", "dtype": "int32", "count": 1, "placement": "window"}, + ], + "inputs": ["in", "notify_counter"], + "outputs": ["out", "result"], + "args": ["in", "out", "result", "notify_counter", "deviceCtx"], +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp new file mode 100644 index 00000000..4b5b3992 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels/orchestration/async_notify_orchestration.cpp @@ -0,0 +1,80 @@ +/** + * Async Notify Demo - Device-side orchestration + * + * Two-card hardware mode: + * t0 (producer, func_id=0): out = in * 2, then TNOTIFY(AtomicAdd) the + * peer's window counter. Completes normally (RTC). + * t1 (notify_wait, func_id=2, deferred): registers notification counter + * condition (counter >= 1) via CQ, returns immediately. + * Produces dummy_notify tensor for dependency chain. + * t2 (consumer, func_id=1): result = out + notify_counter. + * Depends on both producer (via ext_out) and notify_wait + * (via dummy_notify), ensuring counter >= 1 before reading. + * + * The notify counter is pre-zeroed by the distributed runner input loader. + */ + +#include + +#include "common/comm_context.h" +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + return PTO2OrchestrationConfig{ + .expected_arg_count = 5, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(uint64_t* args, int arg_count, + int orch_thread_num, int orch_thread_index) { + (void)orch_thread_num; + if (orch_thread_index != 0) return; + + if (arg_count != 5) { + LOG_ERROR("async_notify_demo: expected 5 args, got %d", arg_count); + return; + } + + void* in_ptr = (void*)(uintptr_t)args[0]; + void* out_ptr = (void*)(uintptr_t)args[1]; + void* result_ptr = (void*)(uintptr_t)args[2]; + void* notify_counter_ptr = (void*)(uintptr_t)args[3]; + auto* comm_ctx = reinterpret_cast((uintptr_t)args[4]); + int my_rank = (int)comm_ctx->rankId; + + uint32_t shapes[1] = {128 * 128}; + Tensor ext_in = make_tensor_external(in_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_out = make_tensor_external(out_ptr, shapes, 1, DataType::FLOAT32); + Tensor ext_result = make_tensor_external(result_ptr, shapes, 1, DataType::FLOAT32); + + // Producer: normal run-to-completion task (sends TNOTIFY to peer) + PTOParam params_producer; + params_producer.add_input(ext_in); + params_producer.add_output(ext_out); + params_producer.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr); + params_producer.add_scalar((uint64_t)(uintptr_t)comm_ctx); + pto2_rt_submit_aiv_task(0, params_producer); + + // NotifyWait: deferred task that waits for notification counter >= 1. + // Returns a dependency token tensor for downstream tasks. + Tensor notify_token = pto2_rt_submit_notification_wait_task( + 2, (uint64_t)(uintptr_t)notify_counter_ptr, 1); + + // Consumer: depends on producer (via ext_out) and notify_wait (via token). + PTOParam params_consumer; + params_consumer.add_input(notify_token); + params_consumer.add_input(ext_out); + params_consumer.add_output(ext_result); + params_consumer.add_scalar((uint64_t)(uintptr_t)notify_counter_ptr); + pto2_rt_submit_aiv_task(1, params_consumer); + + LOG_INFO("async_notify_demo: rank %d producer=RTC, notify_wait=deferred(counter=0x%lx), consumer=RTC", + my_rank, (uint64_t)(uintptr_t)notify_counter_ptr); +} + +} // extern "C" diff --git a/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/golden.py b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/golden.py new file mode 100644 index 00000000..d1cac620 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/golden.py @@ -0,0 +1,100 @@ +""" +Golden script for MOE Dispatch V2 — 8-rank multi-expert dispatch. + +Routing: expert_ids[i] = i % TOTAL_EXPERTS (deterministic round-robin) +Tokens: tokens[i][j] = ((rank * NUM_TOKENS + i) * HIDDEN_DIM + j) / 1e5 + +Per rank, the prepare kernel partitions tokens: + - Local expert tokens -> written directly to shmem_data[slot] + - Remote expert tokens -> packed into send_staging[target_rank][expert_offset] + +SendData TPUT_ASYNCs each (peer, expert) staging buffer to the peer's shmem_data. +SendCount TPUT_ASYNCs per-peer counts to the peer's recv_counts + TNOTIFY. +RecvAssemble reads shmem_data + counts after 7 notifications, assembles expandX. + +Verified outputs (symmetric across all ranks): + - expert_token_nums: each local expert receives 1 token from each of 8 ranks = [8, 8] + - local_counts: each rank has 2 local tokens, 1 per local expert = [1, 1, 0, ...] +""" + +NUM_TOKENS = 16 +HIDDEN_DIM = 128 +NUM_RANKS = 8 +EXPERTS_PER_RANK = 2 +TOTAL_EXPERTS = NUM_RANKS * EXPERTS_PER_RANK +NUM_EXPERT_SLOTS = EXPERTS_PER_RANK * NUM_RANKS +EXPAND_X_ROWS = NUM_TOKENS * NUM_RANKS +COUNT_PAD = 32 + +__outputs__ = ["expert_token_nums", "local_counts"] + +RTOL = 1e-5 +ATOL = 1e-5 + + +def _make_tokens(rank): + tokens = [0.0] * (NUM_TOKENS * HIDDEN_DIM) + for i in range(NUM_TOKENS): + for j in range(HIDDEN_DIM): + tokens[i * HIDDEN_DIM + j] = float( + (rank * NUM_TOKENS + i) * HIDDEN_DIM + j) / 100000.0 + return tokens + + +def _route_expert_ids(): + return [i % TOTAL_EXPERTS for i in range(NUM_TOKENS)] + + +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + del root, comm_ctx + + tokens = _make_tokens(rank) + expert_ids = _route_expert_ids() + + return [ + ("tokens", tokens), + ("expert_ids", expert_ids), + ("shmem_data", [0.0] * (NUM_EXPERT_SLOTS * NUM_TOKENS * HIDDEN_DIM)), + ("send_staging", [0.0] * (NUM_RANKS * EXPERTS_PER_RANK * NUM_TOKENS * HIDDEN_DIM)), + ("local_counts", [0] * COUNT_PAD), + ("send_counts", [0] * (NUM_RANKS * COUNT_PAD)), + ("recv_counts", [0] * (NUM_RANKS * COUNT_PAD)), + ("notify_counter", [0]), + ("expand_x", [0.0] * (EXPAND_X_ROWS * HIDDEN_DIM)), + ("expert_token_nums", [0] * EXPERTS_PER_RANK), + ] + + +def compute_golden(tensors: dict, params: dict) -> None: + nranks = params.get("nranks", NUM_RANKS) + my_rank = params.get("root", 0) + + expert_ids = _route_expert_ids() + local_expert_start = my_rank * EXPERTS_PER_RANK + + local_counts = tensors["local_counts"] + for k in range(EXPERTS_PER_RANK): + local_counts[k] = 0 + + for i in range(NUM_TOKENS): + eid = expert_ids[i] + target_rank = eid // EXPERTS_PER_RANK + expert_offset = eid % EXPERTS_PER_RANK + + if target_rank == my_rank: + local_counts[expert_offset] += 1 + + expert_token_nums = tensors["expert_token_nums"] + + for exp_off in range(EXPERTS_PER_RANK): + expert_total = 0 + for src_rank in range(nranks): + src_expert_id = local_expert_start + exp_off + src_expert_ids = _route_expert_ids() + + for i in range(NUM_TOKENS): + if src_expert_ids[i] == src_expert_id: + expert_total += 1 + + expert_token_nums[exp_off] = expert_total diff --git a/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_prepare.cpp b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_prepare.cpp new file mode 100644 index 00000000..5fcf2bf6 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_prepare.cpp @@ -0,0 +1,112 @@ +/** + * MOE Prepare Kernel — route tokens and pack per-rank staging buffers (func_id=0) + * + * For each token, determine the target expert/rank, then: + * - Local expert -> shmem_data[slot] (in-place, no SDMA needed) + * - Remote expert -> send_staging[target_rank][expert_offset] + * + * Computes per-expert local_counts and per-(target_rank, expert_offset) send_counts. + * + * Kernel args layout: + * args[0] = &Tensor(tokens) — input [NUM_TOKENS * HIDDEN_DIM] float + * args[1] = &Tensor(expert_ids) — input [NUM_TOKENS] int32 + * args[2] = &Tensor(send_staging) — output [NUM_RANKS * EXPERTS_PER_RANK * NUM_TOKENS * HIDDEN_DIM] float + * args[3] = &Tensor(local_counts) — output [COUNT_PAD] int32 + * args[4] = shmem_data_addr — scalar (GM float* base) + * args[5] = send_counts_addr — scalar (GM int32* base, [NUM_RANKS * COUNT_PAD]) + * args[6] = CommDeviceContext* — scalar + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "common/comm_context.h" +#include "tensor.h" + +static constexpr int NUM_TOKENS = 16; +static constexpr int HIDDEN_DIM = 128; +static constexpr int NUM_RANKS = 8; +static constexpr int EXPERTS_PER_RANK = 2; +static constexpr int NUM_EXPERT_SLOTS = EXPERTS_PER_RANK * NUM_RANKS; +static constexpr int COUNT_PAD = 32; +static constexpr int SLOT_ELEMS = NUM_TOKENS * HIDDEN_DIM; + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* tokens_t = reinterpret_cast<__gm__ Tensor*>(args[0]); + __gm__ Tensor* expert_ids_t = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ Tensor* send_stg_t = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ Tensor* local_cnt_t = reinterpret_cast<__gm__ Tensor*>(args[3]); + + __gm__ float* shmem_data = + reinterpret_cast<__gm__ float*>(static_cast(args[4])); + __gm__ int32_t* send_counts = + reinterpret_cast<__gm__ int32_t*>(static_cast(args[5])); + __gm__ CommDeviceContext* comm_ctx = + reinterpret_cast<__gm__ CommDeviceContext*>(static_cast(args[6])); + + __gm__ float* tokens = + reinterpret_cast<__gm__ float*>(tokens_t->buffer.addr) + tokens_t->start_offset; + __gm__ int32_t* expert_ids = + reinterpret_cast<__gm__ int32_t*>(expert_ids_t->buffer.addr) + expert_ids_t->start_offset; + __gm__ float* send_staging = + reinterpret_cast<__gm__ float*>(send_stg_t->buffer.addr) + send_stg_t->start_offset; + __gm__ int32_t* local_counts = + reinterpret_cast<__gm__ int32_t*>(local_cnt_t->buffer.addr) + local_cnt_t->start_offset; + + int my_rank = static_cast(comm_ctx->rankId); + + pipe_barrier(PIPE_ALL); + + int l_counts[EXPERTS_PER_RANK] = {}; + int s_counts[NUM_RANKS * EXPERTS_PER_RANK] = {}; + + for (int i = 0; i < NUM_TOKENS; i++) { + int eid = static_cast(expert_ids[i]); + int target_rank = eid / EXPERTS_PER_RANK; + int expert_offset = eid % EXPERTS_PER_RANK; + + __gm__ float* src_ptr = tokens + i * HIDDEN_DIM; + + if (target_rank == my_rank) { + int slot = expert_offset * NUM_RANKS + my_rank; + int idx = l_counts[expert_offset]; + __gm__ float* dst_ptr = shmem_data + + (slot * NUM_TOKENS + idx) * HIDDEN_DIM; + for (int j = 0; j < HIDDEN_DIM; j++) { + dst_ptr[j] = src_ptr[j]; + } + l_counts[expert_offset]++; + } else { + int staging_idx = target_rank * EXPERTS_PER_RANK + expert_offset; + int idx = s_counts[staging_idx]; + __gm__ float* dst_ptr = send_staging + + (staging_idx * NUM_TOKENS + idx) * HIDDEN_DIM; + for (int j = 0; j < HIDDEN_DIM; j++) { + dst_ptr[j] = src_ptr[j]; + } + s_counts[staging_idx]++; + } + } + + pipe_barrier(PIPE_ALL); + + for (int k = 0; k < EXPERTS_PER_RANK; k++) { + local_counts[k] = static_cast(l_counts[k]); + } + for (int r = 0; r < NUM_RANKS; r++) { + for (int e = 0; e < EXPERTS_PER_RANK; e++) { + send_counts[r * COUNT_PAD + e] = + static_cast(s_counts[r * EXPERTS_PER_RANK + e]); + } + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_recv_assemble.cpp b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_recv_assemble.cpp new file mode 100644 index 00000000..ade1eb7d --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_recv_assemble.cpp @@ -0,0 +1,113 @@ +/** + * MOE RecvAssemble Kernel — cumsum + assemble expandX (func_id=2) + * + * Depends on NotifyWait completing (via dummy tensor input), + * guaranteeing notify_counter >= NUM_RANKS-1 (7 peers done). + * + * Reads local_counts + per-source-rank recv_counts, computes cumulative + * sums for assembly offsets, copies token data from shmem_data slots + * into expandX, and writes expert_token_nums. + * + * Slot ordering: for each expert_offset, enumerate source ranks: + * slot = expert_offset * NUM_RANKS + src_rank + * count = local_counts[expert_offset] if src_rank == my_rank + * = recv_counts[src_rank * COUNT_PAD + expert_offset] otherwise + * + * Kernel args layout: + * args[0] = &Tensor(dummy_notify) — input (dependency token from NotifyWait) + * args[1] = &Tensor(local_counts) — input [COUNT_PAD] int32 + * args[2] = &Tensor(expand_x) — output [EXPAND_X_ROWS * HIDDEN_DIM] float + * args[3] = &Tensor(expert_token_nums) — output [EXPERTS_PER_RANK] int32 + * args[4] = shmem_data_addr — scalar (GM float* base) + * args[5] = recv_counts_addr — scalar (GM int32*, [NUM_RANKS * COUNT_PAD]) + * args[6] = CommDeviceContext* — scalar + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include "common/comm_context.h" +#include "tensor.h" + +static constexpr int NUM_TOKENS = 16; +static constexpr int HIDDEN_DIM = 128; +static constexpr int NUM_RANKS = 8; +static constexpr int EXPERTS_PER_RANK = 2; +static constexpr int NUM_EXPERT_SLOTS = EXPERTS_PER_RANK * NUM_RANKS; +static constexpr int COUNT_PAD = 32; + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + // args[0] = dummy_notify tensor (dependency token, unused) + __gm__ Tensor* local_cnt_t = reinterpret_cast<__gm__ Tensor*>(args[1]); + __gm__ Tensor* expand_x_t = reinterpret_cast<__gm__ Tensor*>(args[2]); + __gm__ Tensor* etn_t = reinterpret_cast<__gm__ Tensor*>(args[3]); + + __gm__ float* shmem_data = + reinterpret_cast<__gm__ float*>(static_cast(args[4])); + __gm__ int32_t* recv_counts = + reinterpret_cast<__gm__ int32_t*>(static_cast(args[5])); + __gm__ CommDeviceContext* comm_ctx = + reinterpret_cast<__gm__ CommDeviceContext*>(static_cast(args[6])); + + __gm__ int32_t* local_counts = + reinterpret_cast<__gm__ int32_t*>(local_cnt_t->buffer.addr) + local_cnt_t->start_offset; + __gm__ float* expand_x = + reinterpret_cast<__gm__ float*>(expand_x_t->buffer.addr) + expand_x_t->start_offset; + __gm__ int32_t* expert_token_nums = + reinterpret_cast<__gm__ int32_t*>(etn_t->buffer.addr) + etn_t->start_offset; + + int my_rank = static_cast(comm_ctx->rankId); + + pipe_barrier(PIPE_ALL); + + int slot_counts[NUM_EXPERT_SLOTS]; + for (int exp_off = 0; exp_off < EXPERTS_PER_RANK; exp_off++) { + for (int src_rank = 0; src_rank < NUM_RANKS; src_rank++) { + int slot = exp_off * NUM_RANKS + src_rank; + if (src_rank == my_rank) { + slot_counts[slot] = static_cast(local_counts[exp_off]); + } else { + slot_counts[slot] = static_cast( + recv_counts[src_rank * COUNT_PAD + exp_off]); + } + } + } + + int cumsum[NUM_EXPERT_SLOTS + 1]; + cumsum[0] = 0; + for (int s = 0; s < NUM_EXPERT_SLOTS; s++) { + cumsum[s + 1] = cumsum[s] + slot_counts[s]; + } + + for (int exp_off = 0; exp_off < EXPERTS_PER_RANK; exp_off++) { + int total = 0; + for (int src_rank = 0; src_rank < NUM_RANKS; src_rank++) { + int slot = exp_off * NUM_RANKS + src_rank; + total += slot_counts[slot]; + } + expert_token_nums[exp_off] = static_cast(total); + } + pipe_barrier(PIPE_ALL); + + for (int s = 0; s < NUM_EXPERT_SLOTS; s++) { + int count = slot_counts[s]; + int out_offset = cumsum[s]; + for (int t = 0; t < count; t++) { + __gm__ float* src_ptr = shmem_data + (s * NUM_TOKENS + t) * HIDDEN_DIM; + __gm__ float* dst_ptr = expand_x + (out_offset + t) * HIDDEN_DIM; + for (int j = 0; j < HIDDEN_DIM; j++) { + dst_ptr[j] = src_ptr[j]; + } + } + } + + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_send_data.cpp b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_send_data.cpp new file mode 100644 index 00000000..a4d5384a --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_moe_send_data.cpp @@ -0,0 +1,170 @@ +/** + * MOE Send Kernel — TPUT_ASYNC data + counts + TNOTIFY to all peers (func_id=1) + * + * All SDMA operations share the same channel (channelGroupIdx = block_idx). + * The SDMA engine processes SQEs strictly in order within a channel, so each + * TPUT_ASYNC appends [data SQEs][flag SQE] to the tail. When the last flag + * is set, all previous transfers are guaranteed complete. + * + * Therefore we only register ONE CQ entry — for the very last TPUT_ASYNC. + * The AICPU scheduler polls that single flag; once it flips, the entire + * batch (14 data + 7 count transfers) is done. + * + * Steps: + * 1. 14 × TPUT_ASYNC — per-(peer, expert) token data → peer shmem_data slots + * 2. 7 × TPUT_ASYNC — per-peer count block → peer recv_counts + * 3. Register CQ entry for the LAST TPUT_ASYNC only (1 entry total) + * 4. 7 × TNOTIFY — AtomicAdd(1) → peer notify_counter + * + * Kernel args layout: + * args[0] = &Tensor(send_staging) — input [STAGING_ELEMS] float + * args[1] = shmem_data_addr — scalar + * args[2] = send_counts_addr — scalar (GM int32*, [NUM_RANKS * COUNT_PAD]) + * args[3] = recv_counts_addr — scalar (local addr → CommRemotePtr) + * args[4] = notify_counter_addr — scalar (local addr → CommRemotePtr) + * args[5] = CommDeviceContext* — scalar + * args[6] = sdma_context — scalar + * args[7] = cq_addr — scalar (auto-appended by deferred submit) + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "pto/comm/pto_comm_inst.hpp" +#include "pto/npu/comm/async/sdma/sdma_types.hpp" +#include "pto/common/pto_tile.hpp" + +#include "common/comm_context.h" +#include "tensor.h" + +using namespace pto; + +#include "pto_sq_kernel_api.h" +#include "pto_notify_kernel_api.h" + +static constexpr int NUM_TOKENS = 16; +static constexpr int HIDDEN_DIM = 128; +static constexpr int NUM_RANKS = 8; +static constexpr int EXPERTS_PER_RANK = 2; +static constexpr int COUNT_PAD = 32; +static constexpr int SLOT_ELEMS = NUM_TOKENS * HIDDEN_DIM; + +template +AICORE inline __gm__ T* CommRemotePtr( + __gm__ CommDeviceContext* ctx, __gm__ T* local_ptr, int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = reinterpret_cast(local_ptr) - local_base; + return reinterpret_cast<__gm__ T*>(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + __gm__ Tensor* send_stg_t = reinterpret_cast<__gm__ Tensor*>(args[0]); + + __gm__ float* shmem_data = + reinterpret_cast<__gm__ float*>(static_cast(args[1])); + __gm__ int32_t* send_counts = + reinterpret_cast<__gm__ int32_t*>(static_cast(args[2])); + __gm__ int32_t* local_recv_counts = + reinterpret_cast<__gm__ int32_t*>(static_cast(args[3])); + __gm__ int32_t* local_notify_counter = + reinterpret_cast<__gm__ int32_t*>(static_cast(args[4])); + __gm__ CommDeviceContext* comm_ctx = + reinterpret_cast<__gm__ CommDeviceContext*>(static_cast(args[5])); + uint64_t sdma_context = static_cast(args[6]); + uint64_t cq_addr = static_cast(args[7]); + + __gm__ float* send_staging = + reinterpret_cast<__gm__ float*>(send_stg_t->buffer.addr) + send_stg_t->start_offset; + + int my_rank = static_cast(comm_ctx->rankId); + + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + + using ScratchTile = pto::Tile; + ScratchTile scratchTile; + TASSIGN(scratchTile, 0x20000); + + __gm__ uint8_t* context = + reinterpret_cast<__gm__ uint8_t*>(static_cast(sdma_context)); + + uint64_t last_tag = 0; + + // --- Step 1: 14 × TPUT_ASYNC token data to peer shmem_data slots --- + { + using SlotShape = Shape<1, 1, 1, 1, SLOT_ELEMS>; + using SlotStride = Stride; + using SlotGlobal = GlobalTensor; + + for (int peer = 0; peer < NUM_RANKS; peer++) { + if (peer == my_rank) continue; + for (int exp_off = 0; exp_off < EXPERTS_PER_RANK; exp_off++) { + int staging_idx = peer * EXPERTS_PER_RANK + exp_off; + __gm__ float* local_src = send_staging + staging_idx * SLOT_ELEMS; + + int slot = exp_off * NUM_RANKS + my_rank; + __gm__ float* peer_dst = CommRemotePtr( + comm_ctx, shmem_data + slot * SLOT_ELEMS, peer); + + SlotGlobal dstGlobal(peer_dst); + SlotGlobal srcGlobal(local_src); + + auto desc = pto2_sdma_descriptor( + dstGlobal, srcGlobal, scratchTile, context); + last_tag = pto2_send_request_entry( + PTO2_ENGINE_SDMA, PTO2_SQ_ID_AUTO, desc); + pipe_barrier(PIPE_ALL); + } + } + } + + // --- Step 2: 7 × TPUT_ASYNC per-peer count blocks --- + { + using CountShape = Shape<1, 1, 1, 1, COUNT_PAD>; + using CountStride = Stride; + using CountGlobal = GlobalTensor; + + for (int peer = 0; peer < NUM_RANKS; peer++) { + if (peer == my_rank) continue; + + __gm__ int32_t* peer_recv = CommRemotePtr( + comm_ctx, local_recv_counts + my_rank * COUNT_PAD, peer); + + CountGlobal dstGlobal(peer_recv); + CountGlobal srcGlobal(send_counts + peer * COUNT_PAD); + + auto desc = pto2_sdma_descriptor( + dstGlobal, srcGlobal, scratchTile, context); + last_tag = pto2_send_request_entry( + PTO2_ENGINE_SDMA, PTO2_SQ_ID_AUTO, desc); + pipe_barrier(PIPE_ALL); + } + } + + // --- Step 3: Register only the LAST flag (1 CQ entry) --- + // SDMA channel ordering: when this flag is set, all 21 preceding + // transfers (14 data + 7 count) are guaranteed complete. + if (last_tag != 0) { + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, last_tag); + } + + // --- Step 4: 7 × TNOTIFY to peer notify_counter --- + for (int peer = 0; peer < NUM_RANKS; peer++) { + if (peer == my_rank) continue; + __gm__ int32_t* peer_counter = CommRemotePtr( + comm_ctx, local_notify_counter, peer); + pto2_send_notification(peer_counter, 1, PTO2NotifyOp::AtomicAdd); + } + + pto2_cq_flush(cq); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_notify_wait.cpp b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_notify_wait.cpp new file mode 100644 index 00000000..12f3f283 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/aiv/kernel_notify_wait.cpp @@ -0,0 +1,47 @@ +/** + * NotifyWait Kernel — register notification counter as CQ condition (func_id=3) + * + * Trivial deferred-completion kernel: registers a COUNTER wait condition + * for the notification counter, then returns immediately. The scheduler + * polls the counter via the CQ mechanism and completes this task once + * *notify_counter >= expected_value. + * + * Kernel args layout: + * args[0] = &Tensor(dummy_notify) — output (dependency token for downstream) + * args[1] = notify_counter_addr — scalar (GM int32* to poll) + * args[2] = expected_value — scalar (threshold) + * args[3] = cq_addr — scalar (auto-appended by deferred submit) + */ + +#include + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +#include +#include "tensor.h" +#include "pto_cq_kernel_api.h" +#include "pto_notify_kernel_api.h" + +extern "C" __aicore__ __attribute__((always_inline)) +void kernel_entry(__gm__ int64_t* args) { + // args[0] = dummy_notify tensor (output, unused by kernel) + uint64_t notify_counter_addr = static_cast(args[1]); + uint32_t expected_value = static_cast(args[2]); + uint64_t cq_addr = static_cast(args[3]); + + volatile __gm__ PTO2CompletionQueue* cq = pto2_cq_get(cq_addr); + pto2_cq_reset(cq); + pto2_save_expected_notification_counter( + cq, + reinterpret_cast(static_cast(notify_counter_addr)), + expected_value); + dcci((__gm__ int32_t*)cq, cache_line_t::ENTIRE_DATA_CACHE, dcci_dst_t::CACHELINE_OUT); + dsb(DSB_DDR); + pipe_barrier(PIPE_ALL); +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/kernel_config.py b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/kernel_config.py new file mode 100644 index 00000000..4f7fe0e9 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/kernel_config.py @@ -0,0 +1,89 @@ +""" +MOE Dispatch V2 — 8-rank multi-expert example following pypto V2 pattern. + +Phases: + 0. Prepare — route tokens by expertId, pack per-rank staging buffers, + write local shmem_data slots, compute per-expert counts + 1. Send — TPUT_ASYNC data + counts to all peers, TNOTIFY each peer + 2. RecvAssemble — cumsum received counts, assemble expandX, compute expertTokenNums + +Window memory layout per rank (shared across ranks via RDMA): + shmem_data[NUM_EXPERT_SLOTS][NUM_TOKENS][HIDDEN_DIM] — token data + Slot index: expert_local_offset * NUM_RANKS + src_rank + recv_counts[NUM_RANKS][COUNT_PAD] — per-source-rank counts + +Staging layout: + send_staging[NUM_RANKS][EXPERTS_PER_RANK][NUM_TOKENS][HIDDEN_DIM] + Indexed by (target_rank, expert_offset) + +Requires PTO_PLATFORM=a2a3 (hardware with SDMA support). +""" + +import os +from pathlib import Path + +_KERNELS_ROOT = Path(__file__).parent +_platform = os.environ.get("PTO_PLATFORM", "a2a3sim") + +if _platform != "a2a3": + raise RuntimeError("moe_dispatch requires PTO_PLATFORM=a2a3") + +NUM_TOKENS = 16 +HIDDEN_DIM = 128 +NUM_RANKS = 8 +EXPERTS_PER_RANK = 2 +TOTAL_EXPERTS = NUM_RANKS * EXPERTS_PER_RANK +NUM_EXPERT_SLOTS = EXPERTS_PER_RANK * NUM_RANKS +EXPAND_X_ROWS = NUM_TOKENS * NUM_RANKS +COUNT_PAD = 32 + +ORCHESTRATION = { + "source": str(_KERNELS_ROOT / "orchestration" / "moe_dispatch_orchestration.cpp"), + "function_name": "aicpu_orchestration_entry", +} + +KERNELS = [ + {"func_id": 0, "source": str(_KERNELS_ROOT / "aiv" / "kernel_moe_prepare.cpp"), "core_type": "aiv"}, + {"func_id": 1, "source": str(_KERNELS_ROOT / "aiv" / "kernel_moe_send_data.cpp"), "core_type": "aiv"}, + {"func_id": 2, "source": str(_KERNELS_ROOT / "aiv" / "kernel_moe_recv_assemble.cpp"), "core_type": "aiv"}, + {"func_id": 3, "source": str(_KERNELS_ROOT / "aiv" / "kernel_notify_wait.cpp"), "core_type": "aiv"}, +] + +RUNTIME_CONFIG = { + "runtime": "tensormap_and_ringbuffer", + "aicpu_thread_num": 4, + "orch_thread_num": 1, + "block_dim": 3, + "rounds": 1, +} + +RUNTIME_ENV = { + "PTO2_ENABLE_SDMA": "1", +} + +STAGING_ELEMS = NUM_RANKS * EXPERTS_PER_RANK * NUM_TOKENS * HIDDEN_DIM + +DISTRIBUTED_CONFIG = { + "nranks": NUM_RANKS, + "root": 0, + "win_sync_prefix": 256, + "buffers": [ + {"name": "tokens", "dtype": "float32", "count": NUM_TOKENS * HIDDEN_DIM, "placement": "window"}, + {"name": "expert_ids", "dtype": "int32", "count": NUM_TOKENS, "placement": "window"}, + {"name": "shmem_data", "dtype": "float32", "count": NUM_EXPERT_SLOTS * NUM_TOKENS * HIDDEN_DIM, "placement": "window"}, + {"name": "send_staging", "dtype": "float32", "count": STAGING_ELEMS, "placement": "window"}, + {"name": "local_counts", "dtype": "int32", "count": COUNT_PAD, "placement": "window"}, + {"name": "send_counts", "dtype": "int32", "count": NUM_RANKS * COUNT_PAD, "placement": "window"}, + {"name": "recv_counts", "dtype": "int32", "count": NUM_RANKS * COUNT_PAD, "placement": "window"}, + {"name": "notify_counter", "dtype": "int32", "count": 1, "placement": "window"}, + {"name": "expand_x", "dtype": "float32", "count": EXPAND_X_ROWS * HIDDEN_DIM, "placement": "device"}, + {"name": "expert_token_nums", "dtype": "int32", "count": EXPERTS_PER_RANK, "placement": "device"}, + ], + "inputs": ["tokens", "expert_ids", "notify_counter"], + "outputs": ["expert_token_nums", "local_counts"], + "args": [ + "tokens", "expert_ids", "shmem_data", "send_staging", + "local_counts", "send_counts", "recv_counts", "notify_counter", + "expand_x", "expert_token_nums", "deviceCtx", + ], +} diff --git a/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/orchestration/moe_dispatch_orchestration.cpp b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/orchestration/moe_dispatch_orchestration.cpp new file mode 100644 index 00000000..0da16ac2 --- /dev/null +++ b/examples/a2a3/tensormap_and_ringbuffer/moe_dispatch/kernels/orchestration/moe_dispatch_orchestration.cpp @@ -0,0 +1,154 @@ +/** + * MOE Dispatch V2 Orchestration — 8-rank, 4-phase task DAG + * + * Task DAG per rank: + * + * Phase 0: Prepare (func_id=0, RTC) + * IN: tokens, expert_ids + * OUT: send_staging, local_counts + * Side: writes shmem_data[local slots], send_counts + * | + * +-- send_staging --> Phase 1: Send (func_id=1, deferred CQ) + * | 14 × TPUT_ASYNC data → peer shmem_data + * | 7 × TPUT_ASYNC counts → peer recv_counts + * | 7 × TNOTIFY → peer notify_counter + * | + * +-- local_counts --+ + * | + * Phase 1.5: NotifyWait (func_id=3, deferred CQ) + * OUT: dummy_notify (dependency token) + * Waits for notify_counter >= NUM_RANKS-1 via CQ poll + * | + * Phase 2: RecvAssemble (func_id=2, RTC) + * IN: local_counts, dummy_notify + * OUT: expand_x, expert_token_nums + * Reads shmem_data + recv_counts after NotifyWait completes + * + * args layout (from DISTRIBUTED_CONFIG): + * [0] = tokens (window, float*) + * [1] = expert_ids (window, int32*) + * [2] = shmem_data (window, float*) + * [3] = send_staging (window, float*) + * [4] = local_counts (window, int32*) + * [5] = send_counts (window, int32*) + * [6] = recv_counts (window, int32*) + * [7] = notify_counter (window, int32*) + * [8] = expand_x (device, float*) + * [9] = expert_token_nums (device, int32*) + * [10] = CommDeviceContext* + */ + +#include + +#include "common/comm_context.h" +#include "pto_orchestration_api.h" + +static constexpr int NUM_TOKENS = 16; +static constexpr int HIDDEN_DIM = 128; +static constexpr int NUM_RANKS = 8; +static constexpr int EXPERTS_PER_RANK = 2; +static constexpr int EXPAND_X_ROWS = NUM_TOKENS * NUM_RANKS; +static constexpr int COUNT_PAD = 32; +static constexpr int SLOT_ELEMS = NUM_TOKENS * HIDDEN_DIM; +static constexpr int STAGING_ELEMS = NUM_RANKS * EXPERTS_PER_RANK * SLOT_ELEMS; + +extern "C" { + +__attribute__((visibility("default"))) +PTO2OrchestrationConfig aicpu_orchestration_config(uint64_t* args, int arg_count) { + (void)args; + (void)arg_count; + return PTO2OrchestrationConfig{ + .expected_arg_count = 11, + }; +} + +__attribute__((visibility("default"))) +void aicpu_orchestration_entry(uint64_t* args, int arg_count, + int orch_thread_num, int orch_thread_index) { + (void)orch_thread_num; + if (orch_thread_index != 0) return; + + if (arg_count != 11) { + LOG_ERROR("moe_dispatch_v2: expected 11 args, got %d", arg_count); + return; + } + + void* tokens_ptr = (void*)(uintptr_t)args[0]; + void* expert_ids_ptr = (void*)(uintptr_t)args[1]; + uint64_t shmem_data_addr = args[2]; + void* send_staging_ptr = (void*)(uintptr_t)args[3]; + void* local_counts_ptr = (void*)(uintptr_t)args[4]; + uint64_t send_counts_addr = args[5]; + uint64_t recv_counts_addr = args[6]; + uint64_t notify_counter_addr = args[7]; + void* expand_x_ptr = (void*)(uintptr_t)args[8]; + void* etn_ptr = (void*)(uintptr_t)args[9]; + auto* comm_ctx = reinterpret_cast((uintptr_t)args[10]); + + int my_rank = (int)comm_ctx->rankId; + + uint32_t tokens_shape[1] = { (uint32_t)(NUM_TOKENS * HIDDEN_DIM) }; + uint32_t expert_ids_shape[1] = { (uint32_t)NUM_TOKENS }; + uint32_t send_stg_shape[1] = { (uint32_t)STAGING_ELEMS }; + uint32_t count_shape[1] = { (uint32_t)COUNT_PAD }; + uint32_t expand_x_shape[1] = { (uint32_t)(EXPAND_X_ROWS * HIDDEN_DIM) }; + uint32_t etn_shape[1] = { (uint32_t)EXPERTS_PER_RANK }; + + Tensor ext_tokens = make_tensor_external(tokens_ptr, tokens_shape, 1, DataType::FLOAT32); + Tensor ext_expert_ids = make_tensor_external(expert_ids_ptr, expert_ids_shape, 1, DataType::INT32); + Tensor ext_send_stg = make_tensor_external(send_staging_ptr, send_stg_shape, 1, DataType::FLOAT32); + Tensor ext_local_counts = make_tensor_external(local_counts_ptr, count_shape, 1, DataType::INT32); + Tensor ext_expand_x = make_tensor_external(expand_x_ptr, expand_x_shape, 1, DataType::FLOAT32); + Tensor ext_etn = make_tensor_external(etn_ptr, etn_shape, 1, DataType::INT32); + + uint64_t sdma_context = pto2_rt_get_sdma_context(); + uint64_t cq_send = pto2_rt_alloc_cq(); + if (sdma_context == 0 || cq_send == 0) { + LOG_ERROR("moe_dispatch_v2: rank %d failed SDMA context or CQ alloc", my_rank); + return; + } + + // Phase 0: Prepare + PTOParam params_prepare; + params_prepare.add_input(ext_tokens); + params_prepare.add_input(ext_expert_ids); + params_prepare.add_output(ext_send_stg); + params_prepare.add_output(ext_local_counts); + params_prepare.add_scalar(shmem_data_addr); + params_prepare.add_scalar(send_counts_addr); + params_prepare.add_scalar((uint64_t)(uintptr_t)comm_ctx); + pto2_rt_submit_aiv_task(0, params_prepare); + + // Phase 1: Send — data + counts + notify (single deferred CQ) + PTOParam params_send; + params_send.add_input(ext_send_stg); + params_send.add_scalar(shmem_data_addr); + params_send.add_scalar(send_counts_addr); + params_send.add_scalar(recv_counts_addr); + params_send.add_scalar(notify_counter_addr); + params_send.add_scalar((uint64_t)(uintptr_t)comm_ctx); + params_send.add_scalar(sdma_context); + pto2_rt_submit_aiv_task_deferred(1, params_send, cq_send); + + // Phase 1.5: NotifyWait — deferred wait for notification counter >= NUM_RANKS-1. + // Returns a dependency token for RecvAssemble via TensorMap. + Tensor notify_token = pto2_rt_submit_notification_wait_task( + 3, notify_counter_addr, NUM_RANKS - 1); + + // Phase 2: RecvAssemble (depends on NotifyWait via notify_token) + PTOParam params_recv; + params_recv.add_input(notify_token); + params_recv.add_input(ext_local_counts); + params_recv.add_output(ext_expand_x); + params_recv.add_output(ext_etn); + params_recv.add_scalar(shmem_data_addr); + params_recv.add_scalar(recv_counts_addr); + params_recv.add_scalar((uint64_t)(uintptr_t)comm_ctx); + pto2_rt_submit_aiv_task(2, params_recv); + + LOG_INFO("moe_dispatch_v2: rank %d submitted 4-phase DAG (8-rank, expect %d notifs)", + my_rank, NUM_RANKS - 1); +} + +} // extern "C" diff --git a/examples/scripts/README.md b/examples/scripts/README.md index 0afcb07c..6e634214 100644 --- a/examples/scripts/README.md +++ b/examples/scripts/README.md @@ -42,6 +42,32 @@ python examples/scripts/run_example.py \ -p a2a3sim ``` +#### Running Distributed (Multi-Rank) Tests + +Distributed examples are auto-detected when `kernel_config.py` contains a `DISTRIBUTED_CONFIG` dictionary. No separate script is needed — `run_example.py` handles it automatically: + +```bash +# Simulation (no hardware required, 8 ranks by default from kernel_config) +python examples/scripts/run_example.py \ + -k path/to/distributed_test/kernels \ + -g path/to/distributed_test/golden.py \ + -p a2a3sim + +# Hardware platform — pick specific devices (nranks inferred from device count) +python examples/scripts/run_example.py \ + -k path/to/distributed_test/kernels \ + -g path/to/distributed_test/golden.py \ + -p a2a3 --devices 0,1,2,3,4,5,6,7 + +# Hardware platform — non-contiguous devices +python examples/scripts/run_example.py \ + -k path/to/distributed_test/kernels \ + -g path/to/distributed_test/golden.py \ + -p a2a3 --devices 2,4,5,7 +``` + +The framework spawns one worker process per rank, each using the backend-neutral `comm_*` API. On simulation (`a2a3sim`), ranks communicate via POSIX shared memory; on hardware (`a2a3`), they use HCCL over RDMA. + ## Command Line Arguments ### `run_example.py` Parameters @@ -56,6 +82,7 @@ python examples/scripts/run_example.py \ | `--verbose` | `-v` | Enable verbose output (equivalent to `--log-level debug`) | False | | `--silent` | | Enable silent mode (equivalent to `--log-level error`) | False | | `--log-level` | | Set log level: `error`, `warn`, `info`, `debug` | `info` | +| `--nranks` | | Number of ranks for distributed tests | From `DISTRIBUTED_CONFIG` | | `--clone-protocol` | | Git protocol for cloning pto-isa: `ssh` or `https` | `ssh` | ### Platform Description @@ -161,7 +188,54 @@ ORCHESTRATION = { } ``` -### 3. `golden.py` Format +### 3. Distributed `kernel_config.py` Format + +To make a test distributed, add a `DISTRIBUTED_CONFIG` dictionary alongside the standard `KERNELS` and `ORCHESTRATION` fields: + +```python +DISTRIBUTED_CONFIG = { + "nranks": 8, # Number of ranks + "root": 0, # Root rank for collective ops + "comm_include_dirs": [...], # Extra include dirs for kernel compilation + "win_sync_prefix": 256, # Bytes reserved before window buffers + "buffers": [ + {"name": "input", "dtype": "float32", "count": 256, "placement": "window"}, + {"name": "output", "dtype": "float32", "count": 256, "placement": "device"}, + ], + "inputs": ["input"], # Buffers to load from .bin files + "outputs": ["output"], # Buffers to save after execution + "args": ["input", "output", "nranks", "root", "deviceCtx"], +} +``` + +- **`placement: "window"`** — Buffer is allocated in the RDMA window region (accessible by all ranks). +- **`placement: "device"`** — Buffer is allocated via `device_malloc` (local to each rank). +- **`args`** — Tokens passed as orchestration function arguments. Special tokens: `nranks`, `root`, `deviceCtx` (pointer to `CommDeviceContext`). + +### 4. Distributed `golden.py` Format + +The golden script for distributed tests uses `generate_distributed_inputs` instead of `generate_inputs`: + +```python +def generate_distributed_inputs(rank: int, nranks: int, root: int, + comm_ctx=None) -> list: + """Return a list of (name, data) tuples for this rank.""" + input_data = [float(i + rank * 100) for i in range(256)] + output_data = [0.0] * 256 + return [ + ("input", input_data), + ("output", output_data), + ] + +def compute_golden(tensors: dict, params: dict) -> None: + """Compute expected output for the root rank (in-place).""" + nranks = params.get("nranks", 8) + output = tensors["output"] + for i in range(256): + output[i] = float(nranks * i + 100 * nranks * (nranks - 1) // 2) +``` + +### 5. Standard `golden.py` Format ```python import torch @@ -365,6 +439,25 @@ TEST PASSED ============================================================ ``` +### Distributed Test Success Example + +``` +[INFO] Detected DISTRIBUTED_CONFIG — using distributed runner +[INFO] === Phase 1: Building runtime === +... +[INFO] === Launching 8 workers === +[INFO] Rank 0: OK +[INFO] Rank 1: OK +... +[INFO] Rank 7: OK +[INFO] VERIFY PASSED: output — 256 elements correct +[INFO] Sample: [2800.0, 2808.0, 2816.0, 2824.0, 2832.0] + +============================================================ +TEST PASSED +============================================================ +``` + ### Failure Example ``` @@ -378,8 +471,9 @@ TEST FAILED: Output 'f' does not match golden ## Reference Examples -- **Hardware Example**: [examples/host_build_graph/vector_example/](../host_build_graph/vector_example/) -- **Simulation Example**: [examples/host_build_graph/vector_example/](../host_build_graph/vector_example/) +- **Single-Card Example**: [examples/a2a3/host_build_graph/vector_example/](../a2a3/host_build_graph/vector_example/) +- **Async Completion Demo** (2-card, deferred RDMA read): [examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/](../a2a3/tensormap_and_ringbuffer/async_completion_demo/) +- **Async Notify Demo** (2-card, TNOTIFY launch gating): [examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/](../a2a3/tensormap_and_ringbuffer/async_notify_demo/) ## FAQ @@ -521,6 +615,21 @@ runner = create_code_runner( runner.run() # Execute test ``` +### Distributed Programmatic Usage + +```python +from distributed_code_runner import DistributedCodeRunner + +runner = DistributedCodeRunner( + kernels_dir="path/to/distributed_test/kernels", + golden_path="path/to/distributed_test/golden.py", + platform="a2a3sim", + nranks=8, +) + +runner.run_all() # compile, prepare data, launch workers, verify +``` + ## Related Documentation - [Main Project README](../../README.md) diff --git a/examples/scripts/distributed_code_runner.py b/examples/scripts/distributed_code_runner.py new file mode 100644 index 00000000..a29ddd0a --- /dev/null +++ b/examples/scripts/distributed_code_runner.py @@ -0,0 +1,465 @@ +""" +DistributedCodeRunner — compile, prepare data, launch workers, and verify +results for distributed (multi-card) PTO kernel tests. + +Parallel to CodeRunner, but handles DISTRIBUTED_CONFIG and spawns N +Python worker processes (one per rank) via distributed_worker.py. + +Usage: + runner = DistributedCodeRunner( + kernels_dir="path/to/distributed_test/kernels", + golden_path="path/to/distributed_test/golden.py", + platform="a2a3", nranks=8, + ) + runner.run() +""" + +import importlib.util +import logging +import os +import shutil +import struct +import subprocess +import sys +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +SIMPLER_ROOT = Path(__file__).resolve().parent.parent.parent +SCRIPTS_DIR = Path(__file__).resolve().parent + +DTYPE_FORMAT = { + "float32": ("f", 4), + "float64": ("d", 8), + "int32": ("i", 4), + "int64": ("q", 8), + "uint32": ("I", 4), + "uint64": ("Q", 8), + "float16": ("e", 2), + "int16": ("h", 2), + "uint16": ("H", 2), + "int8": ("b", 1), + "uint8": ("B", 1), +} + + +def _load_module(path, name="mod"): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +class DistributedCodeRunner: + + def __init__( + self, + kernels_dir: str, + golden_path: Optional[str] = None, + platform: str = "a2a3", + nranks: Optional[int] = None, + device_ids: Optional[list[int]] = None, + root: Optional[int] = None, + build_dir: Optional[str] = None, + artifact_dir: Optional[str] = None, + orch_func: Optional[str] = None, + pto_isa_commit: Optional[str] = None, + clone_protocol: str = "ssh", + ): + self.kernels_dir = Path(kernels_dir).resolve() + self.platform = platform + self.build_dir = Path(build_dir).resolve() if build_dir else \ + SIMPLER_ROOT / "build" / "distributed" / "cache" + self.artifact_dir = Path(artifact_dir).resolve() if artifact_dir else \ + SIMPLER_ROOT / "build" / "distributed" / "artifacts" + self.pto_isa_commit = pto_isa_commit + self.clone_protocol = clone_protocol + + self._load_kernel_config() + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + + self.nranks = nranks if nranks is not None else dist.get("nranks", 8) + self.root = root if root is not None else dist.get("root", 0) + self.orch_func = orch_func or self.kcfg.ORCHESTRATION["function_name"] + if self.nranks <= 0: + raise ValueError(f"Distributed nranks must be positive, got {self.nranks}") + if self.root < 0 or self.root >= self.nranks: + raise ValueError( + f"Distributed root must be in [0, {self.nranks}), got {self.root}" + ) + + if device_ids is None: + self.device_ids = list(range(self.nranks)) + else: + if len(device_ids) != self.nranks: + raise ValueError( + f"Expected {self.nranks} device ids, got {len(device_ids)}: {device_ids}" + ) + self.device_ids = list(device_ids) + + self.golden_path = Path(golden_path).resolve() if golden_path else None + self.golden_mod = None + + def _load_kernel_config(self): + config_path = self.kernels_dir / "kernel_config.py" + if not config_path.exists(): + raise FileNotFoundError(f"kernel_config.py not found in {self.kernels_dir}") + self.kcfg = _load_module(config_path, "kernel_config") + + def _load_golden(self): + if self.golden_mod is None and self.golden_path and self.golden_path.exists(): + self.golden_mod = _load_module(self.golden_path, "golden") + return self.golden_mod + + def _orch_artifact_name(self): + src = Path(self.kcfg.ORCHESTRATION["source"]) + return src.stem + ".so" + + def _kernel_artifact_name(self, kernel_cfg): + src = Path(kernel_cfg["source"]) + return src.stem + ".bin" + + def _get_buffer_config(self, name: str): + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + for buf_cfg in dist.get("buffers", []): + if buf_cfg["name"] == name: + return buf_cfg + raise ValueError( + f"Buffer '{name}' from golden.py not found in DISTRIBUTED_CONFIG['buffers']" + ) + + def _get_dtype_format(self, dtype: str, buffer_name: str): + fmt = DTYPE_FORMAT.get(dtype) + if fmt is None: + raise ValueError( + f"Unsupported dtype '{dtype}' for buffer '{buffer_name}'" + ) + return fmt + + # ------------------------------------------------------------------ + # compile() + # ------------------------------------------------------------------ + + def compile(self): + self.artifact_dir.mkdir(parents=True, exist_ok=True) + for sub in ("aicore", "aicpu", "host"): + p = self.build_dir / sub + if p.exists(): + shutil.rmtree(p) + self.build_dir.mkdir(parents=True, exist_ok=True) + + python_dir = SIMPLER_ROOT / "python" + sys.path.insert(0, str(python_dir)) + sys.path.insert(0, str(SCRIPTS_DIR)) + + from runtime_builder import RuntimeBuilder + from elf_parser import extract_text_section + from code_runner import _ensure_pto_isa_root + + pto_isa_root = _ensure_pto_isa_root( + verbose=True, commit=self.pto_isa_commit, + clone_protocol=self.clone_protocol) + if pto_isa_root is None: + raise EnvironmentError("PTO_ISA_ROOT could not be resolved.") + + runtime_name = self.kcfg.RUNTIME_CONFIG.get("runtime", "host_build_graph") + builder = RuntimeBuilder(platform=self.platform) + kernel_compiler = builder.get_kernel_compiler() + + logger.info("=== Phase 1: Building runtime ===") + host_binary, aicpu_binary, aicore_binary = builder.build( + runtime_name, str(self.build_dir)) + + logger.info("=== Phase 2: Compiling orchestration ===") + orch_source = self.kcfg.ORCHESTRATION["source"] + if not os.path.isabs(orch_source): + orch_source = str(self.kernels_dir / orch_source) + orch_binary = kernel_compiler.compile_orchestration( + runtime_name, orch_source, build_dir=str(self.build_dir)) + + logger.info("=== Phase 3: Compiling kernels ===") + if self.platform in ("a2a3", "a2a3sim"): + arch = "a2a3" + elif self.platform in ("a5", "a5sim"): + arch = "a5" + else: + arch = "a2a3" + + runtime_include_dirs = [ + str(SIMPLER_ROOT / "src" / arch / "runtime" / runtime_name / "runtime") + ] + + dist_config = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + extra_includes = list(runtime_include_dirs) + [ + str(SIMPLER_ROOT / "src" / arch / "platform" / "include"), + ] + for d in dist_config.get("comm_include_dirs", []): + p = Path(pto_isa_root) / d if not os.path.isabs(d) else Path(d) + extra_includes.append(str(p)) + + kernel_bins = {} + for k in self.kcfg.KERNELS: + src = k["source"] + if not os.path.isabs(src): + src = str(self.kernels_dir / src) + incore_o = kernel_compiler.compile_incore( + src, + core_type=k.get("core_type", "aiv"), + pto_isa_root=pto_isa_root, + extra_include_dirs=extra_includes, + build_dir=str(self.build_dir), + ) + if self.platform.endswith("sim"): + kernel_bins[k["func_id"]] = (k, incore_o) + else: + kernel_bins[k["func_id"]] = (k, extract_text_section(incore_o)) + + logger.info("=== Phase 4: Saving artifacts ===") + + def save(name, data): + path = self.artifact_dir / name + path.write_bytes(data) + logger.info(f" {name}: {len(data)} bytes") + + save("libhost_runtime.so", host_binary) + save("libaicpu_kernel.so", aicpu_binary) + save("aicore_kernel.o", aicore_binary) + save(self._orch_artifact_name(), orch_binary) + for func_id, (kcfg, data) in kernel_bins.items(): + save(self._kernel_artifact_name(kcfg), data) + + logger.info(f"All artifacts saved to {self.artifact_dir}") + + # ------------------------------------------------------------------ + # prepare_data() + # ------------------------------------------------------------------ + + def prepare_data(self): + golden = self._load_golden() + if not golden or not hasattr(golden, "generate_distributed_inputs"): + logger.info("No golden.py or generate_distributed_inputs — skipping data prep") + return + + for r in range(self.nranks): + rank_dir = self.artifact_dir / f"rank_{r}" + rank_dir.mkdir(parents=True, exist_ok=True) + + inputs = golden.generate_distributed_inputs(r, self.nranks, self.root) + for name, data in inputs: + if isinstance(data, (list, tuple)): + buf_cfg = self._get_buffer_config(name) + fmt_char, _ = self._get_dtype_format(buf_cfg["dtype"], name) + bin_data = struct.pack(f"<{len(data)}{fmt_char}", *data) + path = rank_dir / f"{name}.bin" + path.write_bytes(bin_data) + logger.debug(f" rank_{r}/{name}.bin: {len(bin_data)} bytes") + + logger.info(f"Prepared data for {self.nranks} ranks in {self.artifact_dir}") + + # ------------------------------------------------------------------ + # run() + # ------------------------------------------------------------------ + + def _build_worker_cmd(self, r): + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + rootinfo_file = self.artifact_dir / "rootinfo.bin" + + cmd = [ + sys.executable, + str(SCRIPTS_DIR / "distributed_worker.py"), + "--device-id", str(self.device_ids[r]), + "--rank", str(r), + "--nranks", str(self.nranks), + "--root", str(self.root), + "--artifact-dir", str(self.artifact_dir), + "--rootinfo-file", str(rootinfo_file), + "--data-dir", str(self.artifact_dir / f"rank_{r}"), + "--orch-file", self._orch_artifact_name(), + "--orch-func", self.orch_func, + ] + + rt_cfg = getattr(self.kcfg, "RUNTIME_CONFIG", {}) + cmd += ["--aicpu-thread-num", str(rt_cfg.get("aicpu_thread_num", 1))] + cmd += ["--block-dim", str(rt_cfg.get("block_dim", 1))] + cmd += ["--orch-thread-num", str(rt_cfg.get("orch_thread_num", 0))] + + win_sync = dist.get("win_sync_prefix", 0) + if win_sync: + cmd += ["--win-sync-prefix", str(win_sync)] + + for buf in dist.get("buffers", []): + spec = f"{buf['name']}:{buf['dtype']}:{buf['count']}" + if buf["placement"] == "window": + cmd += ["--win-buffer", spec] + else: + cmd += ["--dev-buffer", spec] + + for name in dist.get("inputs", []): + cmd += ["--load", name] + + for name in dist.get("outputs", []): + cmd += ["--save", name] + + for tok in dist.get("args", []): + cmd += ["--arg", tok] + + for k in self.kcfg.KERNELS: + cmd += ["--kernel-bin", + f"{k['func_id']}:{self._kernel_artifact_name(k)}"] + + return cmd + + def run(self): + rootinfo_file = self.artifact_dir / "rootinfo.bin" + + for f in self.artifact_dir.glob("barrier_*.ready"): + f.unlink() + if rootinfo_file.exists(): + rootinfo_file.unlink() + + shm_dir = Path("/dev/shm") + if shm_dir.is_dir(): + for f in shm_dir.glob("simpler_comm_*"): + try: + f.unlink() + except OSError: + pass + + logger.info(f"=== Launching {self.nranks} workers ===") + + procs = [] + log_files = [] + for r in range(self.nranks): + log_path = self.artifact_dir / f"rank{r}.log" + log_f = open(log_path, "w") + log_files.append(log_f) + + cmd = self._build_worker_cmd(r) + env = os.environ.copy() + runtime_env = getattr(self.kcfg, "RUNTIME_ENV", None) + if isinstance(runtime_env, dict): + env.update(runtime_env) + + proc = subprocess.Popen(cmd, stdout=log_f, stderr=subprocess.STDOUT, env=env) + procs.append(proc) + + fail_count = 0 + for r, proc in enumerate(procs): + proc.wait() + log_files[r].close() + if proc.returncode != 0: + fail_count += 1 + logger.error(f"Rank {r}: FAILED (exit code {proc.returncode})") + else: + logger.info(f"Rank {r}: OK") + + print() + for r in range(self.nranks): + log_path = self.artifact_dir / f"rank{r}.log" + lines = log_path.read_text().strip().split("\n") + print(f"--- RANK {r} (last 5 lines) ---") + for line in lines[-5:]: + print(line) + + print() + if fail_count == 0: + print(f"=== ALL {self.nranks} RANKS COMPLETED ===") + else: + print(f"=== {fail_count}/{self.nranks} RANKS FAILED ===") + + for f in self.artifact_dir.glob("barrier_*.ready"): + f.unlink() + + self._run_ok = (fail_count == 0) + return self._run_ok + + # ------------------------------------------------------------------ + # verify() + # ------------------------------------------------------------------ + + def verify(self): + golden = self._load_golden() + if not golden or not hasattr(golden, "compute_golden"): + logger.info("No golden.py or compute_golden — skipping verification") + return True + + dist = getattr(self.kcfg, "DISTRIBUTED_CONFIG", {}) + output_names = dist.get("outputs", []) + buf_map = {b["name"]: b for b in dist.get("buffers", [])} + + # Compute expected outputs once for the distributed verification step. + seed_dir = self.artifact_dir / f"rank_{self.root}" + seed_outputs = {} + for name in output_names: + path = seed_dir / f"{name}.bin" + if not path.exists(): + logger.error(f"Output file not found: {path}") + return False + raw = path.read_bytes() + dtype = buf_map.get(name, {}).get("dtype", "float32") + fmt_char, elem_sz = DTYPE_FORMAT.get(dtype, ("f", 4)) + count = len(raw) // elem_sz + seed_outputs[name] = list(struct.unpack(f"<{count}{fmt_char}", raw)) + + expected_outputs = {n: v.copy() for n, v in seed_outputs.items()} + params = {"nranks": self.nranks, "root": self.root} + golden.compute_golden(expected_outputs, params) + + rtol = getattr(golden, "RTOL", 1e-5) + atol = getattr(golden, "ATOL", 1e-5) + + all_ok = True + for rank in range(self.nranks): + rank_dir = self.artifact_dir / f"rank_{rank}" + for name in output_names: + path = rank_dir / f"{name}.bin" + if not path.exists(): + logger.error(f"Output file not found: {path}") + all_ok = False + continue + raw = path.read_bytes() + dtype = buf_map.get(name, {}).get("dtype", "float32") + fmt_char, elem_sz = DTYPE_FORMAT.get(dtype, ("f", 4)) + count = len(raw) // elem_sz + actual = list(struct.unpack(f"<{count}{fmt_char}", raw)) + expected = expected_outputs[name] + + mismatches = 0 + for i, (a, e) in enumerate(zip(actual, expected)): + if abs(a - e) > atol + rtol * abs(e): + if mismatches < 3: + logger.error(f" rank {rank} {name}[{i}]: got {a}, expected {e}") + mismatches += 1 + if mismatches > 0: + logger.error(f"VERIFY FAILED: rank {rank} {name} — {mismatches}/{len(actual)} mismatches") + all_ok = False + else: + logger.info(f"VERIFY PASSED: rank {rank} {name} — {len(actual)} elements correct") + if rank == 0 and len(actual) >= 5: + logger.info(f" Sample: {actual[:5]}") + + if all_ok: + print("\n=== VERIFICATION PASSED ===\n") + else: + print("\n=== VERIFICATION FAILED ===\n") + + return all_ok + + # ------------------------------------------------------------------ + # Full pipeline + # ------------------------------------------------------------------ + + def run_all(self, skip_compile=False, skip_verify=False): + if not skip_compile: + self.compile() + + if self.golden_path: + self.prepare_data() + + success = self.run() + + if success and self.golden_path and not skip_verify: + success = self.verify() + + return success diff --git a/examples/scripts/distributed_worker.py b/examples/scripts/distributed_worker.py new file mode 100644 index 00000000..5579a89a --- /dev/null +++ b/examples/scripts/distributed_worker.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +""" +Per-rank Python worker for distributed (multi-card) kernel execution. + +Replaces the monolithic C++ distributed_worker binary. Each rank runs +as a separate process, using the comm_* C API (via ctypes bindings) for +HCCL / sim communication and the existing PTO runtime C API for kernel +execution. + +Spawned by DistributedCodeRunner — not intended for direct invocation. +""" + +import argparse +import struct +import sys +from pathlib import Path + +script_dir = Path(__file__).parent.resolve() +project_root = script_dir.parent.parent +sys.path.insert(0, str(project_root / "python")) +sys.path.insert(0, str(script_dir)) + + +DTYPE_FORMAT = { + "float32": ("f", 4), + "float64": ("d", 8), + "int32": ("i", 4), + "int64": ("q", 8), + "uint32": ("I", 4), + "uint64": ("Q", 8), + "float16": ("e", 2), + "int16": ("h", 2), + "uint16": ("H", 2), + "int8": ("b", 1), + "uint8": ("B", 1), +} + + +def parse_buffer_spec(spec): + parts = spec.split(":") + return {"name": parts[0], "dtype": parts[1], "count": int(parts[2])} + + +def parse_kernel_spec(spec): + p = spec.index(":") + return {"func_id": int(spec[:p]), "filename": spec[p + 1:]} + + +def main(): + parser = argparse.ArgumentParser(description="Distributed per-rank worker") + parser.add_argument("--device-id", type=int, required=True) + parser.add_argument("--rank", type=int, required=True) + parser.add_argument("--nranks", type=int, required=True) + parser.add_argument("--root", type=int, default=0) + parser.add_argument("--artifact-dir", required=True) + parser.add_argument("--rootinfo-file", required=True) + parser.add_argument("--data-dir", default=None) + parser.add_argument("--orch-file", required=True) + parser.add_argument("--orch-func", required=True) + parser.add_argument("--win-sync-prefix", type=int, default=0) + parser.add_argument("--aicpu-thread-num", type=int, default=1) + parser.add_argument("--block-dim", type=int, default=1) + parser.add_argument("--orch-thread-num", type=int, default=0) + parser.add_argument("--win-buffer", action="append", default=[]) + parser.add_argument("--dev-buffer", action="append", default=[]) + parser.add_argument("--load", action="append", default=[], dest="loads") + parser.add_argument("--save", action="append", default=[], dest="saves") + parser.add_argument("--arg", action="append", default=[], dest="args") + parser.add_argument("--kernel-bin", action="append", default=[]) + args = parser.parse_args() + + artifact_dir = Path(args.artifact_dir) + data_dir = Path(args.data_dir) if args.data_dir else artifact_dir / f"rank_{args.rank}" + + buffers = [] + for spec in args.win_buffer: + b = parse_buffer_spec(spec) + b["placement"] = "window" + buffers.append(b) + for spec in args.dev_buffer: + b = parse_buffer_spec(spec) + b["placement"] = "device" + buffers.append(b) + + kernel_bins = [parse_kernel_spec(s) for s in args.kernel_bin] + + buf_by_name = {b["name"]: b for b in buffers} + + def elem_size(dtype): + return DTYPE_FORMAT.get(dtype, ("f", 4))[1] + + def buf_bytes(b): + return b["count"] * elem_size(b["dtype"]) + + # ---------------------------------------------------------------- + # 1. Load library + # ---------------------------------------------------------------- + from bindings import ( + bind_host_binary, set_device, launch_runtime, + device_malloc, device_free, copy_to_device, copy_from_device, + comm_init, comm_alloc_windows, comm_get_local_window_base, + comm_barrier, comm_destroy, + ARG_SCALAR, ARG_INPUT_PTR, ARG_OUTPUT_PTR, ARG_INOUT_PTR, + ) + + lib_path = artifact_dir / "libhost_runtime.so" + Runtime = bind_host_binary(str(lib_path)) + sys.stderr.write(f"[rank {args.rank}] Library loaded\n") + + set_device(args.device_id) + sys.stderr.write(f"[rank {args.rank}] Device {args.device_id} set for runtime\n") + + # ---------------------------------------------------------------- + # 2. Comm init + alloc windows + # ---------------------------------------------------------------- + comm = comm_init(args.rank, args.nranks, args.device_id, args.rootinfo_file) + + total_win = args.win_sync_prefix + for b in buffers: + if b["placement"] == "window": + total_win += buf_bytes(b) + + device_ctx_ptr = comm_alloc_windows(comm, total_win) + local_base = comm_get_local_window_base(comm) + + sys.stderr.write(f"[rank {args.rank}] Comm initialized, local_base=0x{local_base:x}\n") + + # ---------------------------------------------------------------- + # 3. Allocate buffers + # ---------------------------------------------------------------- + win_offset = args.win_sync_prefix + + for b in buffers: + nbytes = buf_bytes(b) + if b["placement"] == "window": + b["dev_ptr"] = local_base + win_offset + win_offset += nbytes + else: + ptr = device_malloc(nbytes) + if not ptr: + sys.stderr.write(f"[rank {args.rank}] device_malloc failed for '{b['name']}'\n") + return 3 + b["dev_ptr"] = ptr + sys.stderr.write( + f"[rank {args.rank}] Buffer '{b['name']}': {b['placement']} " + f"{b['count']}x{b['dtype']}={nbytes}B @ 0x{b['dev_ptr']:x}\n" + ) + + # ---------------------------------------------------------------- + # 4. Load inputs + # ---------------------------------------------------------------- + for name in args.loads: + b = buf_by_name.get(name) + if not b: + sys.stderr.write(f"[rank {args.rank}] --load: buffer '{name}' not found\n") + return 1 + path = data_dir / f"{name}.bin" + host_data = path.read_bytes() + if len(host_data) != buf_bytes(b): + sys.stderr.write( + f"[rank {args.rank}] Size mismatch for '{name}': " + f"file={len(host_data)}, expected={buf_bytes(b)}\n" + ) + return 2 + import ctypes + host_buf = (ctypes.c_uint8 * len(host_data)).from_buffer_copy(host_data) + copy_to_device(b["dev_ptr"], ctypes.addressof(host_buf), len(host_data)) + + # ---------------------------------------------------------------- + # 5. Barrier before kernel execution + # ---------------------------------------------------------------- + comm_barrier(comm) + + # ---------------------------------------------------------------- + # 6. Run simpler runtime + # ---------------------------------------------------------------- + orch_binary = (artifact_dir / args.orch_file).read_bytes() + aicpu_binary = (artifact_dir / "libaicpu_kernel.so").read_bytes() + aicore_binary = (artifact_dir / "aicore_kernel.o").read_bytes() + + kernel_binaries = [] + for k in kernel_bins: + data = (artifact_dir / k["filename"]).read_bytes() + kernel_binaries.append((k["func_id"], data)) + + func_args = [] + arg_types = [] + arg_sizes = [] + for tok in args.args: + if tok == "nranks": + func_args.append(args.nranks) + elif tok == "root": + func_args.append(args.root) + elif tok == "deviceCtx": + func_args.append(device_ctx_ptr) + else: + b = buf_by_name.get(tok) + if not b: + sys.stderr.write(f"[rank {args.rank}] --arg: unknown token '{tok}'\n") + return 1 + func_args.append(b["dev_ptr"]) + # In distributed mode, all memory is pre-allocated by the worker + # (RDMA windows / device_malloc). Pass everything as scalar so + # the runtime doesn't try to re-allocate or copy. + arg_types.append(ARG_SCALAR) + arg_sizes.append(0) + + sys.stderr.write( + f"[rank {args.rank}] Launching kernel: {len(func_args)} args, " + f"{len(kernel_binaries)} kernels\n" + ) + + runtime = Runtime() + runtime.initialize( + orch_binary, + args.orch_func, + func_args, + arg_types=arg_types, + arg_sizes=arg_sizes, + kernel_binaries=kernel_binaries, + ) + + launch_runtime( + runtime, + aicpu_thread_num=args.aicpu_thread_num, + block_dim=args.block_dim, + device_id=args.device_id, + aicpu_binary=aicpu_binary, + aicore_binary=aicore_binary, + orch_thread_num=args.orch_thread_num, + ) + + runtime.finalize() + sys.stderr.write(f"[rank {args.rank}] Kernel execution complete\n") + + # ---------------------------------------------------------------- + # 7. Barrier + save outputs + # ---------------------------------------------------------------- + comm_barrier(comm) + + import ctypes + for name in args.saves: + b = buf_by_name.get(name) + if not b: + sys.stderr.write(f"[rank {args.rank}] --save: buffer '{name}' not found\n") + continue + nbytes = buf_bytes(b) + host_buf = (ctypes.c_uint8 * nbytes)() + copy_from_device(ctypes.addressof(host_buf), b["dev_ptr"], nbytes) + path = data_dir / f"{name}.bin" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(bytes(host_buf)) + sys.stderr.write(f"[rank {args.rank}] Saved '{name}' to {path} ({nbytes}B)\n") + + # ---------------------------------------------------------------- + # 8. Cleanup + # ---------------------------------------------------------------- + for b in buffers: + if b["placement"] == "device" and b.get("dev_ptr"): + device_free(b["dev_ptr"]) + + comm_destroy(comm) + sys.stderr.write(f"[rank {args.rank}] Done\n") + return 0 + + +if __name__ == "__main__": + sys.exit(main() or 0) diff --git a/examples/scripts/run_async_tests.sh b/examples/scripts/run_async_tests.sh new file mode 100755 index 00000000..542f7850 --- /dev/null +++ b/examples/scripts/run_async_tests.sh @@ -0,0 +1,121 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../.." && pwd)" + +DEVICE_LIST="6,7" +PLATFORM="a2a3" +PTO_ISA_COMMIT="" +CLONE_PROTOCOL="https" +CANN_ENV_SCRIPT="${CANN_ENV_SCRIPT:-}" + +usage() { + cat <<'EOF' +Usage: examples/scripts/run_async_tests.sh [options] + +Run the two async distributed hardware test cases: + 1. async_completion_demo + 2. async_notify_demo + +Options: + --devices Device list for distributed run (default: 6,7) + --platform Platform passed to run_example.py (default: a2a3) + --pto-isa-commit PTO-ISA commit to pin (default: latest) + --clone-protocol PTO-ISA clone protocol: ssh|https (default: https) + --cann-env CANN environment script to source + -h, --help Show this help + +Examples: + CANN_ENV_SCRIPT=/path/to/set_env.sh examples/scripts/run_async_tests.sh + examples/scripts/run_async_tests.sh --devices 6,7 --cann-env /path/to/set_env.sh +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --devices) + DEVICE_LIST="$2" + shift 2 + ;; + --platform) + PLATFORM="$2" + shift 2 + ;; + --pto-isa-commit) + PTO_ISA_COMMIT="$2" + shift 2 + ;; + --clone-protocol) + CLONE_PROTOCOL="$2" + shift 2 + ;; + --cann-env) + CANN_ENV_SCRIPT="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown argument: $1" >&2 + usage >&2 + exit 1 + ;; + esac +done + +if [[ -z "${CANN_ENV_SCRIPT}" ]]; then + echo "CANN env script is required. Pass --cann-env or set CANN_ENV_SCRIPT." >&2 + exit 1 +fi + +if [[ ! -f "${CANN_ENV_SCRIPT}" ]]; then + echo "CANN env script not found: ${CANN_ENV_SCRIPT}" >&2 + exit 1 +fi + +source "${CANN_ENV_SCRIPT}" +export PTO_PLATFORM="${PLATFORM}" + +run_case() { + local name="$1" + local kernels="$2" + local golden="$3" + local -a cmd + + echo + echo "============================================================" + echo "Running ${name}" + echo "============================================================" + + cmd=( + python "${REPO_ROOT}/examples/scripts/run_example.py" + -k "${kernels}" + -g "${golden}" + -p "${PLATFORM}" + --devices "${DEVICE_LIST}" + --clone-protocol "${CLONE_PROTOCOL}" + ) + + if [[ -n "${PTO_ISA_COMMIT}" ]]; then + cmd+=(-c "${PTO_ISA_COMMIT}") + fi + + "${cmd[@]}" +} + +run_case \ + "async_completion_demo" \ + "${REPO_ROOT}/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/kernels" \ + "${REPO_ROOT}/examples/a2a3/tensormap_and_ringbuffer/async_completion_demo/golden.py" + +run_case \ + "async_notify_demo" \ + "${REPO_ROOT}/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/kernels" \ + "${REPO_ROOT}/examples/a2a3/tensormap_and_ringbuffer/async_notify_demo/golden.py" + +echo +echo "All async tests passed." diff --git a/examples/scripts/run_example.py b/examples/scripts/run_example.py index 7a6727be..5ed12bc1 100644 --- a/examples/scripts/run_example.py +++ b/examples/scripts/run_example.py @@ -73,6 +73,36 @@ def _wait_for_new_device_log(log_dir, pre_run_logs, timeout=15, interval=0.5): return None +def _parse_device_spec(spec): + """Expand a device spec like '4-7' or '0,1,3,5' into device ids.""" + if spec is None: + return None + + spec = spec.strip() + if not spec: + raise ValueError("Device spec must not be empty") + + device_ids = [] + for item in spec.split(","): + item = item.strip() + if not item: + continue + if "-" in item: + start_str, end_str = item.split("-", 1) + start = int(start_str) + end = int(end_str) + if end < start: + raise ValueError(f"Invalid device range '{item}': end < start") + device_ids.extend(range(start, end + 1)) + else: + device_ids.append(int(item)) + + if not device_ids: + raise ValueError("Device spec must contain at least one device") + + return device_ids + + def main(): parser = argparse.ArgumentParser( description="Run PTO runtime test with kernel config and golden script", @@ -198,10 +228,33 @@ def compute_golden(tensors: dict, params: dict) -> None: help="Skip golden computation and comparison (for benchmarking)" ) + parser.add_argument( + "--nranks", + type=int, + default=None, + help="Override number of ranks for distributed tests (default: from kernel_config)" + ) + + parser.add_argument( + "--device-range", + type=str, + default=None, + help="Explicit device range for distributed tests (e.g., 4-7)" + ) + + parser.add_argument( + "--devices", + type=str, + default=None, + help="Explicit distributed device list, supports comma lists/ranges (e.g., 0,1,3,5 or 4-7)" + ) + args = parser.parse_args() if args.all and args.case: parser.error("--all and --case are mutually exclusive") + if args.device_range and args.devices: + parser.error("--device-range and --devices are mutually exclusive") # Determine log level from arguments log_level_str = None @@ -252,6 +305,55 @@ def compute_golden(tensors: dict, params: dict) -> None: # Import and run try: + # Detect DISTRIBUTED_CONFIG to choose runner + import importlib.util as _ilu + _kc_spec = _ilu.spec_from_file_location("_kc_check", kernel_config_path) + _kc_mod = _ilu.module_from_spec(_kc_spec) + _kc_spec.loader.exec_module(_kc_mod) + is_distributed = hasattr(_kc_mod, "DISTRIBUTED_CONFIG") + + if is_distributed: + from distributed_code_runner import DistributedCodeRunner + + logger.info("Detected DISTRIBUTED_CONFIG — using distributed runner") + dist_cfg = getattr(_kc_mod, "DISTRIBUTED_CONFIG", {}) + + if args.devices is not None: + device_ids = _parse_device_spec(args.devices) + effective_nranks = len(device_ids) + elif args.device_range is not None: + device_ids = _parse_device_spec(args.device_range) + effective_nranks = len(device_ids) + else: + effective_nranks = args.nranks if args.nranks is not None else dist_cfg.get("nranks", 8) + device_ids = [args.device + i for i in range(effective_nranks)] + + if args.nranks is not None and args.nranks != effective_nranks: + raise ValueError( + f"--nranks={args.nranks} conflicts with device list " + f"({effective_nranks} devices)" + ) + + runner = DistributedCodeRunner( + kernels_dir=str(args.kernels), + golden_path=str(args.golden), + platform=args.platform, + nranks=effective_nranks, + device_ids=device_ids, + build_dir=args.savetemp, + pto_isa_commit=args.pto_isa_commit, + clone_protocol=args.clone_protocol, + ) + success = runner.run_all() + if success: + logger.info("=" * 60) + logger.info("TEST PASSED") + logger.info("=" * 60) + else: + logger.error("TEST FAILED") + return 1 + return 0 + from code_runner import create_code_runner runner = create_code_runner( diff --git a/python/bindings.py b/python/bindings.py index 6474f35d..b1a86440 100644 --- a/python/bindings.py +++ b/python/bindings.py @@ -180,6 +180,22 @@ def _setup_functions(self): self.lib.enable_runtime_profiling.argtypes = [c_void_p, c_int] self.lib.enable_runtime_profiling.restype = c_int + # --- Distributed communication API (comm_*) --- + self.lib.comm_init.argtypes = [c_int, c_int, c_int, c_char_p] + self.lib.comm_init.restype = c_void_p + + self.lib.comm_alloc_windows.argtypes = [c_void_p, c_size_t, POINTER(c_uint64)] + self.lib.comm_alloc_windows.restype = c_int + + self.lib.comm_get_local_window_base.argtypes = [c_void_p, POINTER(c_uint64)] + self.lib.comm_get_local_window_base.restype = c_int + + self.lib.comm_barrier.argtypes = [c_void_p] + self.lib.comm_barrier.restype = c_int + + self.lib.comm_destroy.argtypes = [c_void_p] + self.lib.comm_destroy.restype = c_int + # ============================================================================ # Python Wrapper Classes @@ -522,6 +538,124 @@ def launch_runtime( raise RuntimeError(f"launch_runtime failed: {rc}") +# ============================================================================ +# Distributed Communication Functions +# ============================================================================ + + +def comm_init(rank: int, nranks: int, device_id: int, rootinfo_path: str) -> int: + """ + Initialize a distributed communicator for the given rank. + + Args: + rank: This process's rank (0-based) + nranks: Total number of ranks + device_id: Physical device ID used by this process + rootinfo_path: Filesystem path for root info exchange + + Returns: + Opaque comm handle (as integer) + + Raises: + RuntimeError: If not loaded or initialization fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + handle = _lib.comm_init(rank, nranks, device_id, rootinfo_path.encode('utf-8')) + if not handle: + raise RuntimeError(f"comm_init failed for rank {rank}") + return handle + + +def comm_alloc_windows(handle: int, win_size: int) -> int: + """ + Allocate RDMA / shared-memory windows. + + Args: + handle: Comm handle from comm_init() + win_size: Window size hint (bytes per rank) + + Returns: + Device pointer to CommDeviceContext struct + + Raises: + RuntimeError: If allocation fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + device_ctx = c_uint64(0) + rc = _lib.comm_alloc_windows(ctypes.c_void_p(handle), win_size, ctypes.byref(device_ctx)) + if rc != 0: + raise RuntimeError(f"comm_alloc_windows failed: {rc}") + return device_ctx.value + + +def comm_get_local_window_base(handle: int) -> int: + """ + Get the base address of this rank's local window. + + Args: + handle: Comm handle from comm_init() + + Returns: + Device-pointer base address + + Raises: + RuntimeError: If query fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + base = c_uint64(0) + rc = _lib.comm_get_local_window_base(ctypes.c_void_p(handle), ctypes.byref(base)) + if rc != 0: + raise RuntimeError(f"comm_get_local_window_base failed: {rc}") + return base.value + + +def comm_barrier(handle: int) -> None: + """ + Synchronize all ranks in the communicator. + + Args: + handle: Comm handle from comm_init() + + Raises: + RuntimeError: If barrier fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + rc = _lib.comm_barrier(ctypes.c_void_p(handle)) + if rc != 0: + raise RuntimeError(f"comm_barrier failed: {rc}") + + +def comm_destroy(handle: int) -> None: + """ + Destroy the communicator and release all resources. + + Args: + handle: Comm handle from comm_init() + + Raises: + RuntimeError: If destruction fails + """ + global _lib + if _lib is None: + raise RuntimeError("Runtime not loaded. Call bind_host_binary() first.") + + rc = _lib.comm_destroy(ctypes.c_void_p(handle)) + if rc != 0: + raise RuntimeError(f"comm_destroy failed: {rc}") + + # ============================================================================ # Compile Strategy Functions # ============================================================================ diff --git a/src/a2a3/platform/include/common/comm_context.h b/src/a2a3/platform/include/common/comm_context.h new file mode 100644 index 00000000..d3b74c8b --- /dev/null +++ b/src/a2a3/platform/include/common/comm_context.h @@ -0,0 +1,30 @@ +/** + * CommDeviceContext — device-side distributed communication context. + * + * This struct is the ABI contract between host (comm_hccl.cpp / comm_sim.cpp) + * and device kernels. PTO communication instructions (TREDUCE, TGET, TPUT) + * access remote data through the GVA addresses in windowsIn[]/windowsOut[] + * via MTE2 DMA. + * + * On HCCL MESH topology the struct layout matches what HCCL returns directly. + * On RING topology the host builds it by extracting remote RDMA addresses + * from HcclOpResParam's remoteRes array. + * On simulation the host fills it with malloc'd pointers. + */ + +#pragma once + +#include + +static constexpr uint32_t COMM_MAX_RANK_NUM = 64; + +struct CommDeviceContext { + uint64_t workSpace; + uint64_t workSpaceSize; + + uint32_t rankId; + uint32_t rankNum; + uint64_t winSize; + uint64_t windowsIn[COMM_MAX_RANK_NUM]; + uint64_t windowsOut[COMM_MAX_RANK_NUM]; +}; diff --git a/src/a2a3/platform/include/host/comm.h b/src/a2a3/platform/include/host/comm.h new file mode 100644 index 00000000..4c2c624c --- /dev/null +++ b/src/a2a3/platform/include/host/comm.h @@ -0,0 +1,93 @@ +/** + * Backend-neutral distributed communication C API. + * + * Provides five primitives for multi-rank communication: init, allocate + * shared windows, query local window base, barrier, and destroy. + * + * Implementations: + * onboard/host/comm_hccl.cpp — HCCL backend (links CANN hccl/hccl_fwk) + * sim/host/comm_sim.cpp — malloc-based simulation + * + * All functions are compiled into libhost_runtime.so. The linker selects + * the implementation at build time (onboard vs sim), with no runtime + * dispatch or virtual functions. + */ + +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct CommHandle_* CommHandle; + +/** + * Initialize a communicator for the given rank. + * + * On the HCCL backend this performs ACL init, RootInfo exchange (rank 0 + * writes the file, others wait), and HcclCommInitRootInfo. + * + * @param rank This process's rank (0-based). + * @param nranks Total number of ranks. + * @param device_id Physical device ID used by this process. + * @param rootinfo_path Filesystem path used to exchange root info between + * ranks (rank 0 writes, others read). + * @return Opaque handle, or NULL on failure. + */ +CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path); + +/** + * Allocate RDMA / shared-memory windows and populate the device context. + * + * On HCCL this calls HcclAllocComResourceByTiling and extracts per-rank + * window addresses (MESH or RING topology). On sim it mallocs a shared + * region and partitions it. + * + * @param h Handle from comm_init(). + * @param win_size Window size hint (bytes per rank). The backend + * may allocate more; actual size is stored in the + * returned device context. + * @param device_ctx_out Receives a device pointer to a CommDeviceContext + * struct that can be passed to device kernels. + * @return 0 on success, non-zero on failure. + */ +int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out); + +/** + * Get the base address of this rank's local window. + * + * Window buffers allocated via comm_alloc_windows() are contiguous per + * rank. This returns the start of the local rank's region. + * + * @param h Handle from comm_init(). + * @param base_out Receives the device-pointer base address. + * @return 0 on success, non-zero on failure. + */ +int comm_get_local_window_base(CommHandle h, uint64_t* base_out); + +/** + * Synchronize all ranks. + * + * Blocks until every rank in the communicator has called comm_barrier(). + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_barrier(CommHandle h); + +/** + * Destroy the communicator and release all resources. + * + * After this call the handle is invalid. + * + * @param h Handle from comm_init(). + * @return 0 on success, non-zero on failure. + */ +int comm_destroy(CommHandle h); + +#ifdef __cplusplus +} +#endif diff --git a/src/a2a3/platform/onboard/host/CMakeLists.txt b/src/a2a3/platform/onboard/host/CMakeLists.txt index d83ecae7..6fac488f 100644 --- a/src/a2a3/platform/onboard/host/CMakeLists.txt +++ b/src/a2a3/platform/onboard/host/CMakeLists.txt @@ -28,6 +28,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/host_log.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_hccl.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) foreach(SRC_DIR ${CUSTOM_SOURCE_DIRS}) @@ -80,6 +81,25 @@ target_link_directories(host_runtime ${ASCEND_HOME_PATH}/runtime/lib64 ) +# CANN 9.x exposes the working non-V2 HCCL entry points through libhcomm. +# Link it explicitly so comm_hccl.cpp can follow the same initialization path +# as the pto-isa communication tests. +find_library(HCOMM_LIB NAMES hcomm PATHS ${ASCEND_HOME_PATH}/lib64 NO_DEFAULT_PATH) +if(HCOMM_LIB) + set(HCCL_LINK_TARGETS hcomm) + message(STATUS "Using HCCL library: hcomm") +else() + message(FATAL_ERROR "libhcomm not found under ${ASCEND_HOME_PATH}/lib64") +endif() + +# Optionally link nnopbase (provides aclCreateTensor/aclDestroyTensor for SdmaWorkspaceManager) +find_library(NNOPBASE_LIB NAMES nnopbase PATHS ${ASCEND_HOME_PATH}/lib64 NO_DEFAULT_PATH) +if(NNOPBASE_LIB) + set(NNOPBASE_LINK nnopbase) +else() + set(NNOPBASE_LINK "") +endif() + # Link against CANN runtime libraries # ascend_hal is dynamically loaded at runtime via dlopen in device_runner # when performance profiling is enabled @@ -87,6 +107,8 @@ target_link_libraries(host_runtime PRIVATE runtime ascendcl + ${HCCL_LINK_TARGETS} + ${NNOPBASE_LINK} dl ) diff --git a/src/a2a3/platform/onboard/host/comm_hccl.cpp b/src/a2a3/platform/onboard/host/comm_hccl.cpp new file mode 100644 index 00000000..5236dde8 --- /dev/null +++ b/src/a2a3/platform/onboard/host/comm_hccl.cpp @@ -0,0 +1,519 @@ +/** + * HCCL backend for the comm_* distributed communication API. + * + * Implements the five functions declared in host/comm.h using Ascend + * HCCL (bundled with CANN). Handles both MESH and RING topologies + * when extracting per-rank RDMA window addresses. + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "hccl/hccl_comm.h" +#include "hccl/hccl_types.h" + +using CommTopo = uint32_t; + +// Internal HCCL helpers are exported by libhcomm on CANN 9.x. The public +// HCCL APIs below intentionally use the standard, non-V2 entry points to match +// the working pto-isa initialization sequence. +extern "C" HcclResult HcclAllocComResourceByTiling(HcclComm comm, void* stream, + void* mc2Tiling, void** commContext); +extern "C" HcclResult HcomGetCommHandleByGroup(const char* group, HcclComm* commHandle); +extern "C" HcclResult HcomGetL0TopoTypeEx(const char* group, CommTopo* topoType, + uint32_t isSetDevice); + +static inline HcclResult hccl_get_root_info(HcclRootInfo* ri) + { return HcclGetRootInfo(ri); } +static inline HcclResult hccl_comm_init_root_info(uint32_t n, const HcclRootInfo* ri, uint32_t r, HcclComm* c) + { return HcclCommInitRootInfo(n, ri, r, c); } +static inline HcclResult hccl_get_comm_name(HcclComm c, char* name) + { return HcclGetCommName(c, name); } +static inline HcclResult hccl_barrier(HcclComm c, aclrtStream s) + { return HcclBarrier(c, s); } +static inline HcclResult hccl_comm_destroy(HcclComm c) + { return HcclCommDestroy(c); } +static inline HcclResult hccl_alloc_com_resource(HcclComm c, void* s, void* t, void** ctx) + { return HcclAllocComResourceByTiling(c, s, t, ctx); } +static inline HcclResult hccl_get_comm_handle_by_group(const char* g, HcclComm* c) + { return HcomGetCommHandleByGroup(g, c); } +static inline HcclResult hccl_get_l0_topo_type_ex(const char* g, CommTopo* t, uint32_t f) + { return HcomGetL0TopoTypeEx(g, t, f); } + +static constexpr uint32_t COMM_IS_NOT_SET_DEVICE = 0; +static constexpr uint32_t COMM_TOPO_MESH = 0b1u; + +using rtStream_t = void*; +static constexpr int32_t RT_STREAM_PRIORITY_DEFAULT = 0; +extern "C" int32_t rtSetDevice(int32_t device); +extern "C" int32_t rtStreamCreate(rtStream_t* stream, int32_t priority); +extern "C" int32_t rtStreamDestroy(rtStream_t stream); + +// ============================================================================ +// HCCL tiling structures (required by HcclAllocComResourceByTiling) +// ============================================================================ + +namespace { + +static constexpr uint32_t MAX_CC_TILING_NUM = 8U; +static constexpr uint32_t GROUP_NAME_SIZE = 128U; +static constexpr uint32_t ALG_CONFIG_SIZE = 128U; + +struct Mc2InitTilingInner { + uint32_t version; + uint32_t mc2HcommCnt; + uint32_t offset[MAX_CC_TILING_NUM]; + uint8_t debugMode; + uint8_t preparePosition; + uint16_t queueNum; + uint16_t commBlockNum; + uint8_t devType; + char reserved[17]; +}; + +struct Mc2cCTilingInner { + uint8_t skipLocalRankCopy; + uint8_t skipBufferWindowCopy; + uint8_t stepSize; + uint8_t version; + char reserved[9]; + uint8_t commEngine; + uint8_t srcDataType; + uint8_t dstDataType; + char groupName[GROUP_NAME_SIZE]; + char algConfig[ALG_CONFIG_SIZE]; + uint32_t opType; + uint32_t reduceType; +}; + +struct Mc2CommConfigV2 { + Mc2InitTilingInner init; + Mc2cCTilingInner inner; +}; + +// HCCL compat structs for RING topology parsing +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +static constexpr uint32_t COMPAT_LOCAL_NOTIFY_MAX_NUM = 64; +static constexpr uint32_t COMPAT_LOCAL_STREAM_MAX_NUM = 19; +static constexpr uint32_t COMPAT_AICPU_OP_NOTIFY_MAX_NUM = 2; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[COMPAT_LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[COMPAT_LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[COMPAT_AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; +}; + +struct AlgoTopoInfo { + uint32_t userRank; + uint32_t userRankSize; + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interXLinkDisable; + uint32_t floatOverflowMode; + uint32_t multiQpThreshold; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HcclMC2WorkSpace { + uint64_t workspace; + uint64_t workspaceSize; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParamHead { + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; +}; + +struct HcclOpResParam { + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[1]; +}; + +} // anonymous namespace + +// ============================================================================ +// Internal state +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string rootinfo_path; + + rtStream_t stream = nullptr; + HcclComm hccl_comm = nullptr; + + CommDeviceContext host_ctx{}; + CommDeviceContext* device_ctx = nullptr; + bool owns_device_ctx = false; +}; + +// ============================================================================ +// Helpers +// ============================================================================ + +static bool wait_for_file(const std::string& path, int timeout_sec = 120) { + for (int i = 0; i < timeout_sec * 10; ++i) { + std::ifstream f(path, std::ios::binary); + if (f.good()) { + auto sz = f.seekg(0, std::ios::end).tellg(); + if (sz >= static_cast(HCCL_ROOT_INFO_BYTES)) return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; +} + +static void file_barrier(const std::string& dir, int rank, int nranks, const std::string& tag) { + std::string my_marker = dir + "/barrier_" + tag + "_" + std::to_string(rank) + ".ready"; + { std::ofstream(my_marker) << "1"; } + + for (int r = 0; r < nranks; ++r) { + std::string marker = dir + "/barrier_" + tag + "_" + std::to_string(r) + ".ready"; + while (true) { + std::ifstream f(marker); + if (f.good()) break; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } +} + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + + h->rank = rank; + h->nranks = nranks; + h->rootinfo_path = rootinfo_path; + + // ACL init + constexpr int kAclRepeatInit = 100002; + aclError aRet = aclInit(nullptr); + if (aRet != ACL_SUCCESS && static_cast(aRet) != kAclRepeatInit) { + fprintf(stderr, "[comm rank %d] aclInit failed: %d\n", rank, (int)aRet); + delete h; + return nullptr; + } + + if (rank == 0) { + int32_t rtRet = rtSetDevice(device_id); + if (rtRet != 0) { + fprintf(stderr, "[comm rank %d] rtSetDevice(%d) failed: %d\n", + rank, device_id, rtRet); + delete h; + return nullptr; + } + } + + // HCCL requires an ACL runtime context bound to the physical device. + // This cannot be inferred from rank because distributed runs may map + // ranks to arbitrary device lists (for example devices=[2,4,5,7]). + aRet = aclrtSetDevice(device_id); + if (aRet != ACL_SUCCESS) { + fprintf(stderr, "[comm rank %d] aclrtSetDevice(%d) failed: %d\n", + rank, device_id, (int)aRet); + delete h; + return nullptr; + } + + // RootInfo exchange + HcclRootInfo rootInfo{}; + if (rank == 0) { + HcclResult hret = hccl_get_root_info(&rootInfo); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank 0] HcclGetRootInfo failed: %d\n", (int)hret); + delete h; + return nullptr; + } + std::ofstream fout(rootinfo_path, std::ios::binary); + fout.write(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + fout.close(); + } else { + if (!wait_for_file(rootinfo_path)) { + fprintf(stderr, "[comm rank %d] Timeout waiting for rootinfo\n", rank); + delete h; + return nullptr; + } + std::ifstream fin(rootinfo_path, std::ios::binary); + fin.read(rootInfo.internal, HCCL_ROOT_INFO_BYTES); + } + + std::string barrier_dir = h->rootinfo_path; + auto last_slash = barrier_dir.rfind('/'); + if (last_slash != std::string::npos) { + barrier_dir = barrier_dir.substr(0, last_slash); + } + file_barrier(barrier_dir, h->rank, h->nranks, "rootinfo_ready"); + + // Create stream for HCCL operations + rtStreamCreate(&h->stream, RT_STREAM_PRIORITY_DEFAULT); + + // Init communicator + HcclResult hret = hccl_comm_init_root_info( + static_cast(nranks), &rootInfo, static_cast(rank), &h->hccl_comm); + if (hret != HCCL_SUCCESS) { + fprintf(stderr, "[comm rank %d] HcclCommInitRootInfo failed: %d\n", rank, (int)hret); + if (h->stream) rtStreamDestroy(h->stream); + delete h; + return nullptr; + } + + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t /*win_size*/, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + char group[128] = {}; + HcclResult hret = hccl_get_comm_name(h->hccl_comm, group); + if (hret != HCCL_SUCCESS) return -1; + + CommTopo topoType = 0; + hret = hccl_get_l0_topo_type_ex(group, &topoType, COMM_IS_NOT_SET_DEVICE); + if (hret != HCCL_SUCCESS) return -1; + + HcclComm commHandle = nullptr; + hret = hccl_get_comm_handle_by_group(group, &commHandle); + if (hret != HCCL_SUCCESS) return -1; + + // File barrier so all ranks have completed HcclCommInitRootInfo + std::string barrier_dir = h->rootinfo_path; + auto last_slash = barrier_dir.rfind('/'); + if (last_slash != std::string::npos) { + barrier_dir = barrier_dir.substr(0, last_slash); + } + file_barrier(barrier_dir, h->rank, h->nranks, "hccl_init"); + + // Tiling configuration for HcclAllocComResourceByTiling + Mc2CommConfigV2 tiling{}; + memset(&tiling, 0, sizeof(tiling)); + tiling.init.version = 100U; + tiling.init.mc2HcommCnt = 1U; + tiling.init.commBlockNum = 48U; + tiling.init.devType = 4U; + tiling.init.offset[0] = static_cast( + reinterpret_cast(&tiling.inner) - reinterpret_cast(&tiling.init)); + tiling.inner.opType = 18U; + tiling.inner.commEngine = 3U; + tiling.inner.version = 1U; + strncpy(tiling.inner.groupName, group, GROUP_NAME_SIZE - 1); + strncpy(tiling.inner.algConfig, "BatchWrite=level0:fullmesh", ALG_CONFIG_SIZE - 1); + + void* ctxPtr = nullptr; + hret = hccl_alloc_com_resource(commHandle, h->stream, &tiling, &ctxPtr); + if (hret != HCCL_SUCCESS || ctxPtr == nullptr) return -1; + + // Extract CommDeviceContext (topology-dependent) + aclError aRet; + if (topoType == COMM_TOPO_MESH) { + h->device_ctx = reinterpret_cast(ctxPtr); + aRet = aclrtMemcpy(&h->host_ctx, sizeof(h->host_ctx), + h->device_ctx, sizeof(h->host_ctx), ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + } else { + // RING topology: parse HcclOpResParam structure on device + auto* rawCtx = reinterpret_cast(ctxPtr); + + HcclOpResParamHead head{}; + const size_t headOff = offsetof(HcclOpResParam, localUsrRankId); + aRet = aclrtMemcpy(&head, sizeof(head), rawCtx + headOff, sizeof(head), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + const size_t remoteResOff = offsetof(HcclOpResParam, remoteRes); + const size_t remoteResBytes = head.rankSize * sizeof(RemoteResPtr); + std::vector remoteResArr(head.rankSize); + aRet = aclrtMemcpy(remoteResArr.data(), remoteResBytes, + rawCtx + remoteResOff, remoteResBytes, ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + + memset(&h->host_ctx, 0, sizeof(h->host_ctx)); + + uint64_t wsFields[2] = {0, 0}; + aclrtMemcpy(wsFields, sizeof(wsFields), rawCtx, sizeof(wsFields), ACL_MEMCPY_DEVICE_TO_HOST); + h->host_ctx.workSpace = wsFields[0]; + h->host_ctx.workSpaceSize = wsFields[1]; + h->host_ctx.rankId = head.localUsrRankId; + h->host_ctx.rankNum = head.rankSize; + h->host_ctx.winSize = head.winSize; + + for (uint32_t i = 0; i < head.rankSize; ++i) { + if (i == head.localUsrRankId) { + h->host_ctx.windowsIn[i] = head.localWindowsIn; + continue; + } + uint64_t devPtr = remoteResArr[i].nextDevicePtr; + if (devPtr == 0) return -1; + + HcclRankRelationResV2 remoteInfo{}; + aRet = aclrtMemcpy(&remoteInfo, sizeof(remoteInfo), + reinterpret_cast(devPtr), sizeof(remoteInfo), + ACL_MEMCPY_DEVICE_TO_HOST); + if (aRet != ACL_SUCCESS) return -1; + h->host_ctx.windowsIn[i] = remoteInfo.windowsIn; + } + + void* newDevMem = nullptr; + aRet = aclrtMalloc(&newDevMem, sizeof(CommDeviceContext), ACL_MEM_MALLOC_HUGE_FIRST); + if (aRet != ACL_SUCCESS) return -1; + + aRet = aclrtMemcpy(newDevMem, sizeof(CommDeviceContext), + &h->host_ctx, sizeof(CommDeviceContext), ACL_MEMCPY_HOST_TO_DEVICE); + if (aRet != ACL_SUCCESS) { + aclrtFree(newDevMem); + return -1; + } + h->device_ctx = reinterpret_cast(newDevMem); + h->owns_device_ctx = true; + } + + *device_ctx_out = reinterpret_cast(h->device_ctx); + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h) return -1; + hccl_barrier(h->hccl_comm, (aclrtStream)h->stream); + aclrtSynchronizeStream((aclrtStream)h->stream); + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->owns_device_ctx && h->device_ctx) { + aclrtFree(h->device_ctx); + } + if (h->stream) rtStreamDestroy(h->stream); + if (h->hccl_comm) hccl_comm_destroy(h->hccl_comm); + + // NOTE: Do NOT call aclrtResetDevice / aclFinalize here. + // Device lifecycle is owned by DeviceRunner (static singleton) whose + // destructor frees all tracked device memory before resetting the device. + // Resetting early would invalidate pointers still held by MemoryAllocator. + + delete h; + return 0; +} diff --git a/src/a2a3/platform/onboard/host/device_runner.cpp b/src/a2a3/platform/onboard/host/device_runner.cpp index 4e0431bd..50f9b577 100644 --- a/src/a2a3/platform/onboard/host/device_runner.cpp +++ b/src/a2a3/platform/onboard/host/device_runner.cpp @@ -9,6 +9,8 @@ #include +#include "acl/acl.h" + // Include HAL constants from CANN (header only, library loaded dynamically) #include "ascend_hal.h" #include "host/host_regs.h" // Register address retrieval @@ -577,10 +579,19 @@ int DeviceRunner::finalize() { // Free all remaining allocations (including handshake buffer and binGmAddr) mem_alloc_.finalize(); + int saved_device_id = device_id_; device_id_ = -1; worker_count_ = 0; aicore_kernel_binary_.clear(); + // Reset device and finalize ACL AFTER all device memory is freed. + // This was previously done in comm_destroy, but that ran before the + // static DeviceRunner destructor, causing rtFree failures (107000). + if (saved_device_id >= 0) { + aclrtResetDevice(saved_device_id); + aclFinalize(); + } + LOG_INFO("DeviceRunner finalized"); return 0; } diff --git a/src/a2a3/platform/sim/host/CMakeLists.txt b/src/a2a3/platform/sim/host/CMakeLists.txt index 1e304455..11f9dad8 100644 --- a/src/a2a3/platform/sim/host/CMakeLists.txt +++ b/src/a2a3/platform/sim/host/CMakeLists.txt @@ -33,6 +33,7 @@ list(APPEND HOST_RUNTIME_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/unified_log_host.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../../src/host/performance_collector.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/../aicpu/platform_aicpu_affinity.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/comm_sim.cpp" ) if(DEFINED CUSTOM_SOURCE_DIRS) @@ -76,6 +77,7 @@ target_link_libraries(host_runtime PRIVATE pthread dl + rt ) set_target_properties(host_runtime PROPERTIES diff --git a/src/a2a3/platform/sim/host/comm_sim.cpp b/src/a2a3/platform/sim/host/comm_sim.cpp new file mode 100644 index 00000000..33c53791 --- /dev/null +++ b/src/a2a3/platform/sim/host/comm_sim.cpp @@ -0,0 +1,199 @@ +/** + * Simulation backend for the comm_* distributed communication API. + * + * Uses POSIX shared memory (shm_open + mmap) so that multiple *processes* + * (one per rank, spawned by DistributedCodeRunner) share the same RDMA + * window region. Synchronization primitives (barrier counters) live in + * the shared region itself, using GCC __atomic builtins which are safe + * on lock-free-capable types in mmap'd memory. + * + * Shared memory layout (page-aligned header + per-rank windows): + * [ SharedHeader (4096 bytes) ][ rank-0 window ][ rank-1 window ] ... + */ + +#include "host/comm.h" +#include "common/comm_context.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static constexpr size_t HEADER_SIZE = 4096; + +namespace { + +struct SharedHeader { + volatile int nranks; + volatile int alloc_done; + volatile int ready_count; + volatile int barrier_count; + volatile int barrier_phase; + volatile int destroy_count; + size_t per_rank_win_size; +}; + +std::string make_shm_name(const char* rootinfo_path) { + size_t h = std::hash{}(rootinfo_path ? rootinfo_path : "default"); + char buf[64]; + std::snprintf(buf, sizeof(buf), "/simpler_comm_%zx", h); + return buf; +} + +} // anonymous namespace + +// ============================================================================ +// Per-handle state (process-local) +// ============================================================================ + +struct CommHandle_ { + int rank; + int nranks; + std::string shm_name; + + void* mmap_base = nullptr; + size_t mmap_size = 0; + bool is_creator = false; + + CommDeviceContext host_ctx{}; +}; + +// ============================================================================ +// API implementation +// ============================================================================ + +extern "C" CommHandle comm_init(int rank, int nranks, int device_id, const char* rootinfo_path) { + auto* h = new (std::nothrow) CommHandle_{}; + if (!h) return nullptr; + (void)device_id; + + h->rank = rank; + h->nranks = nranks; + h->shm_name = make_shm_name(rootinfo_path); + return h; +} + +extern "C" int comm_alloc_windows(CommHandle h, size_t win_size, uint64_t* device_ctx_out) { + if (!h || !device_ctx_out) return -1; + + size_t total = HEADER_SIZE + win_size * static_cast(h->nranks); + + int fd = shm_open(h->shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR, 0600); + if (fd >= 0) { + h->is_creator = true; + if (ftruncate(fd, static_cast(total)) != 0) { + std::perror("comm_sim: ftruncate"); + close(fd); + shm_unlink(h->shm_name.c_str()); + return -1; + } + } else if (errno == EEXIST) { + fd = shm_open(h->shm_name.c_str(), O_RDWR, 0600); + if (fd < 0) { std::perror("comm_sim: shm_open"); return -1; } + + // Wait for creator to finish ftruncate by checking file size + for (int i = 0; i < 5000; ++i) { + struct stat st; + if (fstat(fd, &st) == 0 && static_cast(st.st_size) >= total) break; + usleep(1000); + } + } else { + std::perror("comm_sim: shm_open O_EXCL"); + return -1; + } + + void* base = mmap(nullptr, total, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + close(fd); + if (base == MAP_FAILED) { std::perror("comm_sim: mmap"); return -1; } + + h->mmap_base = base; + h->mmap_size = total; + + auto* hdr = static_cast(base); + + if (h->is_creator) { + hdr->per_rank_win_size = win_size; + hdr->ready_count = 0; + hdr->barrier_count = 0; + hdr->barrier_phase = 0; + hdr->destroy_count = 0; + __atomic_store_n(&hdr->nranks, h->nranks, __ATOMIC_RELEASE); + __atomic_store_n(&hdr->alloc_done, 1, __ATOMIC_RELEASE); + } else { + while (__atomic_load_n(&hdr->alloc_done, __ATOMIC_ACQUIRE) == 0) { + usleep(100); + } + } + + auto* win_base = static_cast(base) + HEADER_SIZE; + + auto& ctx = h->host_ctx; + ctx.workSpace = 0; + ctx.workSpaceSize = 0; + ctx.rankId = static_cast(h->rank); + ctx.rankNum = static_cast(h->nranks); + ctx.winSize = win_size; + for (int i = 0; i < h->nranks; ++i) { + ctx.windowsIn[i] = reinterpret_cast( + win_base + static_cast(i) * win_size); + } + + *device_ctx_out = reinterpret_cast(&h->host_ctx); + + __atomic_add_fetch(&hdr->ready_count, 1, __ATOMIC_ACQ_REL); + while (__atomic_load_n(&hdr->ready_count, __ATOMIC_ACQUIRE) < h->nranks) { + usleep(100); + } + + return 0; +} + +extern "C" int comm_get_local_window_base(CommHandle h, uint64_t* base_out) { + if (!h || !base_out) return -1; + *base_out = h->host_ctx.windowsIn[h->rank]; + return 0; +} + +extern "C" int comm_barrier(CommHandle h) { + if (!h || !h->mmap_base) return -1; + + auto* hdr = static_cast(h->mmap_base); + int phase = __atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE); + int arrived = __atomic_add_fetch(&hdr->barrier_count, 1, __ATOMIC_ACQ_REL); + + if (arrived == h->nranks) { + __atomic_store_n(&hdr->barrier_count, 0, __ATOMIC_RELEASE); + __atomic_add_fetch(&hdr->barrier_phase, 1, __ATOMIC_ACQ_REL); + } else { + while (__atomic_load_n(&hdr->barrier_phase, __ATOMIC_ACQUIRE) == phase) { + usleep(50); + } + } + + return 0; +} + +extern "C" int comm_destroy(CommHandle h) { + if (!h) return -1; + + if (h->mmap_base) { + auto* hdr = static_cast(h->mmap_base); + int gone = __atomic_add_fetch(&hdr->destroy_count, 1, __ATOMIC_ACQ_REL); + + munmap(h->mmap_base, h->mmap_size); + h->mmap_base = nullptr; + + if (gone >= h->nranks) { + shm_unlink(h->shm_name.c_str()); + } + } + + delete h; + return 0; +} diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp index 74cd2c91..dfd32d4a 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/aicpu/aicpu_executor.cpp @@ -22,6 +22,7 @@ #include "pto_runtime2.h" #include "pto_shared_memory.h" #include "pto_runtime2_types.h" +#include "pto_async_wait.h" // Performance profiling headers #include "aicpu/performance_collector_aicpu.h" @@ -68,6 +69,7 @@ constexpr int32_t STALL_DUMP_WAIT_MAX = 4; constexpr int32_t STALL_DUMP_CORE_MAX = 8; constexpr int32_t PROGRESS_VERBOSE_THRESHOLD = 10; // log every completion for the first N tasks constexpr int32_t PROGRESS_LOG_INTERVAL = 250; // log every N completions after threshold +constexpr int32_t MAX_DEFERRED_RELEASES = 256; static PTO2Runtime *rt{nullptr}; @@ -341,9 +343,12 @@ struct AicpuExecutor { int32_t& completed_this_turn, int32_t& cur_thread_completed, bool& made_progress, + bool& fatal_error, + int32_t& fatal_error_code, PTO2TaskSlotState* deferred_release_slot_states[], int32_t& deferred_release_count, - PTO2LocalReadyBuffer* local_bufs + PTO2LocalReadyBuffer* local_bufs, + PTO2AsyncWaitList& async_wait_list #if PTO2_PROFILING , bool profiling_enabled, @@ -393,6 +398,16 @@ struct AicpuExecutor { // Two-stage completion: mark subtask done, then handle mixed-task completion bool mixed_complete = rt->scheduler.on_subtask_complete(slot_state, subslot); if (mixed_complete) { + int32_t registration_error = PTO2_ERROR_NONE; + if (async_wait_list.register_deferred(slot_state, thread_idx, + registration_error)) { + // Deferred completion is now tracked by async_wait_list. + } else { + if (registration_error != PTO2_ERROR_NONE) { + fatal_error = true; + fatal_error_code = registration_error; + return; + } #if PTO2_SCHED_PROFILING PTO2CompletionStats cstats = rt->scheduler.on_mixed_task_complete(slot_state, thread_idx, local_bufs); notify_edges_total += cstats.fanout_edges; @@ -405,7 +420,7 @@ struct AicpuExecutor { phase_complete_count++; #endif #endif - if (deferred_release_count < 256) { + if (deferred_release_count < MAX_DEFERRED_RELEASES) { deferred_release_slot_states[deferred_release_count++] = &slot_state; } else { DEV_ALWAYS("Thread %d: release", thread_idx); @@ -425,6 +440,7 @@ struct AicpuExecutor { } deferred_release_slot_states[deferred_release_count++] = &slot_state; } + } } tracker.change_core_state(bit_pos); #if PTO2_PROFILING @@ -475,7 +491,7 @@ struct AicpuExecutor { expected_reg_task_id, mixed_complete ? 1 : 0); cur_thread_completed++; - if (mixed_complete) { + if (mixed_complete && slot_state.payload != nullptr && !slot_state.payload->complete_in_future) { completed_this_turn++; } made_progress = true; @@ -1065,6 +1081,8 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa PTO2TaskSlotState* deferred_release_slot_states[256]; int32_t deferred_release_count = 0; + PTO2AsyncWaitList async_wait_list; + bool cores_released = false; #if PTO2_PROFILING @@ -1079,7 +1097,7 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa uint64_t _t0_phase = _t0; #endif int32_t task_count = 0; - if (!tracker.has_any_running_cores()) { + if (!tracker.has_any_running_cores() && async_wait_list.count == 0) { bool orch_done = orchestrator_done_; if (orch_done) { // Check for orchestrator fatal error — exit immediately @@ -1130,8 +1148,43 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa // Sched time = finish_ts - dispatch_ts; recording finish_ts here at loop start reduces // tail overhead (time from AICore done to AICPU recording finish). + // Invariant: previous iteration fully consumed local_bufs + always_assert(local_bufs[0].count == 0 && local_bufs[1].count == 0); + + // Phase 0: Poll async completion conditions (deferred-completion tasks) + int32_t async_completed_this_turn = 0; + if (async_wait_list.count > 0) { + PTO2AsyncPollResult poll_result = async_wait_list.poll_and_complete( + &rt->scheduler, local_bufs, + deferred_release_slot_states, deferred_release_count, MAX_DEFERRED_RELEASES +#if PTO2_SCHED_PROFILING + , thread_idx +#endif + ); + if (poll_result.error_code != PTO2_ERROR_NONE) { + int32_t failed_task = -1; + if (poll_result.failed_slot_state != nullptr + && poll_result.failed_slot_state->task != nullptr) { + failed_task = static_cast( + poll_result.failed_slot_state->task->mixed_task_id.local()); + } + DEV_ERROR("Thread %d: async poll failed for task %d with error code %d", + thread_idx, failed_task, poll_result.error_code); + pto2_record_scheduler_error(header, thread_idx, poll_result.error_code); + emergency_shutdown(runtime); + completed_.store(true, std::memory_order_release); + return -1; + } + async_completed_this_turn = poll_result.completed; + if (async_completed_this_turn > 0) { + made_progress = true; + } + } + // Phase 1: Check running cores for completion, process and move to idle - int32_t completed_this_turn = 0; + int32_t completed_this_turn = async_completed_this_turn; + bool fatal_error = false; + int32_t fatal_error_code = PTO2_ERROR_NONE; // Check AIC running cores bool try_completed = false; @@ -1140,8 +1193,9 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa check_running_cores_for_completion( thread_idx, hank, completed_this_turn, cur_thread_completed, made_progress, + fatal_error, fatal_error_code, deferred_release_slot_states, deferred_release_count, - local_bufs + local_bufs, async_wait_list #if PTO2_PROFILING , profiling_enabled, phase_complete_count #endif @@ -1151,6 +1205,14 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa fanin_edges_total, fanin_max_degree, sched_complete_perf_cycle #endif ); + if (fatal_error) { + DEV_ERROR("Thread %d: async registration failed with error code %d", + thread_idx, fatal_error_code); + pto2_record_scheduler_error(header, thread_idx, fatal_error_code); + emergency_shutdown(runtime); + completed_.store(true, std::memory_order_release); + return -1; + } } // Check AIV running cores @@ -1159,8 +1221,9 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa check_running_cores_for_completion( thread_idx, hank, completed_this_turn, cur_thread_completed, made_progress, + fatal_error, fatal_error_code, deferred_release_slot_states, deferred_release_count, - local_bufs + local_bufs, async_wait_list #if PTO2_PROFILING , profiling_enabled, phase_complete_count #endif @@ -1170,6 +1233,14 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa fanin_edges_total, fanin_max_degree, sched_complete_perf_cycle #endif ); + if (fatal_error) { + DEV_ERROR("Thread %d: async registration failed with error code %d", + thread_idx, fatal_error_code); + pto2_record_scheduler_error(header, thread_idx, fatal_error_code); + emergency_shutdown(runtime); + completed_.store(true, std::memory_order_release); + return -1; + } } if (completed_this_turn > 0) { #if PTO2_SCHED_PROFILING @@ -1350,6 +1421,7 @@ int32_t AicpuExecutor::resolve_and_dispatch_pto2(Runtime* runtime, int32_t threa int32_t c = completed_tasks_.load(std::memory_order_relaxed); DEV_ALWAYS("PTO2 stall: no progress for %d iterations, completed=%d total=%d (last progress at %d)", idle_iterations, c, task_count, last_progress_count); + async_wait_list.dump(thread_idx, STALL_DUMP_WAIT_MAX); // Scan all task slots to find truly stuck tasks using scheduler state PTO2SchedulerState* sched = &rt->scheduler; PTO2SharedMemoryHeader* sm_header_diag = static_cast(sm_base); @@ -1767,6 +1839,12 @@ int32_t AicpuExecutor::run(Runtime* runtime) { // Fanout fill-in in complete_perf_records is disabled (slot_states_ptr = nullptr). runtime->set_pto2_slot_states_ptr(nullptr); + // Pass async context addresses from host-side Runtime to device PTO2Runtime + for (int e = 0; e < PTO2_NUM_ASYNC_ENGINES; e++) { + rt->async_context_addrs[e] = runtime->get_async_context_addr( + static_cast(e)); + } + // Store shared state for other orchestrator threads orch_func_ = orch_func; orch_bind_runtime_ = bind_runtime_func; diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/build_config.py b/src/a2a3/runtime/tensormap_and_ringbuffer/build_config.py index cb0758f7..7935a9f1 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/build_config.py +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/build_config.py @@ -1,3 +1,5 @@ +import os + # Tensormap and Ringbuffer Runtime build configuration # All paths are relative to this file's directory (src/runtime/tensormap_and_ringbuffer/) # @@ -10,6 +12,32 @@ # runtime targets AND the orchestration .so (e.g., tensor methods needed # by the Tensor constructor's validation logic). +def _resolve_pto_isa_include_dir() -> str: + env_root = os.environ.get("PTO_ISA_ROOT") + if env_root: + include_dir = os.path.join(env_root, "include") + if not os.path.isdir(include_dir): + raise RuntimeError( + f"PTO_ISA_ROOT is set but include directory does not exist: {include_dir}\n" + "Please point PTO_ISA_ROOT to the pto-isa repository root." + ) + return include_dir + + fallback_root = os.path.join(os.path.dirname(__file__), "../../../../3rd/pto-isa") + fallback_include = os.path.join(fallback_root, "include") + if os.path.isdir(fallback_include): + return "../../../../3rd/pto-isa/include" + + raise RuntimeError( + "PTO_ISA_ROOT is not set and the default fallback path does not exist:\n" + f" {fallback_include}\n" + "Please export PTO_ISA_ROOT to the pto-isa repository root, for example:\n" + " export PTO_ISA_ROOT=/path/to/pto-isa" + ) + + +PTO_ISA_INCLUDE_DIR = _resolve_pto_isa_include_dir() + BUILD_CONFIG = { "aicore": { "include_dirs": ["runtime"], @@ -20,7 +48,7 @@ "source_dirs": ["aicpu", "runtime", "orchestration"] }, "host": { - "include_dirs": ["runtime"], + "include_dirs": ["runtime", PTO_ISA_INCLUDE_DIR], "source_dirs": ["host", "runtime", "orchestration"] }, "orchestration": { diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp index 3b3c5e33..e901d7ff 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp @@ -28,6 +28,13 @@ #include #include +#if __has_include("pto/npu/comm/async/sdma/sdma_workspace_manager.hpp") && __has_include("acl/acl.h") +#include "pto/npu/comm/async/sdma/sdma_workspace_manager.hpp" +#define PTO2_HAS_SDMA_WORKSPACE_MANAGER 1 +#else +#define PTO2_HAS_SDMA_WORKSPACE_MANAGER 0 +#endif + // Helper: return current time in milliseconds static long long _now_ms() { struct timeval tv; @@ -311,6 +318,25 @@ extern "C" int init_runtime_impl(Runtime *runtime, runtime->set_pto2_gm_sm_ptr(sm_ptr); runtime->record_tensor_pair(nullptr, sm_ptr, static_cast(sm_size)); + // SDMA async context initialization (controlled by PTO2_ENABLE_SDMA env var) +#if PTO2_HAS_SDMA_WORKSPACE_MANAGER + { + const char* env_sdma = std::getenv("PTO2_ENABLE_SDMA"); + if (env_sdma && env_sdma[0] == '1' && env_sdma[1] == '\0') { + LOG_INFO("SDMA async context init requested (PTO2_ENABLE_SDMA=1)"); + static pto::comm::sdma::SdmaWorkspaceManager sdma_manager; + if (sdma_manager.Init()) { + uint64_t ws_addr = reinterpret_cast(sdma_manager.GetWorkspaceAddr()); + runtime->set_async_context_addr(PTO2_ASYNC_ENGINE_SDMA, ws_addr); + LOG_INFO("SDMA async context initialized: addr=0x%lx", (unsigned long)ws_addr); + } else { + LOG_WARN("SDMA async context initialization failed, continuing without SDMA support"); + runtime->set_async_context_addr(PTO2_ASYNC_ENGINE_SDMA, 0); + } + } + } +#endif + // Set up device orchestration state runtime->set_orch_built_on_host(false); runtime->set_orch_args(device_args, func_args_count); diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h index 04bafa31..05cfe2de 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/orchestration/pto_orchestration_api.h @@ -23,7 +23,7 @@ // Type headers needed by orchestration #include "tensor.h" // Tensor -#include "pto_types.h" // PTOParam, PTOTensorEntry, PTOParamType +#include "pto_types.h" // PTOParam, PTOTensorEntry, PTOParamType, PTO2AsyncEngine #include "pto_submit_types.h" // MixedKernels, INVALID_KERNEL_ID, subtask slots // ============================================================================= @@ -104,6 +104,8 @@ void pto2_framework_bind_runtime(PTO2Runtime* rt); typedef struct PTO2RuntimeOps { void (*submit_task)(PTO2Runtime* rt, const MixedKernels& mixed_kernels, const PTOParam& params); + uint64_t (*get_async_context)(PTO2Runtime* rt, PTO2AsyncEngine engine); + uint64_t (*alloc_cq)(PTO2Runtime* rt); void (*scope_begin)(PTO2Runtime* rt); void (*scope_end)(PTO2Runtime* rt); void (*orchestration_done)(PTO2Runtime* rt); @@ -161,6 +163,129 @@ static inline void pto2_rt_submit_aiv_task(int32_t kernel_id, const PTOParam& pa mk.aiv0_kernel_id = kernel_id; rt->ops->submit_task(rt, mk, params); } +static inline uint64_t pto2_rt_get_async_context(PTO2AsyncEngine engine) { + PTO2Runtime* rt = pto2_current_runtime(); + return rt->ops->get_async_context(rt, engine); +} + +static inline uint64_t pto2_rt_get_async_context(PTO2Runtime* rt, PTO2AsyncEngine engine) { + return rt->ops->get_async_context(rt, engine); +} + +static inline uint64_t pto2_rt_get_sdma_context() { + return pto2_rt_get_async_context(PTO2_ASYNC_ENGINE_SDMA); +} + +static inline uint64_t pto2_rt_get_sdma_context(PTO2Runtime* rt) { + return pto2_rt_get_async_context(rt, PTO2_ASYNC_ENGINE_SDMA); +} + +// ============================================================================= +// CQ Model: Deferred Completion Wrappers +// ============================================================================= + +/** + * Allocate a zeroed per-task completion queue from the runtime pool. + * Returns the GM address (cast to uint64_t) or 0 on failure. + */ +static inline uint64_t pto2_rt_alloc_cq() { + PTO2Runtime* rt = pto2_current_runtime(); + return rt->ops->alloc_cq(rt); +} + +static inline uint64_t pto2_rt_alloc_cq(PTO2Runtime* rt) { + return rt->ops->alloc_cq(rt); +} + +/** + * Submit an AIV task with deferred completion (CQ model). + * + * The kernel decides at runtime how many async completions it has + * and writes them into the completion queue. The scheduler discovers + * completions by reading the CQ after all subtasks return. + * + * The CQ address is automatically appended as the last scalar + * so the kernel can access it from args[]. + */ +static inline void pto2_rt_submit_aiv_task_deferred(int32_t kernel_id, + PTOParam& params, + uint64_t cq_addr) { + PTO2Runtime* rt = pto2_current_runtime(); + params.complete_in_future = true; + params.cq_addr = cq_addr; + params.add_scalar(cq_addr); + MixedKernels mk; + mk.aiv0_kernel_id = kernel_id; + rt->ops->submit_task(rt, mk, params); +} + +static inline void pto2_rt_submit_aiv_task_deferred(PTO2Runtime* rt, + int32_t kernel_id, + PTOParam& params, + uint64_t cq_addr) { + params.complete_in_future = true; + params.cq_addr = cq_addr; + params.add_scalar(cq_addr); + MixedKernels mk; + mk.aiv0_kernel_id = kernel_id; + rt->ops->submit_task(rt, mk, params); +} + +static inline void pto2_rt_submit_aic_task_deferred(int32_t kernel_id, + PTOParam& params, + uint64_t cq_addr) { + PTO2Runtime* rt = pto2_current_runtime(); + params.complete_in_future = true; + params.cq_addr = cq_addr; + params.add_scalar(cq_addr); + MixedKernels mk; + mk.aic_kernel_id = kernel_id; + rt->ops->submit_task(rt, mk, params); +} + +static inline void pto2_rt_submit_task_deferred(const MixedKernels& mixed_kernels, + PTOParam& params, + uint64_t cq_addr) { + PTO2Runtime* rt = pto2_current_runtime(); + params.complete_in_future = true; + params.cq_addr = cq_addr; + params.add_scalar(cq_addr); + rt->ops->submit_task(rt, mixed_kernels, params); +} + +/** + * Submit a notification-wait deferred task and return a dependency token. + * + * Encapsulates the boilerplate for creating a NotifyWait task: + * 1. Allocate a CQ + * 2. Create a 1-element dummy output tensor (dependency token) + * 3. Submit a deferred AIV task with (counter_addr, expected_value, cq_addr) + * + * The returned token tensor should be added as an input to any downstream + * task that depends on the notification completing. + * + * @param kernel_id func_id of the NotifyWait kernel + * @param counter_addr GM address of the notification counter (int32*) + * @param expected_value threshold: task completes when *counter >= expected + * @return dependency token tensor (add as input to downstream tasks) + */ +static inline Tensor pto2_rt_submit_notification_wait_task( + int32_t kernel_id, + uint64_t counter_addr, + uint32_t expected_value) { + uint64_t cq_addr = pto2_rt_alloc_cq(); + + uint32_t dummy_shape[1] = { 1 }; + Tensor token = make_tensor(dummy_shape, 1, DataType::INT32); + + PTOParam params; + params.add_output(token); + params.add_scalar(counter_addr); + params.add_scalar(static_cast(expected_value)); + pto2_rt_submit_aiv_task_deferred(kernel_id, params, cq_addr); + + return token; +} static inline void pto2_rt_scope_begin() { PTO2Runtime* rt = pto2_current_runtime(); diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h new file mode 100644 index 00000000..55618043 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_async_wait.h @@ -0,0 +1,359 @@ +/** + * PTO Runtime2 - Async Completion Wait List + * + * Lightweight watch-list abstraction for deferred task completion. + * + * The scheduler polls two logical protocols described in docs/runtime_async.md: + * - CQ protocol: poll *counter_addr >= expected_value (unified COUNTER type) + * - Notification protocol: poll a GM counter until it reaches expected_value + * + * All completion conditions use a single COUNTER type. Hardware event flags + * (e.g. SDMA completion flags) are the special case where expected_value = 1. + * + * The scheduler polls this list each iteration (Phase 0) and triggers + * on_mixed_task_complete for tasks whose conditions are all satisfied. + * + * Design reference: docs/runtime_async.md + */ + +#ifndef PTO_ASYNC_WAIT_H +#define PTO_ASYNC_WAIT_H + +#include +#include "pto_runtime2_types.h" +#include "pto_scheduler.h" + +extern void cache_invalidate_range(const void* addr, size_t size); + +inline constexpr int32_t PTO2_MAX_ASYNC_WAITS = 64; + +enum class PTO2CompletionPollState : uint8_t { + PENDING = 0, + READY = 1, + FAILED = 2, +}; + +struct PTO2CompletionPollResult { + PTO2CompletionPollState state{PTO2CompletionPollState::PENDING}; + int32_t error_code{PTO2_ERROR_NONE}; +}; + +struct PTO2CompletionCondition { + PTO2AsyncEngine engine{PTO2_ASYNC_ENGINE_SDMA}; + bool satisfied{false}; + volatile uint32_t* counter_addr{nullptr}; + uint32_t expected_value{0}; + + PTO2CompletionPollResult test() const { + if (satisfied) { + return {PTO2CompletionPollState::READY, PTO2_ERROR_NONE}; + } + if (counter_addr == nullptr) { + return {PTO2CompletionPollState::FAILED, PTO2_ERROR_ASYNC_COMPLETION_INVALID}; + } + return {*counter_addr >= expected_value ? PTO2CompletionPollState::READY + : PTO2CompletionPollState::PENDING, + PTO2_ERROR_NONE}; + } +}; + +template +#if PTO2_SCHED_PROFILING +static inline PTO2CompletionStats pto2_complete_task( +#else +static inline void pto2_complete_task( +#endif + PTO2SchedulerState* sched, + PTO2TaskSlotState& slot_state, + PTO2LocalReadyBuffer* local_bufs, + PTO2TaskSlotState** deferred_release_slot_states, + int32_t& deferred_release_count +#if PTO2_SCHED_PROFILING + , int thread_idx +#endif + ) { +#if PTO2_SCHED_PROFILING + PTO2CompletionStats stats = sched->on_mixed_task_complete(slot_state, thread_idx, local_bufs); +#else + sched->on_mixed_task_complete(slot_state, local_bufs); +#endif + deferred_release_slot_states[deferred_release_count++] = &slot_state; +#if PTO2_SCHED_PROFILING + return stats; +#endif +} + +// ============================================================================= +// Async Wait Entry (one per deferred task) +// ============================================================================= + +struct PTO2AsyncWaitEntry { + PTO2TaskSlotState* slot_state{nullptr}; + PTO2CompletionCondition conditions[PTO2_MAX_COMPLETIONS_PER_TASK]; + int32_t condition_count{0}; + int32_t waiting_completion_count{0}; +}; + +struct PTO2AsyncPollResult { + int32_t completed{0}; + int32_t error_code{PTO2_ERROR_NONE}; + PTO2TaskSlotState* failed_slot_state{nullptr}; +}; + +// ============================================================================= +// Name helpers (used by dump / diagnostics) +// ============================================================================= + +inline const char* pto2_async_engine_name(PTO2AsyncEngine engine) { + switch (engine) { + case PTO2_ASYNC_ENGINE_SDMA: return "SDMA"; + case PTO2_ASYNC_ENGINE_ROCE: return "ROCE"; + case PTO2_ASYNC_ENGINE_URMA: return "URMA"; + case PTO2_ASYNC_ENGINE_CCU: return "CCU"; + default: return "UNKNOWN"; + } +} + +// ============================================================================= +// Async Wait List (managed by scheduler thread) +// ============================================================================= + +struct PTO2AsyncWaitList { + PTO2AsyncWaitEntry entries[PTO2_MAX_ASYNC_WAITS]; + int32_t count{0}; + + /** + * Find or create an entry for the given slot_state. + * Returns pointer to the entry, or nullptr if full. + */ + PTO2AsyncWaitEntry* find_or_create(PTO2TaskSlotState* slot_state) { + for (int32_t i = 0; i < count; i++) { + if (entries[i].slot_state == slot_state) { + return &entries[i]; + } + } + if (count >= PTO2_MAX_ASYNC_WAITS) { + return nullptr; + } + PTO2AsyncWaitEntry& e = entries[count++]; + e.slot_state = slot_state; + e.condition_count = 0; + e.waiting_completion_count = 0; + return &e; + } + + bool add_counter(PTO2TaskSlotState* slot_state, + volatile uint32_t* counter_addr, + uint32_t expected_value, + PTO2AsyncEngine engine = PTO2_ASYNC_ENGINE_SDMA) { + PTO2AsyncWaitEntry* entry = find_or_create(slot_state); + if (!entry || counter_addr == nullptr + || entry->condition_count >= PTO2_MAX_COMPLETIONS_PER_TASK) { + return false; + } + PTO2CompletionCondition& cond = entry->conditions[entry->condition_count++]; + cond.engine = engine; + cond.satisfied = false; + cond.counter_addr = counter_addr; + cond.expected_value = expected_value; + entry->waiting_completion_count++; + return true; + } + + /** + * Poll all entries. For each satisfied condition, decrement waiting_completion_count. + * When an entry's count reaches zero, call on_mixed_task_complete and add to + * deferred_release. Remove completed entries by swap-with-last. + * + * Returns the number of tasks that completed this call. + */ + template + PTO2AsyncPollResult poll_and_complete( + PTO2SchedulerState* sched, + PTO2LocalReadyBuffer* local_bufs, + PTO2TaskSlotState** deferred_release_slot_states, + int32_t& deferred_release_count, + int32_t deferred_release_capacity +#if PTO2_SCHED_PROFILING + , int thread_idx +#endif + ) { + PTO2AsyncPollResult result; + for (int32_t i = count - 1; i >= 0; --i) { + PTO2AsyncWaitEntry& entry = entries[i]; + + for (int32_t c = 0; c < entry.condition_count; c++) { + PTO2CompletionCondition& cond = entry.conditions[c]; + if (!cond.satisfied) { + // All current counter writers (SDMA engine flags, TNOTIFY + // RDMA atomics) bypass AICPU data cache. Invalidation is + // needed so the poll reads the true GM value. For any + // hypothetical CPU-written counter this is a harmless no-op. + if (cond.counter_addr) { + cache_invalidate_range( + reinterpret_cast(const_cast(cond.counter_addr)), + sizeof(uint32_t)); + } + PTO2CompletionPollResult poll = cond.test(); + if (poll.state == PTO2CompletionPollState::FAILED) { + result.error_code = poll.error_code; + result.failed_slot_state = entry.slot_state; + return result; + } + if (poll.state == PTO2CompletionPollState::READY) { + cond.satisfied = true; + entry.waiting_completion_count--; + } + } + } + + if (entry.waiting_completion_count <= 0) { + if (deferred_release_count >= deferred_release_capacity) { + result.error_code = PTO2_ERROR_ASYNC_WAIT_OVERFLOW; + result.failed_slot_state = entry.slot_state; + return result; + } +#if PTO2_SCHED_PROFILING + auto stats = pto2_complete_task( + sched, + *entry.slot_state, + local_bufs, + deferred_release_slot_states, + deferred_release_count, + thread_idx + ); + (void)stats; +#else + pto2_complete_task( + sched, + *entry.slot_state, + local_bufs, + deferred_release_slot_states, + deferred_release_count + ); +#endif + result.completed++; + + // Swap-remove: replace with last entry + int32_t last = count - 1; + if (i != last) { + entries[i] = entries[last]; + } + count = last; + } + } + return result; + } + /** + * Register deferred completions for a task from its CQ. + * + * Reads the kernel-written PTO2CompletionQueue and registers each entry + * as a COUNTER wait condition. Returns true when at least one condition + * was registered (task is now tracked by the wait list). On error, + * error_code is set to a non-zero PTO2_ERROR_* value. + */ + bool register_deferred(PTO2TaskSlotState& slot_state, + int32_t thread_idx, int32_t& error_code) { + (void)thread_idx; + error_code = PTO2_ERROR_NONE; + PTO2TaskPayload* payload = slot_state.payload; + if (payload == nullptr || !payload->complete_in_future) return false; + + if (payload->cq_addr == 0) { +#ifdef DEV_ERROR + DEV_ERROR("Thread %d: complete_in_future=true but no CQ entries for task %d", + thread_idx, + static_cast(slot_state.task->mixed_task_id.local())); +#endif + error_code = PTO2_ERROR_ASYNC_COMPLETION_INVALID; + return false; + } + + volatile PTO2CompletionQueue* cq = reinterpret_cast( + static_cast(payload->cq_addr)); + // AICore kernel flushes its cache (dcci) before returning, but the + // AICPU may still hold a stale cache line for this CQ. Invalidate + // before reading so we see the kernel's writes. + cache_invalidate_range( + const_cast(reinterpret_cast(cq)), + sizeof(PTO2CompletionQueue)); + int32_t cq_count = cq->count; + if (cq_count <= 0) { +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: task %d CQ addr=0x%lx count=0, completing immediately", + thread_idx, + static_cast(slot_state.task->mixed_task_id.local()), + payload->cq_addr); +#endif + return false; + } + if (cq_count > PTO2_CQ_MAX_ENTRIES) { +#ifdef DEV_ERROR + DEV_ERROR("Thread %d: CQ count=%d exceeds max %d for task %d", + thread_idx, cq_count, PTO2_CQ_MAX_ENTRIES, + static_cast(slot_state.task->mixed_task_id.local())); +#endif + error_code = PTO2_ERROR_ASYNC_COMPLETION_INVALID; + return false; + } +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: task %d reading CQ addr=0x%lx count=%d", + thread_idx, static_cast(slot_state.task->mixed_task_id.local()), + payload->cq_addr, cq_count); +#endif + for (int32_t i = 0; i < cq_count; ++i) { + const volatile PTO2CQEntry& entry = cq->entries[i]; +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: task %d CQ[%d] engine=%s(%d) addr=0x%lx expected=%u", + thread_idx, + static_cast(slot_state.task->mixed_task_id.local()), + i, + pto2_async_engine_name(static_cast(entry.engine)), + static_cast(entry.engine), + entry.addr, + entry.expected_value); +#endif + volatile uint32_t* counter_addr = reinterpret_cast( + static_cast(entry.addr)); + if (!add_counter(&slot_state, counter_addr, entry.expected_value, + static_cast(entry.engine))) { + error_code = PTO2_ERROR_ASYNC_REGISTRATION_FAILED; + return false; + } + } + return true; + } + + /** + * Dump wait list state for stall diagnostics. + */ + void dump(int32_t thread_idx, int32_t max_entries = 4) const { +#ifdef DEV_ALWAYS + DEV_ALWAYS("Thread %d: async_wait_list pending entries=%d", thread_idx, count); + int32_t dump_count = count < max_entries ? count : max_entries; + for (int32_t i = 0; i < dump_count; ++i) { + const PTO2AsyncWaitEntry& entry = entries[i]; + int32_t task_id = -1; + if (entry.slot_state != nullptr && entry.slot_state->task != nullptr) { + task_id = static_cast(entry.slot_state->task->mixed_task_id.local()); + } + DEV_ALWAYS("Thread %d: async_wait[%d] task=%d waiting=%d conditions=%d", + thread_idx, i, task_id, entry.waiting_completion_count, entry.condition_count); + for (int32_t c = 0; c < entry.condition_count; ++c) { + const PTO2CompletionCondition& cond = entry.conditions[c]; + uint32_t value = cond.counter_addr == nullptr ? 0 : *cond.counter_addr; + DEV_ALWAYS("Thread %d: cond[%d] engine=%s satisfied=%d counter_addr=0x%lx value=%u expected=%u", + thread_idx, c, pto2_async_engine_name(cond.engine), + cond.satisfied ? 1 : 0, + static_cast(reinterpret_cast(cond.counter_addr)), + value, cond.expected_value); + } + } +#else + (void)thread_idx; + (void)max_entries; +#endif + } +}; + +#endif // PTO_ASYNC_WAIT_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_kernel_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_kernel_api.h new file mode 100644 index 00000000..03d594ac --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_kernel_api.h @@ -0,0 +1,136 @@ +/** + * PTO CQ Kernel API — inline functions for AICore kernels. + * + * These are NOT AICPU function calls. They are structured GM writes + * that the AICPU scheduler reads after the kernel returns. + * + * All overloads follow the (ENGINE, QUEUE, data...) parameter convention, + * symmetric with pto2_send_request_entry(ENGINE, SQ_ID, desc) in the SQ API. + * + * Usage in kernel code: + * + * auto* cq = pto2_cq_get(args[CQ_ARG_IDX]); + * pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, tag); // flag: expected=1 + * pto2_cq_flush(); + */ + +#ifndef PTO_CQ_KERNEL_API_H +#define PTO_CQ_KERNEL_API_H + +#include "pto_cq_types.h" + +// Requires __gm__ and __aicore__ to be defined before including this header. +// Kernel sources should define them (or include PTO-ISA headers) first. + +// Unified engine constants — shared by SQ and CQ APIs. +// Must match PTO2AsyncEngine in pto_types.h. +#define PTO2_ENGINE_SDMA 0 +#define PTO2_ENGINE_ROCE 1 +#define PTO2_ENGINE_URMA 2 +#define PTO2_ENGINE_CCU 3 + +// Completion type constants (must match PTO2CompletionType in pto_types.h) +#define PTO2_CQ_COMPLETION_COUNTER 0 + +inline __aicore__ void pto2_cq_writeback_gm_line(volatile __gm__ void* addr) { + __gm__ int32_t* gm_addr = (__gm__ int32_t*)addr; +#if defined(SINGLE_CACHE_LINE) && defined(CACHELINE_OUT) + dcci(gm_addr, SINGLE_CACHE_LINE, CACHELINE_OUT); +#elif defined(SINGLE_CACHE_LINE) + dcci(gm_addr, SINGLE_CACHE_LINE); +#endif +#if defined(DSB_DDR) + dsb(DSB_DDR); +#endif +} + +/** + * Obtain the completion queue pointer from a kernel scalar arg. + */ +inline __aicore__ volatile __gm__ PTO2CompletionQueue* pto2_cq_get(uint64_t addr) { + return reinterpret_cast( + static_cast(addr)); +} + +/** + * Reset the CQ header before the kernel appends completion entries. + * + * Runtime-owned CQ buffers may be reused across tasks, so kernels should + * explicitly republish an empty header before the first append. + */ +inline __aicore__ void pto2_cq_reset(volatile __gm__ PTO2CompletionQueue* cq) { + // Republish the header line even when the queue was already zeroed in a + // reused runtime buffer. Some hardware paths were observed to require an + // explicit header-state transition before the subsequent count increment + // became visible to the AICPU scheduler. + cq->count = -1; + pto2_cq_writeback_gm_line(&cq->count); + cq->count = 0; + pto2_cq_writeback_gm_line(&cq->count); +} + +/** + * Register one expected completion condition in the CQ. + * + * All completion conditions are COUNTER type: the scheduler polls + * *addr >= expected_value. Hardware flags (SDMA event flags) are + * the special case where expected_value = 1 (flag goes 0 → non-zero). + * + * Parameter order: (ENGINE, QUEUE, addr, expected) — symmetric with SQ API. + * Each call appends an entry and increments cq->count. + * The caller must ensure total calls per task <= PTO2_CQ_MAX_ENTRIES. + */ +inline __aicore__ void pto2_save_expected_completion( + uint32_t engine, + volatile __gm__ PTO2CompletionQueue* cq, + uint64_t addr, + uint32_t expected_value) +{ + int32_t idx = cq->count; + volatile __gm__ PTO2CQEntry* entry = + const_cast(&cq->entries[idx]); + entry->engine = engine; + entry->completion_type = PTO2_CQ_COMPLETION_COUNTER; + entry->addr = addr; + entry->expected_value = expected_value; + pto2_cq_writeback_gm_line(entry); + + cq->count = idx + 1; + pto2_cq_writeback_gm_line(&cq->count); +} + +/** + * Simplified overload for hardware flags: (ENGINE, CQ, tag). + * + * Registers a COUNTER condition with expected_value=1. + * Equivalent to polling *tag_addr >= 1 (i.e. flag != 0). + * Symmetric with pto2_send_request_entry(ENGINE, SQ_ID, desc). + */ +inline __aicore__ void pto2_save_expected_completion( + uint32_t engine, + volatile __gm__ PTO2CompletionQueue* cq, + uint64_t tag) +{ + pto2_save_expected_completion(engine, cq, tag, 1); +} + +/** + * Final flush before kernel returns. Ensures all CQ writes + * are visible to the AICPU scheduler. + * + * Uses CCE compiler built-in enum constants (cache_line_t, dcci_dst_t, + * dsb_mode_t, pipe_t) which are available when compiling for AICore + * via the bisheng/CCE toolchain. Previous #if-defined guards broke + * because these are C++ enums, not preprocessor macros. + */ +inline __aicore__ void pto2_cq_flush() { + pipe_barrier(PIPE_ALL); +} + +inline __aicore__ void pto2_cq_flush(volatile __gm__ PTO2CompletionQueue* cq) { + dcci((__gm__ int32_t*)cq, cache_line_t::ENTIRE_DATA_CACHE, dcci_dst_t::CACHELINE_OUT); + dsb(DSB_DDR); + pipe_barrier(PIPE_ALL); +} + +#endif // PTO_CQ_KERNEL_API_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_types.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_types.h new file mode 100644 index 00000000..e0c571a8 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_cq_types.h @@ -0,0 +1,47 @@ +/** + * PTO Completion Queue Types — shared between AICore kernels and AICPU runtime. + * + * This header must remain simple and C-compatible. AICore compilation + * environments have restricted standard library access. + */ + +#ifndef PTO_CQ_TYPES_H +#define PTO_CQ_TYPES_H + +#include + +#define PTO2_CQ_MAX_ENTRIES 64 + +/** + * Single CQ entry written by a kernel via pto2_save_expected_completion(). + * The scheduler reads these after the worker returns. + */ +struct PTO2CQEntry { + uint32_t engine; // PTO2AsyncEngine value + int32_t completion_type; // PTO2CompletionType value + uint64_t addr; // completion token (flag/handle/counter GM address) + uint32_t expected_value; // for COUNTER completions + uint32_t _pad; +}; + +/** + * Per-task completion queue. + * + * Allocated by the runtime and passed to the kernel as a scalar arg. + * The kernel calls pto2_save_expected_completion() to append entries + * and increment `count`. The scheduler reads the CQ after all + * subtasks have returned and creates completion conditions accordingly. + * + * Memory ordering contract: + * - Kernel writes entries[i] fields BEFORE incrementing count. + * - Kernel flushes caches (dcci+dsb on HW) before returning. + * - Scheduler reads only after detecting task_status==0, + * which implies all kernel writes are visible. + */ +struct PTO2CompletionQueue { + volatile int32_t count; // entries written so far (kernel increments) + int32_t _pad; + PTO2CQEntry entries[PTO2_CQ_MAX_ENTRIES]; +}; + +#endif // PTO_CQ_TYPES_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_notify_kernel_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_notify_kernel_api.h new file mode 100644 index 00000000..77110559 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_notify_kernel_api.h @@ -0,0 +1,52 @@ +/** + * PTO Notify Kernel API — notification counter abstraction for AICore kernels. + * + * This wraps PTO-ISA TNOTIFY and maps the local counter wait condition onto + * the runtime's existing COUNTER deferred-completion path. + * + * Requires: + * - PTO-ISA headers included before this header + * - __gm__ and __aicore__ defined before this header + */ + +#ifndef PTO_NOTIFY_KERNEL_API_H +#define PTO_NOTIFY_KERNEL_API_H + +#include "pto_cq_kernel_api.h" + +#include + +enum class PTO2NotifyOp : uint32_t { + Set = 0, + AtomicAdd = 1, +}; + +inline __aicore__ pto::comm::NotifyOp pto2_to_notify_op(PTO2NotifyOp op) { + return op == PTO2NotifyOp::Set + ? pto::comm::NotifyOp::Set + : pto::comm::NotifyOp::AtomicAdd; +} + +inline __aicore__ void pto2_send_notification( + volatile __gm__ int32_t* remote_counter_addr, + int32_t value = 1, + PTO2NotifyOp op = PTO2NotifyOp::AtomicAdd) +{ + pto::comm::Signal signal((__gm__ int32_t*)remote_counter_addr); + pto::comm::TNOTIFY(signal, value, pto2_to_notify_op(op)); +#if defined(PIPE_ALL) + pipe_barrier(PIPE_ALL); +#endif +} + +inline __aicore__ void pto2_save_expected_notification_counter( + volatile __gm__ PTO2CompletionQueue* cq, + volatile __gm__ int32_t* local_counter_addr, + uint32_t expected_value) +{ + pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, + (uint64_t)local_counter_addr, + expected_value); +} + +#endif // PTO_NOTIFY_KERNEL_API_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp index 19807408..567c0a4c 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.cpp @@ -7,6 +7,7 @@ */ #include "pto_runtime2.h" +#include "pto_async_wait.h" #include #include #include @@ -32,6 +33,21 @@ static void submit_task_impl(PTO2Runtime* rt, const MixedKernels& mixed_kernels, params); } +static uint64_t get_async_context_impl(PTO2Runtime* rt, PTO2AsyncEngine engine) { + if (engine >= PTO2_NUM_ASYNC_ENGINES) return 0; + return rt->async_context_addrs[engine]; +} + +static uint64_t alloc_cq_impl(PTO2Runtime* rt) { + if (!rt->cq_pool || rt->cq_pool_next >= rt->cq_pool_size) { + return 0; + } + int32_t idx = rt->cq_pool_next++; + PTO2CompletionQueue* cq = &rt->cq_pool[idx]; + memset(cq, 0, sizeof(PTO2CompletionQueue)); + return reinterpret_cast(cq); +} + void pto2_rt_scope_begin(PTO2Runtime* rt) { pto2_scope_begin(&rt->orchestrators[pto2_current_orch_idx]); } @@ -50,6 +66,8 @@ static bool is_fatal_impl(PTO2Runtime* rt) { static const PTO2RuntimeOps s_runtime_ops = { .submit_task = submit_task_impl, + .get_async_context = get_async_context_impl, + .alloc_cq = alloc_cq_impl, .scope_begin = pto2_rt_scope_begin, .scope_end = pto2_rt_scope_end, .orchestration_done = pto2_rt_orchestration_done, @@ -130,6 +148,12 @@ PTO2Runtime* pto2_runtime_create_custom(PTO2RuntimeMode mode, // Connect orchestrator to scheduler (for simulated mode) pto2_orchestrator_set_scheduler(&rt->orchestrators[0], &rt->scheduler); + // Allocate per-task completion queue pool + rt->cq_pool_size = PTO2_MAX_ASYNC_WAITS; + rt->cq_pool = static_cast( + calloc(rt->cq_pool_size, sizeof(PTO2CompletionQueue))); + rt->cq_pool_next = 0; + return rt; } @@ -180,6 +204,12 @@ PTO2Runtime* pto2_runtime_create_from_sm(PTO2RuntimeMode mode, pto2_orchestrator_set_scheduler(&rt->orchestrators[i], &rt->scheduler); } + // Allocate per-task completion queue pool + rt->cq_pool_size = PTO2_MAX_ASYNC_WAITS; + rt->cq_pool = static_cast( + calloc(rt->cq_pool_size, sizeof(PTO2CompletionQueue))); + rt->cq_pool_next = 0; + return rt; } @@ -199,6 +229,7 @@ void pto2_runtime_destroy(PTO2Runtime* rt) { pto2_sm_destroy(rt->sm_handle); } + free(rt->cq_pool); free(rt); } diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h index bec4d2f8..b2de940e 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2.h @@ -32,6 +32,7 @@ #include "pto_tensormap.h" #include "pto_scheduler.h" #include "pto_orchestrator.h" +#include "pto_cq_types.h" // Maximum number of orchestrator threads supported constexpr int PTO2_MAX_ORCH_THREADS = 4; @@ -61,6 +62,8 @@ typedef struct PTO2Runtime PTO2Runtime; // forward declare for ops signatures struct PTO2RuntimeOps { void (*submit_task)(PTO2Runtime* rt, const MixedKernels& mixed_kernels, const PTOParam& params); + uint64_t (*get_async_context)(PTO2Runtime* rt, PTO2AsyncEngine engine); + uint64_t (*alloc_cq)(PTO2Runtime* rt); void (*scope_begin)(PTO2Runtime* rt); void (*scope_end)(PTO2Runtime* rt); void (*orchestration_done)(PTO2Runtime* rt); @@ -98,6 +101,14 @@ struct PTO2Runtime { // Mode PTO2RuntimeMode mode; + // Per-engine async context addresses (0 = not available) + uint64_t async_context_addrs[PTO2_NUM_ASYNC_ENGINES]{}; + + // Per-task completion queue pool (pre-allocated, bump-allocated per deferred task) + PTO2CompletionQueue* cq_pool{nullptr}; + int32_t cq_pool_size{0}; + int32_t cq_pool_next{0}; + // Statistics int64_t total_cycles; }; diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h index 2d6566b1..35af8f44 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_runtime2_types.h @@ -69,6 +69,10 @@ // Scheduler errors (100+): detected in scheduler threads #define PTO2_ERROR_SCHEDULER_TIMEOUT 100 +#define PTO2_ERROR_ASYNC_REGISTRATION_FAILED 101 +#define PTO2_ERROR_ASYNC_WAIT_OVERFLOW 102 +#define PTO2_ERROR_ASYNC_COMPLETION_INVALID 103 +#define PTO2_ERROR_ASYNC_COMPLETION_FAILED 104 // ============================================================================= // Configuration Constants @@ -359,7 +363,8 @@ struct PTO2TaskPayload { int32_t tensor_count{0}; int32_t scalar_count{0}; int32_t fanin_actual_count{0}; // Actual fanin count (without the +1 redundance) - int32_t _reserved{0}; // Reserved (dep_pool_mark moved to SlotState for local access) + bool complete_in_future{false}; // CQ model: kernel decides completions at runtime + uint64_t cq_addr{0}; // CQ model: completion queue address for kernel to write PTO2TaskSlotState* fanin_slot_states[PTO2_MAX_INPUTS]; // Producer slot states (used by on_task_release) // === Tensors (2048B) — alignas(64) Tensor forces alignment === Tensor tensors[PTO2_MAX_TENSOR_PARAMS]; @@ -374,6 +379,8 @@ struct PTO2TaskPayload { void init(const PTOParam& params) { tensor_count = params.tensor_count; scalar_count = params.scalar_count; + complete_in_future = params.complete_in_future; + cq_addr = params.cq_addr; // 1. Copy tensors from PTOParam auto src_tensors = params.tensors; diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h index 8cf5ce0e..9b6be052 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_scheduler.h @@ -27,6 +27,8 @@ #include "common/core_type.h" +struct PTO2SchedulerState; + #if PTO2_SCHED_PROFILING #include "aicpu/device_time.h" #define PTO2_SCHED_CYCLE_START() uint64_t _st0 = get_sys_cnt_aicpu(), _st1 @@ -430,8 +432,6 @@ struct PTO2SchedulerState { int32_t new_refcount = slot_state.fanin_refcount.fetch_add(1, std::memory_order_acq_rel) + 1; if (new_refcount == slot_state.fanin_count) { - // Local-first: try per-CoreType thread-local buffer before global queue - // Route by active_mask: AIC-containing tasks → buf[0], AIV-only → buf[1] PTO2ResourceShape shape = pto2_active_mask_to_shape(slot_state.active_mask); if (!local_bufs || !local_bufs[static_cast(shape)].try_push(&slot_state)) { ready_queues[static_cast(shape)].push(&slot_state); diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h index e3ada51f..c8dbb3a3 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_shared_memory.h @@ -112,6 +112,15 @@ struct alignas(PTO2_ALIGN_SIZE) PTO2SharedMemoryHeader { static_assert(sizeof(PTO2SharedMemoryHeader) % PTO2_ALIGN_SIZE == 0, "PTO2SharedMemoryHeader must be aligned to cache line (PTO2_ALIGN_SIZE)"); +static inline void pto2_record_scheduler_error( + PTO2SharedMemoryHeader* header, int32_t thread_idx, int32_t error_code) { + if (header == nullptr) return; + header->sched_error_bitmap.fetch_or(1u << thread_idx, std::memory_order_acq_rel); + header->sched_error_code.store(error_code, std::memory_order_release); + header->sched_error_thread.store(thread_idx, std::memory_order_release); + header->orch_error_code.store(error_code, std::memory_order_release); +} + // ============================================================================= // Shared Memory Handle // ============================================================================= diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_sq_kernel_api.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_sq_kernel_api.h new file mode 100644 index 00000000..1c7d8f25 --- /dev/null +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_sq_kernel_api.h @@ -0,0 +1,164 @@ +/** + * PTO SQ Kernel API — send queue abstraction for AICore kernels. + * + * Two usage paths, both ending with CQ registration: + * + * Path 1 — High-level (send_request_entry, one-stop): + * + * auto desc = pto2_sdma_descriptor(dst, src, scratch, context); + * uint64_t tag = pto2_send_request_entry(PTO2_ENGINE_SDMA, sq_id, desc); + * pto2_save_expected_completion(PTO2_ENGINE_SDMA, cq, tag); + * pto2_cq_flush(); + * + * Path 2 — Low-level (sq_open + direct ISA instruction): + * + * auto session = pto2_sq_open(PTO2_ENGINE_SDMA, sq_id, scratch, context); + * AsyncEvent event = TPUT_ASYNC(dst, src, session); // or TGET_ASYNC + * pto2_save_expected_completion(cq, event); + * pto2_cq_flush(); + * + * Layering: + * send_request_entry = sq_open + ISA instruction (syntactic sugar) + * sq_open = session management (BuildAsyncSession wrapper) + * + * Requires: + * - PTO-ISA headers included before this header + * - __gm__ and __aicore__ defined before this header + * - HW build only (uses PTO-ISA async instructions) + */ + +#ifndef PTO_SQ_KERNEL_API_H +#define PTO_SQ_KERNEL_API_H + +#include "pto_cq_types.h" +#include "pto_cq_kernel_api.h" + +#include +#include +#include + +// SQ engine types — aliases for the unified PTO2_ENGINE_* constants +#define PTO2_SQ_ENGINE_SDMA PTO2_ENGINE_SDMA +// #define PTO2_SQ_ENGINE_CCU PTO2_ENGINE_CCU // future +// #define PTO2_SQ_ENGINE_URMA PTO2_ENGINE_URMA // future + +#define PTO2_SQ_ID_AUTO UINT32_MAX + +// ============================================================================ +// pto2_sq_open — build async session for a hardware engine queue +// +// This is the foundation layer. Both send_request_entry (high-level) +// and direct ISA usage (low-level) go through this to obtain a session. +// ============================================================================ + +template +inline __aicore__ pto::comm::AsyncSession pto2_sq_open( + uint32_t sq_type, + uint32_t sq_id, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const pto::comm::sdma::SdmaBaseConfig& base_config = + {pto::comm::sdma::kDefaultSdmaBlockBytes, 0, 1}) +{ + pto::comm::AsyncSession session; + pto::comm::BuildAsyncSession( + scratch, context, session, sync_id, base_config, sq_id); + return session; +} + +// ============================================================================ +// pto2_save_expected_completion — AsyncEvent overload +// +// Accepts a PTO-ISA AsyncEvent directly, auto-extracting engine and handle. +// For the low-level path where the user calls ISA instructions directly. +// ============================================================================ + +inline __aicore__ void pto2_save_expected_completion( + volatile __gm__ PTO2CompletionQueue* cq, + const pto::comm::AsyncEvent& event) +{ + uint32_t engine = static_cast(event.engine); + pto2_save_expected_completion(engine, cq, event.handle); +} + +enum class PTO2SdmaRequestOp : uint32_t { + TPut = 0, + TGet = 1, +}; + +// ============================================================================ +// SDMA descriptor + factories (for high-level path) +// ============================================================================ + +template +struct PTO2SdmaDescriptor { + GlobalDstData& dst; + GlobalSrcData& src; + ScratchTile& scratch; + __gm__ uint8_t* context; + uint32_t sync_id; + pto::comm::sdma::SdmaBaseConfig base_config; + PTO2SdmaRequestOp op; +}; + +template +inline __aicore__ PTO2SdmaDescriptor +pto2_sdma_descriptor( + GlobalDstData& dst, + GlobalSrcData& src, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const pto::comm::sdma::SdmaBaseConfig& base_config = + {pto::comm::sdma::kDefaultSdmaBlockBytes, 0, 1}) +{ + return {dst, src, scratch, context, sync_id, base_config, + PTO2SdmaRequestOp::TPut}; +} + +template +inline __aicore__ PTO2SdmaDescriptor +pto2_sdma_tget_descriptor( + GlobalDstData& dst, + GlobalSrcData& src, + ScratchTile& scratch, + __gm__ uint8_t* context, + uint32_t sync_id = 0, + const pto::comm::sdma::SdmaBaseConfig& base_config = + {pto::comm::sdma::kDefaultSdmaBlockBytes, 0, 1}) +{ + return {dst, src, scratch, context, sync_id, base_config, + PTO2SdmaRequestOp::TGet}; +} + +// ============================================================================ +// pto2_send_request_entry — high-level, sugar over sq_open + async ISA op +// +// Original design: tag = pto2_send_request_entry(SQ_TYPE, SQ_ID, descriptor) +// Internally: sq_open(session params from desc) → async ISA op → tag +// ============================================================================ + +template +inline __aicore__ uint64_t pto2_send_request_entry( + uint32_t sq_type, + uint32_t sq_id, + PTO2SdmaDescriptor& desc) +{ + pto::comm::AsyncSession session = pto2_sq_open( + sq_type, sq_id, desc.scratch, desc.context, + desc.sync_id, desc.base_config); + if (!session.valid) return 0; + + pto::comm::AsyncEvent event; + if (desc.op == PTO2SdmaRequestOp::TGet) { + event = pto::comm::TGET_ASYNC( + desc.dst, desc.src, session); + } else { + event = pto::comm::TPUT_ASYNC( + desc.dst, desc.src, session); + } + return event.valid() ? event.handle : 0; +} + +#endif // PTO_SQ_KERNEL_API_H diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h index 7f1e282e..76b66835 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_types.h @@ -22,6 +22,7 @@ #endif #include "tensor.h" +#include "pto_cq_types.h" // Task parameters #define PTO2_MAX_TENSOR_PARAMS 16 // Maximum tensor parameters per task @@ -29,6 +30,8 @@ #define PTO2_MAX_OUTPUTS 16 // Maximum outputs per task #define PTO2_MAX_INPUTS 16 // Maximum inputs per task #define PTO2_MAX_INOUTS 8 // Maximum in-out params per task +// Max completion conditions per deferred task (matches CQ capacity) +#define PTO2_MAX_COMPLETIONS_PER_TASK PTO2_CQ_MAX_ENTRIES // ============================================================================= // Parameter Types (for pto_submit_task API) @@ -43,6 +46,18 @@ enum class PTOParamType : int32_t { INOUT = 2, // Read-then-write: consumer of prior producer + modifier for downstream }; +typedef enum { + PTO2_ASYNC_ENGINE_SDMA = 0, // System DMA + PTO2_ASYNC_ENGINE_ROCE = 1, // RoCE RDMA (reserved) + PTO2_ASYNC_ENGINE_URMA = 2, // URMA cross-die memory ops (reserved) + PTO2_ASYNC_ENGINE_CCU = 3, // Cache coherence unit (reserved) + PTO2_NUM_ASYNC_ENGINES = 4 +} PTO2AsyncEngine; + +enum class PTO2CompletionType : int32_t { + COUNTER = 0, +}; + /** * Aggregated parameter container for pto_submit_task * @@ -67,12 +82,16 @@ struct PTOParam { uint64_t scalars[PTO2_MAX_SCALAR_PARAMS]; int32_t tensor_count{0}; int32_t scalar_count{0}; + bool complete_in_future{false}; + uint64_t cq_addr{0}; bool has_error{false}; const char* error_msg{nullptr}; void reset() { tensor_count = 0; scalar_count = 0; + complete_in_future = false; + cq_addr = 0; has_error = false; error_msg = nullptr; } diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.cpp index c39f8d81..7d9099df 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.cpp @@ -37,6 +37,7 @@ Runtime::Runtime() { pto2_gm_sm_ptr_ = nullptr; pto2_gm_heap_ptr_ = nullptr; pto2_slot_states_ptr_ = nullptr; + memset(async_context_addrs_, 0, sizeof(async_context_addrs_)); orch_args_ = nullptr; orch_arg_count_ = 0; @@ -97,6 +98,17 @@ void Runtime::set_orch_built_on_host(bool v) { orch_built_on_host_ = v; } void Runtime::set_pto2_gm_sm_ptr(void* p) { pto2_gm_sm_ptr_ = p; } void Runtime::set_pto2_gm_heap(void* p) { pto2_gm_heap_ptr_ = p; } void Runtime::set_pto2_slot_states_ptr(void* p) { pto2_slot_states_ptr_ = p; } +void Runtime::set_async_context_addr(PTO2AsyncEngine engine, uint64_t addr) { + if (engine < PTO2_NUM_ASYNC_ENGINES) { + async_context_addrs_[engine] = addr; + } +} +uint64_t Runtime::get_async_context_addr(PTO2AsyncEngine engine) const { + if (engine < PTO2_NUM_ASYNC_ENGINES) { + return async_context_addrs_[engine]; + } + return 0; +} void Runtime::set_orch_args(uint64_t* args, int count) { orch_arg_count_ = count <= RUNTIME_MAX_ARGS ? count : RUNTIME_MAX_ARGS; if (args && orch_arg_count_ > 0) { diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.h b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.h index 35f37abd..836e4693 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.h +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/runtime/runtime.h @@ -27,6 +27,7 @@ #include "common/perf_profiling.h" #include "common/platform_config.h" #include "pto2_dispatch_payload.h" +#include "pto_runtime2_types.h" // ============================================================================= // Configuration Macros @@ -182,6 +183,7 @@ class Runtime { void* pto2_gm_sm_ptr_; // GM pointer to PTO2 shared memory (device) void* pto2_gm_heap_ptr_; // GM heap for orchestrator output buffers (device) void* pto2_slot_states_ptr_; // Pointer to PTO2TaskSlotState array (scheduler-private, for profiling) + uint64_t async_context_addrs_[PTO2_NUM_ASYNC_ENGINES]; // Per-engine async context (0 = not available) uint64_t* orch_args_; // Arguments for device orchestration int orch_arg_count_; uint64_t orch_args_storage_[RUNTIME_MAX_ARGS]; // Copy of args for device @@ -248,6 +250,8 @@ class Runtime { void set_pto2_gm_sm_ptr(void* p); void set_pto2_gm_heap(void* p); void set_pto2_slot_states_ptr(void* p); + void set_async_context_addr(PTO2AsyncEngine engine, uint64_t addr); + uint64_t get_async_context_addr(PTO2AsyncEngine engine) const; void set_orch_args(uint64_t* args, int count); // Device orchestration SO binary (for dlopen on AICPU thread 3)