Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import inspect
import logging
import warnings
from collections import defaultdict
from collections import defaultdict, OrderedDict
from collections.abc import Hashable, Iterable, Mapping, Sequence
from copy import deepcopy
from datetime import datetime
Expand All @@ -31,7 +31,7 @@
from ax.core.batch_trial import BatchTrial
from ax.core.data import combine_data_rows_favoring_recent, Data
from ax.core.experiment_status import ExperimentStatus
from ax.core.generator_run import GeneratorRun
from ax.core.generator_run import ArmWeight, GeneratorRun
from ax.core.llm_provider import LLMMessage
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.core.objective import MultiObjective
Expand Down Expand Up @@ -1949,6 +1949,7 @@ def clone_with(
properties_to_keep: list[str] | None = None,
trial_indices: list[int] | None = None,
clear_trial_type: bool = False,
filter_arm_params_to_search_space: bool = False,
) -> Experiment:
r"""
Return a copy of this experiment with some attributes replaced.
Expand Down Expand Up @@ -1977,6 +1978,10 @@ def clone_with(
clones all trials.
clear_trial_type: If True, all cloned trials on the cloned experiment have
`trial_type` set to `None`.
filter_arm_params_to_search_space: If True and a new search_space is
provided, filter each cloned arm's parameters to only include
parameters present in the new search space. This enables reducing
the search space while keeping arm data compatible.
"""
if properties_to_keep is None:
properties_to_keep = ["owners"]
Expand Down Expand Up @@ -2009,6 +2014,19 @@ def clone_with(
if (status_quo is None and self.status_quo is not None)
else status_quo
)
# Filter status_quo params when reducing search space
if (
filter_arm_params_to_search_space
and status_quo is not None
and search_space is not None
):
filtered_sq_params = {
k: v
for k, v in status_quo.parameters.items()
if k in search_space.parameters
}
status_quo = Arm(parameters=filtered_sq_params, name=status_quo.name)

description = self.description if description is None else description
is_test = self.is_test if is_test is None else is_test

Expand All @@ -2035,6 +2053,10 @@ def clone_with(
properties=properties,
)

params_to_keep: set[str] | None = None
if filter_arm_params_to_search_space and search_space is not None:
params_to_keep = set(search_space.parameters.keys())

# Clone only the specified trials.
original_trial_indices = self.trials.keys()
trial_indices_to_keep = (
Expand All @@ -2060,6 +2082,8 @@ def clone_with(
new_trial = trial.clone_to(
cloned_experiment, clear_trial_type=clear_trial_type
)
if params_to_keep is not None:
_filter_trial_arm_params(new_trial, params_to_keep)
new_index = new_trial.index
old_index_to_new_index[trial_index] = new_index

Expand Down Expand Up @@ -2399,6 +2423,26 @@ def auxiliary_experiments_by_purpose_for_storage(
return result


def _filter_trial_arm_params(
trial: Trial | BatchTrial,
params_to_keep: set[str],
) -> None:
"""Filter arm parameters in-place to only include specified parameters.

Replaces each arm in every generator run with a new Arm whose parameters
contain only the keys present in ``params_to_keep``.
"""
for gr in trial.generator_runs:
new_table: OrderedDict[str, ArmWeight] = OrderedDict()
for _sig, aw in gr._arm_weight_table.items():
filtered = {
k: v for k, v in aw.arm.parameters.items() if k in params_to_keep
}
new_arm = Arm(parameters=filtered, name=aw.arm.name)
new_table[new_arm.signature] = ArmWeight(arm=new_arm, weight=aw.weight)
gr._arm_weight_table = new_table


def add_arm_and_prevent_naming_collision(
new_trial: Trial, old_trial: Trial, old_experiment_name: str | None = None
) -> None:
Expand Down
38 changes: 38 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,44 @@ def test_clone_with(self) -> None:
cloned_experiment = experiment.clone_with(clear_trial_type=True)
self.assertIsNone(cloned_experiment.trials[0].trial_type)

# Test cloning with filter_arm_params_to_search_space
experiment = get_branin_experiment(
with_completed_trial=True,
)
reduced_search_space = SearchSpace(
parameters=[
RangeParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
lower=-5.0,
upper=10.0,
),
],
)
cloned_experiment = experiment.clone_with(
search_space=reduced_search_space,
filter_arm_params_to_search_space=True,
)
for trial in cloned_experiment.trials.values():
for arm in trial.arms:
self.assertEqual(set(arm.parameters.keys()), {"x1"})
self.assertNotIn("x2", arm.parameters)
self.assertEqual(set(cloned_experiment.search_space.parameters.keys()), {"x1"})
# Verify data is preserved
original_data = experiment.lookup_data()
cloned_data = cloned_experiment.lookup_data()
self.assertEqual(len(cloned_data.df), len(original_data.df))

# Verify that filter_arm_params_to_search_space=False preserves all params
cloned_no_filter = experiment.clone_with(
search_space=reduced_search_space,
filter_arm_params_to_search_space=False,
)
for trial in cloned_no_filter.trials.values():
for arm in trial.arms:
self.assertIn("x1", arm.parameters)
self.assertIn("x2", arm.parameters)

# Test cloning with specific properties to keep
experiment_w_props = get_branin_experiment()
experiment_w_props._properties = {
Expand Down