diff --git a/pixi.lock b/pixi.lock index 6a966d57..633d1ff7 100644 --- a/pixi.lock +++ b/pixi.lock @@ -257,6 +257,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -524,6 +525,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -771,6 +773,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -1018,6 +1021,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-ng-2.3.3-h0261ad2_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -1284,6 +1288,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -1532,6 +1537,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -1776,6 +1782,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-ng-2.3.3-h0261ad2_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -2066,6 +2073,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -2326,6 +2334,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -2575,6 +2584,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-ng-2.3.3-h0261ad2_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -2897,6 +2907,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -3230,6 +3241,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -3563,6 +3575,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ba/6c/ff8bf52315064dbeb55cb5067e191120a5b2e58bb648d0d34cf7969dc2c2/choreographer-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ae/44/c1221527f6a71a01ec6fbad7fa78f1d50dfa02217385cf0fa3eec7087d59/click-8.3.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -3840,6 +3853,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-ng-2.3.3-hceb46e0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -4091,6 +4105,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -4338,6 +4353,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/zlib-ng-2.3.3-h0261ad2_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/zstd-1.5.7-h534d264_6.conda - pypi: https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/9f/d73dfb85d7a5b1a56a99adc50f2074029468168c970ff5daeade4ad819e4/choreographer-1.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2c/1a/aff8bb287a4b1400f69e09a53bd65de96aa5cee5691925b38731c67fc695/click_default_group-1.2.4-py2.py3-none-any.whl @@ -5084,6 +5100,111 @@ packages: purls: [] size: 7546 timestamp: 1777848733980 +- pypi: https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl + name: beartype + version: 0.22.9 + sha256: d16c9bbc61ea14637596c5f6fbff2ee99cbe3573e46a716401734ef50c3060c2 + requires_dist: + - autoapi>=0.9.0 ; extra == 'dev' + - celery ; extra == 'dev' + - click ; extra == 'dev' + - coverage>=5.5 ; extra == 'dev' + - docutils>=0.22.0 ; extra == 'dev' + - equinox ; python_full_version < '3.15' and sys_platform == 'linux' and extra == 'dev' + - fastmcp ; python_full_version < '3.14' and extra == 'dev' + - jax[cpu] ; python_full_version < '3.15' and sys_platform == 'linux' and extra == 'dev' + - jaxtyping ; sys_platform == 'linux' and extra == 'dev' + - langchain ; python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'darwin' and extra == 'dev' + - mypy>=0.800 ; platform_python_implementation != 'PyPy' and extra == 'dev' + - nuitka>=1.2.6 ; python_full_version < '3.14' and sys_platform == 'linux' and extra == 'dev' + - numba ; python_full_version < '3.14' and extra == 'dev' + - numpy ; python_full_version < '3.15' and platform_python_implementation != 'PyPy' and sys_platform != 'darwin' and extra == 'dev' + - pandera>=0.26.0 ; python_full_version < '3.14' and extra == 'dev' + - poetry ; extra == 'dev' + - polars ; python_full_version < '3.14' and extra == 'dev' + - pydata-sphinx-theme<=0.7.2 ; extra == 'dev' + - pygments ; extra == 'dev' + - pyinstaller ; extra == 'dev' + - pyright>=1.1.370 ; extra == 'dev' + - pytest>=6.2.0 ; extra == 'dev' + - redis ; extra == 'dev' + - rich-click ; extra == 'dev' + - setuptools ; extra == 'dev' + - sphinx ; extra == 'dev' + - sphinx>=4.2.0,<6.0.0 ; extra == 'dev' + - sphinxext-opengraph>=0.7.5 ; extra == 'dev' + - sqlalchemy ; extra == 'dev' + - torch ; python_full_version < '3.14' and sys_platform == 'linux' and extra == 'dev' + - tox>=3.20.1 ; extra == 'dev' + - typer ; extra == 'dev' + - typing-extensions>=3.10.0.0 ; extra == 'dev' + - xarray ; python_full_version < '3.15' and extra == 'dev' + - mkdocs-material[imaging]>=9.6.0 ; extra == 'doc-ghp' + - mkdocstrings-python-xref>=1.16.0 ; extra == 'doc-ghp' + - mkdocstrings-python>=1.16.0 ; extra == 'doc-ghp' + - autoapi>=0.9.0 ; extra == 'doc-rtd' + - pydata-sphinx-theme<=0.7.2 ; extra == 'doc-rtd' + - setuptools ; extra == 'doc-rtd' + - sphinx>=4.2.0,<6.0.0 ; extra == 'doc-rtd' + - sphinxext-opengraph>=0.7.5 ; extra == 'doc-rtd' + - celery ; extra == 'test' + - click ; extra == 'test' + - coverage>=5.5 ; extra == 'test' + - docutils>=0.22.0 ; extra == 'test' + - equinox ; python_full_version < '3.15' and sys_platform == 'linux' and extra == 'test' + - fastmcp ; python_full_version < '3.14' and extra == 'test' + - jax[cpu] ; python_full_version < '3.15' and sys_platform == 'linux' and extra == 'test' + - jaxtyping ; sys_platform == 'linux' and extra == 'test' + - langchain ; python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'darwin' and extra == 'test' + - mypy>=0.800 ; platform_python_implementation != 'PyPy' and extra == 'test' + - nuitka>=1.2.6 ; python_full_version < '3.14' and sys_platform == 'linux' and extra == 'test' + - numba ; python_full_version < '3.14' and extra == 'test' + - numpy ; python_full_version < '3.15' and platform_python_implementation != 'PyPy' and sys_platform != 'darwin' and extra == 'test' + - pandera>=0.26.0 ; python_full_version < '3.14' and extra == 'test' + - poetry ; extra == 'test' + - polars ; python_full_version < '3.14' and extra == 'test' + - pygments ; extra == 'test' + - pyinstaller ; extra == 'test' + - pyright>=1.1.370 ; extra == 'test' + - pytest>=6.2.0 ; extra == 'test' + - redis ; extra == 'test' + - rich-click ; extra == 'test' + - sphinx ; extra == 'test' + - sqlalchemy ; extra == 'test' + - torch ; python_full_version < '3.14' and sys_platform == 'linux' and extra == 'test' + - tox>=3.20.1 ; extra == 'test' + - typer ; extra == 'test' + - typing-extensions>=3.10.0.0 ; extra == 'test' + - xarray ; python_full_version < '3.15' and extra == 'test' + - celery ; extra == 'test-tox' + - click ; extra == 'test-tox' + - docutils>=0.22.0 ; extra == 'test-tox' + - equinox ; python_full_version < '3.15' and sys_platform == 'linux' and extra == 'test-tox' + - fastmcp ; python_full_version < '3.14' and extra == 'test-tox' + - jax[cpu] ; python_full_version < '3.15' and sys_platform == 'linux' and extra == 'test-tox' + - jaxtyping ; sys_platform == 'linux' and extra == 'test-tox' + - langchain ; python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'darwin' and extra == 'test-tox' + - mypy>=0.800 ; platform_python_implementation != 'PyPy' and extra == 'test-tox' + - nuitka>=1.2.6 ; python_full_version < '3.14' and sys_platform == 'linux' and extra == 'test-tox' + - numba ; python_full_version < '3.14' and extra == 'test-tox' + - numpy ; python_full_version < '3.15' and platform_python_implementation != 'PyPy' and sys_platform != 'darwin' and extra == 'test-tox' + - pandera>=0.26.0 ; python_full_version < '3.14' and extra == 'test-tox' + - poetry ; extra == 'test-tox' + - polars ; python_full_version < '3.14' and extra == 'test-tox' + - pygments ; extra == 'test-tox' + - pyinstaller ; extra == 'test-tox' + - pyright>=1.1.370 ; extra == 'test-tox' + - pytest>=6.2.0 ; extra == 'test-tox' + - redis ; extra == 'test-tox' + - rich-click ; extra == 'test-tox' + - sphinx ; extra == 'test-tox' + - sqlalchemy ; extra == 'test-tox' + - torch ; python_full_version < '3.14' and sys_platform == 'linux' and extra == 'test-tox' + - typer ; extra == 'test-tox' + - typing-extensions>=3.10.0.0 ; extra == 'test-tox' + - xarray ; python_full_version < '3.15' and extra == 'test-tox' + - coverage>=5.5 ; extra == 'test-tox-coverage' + requires_python: '>=3.10' - conda: https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.14.3-pyha770c72_0.conda sha256: bf1e71c3c0a5b024e44ff928225a0874fc3c3356ec1a0b6fe719108e6d1288f6 md5: 5267bef8efea4127aacd1f4e1f149b6e @@ -13231,9 +13352,10 @@ packages: timestamp: 1753199211006 - pypi: ./ name: skillmodels - version: 0.0.24.dev333+g5f274a41a.d20260514 - sha256: 0013bc372ff433bdb373ce83593ec898cbac9ba97ce606b6c10a519c58dbf15d + version: 0.0.24.dev338+g220621200.d20260514 + sha256: d6e4277c6d291d2758728c64f846f7413ce3f40fca556e1040e279a3e6c23d41 requires_dist: + - beartype>=0.22 - dags>=0.5.1 - jax>=0.9 - jaxopt>=0.8.5 diff --git a/pyproject.toml b/pyproject.toml index cf306036..3b8bffde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ classifiers = [ ] dynamic = [ "version" ] dependencies = [ + "beartype>=0.22", "dags>=0.5.1", "jax>=0.9", "jaxopt>=0.8.5", diff --git a/src/skillmodels/__init__.py b/src/skillmodels/__init__.py index 49fe0423..f97183d5 100644 --- a/src/skillmodels/__init__.py +++ b/src/skillmodels/__init__.py @@ -1,21 +1,30 @@ """Skillmodels: A Python package for estimating latent factor models.""" -# Enable 64-bit JAX before any skillmodels submodule -- and crucially before -# any transitive `import jaxopt` -- so jaxopt's module-level jit/sort -# kernels see int64 as the default integer type. Without this, jaxopt's -# `argsort` inside `LBFGSB.update` emits an `s32` accumulator into an -# `s64` scatter operand and XLA's permutation_sort_simplifier verifier -# rejects it on JAX >= 0.10 / cuda13. The package has always assumed -# x64 (every CHS / AF / AMN entry point sets it inside the function); -# centralising it at import time fixes the jaxopt path too and is a -# no-op for callers who already enable it. +# Enable 64-bit JAX before any skillmodels submodule. Every CHS / AF / AMN +# entry point already sets this inside its function body; centralising it +# here makes the package behave consistently for direct callers. import os os.environ.setdefault("JAX_ENABLE_X64", "1") -import contextlib - -import jax +# Workaround for a JAX 0.10 XLA bug surfaced by jaxopt's `LBFGSB.update`. +# The `permutation_sort_simplifier` HLO pass mis-lowers the `argsort` +# inside `update`: it emits an s32 reduction accumulator into the s64 +# scatter operand built by the rest of the optimizer, and the HLO +# verifier rejects the resulting mismatch with `INVALID_ARGUMENT: +# Reduction function's accumulator shape at index 0 differs from the +# init_value shape: s32[] vs s64[]`. Disabling just that one pass via +# `XLA_FLAGS` keeps every other XLA optimisation intact and is a no-op +# on JAX < 0.10 (pre-0.10 lacks the pass). Must be set *before* `import +# jax` because XLA reads `XLA_FLAGS` once at backend init. +_xla_pass_disable = "--xla_disable_hlo_passes=permutation_sort_simplifier" # noqa: S105 +_existing_xla_flags = os.environ.get("XLA_FLAGS", "") +if _xla_pass_disable not in _existing_xla_flags: + os.environ["XLA_FLAGS"] = f"{_existing_xla_flags} {_xla_pass_disable}".strip() + +import contextlib # noqa: E402 + +import jax # noqa: E402 jax.config.update("jax_enable_x64", True) # noqa: FBT003 diff --git a/src/skillmodels/_beartype_conf.py b/src/skillmodels/_beartype_conf.py new file mode 100644 index 00000000..dc50e6b4 --- /dev/null +++ b/src/skillmodels/_beartype_conf.py @@ -0,0 +1,87 @@ +"""Per-exception `BeartypeConf` instances used at the skillmodels perimeter. + +Decorators at user-facing entry points configure beartype to raise the +existing project exception class on parameter-type violations, +preserving the documented exception hierarchy in +`skillmodels.exceptions`. + +The constructors and call sites decorated through this module are the +"perimeter": ModelSpec / FactorSpec / AnchoringSpec / Normalizations, +the three estimation-options dataclasses, and every public function +exposed from the top-level package or the subpackage `__init__`s. The +internal helpers below the perimeter are unannotated for beartype and +trust the perimeter to have already validated parameter types. +""" + +from collections.abc import Callable + +from beartype import BeartypeConf, BeartypeStrategy, beartype + +from skillmodels.exceptions import ( + DiagnosticsCallError, + EstimationCallError, + InferenceCallError, + ModelSpecInitializationError, + OptionsInitializationError, + SimulationCallError, +) + + +def _conf(exc: type[Exception]) -> BeartypeConf: + """Build a `BeartypeConf` that raises `exc` on parameter-type violations. + + `On` strategy: full O(n) container validation so every bad entry in + a mapping/sequence is reported, not just one sampled element. The + decorated entry points are called rarely (construction, estimate, + simulate, plot), so per-call cost is invisible compared to the + JIT-compiled hot path each one kicks off. + + `is_pep484_tower=True`: respect the PEP-484 numeric tower so `int` + satisfies `float`-typed parameters (matches the implicit numeric + conversion that Python and ruff's PYI041 both assume). + """ + return BeartypeConf( + violation_param_type=exc, + strategy=BeartypeStrategy.On, + is_pep484_tower=True, + ) + + +def beartype_init(conf: BeartypeConf) -> Callable[[type], type]: + """Class decorator that wraps only `__init__` with `@beartype(conf=conf)`. + + Bare `@beartype` on a class wraps every method, which surfaces + non-public annotation drift on instance methods that has nothing + to do with parameter validation at construction time (e.g. a + helper method that takes a JAX array typed loosely as `Any`). The + only annotations we actively curate at the perimeter are the + public-facing `__init__` parameters; restrict to those. + """ + + def wrap(cls: type) -> type: + cls.__init__ = beartype(conf=conf)(cls.__init__) # ty: ignore[invalid-assignment] + return cls + + return wrap + + +# Construction of the four user-facing model-spec dataclasses. +MODEL_SPEC_CONF = _conf(ModelSpecInitializationError) + +# Construction of CHSEstimationOptions, AFEstimationOptions, +# AMNEstimationOptions. +OPTIONS_CONF = _conf(OptionsInitializationError) + +# `get_maximization_inputs`, `get_filtered_states`, `estimate_af`, +# `estimate_amn`, `get_af_posterior_states`, +# `get_amn_posterior_states`. +ESTIMATION_CONF = _conf(EstimationCallError) + +# `compute_af_standard_errors`, `compute_amn_standard_errors`. +INFERENCE_CONF = _conf(InferenceCallError) + +# `simulate_dataset`, `simulate_policy_effect`. +SIMULATION_CONF = _conf(SimulationCallError) + +# Diagnostics + visualisation entry points. +DIAGNOSTICS_CONF = _conf(DiagnosticsCallError) diff --git a/src/skillmodels/af/estimate.py b/src/skillmodels/af/estimate.py index d16650c1..7da40958 100644 --- a/src/skillmodels/af/estimate.py +++ b/src/skillmodels/af/estimate.py @@ -9,8 +9,10 @@ import numpy as np import optimagic as om import pandas as pd +from beartype import beartype from jax import Array +from skillmodels._beartype_conf import ESTIMATION_CONF from skillmodels.af.initial_period import estimate_initial_period from skillmodels.af.params import get_measurements_per_factor from skillmodels.af.transition_period import estimate_transition_period @@ -29,6 +31,7 @@ from skillmodels.common.process_model import process_model +@beartype(conf=ESTIMATION_CONF) def estimate_af( # noqa: PLR0915 model_spec: ModelSpec, data: pd.DataFrame, diff --git a/src/skillmodels/af/inference.py b/src/skillmodels/af/inference.py index 2f644b27..05211674 100644 --- a/src/skillmodels/af/inference.py +++ b/src/skillmodels/af/inference.py @@ -39,8 +39,10 @@ import jax.numpy as jnp import numpy as np import pandas as pd +from beartype import beartype from jax import Array +from skillmodels._beartype_conf import INFERENCE_CONF from skillmodels.af.batching import auto_n_obs_per_batch from skillmodels.af.estimate import _extract_period_data from skillmodels.af.halton import create_halton_nodes_and_weights @@ -119,6 +121,7 @@ class AFInferenceResult: """Number of bootstrap replicates drawn.""" +@beartype(conf=INFERENCE_CONF) def compute_af_standard_errors( result: AFEstimationResult, data: pd.DataFrame, diff --git a/src/skillmodels/af/jaxopt_backend.py b/src/skillmodels/af/jaxopt_backend.py index 5c73c4e0..0d5302da 100644 --- a/src/skillmodels/af/jaxopt_backend.py +++ b/src/skillmodels/af/jaxopt_backend.py @@ -15,17 +15,25 @@ import os -# Ensure x64 is on *before* `from jaxopt import LBFGSB` -- jaxopt's -# module-level jit kernels resolve the default integer dtype at import -# time. With x64 off, `jnp.argsort` inside `LBFGSB.update` emits int32 -# indices that scatter into the int64 operand the rest of the optimizer -# builds, and XLA's permutation_sort_simplifier verifier rejects the -# resulting mismatch on JAX >= 0.10. `skillmodels/__init__.py` sets the -# same flag at package import; this is a belt-and-suspenders guard for -# callers that import this module directly. +# Belt-and-suspenders for callers that import this module directly without +# going through `skillmodels/__init__.py`. Two things must be set before +# `import jax` / `from jaxopt import LBFGSB`: +# +# 1. `JAX_ENABLE_X64=1` — the AF pipeline assumes float64 throughout. +# 2. `XLA_FLAGS=--xla_disable_hlo_passes=permutation_sort_simplifier` — +# works around a JAX 0.10 bug where the `argsort` inside +# `LBFGSB.update` emits an s32 reduction accumulator into an s64 +# scatter operand, and XLA's `permutation_sort_simplifier` pass +# rejects the mismatch. See `skillmodels/__init__.py` for the full +# explanation. os.environ.setdefault("JAX_ENABLE_X64", "1") -import jax +_xla_pass_disable = "--xla_disable_hlo_passes=permutation_sort_simplifier" # noqa: S105 +_existing_xla_flags = os.environ.get("XLA_FLAGS", "") +if _xla_pass_disable not in _existing_xla_flags: + os.environ["XLA_FLAGS"] = f"{_existing_xla_flags} {_xla_pass_disable}".strip() + +import jax # noqa: E402 jax.config.update("jax_enable_x64", True) # noqa: FBT003 @@ -136,24 +144,64 @@ def objective_and_grad(free_vec: Array) -> tuple[Array, Array]: val, grad = loglike_and_grad(full_vec) return val, grad[free_idx] + # Match scipy_lbfgsb's stopping rule: stop when EITHER + # * max|projected_grad| < gtol_abs ("gtol channel"), OR + # * (f_k - f_{k+1}) / max(|f_k|, |f_{k+1}|, 1) < ftol_rel + # ("ftol channel"; this is the criterion that typically fires in + # practice for skill-formation likelihoods that go locally flat + # before the gradient does). + # Accept the canonical scipy keys so the same `optimizer_options` + # dict works for both backends; fall back to historical jaxopt + # names for compatibility. + gtol_abs = float(options.pop("convergence_gtol_abs", options.pop("tol", 1e-5))) + ftol_rel = float(options.pop("convergence_ftol_rel", 2.22e-9)) + maxiter = int(options.pop("stopping_maxiter", options.pop("maxiter", 15_000))) + history_size = int(options.pop("history_size", 10)) + solver = LBFGSB( fun=objective_and_grad, value_and_grad=True, - maxiter=int(options.pop("maxiter", 500)), - tol=float(options.pop("tol", 1e-6)), - history_size=int(options.pop("history_size", 10)), + # `maxiter` here is jaxopt's *internal* fail-safe cap; the outer + # Python loop below drives stopping. Set huge so jaxopt never + # interrupts us mid-iteration. + maxiter=maxiter, + tol=gtol_abs, + history_size=history_size, **options, ) - opt_step = solver.run(free_initial, bounds=(free_lower, free_upper)) - final_full = full_template.at[free_idx].set(opt_step.params) # noqa: PD008 + bounds = (free_lower, free_upper) + state = solver.init_state(free_initial, bounds=bounds) + params = free_initial + prev_val = jnp.inf + stopped_on = "maxiter" + n_iter = 0 + # fallback if `maxiter == 0` and the loop body never executes. + for n_iter in range(1, maxiter + 1): # noqa: B007 + params, state = solver.update(params, state, bounds=bounds) + cur_val = state.value + # gtol channel + if bool(state.error < gtol_abs): + stopped_on = "gtol" + break + # ftol channel (skip first iteration where prev_val == inf) + denom = jnp.maximum( + jnp.maximum(jnp.abs(prev_val), jnp.abs(cur_val)), + 1.0, + ) + rel_drop = jnp.abs(prev_val - cur_val) / denom + if bool(jnp.isfinite(prev_val)) and bool(rel_drop < ftol_rel): + stopped_on = "ftol" + break + prev_val = cur_val + + final_full = full_template.at[free_idx].set(params) # noqa: PD008 result_df = full_params_df.copy() result_df["value"] = np.asarray(jax.device_get(final_full)) - n_iter = int(opt_step.state.iter_num) return JaxoptResult( params=result_df, - fun=float(jax.device_get(opt_step.state.value)), - success=n_iter < solver.maxiter, + fun=float(jax.device_get(state.value)), + success=stopped_on != "maxiter", n_iter=n_iter, ) diff --git a/src/skillmodels/af/likelihood.py b/src/skillmodels/af/likelihood.py index 10506e71..c1bbd41f 100644 --- a/src/skillmodels/af/likelihood.py +++ b/src/skillmodels/af/likelihood.py @@ -4,11 +4,12 @@ """ import functools -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import Any import jax import jax.numpy as jnp +import numpy as np from jax import Array from skillmodels.af.types import ChainLink @@ -230,7 +231,7 @@ def _parse_initial_params( def _map_over_obs( f: Callable, - *xs: Array, + *xs: Array | np.ndarray, n_obs_per_batch: int | None, ) -> Array: """Map ``f`` over the leading axis of ``xs``, optionally in batches. @@ -557,7 +558,7 @@ def af_per_obs_loglike_transition( prev_control_params: Array, prev_loadings_flat: Array, prev_meas_sds: Array, - prev_distribution: dict[str, Array], + prev_distribution: Mapping[str, Array | np.ndarray], chain_links: tuple[ChainLink, ...], obs_factor_values_chain: Array, joint_nodes: Array, @@ -660,7 +661,7 @@ def af_loglike_transition( prev_control_params: Array, prev_loadings_flat: Array, prev_meas_sds: Array, - prev_distribution: dict[str, Array], + prev_distribution: Mapping[str, Array | np.ndarray], chain_links: tuple[ChainLink, ...], obs_factor_values_chain: Array, joint_nodes: Array, @@ -868,7 +869,7 @@ def _transition_loglike_per_obs( prev_meas_mask: Array, prev_full_loadings: Array, prev_meas_sds: Array, - prev_distribution: dict[str, Array], + prev_distribution: Mapping[str, Array | np.ndarray], chain_links: tuple[ChainLink, ...], obs_factor_values_chain: Array, joint_nodes: Array, @@ -967,8 +968,8 @@ def _single_obs( def _compute_investment( theta_prev: Array, obs_factor_values: Array, - inv_eq_params: Array, - inv_sds: Array, + inv_eq_params: Array | np.ndarray, + inv_sds: Array | np.ndarray, eps_i: Array, n_endogenous_factors: int, n_state_factors: int, @@ -1001,8 +1002,8 @@ def _rebuild_chain_at_period( z_state: Array, z_inv_per_step: Array, z_shock_per_step: Array, - initial_mean: Array, - initial_chol: Array, + initial_mean: Array | np.ndarray, + initial_chol: Array | np.ndarray, chain_links: tuple[ChainLink, ...], obs_factor_values_at_obs_per_step: Array, n_state_factors: int, @@ -1043,7 +1044,7 @@ def _rebuild_chain_at_period( step), shape (n_state_factors,). When `chain_links` is empty, returns the period-0 state directly. """ - theta = initial_mean + initial_chol @ z_state + theta = jnp.asarray(initial_mean + initial_chol @ z_state) for step_idx, link in enumerate(chain_links): z_inv = z_inv_per_step[step_idx] z_shock = z_shock_per_step[step_idx] @@ -1080,9 +1081,9 @@ def _integrate_transition_single_obs( prev_meas_mask: Array, prev_full_loadings: Array, prev_meas_sds: Array, - obs_cond_weights: Array, - obs_cond_means: Array, - cond_chols: Array, + obs_cond_weights: Array | np.ndarray, + obs_cond_means: Array | np.ndarray, + cond_chols: Array | np.ndarray, chain_links: tuple[ChainLink, ...], obs_factor_values_chain: Array, joint_nodes: Array, diff --git a/src/skillmodels/af/posterior_states.py b/src/skillmodels/af/posterior_states.py index a0146de5..a1e9a241 100644 --- a/src/skillmodels/af/posterior_states.py +++ b/src/skillmodels/af/posterior_states.py @@ -10,8 +10,10 @@ import jax.numpy as jnp import numpy as np import pandas as pd +from beartype import beartype from jax import Array +from skillmodels._beartype_conf import ESTIMATION_CONF from skillmodels.af.halton import create_halton_nodes_and_weights from skillmodels.af.initial_period import _build_loading_mask, _get_ordered_measures from skillmodels.af.likelihood import _log_normal_pdf @@ -21,6 +23,7 @@ from skillmodels.common.state_ranges import create_state_ranges +@beartype(conf=ESTIMATION_CONF) def get_af_posterior_states( af_result: AFEstimationResult, model_spec: ModelSpec, diff --git a/src/skillmodels/af/transition_period.py b/src/skillmodels/af/transition_period.py index d8015b85..1b5e204c 100644 --- a/src/skillmodels/af/transition_period.py +++ b/src/skillmodels/af/transition_period.py @@ -366,7 +366,7 @@ def _run_transition_optimization( prev_measurements: Array, prev_controls: Array, loading_mask: np.ndarray, - prev_dist_arrays: dict[str, Array], + prev_dist_arrays: dict[str, Array | np.ndarray], chain_links: tuple[ChainLink, ...], obs_factor_values_chain: Array, joint_nodes: Array, @@ -685,7 +685,7 @@ def _prepare_transition_inputs( transition_info: TransitionInfo, factors: tuple[str, ...], n_obs: int, -) -> tuple[dict[str, Array], int]: +) -> tuple[dict[str, Array | np.ndarray], int]: """Pack the period-0 conditional distribution payload for the likelihood. Returns a dict the transition likelihood reads to seed its on-demand @@ -1181,10 +1181,10 @@ def _update_conditional_distribution( # `joint_nodes.shape[0]` rows. z_block_curr = n_shock + n_endog - def _chain_one_component(prev_sample: Array) -> Array: + def _chain_one_component(prev_sample: Array | np.ndarray) -> Array: """Map (j, i) -> theta_t given prev_sample (n_halton, n_obs, n_state).""" - def _at_node(j_idx: int, i_idx: int) -> Array: + def _at_node(j_idx: int | Array, i_idx: int | Array) -> Array: theta_prev = prev_sample[j_idx, i_idx] obs_y = ( observed_factor_values[i_idx] diff --git a/src/skillmodels/af/types.py b/src/skillmodels/af/types.py index d45064de..15af8ed7 100644 --- a/src/skillmodels/af/types.py +++ b/src/skillmodels/af/types.py @@ -3,18 +3,19 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal import jax +import numpy as np import pandas as pd from jax import Array +from skillmodels._beartype_conf import OPTIONS_CONF, beartype_init +from skillmodels.common.model_spec import ModelSpec from skillmodels.common.types import ensure_containers_are_immutable -if TYPE_CHECKING: - from skillmodels.common.model_spec import ModelSpec - +@beartype_init(OPTIONS_CONF) @dataclass(frozen=True, init=False) class AFEstimationOptions: """Configuration options for the AF estimator.""" @@ -183,10 +184,10 @@ def __init__( # noqa: D107 class MixtureComponent: """Single component of a Gaussian mixture distribution.""" - mean: Array + mean: Array | np.ndarray """Mean vector, shape (n_factors,).""" - chol_cov: Array + chol_cov: Array | np.ndarray """Lower-triangular Cholesky factor of covariance, shape (n_factors, n_factors).""" @@ -208,30 +209,30 @@ class ChainLink: transition_func: Callable """Combined per-factor transition function f(full_states, params).""" - transition_params: Array + transition_params: Array | np.ndarray """Flat transition parameter vector for this period, shape ``(total_n_transition_params,)``.""" - shock_sds: Array + shock_sds: Array | np.ndarray """Production shock SDs for shock-bearing state factors, shape ``(n_shock_factors,)``.""" - shock_factor_indices: Array + shock_factor_indices: Array | np.ndarray """Mapping each shock slot to its position in the state-factor ordering, shape ``(n_shock_factors,)`` int.""" - inv_eq_params: Array + inv_eq_params: Array | np.ndarray """Flat investment-equation parameters, shape ``(n_endogenous * n_inv_eq_params_per,)``.""" - inv_sds: Array + inv_sds: Array | np.ndarray """Investment shock SDs, shape ``(n_endogenous,)``.""" n_inv_eq_params_per: int """Investment equation parameters per endogenous factor (1 + n_state + n_observed_factors when n_endogenous > 0; 0 otherwise).""" - obs_factor_values: Array + obs_factor_values: Array | np.ndarray """Observed factor values at this link's source period (i.e. period - 1), shape ``(n_obs, n_observed_factors)``. Used in the chain rebuild for the inv equation and the transition function.""" @@ -280,7 +281,7 @@ class ConditionalDistribution: inside the transition likelihood (which rebuilds the chain on-demand). """ - mixture_weights: Array + mixture_weights: Array | np.ndarray """Mixture weights, shape (n_components,).""" components: tuple[MixtureComponent, ...] @@ -288,27 +289,27 @@ class ConditionalDistribution: importance sample. Used by `posterior_states` and `inference`; not used in the transition likelihood itself.""" - samples_per_component: tuple[Array, ...] + samples_per_component: tuple[Array | np.ndarray, ...] """One importance-sample array per mixture component, each shape ``(n_halton, n_obs, n_state)``. Retained for posterior-state summary statistics; not consumed by the transition likelihood (which rebuilds the chain on-demand from a joint Halton). May use a smaller Halton count than the likelihood's `n_halton_points`.""" - conditional_weights: Array | None = None + conditional_weights: Array | np.ndarray | None = None """Individual-specific conditional mixture weights, shape (n_obs, n_components). When not None, these override `mixture_weights` for each observation (computed from Bayes' rule using data from previous periods). """ - cond_means: Array | None = None + cond_means: Array | np.ndarray | None = None """Per-obs Schur-conditional means of the latent state given observed factors at period 0, shape ``(n_components, n_obs, n_state)``. Built by the initial period only. None for transition-period distributions. """ - cond_chols: Array | None = None + cond_chols: Array | np.ndarray | None = None """Per-component Schur-conditional Cholesky factors at period 0, shape ``(n_components, n_state, n_state)``. Shared across observations because the conditional covariance does not depend on Y_i (it's the diff --git a/src/skillmodels/amn/estimate.py b/src/skillmodels/amn/estimate.py index 3914156d..bc410450 100644 --- a/src/skillmodels/amn/estimate.py +++ b/src/skillmodels/amn/estimate.py @@ -12,7 +12,9 @@ import optimagic as om import pandas as pd +from beartype import beartype +from skillmodels._beartype_conf import ESTIMATION_CONF from skillmodels.amn.minimum_distance import solve_minimum_distance from skillmodels.amn.mixture_em import ( build_augmented_measure_layout, @@ -90,6 +92,7 @@ def _apply_overrides( return out.sort_index() +@beartype(conf=ESTIMATION_CONF) def estimate_amn( model_spec: ModelSpec, data: pd.DataFrame, diff --git a/src/skillmodels/amn/inference.py b/src/skillmodels/amn/inference.py index 2c119c08..90a89f24 100644 --- a/src/skillmodels/amn/inference.py +++ b/src/skillmodels/amn/inference.py @@ -24,7 +24,9 @@ import numpy as np import pandas as pd +from beartype import beartype +from skillmodels._beartype_conf import INFERENCE_CONF from skillmodels.amn.estimate import estimate_amn from skillmodels.amn.types import ( AMNEstimationOptions, @@ -54,6 +56,7 @@ def _resample_by_caseid(data: pd.DataFrame, rng: np.random.Generator) -> pd.Data return pd.concat(pieces) +@beartype(conf=INFERENCE_CONF) def compute_amn_standard_errors( result: AMNEstimationResult, data: pd.DataFrame, diff --git a/src/skillmodels/amn/posterior_states.py b/src/skillmodels/amn/posterior_states.py index 242f626f..9e9d9247 100644 --- a/src/skillmodels/amn/posterior_states.py +++ b/src/skillmodels/amn/posterior_states.py @@ -28,13 +28,16 @@ import numpy as np import pandas as pd +from beartype import beartype +from skillmodels._beartype_conf import ESTIMATION_CONF from skillmodels.amn.mixture_em import build_augmented_measure_matrix from skillmodels.amn.types import AMNEstimationResult from skillmodels.common.process_model import process_model from skillmodels.common.state_ranges import create_state_ranges +@beartype(conf=ESTIMATION_CONF) def get_amn_posterior_states( # noqa: C901, PLR0912, PLR0915 amn_result: AMNEstimationResult, data: pd.DataFrame, diff --git a/src/skillmodels/amn/types.py b/src/skillmodels/amn/types.py index 4d9da73f..17ea0806 100644 --- a/src/skillmodels/amn/types.py +++ b/src/skillmodels/amn/types.py @@ -10,17 +10,17 @@ from collections.abc import Mapping from dataclasses import dataclass from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal import numpy as np import pandas as pd +from skillmodels._beartype_conf import OPTIONS_CONF, beartype_init +from skillmodels.common.model_spec import ModelSpec from skillmodels.common.types import ensure_containers_are_immutable -if TYPE_CHECKING: - from skillmodels.common.model_spec import ModelSpec - +@beartype_init(OPTIONS_CONF) @dataclass(frozen=True, init=False) class AMNEstimationOptions: """Configuration options for the AMN estimator.""" diff --git a/src/skillmodels/chs/filtered_states.py b/src/skillmodels/chs/filtered_states.py index 7178794d..d2c7c87a 100644 --- a/src/skillmodels/chs/filtered_states.py +++ b/src/skillmodels/chs/filtered_states.py @@ -1,20 +1,26 @@ """Functions to compute and process filtered latent states.""" -from typing import TYPE_CHECKING, Any +from typing import Any import pandas as pd +from beartype import beartype +from skillmodels._beartype_conf import ESTIMATION_CONF + +# Runtime imports (not `TYPE_CHECKING`-guarded) so that the beartype +# perimeter at `get_filtered_states` can resolve the annotation +# without a forward-ref string. The two `types` modules are +# negligible-cost; AMN's already pulls sklearn lazily. +from skillmodels.af.types import AFEstimationResult +from skillmodels.amn.types import AMNEstimationResult from skillmodels.chs.maximization_inputs import get_maximization_inputs from skillmodels.common.anchoring import anchor_states_df from skillmodels.common.model_spec import ModelSpec from skillmodels.common.process_model import process_model from skillmodels.common.state_ranges import create_state_ranges -if TYPE_CHECKING: - from skillmodels.af.types import AFEstimationResult - from skillmodels.amn.types import AMNEstimationResult - +@beartype(conf=ESTIMATION_CONF) def get_filtered_states( model_spec: ModelSpec, data: pd.DataFrame, diff --git a/src/skillmodels/chs/kalman_filters.py b/src/skillmodels/chs/kalman_filters.py index 01786217..72360cd9 100644 --- a/src/skillmodels/chs/kalman_filters.py +++ b/src/skillmodels/chs/kalman_filters.py @@ -29,7 +29,7 @@ def kalman_update( upper_chols: Array, loadings: Array, control_params: Array, - meas_sd: Array, + meas_sd: float | Array, measurements: Array, controls: Array, log_mixture_weights: Array, @@ -168,7 +168,7 @@ def kalman_predict( transition_func: Callable, states: Array, upper_chols: Array, - sigma_scaling_factor: float, + sigma_scaling_factor: float | Array, sigma_weights: Array, trans_coeffs: dict[str, Array], shock_sds: Array, @@ -241,7 +241,7 @@ def linear_kalman_predict( transition_func: Callable | None, # noqa: ARG001 states: Array, upper_chols: Array, - sigma_scaling_factor: float, # noqa: ARG001 + sigma_scaling_factor: float | Array, # noqa: ARG001 sigma_weights: Array, # noqa: ARG001 trans_coeffs: dict[str, Array], shock_sds: Array, @@ -375,7 +375,7 @@ def _build_f_and_c( def _calculate_sigma_points( states: Array, upper_chols: Array, - scaling_factor: float, + scaling_factor: float | Array, observed_factors: Array, ) -> Array: """Calculate the array of sigma_points for the unscented transform. diff --git a/src/skillmodels/chs/kalman_filters_debug.py b/src/skillmodels/chs/kalman_filters_debug.py index 230b768d..c7dd20a8 100644 --- a/src/skillmodels/chs/kalman_filters_debug.py +++ b/src/skillmodels/chs/kalman_filters_debug.py @@ -14,7 +14,7 @@ def kalman_update( upper_chols: Array, loadings: Array, control_params: Array, - meas_sd: float, + meas_sd: float | Array, measurements: Array, controls: Array, log_mixture_weights: Array, diff --git a/src/skillmodels/chs/likelihood.py b/src/skillmodels/chs/likelihood.py index c53fc039..2f9a5225 100644 --- a/src/skillmodels/chs/likelihood.py +++ b/src/skillmodels/chs/likelihood.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +import numpy as np from jax import Array from skillmodels.chs.clipping import soft_clipping @@ -23,18 +24,18 @@ def log_likelihood( params: Array, parsing_info: ParsingInfo, - measurements: Array, + measurements: Array | np.ndarray, controls: Array, predict_func: Callable, - sigma_scaling_factor: float, + sigma_scaling_factor: float | Array, sigma_weights: Array, dimensions: Dimensions, labels: Labels, chs_estimation_options: CHSEstimationOptions, - is_measurement_iteration: Array, - is_predict_iteration: Array, - iteration_to_period: Array, - observed_factors: Array, + is_measurement_iteration: Array | np.ndarray, + is_predict_iteration: Array | np.ndarray, + iteration_to_period: Array | np.ndarray, + observed_factors: Array | np.ndarray, ) -> Array: """Aggregated log likelihood of a skill formation model. @@ -91,18 +92,18 @@ def log_likelihood( def log_likelihood_obs( params: Array, parsing_info: ParsingInfo, - measurements: Array, + measurements: Array | np.ndarray, controls: Array, predict_func: Callable, - sigma_scaling_factor: float, + sigma_scaling_factor: float | Array, sigma_weights: Array, dimensions: Dimensions, labels: Labels, chs_estimation_options: CHSEstimationOptions, - is_measurement_iteration: Array, - is_predict_iteration: Array, - iteration_to_period: Array, - observed_factors: Array, + is_measurement_iteration: Array | np.ndarray, + is_predict_iteration: Array | np.ndarray, + iteration_to_period: Array | np.ndarray, + observed_factors: Array | np.ndarray, ) -> Array: """Log likelihood of a skill formation model. @@ -200,10 +201,10 @@ def _scan_body( loop_args: dict[str, Array], controls: Array, parsed_params: ParsedParams, - sigma_scaling_factor: float, + sigma_scaling_factor: float | Array, sigma_weights: Array, predict_func: Callable, - observed_factors: Array, + observed_factors: Array | np.ndarray, ) -> tuple[dict[str, Array], dict[str, Array]]: # ================================================================================== # create arguments needed for update diff --git a/src/skillmodels/chs/likelihood_debug.py b/src/skillmodels/chs/likelihood_debug.py index 8f729174..ff22194c 100644 --- a/src/skillmodels/chs/likelihood_debug.py +++ b/src/skillmodels/chs/likelihood_debug.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +import numpy as np from jax import Array from skillmodels.chs.clipping import soft_clipping @@ -23,18 +24,18 @@ def log_likelihood( params: Array, parsing_info: ParsingInfo, - measurements: Array, + measurements: Array | np.ndarray, controls: Array, predict_func: Callable[..., tuple[Array, Array]], - sigma_scaling_factor: float, + sigma_scaling_factor: float | Array, sigma_weights: Array, dimensions: Dimensions, labels: Labels, chs_estimation_options: CHSEstimationOptions, - is_measurement_iteration: Array, - is_predict_iteration: Array, - iteration_to_period: Array, - observed_factors: Array, + is_measurement_iteration: Array | np.ndarray, + is_predict_iteration: Array | np.ndarray, + iteration_to_period: Array | np.ndarray, + observed_factors: Array | np.ndarray, ) -> dict[str, Any]: """Log likelihood of a skill formation model, returning debug data on top. @@ -151,10 +152,10 @@ def _scan_body( loop_args: dict[str, Array], controls: Array, parsed_params: ParsedParams, - sigma_scaling_factor: float, + sigma_scaling_factor: float | Array, sigma_weights: Array, predict_func: Callable[..., tuple[Array, Array]], - observed_factors: Array, + observed_factors: Array | np.ndarray, ) -> tuple[dict[str, Array], dict[str, Any]]: # ================================================================================== # create arguments needed for update diff --git a/src/skillmodels/chs/maximization_inputs.py b/src/skillmodels/chs/maximization_inputs.py index f07853f7..f8f99fb7 100644 --- a/src/skillmodels/chs/maximization_inputs.py +++ b/src/skillmodels/chs/maximization_inputs.py @@ -8,11 +8,13 @@ import jax.numpy as jnp import numpy as np import pandas as pd +from beartype import beartype from jax import Array from numpy.typing import NDArray import skillmodels.chs.likelihood as lf import skillmodels.chs.likelihood_debug as lfd +from skillmodels._beartype_conf import ESTIMATION_CONF from skillmodels.amn.estimate import estimate_amn from skillmodels.amn.start_values import get_spearman_start_params from skillmodels.chs.kalman_filters import ( @@ -41,6 +43,7 @@ jax.config.update("jax_enable_x64", True) # noqa: FBT003 +@beartype(conf=ESTIMATION_CONF) def get_maximization_inputs( # noqa: C901, PLR0915 model_spec: ModelSpec, data: pd.DataFrame, diff --git a/src/skillmodels/chs/options.py b/src/skillmodels/chs/options.py index f53f650e..47c2351b 100644 --- a/src/skillmodels/chs/options.py +++ b/src/skillmodels/chs/options.py @@ -3,7 +3,10 @@ from dataclasses import dataclass from typing import Literal +from skillmodels._beartype_conf import OPTIONS_CONF, beartype_init + +@beartype_init(OPTIONS_CONF) @dataclass(frozen=True) class CHSEstimationOptions: """Tuning parameters for the CHS Kalman-MLE estimator.""" diff --git a/src/skillmodels/chs/process_debug_data.py b/src/skillmodels/chs/process_debug_data.py index 828e8f83..5cbbe81a 100644 --- a/src/skillmodels/chs/process_debug_data.py +++ b/src/skillmodels/chs/process_debug_data.py @@ -110,7 +110,7 @@ def process_debug_data( def _create_post_update_states( - filtered_states: Array, + filtered_states: Array | np.ndarray, factors: tuple[str, ...], update_info: pd.DataFrame, ) -> pd.DataFrame: @@ -129,7 +129,7 @@ def _create_post_update_states( def _convert_state_array_to_df( - arr: NDArray[np.floating[Any]], + arr: NDArray[np.float64], factor_names: tuple[str, ...], ) -> pd.DataFrame: """Convert a 3d state array into a 2d DataFrame. @@ -145,8 +145,8 @@ def _convert_state_array_to_df( def _create_filtered_states( - filtered_states: Array, - log_mixture_weights: Array, + filtered_states: Array | np.ndarray, + log_mixture_weights: Array | np.ndarray, update_info: pd.DataFrame, factors: tuple[str, ...], ) -> pd.DataFrame: @@ -176,7 +176,7 @@ def _create_filtered_states( def _process_residuals( - residuals: Array, + residuals: Array | np.ndarray | list, update_info: pd.DataFrame, ) -> pd.DataFrame: to_concat = [] @@ -192,14 +192,14 @@ def _process_residuals( def _process_residual_sds( - residual_sds: Array, + residual_sds: Array | np.ndarray, update_info: pd.DataFrame, ) -> pd.DataFrame: return _process_residuals(residuals=residual_sds, update_info=update_info) def _process_all_contributions( - all_contributions: Array, + all_contributions: Array | np.ndarray, update_info: pd.DataFrame, ) -> pd.DataFrame: to_concat = [] diff --git a/src/skillmodels/common/check_model.py b/src/skillmodels/common/check_model.py index e6c94c93..90d853cd 100644 --- a/src/skillmodels/common/check_model.py +++ b/src/skillmodels/common/check_model.py @@ -1,6 +1,7 @@ """Functions to validate model specifications.""" from collections.abc import Mapping +from typing import Any import numpy as np @@ -90,7 +91,12 @@ def check_stagemap( return report -def _check_anchoring(anchoring: Anchoring) -> list[str]: +def _check_anchoring(anchoring: Any) -> list[str]: # noqa: ANN401 + """Validate anchoring attributes. + + Runtime-typed because callers may pass duck-typed namespaces or + partially-built objects. + """ report = [] if not isinstance(anchoring.anchoring, bool): report.append("anchoring.anchoring must be a bool.") diff --git a/src/skillmodels/common/constraints.py b/src/skillmodels/common/constraints.py index 01b50af5..0207d888 100644 --- a/src/skillmodels/common/constraints.py +++ b/src/skillmodels/common/constraints.py @@ -3,7 +3,6 @@ import functools import warnings from collections.abc import Iterable, Mapping -from dataclasses import dataclass from typing import Any import numpy as np @@ -11,6 +10,7 @@ import pandas as pd import skillmodels.common.transition_functions as t_f_module +from skillmodels.common.fixed_constraint import FixedConstraintWithValue from skillmodels.common.selector import align_index_names, select_by_loc from skillmodels.common.types import ( Anchoring, @@ -99,35 +99,6 @@ def reconcile_start_to_equality( return out -@dataclass(frozen=True) -class FixedConstraintWithValue(om.FixedConstraint): - """Fixed constraint that carries the target value and parameter location. - - `om.FixedConstraint` fixes parameters at their start values but does not carry a - target value. This wrapper adds `loc` (the parameter location in the params - DataFrame) and `value` (the value to set before optimization). - """ - - loc: pd.MultiIndex | tuple | str | None = None - """Parameter location in the params DataFrame.""" - value: float | None = None - """Value to enforce on the parameter.""" - - def __post_init__(self) -> None: - """Validate that `loc` and `value` are not None and derive `selector`.""" - if self.loc is None: - msg = "loc must not be None" - raise TypeError(msg) - if self.value is None: - msg = "value must not be None" - raise TypeError(msg) - object.__setattr__( - self, - "selector", - functools.partial(select_by_loc, loc=self.loc), - ) - - def collect_fixed_locs( constraints: Iterable[om.constraints.Constraint], ) -> set[tuple[Any, ...]]: @@ -308,7 +279,7 @@ def add_bounds(params: pd.DataFrame, bounds_distance: float) -> pd.DataFrame: return df -def _is_diagonal_entry(ind_tup: tuple[str, ...]) -> bool: +def _is_diagonal_entry(ind_tup: tuple[Any, ...]) -> bool: name2 = ind_tup[-1] middle_pos = int(len(name2) // 2) if ( diff --git a/src/skillmodels/common/correlation_heatmap.py b/src/skillmodels/common/correlation_heatmap.py index e8642312..9225e4b0 100644 --- a/src/skillmodels/common/correlation_heatmap.py +++ b/src/skillmodels/common/correlation_heatmap.py @@ -4,15 +4,18 @@ import numpy as np import pandas as pd +from beartype import beartype from numpy.typing import NDArray from plotly import graph_objects as go +from skillmodels._beartype_conf import DIAGNOSTICS_CONF from skillmodels.common.model_spec import ModelSpec from skillmodels.common.process_data import pre_process_data from skillmodels.common.process_model import process_model from skillmodels.common.types import ProcessedModel +@beartype(conf=DIAGNOSTICS_CONF) def plot_correlation_heatmap( corr: pd.DataFrame, heatmap_kwargs: dict[str, Any] | None = None, @@ -132,6 +135,7 @@ def plot_correlation_heatmap( return fig +@beartype(conf=DIAGNOSTICS_CONF) def get_measurements_corr( data: pd.DataFrame, model_spec: ModelSpec, @@ -175,6 +179,7 @@ def get_measurements_corr( return df.corr() +@beartype(conf=DIAGNOSTICS_CONF) def get_quasi_scores_corr( data: pd.DataFrame, model_spec: ModelSpec, @@ -221,6 +226,7 @@ def get_quasi_scores_corr( return df.corr() +@beartype(conf=DIAGNOSTICS_CONF) def get_scores_corr( data: pd.DataFrame, params: pd.DataFrame, @@ -786,10 +792,15 @@ def _get_factor_scores_data_for_multiple_periods( def _process_factors( - model: ProcessedModel, + model: Any, # noqa: ANN401 factors: list[str] | tuple[str, ...] | str | None, ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """Process factors to get a tuple of tuples.""" + """Process factors to get a tuple of tuples. + + `model` is annotated `Any` because tests pass minimal duck-typed + namespaces that only expose `.labels.latent_factors` / + `.labels.observed_factors`. Production callers pass a `ProcessedModel`. + """ if not factors: latent_factors = model.labels.latent_factors observed_factors = model.labels.observed_factors diff --git a/src/skillmodels/common/diagnostic_plots.py b/src/skillmodels/common/diagnostic_plots.py index 574656f3..66fb8e84 100644 --- a/src/skillmodels/common/diagnostic_plots.py +++ b/src/skillmodels/common/diagnostic_plots.py @@ -5,11 +5,14 @@ import numpy as np import pandas as pd import plotly.graph_objects as go +from beartype import beartype +from skillmodels._beartype_conf import DIAGNOSTICS_CONF from skillmodels.common.model_spec import ModelSpec from skillmodels.common.process_model import process_model +@beartype(conf=DIAGNOSTICS_CONF) def plot_residual_boxplots( model_spec: ModelSpec, *, @@ -90,7 +93,7 @@ def plot_residual_boxplots( def _create_residual_boxplot_for_period( residuals_df: pd.DataFrame, - period: int, + period: int | np.integer, period_col: str, *, show_reference_line: bool, @@ -133,6 +136,7 @@ def _create_residual_boxplot_for_period( return fig +@beartype(conf=DIAGNOSTICS_CONF) def plot_likelihood_contributions( model_spec: ModelSpec, *, @@ -206,7 +210,7 @@ def plot_likelihood_contributions( def _create_likelihood_boxplot_for_period( contributions_df: pd.DataFrame, - period: int, + period: int | np.integer, period_col: str, layout_kwargs: dict[str, Any] | None, ) -> go.Figure: diff --git a/src/skillmodels/common/fixed_constraint.py b/src/skillmodels/common/fixed_constraint.py new file mode 100644 index 00000000..0296f7d2 --- /dev/null +++ b/src/skillmodels/common/fixed_constraint.py @@ -0,0 +1,46 @@ +"""`FixedConstraintWithValue`: leaf data type used across constraint code. + +Lives in its own module so that low-level callers (`transition_functions`, +`af/params`, etc.) can import it without triggering the heavier +`skillmodels.common.constraints` module — `constraints.py` imports +`transition_functions`, which would otherwise force a circular import or +a `TYPE_CHECKING` guard that beartype.claw cannot resolve at decoration +time. +""" + +import functools +from dataclasses import dataclass + +import optimagic as om +import pandas as pd + +from skillmodels.common.selector import select_by_loc + + +@dataclass(frozen=True) +class FixedConstraintWithValue(om.FixedConstraint): + """Fixed constraint that carries the target value and parameter location. + + `om.FixedConstraint` fixes parameters at their start values but does not carry a + target value. This wrapper adds `loc` (the parameter location in the params + DataFrame) and `value` (the value to set before optimization). + """ + + loc: pd.MultiIndex | tuple | str | None = None + """Parameter location in the params DataFrame.""" + value: float | None = None + """Value to enforce on the parameter.""" + + def __post_init__(self) -> None: + """Validate that `loc` and `value` are not None and derive `selector`.""" + if self.loc is None: + msg = "loc must not be None" + raise TypeError(msg) + if self.value is None: + msg = "value must not be None" + raise TypeError(msg) + object.__setattr__( + self, + "selector", + functools.partial(select_by_loc, loc=self.loc), + ) diff --git a/src/skillmodels/common/model_spec.py b/src/skillmodels/common/model_spec.py index 708ccada..53ccba78 100644 --- a/src/skillmodels/common/model_spec.py +++ b/src/skillmodels/common/model_spec.py @@ -11,12 +11,14 @@ from types import MappingProxyType from typing import Any, Self +from skillmodels._beartype_conf import MODEL_SPEC_CONF, beartype_init from skillmodels.common.types import ( Normalizations, ensure_containers_are_immutable, ) +@beartype_init(MODEL_SPEC_CONF) @dataclass(frozen=True) class FactorSpec: """Specification for a single latent factor.""" @@ -62,6 +64,7 @@ def with_normalizations(self, normalizations: Normalizations) -> Self: return replace(self, normalizations=normalizations) +@beartype_init(MODEL_SPEC_CONF) @dataclass(frozen=True) class AnchoringSpec: """Specification for anchoring latent factors to outcomes.""" @@ -83,6 +86,7 @@ def __post_init__(self) -> None: # noqa: D105 ) +@beartype_init(MODEL_SPEC_CONF) @dataclass(frozen=True, init=False) class ModelSpec: """Complete model specification. diff --git a/src/skillmodels/common/process_model.py b/src/skillmodels/common/process_model.py index dc5f2498..6bdb61ab 100644 --- a/src/skillmodels/common/process_model.py +++ b/src/skillmodels/common/process_model.py @@ -130,7 +130,7 @@ def get_has_endogenous_factors(factors: Mapping[str, FactorSpec]) -> bool: "A factor cannot be a correction and not endogenous, got:\n" f"{endogenous_factors}" ) - return endogenous_factors["is_endogenous"].any() # ty: ignore[invalid-return-type] + return bool(endogenous_factors["is_endogenous"].any()) def get_dimensions( @@ -173,7 +173,7 @@ def _get_aug_periods_to_periods( def _aug_periods_from_period( - period: int, aug_periods_to_periods: dict[int, int] + period: int, aug_periods_to_periods: Mapping[int, int] ) -> list[int]: """The inverse of the the aug_periods_to_periods mapper.""" return [ap for ap, p in aug_periods_to_periods.items() if p == period] @@ -389,7 +389,7 @@ def _extract_factor(states: Array, pos: int) -> Array: return TransitionInfo( func=transition_function, param_names=MappingProxyType( - dict(zip(latent_factors, param_names, strict=False)) + dict(zip(latent_factors, (tuple(p) for p in param_names), strict=False)) ), individual_functions=MappingProxyType(individual_functions), function_names=MappingProxyType( diff --git a/src/skillmodels/common/simulate_data.py b/src/skillmodels/common/simulate_data.py index 0ea88ad8..67060b7b 100644 --- a/src/skillmodels/common/simulate_data.py +++ b/src/skillmodels/common/simulate_data.py @@ -6,9 +6,11 @@ import jax.numpy as jnp import numpy as np import pandas as pd +from beartype import beartype from jax import Array from numpy.typing import NDArray +from skillmodels._beartype_conf import SIMULATION_CONF from skillmodels.common.anchoring import anchor_states_df from skillmodels.common.model_spec import ModelSpec from skillmodels.common.params_index import get_params_index @@ -27,6 +29,7 @@ ) +@beartype(conf=SIMULATION_CONF) def simulate_dataset( model_spec: ModelSpec, params: pd.DataFrame, @@ -108,7 +111,7 @@ def simulate_dataset( params = params.reindex(params_index) parsing_info = create_parsing_info( - params_index=params.index, # ty: ignore[invalid-argument-type] + params_index=params_index, update_info=processed_model.update_info, labels=processed_model.labels, anchoring=processed_model.anchoring, @@ -144,7 +147,7 @@ def simulate_dataset( update_info=processed_model.update_info, control_data=control_data, observed_factors=observed_factors, - policies=policies, # ty: ignore[invalid-argument-type] + policies=policies, transition_info=processed_model.transition_info, rng=rng, ) @@ -203,7 +206,7 @@ def _simulate_dataset( update_info: pd.DataFrame, control_data: Array, observed_factors: Array, - policies: list[dict], + policies: list[dict] | None, transition_info: TransitionInfo, rng: np.random.Generator, ) -> tuple[pd.DataFrame, pd.DataFrame]: @@ -329,8 +332,8 @@ def _simulate_dataset( meas = pd.DataFrame( data=measurements_from_states( rng=rng, - states=latent_states[t], # ty: ignore[invalid-argument-type] - controls=control_data[t], # ty: ignore[invalid-argument-type] + states=latent_states[t], + controls=control_data[t], loadings=loadings_df.loc[t].to_numpy(), control_params=control_params_df.loc[t].to_numpy(), sds=meas_sds.loc[t].to_numpy().flatten(), @@ -430,7 +433,7 @@ def _get_shock( mean: float, sd: float, size: int, -) -> NDArray[np.floating]: +) -> NDArray[np.floating] | Array: """Add stochastic effect to a factor of length n_obs. Args: @@ -457,8 +460,8 @@ def generate_start_states( n_obs: int, dimensions: Dimensions, dist_args: list[dict], - weights: NDArray[np.floating], -) -> NDArray[np.floating]: + weights: NDArray[np.floating] | Array, +) -> NDArray[np.floating] | Array: """Draw initial states and control variables from a (mixture of) normals. Args: @@ -489,12 +492,12 @@ def generate_start_states( def measurements_from_states( rng: np.random.Generator, - states: NDArray[np.floating], - controls: NDArray[np.floating], - loadings: NDArray[np.floating], - control_params: NDArray[np.floating], - sds: NDArray[np.floating], -) -> NDArray[np.floating]: + states: NDArray[np.floating] | Array, + controls: NDArray[np.floating] | Array, + loadings: NDArray[np.floating] | Array, + control_params: NDArray[np.floating] | Array, + sds: NDArray[np.floating] | Array, +) -> NDArray[np.floating] | Array: """Generate the variables that would be observed in practice. This generates the data for only one period. Let n_meas be the number @@ -523,6 +526,7 @@ def measurements_from_states( return states_part + control_part + epsilon +@beartype(conf=SIMULATION_CONF) def simulate_policy_effect( model_spec: ModelSpec, params: pd.DataFrame, diff --git a/src/skillmodels/common/state_ranges.py b/src/skillmodels/common/state_ranges.py index 4203c628..af7b88fc 100644 --- a/src/skillmodels/common/state_ranges.py +++ b/src/skillmodels/common/state_ranges.py @@ -14,8 +14,12 @@ """ import pandas as pd +from beartype import beartype +from skillmodels._beartype_conf import DIAGNOSTICS_CONF + +@beartype(conf=DIAGNOSTICS_CONF) def create_state_ranges( filtered_states: pd.DataFrame, factors: tuple[str, ...] | list[str], diff --git a/src/skillmodels/common/transition_functions.py b/src/skillmodels/common/transition_functions.py index d829fdae..7ca4d7d5 100644 --- a/src/skillmodels/common/transition_functions.py +++ b/src/skillmodels/common/transition_functions.py @@ -29,18 +29,15 @@ import functools from itertools import combinations -from typing import TYPE_CHECKING import jax import jax.numpy as jnp import optimagic as om from jax import Array +from skillmodels.common.fixed_constraint import FixedConstraintWithValue from skillmodels.common.selector import select_by_loc -if TYPE_CHECKING: - from skillmodels.common.constraints import FixedConstraintWithValue - def linear(states: Array, params: Array) -> Array: """Linear production function where the constant is the last parameter.""" @@ -60,8 +57,6 @@ def identity_constraints_linear( all_factors: tuple[str, ...], ) -> list[FixedConstraintWithValue]: """Identity constraints for linear transition function.""" - from skillmodels.common.constraints import FixedConstraintWithValue # noqa: PLC0415 - constraints: list[FixedConstraintWithValue] = [] for regressor in params_linear(all_factors): val = 1.0 if factor == regressor else 0.0 @@ -108,8 +103,6 @@ def identity_constraints_translog( all_factors: tuple[str, ...], ) -> list[FixedConstraintWithValue]: """Identity constraints for translog transition function.""" - from skillmodels.common.constraints import FixedConstraintWithValue # noqa: PLC0415 - constraints: list[FixedConstraintWithValue] = [] for regressor in params_translog(all_factors): val = 1.0 if factor == regressor else 0.0 @@ -296,8 +289,6 @@ def identity_constraints_linear_and_squares( all_factors: tuple[str, ...], ) -> list[FixedConstraintWithValue]: """Identity constraints for linear_and_squares transition function.""" - from skillmodels.common.constraints import FixedConstraintWithValue # noqa: PLC0415 - constraints: list[FixedConstraintWithValue] = [] for regressor in params_linear_and_squares(all_factors): val = 1.0 if factor == regressor else 0.0 diff --git a/src/skillmodels/common/types.py b/src/skillmodels/common/types.py index 2af70121..e2d70c24 100644 --- a/src/skillmodels/common/types.py +++ b/src/skillmodels/common/types.py @@ -10,6 +10,8 @@ import pandas as pd from jax import Array +from skillmodels._beartype_conf import MODEL_SPEC_CONF, beartype_init + def _make_immutable(value: Any) -> Any: # noqa: ANN401 """Recursively convert a value to its immutable equivalent.""" @@ -274,6 +276,7 @@ class EndogenousFactorsInfo: """Mapping from factor name to its `FactorInfo`.""" +@beartype_init(MODEL_SPEC_CONF) @dataclass(frozen=True) class Normalizations: """Normalizations for factor identification.""" diff --git a/src/skillmodels/common/variance_decomposition.py b/src/skillmodels/common/variance_decomposition.py index ee12eca6..383c1648 100644 --- a/src/skillmodels/common/variance_decomposition.py +++ b/src/skillmodels/common/variance_decomposition.py @@ -8,11 +8,14 @@ from collections.abc import Mapping import pandas as pd +from beartype import beartype +from skillmodels._beartype_conf import DIAGNOSTICS_CONF from skillmodels.common.model_spec import ModelSpec from skillmodels.common.process_model import process_model +@beartype(conf=DIAGNOSTICS_CONF) def decompose_measurement_variance( model_spec: ModelSpec, params: pd.DataFrame, @@ -179,6 +182,7 @@ def _compute_variance_decomposition( ] +@beartype(conf=DIAGNOSTICS_CONF) def summarize_measurement_reliability( variance_decomposition: pd.DataFrame, ) -> pd.DataFrame: diff --git a/src/skillmodels/common/visualize_factor_distributions.py b/src/skillmodels/common/visualize_factor_distributions.py index e25514af..b8494bad 100644 --- a/src/skillmodels/common/visualize_factor_distributions.py +++ b/src/skillmodels/common/visualize_factor_distributions.py @@ -10,16 +10,19 @@ import plotly.express as px import plotly.figure_factory as ff import plotly.graph_objects as go +from beartype import beartype from numpy.typing import NDArray from plotly.subplots import make_subplots from scipy.stats import gaussian_kde +from skillmodels._beartype_conf import DIAGNOSTICS_CONF from skillmodels.common.model_spec import ModelSpec from skillmodels.common.process_model import process_model from skillmodels.common.types import ProcessedModel from skillmodels.common.utils_plotting import get_layout_kwargs, get_make_subplot_kwargs +@beartype(conf=DIAGNOSTICS_CONF) def combine_distribution_plots( kde_plots: dict[str, go.Figure], contour_plots: dict[tuple[str, str], go.Figure], @@ -158,6 +161,7 @@ def combine_distribution_plots( return fig +@beartype(conf=DIAGNOSTICS_CONF) def univariate_densities( data: pd.DataFrame, model_spec: ModelSpec, @@ -264,6 +268,7 @@ def univariate_densities( return plots_dict +@beartype(conf=DIAGNOSTICS_CONF) def bivariate_density_contours( data: pd.DataFrame, model_spec: ModelSpec, @@ -384,6 +389,7 @@ def bivariate_density_contours( return plots_dict +@beartype(conf=DIAGNOSTICS_CONF) def bivariate_density_surfaces( data: pd.DataFrame, model_spec: ModelSpec, @@ -575,7 +581,7 @@ def _process_distplot_kwargs( show_rug: bool, curve_type: str, bin_size: float, - scenarios: NDArray[Any], + scenarios: NDArray[Any] | pd.api.extensions.ExtensionArray, colorscale: str, distplot_kwargs: dict[str, Any] | None, ) -> dict[str, Any]: @@ -598,9 +604,7 @@ def _calculate_kde_for_3d( data: pd.DataFrame, factors: tuple[str, str], n_points: int, -) -> tuple[ - NDArray[np.floating[Any]], NDArray[np.floating[Any]], NDArray[np.floating[Any]] -]: +) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]: """Create grid mesh and calculate Gaussian kernel over the grid.""" x = data[factors[0]] y = data[factors[1]] diff --git a/src/skillmodels/common/visualize_transition_equations.py b/src/skillmodels/common/visualize_transition_equations.py index 8b0db522..b53ec819 100644 --- a/src/skillmodels/common/visualize_transition_equations.py +++ b/src/skillmodels/common/visualize_transition_equations.py @@ -8,11 +8,13 @@ import jax.numpy as jnp import numpy as np import pandas as pd +from beartype import beartype from jax import Array from plotly import express as px from plotly import graph_objects as go from plotly.subplots import make_subplots +from skillmodels._beartype_conf import DIAGNOSTICS_CONF from skillmodels.common.model_spec import ModelSpec from skillmodels.common.params_index import get_params_index from skillmodels.common.parse_params import create_parsing_info, parse_params @@ -23,6 +25,7 @@ from skillmodels.common.utils_plotting import get_layout_kwargs, get_make_subplot_kwargs +@beartype(conf=DIAGNOSTICS_CONF) def combine_transition_plots( plots_dict: dict[tuple[str, str], go.Figure], column_order: list[str] | tuple[str, ...] | str | None = None, @@ -139,6 +142,7 @@ def combine_transition_plots( return fig +@beartype(conf=DIAGNOSTICS_CONF) def get_transition_plots( model_spec: ModelSpec, params: pd.DataFrame, diff --git a/src/skillmodels/exceptions.py b/src/skillmodels/exceptions.py new file mode 100644 index 00000000..c2ddfaf6 --- /dev/null +++ b/src/skillmodels/exceptions.py @@ -0,0 +1,66 @@ +"""Project-specific exception types raised by skillmodels' user-facing API. + +The beartype decorators applied at the public entry points +(`skillmodels._beartype_conf`) route parameter-type violations through +one of the classes defined here, so callers can write narrowly-scoped +`except` clauses against a stable skillmodels-specific hierarchy +instead of catching the framework-supplied `BeartypeCallHintParamViolation`. + +All classes inherit from `TypeError` so existing `except TypeError` +handlers continue to fire; the subclasses are additive. +""" + + +class SkillmodelsInputError(TypeError): + """Base class for all skillmodels parameter-validation errors.""" + + +class ModelSpecInitializationError(SkillmodelsInputError): + """Bad argument to a model-spec dataclass. + + Raised on construction of `ModelSpec`, `FactorSpec`, + `AnchoringSpec`, or `Normalizations`. + """ + + +class OptionsInitializationError(SkillmodelsInputError): + """Bad argument to an estimation-options dataclass. + + Raised on construction of `CHSEstimationOptions`, + `AFEstimationOptions`, or `AMNEstimationOptions`. + """ + + +class EstimationCallError(SkillmodelsInputError): + """Bad argument to an estimation entry point. + + Raised by `get_maximization_inputs`, `get_filtered_states`, + `estimate_af`, `estimate_amn`, `get_af_posterior_states`, or + `get_amn_posterior_states` when arguments don't match the + declared types. + """ + + +class InferenceCallError(SkillmodelsInputError): + """Bad argument to a standard-error / bootstrap helper. + + Raised by `compute_af_standard_errors` and + `compute_amn_standard_errors`. + """ + + +class SimulationCallError(SkillmodelsInputError): + """Bad argument to a simulation helper. + + Raised by `simulate_dataset` and `simulate_policy_effect`. + """ + + +class DiagnosticsCallError(SkillmodelsInputError): + """Bad argument to a diagnostics / visualisation helper. + + Raised by `decompose_measurement_variance`, + `summarize_measurement_reliability`, `plot_residual_boxplots`, + `plot_likelihood_contributions`, `create_state_ranges`, and the + factor-distribution / transition-equation plotting helpers. + """ diff --git a/tests/conftest.py b/tests/conftest.py index d3a2043d..3268d472 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,40 @@ -"""Shared test fixtures and helpers.""" +"""Shared test fixtures and helpers. -from dataclasses import replace -from pathlib import Path +Tests opt in to whole-package beartype via `beartype.claw.beartype_package` +here so that annotation drift on *internal* helpers surfaces as a +`BeartypeCallHintParamViolation` during the test run. The perimeter +decorators (`skillmodels._beartype_conf`) keep raising project-specific +exception classes for *user-facing* parameter violations; the claw-installed +checks below are for everything in between. -import pandas as pd -import pytest +`skillmodels.chs.qr` is skipped because it relies on JAX's `@custom_jvp` +decorator, which beartype.claw wraps in a way that strips the +`.defjvp` attribute that the second-stage `@qr_gpu.defjvp` decoration +needs. No annotations in that module are user-facing. +""" -from skillmodels.common.config import TEST_DATA_DIR -from skillmodels.test_data.model2 import MODEL2 +from beartype import BeartypeConf +from beartype.claw import beartype_package + +# Mirror the perimeter conf's PEP-484 numeric tower so `int` satisfies +# `float`-typed parameters. Without this every `value=1` call site +# (e.g. `FixedConstraintWithValue(value=1)`) trips the claw checker. +beartype_package( + "skillmodels", + conf=BeartypeConf( + is_pep484_tower=True, + claw_skip_package_names=("skillmodels.chs.qr",), + ), +) + +from dataclasses import replace # noqa: E402 +from pathlib import Path # noqa: E402 + +import pandas as pd # noqa: E402 +import pytest # noqa: E402 + +from skillmodels.common.config import TEST_DATA_DIR # noqa: E402 +from skillmodels.test_data.model2 import MODEL2 # noqa: E402 REGRESSION_VAULT = Path(__file__).parent / "regression_vault" diff --git a/tests/test_af_estimate.py b/tests/test_af_estimate.py index b9282b17..3c342b2d 100644 --- a/tests/test_af_estimate.py +++ b/tests/test_af_estimate.py @@ -1022,7 +1022,7 @@ def fn(full_states: jax.Array, params: jax.Array) -> jax.Array: ) inv = jnp.array([inv_val]) full = jnp.concatenate([theta_0, inv, obs_y]) - theta_next_det = transition_func(full, link.transition_params) + theta_next_det = transition_func(full, link.transition_params) # ty: ignore[invalid-argument-type] theta_0 = theta_next_det + jnp.array([link.shock_sds[0] * z_shock[0]]) expected = theta_0 # θ at the last link's target period diff --git a/tests/test_af_jaxopt_backend.py b/tests/test_af_jaxopt_backend.py index 4772f19e..eb9c5b91 100644 --- a/tests/test_af_jaxopt_backend.py +++ b/tests/test_af_jaxopt_backend.py @@ -73,8 +73,10 @@ def test_optimizer_backend_defaults_to_auto() -> None: def test_optimizer_backend_rejects_unknown_value() -> None: - """Typos in the backend name fail fast.""" - with pytest.raises(ValueError, match="optimizer_backend"): + """Typos in the backend name fail fast via the beartype perimeter.""" + from skillmodels.exceptions import OptionsInitializationError # noqa: PLC0415 + + with pytest.raises(OptionsInitializationError, match="optimizer_backend"): AFEstimationOptions(optimizer_backend="lbfgsb") # ty: ignore[invalid-argument-type] diff --git a/tests/test_amn_plot_harmonization.py b/tests/test_amn_plot_harmonization.py index 7946f648..2192d4a1 100644 --- a/tests/test_amn_plot_harmonization.py +++ b/tests/test_amn_plot_harmonization.py @@ -76,8 +76,21 @@ def test_get_filtered_states_dispatches_to_amn(amn_fit): def test_get_filtered_states_rejects_both_af_and_amn_results(amn_fit): + """Passing an AMN result to `af_result=` triggers the beartype perimeter. + + Pre-beartype, the test passed `fit` (an `AMNEstimationResult`) to + both `af_result=` and `amn_result=` and the function body's + `only one of` `ValueError` fired. Beartype now intercepts first + because `AMNEstimationResult` is not assignable to + `AFEstimationResult | None`. The body-level guard remains in + place for the still-valid case of two real results of the right + type; that combination requires fitting both estimators, which + this fixture deliberately skips. + """ + from skillmodels.exceptions import EstimationCallError # noqa: PLC0415 + fit, data = amn_fit - with pytest.raises(ValueError, match="only one of"): + with pytest.raises(EstimationCallError, match="af_result"): get_filtered_states( model_spec=fit.model_spec, data=data, diff --git a/tests/test_check_model.py b/tests/test_check_model.py index 8fbdb411..b0f6dc42 100644 --- a/tests/test_check_model.py +++ b/tests/test_check_model.py @@ -2,10 +2,11 @@ from types import SimpleNamespace +import pytest + from skillmodels.common.check_model import ( _check_anchoring, _check_loadings_are_not_normalized_to_zero, - _check_measurements, _check_normalized_variables_are_present, check_stagemap, ) @@ -30,7 +31,7 @@ def test_invalid_anchoring_non_bool() -> None: free_constant=False, free_loadings=False, ) - result = _check_anchoring(anchoring) # ty: ignore[invalid-argument-type] + result = _check_anchoring(anchoring) assert any("bool" in msg for msg in result) @@ -42,7 +43,7 @@ def test_invalid_anchoring_non_mapping_outcomes() -> None: free_constant=False, free_loadings=False, ) - result = _check_anchoring(anchoring) # ty: ignore[invalid-argument-type] + result = _check_anchoring(anchoring) assert any("Mapping" in msg for msg in result) @@ -54,7 +55,7 @@ def test_invalid_anchoring_outcome_type() -> None: free_constant=False, free_loadings=False, ) - result = _check_anchoring(anchoring) # ty: ignore[invalid-argument-type] + result = _check_anchoring(anchoring) assert any("variable" in msg.lower() for msg in result) @@ -66,34 +67,31 @@ def test_invalid_anchoring_free_controls_type() -> None: free_constant=False, free_loadings=False, ) - result = _check_anchoring(anchoring) # ty: ignore[invalid-argument-type] + result = _check_anchoring(anchoring) assert any("free_controls" in msg for msg in result) def test_invalid_measurements_not_tuples() -> None: - spec = ModelSpec( - factors={ - "f1": FactorSpec( - measurements=(["y1", "y2"],), # ty: ignore[invalid-argument-type] - ), - }, - ) - result = _check_measurements(model_spec=spec, factors=("f1",)) - assert any("tuples" in msg for msg in result) + """Bad measurements shape is caught at the FactorSpec beartype perimeter. + + Pre-beartype, the spec built and the model-check aggregator + surfaced a soft error message. Now the construction itself + raises `ModelSpecInitializationError`. The soft-check arm of + `_check_measurements` is dead code (kept only for non-type + issues that beartype can't see). + """ + from skillmodels.exceptions import ModelSpecInitializationError # noqa: PLC0415 + + with pytest.raises(ModelSpecInitializationError, match="measurements"): + FactorSpec(measurements=(["y1", "y2"],)) # ty: ignore[invalid-argument-type] def test_invalid_measurement_type() -> None: - spec = ModelSpec( - factors={ - "f1": FactorSpec( - measurements=((["nested_list"],),), # ty: ignore[invalid-argument-type] - ), - }, - ) - result = _check_measurements(model_spec=spec, factors=("f1",)) - assert any( - "column names" in msg.lower() or "tuples" in msg.lower() for msg in result - ) + """Bad measurement element type fails at `FactorSpec.__init__` (beartype).""" + from skillmodels.exceptions import ModelSpecInitializationError # noqa: PLC0415 + + with pytest.raises(ModelSpecInitializationError, match="measurements"): + FactorSpec(measurements=((["nested_list"],),)) # ty: ignore[invalid-argument-type] def test_normalized_variable_not_in_measurements() -> None: @@ -124,7 +122,7 @@ def test_invalid_anchoring_free_constant_type() -> None: free_constant="yes", free_loadings=False, ) - result = _check_anchoring(anchoring) # ty: ignore[invalid-argument-type] + result = _check_anchoring(anchoring) assert any("free_constant" in msg for msg in result) @@ -136,7 +134,7 @@ def test_invalid_anchoring_free_loadings_type() -> None: free_constant=False, free_loadings="yes", ) - result = _check_anchoring(anchoring) # ty: ignore[invalid-argument-type] + result = _check_anchoring(anchoring) assert any("free_loadings" in msg for msg in result) diff --git a/tests/test_correlation_heatmap.py b/tests/test_correlation_heatmap.py index 5b95e852..df515409 100644 --- a/tests/test_correlation_heatmap.py +++ b/tests/test_correlation_heatmap.py @@ -275,12 +275,12 @@ def test_process_factors() -> None: observed_factor = "g" factors = ["b", "d", "g"] all_factors = None - assert tuple("abcd") == _process_factors(model, all_factors)[0] # ty: ignore[invalid-argument-type] - assert tuple("efg") == _process_factors(model, all_factors)[1] # ty: ignore[invalid-argument-type] - assert (latent_factor,) == _process_factors(model, latent_factor)[0] # ty: ignore[invalid-argument-type] - assert (observed_factor,) == _process_factors(model, observed_factor)[1] # ty: ignore[invalid-argument-type] - assert tuple(factors[:-1]) == _process_factors(model, factors)[0] # ty: ignore[invalid-argument-type] - assert (factors[-1],) == _process_factors(model, factors)[1] # ty: ignore[invalid-argument-type] + assert tuple("abcd") == _process_factors(model, all_factors)[0] + assert tuple("efg") == _process_factors(model, all_factors)[1] + assert (latent_factor,) == _process_factors(model, latent_factor)[0] + assert (observed_factor,) == _process_factors(model, observed_factor)[1] + assert tuple(factors[:-1]) == _process_factors(model, factors)[0] + assert (factors[-1],) == _process_factors(model, factors)[1] def test_get_mask_lower_triangle_only() -> None: diff --git a/tests/test_process_debug_data.py b/tests/test_process_debug_data.py index 107c148d..eea1c593 100644 --- a/tests/test_process_debug_data.py +++ b/tests/test_process_debug_data.py @@ -26,7 +26,7 @@ def test_process_residuals_ids_with_mixtures() -> None: index=pd.MultiIndex.from_tuples([(0, "m1"), (0, "m2")]), ) - result = _process_residuals(residuals=residuals, update_info=update_info) # ty: ignore[invalid-argument-type] + result = _process_residuals(residuals=residuals, update_info=update_info) # For each update, ids should be [0, 0, 1, 1, 2, 2] not [0, 1, 2, 3, 4, 5] for _, group in result.groupby(["aug_period", "measurement"]): @@ -49,7 +49,7 @@ def test_create_post_update_states_ids_with_mixtures() -> None: ) result = _create_post_update_states( - filtered_states=filtered_states, # ty: ignore[invalid-argument-type] + filtered_states=filtered_states, factors=factors, update_info=update_info, ) diff --git a/tests/test_transition_functions.py b/tests/test_transition_functions.py index b85195d0..b6146e75 100644 --- a/tests/test_transition_functions.py +++ b/tests/test_transition_functions.py @@ -99,7 +99,9 @@ def test_where_all_but_one_gammas_are_zero() -> None: def test_constant() -> None: - assert constant("bla", "blubb") == "bla" # ty: ignore[invalid-argument-type] + state = jnp.array([1.0, 2.0, 3.0]) + params = jnp.array([]) + aaae(constant(state, params), state) def test_robust_translog() -> None: