clm5/python/ctsm/crop_calendars/check_constant_vars.py
2024-05-09 15:14:01 +08:00

386 lines
13 KiB
Python

"""
For variables that should stay constant, make sure they are
"""
import numpy as np
from ctsm.crop_calendars.cropcal_module import import_rx_dates
def check_one_constant_var_setup(this_ds, case, var):
"""
Various setup steps for check_one_constant_var()
"""
if "gs" in this_ds[var].dims:
time_coord = "gs"
elif "time" in this_ds[var].dims:
time_coord = "time"
else:
raise RuntimeError(f"Which of these is the time coordinate? {this_ds[var].dims}")
i_time_coord = this_ds[var].dims.index(time_coord)
this_da = this_ds[var]
ra_sp = np.moveaxis(this_da.copy().values, i_time_coord, 0)
incl_patches = []
bad_patches = np.array([])
str_list = []
# Read prescription file, if needed
rx_ds = None
if isinstance(case, dict):
if var == "GDDHARV" and "rx_gdds_file" in case:
rx_ds = import_rx_dates(
"gdd", case["rx_gdds_file"], this_ds, set_neg1_to_nan=False
).squeeze()
return time_coord, this_da, ra_sp, incl_patches, str_list, rx_ds, bad_patches
def loop_through_bad_patches(
verbose,
emojus,
var,
everything_ok,
str_list,
rx_ds,
time_1,
t1_yr,
t1_vals,
timestep,
t_yr,
t_vals,
bad_patches_this_time,
found_in_rx,
vary_patches,
vary_lons,
vary_lats,
vary_crops,
vary_crops_int,
any_bad,
):
"""
Loop through and check any patches that were "bad" according to check_constant_vars().
This is pretty inefficient, but it works.
"""
patch = None # In case bad_patches_this_time is empty
for i, patch in enumerate(bad_patches_this_time):
this_patch = vary_patches[i]
this_lon = vary_lons[i]
this_lat = vary_lats[i]
this_crop = vary_crops[i]
this_crop_int = vary_crops_int[i]
# If prescribed input had missing value (-1), it's fine for it to vary.
if rx_ds:
rx_var = f"gs1_{this_crop_int}"
if this_lon in rx_ds.lon.values and this_lat in rx_ds.lat.values:
rx_vals = rx_ds[rx_var].sel(lon=this_lon, lat=this_lat).values
n_unique = len(np.unique(rx_vals))
if n_unique == 1:
found_in_rx[i] = True
if rx_vals == -1:
continue
elif n_unique > 1:
raise RuntimeError(
f"How does lon {this_lon} lat {this_lat} {this_crop} have "
+ f"time-varying {var}?"
)
else:
raise RuntimeError(f"lon {this_lon} lat {this_lat} {this_crop} not in rx dataset?")
# Print info (or save to print later)
any_bad = True
if verbose:
this_str = (
f" Patch {this_patch} (lon {this_lon} lat {this_lat}) "
+ f"{this_crop} ({this_crop_int})"
)
if rx_ds and not found_in_rx[i]:
this_str = this_str.replace("(lon", "* (lon")
if not np.isnan(t1_vals[patch]):
t1_val_print = int(t1_vals[patch])
else:
t1_val_print = "NaN"
if not np.isnan(t_vals[patch]):
t_val_print = int(t_vals[patch])
else:
t_val_print = "NaN"
if var == "SDATES":
str_list.append(
f"{this_str}: Sowing {t1_yr} jday {t1_val_print}, {t_yr} "
+ f"jday {t_val_print}"
)
else:
str_list.append(
f"{this_str}: {t1_yr} {var} {t1_val_print}, {t_yr} {var} " + f"{t_val_print}"
)
else:
if everything_ok:
print(f"{emojus} CLM output {var} unexpectedly vary over time:")
everything_ok = False
print(f"{var} timestep {timestep} does not match timestep {time_1}")
break
return any_bad, patch
def ensure_all_patches_checked(this_ds, this_da, ra_sp, incl_patches):
"""
In check_one_constant_var(), make sure every patch was checked once (or is all-NaN except
possibly final season)
"""
incl_patches = np.sort(incl_patches)
if not np.array_equal(incl_patches, np.unique(incl_patches)):
raise RuntimeError("Patch(es) checked more than once!")
incl_patches = list(incl_patches)
incl_patches += list(
np.where(
np.all(
np.isnan(
ra_sp[
:-1,
]
),
axis=0,
)
)[0]
)
incl_patches = np.sort(incl_patches)
if not np.array_equal(incl_patches, np.unique(incl_patches)):
raise RuntimeError("Patch(es) checked but also all-NaN??")
if not np.array_equal(incl_patches, np.arange(this_ds.dims["patch"])):
for patch in np.arange(this_ds.dims["patch"]):
if patch not in incl_patches:
raise RuntimeError(
f"Not all patches checked! E.g., {patch}: {this_da.isel(patch=patch).values}"
)
def check_one_constant_var_loop_through_timesteps(
this_ds,
ignore_nan,
verbose,
emojus,
var,
everything_ok,
time_coord,
this_da,
str_list,
rx_ds,
time_1,
these_patches,
t1_yr,
t1_vals,
any_bad,
any_bad_before_checking_rx,
bad_patches,
):
"""
In check_one_constant_var(), loop through timesteps
"""
found_in_rx = None
for timestep in np.arange(time_1 + 1, this_ds.dims[time_coord]):
t_yr = this_ds[time_coord].values[timestep]
t_vals = np.squeeze(this_da.isel({time_coord: timestep, "patch": these_patches}).values)
ok_p = t1_vals == t_vals
# If allowed, ignore where either t or t1 is NaN. Should only be used for runs where
# land use varies over time.
if ignore_nan:
ok_p = np.squeeze(np.bitwise_or(ok_p, np.isnan(t1_vals + t_vals)))
if not np.all(ok_p):
any_bad_before_checking_rx = True
bad_patches_this_time = list(np.where(np.bitwise_not(ok_p))[0])
bad_patches = np.concatenate(
(bad_patches, np.array(these_patches)[bad_patches_this_time])
)
if rx_ds:
found_in_rx = np.array([False for x in bad_patches])
vary_patches = list(np.array(these_patches)[bad_patches_this_time])
vary_lons = this_ds.patches1d_lon.values[bad_patches_this_time]
vary_lats = this_ds.patches1d_lat.values[bad_patches_this_time]
vary_crops = this_ds.patches1d_itype_veg_str.values[bad_patches_this_time]
vary_crops_int = this_ds.patches1d_itype_veg.values[bad_patches_this_time]
any_bad_any_crop = False
for crop_int in np.unique(vary_crops_int):
rx_var = f"gs1_{crop_int}"
vary_lons_this_crop = vary_lons[np.where(vary_crops_int == crop_int)]
vary_lats_this_crop = vary_lats[np.where(vary_crops_int == crop_int)]
these_rx_vals = np.diag(
rx_ds[rx_var].sel(lon=vary_lons_this_crop, lat=vary_lats_this_crop).values
)
if len(these_rx_vals) != len(vary_lats_this_crop):
raise RuntimeError(
f"Expected {len(vary_lats_this_crop)} rx values; got "
+ f"{len(these_rx_vals)}"
)
if not np.any(these_rx_vals != -1):
continue
any_bad_any_crop = True
break
if not any_bad_any_crop:
continue
# Loop through and check any patches that were "bad"
any_bad = loop_through_bad_patches(
verbose,
emojus,
var,
everything_ok,
str_list,
rx_ds,
time_1,
t1_yr,
t1_vals,
timestep,
t_yr,
t_vals,
bad_patches_this_time,
found_in_rx,
vary_patches,
vary_lons,
vary_lats,
vary_crops,
vary_crops_int,
any_bad,
)
return any_bad_before_checking_rx, bad_patches, found_in_rx, any_bad
def check_one_constant_var(
this_ds, case, ignore_nan, verbose, emojus, var, any_bad, any_bad_before_checking_rx
):
"""
Ensure that a variable that should be constant actually is
"""
everything_ok = True
(
time_coord,
this_da,
ra_sp,
incl_patches,
str_list,
rx_ds,
bad_patches,
) = check_one_constant_var_setup(this_ds, case, var)
for time_1 in np.arange(this_ds.dims[time_coord] - 1):
condn = ~np.isnan(ra_sp[time_1, ...])
if time_1 > 0:
condn = np.bitwise_and(condn, np.all(np.isnan(ra_sp[:time_1, ...]), axis=0))
these_patches = np.where(condn)[0]
if these_patches.size == 0:
continue
these_patches = list(np.where(condn)[0])
incl_patches += these_patches
t1_yr = this_ds[time_coord].values[time_1]
t1_vals = np.squeeze(this_da.isel({time_coord: time_1, "patch": these_patches}).values)
(
any_bad_before_checking_rx,
bad_patches,
found_in_rx,
any_bad,
) = check_one_constant_var_loop_through_timesteps(
this_ds,
ignore_nan,
verbose,
emojus,
var,
everything_ok,
time_coord,
this_da,
str_list,
rx_ds,
time_1,
these_patches,
t1_yr,
t1_vals,
any_bad,
any_bad_before_checking_rx,
bad_patches,
)
if verbose and any_bad:
print(f"{emojus} CLM output {var} unexpectedly vary over time:")
str_list.sort()
if found_in_rx is None:
raise RuntimeError("Somehow any_bad True but found_in_rx None")
if rx_ds and np.any(~found_in_rx): # pylint: disable=invalid-unary-operand-type
str_list = [
"*: Not found in prescribed input file (maybe minor lon/lat mismatch)"
] + str_list
elif not rx_ds:
str_list = ["(No rx file checked)"] + str_list
print("\n".join(str_list))
# Make sure every patch was checked once (or is all-NaN except possibly final season)
ensure_all_patches_checked(this_ds, this_da, ra_sp, incl_patches)
if not any_bad:
if any_bad_before_checking_rx:
print(
f"✅ CLM output {var} do not vary through {this_ds.dims[time_coord]} growing "
+ "seasons of output (except for patch(es) with missing rx)."
)
else:
print(
f"✅ CLM output {var} do not vary through {this_ds.dims[time_coord]} growing "
+ "seasons of output."
)
return any_bad, any_bad_before_checking_rx, bad_patches
def check_constant_vars(
this_ds, case, ignore_nan, const_growing_seasons=None, verbose=True, throw_error=True
):
"""
For variables that should stay constant, make sure they are
"""
if isinstance(case, str):
const_vars = [case]
elif isinstance(case, list):
const_vars = case
elif isinstance(case, dict):
const_vars = case["const_vars"]
else:
raise TypeError(f"case must be str or dict, not {type(case)}")
if not const_vars:
return None
if const_growing_seasons:
gs_0 = this_ds.gs.values[0]
gs_n = this_ds.gs.values[-1]
if const_growing_seasons.start > gs_0 or const_growing_seasons.stop < gs_n:
print(
f"❗ Only checking const_vars over {const_growing_seasons.start}-"
+ f"{const_growing_seasons.stop} (run includes {gs_0}-{gs_n})"
)
this_ds = this_ds.sel(gs=const_growing_seasons)
any_bad = False
any_bad_before_checking_rx = False
if throw_error:
emojus = ""
else:
emojus = ""
if not isinstance(const_vars, list):
const_vars = [const_vars]
for var in const_vars:
any_bad, any_bad_before_checking_rx, bad_patches = check_one_constant_var(
this_ds, case, ignore_nan, verbose, emojus, var, any_bad, any_bad_before_checking_rx
)
if any_bad and throw_error:
raise RuntimeError("Stopping due to failed check_constant_vars().")
bad_patches = np.unique(bad_patches)
return [int(p) for p in bad_patches]