Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/deepretro_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,3 @@ jobs:
- name: Run tests
run: |
pytest deepretro/tests/ -v

2 changes: 2 additions & 0 deletions deepretro/algorithms/stability_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
* `is_valid_smiles` — quick check that a SMILES string parses.
"""

from __future__ import annotations

from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem.rdMolDescriptors import (
Expand Down
5 changes: 4 additions & 1 deletion deepretro/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Sequence

import numpy as np
from sklearn.metrics import precision_recall_curve


def find_optimal_threshold(
Expand Down Expand Up @@ -42,6 +41,10 @@ def find_optimal_threshold(
>>> f1 > 0.0
True
"""
# Import lazily so docs and lightweight tooling can import this module
# without pulling in sklearn's full scipy stack at module import time.
from sklearn.metrics import precision_recall_curve

precision, recall, thresholds = precision_recall_curve(y_true, probabilities)
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
best_idx = np.argmax(f1_scores)
Expand Down
6 changes: 6 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,17 @@
'diskcache',
'dotenv',
'jax',
'lightgbm',
'numpy',
'pandas',
'PIL',
'rdkit',
'rootutils',
'sklearn',
'structlog',
'tensorflow',
'torch',
'xgboost',
]

# Todo settings
Expand Down
Loading