Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions test/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,37 @@ def test_tensordict_sequential_config(self):
assert seq.in_keys == ["observation"]
assert "action" in seq.out_keys

@pytest.mark.skipif(not _has_hydra, reason="Hydra is not installed")
def test_tanh_module_config(self):
"""Test TanhModuleConfig."""
from hydra.utils import instantiate
from torchrl.trainers.algorithms.configs.modules import TanhModuleConfig

cfg = TanhModuleConfig(
in_keys=["action"],
out_keys=["action"],
low=-1.0,
high=1.0,
clamp=False,
)
assert (
cfg._target_
== "torchrl.trainers.algorithms.configs.modules._make_tanh_module"
)
assert cfg.in_keys == ["action"]
assert cfg.out_keys == ["action"]
assert cfg.low == -1.0
assert cfg.high == 1.0
assert cfg.clamp is False

# Test instantiation
tanh_module = instantiate(cfg)
from torchrl.modules import TanhModule

assert isinstance(tanh_module, TanhModule)
assert tanh_module.in_keys == ["action"]
assert tanh_module.out_keys == ["action"]

@pytest.mark.skipif(not _has_hydra, reason="Hydra is not installed")
def test_value_model_config(self):
"""Test ValueModelConfig."""
Expand Down
3 changes: 3 additions & 0 deletions torchrl/trainers/algorithms/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
ConvNetConfig,
MLPConfig,
ModelConfig,
TanhModuleConfig,
TanhNormalModelConfig,
TensorDictModuleConfig,
TensorDictSequentialConfig,
Expand Down Expand Up @@ -264,6 +265,7 @@
"ConvNetConfig",
"MLPConfig",
"ModelConfig",
"TanhModuleConfig",
"TanhNormalModelConfig",
"TensorDictModuleConfig",
"TensorDictSequentialConfig",
Expand Down Expand Up @@ -441,6 +443,7 @@ def _register_configs():
cs.store(
group="network", name="tensordict_sequential", node=TensorDictSequentialConfig
)
cs.store(group="model", name="tanh_module", node=TanhModuleConfig)
cs.store(group="model", name="tanh_normal", node=TanhNormalModelConfig)
cs.store(group="model", name="value", node=ValueModelConfig)

Expand Down
39 changes: 39 additions & 0 deletions torchrl/trainers/algorithms/configs/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,29 @@ def __post_init__(self) -> None:
super().__post_init__()


@dataclass
class TanhModuleConfig(ModelConfig):
"""A class to configure a TanhModule.

Example:
>>> cfg = TanhModuleConfig(in_keys=["action"], out_keys=["action"], low=-1.0, high=1.0)
>>> module = instantiate(cfg)
>>> assert isinstance(module, TanhModule)

.. seealso:: :class:`torchrl.modules.TanhModule`
"""

spec: Any = None
low: Any = None
high: Any = None
clamp: bool = False
_target_: str = "torchrl.trainers.algorithms.configs.modules._make_tanh_module"

def __post_init__(self) -> None:
"""Post-initialization hook for TanhModule configurations."""
super().__post_init__()


def _make_tensordict_module(*args, **kwargs):
"""Helper function to create a TensorDictModule."""
from hydra.utils import instantiate
Expand Down Expand Up @@ -472,3 +495,19 @@ def _make_value_model(*args, **kwargs):
value_operator = value_operator.share_memory()

return value_operator


def _make_tanh_module(*args, **kwargs):
"""Helper function to create a TanhModule."""
from omegaconf import ListConfig

from torchrl.modules import TanhModule

kwargs.pop("shared", False)

if "in_keys" in kwargs and isinstance(kwargs["in_keys"], ListConfig):
kwargs["in_keys"] = list(kwargs["in_keys"])
if "out_keys" in kwargs and isinstance(kwargs["out_keys"], ListConfig):
kwargs["out_keys"] = list(kwargs["out_keys"])

return TanhModule(**kwargs)