Skip to content

Commit b71f0ec

Browse files
committed
Propagate CPU errors to events
1 parent 713f2b4 commit b71f0ec

17 files changed

Lines changed: 331 additions & 218 deletions

File tree

mlx/array.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ class MLX_API array {
426426
}
427427

428428
void detach_event() const {
429+
array_desc_->event.check_error();
429430
array_desc_->event = Event{};
430431
}
431432

mlx/backend/common/load.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <algorithm>
44
#include <utility>
55

6+
#include <fmt/format.h>
7+
68
#include "mlx/primitives.h"
79
#include "mlx/scheduler.h"
810

@@ -51,7 +53,14 @@ void Load::eval_cpu(const std::vector<array>& inputs, array& out) {
5153
}
5254
};
5355
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
54-
scheduler::enqueue(stream(), [fut = std::move(fut)]() { fut.wait(); });
56+
auto s = stream();
57+
scheduler::enqueue(s, [s, fut = std::move(fut)]() {
58+
try {
59+
fut.get();
60+
} catch (const std::exception& error) {
61+
scheduler::set_error(s, fmt::format("[Load::eval_cpu] {}", error.what()));
62+
}
63+
});
5564
}
5665

5766
} // namespace mlx::core

mlx/backend/cuda/event.cu

Lines changed: 93 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,7 @@ void CudaEvent::init_pool() {
113113
cuda_event_pool();
114114
}
115115

116-
// Wraps CudaEvent with a few features:
117-
// 1. The class can be copied.
118-
// 2. Make wait/record work with CPU streams.
119-
// 3. Add checks for waiting on un-recorded event.
116+
// Wraps CudaEvent so it can be copied.
120117
class CopyableCudaEvent {
121118
public:
122119
explicit CopyableCudaEvent(Device& d)
@@ -126,32 +123,24 @@ class CopyableCudaEvent {
126123
cudaEventDisableTiming | cudaEventBlockingSync)) {}
127124

128125
void wait() {
126+
check_recorded();
129127
event_->wait();
130128
}
131129

132130
void wait(Stream s) {
133-
if (s.device == mlx::core::Device::cpu) {
134-
scheduler::enqueue(s, [*this]() mutable {
135-
check_recorded();
136-
event_->wait();
137-
});
138-
} else {
139-
check_recorded();
140-
auto& encoder = cu::get_command_encoder(s);
141-
encoder.commit();
142-
event_->wait(encoder.stream());
143-
}
131+
assert(s.device == mlx::core::Device::gpu);
132+
check_recorded();
133+
auto& encoder = cu::get_command_encoder(s);
134+
encoder.commit();
135+
event_->wait(encoder.stream());
144136
}
145137

146138
void record(Stream s) {
147-
if (s.device == mlx::core::Device::cpu) {
148-
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
149-
} else {
150-
auto& encoder = cu::get_command_encoder(s);
151-
encoder.commit();
152-
event_->record(encoder.stream());
153-
recorded_ = true;
154-
}
139+
assert(s.device == mlx::core::Device::gpu);
140+
auto& encoder = cu::get_command_encoder(s);
141+
encoder.commit();
142+
event_->record(encoder.stream());
143+
recorded_ = true;
155144
}
156145

157146
bool is_signaled() const {
@@ -213,6 +202,11 @@ auto check_gpu_coherency() {
213202
return coherency;
214203
}
215204

205+
const CudaStream& signal_stream() {
206+
static CudaStream stream(device(0));
207+
return stream;
208+
}
209+
216210
AtomicEvent::AtomicEvent(Device& d) {
217211
void* buf;
218212
cudaError_t (*cuda_free)(void*);
@@ -264,14 +258,11 @@ void AtomicEvent::wait(cudaStream_t stream, uint32_t value) {
264258

265259
void AtomicEvent::wait(Stream s, uint32_t value) {
266260
nvtx3::scoped_range r("cu::AtomicEvent::wait(s)");
267-
if (s.device == mlx::core::Device::cpu) {
268-
scheduler::enqueue(s, [*this, value]() mutable { wait(value); });
269-
} else {
270-
auto& encoder = get_command_encoder(s);
271-
encoder.commit();
272-
wait(encoder.stream(), value);
273-
encoder.add_completed_handler([buf = buf_]() {});
274-
}
261+
assert(s.device == mlx::core::Device::gpu);
262+
auto& encoder = get_command_encoder(s);
263+
encoder.commit();
264+
wait(encoder.stream(), value);
265+
encoder.add_completed_handler([buf = buf_]() {});
275266
}
276267

277268
void AtomicEvent::signal(uint32_t value) {
@@ -289,17 +280,11 @@ void AtomicEvent::signal(cudaStream_t stream, uint32_t value) {
289280

290281
void AtomicEvent::signal(Stream s, uint32_t value) {
291282
nvtx3::scoped_range r("cu::AtomicEvent::signal(s)");
292-
if (s.device == mlx::core::Device::cpu) {
293-
// Signal through a GPU stream so the atomic is updated in GPU - updating
294-
// the atomic in CPU sometimes does not get GPU notified.
295-
scheduler::enqueue(
296-
s, [*this, value]() mutable { signal(signal_stream(), value); });
297-
} else {
298-
auto& encoder = get_command_encoder(s);
299-
encoder.commit();
300-
signal(encoder.stream(), value);
301-
encoder.add_completed_handler([buf = buf_]() {});
302-
}
283+
assert(s.device == mlx::core::Device::gpu);
284+
auto& encoder = get_command_encoder(s);
285+
encoder.commit();
286+
signal(encoder.stream(), value);
287+
encoder.add_completed_handler([buf = buf_]() {});
303288
}
304289

305290
bool AtomicEvent::is_signaled(uint32_t val) const {
@@ -319,9 +304,21 @@ uint32_t AtomicEvent::value() const {
319304
}
320305
}
321306

322-
const CudaStream& AtomicEvent::signal_stream() {
323-
static CudaStream stream(device(0));
324-
return stream;
307+
///////////////////////////////////////////////////////////////////////////////
308+
// EventImpl implementations
309+
///////////////////////////////////////////////////////////////////////////////
310+
311+
void EventImpl::ensure_created(Stream s, uint64_t signal_value) {
312+
if (is_created()) {
313+
return;
314+
}
315+
auto& d = cu::device(s.device);
316+
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
317+
nvtx3::mark("Using slow AtomicEvent");
318+
atomic = std::make_unique<cu::AtomicEvent>(d);
319+
} else {
320+
cuda = std::make_unique<cu::CopyableCudaEvent>(d);
321+
}
325322
}
326323

327324
} // namespace cu
@@ -330,86 +327,85 @@ const CudaStream& AtomicEvent::signal_stream() {
330327
// Event implementations
331328
///////////////////////////////////////////////////////////////////////////////
332329

333-
namespace {
334-
335-
struct EventImpl {
336-
// CudaEvent is preferred when possible because it is fast, however we have
337-
// to fallback to AtomicEvent in following cases:
338-
// 1. the event is used to wait/signal a cpu stream;
339-
// 2. signal value other than 1 has been specified.
340-
std::unique_ptr<cu::CopyableCudaEvent> cuda;
341-
std::unique_ptr<cu::AtomicEvent> atomic;
342-
343-
bool is_created() const {
344-
return cuda || atomic;
345-
}
346-
347-
void ensure_created(Stream s, uint64_t signal_value) {
348-
if (is_created()) {
349-
return;
350-
}
351-
auto& d = cu::device(s.device);
352-
if (s.device == mlx::core::Device::cpu || signal_value > 1) {
353-
nvtx3::mark("Using slow AtomicEvent");
354-
atomic = std::make_unique<cu::AtomicEvent>(d);
355-
} else {
356-
cuda = std::make_unique<cu::CopyableCudaEvent>(d);
357-
}
358-
}
359-
};
360-
361-
} // namespace
362-
363330
Event::Event(Stream s) : stream_(s) {
364-
event_ = std::shared_ptr<void>(
365-
new EventImpl(), [](void* ptr) { delete static_cast<EventImpl*>(ptr); });
331+
event_ = std::make_shared<cu::EventImpl>();
366332
}
367333

368334
void Event::wait() {
369-
auto* event = static_cast<EventImpl*>(event_.get());
370-
assert(event->is_created());
371-
if (event->cuda) {
335+
check_error();
336+
auto& event = cast<cu::EventImpl>();
337+
assert(event.is_created());
338+
if (event.cuda) {
372339
assert(value() == 1);
373-
event->cuda->wait();
340+
event.cuda->wait();
374341
} else {
375-
event->atomic->wait(value());
342+
event.atomic->wait(value());
376343
}
377344
CHECK_CUDA_ERROR(cudaPeekAtLastError());
345+
check_error();
378346
}
379347

380348
void Event::wait(Stream s) {
381-
auto* event = static_cast<EventImpl*>(event_.get());
382-
assert(event->is_created());
383-
if (event->cuda) {
349+
auto& event = cast<cu::EventImpl>();
350+
assert(event.is_created());
351+
if (event.cuda) {
384352
assert(value() == 1);
385-
event->cuda->wait(s);
353+
if (s.device == mlx::core::Device::cpu) {
354+
scheduler::wait_event(s, *this, [value = value()](Event& self) {
355+
self.cast<cu::EventImpl>().cuda->wait();
356+
});
357+
} else {
358+
event.cuda->wait(s);
359+
}
386360
} else {
387-
event->atomic->wait(s, value());
361+
if (s.device == mlx::core::Device::cpu) {
362+
scheduler::wait_event(s, *this, [value = value()](Event& self) {
363+
self.cast<cu::EventImpl>().atomic->wait(value);
364+
});
365+
} else {
366+
event.atomic->wait(s, value());
367+
}
388368
}
389369
}
390370

391371
void Event::signal(Stream s) {
392-
auto* event = static_cast<EventImpl*>(event_.get());
393-
event->ensure_created(s, value());
394-
if (event->cuda) {
372+
auto& event = cast<cu::EventImpl>();
373+
event.ensure_created(s, value());
374+
if (event.cuda) {
395375
assert(value() == 1);
396-
event->cuda->record(s);
376+
if (s.device == mlx::core::Device::cpu) {
377+
throw std::runtime_error("CudaEvent can not wait on CPU stream.");
378+
} else {
379+
event.cuda->record(s);
380+
}
397381
} else {
398-
event->atomic->signal(s, value());
382+
if (s.device == mlx::core::Device::cpu) {
383+
// Signal through a GPU stream so the atomic is updated in GPU - updating
384+
// the atomic in CPU sometimes does not get GPU notified.
385+
scheduler::signal_event(s, *this, [value = value()](Event& self) {
386+
self.cast<cu::EventImpl>().atomic->signal(cu::signal_stream(), value);
387+
});
388+
} else {
389+
event.atomic->signal(s, value());
390+
}
399391
}
400392
}
401393

402394
bool Event::is_signaled() const {
403-
auto* event = static_cast<EventImpl*>(event_.get());
404-
if (!event->is_created()) {
395+
auto& event = cast<cu::EventImpl>();
396+
if (!event.is_created()) {
405397
return false;
406398
}
407-
if (event->cuda) {
399+
if (event.cuda) {
408400
assert(value() == 1);
409-
return event->cuda->is_signaled();
401+
return event.cuda->is_signaled();
410402
} else {
411-
return event->atomic->is_signaled(value());
403+
return event.atomic->is_signaled(value());
412404
}
413405
}
414406

407+
Event::Error& Event::error() {
408+
return cast<cu::EventImpl>().error;
409+
}
410+
415411
} // namespace mlx::core

mlx/backend/cuda/event.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
namespace mlx::core::cu {
1515

16+
class CopyableCudaEvent;
1617
class Device;
1718

1819
// RAII-managed move-only wrapper of cudaEvent_t.
@@ -66,8 +67,6 @@ class AtomicEvent {
6667
uint32_t value() const;
6768

6869
private:
69-
const CudaStream& signal_stream();
70-
7170
uint32_t* ptr() const {
7271
return static_cast<uint32_t*>(buf_.get());
7372
}
@@ -76,4 +75,21 @@ class AtomicEvent {
7675
std::shared_ptr<void> buf_;
7776
};
7877

78+
struct EventImpl {
79+
Event::Error error;
80+
81+
// CudaEvent is preferred when possible because it is fast, however we have
82+
// to fallback to AtomicEvent in following cases:
83+
// 1. the event is used to wait/signal a cpu stream;
84+
// 2. signal value other than 1 has been specified.
85+
std::unique_ptr<cu::CopyableCudaEvent> cuda;
86+
std::unique_ptr<cu::AtomicEvent> atomic;
87+
88+
bool is_created() const {
89+
return cuda || atomic;
90+
}
91+
92+
void ensure_created(Stream s, uint64_t signal_value);
93+
};
94+
7995
} // namespace mlx::core::cu

mlx/backend/cuda/fence.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,23 @@ namespace mlx::core {
99

1010
struct FenceImpl {
1111
uint32_t count;
12-
cu::AtomicEvent event;
12+
Event event;
13+
14+
FenceImpl(uint32_t count, Stream s) : count(count), event(s) {}
1315
};
1416

1517
Fence::Fence(Stream s) {
16-
fence_ = std::shared_ptr<void>(
17-
new FenceImpl{0, cu::device(s.device)},
18-
[](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
18+
fence_ = std::make_shared<FenceImpl>(0, s);
19+
// Ensure that we use AtomicEvent.
20+
cast<FenceImpl>().event.cast<cu::EventImpl>().ensure_created(s, 2);
1921
}
2022

2123
void Fence::wait(Stream s, const array&) {
22-
auto* fence = static_cast<FenceImpl*>(fence_.get());
23-
fence->event.wait(fence->count);
24+
cast<FenceImpl>().event.wait();
2425
}
2526

2627
void Fence::update(Stream s, const array& a, bool cross_device) {
27-
auto* fence = static_cast<FenceImpl*>(fence_.get());
28+
auto& f = cast<FenceImpl>();
2829
if (cross_device) {
2930
// Move to managed memory if there is a device switch
3031
auto& cbuf =
@@ -35,8 +36,9 @@ void Fence::update(Stream s, const array& a, bool cross_device) {
3536
cu::allocator().move_to_unified_memory(cbuf, encoder.stream());
3637
}
3738
}
38-
fence->count++;
39-
fence->event.signal(s, fence->count);
39+
f.count++;
40+
f.event.set_value(f.count);
41+
f.event.signal(s);
4042
}
4143

4244
} // namespace mlx::core

0 commit comments

Comments
 (0)