From f4dc02c0f9ce1ccb6c66254b98ba96c02a6dee33 Mon Sep 17 00:00:00 2001 From: vlado ovtcharov Date: Wed, 22 Jan 2025 17:39:59 -0500 Subject: [PATCH] Add singleton option --- hydra/_internal/instantiate/_instantiate2.py | 43 ++++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index bc279182749..08508064384 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -22,6 +22,7 @@ class _Keys(str, Enum): RECURSIVE = "_recursive_" ARGS = "_args_" PARTIAL = "_partial_" + SINGLETON = "_singleton_" def _is_target(x: Any) -> bool: @@ -31,6 +32,12 @@ def _is_target(x: Any) -> bool: return "_target_" in x return False +def _is_singleton(x: Any) -> bool: + if isinstance(x, dict): + return _Keys.SINGLETON in x + if OmegaConf.is_dict(x): + return _Keys.SINGLETON in x + return False def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]: config_args = kwargs.pop(_Keys.ARGS, ()) @@ -239,6 +246,7 @@ def instantiate( if is_structured_config(config) or isinstance(config, (dict, list)): config = OmegaConf.structured(config, flags={"allow_objects": True}) + singleton_registry = {} if OmegaConf.is_dict(config): # Finalize config (convert targets to strings, merge with kwargs) # Create copy to avoid mutating original @@ -262,7 +270,8 @@ def instantiate( _partial_ = config.pop(_Keys.PARTIAL, False) return instantiate_node( - config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_, + singleton_registry=singleton_registry ) elif OmegaConf.is_list(config): # Finalize config (convert targets to strings, merge with kwargs) @@ -289,7 +298,8 @@ def instantiate( ) return instantiate_node( - config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_, + singleton_registry=singleton_registry ) else: raise InstantiationException( @@ -323,6 +333,7 @@ def instantiate_node( convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, partial: bool = False, + singleton_registry: dict = {} ) -> Any: # Return None if config is None if node is None or (OmegaConf.is_config(node) and node._is_none()): @@ -356,7 +367,8 @@ def instantiate_node( # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): items = [ - instantiate_node(item, convert=convert, recursive=recursive) + instantiate_node(item, convert=convert, recursive=recursive, + singleton_registry=singleton_registry) for item in node._iter_ex(resolve=True) ] @@ -370,9 +382,18 @@ def instantiate_node( return lst elif OmegaConf.is_dict(node): - exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) + exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_", "_singleton_"}) if _is_target(node): _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) + # check if this singleton has already been initialized + is_singleton = _is_singleton(node) + singleton_name = None + if is_singleton: + singleton_name = instantiate_node(node.get(_Keys.SINGLETON), + singleton_registry=singleton_registry) + if singleton_name in singleton_registry: + return singleton_registry[singleton_name] + kwargs = {} is_partial = node.get("_partial_", False) or partial for key in node.keys(): @@ -382,11 +403,15 @@ def instantiate_node( value = node[key] if recursive: value = instantiate_node( - value, convert=convert, recursive=recursive + value, convert=convert, recursive=recursive, + singleton_registry=singleton_registry ) kwargs[key] = _convert_node(value, convert) - return _call_target(_target_, partial, args, kwargs, full_key) + value = _call_target(_target_, partial, args, kwargs, full_key) + if is_singleton: + singleton_registry[singleton_name] = value + return value else: # If ALL or PARTIAL non structured or OBJECT non structured, # instantiate in dict and resolve interpolations eagerly. @@ -398,7 +423,8 @@ def instantiate_node( for key, value in node.items(): # list items inherits recursive flag from the containing dict. dict_items[key] = instantiate_node( - value, convert=convert, recursive=recursive + value, convert=convert, recursive=recursive, + singleton_registry=singleton_registry ) return dict_items else: @@ -406,7 +432,8 @@ def instantiate_node( cfg = OmegaConf.create({}, flags={"allow_objects": True}) for key, value in node.items(): cfg[key] = instantiate_node( - value, convert=convert, recursive=recursive + value, convert=convert, recursive=recursive, + singleton_registry=singleton_registry ) cfg._set_parent(node) cfg._metadata.object_type = node._metadata.object_type