2024-05-09 15:14:01 +08:00

268 lines
10 KiB
Python

"""
Import a dataset that can be spread over multiple files, only including specified variables
and/or vegetation types and/or timesteps, concatenating by time.
- DOES actually read the dataset into memory, but only AFTER dropping unwanted variables and/or
vegetation types.
"""
import re
import warnings
from importlib.util import find_spec
import numpy as np
import xarray as xr
import ctsm.crop_calendars.cropcal_utils as utils
from ctsm.crop_calendars.xr_flexsel import xr_flexsel
def compute_derived_vars(ds_in, var):
"""
Compute derived variables
"""
if (
var == "HYEARS"
and "HDATES" in ds_in
and ds_in.HDATES.dims == ("time", "mxharvests", "patch")
):
year_list = np.array([np.float32(x.year - 1) for x in ds_in.time.values])
hyears = ds_in["HDATES"].copy()
hyears.values = np.tile(
np.expand_dims(year_list, (1, 2)),
(1, ds_in.dims["mxharvests"], ds_in.dims["patch"]),
)
with np.errstate(invalid="ignore"):
is_le_zero = ~np.isnan(ds_in.HDATES.values) & (ds_in.HDATES.values <= 0)
hyears.values[is_le_zero] = ds_in.HDATES.values[is_le_zero]
hyears.values[np.isnan(ds_in.HDATES.values)] = np.nan
hyears.attrs["long_name"] = "DERIVED: actual crop harvest years"
hyears.attrs["units"] = "year"
ds_in["HYEARS"] = hyears
else:
raise RuntimeError(f"Unable to compute derived variable {var}")
return ds_in
def mfdataset_preproc(ds_in, vars_to_import, vegtypes_to_import, time_slice):
"""
Function to drop unwanted variables in preprocessing of open_mfdataset().
- Makes sure to NOT drop any unspecified variables that will be useful in gridding.
- Also adds vegetation type info in the form of a DataArray of strings.
- Also renames "pft" dimension (and all like-named variables, e.g., pft1d_itype_veg_str) to be
named like "patch". This can later be reversed, for compatibility with other code, using
patch2pft().
"""
# Rename "pft" dimension and variables to "patch", if needed
if "pft" in ds_in.dims:
pattern = re.compile("pft.*1d")
matches = [x for x in list(ds_in.keys()) if pattern.search(x) is not None]
pft2patch_dict = {"pft": "patch"}
for match in matches:
pft2patch_dict[match] = match.replace("pft", "patch").replace("patchs", "patches")
ds_in = ds_in.rename(pft2patch_dict)
derived_vars = []
if vars_to_import is not None:
# Split vars_to_import into variables that are vs. aren't already in ds
derived_vars = [v for v in vars_to_import if v not in ds_in]
present_vars = [v for v in vars_to_import if v in ds_in]
vars_to_import = present_vars
# Get list of dimensions present in variables in vars_to_import.
dim_list = []
for var in vars_to_import:
# list(set(x)) returns a list of the unique items in x
dim_list = list(set(dim_list + list(ds_in.variables[var].dims)))
# Get any _1d variables that are associated with those dimensions. These will be useful in
# gridding. Also, if any dimension is "pft", set up to rename it and all like-named
# variables to "patch"
oned_vars = []
for dim in dim_list:
pattern = re.compile(f"{dim}.*1d")
matches = [x for x in list(ds_in.keys()) if pattern.search(x) is not None]
oned_vars = list(set(oned_vars + matches))
# Add dimensions and _1d variables to vars_to_import
vars_to_import = list(set(vars_to_import + list(ds_in.dims) + oned_vars))
# Add any _bounds variables
bounds_vars = []
for var in vars_to_import:
bounds_var = var + "_bounds"
if bounds_var in ds_in:
bounds_vars = bounds_vars + [bounds_var]
vars_to_import = vars_to_import + bounds_vars
# Get list of variables to drop
varlist = list(ds_in.variables)
vars_to_drop = list(np.setdiff1d(varlist, vars_to_import))
# Drop them
ds_in = ds_in.drop_vars(vars_to_drop)
# Add vegetation type info
if "patches1d_itype_veg" in list(ds_in):
this_pftlist = utils.define_pftlist()
utils.get_patch_ivts(
ds_in, this_pftlist
) # Includes check of whether vegtype changes over time anywhere
vegtype_da = utils.get_vegtype_str_da(this_pftlist)
patches1d_itype_veg_str = vegtype_da.values[
ds_in.isel(time=0).patches1d_itype_veg.values.astype(int)
]
npatch = len(patches1d_itype_veg_str)
patches1d_itype_veg_str = xr.DataArray(
patches1d_itype_veg_str,
coords={"patch": np.arange(0, npatch)},
dims=["patch"],
name="patches1d_itype_veg_str",
)
ds_in = xr.merge([ds_in, vegtype_da, patches1d_itype_veg_str])
# Restrict to veg. types of interest, if any
if vegtypes_to_import is not None:
ds_in = xr_flexsel(ds_in, vegtype=vegtypes_to_import)
# Restrict to time slice, if any
if time_slice:
ds_in = utils.safer_timeslice(ds_in, time_slice)
# Finish import
ds_in = xr.decode_cf(ds_in, decode_times=True)
# Compute derived variables
for var in derived_vars:
ds_in = compute_derived_vars(ds_in, var)
return ds_in
def process_inputs(filelist, my_vars, my_vegtypes, my_vars_missing_ok):
"""
Process inputs to import_ds()
"""
if my_vars_missing_ok is None:
my_vars_missing_ok = []
# Convert my_vegtypes here, if needed, to avoid repeating the process each time you read a file
# in xr.open_mfdataset().
if my_vegtypes is not None:
if not isinstance(my_vegtypes, list):
my_vegtypes = [my_vegtypes]
if isinstance(my_vegtypes[0], str):
my_vegtypes = utils.vegtype_str2int(my_vegtypes)
# Same for these variables.
if my_vars is not None:
if not isinstance(my_vars, list):
my_vars = [my_vars]
if my_vars_missing_ok:
if not isinstance(my_vars_missing_ok, list):
my_vars_missing_ok = [my_vars_missing_ok]
# Make sure lists are actually lists
if not isinstance(filelist, list):
filelist = [filelist]
if not isinstance(my_vars_missing_ok, list):
my_vars_missing_ok = [my_vars_missing_ok]
return filelist, my_vars, my_vegtypes, my_vars_missing_ok
def import_ds(
filelist,
my_vars=None,
my_vegtypes=None,
time_slice=None,
my_vars_missing_ok=None,
rename_lsmlatlon=False,
chunks=None,
):
"""
Import a dataset that can be spread over multiple files, only including specified variables
and/or vegetation types and/or timesteps, concatenating by time.
- DOES actually read the dataset into memory, but only AFTER dropping unwanted variables and/or
vegetation types.
"""
filelist, my_vars, my_vegtypes, my_vars_missing_ok = process_inputs(
filelist, my_vars, my_vegtypes, my_vars_missing_ok
)
# Remove files from list if they don't contain requested timesteps.
# time_slice should be in the format slice(start,end[,step]). start or end can be None to be
# unbounded on one side. Note that the standard slice() documentation suggests that only
# elements through end-1 will be selected, but that seems not to be the case in the xarray
# implementation.
if time_slice:
new_filelist = []
for file in sorted(filelist):
filetime = xr.open_dataset(file).time
filetime_sel = utils.safer_timeslice(filetime, time_slice)
include_this_file = filetime_sel.size
if include_this_file:
new_filelist.append(file)
# If you found some matching files, but then you find one that doesn't, stop going
# through the list.
elif new_filelist:
break
if not new_filelist:
raise RuntimeError(f"No files found in time_slice {time_slice}")
filelist = new_filelist
# The xarray open_mfdataset() "preprocess" argument requires a function that takes exactly one
# variable (an xarray.Dataset object). Wrapping mfdataset_preproc() in this lambda function
# allows this. Could also just allow mfdataset_preproc() to access my_vars and my_vegtypes
# directly, but that's bad practice as it could lead to scoping issues.
mfdataset_preproc_closure = lambda ds: mfdataset_preproc(ds, my_vars, my_vegtypes, time_slice)
# Import
if isinstance(filelist, list) and len(filelist) == 1:
filelist = filelist[0]
if isinstance(filelist, list):
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
if find_spec("dask") is None:
raise ModuleNotFoundError(
"You have asked xarray to import a list of files as a single Dataset using"
" open_mfdataset(), but this requires dask, which is not available.\nFile"
f" list: {filelist}"
)
this_ds = xr.open_mfdataset(
sorted(filelist),
data_vars="minimal",
preprocess=mfdataset_preproc_closure,
compat="override",
coords="all",
concat_dim="time",
combine="nested",
chunks=chunks,
)
elif isinstance(filelist, str):
this_ds = xr.open_dataset(filelist, chunks=chunks)
this_ds = mfdataset_preproc(this_ds, my_vars, my_vegtypes, time_slice)
this_ds = this_ds.compute()
# Warn and/or error about variables that couldn't be imported or derived
if my_vars:
missing_vars = [v for v in my_vars if v not in this_ds]
ok_missing_vars = [v for v in missing_vars if v in my_vars_missing_ok]
bad_missing_vars = [v for v in missing_vars if v not in my_vars_missing_ok]
if ok_missing_vars:
print(
"Could not import some variables; either not present or not deriveable:"
f" {ok_missing_vars}"
)
if bad_missing_vars:
raise RuntimeError(
"Could not import some variables; either not present or not deriveable:"
f" {bad_missing_vars}"
)
if rename_lsmlatlon:
if "lsmlat" in this_ds.dims:
this_ds = this_ds.rename({"lsmlat": "lat"})
if "lsmlon" in this_ds.dims:
this_ds = this_ds.rename({"lsmlon": "lon"})
return this_ds