Skip to content

Commit 5e3a8cc

Browse files
committed
[Scheduler] WIP (8)
1 parent cf2e9d0 commit 5e3a8cc

4 files changed

Lines changed: 34 additions & 38 deletions

File tree

PyTorchSimFrontend/extension_codecache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def run_kernel_simulation(*args, **kwargs):
274274
# Dump arguments and meta data
275275
dump_metadata(args, arg_attributes, result_path)
276276
runtime_path = FunctionalSimulator.get_runtime_dump_path(result_path)
277-
if extension_config.pytorchsim_functional_mode:
277+
if extension_config.pytorchsim_functional_mode and not autotune:
278278
funcsim = FunctionalSimulator(result_path, key)
279279
funcsim.run_spike(args, arg_attributes,
280280
runtime_path, self.validation_binary_name,

PyTorchSimFrontend/mlir/mlir_autotune.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,24 @@ def make_run_fn(
6161
# Check already cached result.
6262
write_path = get_write_path(self.source_code)
6363
key, _ = write(self.source_code, "mlir", specified_dir=write_path)
64-
result_path = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "outputs", hash_prefix(key), "togsim_result/0")
65-
if os.path.exists(result_path):
66-
result = TOGSimulator.get_result_from_file(result_path)
67-
def cached_run_fn(*args, **kwargs):
68-
return result
69-
return cached_run_fn
64+
result_dir = os.path.join(extension_config.CONFIG_TORCHSIM_DUMP_PATH, "outputs", hash_prefix(key), "togsim_result")
65+
66+
# Find the most recent .log file in the result directory
67+
if os.path.exists(result_dir) and os.path.isdir(result_dir):
68+
log_files = [f for f in os.listdir(result_dir) if f.endswith('.log')]
69+
if log_files:
70+
# Sort by modification time, get the most recent file
71+
log_files_with_time = [
72+
(f, os.path.getmtime(os.path.join(result_dir, f)))
73+
for f in log_files
74+
]
75+
log_files_with_time.sort(key=lambda x: x[1], reverse=True)
76+
latest_log_file = log_files_with_time[0][0]
77+
result_path = os.path.join(result_dir, latest_log_file)
78+
result = TOGSimulator.get_result_from_file(result_path)
79+
def cached_run_fn(*args, **kwargs):
80+
return result
81+
return cached_run_fn
7082

7183
# Run a candidate code
7284
run_method = custom_async_compile.mlir(

Simulator/simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _send_command(self, command_type, device_index, stream_index, tog_path="", a
308308
timestamp: Timestamp in nanoseconds (default: 0)
309309
310310
Returns:
311-
int: The kernel ID assigned to this command (or -1 for DEVICE_SYNC)
311+
int: The kernel ID assigned to this command
312312
"""
313313
if self.process is None:
314314
raise RuntimeError("[TOGSim] Simulator process is not running")

tests/test_scheduler.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,25 @@
33
import torch
44
from torchvision.models import resnet18 as model1
55
from test_transformer import EncoderBlock as model2
6+
from Simulator.simulator import TOGSimulator
67

78
base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')
8-
sys.path.append(base_path)
9-
from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request
109
config = f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv3_partition.yml'
10+
os.environ['TOGSIM_CONFIG'] = config
1111

1212
target_model1 = model1().eval()
1313
target_model2 = model2(768, 12).eval()
1414

15-
# Init scheduler
16-
scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config)
17-
# Register compiled model
18-
opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last))
19-
opt_model2 = torch.compile(target_model2.to(device=scheduler.execution_engine.module.custom_device()))
20-
SchedulerDNNModel.register_model("resnet18", opt_model1)
21-
SchedulerDNNModel.register_model("bert", opt_model2)
22-
23-
# Init input data
24-
model_input1 = torch.randn(1, 3, 224, 224)
25-
model_input2 = torch.randn(128, 768)
26-
27-
# Init request
28-
new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0)
29-
new_request2 = Request("bert", [model_input2], [], request_queue_idx=1)
30-
new_request3 = Request("resnet18", [model_input1], [], request_queue_idx=0)
31-
new_request4 = Request("bert", [model_input2], [], request_queue_idx=1)
32-
33-
# Add request to scheduler
34-
scheduler.add_request(new_request1, request_time=0)
35-
scheduler.add_request(new_request2, request_time=0)
36-
scheduler.add_request(new_request3, request_time=0)
37-
scheduler.add_request(new_request4, request_time=0)
38-
39-
# Run scheduler
40-
while not scheduler.is_finished():
41-
scheduler.schedule()
42-
15+
device = torch.device("npu:0")
16+
opt_model1 = torch.compile(target_model1.to(device=device, memory_format=torch.channels_last))
17+
opt_model2 = torch.compile(target_model2.to(device=device))
18+
model_input1 = torch.randn(1, 3, 224, 224).to(device=device)
19+
model_input2 = torch.randn(128, 768).to(device=device)
20+
21+
with TOGSimulator(config_path=config):
22+
torch.npu.launch_model(opt_model1, model_input1, stream_index=0, timestamp=0)
23+
torch.npu.launch_model(opt_model2, model_input2, stream_index=1, timestamp=0)
24+
torch.npu.synchronize()
25+
torch.npu.launch_model(opt_model1, model_input1, stream_index=0, timestamp=0)
26+
torch.npu.launch_model(opt_model2, model_input2, stream_index=1, timestamp=0)
4327
print("Done")

0 commit comments

Comments
 (0)