Skip to content

Commit 62675e5

Browse files
committed
[Scheduler] WIP (4)
1 parent 6c4ddd6 commit 62675e5

4 files changed

Lines changed: 39 additions & 290 deletions

File tree

PyTorchSimDevice/torch_openreg/csrc/Module.cpp

Lines changed: 1 addition & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -185,171 +185,6 @@ PyObject* _streamDestroy(PyObject* self, PyObject* arg) {
185185
END_HANDLE_TH_ERRORS
186186
}
187187

188-
PyObject* _streamSynchronize(PyObject* self, PyObject* arg) {
189-
HANDLE_TH_ERRORS
190-
TORCH_CHECK(THPUtils_checkLong(arg), "stream_synchronize expects an int");
191-
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(arg));
192-
193-
orError_t err;
194-
Py_BEGIN_ALLOW_THREADS
195-
err = orStreamSynchronize(stream);
196-
Py_END_ALLOW_THREADS
197-
198-
if (err != orSuccess) {
199-
TORCH_CHECK(false, "Failed to synchronize stream");
200-
}
201-
Py_RETURN_NONE;
202-
END_HANDLE_TH_ERRORS
203-
}
204-
205-
PyObject* _streamQuery(PyObject* self, PyObject* arg) {
206-
HANDLE_TH_ERRORS
207-
TORCH_CHECK(THPUtils_checkLong(arg), "stream_query expects an int");
208-
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(arg));
209-
orError_t err = orStreamQuery(stream);
210-
if (err == orSuccess) {
211-
Py_RETURN_TRUE;
212-
} else {
213-
Py_RETURN_FALSE;
214-
}
215-
END_HANDLE_TH_ERRORS
216-
}
217-
218-
PyObject* _streamGetPriority(PyObject* self, PyObject* arg) {
219-
HANDLE_TH_ERRORS
220-
TORCH_CHECK(THPUtils_checkLong(arg), "stream_get_priority expects an int");
221-
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(arg));
222-
int priority = 0;
223-
orError_t err = orStreamGetPriority(stream, &priority);
224-
if (err != orSuccess) {
225-
TORCH_CHECK(false, "Failed to get stream priority");
226-
}
227-
return THPUtils_packInt32(priority);
228-
END_HANDLE_TH_ERRORS
229-
}
230-
231-
PyObject* _streamWaitEvent(PyObject* self, PyObject* args) {
232-
HANDLE_TH_ERRORS
233-
TORCH_CHECK(PyTuple_Size(args) == 2, "stream_wait_event expects 2 arguments");
234-
PyObject* stream_obj = PyTuple_GetItem(args, 0);
235-
PyObject* event_obj = PyTuple_GetItem(args, 1);
236-
TORCH_CHECK(THPUtils_checkLong(stream_obj), "stream must be an int");
237-
TORCH_CHECK(THPUtils_checkLong(event_obj), "event must be an int");
238-
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(stream_obj));
239-
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(event_obj));
240-
orError_t err = orStreamWaitEvent(stream, event, 0);
241-
if (err != orSuccess) {
242-
TORCH_CHECK(false, "Failed to wait for event");
243-
}
244-
Py_RETURN_NONE;
245-
END_HANDLE_TH_ERRORS
246-
}
247-
248-
// Event functions
249-
PyObject* _eventCreate(PyObject* self, PyObject* noargs) {
250-
HANDLE_TH_ERRORS
251-
torch::utils::device_lazy_init(at::kPrivateUse1);
252-
orEvent_t event = nullptr;
253-
orError_t err = orEventCreate(&event);
254-
if (err != orSuccess) {
255-
TORCH_CHECK(false, "Failed to create event");
256-
}
257-
return THPUtils_packInt64(reinterpret_cast<int64_t>(event));
258-
END_HANDLE_TH_ERRORS
259-
}
260-
261-
PyObject* _eventCreateWithFlags(PyObject* self, PyObject* arg) {
262-
HANDLE_TH_ERRORS
263-
TORCH_CHECK(THPUtils_checkLong(arg), "event_create_with_flags expects an int");
264-
unsigned int flags = static_cast<unsigned int>(THPUtils_unpackLong(arg));
265-
266-
torch::utils::device_lazy_init(at::kPrivateUse1);
267-
orEvent_t event = nullptr;
268-
orError_t err = orEventCreateWithFlags(&event, flags);
269-
if (err != orSuccess) {
270-
TORCH_CHECK(false, "Failed to create event with flags");
271-
}
272-
return THPUtils_packInt64(reinterpret_cast<int64_t>(event));
273-
END_HANDLE_TH_ERRORS
274-
}
275-
276-
PyObject* _eventDestroy(PyObject* self, PyObject* arg) {
277-
HANDLE_TH_ERRORS
278-
TORCH_CHECK(THPUtils_checkLong(arg), "event_destroy expects an int");
279-
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(arg));
280-
orError_t err = orEventDestroy(event);
281-
if (err != orSuccess) {
282-
TORCH_CHECK(false, "Failed to destroy event");
283-
}
284-
Py_RETURN_NONE;
285-
END_HANDLE_TH_ERRORS
286-
}
287-
288-
PyObject* _eventRecord(PyObject* self, PyObject* args) {
289-
HANDLE_TH_ERRORS
290-
TORCH_CHECK(PyTuple_Size(args) == 2, "event_record expects 2 arguments");
291-
PyObject* event_obj = PyTuple_GetItem(args, 0);
292-
PyObject* stream_obj = PyTuple_GetItem(args, 1);
293-
TORCH_CHECK(THPUtils_checkLong(event_obj), "event must be an int");
294-
TORCH_CHECK(THPUtils_checkLong(stream_obj), "stream must be an int");
295-
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(event_obj));
296-
orStream_t stream = reinterpret_cast<orStream_t>(THPUtils_unpackLong(stream_obj));
297-
orError_t err = orEventRecord(event, stream);
298-
if (err != orSuccess) {
299-
TORCH_CHECK(false, "Failed to record event");
300-
}
301-
Py_RETURN_NONE;
302-
END_HANDLE_TH_ERRORS
303-
}
304-
305-
PyObject* _eventSynchronize(PyObject* self, PyObject* arg) {
306-
HANDLE_TH_ERRORS
307-
TORCH_CHECK(THPUtils_checkLong(arg), "event_synchronize expects an int");
308-
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(arg));
309-
310-
orError_t err;
311-
Py_BEGIN_ALLOW_THREADS
312-
err = orEventSynchronize(event);
313-
Py_END_ALLOW_THREADS
314-
315-
if (err != orSuccess) {
316-
TORCH_CHECK(false, "Failed to synchronize event");
317-
}
318-
Py_RETURN_NONE;
319-
END_HANDLE_TH_ERRORS
320-
}
321-
322-
PyObject* _eventQuery(PyObject* self, PyObject* arg) {
323-
HANDLE_TH_ERRORS
324-
TORCH_CHECK(THPUtils_checkLong(arg), "event_query expects an int");
325-
orEvent_t event = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(arg));
326-
orError_t err = orEventQuery(event);
327-
if (err == orSuccess) {
328-
Py_RETURN_TRUE;
329-
} else {
330-
Py_RETURN_FALSE;
331-
}
332-
END_HANDLE_TH_ERRORS
333-
}
334-
335-
PyObject* _eventElapsedTime(PyObject* self, PyObject* args) {
336-
HANDLE_TH_ERRORS
337-
TORCH_CHECK(PyTuple_Size(args) == 2, "event_elapsed_time expects 2 arguments");
338-
PyObject* start_obj = PyTuple_GetItem(args, 0);
339-
PyObject* end_obj = PyTuple_GetItem(args, 1);
340-
TORCH_CHECK(THPUtils_checkLong(start_obj), "start event must be an int");
341-
TORCH_CHECK(THPUtils_checkLong(end_obj), "end event must be an int");
342-
orEvent_t start = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(start_obj));
343-
orEvent_t end = reinterpret_cast<orEvent_t>(THPUtils_unpackLong(end_obj));
344-
float ms = 0.0f;
345-
orError_t err = orEventElapsedTime(&ms, start, end);
346-
if (err != orSuccess) {
347-
TORCH_CHECK(false, "Failed to get elapsed time");
348-
}
349-
return PyFloat_FromDouble(static_cast<double>(ms));
350-
END_HANDLE_TH_ERRORS
351-
}
352-
353188
PyObject* _deviceSynchronize(PyObject* self, PyObject* noargs) {
354189
HANDLE_TH_ERRORS
355190
torch::utils::device_lazy_init(at::kPrivateUse1);
@@ -421,20 +256,8 @@ static PyMethodDef methods[] = {
421256
{"get_amp_supported_dtype", _getAmpSupportedDtype, METH_NOARGS, nullptr},
422257
// Stream functions
423258
{"_stream_create", _streamCreate, METH_NOARGS, nullptr},
424-
{"_stream_create_with_priority", _streamCreateWithPriority, METH_VARARGS, nullptr},
425259
{"_stream_destroy", _streamDestroy, METH_O, nullptr},
426-
{"_stream_synchronize", _streamSynchronize, METH_O, nullptr},
427-
{"_stream_query", _streamQuery, METH_O, nullptr},
428-
{"_stream_get_priority", _streamGetPriority, METH_O, nullptr},
429-
{"_stream_wait_event", _streamWaitEvent, METH_VARARGS, nullptr},
430-
// Event functions
431-
{"_event_create", _eventCreate, METH_NOARGS, nullptr},
432-
{"_event_create_with_flags", _eventCreateWithFlags, METH_O, nullptr},
433-
{"_event_destroy", _eventDestroy, METH_O, nullptr},
434-
{"_event_record", _eventRecord, METH_VARARGS, nullptr},
435-
{"_event_synchronize", _eventSynchronize, METH_O, nullptr},
436-
{"_event_query", _eventQuery, METH_O, nullptr},
437-
{"_event_elapsed_time", _eventElapsedTime, METH_VARARGS, nullptr},
260+
438261
// Device functions
439262
{"_device_synchronize", _deviceSynchronize, METH_NOARGS, nullptr},
440263
// Stream task functions

PyTorchSimDevice/torch_openreg/openreg/__init__.py

Lines changed: 27 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .extension_device_interface import ExtensionDeviceInterface
99

1010
_initialized = False
11+
_default_streams = {} # Dictionary to store default streams per device
1112

1213

1314
class device:
@@ -64,36 +65,21 @@ def _lazy_init():
6465
register_interface_for_device(custom_device(), ExtensionDeviceInterface)
6566
_initialized = True
6667

68+
# Create default streams for all devices
69+
num_devices = device_count()
70+
for device_idx in range(num_devices):
71+
_default_streams[device_idx] = Stream()
6772

6873
class Stream:
6974
"""Wrapper for OpenReg stream."""
7075

71-
def __init__(self, priority=None, flags=0):
72-
if priority is not None:
73-
self._stream = torch_openreg._C._stream_create_with_priority(flags, priority)
74-
else:
75-
self._stream = torch_openreg._C._stream_create()
76+
def __init__(self, flags=0):
77+
self._stream = torch_openreg._C._stream_create()
7678

7779
def __del__(self):
7880
if hasattr(self, '_stream'):
7981
torch_openreg._C._stream_destroy(self._stream)
8082

81-
def synchronize(self):
82-
"""Wait for all operations in the stream to complete."""
83-
torch_openreg._C._stream_synchronize(self._stream)
84-
85-
def query(self):
86-
"""Check if all operations in the stream have completed."""
87-
return torch_openreg._C._stream_query(self._stream)
88-
89-
def wait_event(self, event):
90-
"""Make this stream wait for an event."""
91-
torch_openreg._C._stream_wait_event(self._stream, event._event)
92-
93-
def get_priority(self):
94-
"""Get the priority of the stream."""
95-
return torch_openreg._C._stream_get_priority(self._stream)
96-
9783
def launch_kernel(self, task):
9884
"""Add a Python callable kernel to this stream.
9985
@@ -108,74 +94,43 @@ def cdata(self):
10894
return self._stream
10995

11096

111-
class Event:
112-
"""Wrapper for OpenReg event."""
113-
114-
def __init__(self, enable_timing=False):
115-
if enable_timing:
116-
# orEventEnableTiming = 0x1
117-
self._event = torch_openreg._C._event_create_with_flags(0x1)
118-
else:
119-
self._event = torch_openreg._C._event_create()
120-
121-
def __del__(self):
122-
if hasattr(self, '_event'):
123-
torch_openreg._C._event_destroy(self._event)
124-
125-
def record(self, stream=None):
126-
"""Record the event in a stream."""
127-
if stream is None:
128-
# Use default stream (stream 0)
129-
stream = Stream()
130-
torch_openreg._C._event_record(self._event, stream._stream)
131-
132-
def synchronize(self):
133-
"""Wait for the event to complete."""
134-
torch_openreg._C._event_synchronize(self._event)
135-
136-
def query(self):
137-
"""Check if the event has completed."""
138-
return torch_openreg._C._event_query(self._event)
139-
140-
def elapsed_time(self, start_event):
141-
"""Get the elapsed time between two events in milliseconds."""
142-
return torch_openreg._C._event_elapsed_time(start_event._event, self._event)
143-
144-
@property
145-
def cdata(self):
146-
"""Get the underlying event pointer (for internal use)."""
147-
return self._event
148-
149-
15097
def synchronize():
15198
"""Synchronize all streams on the current device."""
15299
torch_openreg._C._device_synchronize()
153100

154101

155-
def stream(priority=None, flags=0):
102+
def stream(flags=0):
156103
"""Create a new stream.
157104
158105
Args:
159-
priority: Stream priority (optional)
160106
flags: Stream flags (optional)
161107
162108
Returns:
163109
Stream: A new stream object
164110
"""
165-
return Stream(priority=priority, flags=flags)
111+
return Stream(flags=flags)
166112

113+
def default_stream(device=None):
114+
_lazy_init()
115+
if device is None:
116+
device_idx = current_device()
117+
else:
118+
device_idx = torch.accelerator._get_device_index(device, optional=True)
119+
if device_idx < 0:
120+
device_idx = current_device()
167121

168-
def event(enable_timing=False):
169-
"""Create a new event.
122+
if device_idx not in _default_streams:
123+
# Create default stream if it doesn't exist
124+
_default_streams[device_idx] = Stream()
170125

171-
Args:
172-
enable_timing: Whether to enable timing for the event
126+
return _default_streams[device_idx]
173127

174-
Returns:
175-
Event: A new event object
176-
"""
177-
return Event(enable_timing=enable_timing)
178128

129+
def launch_kernel(task, stream=None):
130+
_lazy_init()
131+
if stream is None:
132+
stream = default_stream()
133+
stream.launch_kernel(task)
179134

180135
from .random import * # noqa: F403
181136
from .amp import *
@@ -200,9 +155,7 @@ def event(enable_timing=False):
200155
"get_autocast_dtype",
201156
"set_autocast_dtype",
202157
"get_amp_supported_dtype",
203-
"Stream",
204-
"Event",
205158
"stream",
206-
"event",
159+
"launch_kernel",
207160
"synchronize",
208161
]

0 commit comments

Comments
 (0)