Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class BenchmarkMethod(Base):
"""Benchmark method, represented in terms of Ax generation strategy (which tells us
which models to use when) and Orchestrator options (which tell us extra execution
information like maximum parallelism, early stopping configuration, etc.).
information like maximum concurrency, early stopping configuration, etc.).

Args:
name: String description. Defaults to the name of the generation strategy.
Expand Down
53 changes: 30 additions & 23 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@
UnsupportedPlotError,
UserInputError,
)
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.exceptions.generation_strategy import (
MaxParallelismReachedException as MaxConcurrencyReachedException,
)
from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.generation_strategy.transition_criterion import MaxGenerationParallelism
from ax.generation_strategy.transition_criterion import (
MaxGenerationParallelism as MaxGenerationConcurrency,
)
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
from ax.global_stopping.strategies.improvement import constraint_satisfaction
from ax.plot.base import AxPlotConfig
Expand Down Expand Up @@ -570,7 +574,7 @@ def get_next_trial(
),
ttl_seconds=ttl_seconds,
)
except MaxParallelismReachedException as e:
except MaxConcurrencyReachedException as e:
if self._early_stopping_strategy is not None:
e.message += ( # noqa: B306
" When stopping trials early, make sure to call `stop_trial_early` "
Expand Down Expand Up @@ -836,39 +840,39 @@ def get_trials_data_frame(self) -> pd.DataFrame:
"""
return self.experiment.to_df()

def get_max_parallelism(self) -> list[tuple[int, int]]:
"""Retrieves maximum number of trials that can be scheduled in parallel
def get_max_concurrency(self) -> list[tuple[int, int]]:
"""Retrieves maximum number of trials that can be scheduled concurrently
at different stages of optimization.

Some optimization algorithms profit significantly from sequential
optimization (i.e. suggest a few points, get updated with data for them,
repeat, see https://ax.dev/docs/bayesopt.html).
Parallelism setting indicates how many trials should be running simulteneously
Concurrency setting indicates how many trials should be running simultaneously
(generated, but not yet completed with data).

The output of this method is mapping of form
{num_trials -> max_parallelism_setting}, where the max_parallelism_setting
is used for num_trials trials. If max_parallelism_setting is -1, as
many of the trials can be ran in parallel, as necessary. If num_trials
in a tuple is -1, then the corresponding max_parallelism_setting
{num_trials -> max_concurrency_setting}, where the max_concurrency_setting
is used for num_trials trials. If max_concurrency_setting is -1, as
many of the trials can be ran concurrently, as necessary. If num_trials
in a tuple is -1, then the corresponding max_concurrency_setting
should be used for all subsequent trials.

For example, if the returned list is [(5, -1), (12, 6), (-1, 3)],
the schedule could be: run 5 trials with any parallelism, run 6 trials in
parallel twice, run 3 trials in parallel for as long as needed. Here,
the schedule could be: run 5 trials with any concurrency, run 6 trials
concurrently twice, run 3 trials concurrently for as long as needed. Here,
'running' a trial means obtaining a next trial from `AxClient` through
get_next_trials and completing it with data when available.

Returns:
Mapping of form {num_trials -> max_parallelism_setting}.
Mapping of form {num_trials -> max_concurrency_setting}.
"""
parallelism_settings = []
concurrency_settings = []
for node in self.generation_strategy._nodes:
# Extract max_parallelism from MaxGenerationParallelism criterion
max_parallelism = None
# Extract max_concurrency from MaxGenerationConcurrency criterion
max_concurrency = None
for tc in node.transition_criteria:
if isinstance(tc, MaxGenerationParallelism):
max_parallelism = tc.threshold
if isinstance(tc, MaxGenerationConcurrency):
max_concurrency = tc.threshold
break
# Try to get num_trials from the node. If there's no MinTrials
# criterion (unlimited trials), num_trials will raise UserInputError.
Expand All @@ -877,13 +881,16 @@ def get_max_parallelism(self) -> list[tuple[int, int]]:
num_trials = node.num_trials
except UserInputError:
num_trials = -1
parallelism_settings.append(
concurrency_settings.append(
(
num_trials,
max_parallelism if max_parallelism is not None else num_trials,
max_concurrency if max_concurrency is not None else num_trials,
)
)
return parallelism_settings
return concurrency_settings

def get_max_parallelism(self) -> list[tuple[int, int]]:
raise NotImplementedError("Use `get_max_concurrency` instead.")

def get_optimization_trace(
self, objective_optimum: float | None = None
Expand Down Expand Up @@ -1702,8 +1709,8 @@ def __repr__(self) -> str:
@staticmethod
def get_recommended_max_parallelism() -> None:
raise NotImplementedError(
"Use `get_max_parallelism` instead; parallelism levels are now "
"enforced in generation strategy, so max parallelism is no longer "
"Use `get_max_concurrency` instead; concurrency levels are now "
"enforced in generation strategy, so max concurrency is no longer "
"just recommended."
)

Expand Down
30 changes: 17 additions & 13 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
UnsupportedPlotError,
UserInputError,
)
from ax.exceptions.generation_strategy import MaxParallelismReachedException
from ax.exceptions.generation_strategy import (
MaxParallelismReachedException as MaxConcurrencyReachedException,
)
from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_CONCURRENCY
from ax.generation_strategy.generation_strategy import (
GenerationNode,
Expand All @@ -58,7 +60,7 @@
)
from ax.generation_strategy.generator_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import (
MaxGenerationParallelism,
MaxGenerationParallelism as MaxGenerationConcurrency,
MinTrials,
)
from ax.metrics.branin import branin, BraninMetric
Expand Down Expand Up @@ -1616,14 +1618,14 @@ def test_keep_generating_without_data(self) -> None:
self.assertTrue(len(node0_min_trials) > 0)
self.assertFalse(node0_min_trials[0].block_gen_if_met)

# Check that max_parallelism is None by verifying no MaxGenerationParallelism
# Check that max_concurrency is None by verifying no MaxGenerationConcurrency
# criterion exists on node 1
node1_max_parallelism = [
node1_max_concurrency = [
tc
for tc in ax_client.generation_strategy._nodes[1].transition_criteria
if isinstance(tc, MaxGenerationParallelism)
if isinstance(tc, MaxGenerationConcurrency)
]
self.assertEqual(len(node1_max_parallelism), 0)
self.assertEqual(len(node1_max_concurrency), 0)

for _ in range(10):
ax_client.get_next_trial()
Expand Down Expand Up @@ -1939,17 +1941,17 @@ def test_relative_oc_without_sq(self) -> None:
def test_recommended_parallelism(self) -> None:
ax_client = AxClient()
with self.assertRaisesRegex(AssertionError, "No generation strategy"):
ax_client.get_max_parallelism()
ax_client.get_max_concurrency()
ax_client.create_experiment(
parameters=[
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
],
)
self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)])
self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)])
self.assertEqual(
run_trials_using_recommended_parallelism(
ax_client, ax_client.get_max_parallelism(), 20
ax_client, ax_client.get_max_concurrency(), 20
),
0,
)
Expand Down Expand Up @@ -2320,6 +2322,8 @@ def test_deprecated_save_load_method_errors(self) -> None:
ax_client.load_experiment("test_experiment")
with self.assertRaises(NotImplementedError):
ax_client.get_recommended_max_parallelism()
with self.assertRaises(NotImplementedError):
ax_client.get_max_parallelism()

def test_find_last_trial_with_parameterization(self) -> None:
ax_client = AxClient()
Expand Down Expand Up @@ -2872,7 +2876,7 @@ def test_estimate_early_stopping_savings(self) -> None:

self.assertEqual(ax_client.estimate_early_stopping_savings(), 0)

def test_max_parallelism_exception_when_early_stopping(self) -> None:
def test_max_concurrency_exception_when_early_stopping(self) -> None:
ax_client = AxClient()
ax_client.create_experiment(
parameters=[
Expand All @@ -2882,7 +2886,7 @@ def test_max_parallelism_exception_when_early_stopping(self) -> None:
support_intermediate_data=True,
)

exception = MaxParallelismReachedException(step_index=1, num_running=10)
exception = MaxConcurrencyReachedException(step_index=1, num_running=10)

# pyre-fixme[53]: Captured variable `exception` is not annotated.
def fake_new_trial(*args: Any, **kwargs: Any) -> None:
Expand All @@ -2892,15 +2896,15 @@ def fake_new_trial(*args: Any, **kwargs: Any) -> None:
ax_client.experiment.new_trial = fake_new_trial

# Without early stopping.
with self.assertRaises(MaxParallelismReachedException) as cm:
with self.assertRaises(MaxConcurrencyReachedException) as cm:
ax_client.get_next_trial()
# Assert Exception's message is unchanged.
self.assertEqual(cm.exception.message, exception.message)

# With early stopping.
ax_client._early_stopping_strategy = DummyEarlyStoppingStrategy()
# Assert Exception's message is augmented to mention early stopping.
with self.assertRaisesRegex(MaxParallelismReachedException, ".*early.*stop"):
with self.assertRaisesRegex(MaxConcurrencyReachedException, ".*early.*stop"):
ax_client.get_next_trial()

def test_experiment_does_not_support_early_stopping(self) -> None:
Expand Down