diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index 2f09ece868b..4b1ccd9aa7c 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -4,7 +4,7 @@ import functools from enum import Enum from textwrap import dedent -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from omegaconf import OmegaConf, SCMode from omegaconf._utils import is_structured_config @@ -145,7 +145,12 @@ def _resolve_target( return target -def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: +def instantiate( + config: Any, + *args: Any, + target_wrapper: Optional[Callable[..., Any]] = None, + **kwargs: Any, +) -> Any: """ :param config: An config object describing what to call and what params to use. In addition to the parameters, the config must contain: @@ -168,6 +173,7 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: are converted to dicts / lists too. _partial_: If True, return functools.partial wrapped method or object False by default. Configure per target. + :param target_wrapper: Optional callable wrap _target_ with before it itself is called. :param args: Optional positional parameters pass-through :param kwargs: Optional named parameters to override parameters in the config object. Parameters not present @@ -224,7 +230,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: _partial_ = config.pop(_Keys.PARTIAL, False) return instantiate_node( - config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + config, + *args, + recursive=_recursive_, + convert=_convert_, + partial=_partial_, + target_wrapper=target_wrapper, ) elif OmegaConf.is_list(config): # Finalize config (convert targets to strings, merge with kwargs) @@ -247,7 +258,12 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any: ) return instantiate_node( - config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_ + config, + *args, + recursive=_recursive_, + convert=_convert_, + partial=_partial_, + target_wrapper=target_wrapper, ) else: raise InstantiationException( @@ -281,6 +297,7 @@ def instantiate_node( convert: Union[str, ConvertMode] = ConvertMode.NONE, recursive: bool = True, partial: bool = False, + target_wrapper: Optional[Callable[..., Any]] = None, ) -> Any: # Return None if config is None if node is None or (OmegaConf.is_config(node) and node._is_none()): @@ -314,7 +331,12 @@ def instantiate_node( # If OmegaConf list, create new list of instances if recursive if OmegaConf.is_list(node): items = [ - instantiate_node(item, convert=convert, recursive=recursive) + instantiate_node( + item, + convert=convert, + recursive=recursive, + target_wrapper=target_wrapper, + ) for item in node._iter_ex(resolve=True) ] @@ -331,6 +353,9 @@ def instantiate_node( exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"}) if _is_target(node): _target_ = _resolve_target(node.get(_Keys.TARGET), full_key) + if target_wrapper: + _target_ = target_wrapper(_target_) + kwargs = {} is_partial = node.get("_partial_", False) or partial for key in node.keys(): @@ -340,7 +365,10 @@ def instantiate_node( value = node[key] if recursive: value = instantiate_node( - value, convert=convert, recursive=recursive + value, + convert=convert, + recursive=recursive, + target_wrapper=target_wrapper, ) kwargs[key] = _convert_node(value, convert)