diff --git a/nomenclature/processor/region.py b/nomenclature/processor/region.py index e853e648..c3255836 100644 --- a/nomenclature/processor/region.py +++ b/nomenclature/processor/region.py @@ -13,10 +13,12 @@ AfterValidator, BaseModel, ConfigDict, + Field, ValidationInfo, field_validator, model_validator, validate_call, + field_serializer, ) from pydantic.types import DirectoryPath, FilePath from pydantic_core import PydanticCustomError @@ -109,23 +111,11 @@ class RegionAggregationMapping(BaseModel): model: list[str] file: FilePath - native_regions: list[NativeRegion] | None = None - common_regions: list[CommonRegion] | None = None - exclude_regions: list[str] | None = None + native_regions: list[NativeRegion] = Field(default_factory=list) + common_regions: list[CommonRegion] = Field(default_factory=list) + exclude_regions: list[str] = Field(default_factory=list) - @model_validator(mode="before") - @classmethod - def check_no_additional_attributes(cls, v): - if illegal_additional_attributes := [ - input_attribute - for input_attribute in v.keys() - if input_attribute not in cls.model_fields - ]: - raise ValueError( - "Illegal attributes in 'RegionAggregationMapping': " - f"{illegal_additional_attributes} (file {v['file']})" - ) - return v + model_config = ConfigDict(extra="forbid") @field_validator("model", mode="before") @classmethod @@ -188,7 +178,7 @@ def check_native_or_common_regions( cls, v: "RegionAggregationMapping" ) -> "RegionAggregationMapping": # Check that we have at least one of the two: native and common regions - if v.native_regions is None and v.common_regions is None: + if not v.native_regions and not v.common_regions: raise ValueError( "At least one of 'native_regions' and 'common_regions' must be " f"provided in {v.file}" @@ -201,9 +191,7 @@ def check_illegal_renaming( cls, v: "RegionAggregationMapping" ) -> "RegionAggregationMapping": """Check if any renaming overlaps with common regions""" - # Skip if only either native-regions or common-regions are specified - if v.native_regions is None or v.common_regions is None: - return v + native_region_names = {nr.target_native_region for nr in v.native_regions} common_region_names = {cr.name for cr in v.common_regions} overlap = list(native_region_names & common_region_names) @@ -408,28 +396,35 @@ def check_unexpected_regions(self, df: IamDataFrame) -> None: def __eq__(self, other: "RegionAggregationMapping") -> bool: return self.model_dump(exclude={"file"}) == other.model_dump(exclude={"file"}) + @field_serializer("model", when_used="json") + def serialize_model(self, model) -> str | list[str]: + return model[0] if len(model) == 1 else model + + @field_serializer("native_regions", when_used="json") + def serialize_native_regions(self, native_regions) -> list: + return [ + ( + {native_region.name: native_region.rename} + if native_region.rename + else native_region.name + ) + for native_region in native_regions + ] + + @field_serializer("common_regions", when_used="json") + def serialize_common_regions(self, common_regions) -> list: + return [ + {common_region.name: common_region.constituent_regions} + for common_region in common_regions + ] + def to_yaml(self, file) -> None: - dict_representation = { - "model": self.model[0] if len(self.model) == 1 else self.model - } - if self.native_regions: - dict_representation["native_regions"] = [ - ( - {native_region.name: native_region.rename} - if native_region.rename - else native_region.name - ) - for native_region in self.native_regions - ] - if self.common_regions: - dict_representation["common_regions"] = [ - {common_region.name: common_region.constituent_regions} - for common_region in self.common_regions - ] - if self.exclude_regions: - dict_representation["exclude_regions"] = self.exclude_regions with open(file, "w", encoding="utf-8") as f: - yaml.dump(dict_representation, f, sort_keys=False) + yaml.dump( + self.model_dump(mode="json", exclude_defaults=True, exclude={"file"}), + f, + sort_keys=False, + ) def validate_with_definition(v: RegionAggregationMapping, info: ValidationInfo): @@ -634,70 +629,64 @@ def _apply_region_processing( # silence pyam's empty filter warnings with adjust_log_level(logger="pyam", level="ERROR"): # rename native regions - if self.mappings[model].native_regions is not None: - _df = model_df.filter( - region=self.mappings[model].model_native_region_names + _df = model_df.filter(region=self.mappings[model].model_native_region_names) + if not _df.empty: + _processed_data.append( + _df.rename(region=self.mappings[model].rename_mapping)._data ) - if not _df.empty: - _processed_data.append( - _df.rename(region=self.mappings[model].rename_mapping)._data - ) # aggregate common regions - if self.mappings[model].common_regions is not None: - for common_region in self.mappings[model].common_regions: - # if a common region is consists of a single native region, rename - if common_region.is_single_constituent_region: - _df = model_df.filter( - region=common_region.constituent_regions[0] - ).rename(region=common_region.rename_dict) - if not _df.empty: - _processed_data.append(_df._data) - continue + for common_region in self.mappings[model].common_regions: + # if a common region is consists of a single native region, rename + if common_region.is_single_constituent_region: + _df = model_df.filter( + region=common_region.constituent_regions[0] + ).rename(region=common_region.rename_dict) + if not _df.empty: + _processed_data.append(_df._data) + continue - # if there are multiple constituent regions, aggregate - regions = [common_region.name, common_region.constituent_regions] + # if there are multiple constituent regions, aggregate + regions = [common_region.name, common_region.constituent_regions] - # first, perform 'simple' aggregation (no arguments) - simple_vars = [ - var - for var in self.variable_codelist.vars_default_args( - model_df.variable - ) - ] - _df = model_df.aggregate_region( - simple_vars, - *regions, + # first, perform 'simple' aggregation (no arguments) + simple_vars = [ + var + for var in self.variable_codelist.vars_default_args( + model_df.variable ) - if _df is not None and not _df.empty: - _processed_data.append(_df._data) - - # second, special weighted aggregation - for var in self.variable_codelist.vars_kwargs(model_df.variable): - if var.region_aggregation is None: - _df = _aggregate_region( - model_df, - var.name, - *regions, - **var.pyam_agg_kwargs, - ) - if _df is not None and not _df.empty: - _processed_data.append(_df._data) - else: - for rename_var in var.region_aggregation: - for _rename, _kwargs in rename_var.items(): - _df = _aggregate_region( - model_df, - var.name, - *regions, - **_kwargs, + ] + _df = model_df.aggregate_region( + simple_vars, + *regions, + ) + if _df is not None and not _df.empty: + _processed_data.append(_df._data) + + # second, special weighted aggregation + for var in self.variable_codelist.vars_kwargs(model_df.variable): + if var.region_aggregation is None: + _df = _aggregate_region( + model_df, + var.name, + *regions, + **var.pyam_agg_kwargs, + ) + if _df is not None and not _df.empty: + _processed_data.append(_df._data) + else: + for rename_var in var.region_aggregation: + for _rename, _kwargs in rename_var.items(): + _df = _aggregate_region( + model_df, + var.name, + *regions, + **_kwargs, + ) + if _df is not None and not _df.empty: + _processed_data.append( + _df.rename(variable={var.name: _rename})._data ) - if _df is not None and not _df.empty: - _processed_data.append( - _df.rename( - variable={var.name: _rename} - )._data - ) common_region_df = model_df.filter( region=self.mappings[model].common_region_names, diff --git a/tests/test_region_aggregation.py b/tests/test_region_aggregation.py index 5640b71f..85b16555 100644 --- a/tests/test_region_aggregation.py +++ b/tests/test_region_aggregation.py @@ -43,7 +43,7 @@ def test_mapping(): "constituent_regions": ["region_c"], }, ], - "exclude_regions": None, + "exclude_regions": [], } assert obs.model_dump() == exp @@ -51,10 +51,6 @@ def test_mapping(): @pytest.mark.parametrize( "file, error_msg_pattern", [ - ( - "illegal_mapping_illegal_attribute.yaml", - "Illegal attributes in 'RegionAggregationMapping'", - ), ( "illegal_mapping_conflict_regions.yaml", "Name collision in native and common regions.*common_region_1", @@ -92,6 +88,15 @@ def test_illegal_mappings(file, error_msg_pattern): RegionAggregationMapping.from_file(TEST_FOLDER_REGION_AGGREGATION / file) +def test_illegal_additional_attribute(): + with pytest.raises( + pydantic.ValidationError, match="Extra inputs are not permitted" + ): + RegionAggregationMapping.from_file( + TEST_FOLDER_REGION_AGGREGATION / "illegal_mapping_illegal_attribute.yaml" + ) + + def test_mapping_parsing_error(): with pytest.raises(ValueError, match="string indices must be integers"): RegionAggregationMapping.from_file( @@ -119,15 +124,15 @@ def test_region_processor_working(region_processor_path, simple_definition): "native_regions": [ {"name": "World", "rename": None}, ], - "common_regions": None, - "exclude_regions": None, + "common_regions": [], + "exclude_regions": [], }, { "model": ["model_b"], "file": ( TEST_FOLDER_REGION_PROCESSING / "regionprocessor_working/mapping_2.yaml" ).relative_to(Path.cwd()), - "native_regions": None, + "native_regions": [], "common_regions": [ { "name": "World",