Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deepretro_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:

- name: Run ty check
run: |
uv run --project deepretro --with ty ty check deepretro \
uv run --isolated --no-project --with ty ty check deepretro \
--python-version 3.12 \
--ignore unresolved-import \
--ignore invalid-method-override \
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ coverage/
docs/build/
docs/_build/
*.egg-info/
deepretro/build/
.worktrees/
6 changes: 4 additions & 2 deletions deepretro/algorithms/stability_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@
* `is_valid_smiles` — quick check that a SMILES string parses.
"""

from __future__ import annotations

from typing import Any

from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.rdMolDescriptors import (
Expand All @@ -82,8 +86,6 @@
CalcNumBridgeheadAtoms,
)

from typing import Any

# Helpers


Expand Down
4 changes: 4 additions & 0 deletions deepretro/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ docs = [
"sphinx-autodoc-typehints>=3.0",
"myst-parser>=4.0",
]
az = [
"aizynthfinder",
"pillow",
]

[project.urls]
Repository = "https://github.com/deepforestsci/DeepRetro"
Expand Down
97 changes: 57 additions & 40 deletions deepretro/tests/test_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,23 @@
from pathlib import Path

import pytest
from deepretro.utils.cache import CacheManager

# Skip entire module if aizynthfinder is not installed
aizynthfinder = pytest.importorskip("aizynthfinder")

PROJECT_ROOT = Path(__file__).resolve().parents[2]
AZ_MODULE_PATH = PROJECT_ROOT / "deepretro" / "utils" / "az.py"
AZ_MODULE_NAME = "_deepretro_utils_az_under_test"
_BASIC_ROUTE = [
{
"type": "mol",
"hide": False,
"smiles": "CCO",
"is_chemical": True,
"in_stock": True,
}
]


def _download_aizynth_models(models_dir: Path) -> None:
Expand Down Expand Up @@ -74,16 +84,6 @@ def az_module_with_models(az_module, az_models_dir):
return az_module


def _get_run_az_impl(az_module):
"""Get the underlying run_az implementation (bypassing cache)."""
return getattr(az_module.run_az, "__wrapped__", az_module.run_az)


def _get_run_az_with_img_impl(az_module):
"""Get the underlying run_az_with_img implementation (bypassing cache)."""
return getattr(az_module.run_az_with_img, "__wrapped__", az_module.run_az_with_img)


@pytest.mark.parametrize(
("smiles", "expected"),
[
Expand All @@ -102,8 +102,7 @@ def test_is_basic_molecule_chemical_thresholds(az_module, smiles, expected):
@pytest.mark.slow
def test_run_az_returns_valid_result(az_module_with_models):
"""Test run_az with real AiZynthFinder returns (bool, list) with valid structure."""
run_az = _get_run_az_impl(az_module_with_models)
status, routes = run_az("C1CCCCC1", az_model="USPTO")
status, routes = az_module_with_models.run_az("C1CCCCC1", az_model="USPTO")
assert isinstance(status, bool)
assert isinstance(routes, list)
# Routes may be empty or contain dicts with expected keys
Expand All @@ -115,9 +114,8 @@ def test_run_az_returns_valid_result(az_module_with_models):
@pytest.mark.slow
def test_run_az_uses_fallback_config_when_model_missing(az_module_with_models):
"""Test run_az falls back to AZ_MODEL_CONFIG_PATH when model-specific config missing."""
run_az = _get_run_az_impl(az_module_with_models)
# MISSING_MODEL has no config in AZ_MODELS_PATH, so uses fallback
status, routes = run_az("C1CCCCC1", az_model="MISSING_MODEL")
status, routes = az_module_with_models.run_az("C1CCCCC1", az_model="MISSING_MODEL")
assert isinstance(status, bool)
assert isinstance(routes, list)

Expand All @@ -130,32 +128,46 @@ def test_run_az_raises_if_no_config_available(tmp_path, monkeypatch, az_module):
monkeypatch.setattr(az_module, "AZ_MODEL_CONFIG_PATH", str(missing_fallback))
monkeypatch.setattr(az_module, "BASIC_MOLECULES", [])

run_az = _get_run_az_impl(az_module)
with pytest.raises(FileNotFoundError, match=re.escape(str(missing_fallback))):
run_az("CCCCC", az_model="MISSING_MODEL")
az_module.run_az("CCCCC", az_model="MISSING_MODEL")


def test_run_az_short_circuits_for_feedstock_smiles(az_module_with_models):
"""Test run_az returns early for basic molecules without calling AiZynthFinder."""
run_az = _get_run_az_impl(az_module_with_models)
status, routes = run_az("CCO", az_model="USPTO")
status, routes = az_module_with_models.run_az("CCO", az_model="USPTO")
assert status is True
assert routes == [
{
"type": "mol",
"hide": False,
"smiles": "CCO",
"is_chemical": True,
"in_stock": True,
}
]
assert routes == _BASIC_ROUTE


def test_run_az_has_no_cache_side_effect_without_cache(
az_module_with_models, tmp_path, monkeypatch
):
"""run_az should not create cache state unless the caller passes one in."""
cache_dir = tmp_path / "cache"
monkeypatch.setenv("DEEPRETRO_CACHE_DIR", str(cache_dir))

status, routes = az_module_with_models.run_az("CCO", az_model="USPTO")

assert status is True
assert routes == _BASIC_ROUTE
assert not cache_dir.exists()


def test_run_az_uses_explicit_cache(az_module_with_models):
"""run_az should reuse results only when an explicit cache is provided."""
cache = CacheManager()

first = az_module_with_models.run_az("CCO", az_model="USPTO", cache=cache)
second = az_module_with_models.run_az("CCO", az_model="USPTO", cache=cache)

assert first == second
assert cache.stats().hits == 1
assert cache.stats().misses == 1
assert cache.stats().num_entries == 1
@pytest.mark.slow
def test_run_az_with_img_returns_valid_result(az_module_with_models):
"""Test run_az_with_img with real AiZynthFinder returns (bool, list, images)."""
run_az_with_img = _get_run_az_with_img_impl(az_module_with_models)
status, routes, images = run_az_with_img("C1CCCCC1")
status, routes, images = az_module_with_models.run_az_with_img("C1CCCCC1")
assert isinstance(status, bool)
assert isinstance(routes, list)
# images can be list of PIL Images or None
Expand All @@ -165,16 +177,21 @@ def test_run_az_with_img_returns_valid_result(az_module_with_models):

def test_run_az_with_img_short_circuits_for_basic_molecules(az_module_with_models):
"""Test run_az_with_img returns early for basic molecules."""
run_az_with_img = _get_run_az_with_img_impl(az_module_with_models)
status, routes, images = run_az_with_img("CCO")
status, routes, images = az_module_with_models.run_az_with_img("CCO")
assert status is True
assert routes == [
{
"type": "mol",
"hide": False,
"smiles": "CCO",
"is_chemical": True,
"in_stock": True,
}
]
assert routes == _BASIC_ROUTE
assert images is None


def test_run_az_with_img_uses_explicit_cache(az_module_with_models):
"""run_az_with_img should also cache only through the provided instance."""
cache = CacheManager()

first = az_module_with_models.run_az_with_img("CCO", cache=cache)
second = az_module_with_models.run_az_with_img("CCO", cache=cache)

assert first == second
assert first == (True, _BASIC_ROUTE, None)
assert cache.stats().hits == 1
assert cache.stats().misses == 1
assert cache.stats().num_entries == 1
Loading