diff --git a/.github/workflows/deepretro_tests.yml b/.github/workflows/deepretro_tests.yml index adab4976..41ac5ccf 100644 --- a/.github/workflows/deepretro_tests.yml +++ b/.github/workflows/deepretro_tests.yml @@ -55,9 +55,16 @@ jobs: - name: Install deepretro run: | python -m pip install --upgrade pip setuptools wheel - pip install -e "deepretro/[dev]" + if [ "${{ matrix.python-version }}" = "3.11" ]; then + pip install -e "deepretro/[dev,az]" + else + pip install -e "deepretro/[dev]" + fi - name: Run tests run: | - pytest deepretro/tests/ -v - + if [ "${{ matrix.python-version }}" = "3.11" ]; then + pytest deepretro/tests/ -v + else + pytest deepretro/tests/ -v -m "not slow" + fi diff --git a/deepretro/pyproject.toml b/deepretro/pyproject.toml index 8b16e040..e9b2057e 100644 --- a/deepretro/pyproject.toml +++ b/deepretro/pyproject.toml @@ -29,6 +29,12 @@ docs = [ "sphinx-autodoc-typehints>=3.0", "myst-parser>=4.0", ] +az = [ + "aizynthfinder>=4.4.0", + "numpy>=1.26,<2", + "pandas>=2.2,<3", + "pillow", +] [project.urls] Repository = "https://github.com/deepforestsci/DeepRetro" diff --git a/deepretro/tests/test_az.py b/deepretro/tests/test_az.py index 8fc8902a..4f1191cc 100644 --- a/deepretro/tests/test_az.py +++ b/deepretro/tests/test_az.py @@ -16,13 +16,20 @@ from pathlib import Path import pytest - -# Skip entire module if aizynthfinder is not installed -aizynthfinder = pytest.importorskip("aizynthfinder") +from deepretro.utils.cache import CacheManager PROJECT_ROOT = Path(__file__).resolve().parents[2] AZ_MODULE_PATH = PROJECT_ROOT / "deepretro" / "utils" / "az.py" AZ_MODULE_NAME = "_deepretro_utils_az_under_test" +_BASIC_ROUTE = [ + { + "type": "mol", + "hide": False, + "smiles": "CCO", + "is_chemical": True, + "in_stock": True, + } +] def _download_aizynth_models(models_dir: Path) -> None: @@ -74,16 +81,6 @@ def az_module_with_models(az_module, az_models_dir): return az_module -def _get_run_az_impl(az_module): - """Get the underlying run_az implementation (bypassing cache).""" - return getattr(az_module.run_az, "__wrapped__", az_module.run_az) - - -def _get_run_az_with_img_impl(az_module): - """Get the underlying run_az_with_img implementation (bypassing cache).""" - return getattr(az_module.run_az_with_img, "__wrapped__", az_module.run_az_with_img) - - @pytest.mark.parametrize( ("smiles", "expected"), [ @@ -102,8 +99,7 @@ def test_is_basic_molecule_chemical_thresholds(az_module, smiles, expected): @pytest.mark.slow def test_run_az_returns_valid_result(az_module_with_models): """Test run_az with real AiZynthFinder returns (bool, list) with valid structure.""" - run_az = _get_run_az_impl(az_module_with_models) - status, routes = run_az("C1CCCCC1", az_model="USPTO") + status, routes = az_module_with_models.run_az("C1CCCCC1", az_model="USPTO") assert isinstance(status, bool) assert isinstance(routes, list) # Routes may be empty or contain dicts with expected keys @@ -115,9 +111,8 @@ def test_run_az_returns_valid_result(az_module_with_models): @pytest.mark.slow def test_run_az_uses_fallback_config_when_model_missing(az_module_with_models): """Test run_az falls back to AZ_MODEL_CONFIG_PATH when model-specific config missing.""" - run_az = _get_run_az_impl(az_module_with_models) # MISSING_MODEL has no config in AZ_MODELS_PATH, so uses fallback - status, routes = run_az("C1CCCCC1", az_model="MISSING_MODEL") + status, routes = az_module_with_models.run_az("C1CCCCC1", az_model="MISSING_MODEL") assert isinstance(status, bool) assert isinstance(routes, list) @@ -130,32 +125,46 @@ def test_run_az_raises_if_no_config_available(tmp_path, monkeypatch, az_module): monkeypatch.setattr(az_module, "AZ_MODEL_CONFIG_PATH", str(missing_fallback)) monkeypatch.setattr(az_module, "BASIC_MOLECULES", []) - run_az = _get_run_az_impl(az_module) with pytest.raises(FileNotFoundError, match=re.escape(str(missing_fallback))): - run_az("CCCCC", az_model="MISSING_MODEL") + az_module.run_az("CCCCC", az_model="MISSING_MODEL") -def test_run_az_short_circuits_for_feedstock_smiles(az_module_with_models): +def test_run_az_short_circuits_for_feedstock_smiles(az_module): """Test run_az returns early for basic molecules without calling AiZynthFinder.""" - run_az = _get_run_az_impl(az_module_with_models) - status, routes = run_az("CCO", az_model="USPTO") + status, routes = az_module.run_az("CCO", az_model="USPTO") + assert status is True + assert routes == _BASIC_ROUTE + + +def test_run_az_has_no_cache_side_effect_without_cache(az_module): + """run_az should not create cache state unless the caller passes one in.""" + cache = CacheManager() + + status, routes = az_module.run_az("CCO", az_model="USPTO") + assert status is True - assert routes == [ - { - "type": "mol", - "hide": False, - "smiles": "CCO", - "is_chemical": True, - "in_stock": True, - } - ] + assert routes == _BASIC_ROUTE + assert cache.stats().num_entries == 0 + + +def test_run_az_uses_explicit_cache(az_module): + """run_az should reuse results only when an explicit cache is provided.""" + cache = CacheManager() + + first = az_module.run_az("CCO", az_model="USPTO", cache=cache) + second = az_module.run_az("CCO", az_model="USPTO", cache=cache) + + assert first == second + assert cache.stats().hits == 1 + assert cache.stats().misses == 1 + assert cache.stats().num_entries == 1 + cache.close() @pytest.mark.slow def test_run_az_with_img_returns_valid_result(az_module_with_models): """Test run_az_with_img with real AiZynthFinder returns (bool, list, images).""" - run_az_with_img = _get_run_az_with_img_impl(az_module_with_models) - status, routes, images = run_az_with_img("C1CCCCC1") + status, routes, images = az_module_with_models.run_az_with_img("C1CCCCC1") assert isinstance(status, bool) assert isinstance(routes, list) # images can be list of PIL Images or None @@ -163,18 +172,24 @@ def test_run_az_with_img_returns_valid_result(az_module_with_models): assert isinstance(images, (list, tuple)) -def test_run_az_with_img_short_circuits_for_basic_molecules(az_module_with_models): +def test_run_az_with_img_short_circuits_for_basic_molecules(az_module): """Test run_az_with_img returns early for basic molecules.""" - run_az_with_img = _get_run_az_with_img_impl(az_module_with_models) - status, routes, images = run_az_with_img("CCO") + status, routes, images = az_module.run_az_with_img("CCO") assert status is True - assert routes == [ - { - "type": "mol", - "hide": False, - "smiles": "CCO", - "is_chemical": True, - "in_stock": True, - } - ] + assert routes == _BASIC_ROUTE assert images is None + + +def test_run_az_with_img_uses_explicit_cache(az_module): + """run_az_with_img should also cache only through the provided instance.""" + cache = CacheManager() + + first = az_module.run_az_with_img("CCO", cache=cache) + second = az_module.run_az_with_img("CCO", cache=cache) + + assert first == second + assert first == (True, _BASIC_ROUTE, None) + assert cache.stats().hits == 1 + assert cache.stats().misses == 1 + assert cache.stats().num_entries == 1 + cache.close() diff --git a/deepretro/utils/az.py b/deepretro/utils/az.py index 197eeae6..3c9a3635 100644 --- a/deepretro/utils/az.py +++ b/deepretro/utils/az.py @@ -3,52 +3,63 @@ Runs AiZynthFinder on target molecules, with optional image export. Uses ZINC stock and USPTO expansion/filter policies by default. Requires ``AZ_MODEL_CONFIG_PATH`` or ``AZ_MODELS_PATH`` environment variables. +Caching is opt-in through an explicit ``CacheManager`` argument. """ +from __future__ import annotations + +import importlib import os -from aizynthfinder.aizynthfinder import AiZynthFinder -from typing import Any, Dict, Sequence -from src.variables import BASIC_MOLECULES -from src.cache import cache_results -import rootutils +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Sequence, cast + +import structlog from rdkit import Chem from rdkit.Chem import rdqueries -from PIL.Image import Image -root_dir = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +from deepretro.utils.cache import CacheManager, make_cache_key +from deepretro.utils.variables import BASIC_MOLECULES -ENABLE_LOGGING = ( - False if os.getenv("ENABLE_LOGGING", "true").lower() == "false" else True -) +if TYPE_CHECKING: + from PIL.Image import Image -# Paths from env; required for AiZynthFinder config and model files -AZ_MODEL_CONFIG_PATH = f"{root_dir}/{os.getenv('AZ_MODEL_CONFIG_PATH')}" -AZ_MODELS_PATH = f"{root_dir}/{os.getenv('AZ_MODELS_PATH')}" +logger = structlog.get_logger() +PROJECT_ROOT = Path(__file__).resolve().parents[2] -def _log(message: str, logger=None): - """Log the message +AZ_MODEL_CONFIG_PATH = f"{PROJECT_ROOT}/{os.getenv('AZ_MODEL_CONFIG_PATH')}" +AZ_MODELS_PATH = f"{PROJECT_ROOT}/{os.getenv('AZ_MODELS_PATH')}" - Parameters - ---------- - message : str - The message to be logged - logger : _type_, optional - The logger object, by default None - Returns - ------- - None - """ - if logger is not None: - logger.info(message) - else: - print(message) +def _basic_molecule_route(smiles: str) -> list[Dict[str, Any]]: + """Return the solved route payload for a feedstock/basic molecule.""" + return [ + { + "type": "mol", + "hide": False, + "smiles": smiles, + "is_chemical": True, + "in_stock": True, + } + ] + + +def _get_aizynthfinder_cls(): + """Return the AiZynthFinder class or raise a helpful optional-dependency error.""" + try: + module = importlib.import_module("aizynthfinder.aizynthfinder") + except ImportError as exc: + raise ImportError( + "AiZynthFinder support requires optional dependencies. " + "Install the package with `deepretro[az]`." + ) from exc + return module.AiZynthFinder -@cache_results def run_az( - smiles: str, az_model: str = "USPTO" + smiles: str, + az_model: str = "USPTO", + cache: CacheManager | None = None, ) -> tuple[bool, Sequence[Dict[str, Any]]]: """Run the retrosynthesis using AiZynthFinder. @@ -66,18 +77,39 @@ def run_az( az_model : str, optional AiZynthFinder model variant (e.g. ``"USPTO"``, ``"Pistachio_50"``), by default ``"USPTO"``. + cache : CacheManager | None, optional + Explicit cache instance used to memoize results for this call. When + ``None``, no cache is read or written. Returns ------- tuple[bool, Sequence[Dict[str, Any]]] ``(solved, routes)`` — whether a route was found and the route data. + + Notes + ----- + Install the package with ``deepretro[az]``. Caching is disabled unless an + explicit ``cache=CacheManager(...)`` is supplied. """ + cache_key = make_cache_key("run_az", smiles, az_model=az_model, version=1) + cache_miss = object() + if cache is not None: + cached_result = cache.get(cache_key, default=cache_miss) + if cached_result is not cache_miss: + return cast(tuple[bool, Sequence[Dict[str, Any]]], cached_result) + + if smiles in BASIC_MOLECULES or is_basic_molecule(smiles): + result = (True, _basic_molecule_route(smiles)) + if cache is not None: + cache.set(cache_key, result, tag=smiles) + return result + try: config_path = f"{AZ_MODELS_PATH}/{az_model}/config.yml" with open(config_path, "r") as _: config_filename = config_path except FileNotFoundError: - _log(f"AZ_MODEL_CONFIG_PATH not found at {config_path}") + logger.warning("AZ config not found, trying fallback", path=config_path) try: with open(AZ_MODEL_CONFIG_PATH, "r") as _: config_filename = AZ_MODEL_CONFIG_PATH @@ -85,18 +117,8 @@ def run_az( raise FileNotFoundError( f"AZ_MODEL_CONFIG_PATH not found at {AZ_MODEL_CONFIG_PATH}" ) - # if simple molecule, skip the retrosynthesis - if smiles in BASIC_MOLECULES or is_basic_molecule(smiles): - return True, [ - { - "type": "mol", - "hide": False, - "smiles": smiles, - "is_chemical": True, - "in_stock": True, - } - ] - finder = AiZynthFinder(configfile=config_filename) + ai_zynth_finder_cls = _get_aizynthfinder_cls() + finder = ai_zynth_finder_cls(configfile=config_filename) finder.stock.select("zinc") finder.expansion_policy.select("uspto") finder.filter_policy.select("uspto") @@ -108,12 +130,15 @@ def run_az( result_dict = finder.routes.dict_with_extra( include_metadata=True, include_scores=True ) - return status, result_dict + result = (status, result_dict) + if cache is not None: + cache.set(cache_key, result, tag=smiles) + return result -@cache_results def run_az_with_img( smiles: str, + cache: CacheManager | None = None, ) -> tuple[bool, Sequence[Dict[str, Any]], Sequence[Image | None] | None]: """Run the retrosynthesis using AiZynthFinder. @@ -128,29 +153,39 @@ def run_az_with_img( ---------- smiles : str SMILES string of the target molecule. + cache : CacheManager | None, optional + Explicit cache instance used to memoize results for this call. When + ``None``, no cache is read or written. Returns ------- tuple[bool, Sequence[Dict[str, Any]], Sequence[Image] | None] ``(solved, routes, images)`` — solved status, route data, and optional route images (PNG bytes). Uses ``AZ_MODEL_CONFIG_PATH``. + + Notes + ----- + Install the package with ``deepretro[az]``. Caching is disabled unless an + explicit ``cache=CacheManager(...)`` is supplied. """ - # if simple molecule, skip the retrosynthesis + cache_key = make_cache_key("run_az_with_img", smiles, version=1) + cache_miss = object() + if cache is not None: + cached_result = cache.get(cache_key, default=cache_miss) + if cached_result is not cache_miss: + return cast( + tuple[bool, Sequence[Dict[str, Any]], Any], + cached_result, + ) + if smiles in BASIC_MOLECULES or is_basic_molecule(smiles): - return ( - True, - [ - { - "type": "mol", - "hide": False, - "smiles": smiles, - "is_chemical": True, - "in_stock": True, - } - ], - None, - ) - finder = AiZynthFinder(configfile=AZ_MODEL_CONFIG_PATH) + result = (True, _basic_molecule_route(smiles), None) + if cache is not None: + cache.set(cache_key, result, tag=smiles) + return result + + ai_zynth_finder_cls = _get_aizynthfinder_cls() + finder = ai_zynth_finder_cls(configfile=AZ_MODEL_CONFIG_PATH) finder.stock.select("zinc") finder.expansion_policy.select("uspto") finder.filter_policy.select("uspto") @@ -163,7 +198,10 @@ def run_az_with_img( include_metadata=True, include_scores=True ) images: Sequence[Image | None] = finder.routes.images - return status, result_dict, images + result = (status, result_dict, images) + if cache is not None: + cache.set(cache_key, result, tag=smiles) + return result def is_basic_molecule(smiles: str) -> bool: diff --git a/docs/source/package/deepretro.utils.az.rst b/docs/source/package/deepretro.utils.az.rst index 31b7e1d3..9aa4dc6f 100644 --- a/docs/source/package/deepretro.utils.az.rst +++ b/docs/source/package/deepretro.utils.az.rst @@ -3,14 +3,38 @@ deepretro.utils.az AiZynthFinder integration helpers for template-based retrosynthesis. +Installation +------------ + +Install the package with ``deepretro[az]`` before using this module: + +.. code-block:: bash + + uv pip install "deepretro[az]" + What This Module Does --------------------- - Runs AiZynthFinder search for a target SMILES. - Returns route dictionaries with metadata/scores. - Provides optional image outputs for generated routes. +- Supports optional explicit caching through ``CacheManager``. - Short-circuits simple molecules to avoid unnecessary search overhead. +Caching +------- + +``run_az`` and ``run_az_with_img`` do not cache anything unless the caller +passes a cache instance explicitly: + +.. code-block:: python + + from deepretro.utils.az import run_az + from deepretro.utils.cache import CacheManager + + cache = CacheManager() + solved, routes = run_az("C1CCCCC1", az_model="USPTO", cache=cache) + Configuration -------------