Skip to content

Commit 00c717a

Browse files
authored
Fix redis balance (#108)
1 parent 51627fb commit 00c717a

File tree

9 files changed

+204
-135
lines changed

9 files changed

+204
-135
lines changed

docker/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ RUN mkdir -p /tmp/protoc && cd /tmp/protoc && \
3737
RUN echo "go env -w GO111MODULE=on && go env -w GOPROXY=https://goproxy.io,direct" >> /root/.bashrc
3838
ENV GO111MODULE=on
3939
ENV GOPROXY=https://goproxy.io,direct
40+

docker/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ flask==1.1.2
55
pathlib2==2.3.5
66
protobuf==3.8.0
77
kubernetes
8+
redis
9+
paddle-serving-client

python/paddle_edl/distill/distill_reader.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
# -*- coding: utf-8 -*-
22
import logging
33
import multiprocessing as mps
4-
import socket
5-
import time
6-
import threading
74
import os
8-
9-
from contextlib import closing
10-
from six.moves import queue
5+
import threading
6+
import time
117

128
from . import distill_worker
139

@@ -17,22 +13,6 @@
1713
logger = logging.getLogger(__name__)
1814

1915

20-
def is_server_alive(server):
21-
# FIXME. only for test, need find a better test method
22-
if distill_worker._NOP_PREDICT_TEST:
23-
return True
24-
alive = True
25-
ip, port = server.split(":")
26-
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
27-
try:
28-
s.settimeout(1.5)
29-
s.connect((ip, int(port)))
30-
s.shutdown(socket.SHUT_RDWR)
31-
except:
32-
alive = False
33-
return alive
34-
35-
3616
class ServiceDiscover(object):
3717
def get_servers(self):
3818
pass

python/paddle_edl/distill/redis/balance_server.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def __init__(self, ip='127.0.0.1', port=9379, table=None):
208208
self._table = table
209209
self._handle_func = {
210210
'register': self._handle_register,
211-
'require_service': self._handle_require_service,
212211
'heartbeat': self._handle_heartbeat
213212
}
214213

@@ -233,26 +232,20 @@ def _handle_register(self, fd, msg):
233232
}
234233
self._enqueue_response(fd, msg)
235234

236-
def _handle_require_service(self, fd, msg):
237-
require_num = int(msg['num'])
238-
print('fd={}, require_num={}'.format(fd, require_num))
239-
240-
# for debug
241-
if self._table is None:
242-
teacher_list = ['127.0.0.1:0001', '127.0.0.1:0002']
243-
num = len(teacher_list)
244-
else:
245-
servers = self._table.get_servers(fd, require_num)
246-
teacher_list = servers
247-
num = len(teacher_list)
248-
249-
msg = {'type': 'response_service', 'servers': teacher_list, 'num': num}
250-
self._enqueue_response(fd, msg)
251-
252235
def _handle_heartbeat(self, fd, msg):
253-
is_update, servers = self._table.is_servers_update(fd)
254-
if is_update:
255-
msg = {'type': 'servers_change', 'servers': servers}
236+
version = 0
237+
try:
238+
version = int(msg['version'])
239+
except KeyError:
240+
# compatible old client
241+
pass
242+
new_version, servers = self._table.is_servers_update(fd, version)
243+
if new_version > version:
244+
msg = {
245+
'type': 'servers_change',
246+
'servers': servers,
247+
'version': new_version
248+
}
256249
else:
257250
msg = {'type': 'heartbeat'}
258251
self._enqueue_response(fd, msg)
@@ -274,7 +267,47 @@ def server_forever(self):
274267
if __name__ == '__main__':
275268
from service_table import ServiceTable
276269

277-
# service_name = 'TestService'
278-
table = ServiceTable('127.0.0.1', 6379) # connect redis ip:port
279-
balance_server = BalanceServer('0.0.0.0', 7001, table) # listen
270+
import argparse
271+
parser = argparse.ArgumentParser(
272+
description='Discovery server with balance')
273+
parser.add_argument(
274+
'--server',
275+
type=str,
276+
default='0.0.0.0:7001',
277+
help='endpoint of the server, e.g. 127.0.0.1:8888 [default: %(default)s]'
278+
)
279+
parser.add_argument(
280+
'--worker_num',
281+
type=int,
282+
default=1,
283+
help='worker num of server [default: %(default)s]')
284+
parser.add_argument(
285+
'--db_endpoints',
286+
type=str,
287+
default='127.0.0.1:6379',
288+
help='database endpoints, e.g. 127.0.0.1:2379,127.0.0.1:2380 [default: %(default)s]'
289+
)
290+
parser.add_argument(
291+
'--db_passwd',
292+
type=str,
293+
default=None,
294+
help='detabase password [default: %(default)s]')
295+
parser.add_argument(
296+
'--db_type',
297+
type=str,
298+
default='redis',
299+
help='database type, only support redis for now [default: %(default)s]')
300+
301+
args = parser.parse_args()
302+
server = args.server
303+
worker_num = args.worker_num
304+
db_endpoints = args.db_endpoints.split(',')
305+
306+
redis_ip_port = db_endpoints[0].split(':')
307+
server_ip_port = server.split(':')
308+
309+
table = ServiceTable(redis_ip_port[0],
310+
int(redis_ip_port[1])) # connect redis ip:port
311+
balance_server = BalanceServer(server_ip_port[0],
312+
int(server_ip_port[1]), table) # listen
280313
balance_server.server_forever()

python/paddle_edl/distill/redis/client.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self, endpoints, service_name, require_num, token=None):
4646
self._balance_list = endpoints
4747
self.teacher_list = []
4848
self._is_update = False
49+
self._version = 0
4950

5051
self._thread = None
5152
self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -86,22 +87,10 @@ def _register(self, require_num=1):
8687
servers = msg['servers']
8788
return servers
8889

89-
def _require(self, require_num):
90-
# require service
91-
msg = {'type': 'require_service', 'num': require_num}
92-
self._send_msg(msg)
93-
94-
# get servers
95-
msg = self._recv_msg()
96-
assert msg['type'] == 'response_service'
97-
response_num = msg['num']
98-
servers = msg['servers']
99-
return servers
100-
10190
def _heartbeat(self):
10291
while self._need_stop is False:
10392
time.sleep(2)
104-
msg = {'type': 'heartbeat'}
93+
msg = {'type': 'heartbeat', 'version': self._version}
10594
self._send_msg(msg)
10695
msg = self._recv_msg()
10796

@@ -110,11 +99,15 @@ def _heartbeat(self):
11099
continue
111100
elif msg['type'] == 'servers_change':
112101
self._is_update = True
102+
try:
103+
self._version = int(msg['version'])
104+
except KeyError:
105+
# compatible with old balance server
106+
pass
113107
self.teacher_list = msg['servers']
114-
sys.stderr.write('servers_change: ' + str(msg['servers']) +
115-
'\n')
116-
# Todo
117-
pass
108+
sys.stderr.write(
109+
'[INFO] service change version={} teachers={}\n'.format(
110+
self._version, str(self.teacher_list)))
118111

119112
def get_teacher_list(self):
120113
'''

python/paddle_edl/distill/redis/server_register.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,49 @@ def start(self, daemon=False):
8888
import sys
8989
from redis_store import RedisStore
9090

91-
ip = '127.0.0.1'
92-
port = 5458
93-
service_name = 'TestService'
94-
service_name = 'DistillService'
95-
96-
if len(sys.argv) == 2:
97-
port = int(sys.argv[1])
98-
elif len(sys.argv) >= 3:
99-
ip = sys.argv[1]
100-
port = int(sys.argv[2])
101-
if len(sys.argv) == 4:
102-
service_name = sys.argv[3]
103-
104-
print('register {}:{} service_name={}'.format(ip, port, service_name))
105-
106-
#store = RedisStore('127.0.0.1', 6379)
107-
store = RedisStore('10.255.100.13', 6379)
108-
register = ServerRegister(ip, port, service_name, store)
91+
import argparse
92+
parser = argparse.ArgumentParser(description='Server Register')
93+
parser.add_argument(
94+
'--db_endpoints',
95+
type=str,
96+
default='127.0.0.1:6379',
97+
help='database endpoints, e.g. 127.0.0.1:6379 [default: %(default)s]')
98+
parser.add_argument(
99+
'--db_passwd',
100+
type=str,
101+
default=None,
102+
help='detabase password [default: %(default)s]')
103+
parser.add_argument(
104+
'--db_type',
105+
type=str,
106+
default='redis',
107+
help='database type, only support redis for now [default: %(default)s]')
108+
parser.add_argument(
109+
'--service_name',
110+
type=str,
111+
help='service name where the server is located',
112+
required=True)
113+
parser.add_argument(
114+
'--server',
115+
type=str,
116+
help='endpoint of the server, e.g. 127.0.0.1:8888',
117+
required=True)
118+
# TODO. service_token
119+
parser.add_argument(
120+
'--service_token',
121+
type=str,
122+
default=None,
123+
help='service token, which the same can register [default: %(default)s]'
124+
)
125+
126+
args = parser.parse_args()
127+
server = args.server
128+
db_endpoints = args.db_endpoints.split(',')
129+
130+
redis_ip_port = db_endpoints[0].split(':')
131+
server_ip_port = server.split(':')
132+
133+
store = RedisStore(redis_ip_port[0], int(redis_ip_port[1]))
134+
register = ServerRegister(server_ip_port[0],
135+
int(server_ip_port[1]), args.service_name, store)
109136
register.start()

python/paddle_edl/distill/redis/service_table.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
import threading
1617
import time
1718
import sys
18-
import heapq
1919
from redis_store import RedisStore
2020

21+
logging.basicConfig(
22+
level=logging.DEBUG,
23+
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s")
24+
2125

2226
class ServiceTable(object):
2327
def __init__(self, ip='127.0.0.1', port=6379, passwd=None,
@@ -43,15 +47,16 @@ def __init__(self, ip='127.0.0.1', port=6379, passwd=None,
4347
# {fd: set(servers), }
4448
self._fd_to_servers = {}
4549
self._server_to_fds = {}
46-
self._fd_to_update = {}
50+
self._fd_to_version = {}
4751
self._fd_to_max_num = {}
4852

49-
def is_servers_update(self, fd):
50-
ret = self._fd_to_update[fd]
51-
self._fd_to_update[fd] = False
52-
if ret is True:
53-
return ret, list(self._fd_to_servers[fd])
54-
return ret, []
53+
def is_servers_update(self, fd, version):
54+
new_version = self._fd_to_version[fd]
55+
56+
if new_version > version: # is update
57+
return new_version, list(self._fd_to_servers[fd])
58+
else:
59+
return new_version, None
5560

5661
def get_servers(self, fd, num):
5762
if fd not in self._fd_to_service_name:
@@ -62,13 +67,12 @@ def get_servers(self, fd, num):
6267
if service_name not in self._service_name_to_servers or \
6368
self._service_name_to_update[service_name] is True:
6469
self._refresh_service(service_name)
65-
self._fd_to_update[fd] = False
6670

6771
return list(self._fd_to_servers[fd])
6872

6973
def add_service_name(self, fd, service_name, num):
7074
self._fd_to_servers[fd] = set()
71-
self._fd_to_update[fd] = False
75+
self._fd_to_version[fd] = 0
7276
self._fd_to_max_num[fd] = num
7377
print('fd={}, service_name={}, max_num={}'.format(fd, service_name,
7478
num))
@@ -103,7 +107,7 @@ def rm_service_name(self, fd):
103107
del self._fd_to_service_name[fd]
104108

105109
del self._fd_to_max_num[fd]
106-
del self._fd_to_update[fd]
110+
del self._fd_to_version[fd]
107111

108112
for server in self._fd_to_servers[fd]:
109113
self._server_to_fds[server].remove(fd)
@@ -163,7 +167,7 @@ def _refresh_service(self, service_name):
163167
with self._mutex:
164168
if service_name not in self._service_name_to_fds:
165169
for fd in update_fd:
166-
self._fd_to_update[fd] = True
170+
self._fd_to_version[fd] += 1
167171
return
168172
fd_num = len(self._service_name_to_fds[service_name])
169173

@@ -172,7 +176,7 @@ def _refresh_service(self, service_name):
172176
if server_num == 0:
173177
print('service={} server_num=0'.format(service_name))
174178
for fd in update_fd:
175-
self._fd_to_update[fd] = True
179+
self._fd_to_version[fd] += 1
176180
return
177181
# assume: fd_num=3, server_num=97
178182
# assign: {fd0:32, fd1:32, fd2:32}
@@ -193,24 +197,23 @@ def _refresh_service(self, service_name):
193197
self._fd_to_servers[fd].remove(server)
194198
update_fd.add(fd)
195199
print('pop fd={} server={}'.format(fd, server))
196-
# add heap
197-
# heapq.heappush(server_conn, (len(self._server_to_fds[server]), server))
198-
199-
# Todo. use ReadWrite Lock
200200
try:
201-
# fd greed connect with server
202-
for fd in self._service_name_to_fds[service_name]:
201+
fds = self._service_name_to_fds[service_name]
202+
for fd in fds:
203203
max_connect = min(fd_max_connect, self._fd_to_max_num[fd])
204-
print('fd={} max_connect={}'.format(fd, max_connect))
204+
logging.info('fd={} max_connect={}'.format(fd, max_connect))
205205
if fd not in self._fd_to_servers:
206206
self._fd_to_servers[fd] = set()
207207
# limit connect of fd
208208
while len(self._fd_to_servers[fd]) > max_connect:
209209
server = self._fd_to_servers[fd].pop()
210210
self._server_to_fds[server].remove(fd)
211211
update_fd.add(fd)
212-
print('pop1 fd={} server={}'.format(fd, server))
212+
logging.info('pop1 fd={} server={}'.format(fd, server))
213213

214+
# fd greed connect with server
215+
for fd in fds:
216+
max_connect = min(fd_max_connect, self._fd_to_max_num[fd])
214217
for server in servers:
215218
if len(self._fd_to_servers[fd]) >= max_connect:
216219
break
@@ -222,12 +225,12 @@ def _refresh_service(self, service_name):
222225
self._fd_to_servers[fd].add(server)
223226
self._server_to_fds[server].add(fd)
224227
update_fd.add(fd)
225-
print('add fd={} server={}'.format(fd, server))
228+
logging.info('add fd={} server={}'.format(fd, server))
226229
except Exception, e:
227230
sys.stderr.write(str(e) + '\n')
228231

229232
for fd in update_fd:
230-
self._fd_to_update[fd] = True
233+
self._fd_to_version[fd] += 1
231234

232235
def _refresh(self):
233236
while True:

0 commit comments

Comments
 (0)