From 19ec42a7409dc6f6f7103ecc35bb3e4e2490f290 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 26 Feb 2026 11:10:42 -0800 Subject: [PATCH 1/2] Rename get_max_parallelism to get_max_concurrency in AxClient (#4923) Summary: Renames `AxClient.get_max_parallelism()` to `get_max_concurrency()` and updates internal variable names, comments, and docstrings to use "concurrency" terminology. The old `get_max_parallelism` is preserved as a deprecated stub raising `NotImplementedError`. Also updates `get_recommended_max_parallelism` to point to the new name, and imports `MaxParallelismReachedException` / `MaxGenerationParallelism` under concurrency-named aliases. `get_max_parallelism` is only used directly in ad-hoc notebooks, making this a low-risk rename Differential Revision: D93771849 --- ax/service/ax_client.py | 53 +++++++++++++++++------------- ax/service/tests/test_ax_client.py | 30 +++++++++-------- 2 files changed, 47 insertions(+), 36 deletions(-) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index cb49ed1e39c..88915410d3a 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -42,10 +42,14 @@ UnsupportedPlotError, UserInputError, ) -from ax.exceptions.generation_strategy import MaxParallelismReachedException +from ax.exceptions.generation_strategy import ( + MaxParallelismReachedException as MaxConcurrencyReachedException, +) from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy from ax.generation_strategy.generation_strategy import GenerationStrategy -from ax.generation_strategy.transition_criterion import MaxGenerationParallelism +from ax.generation_strategy.transition_criterion import ( + MaxGenerationParallelism as MaxGenerationConcurrency, +) from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy from ax.global_stopping.strategies.improvement import constraint_satisfaction from ax.plot.base import AxPlotConfig @@ -570,7 +574,7 @@ def get_next_trial( ), ttl_seconds=ttl_seconds, ) - except MaxParallelismReachedException as e: + except MaxConcurrencyReachedException as e: if self._early_stopping_strategy is not None: e.message += ( # noqa: B306 " When stopping trials early, make sure to call `stop_trial_early` " @@ -836,39 +840,39 @@ def get_trials_data_frame(self) -> pd.DataFrame: """ return self.experiment.to_df() - def get_max_parallelism(self) -> list[tuple[int, int]]: - """Retrieves maximum number of trials that can be scheduled in parallel + def get_max_concurrency(self) -> list[tuple[int, int]]: + """Retrieves maximum number of trials that can be scheduled concurrently at different stages of optimization. Some optimization algorithms profit significantly from sequential optimization (i.e. suggest a few points, get updated with data for them, repeat, see https://ax.dev/docs/bayesopt.html). - Parallelism setting indicates how many trials should be running simulteneously + Concurrency setting indicates how many trials should be running simultaneously (generated, but not yet completed with data). The output of this method is mapping of form - {num_trials -> max_parallelism_setting}, where the max_parallelism_setting - is used for num_trials trials. If max_parallelism_setting is -1, as - many of the trials can be ran in parallel, as necessary. If num_trials - in a tuple is -1, then the corresponding max_parallelism_setting + {num_trials -> max_concurrency_setting}, where the max_concurrency_setting + is used for num_trials trials. If max_concurrency_setting is -1, as + many of the trials can be ran concurrently, as necessary. If num_trials + in a tuple is -1, then the corresponding max_concurrency_setting should be used for all subsequent trials. For example, if the returned list is [(5, -1), (12, 6), (-1, 3)], - the schedule could be: run 5 trials with any parallelism, run 6 trials in - parallel twice, run 3 trials in parallel for as long as needed. Here, + the schedule could be: run 5 trials with any concurrency, run 6 trials + concurrently twice, run 3 trials concurrently for as long as needed. Here, 'running' a trial means obtaining a next trial from `AxClient` through get_next_trials and completing it with data when available. Returns: - Mapping of form {num_trials -> max_parallelism_setting}. + Mapping of form {num_trials -> max_concurrency_setting}. """ - parallelism_settings = [] + concurrency_settings = [] for node in self.generation_strategy._nodes: - # Extract max_parallelism from MaxGenerationParallelism criterion - max_parallelism = None + # Extract max_concurrency from MaxGenerationConcurrency criterion + max_concurrency = None for tc in node.transition_criteria: - if isinstance(tc, MaxGenerationParallelism): - max_parallelism = tc.threshold + if isinstance(tc, MaxGenerationConcurrency): + max_concurrency = tc.threshold break # Try to get num_trials from the node. If there's no MinTrials # criterion (unlimited trials), num_trials will raise UserInputError. @@ -877,13 +881,16 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: num_trials = node.num_trials except UserInputError: num_trials = -1 - parallelism_settings.append( + concurrency_settings.append( ( num_trials, - max_parallelism if max_parallelism is not None else num_trials, + max_concurrency if max_concurrency is not None else num_trials, ) ) - return parallelism_settings + return concurrency_settings + + def get_max_parallelism(self) -> list[tuple[int, int]]: + raise NotImplementedError("Use `get_max_concurrency` instead.") def get_optimization_trace( self, objective_optimum: float | None = None @@ -1702,8 +1709,8 @@ def __repr__(self) -> str: @staticmethod def get_recommended_max_parallelism() -> None: raise NotImplementedError( - "Use `get_max_parallelism` instead; parallelism levels are now " - "enforced in generation strategy, so max parallelism is no longer " + "Use `get_max_concurrency` instead; concurrency levels are now " + "enforced in generation strategy, so max concurrency is no longer " "just recommended." ) diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index cd6b36470d1..3e95ab5f74b 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -49,7 +49,9 @@ UnsupportedPlotError, UserInputError, ) -from ax.exceptions.generation_strategy import MaxParallelismReachedException +from ax.exceptions.generation_strategy import ( + MaxParallelismReachedException as MaxConcurrencyReachedException, +) from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_CONCURRENCY from ax.generation_strategy.generation_strategy import ( GenerationNode, @@ -58,7 +60,7 @@ ) from ax.generation_strategy.generator_spec import GeneratorSpec from ax.generation_strategy.transition_criterion import ( - MaxGenerationParallelism, + MaxGenerationParallelism as MaxGenerationConcurrency, MinTrials, ) from ax.metrics.branin import branin, BraninMetric @@ -1616,14 +1618,14 @@ def test_keep_generating_without_data(self) -> None: self.assertTrue(len(node0_min_trials) > 0) self.assertFalse(node0_min_trials[0].block_gen_if_met) - # Check that max_parallelism is None by verifying no MaxGenerationParallelism + # Check that max_concurrency is None by verifying no MaxGenerationConcurrency # criterion exists on node 1 - node1_max_parallelism = [ + node1_max_concurrency = [ tc for tc in ax_client.generation_strategy._nodes[1].transition_criteria - if isinstance(tc, MaxGenerationParallelism) + if isinstance(tc, MaxGenerationConcurrency) ] - self.assertEqual(len(node1_max_parallelism), 0) + self.assertEqual(len(node1_max_concurrency), 0) for _ in range(10): ax_client.get_next_trial() @@ -1939,17 +1941,17 @@ def test_relative_oc_without_sq(self) -> None: def test_recommended_parallelism(self) -> None: ax_client = AxClient() with self.assertRaisesRegex(AssertionError, "No generation strategy"): - ax_client.get_max_parallelism() + ax_client.get_max_concurrency() ax_client.create_experiment( parameters=[ {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)]) + self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)]) self.assertEqual( run_trials_using_recommended_parallelism( - ax_client, ax_client.get_max_parallelism(), 20 + ax_client, ax_client.get_max_concurrency(), 20 ), 0, ) @@ -2320,6 +2322,8 @@ def test_deprecated_save_load_method_errors(self) -> None: ax_client.load_experiment("test_experiment") with self.assertRaises(NotImplementedError): ax_client.get_recommended_max_parallelism() + with self.assertRaises(NotImplementedError): + ax_client.get_max_parallelism() def test_find_last_trial_with_parameterization(self) -> None: ax_client = AxClient() @@ -2872,7 +2876,7 @@ def test_estimate_early_stopping_savings(self) -> None: self.assertEqual(ax_client.estimate_early_stopping_savings(), 0) - def test_max_parallelism_exception_when_early_stopping(self) -> None: + def test_max_concurrency_exception_when_early_stopping(self) -> None: ax_client = AxClient() ax_client.create_experiment( parameters=[ @@ -2882,7 +2886,7 @@ def test_max_parallelism_exception_when_early_stopping(self) -> None: support_intermediate_data=True, ) - exception = MaxParallelismReachedException(step_index=1, num_running=10) + exception = MaxConcurrencyReachedException(step_index=1, num_running=10) # pyre-fixme[53]: Captured variable `exception` is not annotated. def fake_new_trial(*args: Any, **kwargs: Any) -> None: @@ -2892,7 +2896,7 @@ def fake_new_trial(*args: Any, **kwargs: Any) -> None: ax_client.experiment.new_trial = fake_new_trial # Without early stopping. - with self.assertRaises(MaxParallelismReachedException) as cm: + with self.assertRaises(MaxConcurrencyReachedException) as cm: ax_client.get_next_trial() # Assert Exception's message is unchanged. self.assertEqual(cm.exception.message, exception.message) @@ -2900,7 +2904,7 @@ def fake_new_trial(*args: Any, **kwargs: Any) -> None: # With early stopping. ax_client._early_stopping_strategy = DummyEarlyStoppingStrategy() # Assert Exception's message is augmented to mention early stopping. - with self.assertRaisesRegex(MaxParallelismReachedException, ".*early.*stop"): + with self.assertRaisesRegex(MaxConcurrencyReachedException, ".*early.*stop"): ax_client.get_next_trial() def test_experiment_does_not_support_early_stopping(self) -> None: From b25c1597db3902f904e7d63d0c4f916f3d550c33 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Thu, 26 Feb 2026 11:10:42 -0800 Subject: [PATCH 2/2] Documentation update: use concurrency terminology in benchmark docstrings Summary: Updates docstrings in `BenchmarkMethod`, `BenchmarkExecutionSettings`, and `nightly.py` to use "concurrency" terminology instead of "parallelism" where appropriate. Also applies a formatting fix to a multi-line `BenchmarkExecutionSettings(...)` call in `nightly.py`. No interface or behavioral changes. Differential Revision: D93771883 --- ax/benchmark/benchmark_method.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ax/benchmark/benchmark_method.py b/ax/benchmark/benchmark_method.py index bea5463bd2f..1abb4c5c361 100644 --- a/ax/benchmark/benchmark_method.py +++ b/ax/benchmark/benchmark_method.py @@ -16,7 +16,7 @@ class BenchmarkMethod(Base): """Benchmark method, represented in terms of Ax generation strategy (which tells us which models to use when) and Orchestrator options (which tell us extra execution - information like maximum parallelism, early stopping configuration, etc.). + information like maximum concurrency, early stopping configuration, etc.). Args: name: String description. Defaults to the name of the generation strategy.