Skip to content

Commit 51627fb

Browse files
authored
Fix distill exit bug (#109)
1 parent 37d850e commit 51627fb

File tree

2 files changed

+112
-27
lines changed

2 files changed

+112
-27
lines changed

python/paddle_edl/distill/distill_reader.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(self, ins, predicts):
100100
self._service_name = None
101101

102102
# reader worker args
103+
self._reader_worker = None
103104
self._reader = None
104105
self._reader_type = None
105106
self._reader_out_queue = None
@@ -167,6 +168,7 @@ def _start_reader_worker(self):
167168
self._reader_cond, ))
168169
reader_worker.daemon = True
169170
reader_worker.start()
171+
self._reader_worker = reader_worker
170172
self._is_reader_start = True
171173
else:
172174
with self._reader_cond:
@@ -209,7 +211,8 @@ def _start_predict_worker_pool(self):
209211
self._require_num,
210212
self._predict_stop_events,
211213
self._get_servers,
212-
self._predict_manage_stop_event, ))
214+
self._predict_manage_stop_event,
215+
self._predict_cond, ))
213216
self._predict_manage_thread.daemon = True
214217
self._predict_manage_thread.start()
215218

@@ -388,3 +391,19 @@ def __call__(self):
388391
self._reader_type, self._predict_out_queue,
389392
self._fetch_stop_event, self._task_semaphore):
390393
yield data
394+
395+
def __del__(self):
396+
if not self._is_args_init:
397+
return
398+
399+
# stop reader worker
400+
with self._reader_cond:
401+
self._reader_stop_event.set()
402+
self._reader_cond.notify()
403+
404+
self._predict_manage_stop_event.set()
405+
406+
for i in range(20):
407+
if self._reader_worker.is_alive() or \
408+
self._predict_manage_thread.is_alive():
409+
time.sleep(1)

python/paddle_edl/distill/distill_worker.py

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import logging
1616
import numpy as np
1717
import os
18+
import signal
19+
import six
20+
import sys
1821
import time
1922

2023
from paddle_serving_client import Client
@@ -53,7 +56,7 @@ def __init__(self, server_id, server, stop_event_id, state=PENDING):
5356

5457
def predict_manage_worker(process, server_queue, server_result_queue,
5558
require_num, predict_stop_events, get_servers_fun,
56-
stop_event):
59+
stop_event, predict_cond):
5760
""" thread that manage predict worker """
5861
num_shutdown_process = [0]
5962

@@ -129,6 +132,34 @@ def shutdown_one_process():
129132
except queue.Empty:
130133
pass
131134

135+
def clean_queue(data_queue):
136+
while True:
137+
try:
138+
data_queue.get_nowait()
139+
except Exception:
140+
break
141+
142+
clean_queue(server_queue)
143+
clean_queue(server_result_queue)
144+
145+
with predict_cond:
146+
for predict_stop_event in predict_stop_events:
147+
predict_stop_event.set()
148+
predict_cond.notify_all()
149+
150+
for i in range(require_num):
151+
shutdown_one_process()
152+
clean_queue(server_result_queue)
153+
154+
for i in range(20):
155+
shutdown_process = 0
156+
for p in process:
157+
if not p.is_alive():
158+
shutdown_process += 1
159+
if shutdown_process == len(process):
160+
break
161+
time.sleep(1)
162+
132163

133164
class _PoisonPill:
134165
def __init__(self, feed_count, predict_count=0):
@@ -284,22 +315,38 @@ def __del__(self):
284315
def predict_worker(server_queue, server_result_queue, working_predict_count,
285316
in_queue, out_queue, feeds, fetchs, conf_file, stop_events,
286317
predict_lock, global_finished_task, predict_cond):
287-
while True:
288-
# get server
289-
server_item = server_queue.get()
290-
if server_item is None:
291-
server_queue.put(None) # poison_pill
292-
return
293-
294-
# predict
295-
success = predict_loop(server_item, working_predict_count, in_queue,
296-
out_queue, feeds, fetchs, conf_file,
297-
stop_events, predict_lock, global_finished_task,
298-
predict_cond)
318+
signal_exit = [False, ]
319+
320+
# Define signal handler function
321+
def predict_signal_handle(signum, frame):
322+
signal_exit[0] = True
323+
exit(0)
324+
325+
# register signal.SIGTERM's handler
326+
signal.signal(signal.SIGTERM, predict_signal_handle)
327+
328+
try:
329+
while True:
330+
# get server
331+
server_item = server_queue.get()
332+
if server_item is None:
333+
server_result_queue.put(None)
334+
return
299335

300-
server_item.state = ServerItem.FINISHED if success else ServerItem.ERROR
301-
server_result_queue.put(server_item)
302-
logger.info('Stopped server={}'.format(server_item.server))
336+
# predict
337+
success = predict_loop(server_item, working_predict_count,
338+
in_queue, out_queue, feeds, fetchs,
339+
conf_file, stop_events, predict_lock,
340+
global_finished_task, predict_cond)
341+
342+
server_item.state = ServerItem.FINISHED if success else ServerItem.ERROR
343+
server_result_queue.put(server_item)
344+
logger.info('Stopped server={}'.format(server_item.server))
345+
except Exception as e:
346+
if signal_exit[0] is True:
347+
pass
348+
else:
349+
six.reraise(*sys.exc_info())
303350

304351

305352
def predict_loop(server_item, working_predict_count, in_queue, out_queue,
@@ -365,7 +412,8 @@ def predict_loop(server_item, working_predict_count, in_queue, out_queue,
365412
out_queue.put(poison_pill) # poison consumer
366413
else:
367414
in_queue.put(poison_pill) # poison other predict worker
368-
415+
if stop_event.is_set():
416+
break
369417
# wait next reader iter or last failed predict job
370418
predict_cond.wait()
371419

@@ -435,22 +483,40 @@ def reader_worker(reader, reader_type, teacher_batch_size, out_queue,
435483
# consumer may recv out-of-order task(3, 1, 2) before task(0), consumer will store then,
436484
# when task(0) is completed and consumer recv it, it will release semaphore,
437485
# reader go on working.
486+
487+
signal_exit = [False, ]
488+
489+
def reader_signal_handle(signum, frame):
490+
signal_exit[0] = True
491+
exit(0)
492+
493+
# register signal.SIGTERM's handler
494+
signal.signal(signal.SIGTERM, reader_signal_handle)
495+
438496
read_func_map = {
439497
ReaderType.SAMPLE: read_sample,
440498
ReaderType.SAMPLE_LIST: read_sample_list,
441499
ReaderType.BATCH: read_batch
442500
}
443501
read_func = read_func_map[reader_type]
444502

445-
while not stop_event.is_set():
446-
task_size = read_func(reader, teacher_batch_size, out_queue,
447-
task_semaphore)
448-
449-
poison_pill = _PoisonPill(task_size)
450-
with reader_cond:
451-
out_queue.put(poison_pill)
452-
# wait next reader iter
453-
reader_cond.wait()
503+
try:
504+
while not stop_event.is_set():
505+
task_size = read_func(reader, teacher_batch_size, out_queue,
506+
task_semaphore)
507+
508+
poison_pill = _PoisonPill(task_size)
509+
with reader_cond:
510+
out_queue.put(poison_pill)
511+
if stop_event.is_set():
512+
break
513+
# wait next reader iter
514+
reader_cond.wait()
515+
except Exception as e:
516+
if signal_exit[0] is True:
517+
pass
518+
else:
519+
six.reraise(*sys.exc_info())
454520

455521

456522
def read_sample(reader, teacher_batch_size, out_queue, task_semaphore):

0 commit comments

Comments
 (0)