@@ -113,10 +113,7 @@ void CudaEvent::init_pool() {
113113 cuda_event_pool ();
114114}
115115
116- // Wraps CudaEvent with a few features:
117- // 1. The class can be copied.
118- // 2. Make wait/record work with CPU streams.
119- // 3. Add checks for waiting on un-recorded event.
116+ // Wraps CudaEvent so it can be copied.
120117class CopyableCudaEvent {
121118 public:
122119 explicit CopyableCudaEvent (Device& d)
@@ -126,32 +123,24 @@ class CopyableCudaEvent {
126123 cudaEventDisableTiming | cudaEventBlockingSync)) {}
127124
128125 void wait () {
126+ check_recorded ();
129127 event_->wait ();
130128 }
131129
132130 void wait (Stream s) {
133- if (s.device == mlx::core::Device::cpu) {
134- scheduler::enqueue (s, [*this ]() mutable {
135- check_recorded ();
136- event_->wait ();
137- });
138- } else {
139- check_recorded ();
140- auto & encoder = cu::get_command_encoder (s);
141- encoder.commit ();
142- event_->wait (encoder.stream ());
143- }
131+ assert (s.device == mlx::core::Device::gpu);
132+ check_recorded ();
133+ auto & encoder = cu::get_command_encoder (s);
134+ encoder.commit ();
135+ event_->wait (encoder.stream ());
144136 }
145137
146138 void record (Stream s) {
147- if (s.device == mlx::core::Device::cpu) {
148- throw std::runtime_error (" CudaEvent can not wait on CPU stream." );
149- } else {
150- auto & encoder = cu::get_command_encoder (s);
151- encoder.commit ();
152- event_->record (encoder.stream ());
153- recorded_ = true ;
154- }
139+ assert (s.device == mlx::core::Device::gpu);
140+ auto & encoder = cu::get_command_encoder (s);
141+ encoder.commit ();
142+ event_->record (encoder.stream ());
143+ recorded_ = true ;
155144 }
156145
157146 bool is_signaled () const {
@@ -213,6 +202,11 @@ auto check_gpu_coherency() {
213202 return coherency;
214203}
215204
205+ const CudaStream& signal_stream () {
206+ static CudaStream stream (device (0 ));
207+ return stream;
208+ }
209+
216210AtomicEvent::AtomicEvent (Device& d) {
217211 void * buf;
218212 cudaError_t (*cuda_free)(void *);
@@ -264,14 +258,11 @@ void AtomicEvent::wait(cudaStream_t stream, uint32_t value) {
264258
265259void AtomicEvent::wait (Stream s, uint32_t value) {
266260 nvtx3::scoped_range r (" cu::AtomicEvent::wait(s)" );
267- if (s.device == mlx::core::Device::cpu) {
268- scheduler::enqueue (s, [*this , value]() mutable { wait (value); });
269- } else {
270- auto & encoder = get_command_encoder (s);
271- encoder.commit ();
272- wait (encoder.stream (), value);
273- encoder.add_completed_handler ([buf = buf_]() {});
274- }
261+ assert (s.device == mlx::core::Device::gpu);
262+ auto & encoder = get_command_encoder (s);
263+ encoder.commit ();
264+ wait (encoder.stream (), value);
265+ encoder.add_completed_handler ([buf = buf_]() {});
275266}
276267
277268void AtomicEvent::signal (uint32_t value) {
@@ -289,17 +280,11 @@ void AtomicEvent::signal(cudaStream_t stream, uint32_t value) {
289280
290281void AtomicEvent::signal (Stream s, uint32_t value) {
291282 nvtx3::scoped_range r (" cu::AtomicEvent::signal(s)" );
292- if (s.device == mlx::core::Device::cpu) {
293- // Signal through a GPU stream so the atomic is updated in GPU - updating
294- // the atomic in CPU sometimes does not get GPU notified.
295- scheduler::enqueue (
296- s, [*this , value]() mutable { signal (signal_stream (), value); });
297- } else {
298- auto & encoder = get_command_encoder (s);
299- encoder.commit ();
300- signal (encoder.stream (), value);
301- encoder.add_completed_handler ([buf = buf_]() {});
302- }
283+ assert (s.device == mlx::core::Device::gpu);
284+ auto & encoder = get_command_encoder (s);
285+ encoder.commit ();
286+ signal (encoder.stream (), value);
287+ encoder.add_completed_handler ([buf = buf_]() {});
303288}
304289
305290bool AtomicEvent::is_signaled (uint32_t val) const {
@@ -319,9 +304,21 @@ uint32_t AtomicEvent::value() const {
319304 }
320305}
321306
322- const CudaStream& AtomicEvent::signal_stream () {
323- static CudaStream stream (device (0 ));
324- return stream;
307+ // /////////////////////////////////////////////////////////////////////////////
308+ // EventImpl implementations
309+ // /////////////////////////////////////////////////////////////////////////////
310+
311+ void EventImpl::ensure_created (Stream s, uint64_t signal_value) {
312+ if (is_created ()) {
313+ return ;
314+ }
315+ auto & d = cu::device (s.device );
316+ if (s.device == mlx::core::Device::cpu || signal_value > 1 ) {
317+ nvtx3::mark (" Using slow AtomicEvent" );
318+ atomic = std::make_unique<cu::AtomicEvent>(d);
319+ } else {
320+ cuda = std::make_unique<cu::CopyableCudaEvent>(d);
321+ }
325322}
326323
327324} // namespace cu
@@ -330,86 +327,85 @@ const CudaStream& AtomicEvent::signal_stream() {
330327// Event implementations
331328// /////////////////////////////////////////////////////////////////////////////
332329
333- namespace {
334-
335- struct EventImpl {
336- // CudaEvent is preferred when possible because it is fast, however we have
337- // to fallback to AtomicEvent in following cases:
338- // 1. the event is used to wait/signal a cpu stream;
339- // 2. signal value other than 1 has been specified.
340- std::unique_ptr<cu::CopyableCudaEvent> cuda;
341- std::unique_ptr<cu::AtomicEvent> atomic;
342-
343- bool is_created () const {
344- return cuda || atomic;
345- }
346-
347- void ensure_created (Stream s, uint64_t signal_value) {
348- if (is_created ()) {
349- return ;
350- }
351- auto & d = cu::device (s.device );
352- if (s.device == mlx::core::Device::cpu || signal_value > 1 ) {
353- nvtx3::mark (" Using slow AtomicEvent" );
354- atomic = std::make_unique<cu::AtomicEvent>(d);
355- } else {
356- cuda = std::make_unique<cu::CopyableCudaEvent>(d);
357- }
358- }
359- };
360-
361- } // namespace
362-
363330Event::Event (Stream s) : stream_(s) {
364- event_ = std::shared_ptr<void >(
365- new EventImpl (), [](void * ptr) { delete static_cast <EventImpl*>(ptr); });
331+ event_ = std::make_shared<cu::EventImpl>();
366332}
367333
368334void Event::wait () {
369- auto * event = static_cast <EventImpl*>(event_.get ());
370- assert (event->is_created ());
371- if (event->cuda ) {
335+ check_error ();
336+ auto & event = cast<cu::EventImpl>();
337+ assert (event.is_created ());
338+ if (event.cuda ) {
372339 assert (value () == 1 );
373- event-> cuda ->wait ();
340+ event. cuda ->wait ();
374341 } else {
375- event-> atomic ->wait (value ());
342+ event. atomic ->wait (value ());
376343 }
377344 CHECK_CUDA_ERROR (cudaPeekAtLastError ());
345+ check_error ();
378346}
379347
380348void Event::wait (Stream s) {
381- auto * event = static_cast < EventImpl*>(event_. get () );
382- assert (event-> is_created ());
383- if (event-> cuda ) {
349+ auto & event = cast<cu:: EventImpl>( );
350+ assert (event. is_created ());
351+ if (event. cuda ) {
384352 assert (value () == 1 );
385- event->cuda ->wait (s);
353+ if (s.device == mlx::core::Device::cpu) {
354+ scheduler::wait_event (s, *this , [value = value ()](Event& self) {
355+ self.cast <cu::EventImpl>().cuda ->wait ();
356+ });
357+ } else {
358+ event.cuda ->wait (s);
359+ }
386360 } else {
387- event->atomic ->wait (s, value ());
361+ if (s.device == mlx::core::Device::cpu) {
362+ scheduler::wait_event (s, *this , [value = value ()](Event& self) {
363+ self.cast <cu::EventImpl>().atomic ->wait (value);
364+ });
365+ } else {
366+ event.atomic ->wait (s, value ());
367+ }
388368 }
389369}
390370
391371void Event::signal (Stream s) {
392- auto * event = static_cast < EventImpl*>(event_. get () );
393- event-> ensure_created (s, value ());
394- if (event-> cuda ) {
372+ auto & event = cast<cu:: EventImpl>( );
373+ event. ensure_created (s, value ());
374+ if (event. cuda ) {
395375 assert (value () == 1 );
396- event->cuda ->record (s);
376+ if (s.device == mlx::core::Device::cpu) {
377+ throw std::runtime_error (" CudaEvent can not wait on CPU stream." );
378+ } else {
379+ event.cuda ->record (s);
380+ }
397381 } else {
398- event->atomic ->signal (s, value ());
382+ if (s.device == mlx::core::Device::cpu) {
383+ // Signal through a GPU stream so the atomic is updated in GPU - updating
384+ // the atomic in CPU sometimes does not get GPU notified.
385+ scheduler::signal_event (s, *this , [value = value ()](Event& self) {
386+ self.cast <cu::EventImpl>().atomic ->signal (cu::signal_stream (), value);
387+ });
388+ } else {
389+ event.atomic ->signal (s, value ());
390+ }
399391 }
400392}
401393
402394bool Event::is_signaled () const {
403- auto * event = static_cast < EventImpl*>(event_. get () );
404- if (!event-> is_created ()) {
395+ auto & event = cast<cu:: EventImpl>( );
396+ if (!event. is_created ()) {
405397 return false ;
406398 }
407- if (event-> cuda ) {
399+ if (event. cuda ) {
408400 assert (value () == 1 );
409- return event-> cuda ->is_signaled ();
401+ return event. cuda ->is_signaled ();
410402 } else {
411- return event-> atomic ->is_signaled (value ());
403+ return event. atomic ->is_signaled (value ());
412404 }
413405}
414406
407+ Event::Error& Event::error () {
408+ return cast<cu::EventImpl>().error ;
409+ }
410+
415411} // namespace mlx::core
0 commit comments