diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py index 1f8bb33..acccebd 100755 --- a/src/pycea/tl/ancestral_states.py +++ b/src/pycea/tl/ancestral_states.py @@ -181,6 +181,18 @@ def traceback(node, parent_value_index): del tree.nodes[node]["_pointers"] +def _reconstruct_sum(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None: + """Reconstructs ancestral states by summing leaf values with an iterative bottom-up traversal.""" + for node in reversed(list(nx.topological_sort(tree))): + val = _get_node_value(tree, node, key, index) + is_fixed = fixed_nodes is not None and node in fixed_nodes and val is not None + if tree.out_degree(node) == 0 or is_fixed: + continue + children_values = [_get_node_value(tree, child, key, index) for child in tree.successors(node)] + valid = [v for v in children_values if v is not None] + _set_node_value(tree, node, key, sum(valid) if valid else None, index) + + def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None: """Reconstructs ancestral by averaging the values of the children.""" @@ -243,6 +255,8 @@ def _ancestral_states( _reconstruct_fitch_hartigan(tree, key, missing, index, fixed_nodes) elif method == "mean": _reconstruct_mean(tree, key, index, fixed_nodes) + elif method == "sum": + _reconstruct_sum(tree, key, index, fixed_nodes) elif method == "mode": _reconstruct_list(tree, key, _most_common, index, fixed_nodes) elif callable(method): @@ -308,6 +322,7 @@ def ancestral_states( Method to reconstruct ancestral states: * 'mean' : The mean of leaves in subtree. + * 'sum' : The sum of leaves in subtree (iterative bottom-up traversal). * 'mode' : The most common value in the subtree. * 'fitch_hartigan' : The Fitch-Hartigan algorithm. * 'sankoff' : The Sankoff algorithm with specified costs. @@ -365,7 +380,7 @@ def ancestral_states( if method in ["fitch_hartigan", "sankoff"]: raise ValueError(f"Method {method} requires categorical data.") if dtypes.intersection({"O", "S"}): - if method in ["mean"]: + if method in ["mean", "sum"]: raise ValueError(f"Method {method} requires numeric data.") # Determine fixed internal nodes for nodes/subset alignment leaves_set = set(get_leaves(t)) diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py index c9d7212..64da24e 100755 --- a/tests/test_ancestral_states.py +++ b/tests/test_ancestral_states.py @@ -143,6 +143,37 @@ def test_ancestral_states_nodes_fitch(nodes_tdata): assert tree.nodes["C"]["str_value"] == "1" # C value preserved +def test_ancestral_states_sum(tdata): + # tree1: root -> B(0), C; C -> D(0), E(3) [index order: B=0, D=0, E=3, F=2] + # C sum = 0+3 = 3; root sum = 0+3 = 3 + states = ancestral_states(tdata, "value", method="sum", copy=True) + assert tdata.obst["tree1"].nodes["C"]["value"] == 3 + assert tdata.obst["tree1"].nodes["root"]["value"] == 3 + # tree2: root -> F(2); root sum = 2 + assert tdata.obst["tree2"].nodes["root"]["value"] == 2 + assert states is not None + assert states["value"].loc[("tree1", "root")] == 3 + + +def test_ancestral_states_sum_array(tdata): + # spatial tree1: B=[0,4], D=[1,1], E=[2,1] + # C sum = [1+2, 1+1] = [3, 2]; root sum = [0+3, 4+2] = [3, 6] + states = ancestral_states(tdata, "spatial", method="sum", copy=True) + assert tdata.obst["tree1"].nodes["C"]["spatial"] == [3, 2] + assert tdata.obst["tree1"].nodes["root"]["spatial"] == [3, 6] + assert states is not None + assert states.loc[("tree1", "root"), "spatial"] == [3, 6] + + +def test_ancestral_states_sum_fixed_nodes(nodes_tdata): + # C=5 (fixed), so C is treated as a leaf for reconstruction + # root sum = B(0) + C(5) = 5 + ancestral_states(nodes_tdata, "value", method="sum", copy=False) + tree = nodes_tdata.obst["tree"] + assert tree.nodes["C"]["value"] == 5 # fixed, unchanged + assert tree.nodes["root"]["value"] == pytest.approx(5.0) + + def test_ancestral_states_invalid(tdata): with pytest.raises(ValueError): ancestral_states(tdata, "characters", method="sankoff")