clm5/python/ctsm/crop_calendars/cropcal_utils.py
2024-05-09 15:14:01 +08:00

398 lines
13 KiB
Python

"""
utility functions
copied from klindsay, https://github.com/klindsay28/CESM2_coup_carb_cycle_JAMES/blob/master/utils.py
"""
import numpy as np
import xarray as xr
def define_pftlist():
"""
Return list of PFTs used in CLM
"""
pftlist = [
"not_vegetated",
"needleleaf_evergreen_temperate_tree",
"needleleaf_evergreen_boreal_tree",
"needleleaf_deciduous_boreal_tree",
"broadleaf_evergreen_tropical_tree",
"broadleaf_evergreen_temperate_tree",
"broadleaf_deciduous_tropical_tree",
"broadleaf_deciduous_temperate_tree",
"broadleaf_deciduous_boreal_tree",
"broadleaf_evergreen_shrub",
"broadleaf_deciduous_temperate_shrub",
"broadleaf_deciduous_boreal_shrub",
"c3_arctic_grass",
"c3_non-arctic_grass",
"c4_grass",
"unmanaged_c3_crop",
"unmanaged_c3_irrigated",
"temperate_corn",
"irrigated_temperate_corn",
"spring_wheat",
"irrigated_spring_wheat",
"winter_wheat",
"irrigated_winter_wheat",
"soybean",
"irrigated_soybean",
"barley",
"irrigated_barley",
"winter_barley",
"irrigated_winter_barley",
"rye",
"irrigated_rye",
"winter_rye",
"irrigated_winter_rye",
"cassava",
"irrigated_cassava",
"citrus",
"irrigated_citrus",
"cocoa",
"irrigated_cocoa",
"coffee",
"irrigated_coffee",
"cotton",
"irrigated_cotton",
"datepalm",
"irrigated_datepalm",
"foddergrass",
"irrigated_foddergrass",
"grapes",
"irrigated_grapes",
"groundnuts",
"irrigated_groundnuts",
"millet",
"irrigated_millet",
"oilpalm",
"irrigated_oilpalm",
"potatoes",
"irrigated_potatoes",
"pulses",
"irrigated_pulses",
"rapeseed",
"irrigated_rapeseed",
"rice",
"irrigated_rice",
"sorghum",
"irrigated_sorghum",
"sugarbeet",
"irrigated_sugarbeet",
"sugarcane",
"irrigated_sugarcane",
"sunflower",
"irrigated_sunflower",
"miscanthus",
"irrigated_miscanthus",
"switchgrass",
"irrigated_switchgrass",
"tropical_corn",
"irrigated_tropical_corn",
"tropical_soybean",
"irrigated_tropical_soybean",
]
return pftlist
def ivt_str2int(ivt_str):
"""
Get CLM ivt number corresponding to a given name
"""
pftlist = define_pftlist()
if isinstance(ivt_str, str):
ivt_int = pftlist.index(ivt_str)
elif isinstance(ivt_str, (list, np.ndarray)):
ivt_int = [ivt_str2int(x) for x in ivt_str]
if isinstance(ivt_str, np.ndarray):
ivt_int = np.array(ivt_int)
else:
raise RuntimeError(
f"Update ivt_str_to_int() to handle input of type {type(ivt_str)} (if possible)"
)
return ivt_int
def ivt_int2str(ivt_int):
"""
Get CLM ivt name corresponding to a given number
"""
pftlist = define_pftlist()
if np.issubdtype(type(ivt_int), np.integer) or int(ivt_int) == ivt_int:
ivt_str = pftlist[int(ivt_int)]
elif isinstance(ivt_int, (list, np.ndarray)):
ivt_str = [ivt_int2str(x) for x in ivt_int]
if isinstance(ivt_int, np.ndarray):
ivt_str = np.array(ivt_str)
elif isinstance(ivt_int, float):
raise RuntimeError("List indices must be integers")
else:
raise RuntimeError(
f"Update ivt_str_to_int() to handle input of type {type(ivt_int)} (if possible)"
)
return ivt_str
def is_this_vegtype(this_vegtype, this_filter, this_method):
"""
Does this vegetation type's name match (for a given comparison method) any member of a filtering
list?
Methods:
ok_contains: True if any member of this_filter is found in this_vegtype.
notok_contains: True of no member of this_filter is found in this_vegtype.
ok_exact: True if this_vegtype matches any member of this_filter
exactly.
notok_exact: True if this_vegtype does not match any member of
this_filter exactly.
"""
# Make sure data type of this_vegtype is acceptable
if isinstance(this_vegtype, float) and int(this_vegtype) == this_vegtype:
this_vegtype = int(this_vegtype)
data_type_ok = lambda x: isinstance(x, (int, np.int64, str))
ok_input = True
if not data_type_ok(this_vegtype):
if isinstance(this_vegtype, xr.core.dataarray.DataArray):
this_vegtype = this_vegtype.values
if isinstance(this_vegtype, (list, np.ndarray)):
if len(this_vegtype) == 1 and data_type_ok(this_vegtype[0]):
this_vegtype = this_vegtype[0]
elif data_type_ok(this_vegtype[0]):
raise TypeError(
"is_this_vegtype(): this_vegtype must be a single string or integer, not a list"
" of them. Did you mean to call is_each_vegtype() instead?"
)
else:
ok_input = False
else:
ok_input = False
if not ok_input:
raise TypeError(
"is_this_vegtype(): First argument (this_vegtype) must be a string or integer, not"
f" {type(this_vegtype)}"
)
# Make sure data type of this_filter is acceptable
if not np.iterable(this_filter):
raise TypeError(
"is_this_vegtype(): Second argument (this_filter) must be iterable (e.g., a list), not"
f" {type(this_filter)}"
)
# Perform the comparison
if this_method == "ok_contains":
return any(n in this_vegtype for n in this_filter)
if this_method == "notok_contains":
return not any(n in this_vegtype for n in this_filter)
if this_method == "ok_exact":
return any(n == this_vegtype for n in this_filter)
if this_method == "notok_exact":
return not any(n == this_vegtype for n in this_filter)
raise ValueError(f"Unknown comparison method: '{this_method}'")
def is_each_vegtype(this_vegtypelist, this_filter, this_method):
"""
Get boolean list of whether each vegetation type in list is a managed crop
this_vegtypelist: The list of vegetation types whose members you want to test.
this_filter: The list of strings against which you want to compare each member of
this_vegtypelist.
this_method: How you want to do the comparison. See is_this_vegtype().
"""
if isinstance(this_vegtypelist, xr.DataArray):
this_vegtypelist = this_vegtypelist.values
return [is_this_vegtype(x, this_filter, this_method) for x in this_vegtypelist]
def define_mgdcrop_list():
"""
List (strings) of managed crops in CLM.
"""
notcrop_list = ["tree", "grass", "shrub", "unmanaged", "not_vegetated"]
defined_pftlist = define_pftlist()
is_crop = is_each_vegtype(defined_pftlist, notcrop_list, "notok_contains")
return [defined_pftlist[i] for i, x in enumerate(is_crop) if x]
def vegtype_str2int(vegtype_str, vegtype_mainlist=None):
"""
Convert list of vegtype strings to integer index equivalents.
"""
convert_to_ndarray = not isinstance(vegtype_str, np.ndarray)
if convert_to_ndarray:
vegtype_str = np.array(vegtype_str)
if isinstance(vegtype_mainlist, xr.Dataset):
vegtype_mainlist = vegtype_mainlist.vegtype_str.values
elif isinstance(vegtype_mainlist, xr.DataArray):
vegtype_mainlist = vegtype_mainlist.values
elif vegtype_mainlist is None:
vegtype_mainlist = define_pftlist()
if not isinstance(vegtype_mainlist, list) and isinstance(vegtype_mainlist[0], str):
if isinstance(vegtype_mainlist, list):
raise TypeError(
f"Not sure how to handle vegtype_mainlist as list of {type(vegtype_mainlist[0])}"
)
raise TypeError(
f"Not sure how to handle vegtype_mainlist as type {type(vegtype_mainlist[0])}"
)
if vegtype_str.shape == ():
indices = np.array([-1])
else:
indices = np.full(len(vegtype_str), -1)
for vegtype_str_2 in np.unique(vegtype_str):
indices[np.where(vegtype_str == vegtype_str_2)] = vegtype_mainlist.index(vegtype_str_2)
if convert_to_ndarray:
indices = [int(x) for x in indices]
return indices
def get_patch_ivts(this_ds, this_pftlist):
"""
Get PFT of each patch, in both integer and string forms.
"""
# First, get all the integer values; should be time*pft or pft*time. We will eventually just
# take the first timestep.
vegtype_int = this_ds.patches1d_itype_veg
vegtype_int.values = vegtype_int.values.astype(int)
# Convert to strings.
vegtype_str = list(np.array(this_pftlist)[vegtype_int.values])
# Return a dictionary with both results
return {"int": vegtype_int, "str": vegtype_str, "all_str": this_pftlist}
def get_vegtype_str_da(vegtype_str):
"""
Convert a list of strings with vegetation type names into a DataArray.
"""
nvt = len(vegtype_str)
vegtype_str_da = xr.DataArray(
vegtype_str, coords={"ivt": np.arange(0, nvt)}, dims=["ivt"], name="vegtype_str"
)
return vegtype_str_da
def safer_timeslice(ds_in, time_slice, time_var="time"):
"""
ctsm_pylib can't handle time slicing like Dataset.sel(time=slice("1998-01-01", "2005-12-31"))
for some reason. This function tries to fall back to slicing by integers. It should work with
both Datasets and DataArrays.
"""
try:
ds_in = ds_in.sel({time_var: time_slice})
except: # pylint: disable=bare-except
# If the issue might have been slicing using strings, try to fall back to integer slicing
can_try_integer_slicing = (
isinstance(time_slice.start, str)
and isinstance(time_slice.stop, str)
and len(time_slice.start.split("-")) == 3
and time_slice.start.split("-")[1:] == ["01", "01"]
and len(time_slice.stop.split("-")) == 3
and (
time_slice.stop.split("-")[1:] == ["12", "31"]
or time_slice.stop.split("-")[1:] == ["01", "01"]
)
)
if can_try_integer_slicing:
fileyears = np.array([x.year for x in ds_in.time.values])
if len(np.unique(fileyears)) != len(fileyears):
print("Could not fall back to integer slicing of years: Time axis not annual")
raise
y_start = int(time_slice.start.split("-")[0])
y_stop = int(time_slice.stop.split("-")[0])
where_in_timeslice = np.where((fileyears >= y_start) & (fileyears <= y_stop))[0]
ds_in = ds_in.isel({time_var: where_in_timeslice})
else:
print(f"Could not fall back to integer slicing for time_slice {time_slice}")
raise
return ds_in
def lon_idl2pm(lons_in, fail_silently=False):
"""
Convert a longitude axis that's -180 to 180 around the international date line to one that's 0
to 360 around the prime meridian.
- If you pass in a Dataset or DataArray, the "lon" coordinates will be changed. Otherwise, it
assumes you're passing in numeric data.
"""
def check_ok(tmp, fail_silently):
msg = ""
if np.any(tmp > 180):
msg = f"Maximum longitude is already > 180 ({np.max(tmp)})"
elif np.any(tmp < -180):
msg = f"Minimum longitude is < -180 ({np.min(tmp)})"
if msg == "":
return True
if fail_silently:
return False
raise ValueError(msg)
def do_it(tmp):
tmp = tmp + 360
tmp = np.mod(tmp, 360)
return tmp
if isinstance(lons_in, (xr.DataArray, xr.Dataset)):
if not check_ok(lons_in.lon.values, fail_silently):
return lons_in
lons_out = lons_in
lons_out = lons_out.assign_coords(lon=do_it(lons_in.lon.values))
lons_out = make_lon_increasing(lons_out)
else:
if not check_ok(lons_in, fail_silently):
return lons_in
lons_out = do_it(lons_in)
if not is_strictly_increasing(lons_out):
print(
"WARNING: You passed in numeric longitudes to lon_idl2pm() and these have been"
" converted, but they're not strictly increasing."
)
print(
"To assign the new longitude coordinates to an Xarray object, use"
" xarrayobject.assign_coordinates()! (Pass the object directly in to lon_idl2pm() in"
" order to suppress this message.)"
)
return lons_out
def is_strictly_increasing(this_list):
"""
Helper function to check that a list is strictly increasing
https://stackoverflow.com/a/4983359/2965321
"""
return all(x < y for x, y in zip(this_list, this_list[1:]))
def make_lon_increasing(xr_obj):
"""
Ensure that longitude axis coordinates are monotonically increasing
"""
if not "lon" in xr_obj.dims:
return xr_obj
lons = xr_obj.lon.values
if is_strictly_increasing(lons):
return xr_obj
shift = 0
while not is_strictly_increasing(lons) and shift < lons.size:
shift = shift + 1
lons = np.roll(lons, 1, axis=0)
if not is_strictly_increasing(lons):
raise RuntimeError("Unable to rearrange longitude axis so it's monotonically increasing")
return xr_obj.roll(lon=shift, roll_coords=True)