diff --git a/.gitignore b/.gitignore index a0b1ec8..aba0860 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,4 @@ outputs/ src/ dev_scripts/ .claude/ +CLAUDE.md diff --git a/openquake/hme/utils/io/source_processing.py b/openquake/hme/utils/io/source_processing.py index f474dfe..b0bce09 100644 --- a/openquake/hme/utils/io/source_processing.py +++ b/openquake/hme/utils/io/source_processing.py @@ -195,6 +195,7 @@ def _process_source_chunk(source_chunk_w_args) -> list: rups = flatten_list(rups) + pbar.close() return rups @@ -446,7 +447,13 @@ def rupture_dict_to_gdf( dfs = [] for branch, branch_df in rupture_dict.items(): - branch_df["occurrence_rate"] *= weights[branch] + branch_df = branch_df.copy() + w = weights[branch] + if isinstance(w, dict): + src_ids = branch_df.index.str.rsplit("_", n=1).str[0] + branch_df["occurrence_rate"] *= src_ids.map(w) + else: + branch_df["occurrence_rate"] *= w branch_df.index = branch_df.index.values + f"_{branch}" branch_df["branch"] = branch diff --git a/openquake/hme/utils/io/source_reader.py b/openquake/hme/utils/io/source_reader.py index 9330d6c..3ce371e 100644 --- a/openquake/hme/utils/io/source_reader.py +++ b/openquake/hme/utils/io/source_reader.py @@ -11,15 +11,18 @@ from openquake.engine.engine import create_jobs, run_jobs from openquake.hazardlib.gsim_lt import GsimLogicTree +from openquake.hazardlib.source import MultiPointSource from openquake.hme.utils.utils import _get_class_name, breakpoint try: from openquake.hazardlib.source_group import read_csm + csm_new_flag = True except ImportError: csm_new_flag = False + def csm_from_job_ini(job_ini, get_gsim_lt: bool = False): if not isinstance(job_ini, dict) and os.path.isfile(job_ini): job_ini = get_params(job_ini) @@ -43,7 +46,7 @@ def csm_from_job_ini(job_ini, get_gsim_lt: bool = False): with job, datastore.read(job.calc_id) as dstore: if csm_new_flag: csm = read_csm(dstore) - else: # older OQ + else: # older OQ csm = dstore['_csm'] sources = csm.get_sources() logging.debug("\tgot csm from dstore") @@ -145,7 +148,7 @@ def process_source_logic_tree_oq( gmm_lt_file: str = "gmmLT.xml", sites_file: Optional[str] = None, branch: Optional[str] = None, - collapse_lt: Optional[bool] = False, + collapse_lt: Optional[bool] = True, source_types: Optional[Sequence] = None, tectonic_region_types: Optional[Sequence] = None, description: Optional[str] = None, @@ -218,19 +221,17 @@ def process_source_logic_tree_oq( s.num_ruptures for s in ssm_lt_sources["composite"] ] } - source_weights = list(sources_w_weights.values()) - ssm_lt_weights = {"composite": []} - - for i, rup_count in enumerate(ssm_lt_rup_counts["composite"]): - ssm_lt_weights["composite"].append( - np.ones(rup_count) * source_weights[i] - ) - - ssm_lt_weights["composite"] = np.hstack( - ssm_lt_weights["composite"] - ) + src_weight_dict = {} + for src, w in sources_w_weights.items(): + if isinstance(src, MultiPointSource): + for sub_src in src: + src_weight_dict[sub_src.source_id] = w + else: + src_weight_dict[src.source_id] = w + + ssm_lt_weights = {"composite": src_weight_dict} logging.info( - f"{len(ssm_lt_weights['composite']):_} rups in composite model" + f"{len(ssm_lt_weights['composite']):_} sources in composite model" ) else: ssm_lt_sources = branch_sources diff --git a/openquake/hme/utils/io/tests/test_io.py b/openquake/hme/utils/io/tests/test_io.py index 8e3aed2..a4bd3d7 100644 --- a/openquake/hme/utils/io/tests/test_io.py +++ b/openquake/hme/utils/io/tests/test_io.py @@ -1,13 +1,30 @@ +import os +import pathlib import unittest + import numpy as np import pandas as pd +from openquake.hme.core.core import read_yaml_config from openquake.hme.utils.io import read_rupture_file from openquake.hme.utils.simple_rupture import SimpleRupture from openquake.hme.utils.tests.load_sm1 import cfg, input_data, eq_gdf, rup_gdf +from openquake.hme.utils.io.source_reader import ( + process_source_logic_tree_oq, +) + +from openquake.hme.utils.io.source_processing import ( + rupture_dict_from_logic_tree_dict, + rupture_dict_to_gdf, +) + +from openquake.hme.utils.utils import breakpoint + +BASE_PATH = pathlib.Path(os.path.dirname(__file__)) + def test_read_rupture_file(): rup_fp = cfg["input"]["rupture_file"]["rupture_file_path"] @@ -24,4 +41,58 @@ def test_read_rupture_file(): if isinstance(param_r1, str): assert param_r1 == param_r2 else: - np.testing.assert_almost_equal(param_r1, param_r2, decimal=2) + np.testing.assert_almost_equal( + param_r1, param_r2, decimal=2 + ) + + +def test_2_branches(): + test_dir = ( + BASE_PATH + / '..' + / '..' + / 'tests' + / 'data' + / 'source_models' + / '2_branches' + ) + cfg = read_yaml_config(test_dir / 'test_2_ssm_branches.yaml') + source_cfg = cfg['input']['ssm'] + ( + b_ssm_lt_sources, + b_ssm_lt_weights, + b_ssm_lt_rup_counts, + b_gsim_lt, + ) = process_source_logic_tree_oq( + source_cfg["job_ini_file"], + test_dir / source_cfg["ssm_dir"], + ) + + ( + c_ssm_lt_sources, + c_ssm_lt_weights, + c_ssm_lt_rup_counts, + c_gsim_lt, + ) = process_source_logic_tree_oq( + source_cfg["job_ini_file"], + test_dir / source_cfg["ssm_dir"], + collapse_lt=True, + ) + + branch_rdf = rupture_dict_to_gdf( + rupture_dict_from_logic_tree_dict( + b_ssm_lt_sources, + b_ssm_lt_rup_counts, + ), + b_ssm_lt_weights, + ) + + collapse_rdf = rupture_dict_to_gdf( + rupture_dict_from_logic_tree_dict( + c_ssm_lt_sources, + c_ssm_lt_rup_counts, + ), + c_ssm_lt_weights, + ) + + breakpoint() diff --git a/openquake/hme/utils/io/tests/test_source_reader.py b/openquake/hme/utils/io/tests/test_source_reader.py index 18dbda0..4fbb628 100644 --- a/openquake/hme/utils/io/tests/test_source_reader.py +++ b/openquake/hme/utils/io/tests/test_source_reader.py @@ -6,6 +6,7 @@ from openquake.hme.core.core import read_yaml_config from openquake.hme.utils.tests import load_sm1 +from openquake.hme.utils.utils import breakpoint from openquake.hme.utils.io.source_reader import ( csm_from_job_ini, # get_csm_rlzs, @@ -19,8 +20,6 @@ BASE_PATH = pathlib.Path(os.path.dirname(__file__)) -source_cfg = load_sm1.cfg["input"]["ssm"] - # w/ job ini # get job ini # get csm @@ -32,6 +31,8 @@ def test_single_branch_without_job_ini(): + source_cfg = load_sm1.cfg["input"]["ssm"] + def test_make_job_ini(): job_ini = make_job_ini( source_cfg["ssm_dir"], @@ -98,6 +99,7 @@ def test_csm_from_job_ini(): def test_process_source_logic_tree_oq(): + source_cfg = load_sm1.cfg["input"]["ssm"] ( ssm_lt_sources, ssm_lt_weights, @@ -119,11 +121,18 @@ def test_process_source_logic_tree_oq(): assert list(ssm_lt_weights.keys()) == [0] assert ssm_lt_weights == {0: 1.0} - + def test_2_branches_compound(): - test_dir = (BASE_PATH / '..' / '..' / 'tests' / 'data' / 'source_models' / - '2_branches') + test_dir = ( + BASE_PATH + / '..' + / '..' + / 'tests' + / 'data' + / 'source_models' + / '2_branches' + ) cfg = read_yaml_config(test_dir / 'test_2_ssm_branches.yaml') source_cfg = cfg['input']['ssm'] ( @@ -134,24 +143,55 @@ def test_2_branches_compound(): ) = process_source_logic_tree_oq( source_cfg["job_ini_file"], test_dir / source_cfg["ssm_dir"], - ) + ) - assert tuple(ssm_lt_sources.keys()) == (0,1) + assert tuple(ssm_lt_sources.keys()) == (0, 1) assert len(ssm_lt_sources[0]) == 2 assert ssm_lt_sources[0][0].__class__.__name__ == 'PointSource' assert ssm_lt_weights == {0: 0.75, 1: 0.25} assert ssm_lt_rup_counts == {0: [1, 1], 1: [1, 1]} # no need to test gsim_lt + # breakpoint() -@unittest.skip("not implemented correctly") + +# @unittest.skip("not implemented correctly") def test_2_branches_collapse(): - pass + test_dir = ( + BASE_PATH + / '..' + / '..' + / 'tests' + / 'data' + / 'source_models' + / '2_branches' + ) + cfg = read_yaml_config(test_dir / 'test_2_ssm_branches.yaml') + source_cfg = cfg['input']['ssm'] + ( + ssm_lt_sources, + ssm_lt_weights, + ssm_lt_rup_counts, + gsim_lt, + ) = process_source_logic_tree_oq( + source_cfg["job_ini_file"], + test_dir / source_cfg["ssm_dir"], + collapse_lt=True, + ) + + # breakpoint() def test_2_branches_1_branch(): - test_dir = (BASE_PATH / '..' / '..' / 'tests' / 'data' / 'source_models' / - '2_branches') + test_dir = ( + BASE_PATH + / '..' + / '..' + / 'tests' + / 'data' + / 'source_models' + / '2_branches' + ) cfg = read_yaml_config(test_dir / 'test_2_ssm_branches.yaml') source_cfg = cfg['input']['ssm'] source_cfg['branch'] = 1 @@ -164,9 +204,9 @@ def test_2_branches_1_branch(): source_cfg["job_ini_file"], test_dir / source_cfg["ssm_dir"], branch=source_cfg['branch'], - ) + ) assert tuple(ssm_lt_sources.keys()) == (1,) assert len(ssm_lt_sources[1]) == 2 assert ssm_lt_weights == {1: 1.0} - assert ssm_lt_rup_counts == {1: [1,1]} + assert ssm_lt_rup_counts == {1: [1, 1]}