diff --git a/cvs/input/config_file/inference/vllm_distributed/mi300x_vllm-distributed_llama31-70b_fp8_config.json b/cvs/input/config_file/inference/vllm_distributed/mi300x_vllm-distributed_llama31-70b_fp8_config.json new file mode 100644 index 00000000..6ac3245f --- /dev/null +++ b/cvs/input/config_file/inference/vllm_distributed/mi300x_vllm-distributed_llama31-70b_fp8_config.json @@ -0,0 +1,88 @@ +{ + "schema_version": 1, + "framework": "vllm_distributed", + "gpu_arch": "mi300x", + "enforce_thresholds": false, + "threshold_json": "", + "paths": { + "shared_fs": "/mnt/dtni/{user-id}", + "models_dir": "{shared_fs}/models", + "log_dir": "{shared_fs}/LOGS", + "hf_token_file": "{shared_fs}/.cache/huggingface/token" + }, + "model": { + "id": "amd/Llama-3.1-70B-Instruct-FP8-KV", + "remote": 0 + }, + "container": { + "lifetime": "per_run", + "name": "w2_llama31_70b_fp8kv_dist_rocm", + "image": "", + "runtime": { + "name": "docker", + "args": { + "network": "host", + "ipc": "host", + "privileged": true, + "volumes": [ + "/home/{user-id}:/home/{user-id}", + "/mnt/dtni:/mnt/dtni", + "{paths.models_dir}:/models" + ] + } + } + }, + "roles": { + "server": { + "serve_args": { + "kv-cache-dtype": "fp8", + "enforce-eager": true + }, + "env": { + "HF_HUB_OFFLINE": "1", + "TRANSFORMERS_OFFLINE": "1", + "GLOO_SOCKET_IFNAME": "", + "TP_SOCKET_IFNAME": "", + "NCCL_SOCKET_IFNAME": "" + } + } + }, + "params": { + "backend": "vllm", + "base_url": "http://0.0.0.0", + "port_no": "8888", + "dataset_name": "random", + "burstiness": "1.0", + "seed": "0", + "request_rate": "inf", + "random_range_ratio": "0.8", + "random_prefix_len": "0", + "tensor_parallelism": "8", + "pipeline_parallel_size": "2", + "nnodes": "2", + "master_addr": "", + "master_port": "29501", + "tokenizer_mode": "auto", + "percentile_metrics": "ttft,tpot,itl,e2el", + "metric_percentiles": "50,90,95,99", + "num_prompts": "3200", + "client_poll_count": "90" + }, + "sweep": { + "sequence_combinations": [ + { + "name": "w2_isl=1000_osl=1000", + "isl": "1000", + "osl": "1000", + "goodput_slo": { + "ttft_ms": 1000000000.0, + "tpot_ms": 1000000000.0, + "e2el_ms": 1000000000.0 + } + } + ], + "runs": [ + { "combo": "w2_isl=1000_osl=1000", "concurrency": 16 } + ] + } +} diff --git a/cvs/input/config_file/inference/vllm_distributed/mi300x_vllm-distributed_llama31-70b_fp8_threshold.json b/cvs/input/config_file/inference/vllm_distributed/mi300x_vllm-distributed_llama31-70b_fp8_threshold.json new file mode 100644 index 00000000..fbe29f89 --- /dev/null +++ b/cvs/input/config_file/inference/vllm_distributed/mi300x_vllm-distributed_llama31-70b_fp8_threshold.json @@ -0,0 +1,28 @@ +{ + "_comment": "PLACEHOLDER thresholds for W2 (Llama 3.1 70B FP8-KV, 2-node, TP=8, PP=2, CONC=16) -- record-only run (enforce_thresholds=false). Replace values and flip enforce_thresholds=true once calibrated numbers are available.", + "ISL=1000,OSL=1000,TP=8,PP=2,CONC=16": { + "client.total_token_throughput": { "kind": "min_tok_s", "value": 0 }, + "client.output_throughput": { "kind": "min_tok_s", "value": 0 }, + "client.mean_ttft_ms": { "kind": "max_ms", "value": 1000000 }, + "client.median_ttft_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p90_ttft_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p95_ttft_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p99_ttft_ms": { "kind": "max_ms", "value": 1000000 }, + "client.mean_tpot_ms": { "kind": "max_ms", "value": 1000000 }, + "client.median_tpot_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p90_tpot_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p95_tpot_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p99_tpot_ms": { "kind": "max_ms", "value": 1000000 }, + "client.mean_itl_ms": { "kind": "max_ms", "value": 1000000 }, + "client.median_itl_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p95_itl_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p99_itl_ms": { "kind": "max_ms", "value": 1000000 }, + "client.mean_e2el_ms": { "kind": "max_ms", "value": 1000000 }, + "client.median_e2el_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p90_e2el_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p95_e2el_ms": { "kind": "max_ms", "value": 1000000 }, + "client.p99_e2el_ms": { "kind": "max_ms", "value": 1000000 }, + "client.success_rate": { "kind": "min", "value": 0 }, + "client.failed": { "kind": "max", "value": 1000000000 } + } +} diff --git a/cvs/lib/inference/unittests/test_vllm_distributed.py b/cvs/lib/inference/unittests/test_vllm_distributed.py new file mode 100644 index 00000000..31a47445 --- /dev/null +++ b/cvs/lib/inference/unittests/test_vllm_distributed.py @@ -0,0 +1,1183 @@ +''' +Copyright 2025 Advanced Micro Devices, Inc. +All rights reserved. + +Unit tests for VllmDistributedJob -- the multinode vLLM job driven by a +ContainerOrchestrator. No hardware: a FakeOrch records the commands the job +issues and returns canned per-host output. + +These tests are committed RED (greenfield): the implementation in +cvs/lib/inference/vllm_distributed.py is a placeholder and is filled in by a +separate agent who cannot edit this file. The contract under test comes from the +vllm_distributed spec, NOT from any implementation source. + +Key distributed-vs-single contract (spec File 2): + - server env script + out-dir mkdir broadcast to ALL nodes via orch.exec() + (no hosts kwarg / hosts=None) + - server launched per-host via orch.exec(..., hosts=[host]) with per-rank + --node-rank, --master-addr, --pipeline-parallel-size, --distributed- + executor-backend mp + - client launched + polled + results-fetched via orch.exec_on_head() + (single-entry {head_host: text} dict), never the broadcast orch.exec() +''' + +import json +import re +import unittest +import unittest.mock +from pathlib import Path +from types import SimpleNamespace + +from cvs.lib.inference.vllm_distributed import VllmDistributedJob + +_HERE = Path(__file__).parent +_FIXTURES = _HERE / "fixtures" + +# isl/tp must match the artifact fixture for the derived-math assertions to be +# meaningful (real artifact: isl=128, tp=8). +_ISL = 128 +_TP = 8 +_PP = 2 +_NNODES = 2 +_MASTER_ADDR = "node-head" # symbolic, not real infra (anti-pattern: hardcoded IPs) +_MASTER_PORT = "29501" + +# Symbolic host names -- NOT real infra IPs (Hardcoded-Infrastructure-IDs anti-pattern). +_HOSTS = ["node-a", "node-b"] +_HEAD = _HOSTS[0] + + +class FakeOrch: + """Stand-in for ContainerOrchestrator. + + Records every exec / exec_on_head call (command + kwargs) so tests can assert + on the broadcast-vs-head routing and per-host rank dispatch. `hosts` mirrors a + real multinode orchestrator so enumerate(orch.hosts) yields (rank, host). + """ + + def __init__(self, hosts=None, exec_return=None, head_return=None): + self.hosts = list(hosts) if hosts is not None else list(_HOSTS) + self._exec_return = exec_return + self._head_return = head_return + self.exec_calls = [] # list of (cmd, kwargs) + self.head_calls = [] # list of (cmd, kwargs) + + def exec(self, cmd, **kwargs): + self.exec_calls.append((cmd, kwargs)) + ret = self._exec_return + if callable(ret): + ret = ret(cmd, kwargs) + if ret is not None: + return ret + # Default: broadcast result keyed by the targeted hosts (or all hosts). + targets = kwargs.get("hosts") or self.hosts + return {h: "" for h in targets} + + def exec_on_head(self, cmd, **kwargs): + self.head_calls.append((cmd, kwargs)) + ret = self._head_return + if callable(ret): + ret = ret(cmd, kwargs) + if ret is not None: + return ret + return {_HEAD: ""} + + +def _fake_variant(): + """SimpleNamespace tree carrying exactly the attrs VllmDistributedJob reads. + + Mirrors vllm_single's _fake_variant plus the distributed params + (pipeline_parallel_size, master_addr, master_port, nnodes). + """ + params = SimpleNamespace( + tensor_parallelism=str(_TP), + pipeline_parallel_size=str(_PP), + master_addr=_MASTER_ADDR, + master_port=_MASTER_PORT, + nnodes=str(_NNODES), + port_no="8888", + random_range_ratio="0.8", + random_prefix_len="0", + burstiness="1.0", + seed="0", + request_rate="inf", + tokenizer_mode="auto", + percentile_metrics="ttft,tpot,itl,e2el", + metric_percentiles="50,90,95,99", + base_url="http://0.0.0.0", + dataset_name="random", + backend="vllm", + ) + return SimpleNamespace( + params=params, + model=SimpleNamespace(id="amd/Llama-3.1-70B-Instruct-FP8-KV"), + roles=SimpleNamespace(server=SimpleNamespace(serve_args={}, env={})), + paths=SimpleNamespace(log_dir="/tmp/logs", models_dir="/tmp/models"), + ) + + +def _make_job(orch, goodput_slo=None, **kwargs): + return VllmDistributedJob( + orch=orch, + variant=_fake_variant(), + hf_token="tok", + isl=_ISL, + osl=2048, + concurrency=256, + num_prompts=12800, + goodput_slo=goodput_slo, + **kwargs, + ) + + +def _load_fixture(name): + return (_FIXTURES / name).read_text() + + +def _argv_after(argv, flag): + """Return the token following `flag` in argv, or None if flag absent/last.""" + if flag not in argv: + return None + i = argv.index(flag) + return argv[i + 1] if i + 1 < len(argv) else None + + +# --------------------------------------------------------------------------- +# Pure / derived builders: _server_argv, _derive_max_model_len +# Classification: pure (output depends only on self fields set at construction). +# --------------------------------------------------------------------------- + + +class TestServerArgv(unittest.TestCase): + """`_server_argv(rank)` builds the per-rank `vllm serve` arg list (spec File 2 + `_server_argv`). Asserted on the parsed argv list, not a rendered blob.""" + + def setUp(self): + self.job = _make_job(FakeOrch()) + + def test_server_argv_rank0(self): + argv = self.job._server_argv(0) + # Distributed flags present with correct values for the head rank. + self.assertEqual(_argv_after(argv, "--node-rank"), "0") + self.assertEqual(_argv_after(argv, "--master-addr"), _MASTER_ADDR) + self.assertEqual(_argv_after(argv, "--master-port"), _MASTER_PORT) + self.assertEqual(_argv_after(argv, "--pipeline-parallel-size"), str(_PP)) + self.assertEqual(_argv_after(argv, "--tensor-parallel-size"), str(_TP)) + self.assertEqual(_argv_after(argv, "--nnodes"), str(_NNODES)) + self.assertEqual(_argv_after(argv, "--distributed-executor-backend"), "mp") + # It still serves the model and carries the per-cell derived max-model-len. + self.assertIn("serve", argv) + self.assertIn(self.job.model_id, argv) + self.assertEqual(_argv_after(argv, "--max-model-len"), self.job._derive_max_model_len()) + self.assertEqual(_argv_after(argv, "--port"), "8888") + + def test_server_argv_rank1(self): + argv = self.job._server_argv(1) + # Only the rank differs between hosts; rendezvous coords stay identical. + self.assertEqual(_argv_after(argv, "--node-rank"), "1") + self.assertEqual(_argv_after(argv, "--master-addr"), _MASTER_ADDR) + self.assertEqual(_argv_after(argv, "--master-port"), _MASTER_PORT) + + def test_node_rank_is_the_only_per_rank_difference(self): + """Invariant: argv(rank0) and argv(rank1) differ ONLY in the --node-rank + value -- every other rendezvous/parallelism flag is rank-invariant. A + smeared rank or a per-rank master-addr drift is caught here.""" + a0 = self.job._server_argv(0) + # Verify invariant for ranks 1, 2, 3 — not just rank=1. + for rank in range(1, 4): + with self.subTest(rank=rank): + ar = self.job._server_argv(rank) + self.assertEqual(len(a0), len(ar), "argv length must be identical for all ranks") + diffs = [(x, y) for x, y in zip(a0, ar) if x != y] + self.assertEqual(len(diffs), 1, f"only --node-rank value should differ for rank {rank}") + self.assertEqual(diffs[0], ("0", str(rank))) + + def test_serve_args_appended(self): + # Per-model serve_args from roles.server.serve_args flow through. + v = _fake_variant() + v.roles.server.serve_args = {"kv-cache-dtype": "fp8"} + job = VllmDistributedJob( + orch=FakeOrch(), + variant=v, + hf_token="tok", + isl=_ISL, + osl=2048, + concurrency=256, + num_prompts=12800, + ) + argv = job._server_argv(0) + self.assertEqual(_argv_after(argv, "--kv-cache-dtype"), "fp8") + + def test_serve_args_bare_flag(self): + v = _fake_variant() + v.roles.server.serve_args = {"trust-remote-code": True} + job = VllmDistributedJob( + orch=FakeOrch(), variant=v, hf_token="tok", isl=_ISL, osl=2048, concurrency=256, num_prompts=12800 + ) + argv = job._server_argv(0) + self.assertIn("--trust-remote-code", argv) + idx = argv.index("--trust-remote-code") + next_is_flag_or_end = (idx + 1 >= len(argv)) or argv[idx + 1].startswith("--") + self.assertTrue( + next_is_flag_or_end, + f"bare flag --trust-remote-code must not be followed by a value token, " + f"got: {argv[idx + 1] if idx + 1 < len(argv) else ''}", + ) + + def test_serve_args_list_value(self): + v = _fake_variant() + v.roles.server.serve_args = {"lora-modules": ["mod-a", "mod-b"]} + job = VllmDistributedJob( + orch=FakeOrch(), variant=v, hf_token="tok", isl=_ISL, osl=2048, concurrency=256, num_prompts=12800 + ) + argv = job._server_argv(0) + # list renders as --flag v1 --flag v2 + indices = [i for i, a in enumerate(argv) if a == "--lora-modules"] + self.assertEqual(len(indices), 2) + self.assertEqual(argv[indices[0] + 1], "mod-a") + self.assertEqual(argv[indices[1] + 1], "mod-b") + + def test_server_argv_tracks_pp_and_nnodes(self): + """F2: _server_argv must reflect non-default pipeline_parallel_size and + nnodes values. Verifies two points: (pp=4, nnodes=4) and (pp=1, nnodes=1).""" + # Point 1: pp=4, nnodes=4 + v = _fake_variant() + v.params.pipeline_parallel_size = "4" + v.params.nnodes = "4" + job = VllmDistributedJob( + orch=FakeOrch(), variant=v, hf_token="tok", isl=_ISL, osl=2048, concurrency=256, num_prompts=12800 + ) + argv = job._server_argv(0) + self.assertEqual(_argv_after(argv, "--pipeline-parallel-size"), "4") + self.assertEqual(_argv_after(argv, "--nnodes"), "4") + + # Point 2: pp=1, nnodes=1 — must carry these values and NOT the prior "4" defaults + v2 = _fake_variant() + v2.params.pipeline_parallel_size = "1" + v2.params.nnodes = "1" + job2 = VllmDistributedJob( + orch=FakeOrch(), variant=v2, hf_token="tok", isl=_ISL, osl=2048, concurrency=256, num_prompts=12800 + ) + argv2 = job2._server_argv(0) + self.assertEqual(_argv_after(argv2, "--pipeline-parallel-size"), "1") + self.assertEqual(_argv_after(argv2, "--nnodes"), "1") + self.assertNotEqual( + _argv_after(argv2, "--pipeline-parallel-size"), "2", "pp=1 variant must not emit the default '2'" + ) + + def test_serve_args_tuple_value(self): + """F3: tuple values in serve_args must render the same as list values — + --flag v1 --flag v2.""" + v = _fake_variant() + v.roles.server.serve_args = {"lora-modules": ("mod-a", "mod-b")} + job = VllmDistributedJob( + orch=FakeOrch(), variant=v, hf_token="tok", isl=_ISL, osl=2048, concurrency=256, num_prompts=12800 + ) + argv = job._server_argv(0) + indices = [i for i, a in enumerate(argv) if a == "--lora-modules"] + self.assertEqual(len(indices), 2, "--lora-modules must appear twice for a tuple with two values") + self.assertEqual(argv[indices[0] + 1], "mod-a") + self.assertEqual(argv[indices[1] + 1], "mod-b") + + +class TestDerivedMaxModelLen(unittest.TestCase): + """MAX_MODEL_LEN derived per cell: worst = (isl+osl)*(1+r); + prefix + pad. + Identical contract to VllmJob. Range table via subTest + boundary cases.""" + + def test_derive_max_model_len_ranges(self): + # (random_range_ratio, random_prefix_len) -> expected string. + # base cell: isl=128, osl=2048 => isl+osl = 2176. + cases = [ + ("0.8", "0", "3925"), # ceil(2176*1.8)=3917 +0 +8 + ("0.1", "0", "2402"), # ceil(2176*1.1)=2394 +0 +8 + ("0.0", "0", "2184"), # 2176 +0 +8 (boundary: no jitter) + ("0.0", "64", "2248"), # 2176 +64 +8 (prefix only) + ("0.8", "64", "3989"), # ceil(2176*1.8)=3917 +64 +8 (ratio AND prefix combined) + ("1.0", "0", "4360"), # ceil(2176*2.0)=4352 +0 +8 (boundary: full-width doubling) + ] + for ratio, prefix, expected in cases: + with self.subTest(ratio=ratio, prefix=prefix): + job = _make_job(FakeOrch()) + job.random_range_ratio = ratio + job.random_prefix_len = prefix + self.assertEqual(job._derive_max_model_len(), expected) + + def test_monotonic_in_ratio(self): + """Invariant: the derived window is non-decreasing as the jitter ratio + grows (a wider sampling band can only need a longer max-model-len).""" + job = _make_job(FakeOrch()) + prev = None + for ratio in ("0.0", "0.1", "0.4", "0.8", "1.0"): + job.random_range_ratio = ratio + cur = int(job._derive_max_model_len()) + if prev is not None: + self.assertGreaterEqual(cur, prev) + prev = cur + + def test_derive_max_model_len_varies_isl(self): + # isl=512, osl=1024, ratio=0.8: ceil((512+1024)*1.8) + 0 + 8 = ceil(2764.8)+8 = 2773 + job = VllmDistributedJob( + orch=FakeOrch(), + variant=_fake_variant(), + hf_token="tok", + isl=512, + osl=1024, + concurrency=16, + num_prompts=800, + ) + job.random_range_ratio = "0.8" + job.random_prefix_len = "0" + import math + + expected = str(math.ceil((512 + 1024) * 1.8) + 8) + self.assertEqual(job._derive_max_model_len(), expected) + + +# --------------------------------------------------------------------------- +# Broadcast vs head routing: build_server_cmd, start_server, run_client, +# wait_client_complete, parse_results. +# Classification: subsystem -- mocked at the orch seam. The load-bearing +# contract is WHICH orch entrypoint (broadcast exec vs exec_on_head) each step +# uses and, for start_server, the per-host rank dispatch. +# --------------------------------------------------------------------------- + + +class TestBuildServerCmdBroadcast(unittest.TestCase): + def test_build_server_cmd_broadcasts_to_all_nodes(self): + """Env-script write and out-dir mkdir must go to ALL nodes: broadcast + orch.exec() with no per-host targeting (no hosts kwarg, or hosts=None). + A hosts=[one] here would leave the other node without the env/out-dir.""" + orch = FakeOrch() + job = _make_job(orch) + job.build_server_cmd() + + self.assertGreaterEqual(len(orch.exec_calls), 2, "expected at least env-script write + mkdir") + # Every exec issued by build_server_cmd is a broadcast (hosts unset/None). + for cmd, kwargs in orch.exec_calls: + self.assertIsNone( + kwargs.get("hosts"), f"build_server_cmd must broadcast, got hosts={kwargs.get('hosts')!r} for {cmd!r}" + ) + # It does not touch the head-only channel. + self.assertEqual(orch.head_calls, []) + joined = " ".join(c for c, _ in orch.exec_calls) + self.assertIn("/tmp/server_env_script.sh", joined) + self.assertIn("mkdir -p", joined) + mkdir_cmd = next((c for c, _ in orch.exec_calls if "mkdir" in c), None) + self.assertIsNotNone(mkdir_cmd, "no mkdir command found in exec_calls") + self.assertIn(job.out_dir, mkdir_cmd, "mkdir must target job.out_dir") + + def test_env_script_carries_required_exports(self): + orch = FakeOrch() + job = _make_job(orch) + job.build_server_cmd() + env_cmd = next((c for c, _ in orch.exec_calls if "/tmp/server_env_script.sh" in c), None) + self.assertIsNotNone(env_cmd, "no env-script write command found") + for token in ( + "HF_TOKEN=tok", + f"HF_HUB_CACHE={job.models_dir}", + "VLLM_USE_AITER_UNIFIED_ATTENTION=1", + "VLLM_ROCM_USE_AITER_MHA=0", + "VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4=1", + ): + self.assertRegex(env_cmd, r'\bexport\s+' + re.escape(token), f"env script must export '{token}'") + + def test_out_dir_encodes_cell_parameters(self): + """F8: out_dir must encode (isl, osl, concurrency) so that distinct cells + write to distinct directories. Verifies both uniqueness and substring presence.""" + orch_a = FakeOrch() + job_a = VllmDistributedJob( + orch=orch_a, + variant=_fake_variant(), + hf_token="tok", + isl=128, + osl=2048, + concurrency=16, + num_prompts=800, + ) + orch_b = FakeOrch() + job_b = VllmDistributedJob( + orch=orch_b, + variant=_fake_variant(), + hf_token="tok", + isl=512, + osl=1024, + concurrency=32, + num_prompts=800, + ) + self.assertNotEqual( + job_a.out_dir, job_b.out_dir, "distinct (isl, osl, concurrency) must produce distinct out_dir values" + ) + # Each out_dir must contain its cell parameters as substrings. + for val in ("128", "2048", "16"): + self.assertIn(val, job_a.out_dir, f"job_a.out_dir must contain '{val}' (isl/osl/concurrency)") + for val in ("512", "1024", "32"): + self.assertIn(val, job_b.out_dir, f"job_b.out_dir must contain '{val}' (isl/osl/concurrency)") + + def test_env_script_carries_server_env_overrides(self): + v = _fake_variant() + v.roles.server.env = {"CUSTOM_VAR": "custom_val"} + orch = FakeOrch() + job = VllmDistributedJob( + orch=orch, variant=v, hf_token="tok", isl=_ISL, osl=2048, concurrency=256, num_prompts=12800 + ) + job.build_server_cmd() + # Find the env-script write among the broadcast calls (order-independent). + env_cmd = next((c for c, _ in orch.exec_calls if "/tmp/server_env_script.sh" in c), None) + self.assertIsNotNone(env_cmd, "no env-script write found in exec_calls") + self.assertRegex(env_cmd, r'\bexport\s+CUSTOM_VAR=custom_val', "env script must export CUSTOM_VAR=custom_val") + + +class TestStartServerPerHost(unittest.TestCase): + def test_start_server_calls_per_host_with_correct_rank(self): + """start_server iterates enumerate(orch.hosts): exactly one + orch.exec(..., hosts=[host]) per host, and host i carries --node-rank i. + Verified with N=2 and N=4 to rule out special-casing the first two hosts.""" + for n_hosts in (2, 4): + hosts = [f"node-{i}" for i in range(n_hosts)] + with self.subTest(n_hosts=n_hosts): + orch = FakeOrch(hosts=hosts) + job = _make_job(orch) + job.start_server() + + # One targeted exec per host; never a broadcast launch and never head-only. + self.assertEqual(len(orch.exec_calls), n_hosts) + self.assertEqual(orch.head_calls, []) + + seen = {} # host -> rank parsed from the launched command + for cmd, kwargs in orch.exec_calls: + h_list = kwargs.get("hosts") + self.assertIsInstance(h_list, list) + self.assertEqual(len(h_list), 1, "each launch targets exactly one host") + host = h_list[0] + m = re.search(r"--node-rank\s+(\d+)", cmd) + self.assertIsNotNone(m, f"launch for {host} has no --node-rank: {cmd!r}") + seen[host] = int(m.group(1)) + # Per-host command must source env script and carry distributed flags. + self.assertIn("source /tmp/server_env_script.sh", cmd, "per-host launch must source the env script") + self.assertIn("--distributed-executor-backend", cmd) + self.assertIn("--master-addr", cmd) + self.assertIn("--pipeline-parallel-size", cmd) + self.assertIn("--nnodes", cmd) + self.assertIn("nohup", cmd, "per-host launch must use nohup to background") + self.assertRegex(cmd, r'2>&1\s*&', "per-host launch must be backgrounded with 2>&1 &") + self.assertIn("vllm serve", cmd, "per-host launch must use 'vllm serve' subcommand") + + self.assertEqual(set(seen), set(hosts), "every host got launched exactly once") + # host at orch.hosts[i] must receive rank i. + for i, host in enumerate(hosts): + self.assertEqual(seen[host], i, f"{host} (index {i}) got rank {seen[host]}") + + def test_start_server_raises_on_early_failure(self): + """If a host's launch output matches EARLY_FAILURE_RE, start_server must + raise RuntimeError. Covers all four EARLY_FAILURE_RE arms via subTest.""" + early_failure_strings = [ + "command not found: vllm", + "no such file or directory: vllm", + "cannot access /usr/bin/vllm", + "failed to start vllm server", + ] + for failure_text in early_failure_strings: + with self.subTest(failure=failure_text): + + def make_fail(ft=failure_text): + def fail(cmd, kwargs): + host = kwargs.get("hosts", [None])[0] + return {host: (ft if host == _HOSTS[1] else "")} + + return fail + + orch = FakeOrch(hosts=_HOSTS, exec_return=make_fail()) + job = _make_job(orch) + with self.assertRaises(RuntimeError): + job.start_server() + + def test_start_server_raises_on_head_failure(self): + """Rank-0 (head) failure must also raise — not silently ignored.""" + + def head_fails(cmd, kwargs): + host = kwargs.get("hosts", [None])[0] + return {host: ("command not found: vllm" if host == _HOSTS[0] else "")} + + orch = FakeOrch(hosts=_HOSTS, exec_return=head_fails) + job = _make_job(orch) + with self.assertRaises(RuntimeError): + job.start_server() + + def test_start_server_clean_launch_does_not_raise(self): + orch = FakeOrch(hosts=_HOSTS, exec_return=lambda c, k: {k.get("hosts", [None])[0]: ""}) + job = _make_job(orch) + job.start_server() # no raise on empty/benign output + + def test_start_server_clean_launch_with_benign_output(self): + """Non-empty output that does NOT match EARLY_FAILURE_RE must not raise. + Distinguishes EARLY_FAILURE_RE gating from 'any non-empty raises'.""" + benign = "nohup: ignoring input\nINFO: starting vllm..." + + def benign_return(cmd, kwargs): + host = kwargs.get("hosts", [None])[0] + return {host: benign} + + orch = FakeOrch(hosts=_HOSTS, exec_return=benign_return) + job = _make_job(orch) + job.start_server() # must NOT raise despite non-empty output + + +class TestRunClientHeadOnly(unittest.TestCase): + def test_run_client_uses_exec_on_head(self): + """The bench client runs ONLY on the head node: via exec_on_head, never + the broadcast orch.exec (which would start N redundant clients).""" + orch = FakeOrch() + job = _make_job(orch) + job.run_client() + self.assertEqual(len(orch.head_calls), 1) + self.assertEqual(orch.exec_calls, []) + client_cmd = orch.head_calls[0][0] + self.assertIn("bench serve", client_cmd) + self.assertIn("source /tmp/server_env_script.sh", client_cmd, "run_client must source the env script") + self.assertIn(job.client_log, client_cmd) + self.assertIn("2>&1 &", client_cmd, "run_client must background the bench client with 2>&1 &") + self.assertIn("--result-dir", client_cmd) + self.assertIn(job.out_dir, client_cmd) + self.assertIn("--result-filename", client_cmd) + self.assertIn("--save-result", client_cmd) + self.assertIn("--ignore-eos", client_cmd) + self.assertIn("--model", client_cmd) + self.assertIn(job.model_id, client_cmd) + self.assertIn("--base-url", client_cmd) + self.assertIn(f"{job.base_url}:{job.port_no}", client_cmd) + self.assertIn("--num-prompts", client_cmd) + + # Use regex to verify each flag is paired with its correct value, + # preventing false passes from substring collisions (e.g. 128 in 12800). + def _flag_val(flag): + m = re.search(re.escape(flag) + r'\s+(\S+)', client_cmd) + return m.group(1) if m else None + + self.assertEqual(_flag_val("--result-filename"), "results") + self.assertEqual(_flag_val("--random-input-len"), str(_ISL)) + self.assertEqual(_flag_val("--random-output-len"), str(2048)) + self.assertEqual(_flag_val("--max-concurrency"), str(256)) + self.assertEqual(_flag_val("--num-prompts"), str(12800)) + self.assertEqual(_flag_val("--percentile-metrics"), "ttft,tpot,itl,e2el") + self.assertEqual(_flag_val("--metric-percentiles"), "50,90,95,99") + self.assertEqual(_flag_val("--random-range-ratio"), "0.8") + self.assertEqual(_flag_val("--random-prefix-len"), "0") + self.assertEqual(_flag_val("--request-rate"), "inf") + self.assertEqual(_flag_val("--burstiness"), "1.0") + self.assertEqual(_flag_val("--seed"), "0") + self.assertEqual(_flag_val("--tokenizer-mode"), "auto") + + def test_run_client_goodput_flag_built_from_slo(self): + orch = FakeOrch() + job = _make_job(orch, goodput_slo={"ttft_ms": 500.0, "tpot_ms": 50.0, "e2el_ms": 60000.0}) + job.run_client() + self.assertEqual(orch.exec_calls, [], "run_client must not broadcast even with goodput_slo set") + cmd = orch.head_calls[0][0] + self.assertIn("--goodput", cmd) + for tok in ("ttft:500.0", "tpot:50.0", "e2el:60000.0"): + self.assertIn(tok, cmd) + + def test_run_client_goodput_omitted_when_none(self): + orch = FakeOrch() + job = _make_job(orch, goodput_slo=None) + job.run_client() + self.assertEqual(orch.exec_calls, [], "run_client must not broadcast even with goodput_slo=None") + self.assertNotIn("--goodput", orch.head_calls[0][0]) + + def test_run_client_goodput_sparse_dict_skips_none_keys(self): + # Only ttft_ms present; tpot and e2el absent — must NOT emit tpot:None or e2el:None. + orch = FakeOrch() + job = _make_job(orch, goodput_slo={"ttft_ms": 500.0}) + job.run_client() + self.assertEqual(orch.exec_calls, [], "run_client must not broadcast even with sparse goodput_slo") + cmd = orch.head_calls[0][0] + self.assertIn("--goodput", cmd) + self.assertIn("ttft:500.0", cmd) + self.assertNotIn("tpot:", cmd) + self.assertNotIn("e2el:", cmd) + + +class TestWaitClientCompleteHeadOnly(unittest.TestCase): + """wait_client_complete polls the client log on the head node only + (exec_on_head returns {head: text}). Failure-before-completion semantics + mirror VllmJob but iterate the single-entry dict.""" + + def _job_with_head(self, head_text): + orch = FakeOrch(head_return={_HEAD: head_text}) + # Pass timing overrides via constructor kwargs (F4: constructor path). + job = _make_job(orch, client_initial_wait_s=0, client_poll_wait_s=0, client_poll_count=1) + # Belt-and-suspenders: also set private attrs so the test works regardless + # of whether the impl reads constructor kwargs or private attrs. + job._client_initial_wait = 0 + job._client_poll_wait = 0 + job._client_poll_count = 1 + return orch, job + + def test_uses_exec_on_head(self): + orch, job = self._job_with_head("Serving Benchmark Result\nFailed requests: 0\n") + job.wait_client_complete() + self.assertGreaterEqual(len(orch.head_calls), 1) + self.assertEqual(orch.exec_calls, [], "client poll must not broadcast") + polled_cmd = orch.head_calls[-1][0] + self.assertIn(job.client_log, polled_cmd, "wait_client_complete must poll job.client_log") + + def test_completion_marker_alone_is_sufficient(self): + # Completion banner without a "Failed requests:" summary line must still return. + orch, job = self._job_with_head("Serving Benchmark Result\n") + job.wait_client_complete() # must not raise + self.assertGreaterEqual(len(orch.head_calls), 1) + + def test_nonzero_failed_requests_raises(self): + orch, job = self._job_with_head("Serving Benchmark Result\nFailed requests: 7\n") + with self.assertRaises(RuntimeError): + job.wait_client_complete() + self.assertEqual(orch.exec_calls, [], "failure path must not broadcast") + + def test_zero_failed_requests_is_not_a_failure(self): + # "Failed requests: 0" is always printed — it must NOT raise. + # Distinguish from test_completion_marker_alone_is_sufficient by asserting + # the head channel was actually called (not a no-op return). + orch, job = self._job_with_head("Serving Benchmark Result\nFailed requests: 0\n") + job.wait_client_complete() + self.assertGreaterEqual(len(orch.head_calls), 1, "poll must have been issued") + self.assertEqual(orch.exec_calls, [], "must not broadcast") + + def test_client_crash_raises(self): + orch, job = self._job_with_head("Traceback (most recent call last):\n ...\n") + with self.assertRaises(RuntimeError): + job.wait_client_complete() + self.assertEqual(orch.exec_calls, [], "failure path must not broadcast") + + def test_launch_failure_raises(self): + arms = [ + "error: argument --bogus: unrecognized arguments\n", + "invalid choice: bench\n", + "command not found: vllm\n", + "/usr/bin/vllm: No such file or directory\n", + ] + for text in arms: + with self.subTest(text=text.strip()): + orch, job = self._job_with_head(text) + with self.assertRaises(RuntimeError): + job.wait_client_complete() + self.assertEqual(orch.exec_calls, [], "failure path must not broadcast") + + def test_no_completion_within_cap_raises(self): + # Log never shows the summary -> poll cap exhausted -> RuntimeError. + orch, job = self._job_with_head("still warming up...\n") + with self.assertRaises(RuntimeError): + job.wait_client_complete() + self.assertEqual(orch.exec_calls, [], "timeout path must not broadcast") + + def test_completion_on_second_poll(self): + """Deferred-completion path: banner absent on poll 1, present on poll 2. + An implementation that exits after exactly one poll (ignoring the loop + cap) would raise RuntimeError here instead of returning normally.""" + call_count = [0] + + def deferred(cmd, kwargs): + call_count[0] += 1 + if call_count[0] >= 2: + return {_HEAD: "Serving Benchmark Result\nFailed requests: 0\n"} + return {_HEAD: "still warming up...\n"} + + orch = FakeOrch(head_return=deferred) + # Pass timing overrides via constructor kwargs (F4: constructor path). + job = _make_job(orch, client_initial_wait_s=0, client_poll_wait_s=0, client_poll_count=3) + # Belt-and-suspenders: also set private attrs. + job._client_initial_wait = 0 + job._client_poll_wait = 0 + job._client_poll_count = 3 + job.wait_client_complete() # must return without raising + self.assertGreaterEqual(len(orch.head_calls), 2, "must have polled at least twice") + + +class TestWaitReadyBroadcast(unittest.TestCase): + """wait_ready / is_ready poll the server log on ALL nodes via broadcast + orch.exec (not exec_on_head). For distributed runs this is load-bearing: + a crashed rank-1 shard must be detected, not silently skipped.""" + + def _job_ready(self, exec_return): + orch = FakeOrch(hosts=_HOSTS, exec_return=exec_return) + # Pass timing overrides via constructor kwargs (F4: constructor path). + job = _make_job( + orch, server_precheck_wait_s=0, server_warmup_wait_s=0, server_poll_count=3, server_poll_wait_s=0 + ) + # Belt-and-suspenders: also set private attrs. + job._precheck_wait = 0 + job._warmup_wait = 0 + job._server_poll_count = 3 + job._server_poll_wait = 0 + return orch, job + + # Command-discriminating dispatcher: wait_ready calls orch.exec for two purposes: + # (1) tail -30 — early-failure pre-check; expects string per-host values + # (2) grep -qiE — is_ready poll; expects {exit_code: N} dicts + # The mock must return the correct shape for each call type to avoid TypeErrors. + @staticmethod + def _ready_dispatcher(grep_exit_code_fn): + """Return an exec_return callable that dispatches on command type. + + grep_exit_code_fn(call_index) -> int: exit code for is_ready grep calls. + Tail/precheck calls always return benign empty strings. + """ + grep_calls = [0] + + def dispatch(cmd, kwargs): + if "grep" in cmd: + grep_calls[0] += 1 + code = grep_exit_code_fn(grep_calls[0]) + return {h: {"exit_code": code} for h in _HOSTS} + # tail / other precheck: return benign string output + return {h: "" for h in _HOSTS} + + return dispatch + + def test_is_ready_uses_broadcast_exec(self): + """is_ready must poll via broadcast orch.exec (all nodes), not exec_on_head. + Also asserts the command targets job.server_log (not a hardcoded/wrong path).""" + orch = FakeOrch(hosts=_HOSTS, exec_return=self._ready_dispatcher(lambda _: 0)) + job = _make_job(orch) + job.is_ready() + self.assertTrue(orch.exec_calls, "is_ready must issue at least one broadcast exec") + self.assertEqual(orch.head_calls, [], "is_ready must not use exec_on_head") + grep_cmd = next((c for c, _ in orch.exec_calls if "grep" in c), None) + self.assertIsNotNone(grep_cmd, "is_ready must issue a grep command") + self.assertIn(job.server_log, grep_cmd, "is_ready must grep job.server_log") + for arm in ("Application startup complete", "Uvicorn running", "Started server"): + self.assertIn(arm.lower(), grep_cmd.lower(), f"READINESS_RE arm '{arm}' missing from grep command") + + def test_wait_ready_returns_when_all_nodes_ready(self): + """wait_ready returns normally when is_ready() reports all nodes up.""" + orch, job = self._job_ready(self._ready_dispatcher(lambda _: 0)) + job.wait_ready() # must not raise + self.assertTrue(orch.exec_calls, "wait_ready must have issued at least one exec") + tail_cmd = next((c for c, _ in orch.exec_calls if "tail" in c), None) + if tail_cmd is not None: + self.assertIn(job.server_log, tail_cmd, "wait_ready precheck must tail job.server_log") + + def test_wait_ready_raises_on_timeout(self): + """wait_ready raises RuntimeError when poll cap is exhausted without readiness.""" + orch, job = self._job_ready(self._ready_dispatcher(lambda _: 1)) + with self.assertRaises(RuntimeError): + job.wait_ready() + + def test_wait_ready_raises_when_one_node_not_ready(self): + """Distributed-specific: a single not-ready node must hold the entire + cluster in the not-ready state. Distinguishes all() from any().""" + + def one_lagging(cmd, kwargs): + if "grep" in cmd: + # node-a ready, node-b not — partial cluster case + return {"node-a": {"exit_code": 0}, "node-b": {"exit_code": 1}} + return {h: "" for h in _HOSTS} + + orch, job = self._job_ready(one_lagging) + with self.assertRaises(RuntimeError): + job.wait_ready() + + def test_wait_ready_returns_on_second_poll(self): + """Deferred-readiness: not ready on poll 1, ready on poll 2. + An implementation that calls is_ready() only once raises instead of returning.""" + grep_calls = [0] + + def deferred_ready(cmd, kwargs): + if "grep" in cmd: + grep_calls[0] += 1 + code = 0 if grep_calls[0] >= 2 else 1 + return {h: {"exit_code": code} for h in _HOSTS} + return {h: "" for h in _HOSTS} + + orch, job = self._job_ready(deferred_ready) + job.wait_ready() # must return without raising + grep_count = sum(1 for c, _ in orch.exec_calls if "grep" in c) + self.assertGreaterEqual(grep_count, 2, "must have polled is_ready at least twice") + + def test_wait_ready_raises_on_non_head_early_failure(self): + """Non-head node (rank>0) early-failure must also raise — not silently ignored. + Pins the distributed-specific contract: every shard is inspected.""" + for failure_text in ("command not found: vllm", "no such file or directory: vllm"): + with self.subTest(failure=failure_text): + + def non_head_fail(cmd, kwargs, _ft=failure_text): + if "tail" in cmd: + # node-a (head) is clean; node-b (rank-1) reports the failure + return {"node-a": "", "node-b": _ft} + return {h: {"exit_code": 0} for h in _HOSTS} + + orch = FakeOrch(hosts=_HOSTS, exec_return=non_head_fail) + # Pass timing overrides via constructor kwargs (F4: constructor path). + job = _make_job( + orch, server_precheck_wait_s=0, server_warmup_wait_s=0, server_poll_count=1, server_poll_wait_s=0 + ) + # Belt-and-suspenders: also set private attrs. + job._precheck_wait = 0 + job._warmup_wait = 0 + job._server_poll_count = 1 + job._server_poll_wait = 0 + with self.assertRaises(RuntimeError): + job.wait_ready() + + def test_wait_ready_raises_on_early_failure(self): + """EARLY_FAILURE_RE output during the readiness wait raises immediately. + Covers all four EARLY_FAILURE_RE arms via subTest.""" + early_failure_strings = [ + "command not found: vllm", + "no such file or directory: vllm", + "cannot access /usr/bin/vllm", + "failed to start vllm server", + ] + for failure_text in early_failure_strings: + with self.subTest(failure=failure_text): + + def early_fail(cmd, kwargs, _ft=failure_text): + if "tail" in cmd: + return {"node-a": _ft, "node-b": ""} + return {h: {"exit_code": 0} for h in _HOSTS} + + orch = FakeOrch(hosts=_HOSTS, exec_return=early_fail) + # Pass timing overrides via constructor kwargs (F4: constructor path). + job = _make_job( + orch, server_precheck_wait_s=0, server_warmup_wait_s=0, server_poll_count=1, server_poll_wait_s=0 + ) + # Belt-and-suspenders: also set private attrs. + job._precheck_wait = 0 + job._warmup_wait = 0 + job._server_poll_count = 1 + job._server_poll_wait = 0 + with self.assertRaises(RuntimeError): + job.wait_ready() + + +class TestParseResultsHeadOnly(unittest.TestCase): + """parse_results cats the `results` artifact on the head node via + exec_on_head and delegates to to_client_metrics. Raises on + empty/missing/unparseable.""" + + def _parse(self, head_text): + orch = FakeOrch(head_return={_HEAD: head_text}) + job = _make_job(orch) + return orch, job + + def test_parse_results_uses_exec_on_head(self): + orch, job = self._parse(_load_fixture("vllm_results_widened.json")) + out = job.parse_results() + self.assertGreaterEqual(len(orch.head_calls), 1) + self.assertEqual(orch.exec_calls, [], "results fetch must not broadcast") + # Delegation produced namespaced client.* metrics for the head host. + self.assertIn(_HEAD, out) + self.assertIn("client.total_token_throughput", out[_HEAD]) + fetch_cmd = orch.head_calls[-1][0] + self.assertIn(job.out_dir, fetch_cmd, "parse_results must cat from out_dir") + # Verify exact artifact name 'results' (not a superstring like 'results.json'). + self.assertRegex(fetch_cmd, r'(? client -> +# stopped). The value-add rows are the routing-illegal cases (a step that must +# NOT broadcast, a step that must NOT go head-only) and idempotent re-entry. +# --------------------------------------------------------------------------- + + +class TestVllmDistributedJobLifecycle(unittest.TestCase): + """Transition table (routing is the observable state of each phase): + + | from | event | to / effect | + |-------------|------------------------|-----------------------------------------------| + | constructed | build_server_cmd() | env+mkdir BROADCAST to all nodes | + | env-ready | start_server() | one TARGETED exec per host, rank i -> host i | + | started | run_client() | HEAD-ONLY exec_on_head, no broadcast | + | client-up | wait_client_complete() | HEAD-ONLY poll | + | complete | parse_results() | HEAD-ONLY cat + parse | + | any | stop_server() | pkill BROADCAST to all nodes | + + Illegal/guard rows: run_client must never broadcast; build_server_cmd must + never go head-only; start_server on a single host yields a single rank-0. + """ + + def test_legal_sequence_routes_correctly(self): + # exec_return: dispatches by command type so each method gets the right shape. + # is_ready uses grep with detailed=True → {exit_code: 0} dicts. + # start_server / build_server_cmd expect string output for EARLY_FAILURE_RE. + def broadcast_return(cmd, kwargs): + targets = kwargs.get("hosts") or _HOSTS + if kwargs.get("detailed") or "grep" in cmd: + return {h: {"exit_code": 0} for h in targets} + return {h: "" for h in targets} + + # head_return: dispatch on command content so the mock is deterministic + # regardless of how many exec_on_head calls run_client issues internally. + # parse_results cats job.out_dir/.../results; wait_client_complete tails/greps + # job.client_log. We detect the parse_results call by the presence of "results" + # and job.out_dir (the artifact path), and return the JSON only for that call. + results_text = _load_fixture("vllm_results_widened.json") + # Build a temp job to get out_dir for the dispatch predicate. + _tmp_job = _make_job(FakeOrch(hosts=_HOSTS)) + _out_dir = _tmp_job.out_dir + del _tmp_job + + def stateful_head(cmd, kwargs): + # parse_results calls: cat /results + if _out_dir in cmd and "results" in cmd: + return {_HEAD: results_text} + # wait_client_complete and run_client calls: return the completion banner + return {_HEAD: "Serving Benchmark Result\nFailed requests: 0\n"} + + orch = FakeOrch( + hosts=_HOSTS, + exec_return=broadcast_return, + head_return=stateful_head, + ) + # Pass all timing overrides via constructor kwargs (F4: constructor path). + job = _make_job( + orch, + client_initial_wait_s=0, + client_poll_wait_s=0, + client_poll_count=1, + server_precheck_wait_s=0, + server_warmup_wait_s=0, + server_poll_count=1, + server_poll_wait_s=0, + ) + # Belt-and-suspenders: also set private attrs. + job._client_initial_wait = 0 + job._client_poll_wait = 0 + job._client_poll_count = 1 + job._precheck_wait = 0 + job._warmup_wait = 0 + job._server_poll_count = 1 + job._server_poll_wait = 0 + + job.build_server_cmd() + broadcast_after_build = len(orch.exec_calls) + self.assertTrue(all(k.get("hosts") is None for _, k in orch.exec_calls)) + + job.start_server() + # start_server added exactly one targeted exec per host. + per_host = orch.exec_calls[broadcast_after_build:] + self.assertEqual(len(per_host), len(_HOSTS)) + self.assertTrue(all(isinstance(k.get("hosts"), list) for _, k in per_host)) + + exec_before_ready = len(orch.exec_calls) + job.wait_ready() + # wait_ready must have issued at least one broadcast exec (no hosts kwarg). + ready_execs = orch.exec_calls[exec_before_ready:] + self.assertTrue(ready_execs, "wait_ready issued no broadcast exec") + self.assertTrue(all(k.get("hosts") is None for _, k in ready_execs)) + + head_before_client = len(orch.head_calls) + job.run_client() + job.wait_client_complete() + out = job.parse_results() + # Every client-side step used the head channel, none broadcast. + self.assertGreater(len(orch.head_calls), head_before_client) + self.assertIn(_HEAD, out) + + def test_stop_server_broadcasts(self): + """Illegal-if-head-only: teardown must pkill on ALL nodes, else a stray + server lingers on the non-head node. Broadcast exec, no head-only call.""" + orch = FakeOrch(hosts=_HOSTS) + job = _make_job(orch) + with unittest.mock.patch("time.sleep"): + job.stop_server() + self.assertTrue(orch.exec_calls, "stop_server issued no command") + self.assertEqual(orch.head_calls, [], "stop_server must not use exec_on_head") + for cmd, kwargs in orch.exec_calls: + self.assertIsNone(kwargs.get("hosts"), "stop_server must broadcast to all nodes") + self.assertIn("pkill", cmd) + self.assertIn("-f", cmd, "pkill must use -f for full-cmdline match") + self.assertIn("vllm serve", cmd, "pkill must target 'vllm serve' process") + + def test_stop_server_is_idempotent(self): + """Idempotent re-entry: calling stop_server twice must not raise (pkill + on an already-dead server is a no-op). Both calls must still broadcast.""" + orch = FakeOrch(hosts=_HOSTS) + job = _make_job(orch) + with unittest.mock.patch("time.sleep"): + job.stop_server() + after_first = len(orch.exec_calls) + self.assertGreater(after_first, 0, "first stop_server issued no exec") + job.stop_server() + self.assertGreater(len(orch.exec_calls), after_first, "second stop_server must still broadcast pkill") + self.assertEqual(orch.head_calls, [], "stop_server must not use exec_on_head") + for cmd, kwargs in orch.exec_calls: + self.assertIn("pkill", cmd) + self.assertIn("vllm serve", cmd, "pkill must target 'vllm serve' process") + + def test_single_host_cluster_yields_single_rank0(self): + """Boundary (nnodes->1 host in orch.hosts): start_server dispatches a + single rank-0 launch, never a phantom rank-1.""" + orch = FakeOrch(hosts=["solo-node"]) + job = _make_job(orch) + job.start_server() + self.assertEqual(len(orch.exec_calls), 1) + cmd, kwargs = orch.exec_calls[0] + self.assertEqual(kwargs.get("hosts"), ["solo-node"]) + self.assertEqual(re.search(r"--node-rank\s+(\d+)", cmd).group(1), "0") + + +# --------------------------------------------------------------------------- +# VariantConfig.cell_key format (spec File 1). Pure function over its three +# args + the configured TP/PP. Imported lazily so a not-yet-existent loader +# fails THIS test cleanly (RED) rather than erroring collection of the module. +# --------------------------------------------------------------------------- + + +class TestCellKeyFormat(unittest.TestCase): + def _variant_config(self): + from cvs.lib.inference.utils.vllm_distributed_config_loader import VariantConfig + + return VariantConfig(**self._raw_config()) + + @staticmethod + def _raw_config(): + # Minimal raw config kwargs sufficient to construct VariantConfig with the + # spec's default params (tensor_parallelism=8, pipeline_parallel_size=2). + return { + "schema_version": 1, + "framework": "vllm_distributed", + "gpu_arch": "mi300x", + "enforce_thresholds": False, + "paths": { + "shared_fs": "/tmp", + "models_dir": "/tmp/models", + "log_dir": "/tmp/LOGS", + "hf_token_file": "/tmp/tok", + }, + "model": {"id": "amd/Llama-3.1-70B-Instruct-FP8-KV", "remote": 0}, + "roles": {"server": {"serve_args": {}}}, + "params": {"master_addr": "node-head"}, + "sweep": {"sequence_combinations": [], "runs": []}, + "thresholds": {}, + } + + def test_cell_key_format(self): + """cell_key(isl, osl, conc) == 'ISL=,OSL=,TP=,PP=,CONC=' + with TP/PP pulled from params (default 8 / 2). The PP segment is the + distributed-specific addition over vllm_single's key.""" + vc = self._variant_config() + self.assertEqual( + vc.cell_key("1000", "1000", 16), + "ISL=1000,OSL=1000,TP=8,PP=2,CONC=16", + ) + + def test_cell_key_reflects_param_values(self): + # TP/PP in the key track the configured params, not hardcoded literals. + # Use non-default values (TP=2, PP=1) so the test independently falsifies + # any implementation that hardcodes TP=8/PP=2. + raw = self._raw_config() + raw["params"]["tensor_parallelism"] = "2" + raw["params"]["pipeline_parallel_size"] = "1" + from cvs.lib.inference.utils.vllm_distributed_config_loader import VariantConfig + + vc = VariantConfig(**raw) + key = vc.cell_key("8000", "1024", 16) + self.assertIn("TP=2", key) + self.assertIn("PP=1", key) + self.assertNotIn("TP=8", key) + self.assertNotIn("PP=2", key) + self.assertIn("ISL=8000", key) + self.assertIn("OSL=1024", key) + self.assertIn("CONC=16", key) + + def test_cell_key_uses_non_default_tp_pp(self): + # If cell_key hardcodes TP=8/PP=2 rather than reading params, this fails. + raw = self._raw_config() + raw["params"]["tensor_parallelism"] = "4" + raw["params"]["pipeline_parallel_size"] = "4" + from cvs.lib.inference.utils.vllm_distributed_config_loader import VariantConfig + + vc = VariantConfig(**raw) + key = vc.cell_key("1000", "1000", 16) + self.assertIn("TP=4", key) + self.assertIn("PP=4", key) + self.assertNotIn("TP=8", key) + self.assertNotIn("PP=2", key) + + +if __name__ == "__main__": + unittest.main() diff --git a/cvs/lib/inference/unittests/test_vllm_orch_parse.py b/cvs/lib/inference/unittests/test_vllm_orch_parse.py index 95d6b42e..3f56e435 100644 --- a/cvs/lib/inference/unittests/test_vllm_orch_parse.py +++ b/cvs/lib/inference/unittests/test_vllm_orch_parse.py @@ -20,7 +20,7 @@ _FIXTURES = _HERE / "fixtures" _REPO = _HERE.parents[3] # cvs/lib/inference/unittests -> repo root _SHARED = _REPO / "cvs/tests/inference/vllm/_shared.py" -_THRESHOLD = _REPO / "cvs/input/config_file/inference/vllm_single/w1_llama31_70b_fp8kv/llama31_70b_fp8_threshold.json" +_THRESHOLD = _REPO / "cvs/input/config_file/inference/vllm_single/mi300x_vllm-single_llama31-70b_fp8_threshold.json" # isl/tp used to build the job; must match the fixture's run for the derived # math assertions to be meaningful (real artifact: isl=128, tp=8). diff --git a/cvs/lib/inference/utils/vllm_distributed_config_loader.py b/cvs/lib/inference/utils/vllm_distributed_config_loader.py new file mode 100644 index 00000000..9a4015b7 --- /dev/null +++ b/cvs/lib/inference/utils/vllm_distributed_config_loader.py @@ -0,0 +1,224 @@ +''' +Copyright 2025 Advanced Micro Devices, Inc. +All rights reserved. + +Inference-specific config schema for the vllm_distributed suite. + +Extends the vllm_single loader with distributed params (pipeline_parallel_size, +master_addr, master_port, nnodes) and a PP-aware cell_key. + +cell_key format: ISL=,OSL=,TP=,PP=,CONC= +The PP segment is the distributed-specific addition over vllm_single. +''' + +from __future__ import annotations + +import warnings +from collections import Counter +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Literal + +from cvs.lib.inference.utils.vllm_parsing import GATED_METRICS +from cvs.lib.utils.config_loader import substitute_config + + +class _Forbid(BaseModel): + model_config = ConfigDict(extra="forbid") + + +class _Allow(BaseModel): + model_config = ConfigDict(extra="allow") + + +# ---------- sub-models ---------- + + +class ContainerConfig(_Allow): + """Loose container block — passes through to the orchestrator as-is.""" + + lifetime: str = "per_run" + name: str = "" + image: str = "" + + +class Paths(_Forbid): + shared_fs: str + models_dir: str + log_dir: str + hf_token_file: str + + +class ModelSpec(_Forbid): + id: str + remote: Literal[0, 1] + + +class RoleServer(_Forbid): + serve_args: Dict[str, Any] = {} + env: Dict[str, str] = {} + + +class Roles(_Forbid): + server: RoleServer = RoleServer() + + +class GoodputSlo(_Forbid): + ttft_ms: float + tpot_ms: float + e2el_ms: float + + +class SeqCombo(_Forbid): + name: str + isl: str + osl: str + goodput_slo: Optional[GoodputSlo] = None + + +class Run(_Forbid): + combo: str + concurrency: int + + +def validate_sweep_selector(combo_names, run_combo_refs): + counts = Counter(combo_names) + dupes = sorted(name for name, count in counts.items() if count > 1) + if dupes: + raise ValueError(f"duplicate sequence_combination names: {dupes}") + known = set(counts) + unknown = sorted({r for r in run_combo_refs if r not in known}) + if unknown: + raise ValueError(f"run.combo names no sequence_combination: {unknown} (known: {sorted(known)})") + + +class Sweep(_Forbid): + sequence_combinations: List[SeqCombo] + runs: List[Run] + + @model_validator(mode="after") + def _check_runs_reference_known_combos(self): + validate_sweep_selector( + [c.name for c in self.sequence_combinations], + [r.combo for r in self.runs], + ) + return self + + +class Params(_Forbid): + backend: str = "vllm" + base_url: str = "http://0.0.0.0" + port_no: str = "8888" + dataset_name: str = "random" + burstiness: str = "1.0" + seed: str = "0" + request_rate: str = "inf" + random_range_ratio: str = "0.8" + random_prefix_len: str = "0" + tensor_parallelism: str = "8" + pipeline_parallel_size: str = "2" + master_addr: str = "localhost" + master_port: str = "29501" + nnodes: str = "2" + tokenizer_mode: str = "auto" + percentile_metrics: str = "ttft,tpot,itl,e2el" + metric_percentiles: str = "50,90,95,99" + num_prompts: str = "3200" + client_poll_count: str = "20" + + +class VariantConfig(_Forbid): + """Typed config for the vllm_distributed suite. + + Standalone (does not extend BaseVariantConfig) so it can be constructed + without the threshold_json field that the base requires -- that field is + present in production configs but absent from unit-test fixtures. + + `container` is optional: unit-test fixtures omit it (defaults to a bare + ContainerConfig), while production configs provide the full block. The + conftest's `orch` fixture accesses `variant_config.container.model_dump()` + to build the orchestrator, so the field must be present even when empty. + """ + + schema_version: Literal[1] + framework: Literal["vllm_distributed"] + gpu_arch: str + enforce_thresholds: bool = True + container: ContainerConfig = Field(default_factory=ContainerConfig) + paths: Paths + model: ModelSpec + roles: Roles = Roles() + params: Params = Params() + sweep: Sweep + thresholds: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + + def cell_key(self, isl, osl, concurrency): + """The canonical threshold key for one sweep cell. + + Distributed-specific: includes PP= segment in addition to the + vllm_single TP= segment. + + Format: ISL=,OSL=,TP=,PP=,CONC= + """ + return ( + f"ISL={isl},OSL={osl}," + f"TP={self.params.tensor_parallelism}," + f"PP={self.params.pipeline_parallel_size}," + f"CONC={concurrency}" + ) + + def expected_cells(self): + """Every (isl, osl, conc) cell the sweep's `runs` selector picks.""" + by_name = {c.name: c for c in self.sweep.sequence_combinations} + return [self.cell_key(by_name[r.combo].isl, by_name[r.combo].osl, r.concurrency) for r in self.sweep.runs] + + @model_validator(mode="after") + def _check_remote_not_implemented(self): + if self.model.remote == 1: + raise NotImplementedError("model.remote=1 (remote model download) is not implemented. ") + return self + + @model_validator(mode="after") + def _check_thresholds_cover_sweep(self): + expected = set(self.expected_cells()) + present = set(self.thresholds.keys()) + missing = sorted(expected - present) + extra = sorted(present - expected) + problems = [] + if missing: + problems.append(f"sweep cells with no threshold entry: {missing}") + if extra: + problems.append(f"threshold keys matching no sweep cell (typo?): {extra}") + gated_keys = [f"client.{m}" for m in sorted(GATED_METRICS)] + gated_gaps = {} + for cell in sorted(expected & present): + specs = self.thresholds.get(cell) or {} + absent = [k for k in gated_keys if k not in specs] + if absent: + gated_gaps[cell] = absent + if gated_gaps: + problems.append(f"cells missing gated-metric specs: {gated_gaps}") + if problems: + msg = "threshold.json does not match the sweep matrix; " + "; ".join(problems) + if self.enforce_thresholds: + raise ValueError(msg) + warnings.warn(f"{msg} (enforce_thresholds=false -> record-only)", stacklevel=2) + return self + + +# ---------- public API ---------- + + +def load_variant(config_path, cluster_dict): + """Load and validate a vllm_distributed variant config + its sibling threshold file. + + Production configs contain extra fields (container, threshold_json) that the + standalone VariantConfig does not model. They are stripped before construction + so unit-test fixtures (which also omit those fields) and production configs + are handled identically. + """ + raw, thresholds = substitute_config(config_path, cluster_dict) + known = {k: v for k, v in raw.items() if k in VariantConfig.model_fields} + known["thresholds"] = thresholds + return VariantConfig(**known) diff --git a/cvs/lib/inference/vllm_distributed.py b/cvs/lib/inference/vllm_distributed.py new file mode 100644 index 00000000..ae5a1f3f --- /dev/null +++ b/cvs/lib/inference/vllm_distributed.py @@ -0,0 +1,542 @@ +''' +Copyright 2025 Advanced Micro Devices, Inc. +All rights reserved. + +Multinode vLLM distributed job driven by a ContainerOrchestrator. + +Runs vLLM with --distributed-executor-backend mp across N nodes. +Server launches on every node with per-rank --node-rank flag. +Client runs exclusively on the head node (rank 0). + +Key distributed-vs-single routing contract: + - build_server_cmd: env-script write + mkdir BROADCAST to ALL nodes + (orch.exec with no hosts kwarg / hosts=None) + - start_server: one targeted orch.exec(..., hosts=[host]) per host, + with per-rank --node-rank + - run_client / wait_client_complete / parse_results: head-only via + orch.exec_on_head (never broadcast) + - wait_ready / is_ready: BROADCAST orch.exec so every shard is checked + - stop_server: pkill BROADCAST to all nodes +''' + +from __future__ import annotations + +import json +import math +import re +import shlex +import time + +from cvs.lib import globals +from cvs.lib.inference.utils.vllm_parsing import to_client_metrics + +log = globals.log + + +class VllmDistributedJob: + """Multinode vLLM benchmark job driven by an injected ContainerOrchestrator. + + All container/SSH plumbing belongs to `orch`. This class composes the + server-env script, launches the server on EVERY node in the background + (each with its per-rank --node-rank), polls ALL nodes until ready, runs + the bench_serving client on the HEAD node only, and parses the resulting + log (head-only). + + The `orch.hosts` list provides the enumeration: `enumerate(orch.hosts)` + yields (rank, host) pairs for per-host rank dispatch. + """ + + READINESS_RE = re.compile(r"Application startup complete|Uvicorn running|Started server", re.I) + COMPLETION_RE = re.compile(r"Serving Benchmark Result", re.I) + FAILED_REQUESTS_RE = re.compile(r"Failed requests:\s+([0-9]+)", re.I) + CLIENT_CRASH_RE = re.compile(r"Traceback \(most recent call last\)", re.I) + CLIENT_LAUNCH_FAIL_RE = re.compile( + r"unrecognized arguments|invalid choice|error: argument |command not found|: No such file or directory", + re.I, + ) + EARLY_FAILURE_RE = re.compile( + r"no such file or directory|command not found|cannot access|failed to start" + r"|Free memory on device.*less than desired" + r"|Engine core initialization failed" + r"|WorkerProc failed to start", + re.I, + ) + # Pattern checked against full NFS log after warmup to catch hard failures. + FATAL_LOG_RE = re.compile( + r"Free memory on device.{0,80}less than desired" + r"|Engine core initialization failed" + r"|RuntimeError:.*[Ee]ngine", + re.I, + ) + + def __init__( + self, + orch, + variant, + hf_token, + isl, + osl, + concurrency, + num_prompts, + goodput_slo=None, + log_subdir="vllm", + server_precheck_wait_s=30, + server_warmup_wait_s=330, + server_poll_count=60, + server_poll_wait_s=60, + client_initial_wait_s=120, + client_poll_count=20, + client_poll_wait_s=60, + ): + self.orch = orch + self.variant = variant + self.hf_token = hf_token + self.isl = str(isl) + self.osl = str(osl) + self.concurrency = str(concurrency) + self.num_prompts = str(num_prompts) + self.goodput_slo = goodput_slo + self.log_subdir = log_subdir + + p = variant.params + self.tp = p.tensor_parallelism + self.pp = p.pipeline_parallel_size + self.master_addr = p.master_addr + self.master_port = p.master_port + self.nnodes = p.nnodes + self.port_no = p.port_no + self.random_range_ratio = p.random_range_ratio + self.random_prefix_len = p.random_prefix_len + self.burstiness = p.burstiness + self.seed = p.seed + self.request_rate = p.request_rate + self.tokenizer_mode = p.tokenizer_mode + self.percentile_metrics = p.percentile_metrics + self.metric_percentiles = p.metric_percentiles + self.base_url = p.base_url + self.dataset_name = p.dataset_name + self.backend = p.backend + + self.model_id = variant.model.id + self.log_dir = variant.paths.log_dir + self.serve_args = dict(variant.roles.server.serve_args) + self.server_env = dict(variant.roles.server.env) + self.models_dir = variant.paths.models_dir + + # Per-cell output directory, keyed by (isl/osl/conc). + self.out_dir = f"{self.log_dir}/{self.log_subdir}/out-node0/isl{self.isl}_osl{self.osl}_conc{self.concurrency}" + self.server_log = f"{self.out_dir}/vllm_serve_server.log" + self.client_log = f"{self.out_dir}/client.log" + + self._precheck_wait = server_precheck_wait_s + self._warmup_wait = server_warmup_wait_s + self._server_poll_count = server_poll_count + self._server_poll_wait = server_poll_wait_s + self._client_initial_wait = client_initial_wait_s + self._client_poll_count = client_poll_count + self._client_poll_wait = client_poll_wait_s + + # ---------- derived builders ---------- + + _MML_PAD = 8 + + def _derive_max_model_len(self): + r = float(self.random_range_ratio) + worst = (int(self.isl) + int(self.osl)) * (1.0 + r) + return str(math.ceil(worst) + int(self.random_prefix_len) + self._MML_PAD) + + @staticmethod + def _flatten_serve_args(mapping): + """A {flag: value} serve-args map -> a flat `vllm serve` arg list.""" + argv = [] + for flag, value in mapping.items(): + opt = f"--{flag}" + if value is True: + argv.append(opt) + elif isinstance(value, (list, tuple)): + for v in value: + argv.extend([opt, str(v)]) + else: + argv.extend([opt, str(value)]) + return argv + + def _server_argv(self, rank): + """The `vllm serve` arg list for a specific node rank. + + All distributed flags are added: --node-rank, --master-addr, + --master-port, --pipeline-parallel-size, --nnodes, and + --distributed-executor-backend mp. Only --node-rank varies per rank; + all other flags are rank-invariant. + """ + argv = [ + "vllm", + "serve", + self.model_id, + "--tensor-parallel-size", + str(self.tp), + "--pipeline-parallel-size", + str(self.pp), + "--max-model-len", + self._derive_max_model_len(), + "--port", + str(self.port_no), + "--node-rank", + str(rank), + "--master-addr", + str(self.master_addr), + "--master-port", + str(self.master_port), + "--nnodes", + str(self.nnodes), + "--distributed-executor-backend", + "mp", + ] + argv.extend(self._flatten_serve_args(self.serve_args)) + return argv + + # ---------- server side ---------- + + def _rank_log(self, rank): + """Return the server log path for a specific rank (avoids shared-file clobber).""" + return self.server_log.replace("out-node0", f"out-node{rank}") + + def build_server_cmd(self): + """Write the server-env script and create the per-node out-dirs. + + Both are BROADCAST to ALL nodes (no hosts kwarg) so every rank has the + env script and out-dir before any per-host server launch. + """ + env_lines = [ + f"export HF_TOKEN={shlex.quote(self.hf_token)}", + f"export HF_HUB_CACHE={shlex.quote(self.models_dir)}", + "export VLLM_USE_AITER_UNIFIED_ATTENTION=1", + "export VLLM_ROCM_USE_AITER_MHA=0", + "export VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4=1", + ] + for k, v in self.server_env.items(): + env_lines.append(f"export {k}={shlex.quote(str(v))}") + env_script = "\n".join(env_lines) + "\n" + # Broadcast: no hosts kwarg -> all nodes get the env script. + self.orch.exec("bash -c " + shlex.quote(f"printf '%s' {shlex.quote(env_script)} > /tmp/server_env_script.sh")) + # Create a per-rank out-dir for every node so logs never share a path. + for rank in range(int(self.nnodes)): + rank_dir = self.out_dir.replace("out-node0", f"out-node{rank}") + self.orch.exec(f"mkdir -p {shlex.quote(rank_dir)}") + + # Patch 0: Drop stale .pyc files so Python recompiles from the .py source + # that is already present in the Docker image. Two known-stale files: + # multiproc_executor.pyc — compiled before the fix that returns [] instead + # of asserting rpc_broadcast_mq is not None on follower nodes. + # core.pyc — compiled before the follower guards added in Patches 1-3. + # Python skips recompilation when the .pyc mtime/hash matches the .py, + # so an in-place image update leaves stale .pyc files silently in use. + _stale_pycs = [ + ( + "/opt/python/lib/python3.12/site-packages/vllm/v1/executor/" + "__pycache__/multiproc_executor.cpython-312.pyc" + ), + ("/opt/python/lib/python3.12/site-packages/vllm/v1/engine/__pycache__/core.cpython-312.pyc"), + ] + for _pyc in _stale_pycs: + self.orch.exec(f"rm -f {shlex.quote(_pyc)}") + log.info("vllm stale pycs removed (multiproc_executor, core)") + + # Patch 0b: Fix collective_rpc assert in multiproc_executor.py. + # The image's .py has `assert self.rpc_broadcast_mq is not None` which + # fires on follower nodes. Replace with a guard that returns [] (or None + # for single-return calls) so followers silently skip the broadcast. + _mpexec_script = ( + "\n".join( + [ + "import pathlib", + "p = pathlib.Path('/opt/python/lib/python3.12/site-packages/vllm/v1/executor/multiproc_executor.py')", + "src = p.read_text()", + "if 'if self.rpc_broadcast_mq is None' in src:", + " print('ALREADY_PATCHED')", + "else:", + " old = (' assert self.rpc_broadcast_mq is not None, (\\n'", + " ' \"collective_rpc should not be called on follower node\"\\n'", + " ' )')", + " new = (' if self.rpc_broadcast_mq is None:\\n'", + " ' return None if (unique_reply_rank is not None or kv_output_aggregator is not None) else []')", + " if old in src:", + " p.write_text(src.replace(old, new, 1))", + " print('PATCHED')", + " else:", + " print('NOT_FOUND')", + ] + ) + + "\n" + ) + self.orch.exec("bash -c " + shlex.quote(f"printf '%s' {shlex.quote(_mpexec_script)} > /tmp/vllm_patch0b.py")) + patch_out0b = self.orch.exec("python3 /tmp/vllm_patch0b.py") + for host, out in (patch_out0b or {}).items(): + log.info("vllm multiproc_executor.py patch0b on %s: %s", host, (out or "").strip()) + + # Patch 1: Guard _initialize_kv_caches for follower nodes (node_rank > 0). + # collective_rpc requires rpc_broadcast_mq which is None on followers; + # skip to a dummy KVCacheConfig so init can proceed. + _patch1_script = ( + "\n".join( + [ + "import pathlib", + "p = pathlib.Path('/opt/python/lib/python3.12/site-packages/vllm/v1/engine/core.py')", + "src = p.read_text()", + "old = ' kv_cache_config = self._initialize_kv_caches(vllm_config)\\n'", + "new = (", + " ' if vllm_config.parallel_config.node_rank_within_dp == 0:\\n'", + " ' kv_cache_config = self._initialize_kv_caches(vllm_config)\\n'", + " ' else:\\n'", + " ' vllm_config.cache_config.num_gpu_blocks = 1\\n'", + " ' kv_cache_config = KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[])\\n'", + ")", + "already = 'vllm_config.cache_config.num_gpu_blocks = 1'", + "if already in src:", + " print('ALREADY_PATCHED')", + "elif old in src:", + " p.write_text(src.replace(old, new, 1))", + " print('PATCHED')", + "else:", + " print('NOT_FOUND')", + ] + ) + + "\n" + ) + self.orch.exec("bash -c " + shlex.quote(f"printf '%s' {shlex.quote(_patch1_script)} > /tmp/vllm_patch1.py")) + patch_out = self.orch.exec("python3 /tmp/vllm_patch1.py") + for host, out in (patch_out or {}).items(): + log.info("vllm core.py patch1 on %s: %s", host, (out or "").strip()) + + # Patch 2: Guard Scheduler() creation for follower nodes. + # Scheduler.__init__ → KVCacheManager → HybridKVCacheCoordinator asserts + # len(attention_groups) > 1, but followers have kv_cache_groups=[]. + # Stub it out with _F: follower EngineCore only needs workers running. + _patch2_script = ( + "\n".join( + [ + "import pathlib", + "p = pathlib.Path('/opt/python/lib/python3.12/site-packages/vllm/v1/engine/core.py')", + "src = p.read_text()", + "old = ' self.scheduler: SchedulerInterface = Scheduler(\\n'", + "new = (", + " ' if vllm_config.parallel_config.node_rank_within_dp != 0:\\n'", + " ' class _F:\\n'", + " ' connector = None\\n'", + " ' def get_kv_connector(self): return None\\n'", + " ' def __getattr__(self, n): return lambda *a, **k: None\\n'", + " ' self.scheduler = _F()\\n'", + " ' else:\\n'", + " ' self.scheduler: SchedulerInterface = Scheduler(\\n'", + ")", + "already = 'class _F:'", + "if already in src:", + " print('ALREADY_PATCHED')", + "elif old in src:", + " p.write_text(src.replace(old, new, 1))", + " print('PATCHED')", + "else:", + " print('NOT_FOUND')", + ] + ) + + "\n" + ) + self.orch.exec("bash -c " + shlex.quote(f"printf '%s' {shlex.quote(_patch2_script)} > /tmp/vllm_patch2.py")) + patch_out2 = self.orch.exec("python3 /tmp/vllm_patch2.py") + for host, out in (patch_out2 or {}).items(): + log.info("vllm core.py patch2 on %s: %s", host, (out or "").strip()) + + # Patch 3: Ensure get_supported_tasks returns ("generate",) for follower + # nodes instead of calling collective_rpc which returns [] on followers + # (causing IndexError in abstract.py:supported_tasks → output[0]). + # Two image variants are handled: + # A) Image has guard but uses SupportedTask.GENERATE (Literal, not Enum) + # → replace the bad return value. + # B) Image has bare form (no guard at all) → insert the guard block. + _patch3_script = ( + "\n".join( + [ + "import pathlib", + "p = pathlib.Path('/opt/python/lib/python3.12/site-packages/vllm/v1/engine/core.py')", + "src = p.read_text()", + "case_a_old = ' return (SupportedTask.GENERATE,)'", + "case_b_old = ' def get_supported_tasks(self) -> tuple[SupportedTask, ...]:\\n return self.model_executor.supported_tasks\\n'", + "case_b_new = (' def get_supported_tasks(self) -> tuple[SupportedTask, ...]:\\n'", + " ' if self.vllm_config.parallel_config.node_rank_within_dp != 0:\\n'", + " ' return (\"generate\",)\\n'", + " ' return self.model_executor.supported_tasks\\n')", + "if '\"generate\"' in src and 'node_rank_within_dp != 0' in src:", + " print('ALREADY_PATCHED')", + "elif case_a_old in src:", + " p.write_text(src.replace(case_a_old, ' return (\"generate\",)', 1))", + " print('PATCHED_A')", + "elif case_b_old in src:", + " p.write_text(src.replace(case_b_old, case_b_new, 1))", + " print('PATCHED_B')", + "else:", + " idx = src.find('def get_supported_tasks')", + " print('NOT_FOUND ctx:', src[idx:idx+150] if idx != -1 else 'fn absent')", + ] + ) + + "\n" + ) + self.orch.exec("bash -c " + shlex.quote(f"printf '%s' {shlex.quote(_patch3_script)} > /tmp/vllm_patch3.py")) + patch_out3 = self.orch.exec("python3 /tmp/vllm_patch3.py") + for host, out in (patch_out3 or {}).items(): + log.info("vllm core.py patch3 on %s: %s", host, (out or "").strip()) + + def start_server(self): + """Launch vllm serve on each host with the correct --node-rank. + + Iterates enumerate(orch.hosts) -> (rank, host), issues one targeted + orch.exec(..., hosts=[host]) per host. Raises on early failure. + """ + for rank, host in enumerate(self.orch.hosts): + serve_cmd = " ".join(shlex.quote(str(a)) for a in self._server_argv(rank)) + rank_log = self._rank_log(rank) + inner = f"source /tmp/server_env_script.sh && nohup {serve_cmd} > {shlex.quote(rank_log)} 2>&1 &" + out = self.orch.exec("bash -c " + shlex.quote(inner), hosts=[host]) + for h, output in out.items(): + if self.EARLY_FAILURE_RE.search(output or ""): + raise RuntimeError(f"vllm server failed to launch on {h} (rank {rank}): {output[-500:]}") + + def is_ready(self): + """Broadcast grep to ALL nodes; ready only when every node's exit_code == 0.""" + pattern = self.READINESS_RE.pattern + out = self.orch.exec( + f"grep -qiE {shlex.quote(pattern)} {shlex.quote(self.server_log)}", + detailed=True, + ) + return bool(out) and all(r["exit_code"] == 0 for r in out.values()) + + def wait_ready(self): + log.info("waiting %ds for server log to materialise", self._precheck_wait) + time.sleep(self._precheck_wait) + + # Broadcast tail to ALL nodes to detect early failures on any shard. + out = self.orch.exec(f"tail -30 {shlex.quote(self.server_log)}") + for host, output in out.items(): + if self.EARLY_FAILURE_RE.search(output or ""): + raise RuntimeError(f"vllm server early failure on {host}: {output[-500:]}") + + log.info("warmup wait %ds", self._warmup_wait) + time.sleep(self._warmup_wait) + + # After warmup, check for hard fatal signatures before polling. + out = self.orch.exec( + f"grep -m1 -iE {shlex.quote(self.FATAL_LOG_RE.pattern)} {shlex.quote(self.server_log)}", + detailed=True, + ) + for host, r in (out or {}).items(): + if r.get("exit_code") == 0 and r.get("stdout", "").strip(): + raise RuntimeError(f"vllm server fatal error on {host}: {r['stdout'].strip()[-500:]}") + + for it in range(self._server_poll_count): + if self.is_ready(): + log.info("server ready (iter=%d)", it) + return + time.sleep(self._server_poll_wait) + raise RuntimeError("vllm server did not become ready before timeout") + + def stop_server(self): + """Broadcast pkill to ALL nodes so no stray shard lingers.""" + log.info("stopping vllm server") + self.orch.exec("bash -c 'pkill -f \"vllm serve\" || true'") + time.sleep(5) + + # ---------- client side (head-only) ---------- + + def run_client(self): + """Launch bench serve on the HEAD node only via exec_on_head.""" + args = [ + "vllm", + "bench", + "serve", + "--model", + self.model_id, + "--backend", + self.backend, + "--base-url", + f"{self.base_url}:{self.port_no}", + "--dataset-name", + self.dataset_name, + "--num-prompts", + self.num_prompts, + "--random-input-len", + self.isl, + "--random-output-len", + self.osl, + "--max-concurrency", + self.concurrency, + "--request-rate", + self.request_rate, + "--burstiness", + self.burstiness, + "--tokenizer-mode", + self.tokenizer_mode, + "--seed", + self.seed, + "--random-range-ratio", + self.random_range_ratio, + "--random-prefix-len", + self.random_prefix_len, + "--percentile-metrics", + self.percentile_metrics, + "--metric-percentiles", + self.metric_percentiles, + "--ignore-eos", + "--save-result", + "--result-dir", + self.out_dir, + "--result-filename", + "results", + ] + if self.goodput_slo: + args.append("--goodput") + for metric, key in (("ttft", "ttft_ms"), ("tpot", "tpot_ms"), ("e2el", "e2el_ms")): + val = self.goodput_slo.get(key) + if val is not None: + args.append(f"{metric}:{val}") + bench_cmd = " ".join(shlex.quote(str(a)) for a in args) + client_cmd = f"source /tmp/server_env_script.sh && {bench_cmd} > {shlex.quote(self.client_log)} 2>&1 &" + self.orch.exec_on_head("bash -c " + shlex.quote(client_cmd)) + + def wait_client_complete(self): + """Poll the client log on the HEAD node only via exec_on_head.""" + log.info("client initial wait %ds", self._client_initial_wait) + time.sleep(self._client_initial_wait) + for it in range(self._client_poll_count): + out = self.orch.exec_on_head(f"tail -2000 {shlex.quote(self.client_log)}") + failed = [] + done = [] + for host, output in out.items(): + txt = output or "" + done.append(bool(self.COMPLETION_RE.search(txt))) + if self.CLIENT_CRASH_RE.search(txt) or self.CLIENT_LAUNCH_FAIL_RE.search(txt): + failed.append((host, txt[-500:])) + else: + fm = self.FAILED_REQUESTS_RE.search(txt) + if fm and int(fm.group(1)) > 0: + failed.append((host, f"Failed requests: {fm.group(1)} -- {txt[-500:]}")) + if failed: + raise RuntimeError("client failed: " + "; ".join(f"{h}: {m}" for h, m in failed)) + if done and all(done): + log.info("client complete (iter=%d)", it) + return + time.sleep(self._client_poll_wait) + raise RuntimeError("client did not complete before poll cap") + + def parse_results(self): + """Fetch and parse the results artifact from the HEAD node via exec_on_head.""" + artifact = f"{self.out_dir}/results" + out = self.orch.exec_on_head(f"cat {shlex.quote(artifact)}") + results = {} + for host, text in out.items(): + text = (text or "").strip() + if not text: + raise RuntimeError(f"empty/missing results artifact on {host}: {artifact}") + try: + raw = json.loads(text) + except (json.JSONDecodeError, ValueError) as e: + raise RuntimeError(f"unparseable results artifact on {host}: {artifact}: {e}") from e + results[host] = to_client_metrics(raw, tp=self.tp, isl=self.isl) + return results diff --git a/cvs/tests/inference/vllm_distributed/__init__.py b/cvs/tests/inference/vllm_distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cvs/tests/inference/vllm_distributed/conftest.py b/cvs/tests/inference/vllm_distributed/conftest.py new file mode 100644 index 00000000..4ef4186e --- /dev/null +++ b/cvs/tests/inference/vllm_distributed/conftest.py @@ -0,0 +1,204 @@ +''' +Copyright 2025 Advanced Micro Devices, Inc. +All rights reserved. +''' + +import json +import os + +import pytest + +from cvs.core.orchestrators.factory import OrchestratorConfig, OrchestratorFactory +from cvs.lib import globals +from cvs.lib.inference.utils.vllm_distributed_config_loader import load_variant +from cvs.lib.utils_lib import resolve_cluster_config_placeholders + +log = globals.log + + +def _deep_merge(base, override): + """Recursively merge `override` onto `base` (dicts merged key-wise, scalars/lists replaced). + + Protects cluster-set SCALAR and DICT container keys (e.g. shm_size, an env + map) from being wiped by a top-level replace: they survive unless the variant + overrides that same key. List keys (e.g. runtime.args, volume mounts) are + REPLACED here, not unioned -- the cluster's list values are recombined with + the variant's additively further downstream, in container.py's getters. + """ + if not (isinstance(base, dict) and isinstance(override, dict)): + return override + out = dict(base) + for k, v in override.items(): + out[k] = _deep_merge(base[k], v) if k in base else v + return out + + +@pytest.fixture(scope="module") +def cluster_dict(pytestconfig): + cluster_file = pytestconfig.getoption("cluster_file") + if not cluster_file: + pytest.fail("--cluster_file is required") + with open(cluster_file) as fp: + d = json.load(fp) + return resolve_cluster_config_placeholders(d) + + +@pytest.fixture(scope="module") +def variant_config(pytestconfig, cluster_dict): + config_file = pytestconfig.getoption("config_file") + if not config_file: + pytest.fail("--config_file is required") + return load_variant(config_file, cluster_dict) + + +class _Lifecycle: + """Cross-test state for the lifecycle-as-tests model. + + The container launch / sshd / fetch / teardown stages are individual tests + (so each is a timed, pass/fail row in the HTML) rather than fixture body + code. They share this object: `failed` lets a broken stage skip the rest + instead of cascading; `torn_down` lets the explicit teardown test suppress + the fixture's leak-guard finalizer; `report` maps a test's nodeid to the + rows it recorded, each carrying its own unit, so pytest_runtest_makereport + renders only that test's stages -- not every stage on every row. + """ + + def __init__(self): + self.failed = False + self.torn_down = False + self.report = {} # nodeid -> list[(label, value, unit)] + + def record(self, nodeid, label, value, unit="s"): + self.report.setdefault(nodeid, []).append((label, value, unit)) + + +@pytest.fixture(scope="module") +def lifecycle(): + return _Lifecycle() + + +@pytest.fixture(scope="module") +def orch(cluster_dict, variant_config, lifecycle): + """Construct a ContainerOrchestrator and own ONLY its teardown safety net. + + The actual launch/sshd happen in test_launch_container / test_setup_sshd + so they appear as timed rows. This fixture builds the object and registers a + leak-guard finalizer: if a mid-sweep test fails before test_teardown runs, + the container is still torn down here. When test_teardown ran successfully + it sets lifecycle.torn_down, so the finalizer no-ops (no double teardown). + """ + # OrchestratorConfig.from_configs does a top-level dict.update, so a bare variant + # container block would wipe the cluster file's container settings. Deep-merge the + # variant ONTO the cluster block so cluster-set scalar/dict keys survive, with the + # variant winning on conflicting keys. (List keys like runtime.args are replaced + # here but recombined additively downstream in container.py's getters.) + container_block = _deep_merge( + cluster_dict.get("container", {}), + variant_config.container.model_dump(), + ) + testsuite_config = { + "orchestrator": "container", + "container": container_block, + } + cfg = OrchestratorConfig.from_configs(cluster_dict, testsuite_config) + o = OrchestratorFactory.create_orchestrator(log, cfg) + yield o + if not lifecycle.torn_down: + log.info("orch fixture leak-guard: tearing down container (explicit teardown did not run)") + o.teardown_containers() + + +@pytest.fixture(scope="module") +def hf_token(variant_config): + path = variant_config.paths.hf_token_file + if not os.path.isfile(path): + if variant_config.model.remote == 0: + # Pre-staged model: token not needed for download; server env sets + # HF_HUB_OFFLINE=1 to skip Hub auth checks entirely. + return "" + pytest.skip(f"hf_token file missing: {path}") + with open(path) as fp: + return fp.read().strip() + + +@pytest.fixture(scope="module") +def inf_res_dict(): + return {} + + +def pytest_collection_modifyitems(items): + """Pin the lifecycle order explicitly instead of relying on definition order. + + `test_print_results_table` is an imported function (its source line points + into _shared.py), so default ordering collects it FIRST -- which would log an + empty table before any cell ran. Sort deterministically: launch, sshd, fetch, + the benchmark cells, the results table, then teardown last. Items from other + modules keep their relative order. + """ + rank = { + "test_launch_container": 0, + "test_setup_sshd": 1, + "test_model_fetch": 2, + "test_vllm_inference": 3, + "test_metric": 4, + "test_print_results_table": 5, + "test_teardown": 6, + } + items.sort(key=lambda it: rank.get(it.originalname or it.name.split("[")[0], 99)) + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item, call): + """Attach THIS test's recorded rows to its HTML report detail panel. + + Renders only the rows recorded against the current item's nodeid (so each + stage shows its own timings, not every stage's), and reads the unit per row + (durations in `s`, the fetch size in `GB`) instead of a fixed "seconds" + header. Guarded: a no-op when pytest-html is not installed (the `extras` + plugin attribute is absent), so the suite still runs under a bare pytest. + """ + outcome = yield + report = outcome.get_result() + if report.when != "call": + return + lc = item.funcargs.get("lifecycle") + rows = getattr(lc, "report", {}).get(item.nodeid) if lc else None + if not rows: + return + try: + import pytest_html + except ImportError: + return + body = "".join(f"{label}{value:.1f}{unit}" for label, value, unit in rows) + html = f"{body}
stagevalueunit
" + extras = getattr(report, "extras", []) + extras.append(pytest_html.extras.html(html)) + report.extras = extras + + +def pytest_html_results_table_header(cells): + """Add Value + Unit columns just before the trailing Links column. + + Populated for test_metric rows; blank for lifecycle/inference rows (they + record no metric_value user-property). Scoped to this suite's conftest, so + other suites' result tables are unaffected. + """ + cells.insert(-1, "Value") + cells.insert(-1, "Unit") + + +def pytest_html_results_table_row(report, cells): + props = dict(report.user_properties) + has = "metric_value" in props + val = props.get("metric_value") + unit = props.get("metric_unit", "") if has else "" + if not has: + shown = "" + elif val is None: + shown = "-" + elif isinstance(val, float): + shown = f"{val:.3f}" + else: + shown = str(val) + cells.insert(-1, f"{shown}") + cells.insert(-1, f"{unit}") diff --git a/cvs/tests/inference/vllm_distributed/vllm_distributed.py b/cvs/tests/inference/vllm_distributed/vllm_distributed.py new file mode 100644 index 00000000..1ad4be54 --- /dev/null +++ b/cvs/tests/inference/vllm_distributed/vllm_distributed.py @@ -0,0 +1,316 @@ +''' +Copyright 2025 Advanced Micro Devices, Inc. +All rights reserved. + +Parametrized vLLM multinode distributed benchmark suite (replaces the 4 per-model wrappers). +''' + +import json +import os +import shlex +import time + +import pytest + +from cvs.lib import globals +from cvs.lib.inference.utils.inferencing_config_loader import GoodputSlo, validate_sweep_selector +from cvs.lib.utils.verdict import evaluate_all +from cvs.lib.inference.utils.vllm_parsing import CLIENT_METRICS as _METRICS, CLIENT_METRIC_UNITS as _METRIC_UNITS +from cvs.lib.inference.vllm_distributed import VllmDistributedJob + +import importlib.util as _ilu +import pathlib as _pl + +_spec = _ilu.spec_from_file_location("_vllm_shared", _pl.Path(__file__).parent.parent / "vllm" / "_shared.py") +_mod = _ilu.module_from_spec(_spec) +_spec.loader.exec_module(_mod) +test_print_results_table = _mod.test_print_results_table # exported as a sibling test # noqa: F841 + +log = globals.log + +# Fetch-progress poll: du the cache dir until its size stops growing. The model +# download streams in parallel shards, so size climbs then plateaus at the full +# weight set; a stable size across two polls means the fetch settled. +_FETCH_POLL_COUNT = 80 +_FETCH_POLL_WAIT_S = 30 +_FETCH_PRESENCE_RETRIES = 5 + + +def pytest_generate_tests(metafunc): + """Parametrize test_vllm_inference from the sweep's named-combo + runs selector. + + Lives in the suite module (not conftest) because it parametrizes fixtures + only test_vllm_inference consumes -- co-locating the parametrization with + its sole consumer. It runs at collection time, before fixtures exist, so it + reads the raw config_file JSON directly (it cannot use the variant_config + fixture / the typed loader). + + The sweep lists `sequence_combinations` (each with a `name`) once and a + `runs` array of `{combo, concurrency}` pairs; one case is emitted per run. + No NxM cartesian -- exactly the cells `runs` enumerates. + """ + config_file = metafunc.config.getoption("config_file") + if not config_file or not os.path.isfile(config_file): + return + with open(config_file) as fp: + raw = json.load(fp) + sweep = raw.get("sweep", {}) + combos = sweep.get("sequence_combinations", []) + runs = sweep.get("runs", []) + # Validate each raw goodput_slo dict through the same _Forbid model the + # typed loader uses. pytest_generate_tests bypasses load_variant (it reads + # raw JSON at collection time), so without this a typo'd SLO key would be + # silently dropped and a wrong goodput gate would run on hardware. + for combo in combos: + if combo.get("goodput_slo") is not None: + GoodputSlo(**combo["goodput_slo"]) + # Mirror the typed Sweep validator here (this path reads raw JSON before + # load_variant runs) via the shared rule so the two cannot drift: a + # duplicate combo name or a run referencing an unknown combo must fail + # collection, not silently drop. + validate_sweep_selector([c["name"] for c in combos], [r["combo"] for r in runs]) + by_name = {c["name"]: c for c in combos} + cases = [] + ids = [] + for run in runs: + combo = by_name[run["combo"]] + conc = run["concurrency"] + cases.append((combo, conc)) + ids.append(run["combo"] + "-conc" + str(conc)) + if "metric" in metafunc.fixturenames: + if cases: + metric_cases = [] + metric_ids = [] + for (combo, c), cid in zip(cases, ids): + for short, _unit in _METRICS: + metric_cases.append((combo, c, short)) + metric_ids.append(cid + "-" + short) + metafunc.parametrize("seq_combo,concurrency,metric", metric_cases, ids=metric_ids) + elif "seq_combo" in metafunc.fixturenames and "concurrency" in metafunc.fixturenames and cases: + metafunc.parametrize("seq_combo,concurrency", cases, ids=ids) + + +def _num_prompts_for(osl, concurrency): + return str(concurrency * 20) if int(osl) >= 8192 else str(concurrency * 50) + + +def _du_bytes(orch, path): + """Total bytes under `path` inside the container, or 0 if it doesn't exist yet.""" + out = orch.exec(f"bash -c {shlex.quote(f'du -sb {shlex.quote(path)} 2>/dev/null | cut -f1')}") + total = 0 + for text in (out or {}).values(): + for tok in (text or "").split(): + if tok.isdigit(): + total = max(total, int(tok)) + return total + + +def test_launch_container(orch, variant_config, lifecycle, request): + """Stage 1: launch the container. Asserts it is independently observed running.""" + t = time.monotonic() + ok = orch.setup_containers() + lifecycle.record(request.node.nodeid, "container_launch", time.monotonic() - t) + if not ok: + lifecycle.failed = True + name = orch.get_container_name(orch.container_config, orch.container_config["image"]) + pytest.fail(f"setup_containers() returned False for {name}") + name = orch.get_container_name(orch.container_config, orch.container_config["image"]) + if not orch.verify_containers_running(name): + lifecycle.failed = True + pytest.fail(f"container {name} not running after setup_containers()") + + +def test_setup_sshd(orch, lifecycle, request): + """Stage 2: start sshd in the container (multinode only; single-node skips it).""" + if lifecycle.failed: + pytest.skip("a prior lifecycle stage failed") + t = time.monotonic() + ok = orch.setup_sshd() + lifecycle.record(request.node.nodeid, "sshd_setup", time.monotonic() - t) + if not ok: + lifecycle.failed = True + pytest.fail("setup_sshd() returned False") + # Single-node runs skip starting the in-container sshd (it exists only for + # inter-node MPI), so only probe 2224 when there is more than one host. + if len(orch.hosts) > 1: + # Port 2224 in hex is 08B0; check /proc/net/tcp because ss/iproute2 may not be installed + probe = orch.exec("bash -c 'grep -qi 08B0 /proc/net/tcp /proc/net/tcp6 2>/dev/null && echo OK || echo NO'") + if not any("OK" in (v or "") for v in (probe or {}).values()): + lifecycle.failed = True + pytest.fail("sshd not listening on 2224 after setup_sshd()") + + +def test_model_fetch(orch, variant_config, lifecycle, request): + """Stage 3: ensure the model is present in the HF cache (mounted models dir). + + For a remote pull this is the ~152GB download; the row shows its real + duration and final size. For an offline/pre-staged model it returns near + instantly. Skips (never silently passes) if the cache dir is unconfigured + -- without it the fetch target is meaningless. Progress is polled via + `du -sb` (size on disk), the robust size-poll proven in the validation run. + """ + if lifecycle.failed: + pytest.skip("a prior lifecycle stage failed") + models_dir = variant_config.paths.models_dir + if not models_dir: + pytest.skip("paths.models_dir unset; cannot locate/verify the HF cache") + + remote = getattr(variant_config.model, "remote", 0) + t = time.monotonic() + orch.exec(f"mkdir -p {shlex.quote(models_dir)}") + + if not remote: + # Pre-staged model: nothing to download. Confirm bytes are present, + # retrying a few times so a cold/slow mount that reads 0 on the first + # du does not false-fail a model that is actually there. + final = 0 + for it in range(_FETCH_PRESENCE_RETRIES): + final = _du_bytes(orch, models_dir) + log.info("[fetch presence %d] size=%.1fGB", it, final / 1e9) + if final > 0: + break + time.sleep(_FETCH_POLL_WAIT_S) + else: + # Kick a background download into the pinned cache, then poll size until + # it stops growing (two equal readings) or we exhaust the poll budget. + fetch = ( + f"HF_HUB_CACHE={shlex.quote(models_dir)} " + f"nohup hf download {shlex.quote(variant_config.model.id)} " + f"> /tmp/hf_fetch.log 2>&1 &" + ) + orch.exec("bash -c " + shlex.quote(fetch)) + + prev = -1 + stable = 0 + final = _du_bytes(orch, models_dir) + for it in range(_FETCH_POLL_COUNT): + cur = _du_bytes(orch, models_dir) + final = cur + log.info("[fetch poll %d] size=%.1fGB", it, cur / 1e9) + if cur > 0 and cur == prev: + stable += 1 + if stable >= 2: + break + else: + stable = 0 + prev = cur + time.sleep(_FETCH_POLL_WAIT_S) + + lifecycle.record(request.node.nodeid, "model_fetch", time.monotonic() - t) + lifecycle.record(request.node.nodeid, "model_size", final / 1e9, "GB") + if final <= 0: + lifecycle.failed = True + pytest.fail(f"no model bytes under {models_dir} after fetch") + + +def test_vllm_inference(orch, variant_config, hf_token, seq_combo, concurrency, inf_res_dict, lifecycle, request): + if lifecycle.failed: + pytest.skip("a prior lifecycle stage failed") + isl = seq_combo["isl"] + osl = seq_combo["osl"] + job = VllmDistributedJob( + orch=orch, + variant=variant_config, + hf_token=hf_token, + isl=isl, + osl=osl, + concurrency=concurrency, + num_prompts=_num_prompts_for(osl, concurrency), + goodput_slo=seq_combo.get("goodput_slo"), + client_poll_count=int(variant_config.params.client_poll_count), + ) + + # A failure mid-sweep flips lifecycle.failed so the remaining cells skip + # cleanly (instead of each re-failing) AND the orch leak-guard finalizer + # still tears the container down. The explicit teardown row may not run on + # the failure path, which is exactly what the finalizer covers. + try: + job.stop_server() + job.build_server_cmd() + t = time.monotonic() + job.start_server() + job.wait_ready() + lifecycle.record(request.node.nodeid, "server_ready", time.monotonic() - t) + job.run_client() + job.wait_client_complete() + results = job.parse_results() + except Exception: + lifecycle.failed = True + raise + + key = ( + variant_config.model.id, + variant_config.gpu_arch, + isl, + osl, + seq_combo.get("name", "default"), + concurrency, + ) + inf_res_dict[key] = results + # Verdict is no longer asserted here: each metric is its own test (test_metric, + # one HTML row per metric per cell). This test only runs the benchmark and + # records the cell's results into the module-scoped inf_res_dict. + + +def test_metric(seq_combo, concurrency, metric, inf_res_dict, variant_config, lifecycle, request): + """One pytest test (= one HTML row) per perf metric per cell. + + The benchmark already ran once in test_vllm_inference and stashed its results + in the module-scoped inf_res_dict; this reads a single cached metric and + surfaces it as its own pass/fail row. The value is rendered inline via the + Value/Unit table columns (pytest_html_results_table_row in conftest). No GPU + work. Skips cleanly when the cell's inference failed/skipped so a missing cell + never reports a false green. + + Verdict: when enforce_thresholds is true AND a spec exists for this cell+metric + the value is asserted via the shared evaluate_all; otherwise the row is a + record-only PASS that simply displays the number. evaluate_all is handed the + full per-cell actuals (not just this one metric) so a min_ratio spec can still + resolve its reference metric. + """ + if lifecycle.failed: + pytest.skip("a prior lifecycle stage failed") + isl = seq_combo["isl"] + osl = seq_combo["osl"] + key = ( + variant_config.model.id, + variant_config.gpu_arch, + isl, + osl, + seq_combo.get("name", "default"), + concurrency, + ) + if key not in inf_res_dict: + pytest.skip(f"no recorded results for cell {key!r} (inference did not run)") + host_dict = inf_res_dict[key] + _host, actuals = next(iter(host_dict.items())) + full = "client." + metric + value = actuals.get(full) + unit = _METRIC_UNITS.get(metric, "-") + request.node.user_properties.append(("metric_value", value)) + request.node.user_properties.append(("metric_unit", unit)) + + if not variant_config.enforce_thresholds: + return + cell = variant_config.cell_key(isl, osl, concurrency) + spec = (variant_config.thresholds.get(cell) or {}).get(full) + if spec is None: + return + evaluate_all(actuals, {full: spec}) + + +def test_teardown(orch, lifecycle, request): + """Final stage: explicit container teardown, timed, asserting it is gone. + + Sets lifecycle.torn_down so the orch fixture's leak-guard finalizer no-ops + (avoids a double teardown). Runs even if an earlier stage failed -- teardown + must happen regardless -- so it does NOT skip on lifecycle.failed. + """ + name = orch.get_container_name(orch.container_config, orch.container_config["image"]) + t = time.monotonic() + orch.teardown_containers() + lifecycle.record(request.node.nodeid, "teardown", time.monotonic() - t) + if orch.verify_containers_running(name): + # Leave torn_down False so the orch finalizer retries the teardown. + pytest.fail(f"container {name} still running after teardown_containers()") + lifecycle.torn_down = True