Skip to content

Commit fc16f72

Browse files
authored
Support gRPC communication with SMG (Shepherd Model Gateway) workers (#3946)
* Support grpc communication with smg workers Minor Change Support grpc communication with smg router * Resolve Review Comments --------- Co-authored-by: Bihan Rana
1 parent c5fa9f0 commit fc16f72

5 files changed

Lines changed: 627 additions & 40 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ server = [
201201
"python-json-logger>=3.1.0",
202202
"prometheus-client",
203203
"grpcio>=1.50",
204+
"protobuf>=6.33.5",
205+
"smg-grpc-proto>=0.4.7",
204206
]
205207
aws = [
206208
"boto3>=1.38.13",

src/dstack/_internal/cli/commands/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def _command(self, args: argparse.Namespace):
8080
os.environ["DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT"] = "1"
8181
if args.token:
8282
os.environ["DSTACK_SERVER_ADMIN_TOKEN"] = args.token
83+
# Hide noisy "Other threads are currently calling into gRPC, skipping fork() handlers"
84+
# messages in server logs. Users can still change this with GRPC_VERBOSITY.
85+
os.environ.setdefault("GRPC_VERBOSITY", "ERROR")
8386
uvicorn_log_level = os.getenv("DSTACK_SERVER_UVICORN_LOG_LEVEL", "ERROR").lower()
8487
reload_disabled = os.getenv("DSTACK_SERVER_RELOAD_DISABLED") is not None
8588

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""SSH-tunneled gRPC channel target to a job's service port (UDS)."""
2+
3+
from collections.abc import AsyncGenerator
4+
from contextlib import asynccontextmanager
5+
from datetime import timedelta
6+
from pathlib import Path
7+
from tempfile import TemporaryDirectory
8+
from typing import Any
9+
10+
import grpc
11+
12+
from dstack._internal.core.services.ssh.tunnel import (
13+
SSH_DEFAULT_OPTIONS,
14+
IPSocket,
15+
SocketPair,
16+
UnixSocket,
17+
)
18+
from dstack._internal.server.models import JobModel
19+
from dstack._internal.server.services.jobs import get_job_spec
20+
from dstack._internal.server.services.ssh import container_ssh_tunnel
21+
from dstack._internal.utils.common import get_or_error
22+
23+
SSH_CONNECT_TIMEOUT = timedelta(seconds=10)
24+
# Match router_worker_sync HTTP server_info cap (_MAX_SERVER_INFO_RESPONSE_BYTES).
25+
_MAX_GRPC_MESSAGE_BYTES = 256 * 1024
26+
_GRPC_CHANNEL_OPTIONS = (
27+
("grpc.max_receive_message_length", _MAX_GRPC_MESSAGE_BYTES),
28+
("grpc.max_send_message_length", _MAX_GRPC_MESSAGE_BYTES),
29+
)
30+
31+
32+
@asynccontextmanager
33+
async def get_service_replica_grpc_client(job: JobModel) -> AsyncGenerator[Any, None]:
34+
options = {
35+
**SSH_DEFAULT_OPTIONS,
36+
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
37+
}
38+
job_spec = get_job_spec(job)
39+
with TemporaryDirectory() as temp_dir:
40+
# Keep the same socket file name as the HTTP helper for consistency.
41+
app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
42+
async with container_ssh_tunnel(
43+
job=job,
44+
forwarded_sockets=[
45+
SocketPair(
46+
remote=IPSocket("localhost", get_or_error(job_spec.service_port)),
47+
local=UnixSocket(app_socket_path),
48+
),
49+
],
50+
options=options,
51+
):
52+
target = f"unix://{app_socket_path}"
53+
channel = grpc.aio.insecure_channel(target, options=_GRPC_CHANNEL_OPTIONS)
54+
try:
55+
yield channel
56+
finally:
57+
await channel.close()

0 commit comments

Comments
 (0)