@@ -185,171 +185,6 @@ PyObject* _streamDestroy(PyObject* self, PyObject* arg) {
185185 END_HANDLE_TH_ERRORS
186186}
187187
188- PyObject* _streamSynchronize (PyObject* self, PyObject* arg) {
189- HANDLE_TH_ERRORS
190- TORCH_CHECK (THPUtils_checkLong (arg), " stream_synchronize expects an int" );
191- orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (arg));
192-
193- orError_t err;
194- Py_BEGIN_ALLOW_THREADS
195- err = orStreamSynchronize (stream);
196- Py_END_ALLOW_THREADS
197-
198- if (err != orSuccess) {
199- TORCH_CHECK (false , " Failed to synchronize stream" );
200- }
201- Py_RETURN_NONE;
202- END_HANDLE_TH_ERRORS
203- }
204-
205- PyObject* _streamQuery (PyObject* self, PyObject* arg) {
206- HANDLE_TH_ERRORS
207- TORCH_CHECK (THPUtils_checkLong (arg), " stream_query expects an int" );
208- orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (arg));
209- orError_t err = orStreamQuery (stream);
210- if (err == orSuccess) {
211- Py_RETURN_TRUE;
212- } else {
213- Py_RETURN_FALSE;
214- }
215- END_HANDLE_TH_ERRORS
216- }
217-
218- PyObject* _streamGetPriority (PyObject* self, PyObject* arg) {
219- HANDLE_TH_ERRORS
220- TORCH_CHECK (THPUtils_checkLong (arg), " stream_get_priority expects an int" );
221- orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (arg));
222- int priority = 0 ;
223- orError_t err = orStreamGetPriority (stream, &priority);
224- if (err != orSuccess) {
225- TORCH_CHECK (false , " Failed to get stream priority" );
226- }
227- return THPUtils_packInt32 (priority);
228- END_HANDLE_TH_ERRORS
229- }
230-
231- PyObject* _streamWaitEvent (PyObject* self, PyObject* args) {
232- HANDLE_TH_ERRORS
233- TORCH_CHECK (PyTuple_Size (args) == 2 , " stream_wait_event expects 2 arguments" );
234- PyObject* stream_obj = PyTuple_GetItem (args, 0 );
235- PyObject* event_obj = PyTuple_GetItem (args, 1 );
236- TORCH_CHECK (THPUtils_checkLong (stream_obj), " stream must be an int" );
237- TORCH_CHECK (THPUtils_checkLong (event_obj), " event must be an int" );
238- orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (stream_obj));
239- orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (event_obj));
240- orError_t err = orStreamWaitEvent (stream, event, 0 );
241- if (err != orSuccess) {
242- TORCH_CHECK (false , " Failed to wait for event" );
243- }
244- Py_RETURN_NONE;
245- END_HANDLE_TH_ERRORS
246- }
247-
248- // Event functions
249- PyObject* _eventCreate (PyObject* self, PyObject* noargs) {
250- HANDLE_TH_ERRORS
251- torch::utils::device_lazy_init (at::kPrivateUse1 );
252- orEvent_t event = nullptr ;
253- orError_t err = orEventCreate (&event);
254- if (err != orSuccess) {
255- TORCH_CHECK (false , " Failed to create event" );
256- }
257- return THPUtils_packInt64 (reinterpret_cast <int64_t >(event));
258- END_HANDLE_TH_ERRORS
259- }
260-
261- PyObject* _eventCreateWithFlags (PyObject* self, PyObject* arg) {
262- HANDLE_TH_ERRORS
263- TORCH_CHECK (THPUtils_checkLong (arg), " event_create_with_flags expects an int" );
264- unsigned int flags = static_cast <unsigned int >(THPUtils_unpackLong (arg));
265-
266- torch::utils::device_lazy_init (at::kPrivateUse1 );
267- orEvent_t event = nullptr ;
268- orError_t err = orEventCreateWithFlags (&event, flags);
269- if (err != orSuccess) {
270- TORCH_CHECK (false , " Failed to create event with flags" );
271- }
272- return THPUtils_packInt64 (reinterpret_cast <int64_t >(event));
273- END_HANDLE_TH_ERRORS
274- }
275-
276- PyObject* _eventDestroy (PyObject* self, PyObject* arg) {
277- HANDLE_TH_ERRORS
278- TORCH_CHECK (THPUtils_checkLong (arg), " event_destroy expects an int" );
279- orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (arg));
280- orError_t err = orEventDestroy (event);
281- if (err != orSuccess) {
282- TORCH_CHECK (false , " Failed to destroy event" );
283- }
284- Py_RETURN_NONE;
285- END_HANDLE_TH_ERRORS
286- }
287-
288- PyObject* _eventRecord (PyObject* self, PyObject* args) {
289- HANDLE_TH_ERRORS
290- TORCH_CHECK (PyTuple_Size (args) == 2 , " event_record expects 2 arguments" );
291- PyObject* event_obj = PyTuple_GetItem (args, 0 );
292- PyObject* stream_obj = PyTuple_GetItem (args, 1 );
293- TORCH_CHECK (THPUtils_checkLong (event_obj), " event must be an int" );
294- TORCH_CHECK (THPUtils_checkLong (stream_obj), " stream must be an int" );
295- orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (event_obj));
296- orStream_t stream = reinterpret_cast <orStream_t>(THPUtils_unpackLong (stream_obj));
297- orError_t err = orEventRecord (event, stream);
298- if (err != orSuccess) {
299- TORCH_CHECK (false , " Failed to record event" );
300- }
301- Py_RETURN_NONE;
302- END_HANDLE_TH_ERRORS
303- }
304-
305- PyObject* _eventSynchronize (PyObject* self, PyObject* arg) {
306- HANDLE_TH_ERRORS
307- TORCH_CHECK (THPUtils_checkLong (arg), " event_synchronize expects an int" );
308- orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (arg));
309-
310- orError_t err;
311- Py_BEGIN_ALLOW_THREADS
312- err = orEventSynchronize (event);
313- Py_END_ALLOW_THREADS
314-
315- if (err != orSuccess) {
316- TORCH_CHECK (false , " Failed to synchronize event" );
317- }
318- Py_RETURN_NONE;
319- END_HANDLE_TH_ERRORS
320- }
321-
322- PyObject* _eventQuery (PyObject* self, PyObject* arg) {
323- HANDLE_TH_ERRORS
324- TORCH_CHECK (THPUtils_checkLong (arg), " event_query expects an int" );
325- orEvent_t event = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (arg));
326- orError_t err = orEventQuery (event);
327- if (err == orSuccess) {
328- Py_RETURN_TRUE;
329- } else {
330- Py_RETURN_FALSE;
331- }
332- END_HANDLE_TH_ERRORS
333- }
334-
335- PyObject* _eventElapsedTime (PyObject* self, PyObject* args) {
336- HANDLE_TH_ERRORS
337- TORCH_CHECK (PyTuple_Size (args) == 2 , " event_elapsed_time expects 2 arguments" );
338- PyObject* start_obj = PyTuple_GetItem (args, 0 );
339- PyObject* end_obj = PyTuple_GetItem (args, 1 );
340- TORCH_CHECK (THPUtils_checkLong (start_obj), " start event must be an int" );
341- TORCH_CHECK (THPUtils_checkLong (end_obj), " end event must be an int" );
342- orEvent_t start = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (start_obj));
343- orEvent_t end = reinterpret_cast <orEvent_t>(THPUtils_unpackLong (end_obj));
344- float ms = 0 .0f ;
345- orError_t err = orEventElapsedTime (&ms, start, end);
346- if (err != orSuccess) {
347- TORCH_CHECK (false , " Failed to get elapsed time" );
348- }
349- return PyFloat_FromDouble (static_cast <double >(ms));
350- END_HANDLE_TH_ERRORS
351- }
352-
353188PyObject* _deviceSynchronize (PyObject* self, PyObject* noargs) {
354189 HANDLE_TH_ERRORS
355190 torch::utils::device_lazy_init (at::kPrivateUse1 );
@@ -421,20 +256,8 @@ static PyMethodDef methods[] = {
421256 {" get_amp_supported_dtype" , _getAmpSupportedDtype, METH_NOARGS, nullptr },
422257 // Stream functions
423258 {" _stream_create" , _streamCreate, METH_NOARGS, nullptr },
424- {" _stream_create_with_priority" , _streamCreateWithPriority, METH_VARARGS, nullptr },
425259 {" _stream_destroy" , _streamDestroy, METH_O, nullptr },
426- {" _stream_synchronize" , _streamSynchronize, METH_O, nullptr },
427- {" _stream_query" , _streamQuery, METH_O, nullptr },
428- {" _stream_get_priority" , _streamGetPriority, METH_O, nullptr },
429- {" _stream_wait_event" , _streamWaitEvent, METH_VARARGS, nullptr },
430- // Event functions
431- {" _event_create" , _eventCreate, METH_NOARGS, nullptr },
432- {" _event_create_with_flags" , _eventCreateWithFlags, METH_O, nullptr },
433- {" _event_destroy" , _eventDestroy, METH_O, nullptr },
434- {" _event_record" , _eventRecord, METH_VARARGS, nullptr },
435- {" _event_synchronize" , _eventSynchronize, METH_O, nullptr },
436- {" _event_query" , _eventQuery, METH_O, nullptr },
437- {" _event_elapsed_time" , _eventElapsedTime, METH_VARARGS, nullptr },
260+
438261 // Device functions
439262 {" _device_synchronize" , _deviceSynchronize, METH_NOARGS, nullptr },
440263 // Stream task functions
0 commit comments