Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0422f62
initial implementation of implementation input
Masoudvahid Dec 28, 2025
31124b1
merged main into async_inputs
Mr-DarkTesla Dec 30, 2025
535decd
Yet not working runner
Mr-DarkTesla Dec 30, 2025
28047c5
Requires datasets 4.0.0
Mr-DarkTesla Dec 30, 2025
08d4348
Fixed queue push hook
Mr-DarkTesla Jan 1, 2026
9f6b79b
Added shard push to different targets
Mr-DarkTesla Jan 1, 2026
29fd350
typo
Mr-DarkTesla Jan 1, 2026
345d116
merged main into async_input_runner
Mr-DarkTesla Jan 6, 2026
6289ce6
merged main into async_input_runner
Mr-DarkTesla Jan 6, 2026
6843443
live context queue now in solver
Mr-DarkTesla Jan 6, 2026
8896e9e
small renames
Mr-DarkTesla Jan 6, 2026
e8a322b
Dataset path
Mr-DarkTesla Jan 6, 2026
5f39231
input_block support for async_input & insert after eof paragraph
Masoudvahid Jan 8, 2026
e7d2076
merged async_input into async_input_runner (shards into user inputs)
Mr-DarkTesla Jan 9, 2026
0f4c224
Reverted imports
Mr-DarkTesla Jan 9, 2026
a54107f
New hooks + bugfix
Mr-DarkTesla Jan 9, 2026
2a99fa6
Unified insertions
Mr-DarkTesla Jan 9, 2026
3889a4c
More options for ablation
Mr-DarkTesla Jan 9, 2026
42d7d54
Merge branch 'main' into async_input_235b
justheuristic Jan 11, 2026
ae2f741
device_map
justheuristic Jan 11, 2026
2060ec4
Merge branch 'async_input_runner' into async_input_235b
justheuristic Jan 11, 2026
4942672
HUGE BUGFIX!
Mr-DarkTesla Jan 13, 2026
c46ad1f
Merge branch 'async_input_runner' into async_input_235b
justheuristic Jan 13, 2026
b62dcd1
Support for multishard
Mr-DarkTesla Jan 13, 2026
bb08e8c
llms get lost dataset support
Mr-DarkTesla Jan 13, 2026
2a844ca
save path
Mr-DarkTesla Jan 13, 2026
8204788
Correct shards handling
Mr-DarkTesla Jan 13, 2026
55276ef
Apperently that was incorrect
Mr-DarkTesla Jan 15, 2026
a8b6b37
Merge remote-tracking branch 'origin/async_input_multishard' into asy…
justheuristic Jan 16, 2026
c4d0492
device_map
justheuristic Jan 16, 2026
1573750
revert math
justheuristic Jan 16, 2026
176068a
device_map
justheuristic Jan 16, 2026
b0be2d1
fix edge case for dollars
justheuristic Jan 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
828 changes: 828 additions & 0 deletions async_query.ipynb

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions async_reasoning/async_input_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

def async_input_hook_constructor(solver, shard_to_target, target_reminders, next_shard_every_steps, problem_shards, defer_until_boundary=False):
def on_token(writer_tokens, thinker_tokens, token_times, eos, state):
for shard_idx, problem_shard in enumerate(problem_shards):
if next_shard_every_steps <= 0 or len(thinker_tokens) < next_shard_every_steps * (shard_idx + 1):
return

for target in shard_to_target:
if solver.live_context_queue.push_counter_per_target[target] == shard_idx:
print(end=f"Sent shard {shard_idx} to {target} on step {len(thinker_tokens)}.\n", flush=True)
solver.live_context_queue.push_text(
f"\n\nADDITIONAL USER INPUT: {problem_shard}\n\n",
target=target,
defer_until_boundary=defer_until_boundary and (target != "input")
)
for target in target_reminders:
assert target != "input", "can't remind to input"
assert target not in shard_to_target, f"Can't send reminder to {target}; already in shard_to_input"
if solver.live_context_queue.push_counter_per_target[target] == shard_idx:
print(end=f"Sent reminder {shard_idx} to {target} on step {len(thinker_tokens)}.\n", flush=True)
solver.live_context_queue.push_text(
f" ... [SYSTEM: additional user input detected]\n",
target=target,
defer_until_boundary=defer_until_boundary
)
return on_token
47 changes: 39 additions & 8 deletions async_reasoning/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
import torch
import shared_cache
import transformers


import logging
Expand Down Expand Up @@ -28,32 +29,48 @@ def __init__(self, model, tokenizer, prompting: AsyncReasoningPrompting, tokeniz
self.state = starting_state

# Init all needed cache blocks
(self.input_prompt, self.thinker_output, self.writer_output, self.mode_switching_prompt, self.mode_switching_question
) = (shared_cache.CacheBlock(config=self.model.config) for _ in range(5))
(
self.input_prompt,
self.input_block,
self.thinker_output,
self.writer_output,
self.mode_switching_prompt,
self.mode_switching_question,
) = (shared_cache.CacheBlock(config=self.model.config) for _ in range(6))

def prefill_cache_block(text: str, blocks, write_to=None):
write_to = blocks[-1] if write_to is None else write_to
tmp_cm = shared_cache.SharedCacheManager(cache_structure=[blocks], write_to=[write_to])
encoded = self.tokenizer(text, **self.tokenizer_kwargs)["input_ids"].to(self.device)
with torch.inference_mode():
self.model(**tmp_cm.get_input_kwargs(encoded))

def init_empty_block(block: shared_cache.CacheBlock):
"""Populate block with zero-length caches so it participates in structures without assertions."""
tmp_cm = shared_cache.SharedCacheManager(cache_structure=[[block]], write_to=[block])
dummy = self.tokenizer(" ", **self.tokenizer_kwargs)["input_ids"].to(self.device)
with torch.inference_mode():
self.model(**tmp_cm.get_input_kwargs(dummy))
block.trim_keep_first(0)

# encode each prompt section as LLM KV cache for use in generation
prefill_cache_block(self.prompting.input_prompt, [self.input_prompt]) # <-- writes KV entries to last cache in list
prefill_cache_block(self.prompting.thinker_output_prefix, [self.input_prompt, self.thinker_output])
prefill_cache_block(self.prompting.writer_output_prefix, [self.input_prompt, self.thinker_output, self.writer_output])
init_empty_block(self.input_block)
prefill_cache_block(self.prompting.thinker_output_prefix, [self.input_prompt, self.input_block, self.thinker_output])
prefill_cache_block(self.prompting.writer_output_prefix, [self.input_prompt, self.input_block, self.thinker_output, self.writer_output])
prefill_cache_block(self.prompting.mode_switching_prompt, [self.mode_switching_prompt])
# note: mode_switching_question is re-encoded every time it is asked - no need to fill it here

thinker_view = (self.input_prompt, self.thinker_output)
writer_view = (self.input_prompt, self.thinker_output, self.writer_output)
mode_switching_view = (self.mode_switching_prompt, self.thinker_output, self.writer_output, self.mode_switching_question)
thinker_view = (self.input_prompt, self.input_block, self.thinker_output)
writer_view = (self.input_prompt, self.input_block, self.thinker_output, self.writer_output)
mode_switching_view = (self.mode_switching_prompt, self.input_block, self.thinker_output, self.writer_output, self.mode_switching_question)

# prepare cache manager for each mode: only thinker, only writer and thinker+writer and mode switching
self.cm_thinker_only = shared_cache.SharedCacheManager(cache_structure=[thinker_view])
self.cm_writer_only = shared_cache.SharedCacheManager(cache_structure=[writer_view])
self.cm_thinker_and_writer = shared_cache.SharedCacheManager(cache_structure=[thinker_view, writer_view])
self.cm_mode_switching = shared_cache.SharedCacheManager(cache_structure=[mode_switching_view])
self.cm_input_only = shared_cache.SharedCacheManager(cache_structure=[[self.input_prompt, self.input_block]], write_to=[self.input_block])

# To catch and logg state change
def __setattr__(self, name, value):
Expand All @@ -74,4 +91,18 @@ def cache_manager(self):
raise ValueError(f"Unexpected state {self.state}")

def get_input_kwargs(self, **kwargs):
return self.cache_manager.get_input_kwargs(**kwargs)
return self.cache_manager.get_input_kwargs(**kwargs)

def append_tokens(self, target: str, token_ids: torch.Tensor):
"""Append pre-tokenized ids to writer, thinker, or input caches so generation can consume them mid-stream."""
if target not in ("writer", "thinker", "input"):
raise ValueError(f"target must be 'writer', 'thinker', or 'input', got {target}")
token_ids = token_ids.to(self.device)
if target == "writer":
input_kwargs = self.cm_writer_only.get_input_kwargs(token_ids)
elif target == "input":
input_kwargs = self.cm_input_only.get_input_kwargs(token_ids)
else:
input_kwargs = self.cm_thinker_only.get_input_kwargs(token_ids)
with torch.inference_mode():
self.model(**input_kwargs)
48 changes: 40 additions & 8 deletions async_reasoning/cache_fast_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,50 @@ def __init__(self, model, tokenizer, prompting, tokenizer_kwargs=dict(), startin
self.state = starting_state

# Init all needed cache blocks
(self.input_prompt, self.thinker_output, self.writer_output,
self.mode_switching_prompt, self.mode_switching_question
) = (transformers.DynamicCache() for _ in range(5))
(
self.input_prompt,
self.input_block,
self.thinker_output,
self.writer_output,
self.mode_switching_prompt,
self.mode_switching_question,
) = (transformers.DynamicCache() for _ in range(6))

def prefill_cache_block(text: str, blocks, write_to=None):
write_to = blocks[-1] if write_to is None else write_to
tmp_cm = HogwildCache(cache_structure=[blocks], write_to=[write_to], model=model)
encoded = self.tokenizer(text, **self.tokenizer_kwargs)["input_ids"].to(self.device)
with torch.inference_mode():
self.model(**tmp_cm.get_input_kwargs(encoded))

def init_empty_block(block: transformers.DynamicCache):
"""Populate block with zero-length caches so it participates in structures without assertions."""
tmp_cm = HogwildCache(cache_structure=[[block]], write_to=[block], model=model)
dummy = self.tokenizer(" ", **self.tokenizer_kwargs)["input_ids"].to(self.device)
with torch.inference_mode():
self.model(**tmp_cm.get_input_kwargs(dummy))
for i in range(len(block.key_cache)):
block.key_cache[i] = block.key_cache[i][..., :0, :].contiguous()
block.value_cache[i] = block.value_cache[i][..., :0, :].contiguous()
block._seen_tokens = 0

# encode each prompt section as LLM KV cache for use in generation
prefill_cache_block(self.prompting.input_prompt, [self.input_prompt]) # <-- writes KV entries to last cache in list
prefill_cache_block(self.prompting.thinker_output_prefix, [self.input_prompt, self.thinker_output])
prefill_cache_block(self.prompting.writer_output_prefix, [self.input_prompt, self.thinker_output, self.writer_output])
init_empty_block(self.input_block)
prefill_cache_block(self.prompting.thinker_output_prefix, [self.input_prompt, self.input_block, self.thinker_output])
prefill_cache_block(self.prompting.writer_output_prefix, [self.input_prompt, self.input_block, self.thinker_output, self.writer_output])
prefill_cache_block(self.prompting.mode_switching_prompt, [self.mode_switching_prompt])

thinker_view = (self.input_prompt, self.thinker_output)
writer_view = (self.input_prompt, self.thinker_output, self.writer_output)
mode_switching_view = (self.mode_switching_prompt, self.thinker_output, self.writer_output, self.mode_switching_question)
thinker_view = (self.input_prompt, self.input_block, self.thinker_output)
writer_view = (self.input_prompt, self.input_block, self.thinker_output, self.writer_output)
mode_switching_view = (self.mode_switching_prompt, self.input_block, self.thinker_output, self.writer_output, self.mode_switching_question)

# prepare cache manager for each mode: only thinker, only writer and thinker+writer and mode switching
self.cm_thinker_only = HogwildCache(cache_structure=[thinker_view], model=model)
self.cm_writer_only = HogwildCache(cache_structure=[writer_view], model=model)
self.cm_thinker_and_writer = HogwildCache(cache_structure=[thinker_view, writer_view], model=model)
self.cm_mode_switching = HogwildCache(cache_structure=[mode_switching_view], model=model)
self.cm_input_only = HogwildCache(cache_structure=[[self.input_prompt, self.input_block]], write_to=[self.input_block], model=model)

# To catch and logg state change
def __setattr__(self, name, value):
Expand All @@ -66,3 +84,17 @@ def cache_manager(self):

def get_input_kwargs(self, **kwargs):
return self.cache_manager.get_input_kwargs(**kwargs)

def append_tokens(self, target: str, token_ids: torch.Tensor):
"""Append pre-tokenized ids to writer, thinker, or input caches so generation can consume them mid-stream."""
if target not in ("writer", "thinker", "input"):
raise ValueError(f"target must be 'writer', 'thinker', or 'input', got {target}")
token_ids = token_ids.to(self.device)
if target == "writer":
input_kwargs = self.cm_writer_only.get_input_kwargs(token_ids)
elif target == "input":
input_kwargs = self.cm_input_only.get_input_kwargs(token_ids)
else:
input_kwargs = self.cm_thinker_only.get_input_kwargs(token_ids)
with torch.inference_mode():
self.model(**input_kwargs)
93 changes: 92 additions & 1 deletion async_reasoning/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import warnings
import transformers
from IPython.display import display, Markdown, clear_output
from typing import Sequence, Union, Callable, Optional
from typing import Sequence, Union, Callable, Optional, List, Tuple
import queue

from async_reasoning.prompting import AsyncReasoningPrompting
from async_reasoning.cache import State, AsyncReasoningCache
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self,
self.thinker_forbidden_token_ix, self.writer_forbidden_token_ix = thinker_forbidden_token_ix, writer_forbidden_token_ix
self.end_of_think_token_ix = end_of_think_token_ix
self.use_fast_kernel = use_fast_kernel
self.live_context_queue = LiveContextQueue(tokenizer, model.device)

@torch.inference_mode()
def check_if_should_continue_writing(self,
Expand Down Expand Up @@ -101,11 +103,14 @@ def solve(
token_times = []
writer_output_tokens = self.tokenizer.encode(prompting.writer_output_prefix, **self.tokenizer_kwargs).flatten().tolist()
thinker_output_tokens = self.tokenizer.encode(prompting.thinker_output_prefix, **self.tokenizer_kwargs).flatten().tolist()
input_tokens: List[int] = []

writer_output_tokens.append(self.tokenizer.encode("\n\n", **self.tokenizer_kwargs).item())
thinker_output_tokens.append(self.tokenizer.encode("\n\n", **self.tokenizer_kwargs).item())
eos_generated = False
cache = self.Cache(self.model, self.tokenizer, prompting, tokenizer_kwargs=self.tokenizer_kwargs, starting_state=State.thinker_only)
pending_injections: List["QueuedInjection"] = []
self.live_context_queue.zero_counter()
with torch.inference_mode():
starting_time = time.perf_counter()
for step in range(budget):
Expand Down Expand Up @@ -147,6 +152,17 @@ def solve(
if writer_output_tokens[-1] == self.tokenizer.eos_token_id:
eos_generated = True

# Inject any user-provided context mid-generation
pending_injections.extend(self.live_context_queue.pop_all())
if pending_injections:
pending_injections = self._apply_pending_injections(
pending_injections,
cache,
writer_output_tokens,
thinker_output_tokens,
input_tokens,
)

if on_new_tokens_generated is not None:
on_new_tokens_generated(
writer_output_tokens,
Expand All @@ -163,3 +179,78 @@ def solve(
writer_output_str, thinker_output_str = self.tokenizer.decode(writer_output_tokens), self.tokenizer.decode(thinker_output_tokens)

return writer_output_str, thinker_output_str, token_times, eos_generated

def _apply_pending_injections(
self,
pending_injections: List["QueuedInjection"],
cache: Union['AsyncReasoningCache', 'AsyncReasoningCacheFastKernels'],
writer_output_tokens: List[int],
thinker_output_tokens: List[int],
input_tokens: List[int],
) -> List["QueuedInjection"]:
remaining: List["QueuedInjection"] = []
for inj in pending_injections:
if inj.target == "writer":
token_stream = writer_output_tokens
elif inj.target == "thinker":
token_stream = thinker_output_tokens
else:
token_stream = thinker_output_tokens # defer based on thinker stream for input block
if inj.defer_until_boundary and not self._is_boundary(token_stream):
remaining.append(inj)
continue
tokens_tensor = torch.tensor([inj.tokens], device=self.device)
cache.append_tokens(inj.target, tokens_tensor)
if inj.target == "writer":
writer_output_tokens.extend([int(t) for t in inj.tokens])
elif inj.target == "thinker":
thinker_output_tokens.extend([int(t) for t in inj.tokens])
else:
input_tokens.extend([int(t) for t in inj.tokens])
return remaining

def _is_boundary(self, tokens: Sequence[int]) -> bool:
tail = self.tokenizer.decode(tokens[-12:]) if tokens else ""
if tail.endswith("\n\n"):
return True
return any(tail.rstrip().endswith(mark) for mark in (".", "!", "?", "…"))


class QueuedInjection:
def __init__(self, target: str, tokens: List[int], defer_until_boundary: bool):
self.target = target
self.tokens = tokens
self.defer_until_boundary = defer_until_boundary


class LiveContextQueue:
"""Thread-safe queue for feeding extra context tokens/text mid-generation."""
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, device: torch.device):
self._queue: queue.Queue[QueuedInjection] = queue.Queue()
self.tokenizer = tokenizer
self.device = device
self.zero_counter()

def zero_counter(self):
self.push_counter_per_target = {"writer": 0, "thinker": 0, "input": 0}

def push_text(self, text: str, target: str = "thinker", defer_until_boundary: bool = False):
tokens = self.tokenizer.encode(text, add_special_tokens=False)
self.push_tokens(tokens, target=target, defer_until_boundary=defer_until_boundary)
self.push_counter_per_target[target] += 1

def push_tokens(
self,
tokens: Sequence[int],
target: str = "thinker",
defer_until_boundary: bool = False,
):
if target not in ("writer", "thinker", "input"):
raise ValueError(f"target must be 'writer', 'thinker', or 'input', got {target}")
self._queue.put(QueuedInjection(target, list(tokens), defer_until_boundary))

def pop_all(self) -> List[QueuedInjection]:
items: List[QueuedInjection] = []
while not self._queue.empty():
items.append(self._queue.get())
return items
Loading