Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions deepretro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
and model wrappers for reaction-step data.
"""

__all__ = ["ReactionStepFeaturizer", "configure_logging", "HallucinationClassifier"]

__all__ = [
"ReactionStepFeaturizer",
"configure_logging",
"HallucinationClassifier",
"autosolve",
"autosolve_async",
]

def __getattr__(name: str):
if name == "ReactionStepFeaturizer":
Expand All @@ -21,4 +26,10 @@ def __getattr__(name: str):
from deepretro.models.hallucination_classifier import HallucinationClassifier

return HallucinationClassifier
if name == "autosolve":
from deepretro.utils.autosolve import autosolve
return autosolve
if name == "autosolve_async":
from deepretro.utils.autosolve import autosolve_async
return autosolve_async
raise AttributeError(f"module 'deepretro' has no attribute {name!r}")
8 changes: 4 additions & 4 deletions deepretro/featurizers/reactionstep.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ReactionStepFeaturizer(Featurizer):

1. CircularFingerprint (Morgan/ECFP) for the product — ``size`` bits
2. CircularFingerprint (Morgan/ECFP) for the reactants — ``size`` bits
3. 15 hand-crafted domain features (optional)
3. 27 domain features (optional)

Parameters
----------
Expand All @@ -23,7 +23,7 @@ class ReactionStepFeaturizer(Featurizer):
size : int, optional (default 2048)
Fingerprint bit length for each molecule.
use_domain_features : bool, optional (default True)
If True, appends 15 domain features (atom/bond/ring/MW deltas).
If True, appends 27 domain features.

Notes
-----
Expand All @@ -36,7 +36,7 @@ class ReactionStepFeaturizer(Featurizer):
>>> reactions = [("CCO", "CC.O"), ("c1ccccc1", "c1ccccc1.Cl")]
>>> X = featurizer.featurize(reactions)
>>> X.shape
(2, 4111)
(2, 4123)
"""

def __init__(
Expand All @@ -55,7 +55,7 @@ def feature_dim(self) -> int:
Returns
-------
dim : int
``2 * size + 15`` when ``use_domain_features=True``,
``2 * size + 27`` when ``use_domain_features=True``,
``2 * size`` otherwise.
"""
return 2 * self.size + (NUM_DOMAIN_FEATURES if self.use_domain_features else 0)
Expand Down
95 changes: 95 additions & 0 deletions deepretro/models/hallucination_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Helpers for wiring hallucination checkers into the retrosynthesis pipeline.

Provides:

* :func:`build_ml_checker` — wrap a classifier into the callable
signature the pipeline expects.
* :func:`resolve_hallucination_args` — turn a user-friendly mode string
into the ``(hallucination_check, hallucination_checker_fn)`` pair
consumed by ``llm_pipeline``.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

from deepretro.utils.utils_molecule import is_valid_smiles

VALID_MODES = ("heuristic", "ml", "none")


def build_ml_checker(clf: Any):
"""Wrap a ``HallucinationClassifier`` into a callable with the same
signature as ``src.utils.hallucination_checks.hallucination_checker``:

(product: str, pathways: list) -> (int, list)

Pathways flagged as hallucinated are dropped, exactly like the
heuristic checker. This plugs into ``llm_pipeline``'s retry loop
so rejected results trigger a new LLM call.
"""
def _checker(product: str, pathways: list) -> tuple[int, list]:
valid = []
for pathway in pathways:
if isinstance(pathway, list):
reactants_smi = ".".join(pathway)
else:
reactants_smi = pathway

if not is_valid_smiles(reactants_smi):
continue

pred = clf.predict_single(product, reactants_smi)
if not pred.get("is_hallucination", True):
valid.append(pathway)

return 200, valid

return _checker


def resolve_hallucination_args(
hallucination_mode: str,
hallucination_classifier: Any,
) -> tuple[str, Any]:
"""Return ``(hallucination_check, hallucination_checker_fn)`` for the
pipeline based on *hallucination_mode*.

Parameters
----------
hallucination_mode : str
One of ``"heuristic"``, ``"ml"``, or ``"none"``.
hallucination_classifier : HallucinationClassifier or str or Path or None
Required when *hallucination_mode* is ``"ml"``. Pass a fitted
``HallucinationClassifier`` instance or a ``str`` / ``Path``
pointing to a saved model directory.
"""
if hallucination_mode not in VALID_MODES:
raise ValueError(
f"hallucination_mode must be one of {VALID_MODES}, "
f"got {hallucination_mode!r}"
)

if hallucination_mode == "none":
return "False", None

if hallucination_mode == "heuristic":
return "True", None

# mode == "ml" -- resolve the classifier
from deepretro.models.hallucination_classifier import HallucinationClassifier

if isinstance(hallucination_classifier, (str, Path)):
clf = HallucinationClassifier()
clf.load(str(hallucination_classifier))
elif isinstance(hallucination_classifier, HallucinationClassifier):
clf = hallucination_classifier
else:
raise ValueError(
"hallucination_mode='ml' requires hallucination_classifier "
"to be a HallucinationClassifier instance or a path to a "
f"saved model directory — got {type(hallucination_classifier)}"
)

return "True", build_ml_checker(clf)
Loading
Loading