Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class _Keys(str, Enum):
RECURSIVE = "_recursive_"
ARGS = "_args_"
PARTIAL = "_partial_"
SINGLETON = "_singleton_"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "singleton_id" feels more descriptive that it should be a string value, not bool



def _is_target(x: Any) -> bool:
Expand All @@ -31,6 +32,12 @@ def _is_target(x: Any) -> bool:
return "_target_" in x
return False

def _is_singleton(x: Any) -> bool:
if isinstance(x, dict):
return _Keys.SINGLETON in x
if OmegaConf.is_dict(x):
return _Keys.SINGLETON in x
return False

def _extract_pos_args(input_args: Any, kwargs: Any) -> Tuple[Any, Any]:
config_args = kwargs.pop(_Keys.ARGS, ())
Expand Down Expand Up @@ -239,6 +246,7 @@ def instantiate(
if is_structured_config(config) or isinstance(config, (dict, list)):
config = OmegaConf.structured(config, flags={"allow_objects": True})

singleton_registry = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider making the registry more global so they can be saved across instantiate calls? Would be interested to know if anyone has strong opinions on which behavior is preferred

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I though about adding it to hydra global but that caused a circular reference in the modules. It could be passed in as an optional parameter to instantiate so that someone has the option to re-use it

if OmegaConf.is_dict(config):
# Finalize config (convert targets to strings, merge with kwargs)
# Create copy to avoid mutating original
Expand All @@ -262,7 +270,8 @@ def instantiate(
_partial_ = config.pop(_Keys.PARTIAL, False)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_,
singleton_registry=singleton_registry
)
elif OmegaConf.is_list(config):
# Finalize config (convert targets to strings, merge with kwargs)
Expand All @@ -289,7 +298,8 @@ def instantiate(
)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_,
singleton_registry=singleton_registry
)
else:
raise InstantiationException(
Expand Down Expand Up @@ -323,6 +333,7 @@ def instantiate_node(
convert: Union[str, ConvertMode] = ConvertMode.NONE,
recursive: bool = True,
partial: bool = False,
singleton_registry: dict = {}
) -> Any:
# Return None if config is None
if node is None or (OmegaConf.is_config(node) and node._is_none()):
Expand Down Expand Up @@ -356,7 +367,8 @@ def instantiate_node(
# If OmegaConf list, create new list of instances if recursive
if OmegaConf.is_list(node):
items = [
instantiate_node(item, convert=convert, recursive=recursive)
instantiate_node(item, convert=convert, recursive=recursive,
singleton_registry=singleton_registry)
for item in node._iter_ex(resolve=True)
]

Expand All @@ -370,9 +382,18 @@ def instantiate_node(
return lst

elif OmegaConf.is_dict(node):
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_", "_singleton_"})
if _is_target(node):
_target_ = _resolve_target(node.get(_Keys.TARGET), full_key)
# check if this singleton has already been initialized
is_singleton = _is_singleton(node)
singleton_name = None
if is_singleton:
singleton_name = instantiate_node(node.get(_Keys.SINGLETON),
singleton_registry=singleton_registry)
if singleton_name in singleton_registry:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can/should we enforce singleton_name is a string?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense, will update

return singleton_registry[singleton_name]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned this could get confusing if two different dicts have the same singleton_name and it's not totally obvious which will be used. Since this primarily seems to be wanted with omegaconf resolution, the dicts should be identical - what do you think of storing the original node in singleton_registry and doing an extra check that this node's fields also match the original?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya that is a bit concerning and tricky to catch, checking the dicts are the same feels wrong but would be good to check to give atleast a warning.


kwargs = {}
is_partial = node.get("_partial_", False) or partial
for key in node.keys():
Expand All @@ -382,11 +403,15 @@ def instantiate_node(
value = node[key]
if recursive:
value = instantiate_node(
value, convert=convert, recursive=recursive
value, convert=convert, recursive=recursive,
singleton_registry=singleton_registry
)
kwargs[key] = _convert_node(value, convert)

return _call_target(_target_, partial, args, kwargs, full_key)
value = _call_target(_target_, partial, args, kwargs, full_key)
if is_singleton:
singleton_registry[singleton_name] = value
return value
else:
# If ALL or PARTIAL non structured or OBJECT non structured,
# instantiate in dict and resolve interpolations eagerly.
Expand All @@ -398,15 +423,17 @@ def instantiate_node(
for key, value in node.items():
# list items inherits recursive flag from the containing dict.
dict_items[key] = instantiate_node(
value, convert=convert, recursive=recursive
value, convert=convert, recursive=recursive,
singleton_registry=singleton_registry
)
return dict_items
else:
# Otherwise use DictConfig and resolve interpolations lazily.
cfg = OmegaConf.create({}, flags={"allow_objects": True})
for key, value in node.items():
cfg[key] = instantiate_node(
value, convert=convert, recursive=recursive
value, convert=convert, recursive=recursive,
singleton_registry=singleton_registry
)
cfg._set_parent(node)
cfg._metadata.object_type = node._metadata.object_type
Expand Down