diff --git a/hydra/_internal/core_plugins/basic_sweeper.py b/hydra/_internal/core_plugins/basic_sweeper.py index 3e41fcb5529..c10ca0edc45 100644 --- a/hydra/_internal/core_plugins/basic_sweeper.py +++ b/hydra/_internal/core_plugins/basic_sweeper.py @@ -25,10 +25,11 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence from omegaconf import DictConfig, OmegaConf +from omegaconf._utils import is_structured_config from hydra.core.config_store import ConfigStore from hydra.core.override_parser.overrides_parser import OverridesParser -from hydra.core.override_parser.types import Override +from hydra.core.override_parser.types import Override, QuotedString from hydra.core.utils import JobReturn from hydra.errors import HydraException from hydra.plugins.launcher import Launcher @@ -93,6 +94,78 @@ def setup( config=config, ) + @staticmethod + def simplify_overrides( + overrides: List[Override], + ) -> List[Override]: + # this would simplify the overrides by removing those that are overridden later + # in the list. + # e.g. a=1 and later a=10 would remove the first override. + lists = [] + # NOTE: key -> index of last override with no dict value. (e.g. a=1,2,3) + # any override for key before this would be skipped. + last_primitive = {} + last_dict = {} + last_defaults: Dict[str, int] = {} + + is_defaults: Dict[int, bool] = {} + is_primitive: Dict[int, bool] = {} + has_dict: Dict[int, bool] = {} + + # check value should override earlier ones + # TODO: handle extend_list + def check_write_override(x: Any): + return ( + isinstance(x, (str, int, float, bool, list, QuotedString)) or x is None + ) + + def check_has_dict(x: Any): + return isinstance(x, dict) or is_structured_config(x) + + for i, override in enumerate(overrides): + if override.config_loader is None: + continue + is_group = ( + len(override.config_loader.get_group_options(override.key_or_group)) > 0 + ) + + key = override.get_key_element() + _write = False + _has_dict = False + if override.is_sweep_override(): + if override.is_discrete_sweep(): + _write = all(override.sweep_iterator(check_write_override)) + _has_dict = any(override.sweep_iterator(check_has_dict)) + else: + _write = check_write_override(override.value()) + _has_dict = check_has_dict(override.value()) + + if _write: + if is_group: + is_defaults[i] = True + if override.is_change(): + last_defaults[key] = i + else: + is_primitive[i] = True + last_primitive[key] = i + if _has_dict: + has_dict[i] = True + last_dict[key] = i + + for i, override in enumerate(overrides): + key = override.get_key_element() + if is_primitive.get(i, False) and ( + last_primitive.get(key, -1) != i or last_dict.get(key, -1) > i + ): + continue + if has_dict.get(i, False) and last_primitive.get(key, -1) > i: + continue + if is_defaults.get(i, False) and last_defaults.get(key, -1) > i: + continue + lists.append(override) + + return lists + @staticmethod def split_overrides_to_chunks( lst: List[List[str]], n: Optional[int] @@ -108,13 +181,13 @@ def split_arguments( overrides: List[Override], max_batch_size: Optional[int] ) -> List[List[List[str]]]: lists = [] - final_overrides = OrderedDict() + overrides = BasicSweeper.simplify_overrides(overrides) for override in overrides: if override.is_sweep_override(): if override.is_discrete_sweep(): key = override.get_key_element() sweep = [f"{key}={val}" for val in override.sweep_string_iterator()] - final_overrides[key] = sweep + lists.append(sweep) else: assert override.value_type is not None raise HydraException( @@ -123,10 +196,7 @@ def split_arguments( else: key = override.get_key_element() value = override.get_value_element_as_str() - final_overrides[key] = [f"{key}={value}"] - - for _, v in final_overrides.items(): - lists.append(v) + lists.append([f"{key}={value}"]) all_batches = [list(x) for x in itertools.product(*lists)] assert max_batch_size is None or max_batch_size > 0 diff --git a/hydra/core/override_parser/types.py b/hydra/core/override_parser/types.py index f020ccb6c6b..d813bb7edfb 100644 --- a/hydra/core/override_parser/types.py +++ b/hydra/core/override_parser/types.py @@ -274,6 +274,12 @@ class Override: # Configs repo config_loader: Optional[ConfigLoader] = None + def is_change(self) -> bool: + """ + :return: True if this override represents a change of a config value or config group option + """ + return self.type == OverrideType.CHANGE + def is_delete(self) -> bool: """ :return: True if this override represents a deletion of a config value or config group option diff --git a/tests/test_basic_sweeper.py b/tests/test_basic_sweeper.py index 0d2b0dc89a0..1eb9b2d5fba 100644 --- a/tests/test_basic_sweeper.py +++ b/tests/test_basic_sweeper.py @@ -6,7 +6,9 @@ from pytest import mark, param +from hydra._internal.config_loader_impl import ConfigLoaderImpl from hydra._internal.core_plugins.basic_sweeper import BasicSweeper +from hydra._internal.utils import create_config_search_path from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.test_utils.test_utils import assert_multiline_regex_search, run_process @@ -48,12 +50,65 @@ ), param(["a=range(0,3)"], None, [[["a=0"], ["a=1"], ["a=2"]]], id="range"), param(["a=range(3)"], None, [[["a=0"], ["a=1"], ["a=2"]]], id="range_no_start"), + param(["a=1,2,3", "a=20"], None, [[["a=20"]]], id="override_same_key1"), + param( + ["a=2", "a=10,20"], None, [[["a=10"], ["a=20"]]], id="override_same_key2" + ), + param( + ["a=1,2,3", "a=10,20"], + None, + [[["a=10"], ["a=20"]]], + id="override_same_key3", + ), + param(["a={x:1},{x:2}"], None, [[["a={x:1}"], ["a={x:2}"]]], id="dicts"), + param( + ["a={x:1},{x:2}", "+a={y:10},{y:20}"], + None, + [ + [ + ["a={x:1}", "+a={y:10}"], + ["a={x:1}", "+a={y:20}"], + ["a={x:2}", "+a={y:10}"], + ["a={x:2}", "+a={y:20}"], + ] + ], + id="dicts_multiple_with_plus", + ), + param( + ["a={x:1},{x:2}", "a={y:10},{y:20}"], + None, + [ + [ + ["a={x:1}", "a={y:10}"], + ["a={x:1}", "a={y:20}"], + ["a={x:2}", "a={y:10}"], + ["a={x:2}", "a={y:20}"], + ] + ], + id="dicts_multiple", + ), + param(["a=1,2,3", "a={x:1}"], None, [[["a={x:1}"]]], id="override_with_dict1"), + param( + ["a=1,2,3", "a={x:1},{x:2}"], + None, + [[["a={x:1}"], ["a={x:2}"]]], + id="override_with_dict2", + ), + param( + ["a={x:1}", "a=1,2,3"], + None, + [[["a=1"], ["a=2"], ["a=3"]]], + id="override_with_dict3", + ), + param(["a={x:1},{x:2}", "a=1"], None, [[["a=1"]]], id="override_with_dict4"), ], ) def test_split( args: List[str], max_batch_size: Optional[int], expected: List[List[List[str]]] ) -> None: - parser = OverridesParser.create() + + config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) + parser = OverridesParser.create(config_loader) ret = BasicSweeper.split_arguments( parser.parse_overrides(args), max_batch_size=max_batch_size ) @@ -61,6 +116,40 @@ def test_split( assert lret == expected +@mark.parametrize( + "args,expected", + [ + param(["a=1", "b=2", "a=3"], ["b=2", "a=3"], id="simple_override"), + param(["a=1,2", "a=3,4"], ["a=3,4"], id="override_split"), + param( + ["a=1", "b=2", "+a={x:10}", "+a={y:20}"], + ["a=1", "b=2", "+a={x:10}", "+a={y:20}"], + id="override_plus", + ), + param( + ["a=1", "b=2", "a={x:10}", "a={y:20}"], + ["b=2", "a={x:10}", "a={y:20}"], + id="override_plus", + ), + param(["a={x:1}", "a={y:2}"], ["a={x:1}", "a={y:2}"], id="override_dict"), + param( + ["a=1,2", "+a={x:10},{y:20}", "a=3,4"], + ["+a={x:10},{y:20}", "a=3,4"], + id="override_mixed", + ), + param(["a=1,2", "a={x:10},{y:20}", "a=3,4"], ["a=3,4"], id="override_mixed"), + param(["+a=xx,yy", "+a=[zz]"], ["+a=[zz]"], id="override_plus_list"), + ], +) +def test_simplify(args: List[str], expected: List[str]) -> None: + config_loader = ConfigLoaderImpl(config_search_path=create_config_search_path(None)) + parser = OverridesParser.create(config_loader) + overrides = parser.parse_overrides(args) + simplified = BasicSweeper.simplify_overrides(overrides) + expected_overrides = parser.parse_overrides(expected) + assert simplified == expected_overrides + + def test_partial_failure( tmpdir: Any, ) -> None: diff --git a/tests/test_examples/test_basic_sweep.py b/tests/test_examples/test_basic_sweep.py index 8e1ac120ea5..1bbf8b2f25e 100644 --- a/tests/test_examples/test_basic_sweep.py +++ b/tests/test_examples/test_basic_sweep.py @@ -37,9 +37,9 @@ dedent( """\ [HYDRA] Launching 2 jobs locally - [HYDRA] \t#0 : db=mysql db.timeout=5 + [HYDRA] \t#0 : db.timeout=5 db=mysql driver=mysql, timeout=5 - [HYDRA] \t#1 : db=mysql db.timeout=10 + [HYDRA] \t#1 : db.timeout=10 db=mysql driver=mysql, timeout=10""" ), ), @@ -48,13 +48,13 @@ dedent( """\ [HYDRA] Launching 4 jobs locally - [HYDRA] \t#0 : db=mysql db.timeout=5 db.user=one + [HYDRA] \t#0 : db.timeout=5 db=mysql db.user=one driver=mysql, timeout=5 - [HYDRA] \t#1 : db=mysql db.timeout=5 db.user=two + [HYDRA] \t#1 : db.timeout=5 db=mysql db.user=two driver=mysql, timeout=5 - [HYDRA] \t#2 : db=mysql db.timeout=10 db.user=one + [HYDRA] \t#2 : db.timeout=10 db=mysql db.user=one driver=mysql, timeout=10 - [HYDRA] \t#3 : db=mysql db.timeout=10 db.user=two + [HYDRA] \t#3 : db.timeout=10 db=mysql db.user=two driver=mysql, timeout=10""" ), ),