-
-
Notifications
You must be signed in to change notification settings - Fork 870
Add singleton option #3013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add singleton option #3013
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ class _Keys(str, Enum): | |
| RECURSIVE = "_recursive_" | ||
| ARGS = "_args_" | ||
| PARTIAL = "_partial_" | ||
| SINGLETON = "_singleton_" | ||
|
|
||
|
|
||
| def _is_target(x: Any) -> bool: | ||
|
|
@@ -31,6 +32,12 @@ def _is_target(x: Any) -> bool: | |
| return "_target_" in x | ||
| return False | ||
|
|
||
| def _is_singleton(x: Any) -> bool: | ||
| if isinstance(x, dict): | ||
| return _Keys.SINGLETON in x | ||
| if OmegaConf.is_dict(x): | ||
| return _Keys.SINGLETON in x | ||
| return False | ||
|
|
||
| def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]: | ||
| config_args = kwargs.pop(_Keys.ARGS, ()) | ||
|
|
@@ -239,6 +246,7 @@ def instantiate( | |
| if is_structured_config(config) or isinstance(config, (dict, list)): | ||
| config = OmegaConf.structured(config, flags={"allow_objects": True}) | ||
|
|
||
| singleton_registry = {} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you consider making the registry more global so they can be saved across instantiate calls? Would be interested to know if anyone has strong opinions on which behavior is preferred
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I though about adding it to hydra global but that caused a circular reference in the modules. It could be passed in as an optional parameter to instantiate so that someone has the option to re-use it |
||
| if OmegaConf.is_dict(config): | ||
| # Finalize config (convert targets to strings, merge with kwargs) | ||
| # Create copy to avoid mutating original | ||
|
|
@@ -262,7 +270,8 @@ def instantiate( | |
| _partial_ = config.pop(_Keys.PARTIAL, False) | ||
|
|
||
| return instantiate_node( | ||
| config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ | ||
| config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_, | ||
| singleton_registry=singleton_registry | ||
| ) | ||
| elif OmegaConf.is_list(config): | ||
| # Finalize config (convert targets to strings, merge with kwargs) | ||
|
|
@@ -289,7 +298,8 @@ def instantiate( | |
| ) | ||
|
|
||
| return instantiate_node( | ||
| config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ | ||
| config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_, | ||
| singleton_registry=singleton_registry | ||
| ) | ||
| else: | ||
| raise InstantiationException( | ||
|
|
@@ -323,6 +333,7 @@ def instantiate_node( | |
| convert: Union[str, ConvertMode] = ConvertMode.NONE, | ||
| recursive: bool = True, | ||
| partial: bool = False, | ||
| singleton_registry: dict = {} | ||
| ) -> Any: | ||
| # Return None if config is None | ||
| if node is None or (OmegaConf.is_config(node) and node._is_none()): | ||
|
|
@@ -356,7 +367,8 @@ def instantiate_node( | |
| # If OmegaConf list, create new list of instances if recursive | ||
| if OmegaConf.is_list(node): | ||
| items = [ | ||
| instantiate_node(item, convert=convert, recursive=recursive) | ||
| instantiate_node(item, convert=convert, recursive=recursive, | ||
| singleton_registry=singleton_registry) | ||
| for item in node._iter_ex(resolve=True) | ||
| ] | ||
|
|
||
|
|
@@ -370,9 +382,18 @@ def instantiate_node( | |
| return lst | ||
|
|
||
| elif OmegaConf.is_dict(node): | ||
| exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) | ||
| exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_", "_singleton_"}) | ||
| if _is_target(node): | ||
| _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) | ||
| # check if this singleton has already been initialized | ||
| is_singleton = _is_singleton(node) | ||
| singleton_name = None | ||
| if is_singleton: | ||
| singleton_name = instantiate_node(node.get(_Keys.SINGLETON), | ||
| singleton_registry=singleton_registry) | ||
| if singleton_name in singleton_registry: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can/should we enforce
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that makes sense, will update |
||
| return singleton_registry[singleton_name] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm concerned this could get confusing if two different dicts have the same
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ya that is a bit concerning and tricky to catch, checking the dicts are the same feels wrong but would be good to check to give atleast a warning. |
||
|
|
||
| kwargs = {} | ||
| is_partial = node.get("_partial_", False) or partial | ||
| for key in node.keys(): | ||
|
|
@@ -382,11 +403,15 @@ def instantiate_node( | |
| value = node[key] | ||
| if recursive: | ||
| value = instantiate_node( | ||
| value, convert=convert, recursive=recursive | ||
| value, convert=convert, recursive=recursive, | ||
| singleton_registry=singleton_registry | ||
| ) | ||
| kwargs[key] = _convert_node(value, convert) | ||
|
|
||
| return _call_target(_target_, partial, args, kwargs, full_key) | ||
| value = _call_target(_target_, partial, args, kwargs, full_key) | ||
| if is_singleton: | ||
| singleton_registry[singleton_name] = value | ||
| return value | ||
| else: | ||
| # If ALL or PARTIAL non structured or OBJECT non structured, | ||
| # instantiate in dict and resolve interpolations eagerly. | ||
|
|
@@ -398,15 +423,17 @@ def instantiate_node( | |
| for key, value in node.items(): | ||
| # list items inherits recursive flag from the containing dict. | ||
| dict_items[key] = instantiate_node( | ||
| value, convert=convert, recursive=recursive | ||
| value, convert=convert, recursive=recursive, | ||
| singleton_registry=singleton_registry | ||
| ) | ||
| return dict_items | ||
| else: | ||
| # Otherwise use DictConfig and resolve interpolations lazily. | ||
| cfg = OmegaConf.create({}, flags={"allow_objects": True}) | ||
| for key, value in node.items(): | ||
| cfg[key] = instantiate_node( | ||
| value, convert=convert, recursive=recursive | ||
| value, convert=convert, recursive=recursive, | ||
| singleton_registry=singleton_registry | ||
| ) | ||
| cfg._set_parent(node) | ||
| cfg._metadata.object_type = node._metadata.object_type | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "singleton_id" feels more descriptive that it should be a string value, not bool