Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial plants model draft #296

Merged
merged 24 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
628d407
First draft of plants_model using new structures
davidorme Aug 23, 2023
e47266c
Placeholder PFT definition
davidorme Aug 23, 2023
42bbc0b
Placeholder plants constants dataclass
davidorme Aug 23, 2023
b93c39f
Tests of plant model PFT classes
davidorme Aug 23, 2023
40618c6
Rejigged required __init__ vars
davidorme Aug 24, 2023
534f2e9
Local shared fixtures for plant model testing
davidorme Aug 25, 2023
3dc2ae3
Updated plants functional_types test
davidorme Aug 25, 2023
049e826
Plant community object, plant cohorts and tests
davidorme Aug 25, 2023
827c2de
Updating plant model to build the community, docs and tidying
davidorme Aug 25, 2023
f6edd88
More informative name for Plants class
davidorme Aug 25, 2023
513d327
Unnecessary test __init__ file
davidorme Aug 25, 2023
6d9a4c8
Clearer name for Plants data containing attribute
davidorme Aug 25, 2023
62f12bc
Updated plant model init vars description
davidorme Aug 29, 2023
6f2738b
Fixed dumb constants bug, added two more update vars to list
davidorme Aug 29, 2023
b7fc21f
Updated PFT config with better variable name for maxh
davidorme Aug 29, 2023
0b3c6d1
Renamed PlantFunctionalTypes to Flora, defended against duplicate PFT…
davidorme Aug 29, 2023
f8da1dd
Bug fix in plant functional type checking
davidorme Aug 29, 2023
470146f
Updating PlantCommunities tests to new flora argument
davidorme Aug 29, 2023
b74c916
Added default PlantConsts to PlantsModel __init__
davidorme Aug 29, 2023
0b7dc32
Added minimal PlantsModel tests, aligned test variable names with req…
davidorme Aug 29, 2023
4bff259
Last tweaks from @vgro on #296
davidorme Aug 30, 2023
d599554
Merge branch 'develop' into feature/plants_model
davidorme Sep 1, 2023
101939f
Updating PR to use new model registration process
davidorme Sep 1, 2023
9c2e25f
Typo in consts class name
davidorme Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/models/plants/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Fixtures for plants model testing."""

import numpy as np
import pytest
from xarray import DataArray


@pytest.fixture()
def plant_config(shared_datadir):
"""Simple configuration fixture for use in tests."""

from virtual_rainforest.core.config import Config

return Config(shared_datadir / "all_config.toml")


@pytest.fixture()
def pfts(plant_config):
"""Construct a minimal PlantFunctionalType object."""
from virtual_rainforest.models.plants.functional_types import PlantFunctionalTypes

pfts = PlantFunctionalTypes.from_config(plant_config)

return pfts


@pytest.fixture()
def plants_data():
"""Construct a minimal data object with plant cohort data."""
from virtual_rainforest.core.data import Data
from virtual_rainforest.core.grid import Grid

data = Data(grid=Grid(cell_ny=2, cell_nx=2))
data["plant_cohort_n"] = DataArray(np.array([5] * 4))
data["plant_cohort_pft"] = DataArray(np.array(["tree"] * 4))
data["plant_cohort_cell_id"] = DataArray(np.arange(4))
data["plant_cohort_dbh"] = DataArray(np.array([0.1] * 4))

return data
51 changes: 51 additions & 0 deletions tests/models/plants/data/all_config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# This file contains every tag required to build the config for the plants and soil
# modules. Each value has been chosen to be different from the default value, so that
# this file can be used to test that providing non-default values works.
[core]
modules = ["plants", "soil", "animals"]

[core.grid]
cell_nx = 10
cell_ny = 10

[core.timing]
start_date = "2020-01-01"
update_interval = "2 weeks"
run_length = "50 years"

[core.data_output_options]
save_initial_state = true
save_final_state = true
out_initial_file_name = "model_at_start.nc"
out_final_file_name = "model_at_end.nc"

[plants]
a_plant_integer = 12

[[plants.ftypes]]
pft_name = "shrub"
maxh = 1.0

[[plants.ftypes]]
pft_name = "broadleaf"
maxh = 50.0

[[animals.functional_groups]]
name = "carnivorous_bird"
taxa = "bird"
diet = "carnivore"

[[animals.functional_groups]]
name = "herbivorous_bird"
taxa = "bird"
diet = "herbivore"

[[animals.functional_groups]]
name = "carnivorous_mammal"
taxa = "mammal"
diet = "carnivore"

[[animals.functional_groups]]
name = "herbivorous_mammal"
taxa = "mammal"
diet = "herbivore"
105 changes: 105 additions & 0 deletions tests/models/plants/test_community.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Tests the plant community model code."""

from contextlib import nullcontext as does_not_raise
from logging import CRITICAL, INFO

import numpy as np
import pytest
from xarray import DataArray

from tests.conftest import log_check


@pytest.mark.parametrize(
argnames="vars,raises,exp_log",
argvalues=[
pytest.param(
(("plant_cohort_n", DataArray(np.array([5] * 4))),),
pytest.raises(ValueError),
((CRITICAL, "Missing plant cohort variables"),),
id="missing var",
),
pytest.param(
(
("plant_cohort_n", DataArray(np.array([5] * 9), dims="toolong")),
("plant_cohort_pft", DataArray(np.array(["shrub"] * 4))),
("plant_cohort_cell_id", DataArray(np.arange(4))),
("plant_cohort_dbh", DataArray(np.array([0.1] * 4))),
),
pytest.raises(ValueError),
((CRITICAL, "Unequal plant cohort variable dimensions"),),
id="unequal sizes",
),
pytest.param(
(
("plant_cohort_n", DataArray(np.array([5] * 4).reshape(2, 2))),
("plant_cohort_pft", DataArray(np.array(["shrub"] * 4).reshape(2, 2))),
("plant_cohort_cell_id", DataArray(np.arange(4).reshape(2, 2))),
("plant_cohort_dbh", DataArray(np.array([0.1] * 4).reshape(2, 2))),
),
pytest.raises(ValueError),
((CRITICAL, "Plant cohort variable data is not one dimensional"),),
id="not 1D",
),
pytest.param(
(
("plant_cohort_n", DataArray(np.array([5] * 4))),
("plant_cohort_pft", DataArray(np.array(["shrub"] * 4))),
("plant_cohort_cell_id", DataArray(DataArray(np.arange(2, 6)))),
("plant_cohort_dbh", DataArray(np.array([0.1] * 4))),
),
pytest.raises(ValueError),
((CRITICAL, "Plant cohort cell ids not in grid cell ids"),),
id="bad cell ids",
),
pytest.param(
(
("plant_cohort_n", DataArray(np.array([5] * 4))),
("plant_cohort_pft", DataArray(np.array(["tree"] * 4))),
("plant_cohort_cell_id", DataArray(DataArray(np.arange(4)))),
("plant_cohort_dbh", DataArray(np.array([0.1] * 4))),
),
pytest.raises(ValueError),
((CRITICAL, "Plant cohort PFTs ids not in configured PFTs"),),
id="bad pfts",
),
pytest.param(
(
("plant_cohort_n", DataArray(np.array([5] * 4))),
("plant_cohort_pft", DataArray(np.array(["shrub"] * 4))),
("plant_cohort_cell_id", DataArray(DataArray(np.arange(4)))),
("plant_cohort_dbh", DataArray(np.array([0.1] * 4))),
),
does_not_raise(),
((INFO, "Plant cohort data loaded"),),
id="all good",
),
],
)
def test_PlantCommunities__init__(caplog, vars, pfts, raises, exp_log):
"""Test the data handling of the plants __init__."""

from virtual_rainforest.core.data import Data
from virtual_rainforest.core.grid import Grid
from virtual_rainforest.models.plants.community import PlantCommunities

data = Data(grid=Grid(cell_ny=2, cell_nx=2))

for var, value in vars:
data[var] = value

# Clear any data loading log entries
caplog.clear()

with raises:
plants_obj = PlantCommunities(data, pfts=pfts)

if isinstance(raises, does_not_raise):
# Check the expected contents of plants_obj
assert len(plants_obj.communities) == 4
cids = set([0, 1, 2, 3])
assert set(plants_obj.communities.keys()) == cids
for cid in cids:
assert len(plants_obj[cid]) == 1

log_check(caplog, expected_log=exp_log)
43 changes: 43 additions & 0 deletions tests/models/plants/test_functional_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test module for plants.functional_types.py.

This module tests the functionality of the plant functional types submodule.
"""


def test_plant_functional_type():
"""Simple test of PlantFunctionalType dataclass."""
from virtual_rainforest.models.plants.functional_types import PlantFunctionalType

pft = PlantFunctionalType(pft_name="tree", maxh=12.0)

assert pft.pft_name == "tree"
assert pft.maxh == 12.0


def test_plant_functional_types__init__():
"""Simple test of PlantFunctionalTypes __init__."""
from virtual_rainforest.models.plants.functional_types import (
PlantFunctionalType,
PlantFunctionalTypes,
)

pfts = PlantFunctionalTypes(
{
"shrub": PlantFunctionalType(pft_name="shrub", maxh=1.0),
"broadleaf": PlantFunctionalType(pft_name="broadleaf", maxh=50.0),
}
)

assert len(pfts) == 2
assert tuple(pfts.keys()) == ("shrub", "broadleaf")


def test_plant_functional_types_from_config(plant_config):
"""Simple test of PlantFunctionalTypes from_config factory method."""

from virtual_rainforest.models.plants.functional_types import PlantFunctionalTypes

pfts = PlantFunctionalTypes.from_config(plant_config)

assert len(pfts) == 2
assert tuple(pfts.keys()) == ("shrub", "broadleaf")
17 changes: 16 additions & 1 deletion virtual_rainforest/models/plants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
"""The :mod:`~virtual_rainforest.models.plants` module provides a plants model for use
in the Virtual Rainforest. The submodules provide:

* The :mod:`~virtual_rainforest.models.plants.plants_model` submodule provides the
PlantsModel class as the main API for interacting with the plants model.
* The :mod:`~virtual_rainforest.models.plants.constants` submodule provides dataclasses
containing constants used in the model.
* The :mod:`~virtual_rainforest.models.plants.community` submodule provides the
PlantCohort dataclass that records the details of an individual cohort and the
PlantCommunities class that records list of plant cohorts by grid cell.
""" # noqa: D205, D415

from importlib import resources

from virtual_rainforest.core.config import register_schema
from virtual_rainforest.models.plants.plants_model import PlantsModel

with resources.path(
"virtual_rainforest.models.plants", "plants_schema.json"
) as schema_file_path:
register_schema(module_name="plants", schema_file_path=schema_file_path)
register_schema(
module_name=PlantsModel.model_name, schema_file_path=schema_file_path
)
121 changes: 121 additions & 0 deletions virtual_rainforest/models/plants/community.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""The :mod:`~virtual_rainforest.models.plants.community` submodule provides a simple
dataclass to hold plant cohort information and then the ``PlantCommunities`` class that
is used to hold list of plant cohorts by grid cell and generate those lists from
variables provided in the data object.

NOTE - much of this will be outsourced to pyrealm.

""" # noqa: D205, D415


from dataclasses import dataclass

from virtual_rainforest.core.data import Data
from virtual_rainforest.core.logger import LOGGER
from virtual_rainforest.models.plants.functional_types import (
PlantFunctionalType,
PlantFunctionalTypes,
)


@dataclass
class PlantCohort:
"""A dataclass describing a plant cohort.

The cohort is defined by the plant functional type, the number of individuals in the
cohort and the diameter at breast height for the cohort.
"""

pft: PlantFunctionalType
dbh: float
n: int


class PlantCommunities:
"""A dictionary of plant cohorts keyed by grid cell id.

An instance of this class is initialised from a
:class:`~virtual_rainforest.core.data.Data` object that must contain the variables
``plant_cohort_cell_id``, ``plant_cohort_pft``, ``plant_cohort_n`` and
``plant_cohort_dbh``. These are required to be equal length, one-dimensional arrays
that provide the data to initialise each plant cohort. The data are validated and
then compiled into lists of cohorts keyed by grid cell id. The class provides a
__getitem__ method to allow the list of cohorts for a grid cell to be accessed using
``plants_inst[cell_id]``.

Args:
data: A data instance containing the required plant cohort data.
pfts: The plant functional types to be used.
"""

def __init__(self, data: Data, pfts: PlantFunctionalTypes):
self.communities: dict = dict()
"""A dictionary holding the lists of plant cohorts keyed by cell id."""

# Validate the data being used to generate the Plants object
cohort_data_vars = [
"plant_cohort_n",
"plant_cohort_pft",
"plant_cohort_cell_id",
"plant_cohort_dbh",
]

# All vars present
missing_var = [v for v in cohort_data_vars if v not in data]

if missing_var:
msg = f"Missing plant cohort variables: {', '.join(missing_var)}"
LOGGER.critical(msg)
raise ValueError(msg)

# All vars identically sized and 1D
data_shapes = [data[var].shape for var in cohort_data_vars]

if len(set(data_shapes)) != 1:
msg = (
f"Unequal plant cohort variable dimensions:"
f" {','.join([str(v) for v in set(data_shapes)])}"
)
LOGGER.critical(msg)
raise ValueError(msg)

if len(data_shapes[0]) != 1:
msg = "Plant cohort variable data is not one dimensional"
LOGGER.critical(msg)
raise ValueError(msg)

# Check the grid cell id and pft values are all known
bad_cid = set(data["plant_cohort_cell_id"].data).difference(data.grid.cell_id)
if bad_cid:
msg = (
f"Plant cohort cell ids not in grid cell "
f"ids: {','.join([str(c) for c in bad_cid])}"
)
LOGGER.critical(msg)
raise ValueError(msg)

bad_pft = set(data["plant_cohort_pft"].data).difference(pfts.keys())
if bad_pft:
msg = f"Plant cohort PFTs ids not in configured PFTs: {','.join(bad_pft)}"
LOGGER.critical(msg)
raise ValueError(msg)

# Now compile the plant cohorts adding each cohort to a list keyed by cell id
for cid in data.grid.cell_id:
self.communities[cid] = []

for cid, chrt_pft, chrt_dbh, chrt_n in zip(
data["plant_cohort_cell_id"].data,
data["plant_cohort_pft"].data,
data["plant_cohort_dbh"].data,
data["plant_cohort_n"].data,
):
self.communities[cid].append(
PlantCohort(pft=pfts[chrt_pft], dbh=chrt_dbh, n=chrt_n)
)

LOGGER.info("Plant cohort data loaded")

def __getitem__(self, key: int) -> list[PlantCohort]:
"""Extracts the cohort list for a given cell id."""
return self.communities[key]
Loading