From c2414597a6caa92029bfe332f4ff81ee4012d1b9 Mon Sep 17 00:00:00 2001 From: joehart2001 Date: Tue, 21 Apr 2026 15:07:43 +0100 Subject: [PATCH 1/2] add head name to test parity plots --- mace/cli/run_train.py | 2 +- mace/cli/visualise_train.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 5212269c6..f1d13ad48 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -993,7 +993,7 @@ def run(args) -> None: for head_config in head_configs: if all(check_path_ase_read(f) for f in head_config.train_file): for name, subset in head_config.collections.tests: - test_sets[name] = [ + test_sets[head_config.head_name + "_" + name] = [ data.AtomicData.from_config( config, z_table=z_table, cutoff=args.r_max, heads=heads ) diff --git a/mace/cli/visualise_train.py b/mace/cli/visualise_train.py index d5bab6a92..42cbb3d57 100644 --- a/mace/cli/visualise_train.py +++ b/mace/cli/visualise_train.py @@ -386,8 +386,11 @@ def plot_inference_from_results( # Plot test data (single legend entry) for name, result in test_dict.items(): + if head not in name: + continue # Initialize scatter to None to avoid possibly used before assignment scatter = None + test_label = f"Test ({name})" if key == "energy" and "energy" in result: e_key = "energy" if not plot_interaction_e else "interaction_energy" @@ -396,7 +399,7 @@ def plot_inference_from_results( result[e_key]["predicted_per_atom"], marker="o", color=fixed_color_test, - label="Test", + label=test_label, ) elif key == "force" and "forces" in result: @@ -405,7 +408,7 @@ def plot_inference_from_results( result["forces"]["predicted"], marker="o", color=fixed_color_test, - label="Test", + label=test_label, ) elif key == "stress" and "stress" in result: @@ -414,7 +417,7 @@ def plot_inference_from_results( result["stress"]["predicted"], marker="o", color=fixed_color_test, - label="Test", + label=test_label, ) elif key == "virials" and "virials" in result: @@ -423,7 +426,7 @@ def plot_inference_from_results( result["virials"]["predicted_per_atom"], marker="o", color=fixed_color_test, - label="Test", + label=test_label, ) elif key == "dipole" and "dipole" in result: @@ -432,7 +435,7 @@ def plot_inference_from_results( result["dipole"]["predicted_per_atom"], marker="o", color=fixed_color_test, - label="Test", + label=test_label, ) # Only add to legend_labels if scatter was assigned From 9d8ef53b850c84976b506b7bb06660105bad671b Mon Sep 17 00:00:00 2001 From: joehart2001 Date: Tue, 21 Apr 2026 18:33:49 +0100 Subject: [PATCH 2/2] remove redundant scatter labels and label test head --- mace/cli/visualise_train.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mace/cli/visualise_train.py b/mace/cli/visualise_train.py index 42cbb3d57..b1eb24c52 100644 --- a/mace/cli/visualise_train.py +++ b/mace/cli/visualise_train.py @@ -384,13 +384,12 @@ def plot_inference_from_results( fixed_color_test = colors[2] # Color for test dataset - # Plot test data (single legend entry) + # Plot test data (single legend entry per head) for name, result in test_dict.items(): if head not in name: continue # Initialize scatter to None to avoid possibly used before assignment scatter = None - test_label = f"Test ({name})" if key == "energy" and "energy" in result: e_key = "energy" if not plot_interaction_e else "interaction_energy" @@ -399,7 +398,6 @@ def plot_inference_from_results( result[e_key]["predicted_per_atom"], marker="o", color=fixed_color_test, - label=test_label, ) elif key == "force" and "forces" in result: @@ -408,7 +406,6 @@ def plot_inference_from_results( result["forces"]["predicted"], marker="o", color=fixed_color_test, - label=test_label, ) elif key == "stress" and "stress" in result: @@ -417,7 +414,6 @@ def plot_inference_from_results( result["stress"]["predicted"], marker="o", color=fixed_color_test, - label=test_label, ) elif key == "virials" and "virials" in result: @@ -426,7 +422,6 @@ def plot_inference_from_results( result["virials"]["predicted_per_atom"], marker="o", color=fixed_color_test, - label=test_label, ) elif key == "dipole" and "dipole" in result: @@ -435,12 +430,11 @@ def plot_inference_from_results( result["dipole"]["predicted_per_atom"], marker="o", color=fixed_color_test, - label=test_label, ) # Only add to legend_labels if scatter was assigned if scatter is not None: - legend_labels["Test"] = scatter + legend_labels[f"Test {head}"] = scatter # Add diagonal line for guide min_val = min(ax.get_xlim()[0], ax.get_ylim()[0])