from __future__ import annotations
import numpy as np
from xarray import DataArray, Dataset, DataTree
from swarmpal.io import PalProcess
from swarmpal.toolboxes.tfa import tfalib
from swarmpal.utils.exceptions import PalError
FLAG_THRESHOLDS = {
"B_NEC": {"flag_name": "Flags_B", "max_val": 30},
"F": {"flag_name": "Flags_F", "max_val": 63},
"E_NEC": {"flag_name": "Flags_TII", "max_val": 23},
"n": {"flag_name": "Flags_LP", "max_val": 63},
"Bubble_Probability": {"flag_name": "Flags_F", "max_val": 1000},
"FAC": {"flag_name": "Flags_F", "max_val": 1000},
"EEF": {"flag_name": "Flags", "max_val": 1000},
"Eh_XYZ": {"flag_name": "Quality_flags", "max_val": 0},
"Ev_XYZ": {"flag_name": "Quality_flags", "max_val": 0},
}
[docs]
class Preprocess(PalProcess):
"""Prepare data for input to other TFA tools"""
@property
def process_name(self) -> str:
return "TFA_Preprocess"
[docs]
def set_config(
self,
dataset: str = "",
timevar: str = "Timestamp",
active_variable: str = "",
active_component: int | None = None,
sampling_rate: float = 1,
remove_model: bool = False,
model: str = "",
convert_to_mfa: bool = False,
use_magnitude: bool = False,
clean_by_flags: bool = False,
flagclean_varname: str = "",
flagclean_flagname: str = "",
flagclean_maxval: int | None = None,
output_dataset: str = "PAL_TFA",
) -> None:
"""Set the process configuration
Parameters
----------
dataset : str
Selects this dataset from the datatree
timevar : str
Identifies the name of the time variable, usually "Timestamp" or "Time"
active_variable : str
Selects the variable to use from within the dataset
active_component : int, optional
Selects the component to use (if active_variable is a vector)
sampling_rate : float, optional
Identify the sampling rate of the data input (in Hz), by default 1
remove_model : bool, optional
Remove a magnetic model prediction or not, by default False
model : str, optional
The name of the model
convert_to_mfa : bool, optional
Rotate B to mean-field aligned (MFA) coordinates, by default False
use_magnitude : bool, optional
Use the magnitude of a vector instead, by default False
clean_by_flags : bool, optional
Whether to apply additional flag cleaning or not, by default False
flagclean_varname : str, optional
Name of the variable to clean
flagclean_flagname : str, optional
Name of the flag to use to clean by
flagclean_maxval : int, optional
Maximum allowable flag value
output_dataset : str
Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA"
Notes
-----
Some special ``active_variable`` names exist which are added to the dataset on-the-fly:
* "B_NEC_res_Model"
where a model prediction must be available in the data, like ``"B_NEC_<Model>"``, and ``remove_model`` has been set. The name of the model can be set with, for example, ``model="CHAOS"``.
* "B_MFA"
when ``convert_to_mfa`` has been set.
* "Eh_XYZ" and "Ev_XYZ"
when using the TCT datasets, with vectors defined in ``("Ehx", "Ehy", "Ehz")`` and ``("Evx", "Evy", "Evz")`` respectively.
"""
super().set_config(
dataset=dataset,
timevar=timevar,
active_variable=active_variable,
active_component=active_component,
sampling_rate=sampling_rate,
remove_model=remove_model,
model=model,
convert_to_mfa=convert_to_mfa,
use_magnitude=use_magnitude,
clean_by_flags=clean_by_flags,
flagclean_varname=flagclean_varname,
flagclean_flagname=flagclean_flagname,
flagclean_maxval=flagclean_maxval,
output_dataset=output_dataset,
)
@property
def active_variable(self):
return self.config.get("active_variable", "")
@property
def active_component(self):
return self.config.get("active_component", "")
[docs]
def _call(self, datatree: DataTree) -> DataTree:
self._validate_inputs(datatree)
# Select the datatree to work on
self.subtree = datatree[self.config.get("dataset")]
# Prepare data depending on the content
if self.active_variable in ("Eh_XYZ", "Ev_XYZ"):
ds = self._prep_efi_expt_data(self.subtree.ds)
else:
ds = self._prep_magnetic_data(self.subtree.ds)
# Optionally clean according to flag values
if self.config.get("clean_by_flags", False):
ds = self._flag_cleaning(ds)
# Identify and assign "TFA_Variable" (to be used in next processes)
if self.active_component is not None:
da = ds[self.active_variable][:, self.active_component].copy(deep=True)
elif self.config["use_magnitude"]:
da = (ds[self.active_variable] ** 2).sum(axis=1).pipe(np.sqrt)
else:
da = ds[self.active_variable].copy(deep=True)
# Rename (Timestamp/Time) to TFA_Time to avoid collision
da = da.rename({self.config["timevar"]: "TFA_Time"})
da = self._constant_cadence(da)
ds_out = Dataset(data_vars={"TFA_Variable": da, "TFA_Time": da["TFA_Time"]})
# Append the specially created variables (e.g. B_NEC_res_Model, B_MFA, Eh_XYZ, Ev_XYZ)
# (created in "ds" by "_prep_magnetic_data" & "_prep_efi_expt_data")
# so that they are preserved in the output
special_vars = {"B_NEC_res_Model", "B_MFA", "Eh_XYZ", "Ev_XYZ"}.intersection(
set(ds.data_vars)
)
for special_var in special_vars:
ds_out = ds_out.assign({special_var: ds[special_var]})
# Remove attrs, because .to_netcdf() is failing when blank units are set here
ds_out["TFA_Time"].attrs = {}
# Assign dataset back into the datatree to return
datatree[self.output_dataset] = ds_out
return datatree
[docs]
def _prep_magnetic_data(self, ds: Dataset) -> Dataset:
"""Subtract model and/or rotate to MFA"""
remove_model = self.config.get("remove_model", False)
convert_to_mfa = self.config.get("convert_to_mfa", False)
timevar = self.config.get("timevar")
# Identify model name from config or from PAL meta
model = self.config.get("model", "")
try:
model = model or self.subtree.swarmpal.magnetic_model_name
except PalError:
model = ""
# Optionally assign residuals to dataset
if remove_model:
ds = ds.assign(
{"B_NEC_res_Model": self.subtree.swarmpal.magnetic_residual(model)},
)
# Optionally rotate to MFA (Mean-field aligned coordinates)
if convert_to_mfa:
if remove_model:
B_MFA = tfalib.mfa(
ds["B_NEC_res_Model"].data, ds[f"B_NEC_{model}"].data
)
else:
B_MFA = tfalib.mfa(ds["B_NEC"].data, ds[f"B_NEC_{model}"].data)
ds = ds.assign_coords({"MFA": [0, 1, 2]})
ds = ds.assign({"B_MFA": ((timevar, "MFA"), B_MFA)})
ds["B_MFA"].attrs = {
"units": "nT",
"description": "Magnetic field in Mean-field aligned coordinates",
}
return ds
[docs]
def _prep_efi_expt_data(self, ds: Dataset) -> Dataset:
"""Assign the Eh_XYZ or Ev_XYZ vector data variable"""
# Validate input data
timevar = self.config.get("timevar")
available_vars = set(ds.data_vars)
vectors = {
"Eh_XYZ": ("Ehx", "Ehy", "Ehz"),
"Ev_XYZ": ("Evx", "Evy", "Evz"),
}
vectors = vectors[self.active_variable]
required_vars = {*vectors, "Quality_flags"}
if not required_vars.issubset(available_vars):
raise PalError(f"Not all available: {required_vars}")
# Create and assign the vector parameter
E_XYZ = np.vstack([ds[i] for i in vectors]).T
ds = ds.assign_coords({"XYZ": ["X", "Y", "Z"]})
ds = ds.assign({self.active_variable: ((timevar, "XYZ"), E_XYZ)})
return ds
[docs]
def _flag_cleaning(self, ds):
"""Set values to NaN where flags exceed a threshold"""
varname = self.config.get("flagclean_varname", None)
flagname = self.config.get("flagclean_flagname", None)
max_val = self.config.get("flagclean_maxval", None)
# Use default parameters if none given in config
varname = varname or self.active_variable
flagname = (
flagname or FLAG_THRESHOLDS[varname.replace("_res_Model", "")]["flag_name"]
)
max_val = (
max_val or FLAG_THRESHOLDS[varname.replace("_res_Model", "")]["max_val"]
)
# Set flagged values to NaN
inds_to_remove = ds[flagname] > max_val
ds[varname][inds_to_remove, ...] = np.nan
return ds
[docs]
def _constant_cadence(self, da):
"""Convert array to that of constant cadence"""
# Convert time to seconds for tfalib.constant_cadence
t_old = da["TFA_Time"].data
t_old_sec = (t_old - t_old[0]) / np.timedelta64(1, "s")
new_t_sec, new_X = tfalib.constant_cadence(
t_old_sec, da.data, self.config["sampling_rate"], interp=False
)[0:2]
new_t = t_old[0] + (new_t_sec * 1e9).astype("timedelta64[ns]")
# Assign into new array to return
da_new = DataArray(
data=new_X,
dims=("TFA_Time",),
)
da_new = da_new.assign_coords({"TFA_Time": new_t})
da_new.attrs = {
"units": da.attrs.get("units", ""),
"description": da.attrs.get("description", ""),
}
da_new["TFA_Time"].attrs = {
"units": da["TFA_Time"].attrs.get("units", ""),
"description": da["TFA_Time"].attrs.get("description", ""),
}
return da_new
[docs]
def _get_tfa_active_subtree(datatree, output_dataset):
"""Returns the relevant subtree when Preprocess has been applied"""
# Scan the tree based on previous preprocess application
pal_processes_meta = datatree.swarmpal.pal_meta.get(output_dataset, {})
tfa_preprocess_meta = pal_processes_meta.get("TFA_Preprocess")
if not tfa_preprocess_meta:
raise PalError("Must first run tfa.processes.Preprocess")
return datatree[tfa_preprocess_meta.get("output_dataset")]
[docs]
def _get_sampling_rate(datatree, output_dataset):
"""Get the sampling rate set by Preprocess"""
pal_processes_meta = datatree.swarmpal.pal_meta.get(output_dataset, {})
tfa_preprocess_meta = pal_processes_meta.get("TFA_Preprocess")
return tfa_preprocess_meta["sampling_rate"]
[docs]
class Clean(PalProcess):
"""Clean TFA_Variable by removing outliers and interpolate gaps"""
@property
def process_name(self) -> str:
return "TFA_Clean"
[docs]
def set_config(
self,
window_size: int = 10,
method: str = "iqr",
multiplier: float = 0.5,
output_dataset: str = "PAL_TFA",
) -> None:
"""Set the process configuration
Parameters
----------
window_size : int, optional
The size (number of points) of the rolling window, by default 10
method : str, optional
"normal" or "iqr", by default "iqr"
multiplier : float, optional
Indicates the spread of the zone of accepted values, by default 0.5
output_dataset : str
Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA"
"""
super().set_config(
window_size=window_size,
method=method,
multiplier=multiplier,
output_dataset=output_dataset,
)
[docs]
def _call(self, datatree) -> DataTree:
# Identify the DataArray to modify
subtree = _get_tfa_active_subtree(datatree, self.output_dataset)
target_var = subtree["TFA_Variable"]
# Apply cleaning routine inplace
self._clean_variable(target_var)
return datatree
[docs]
def _clean_variable(self, target_var) -> DataArray:
# Remove outliers
inds = tfalib.outliers(
target_var.data,
self.config.get("window_size"),
method=self.config.get("method"),
multiplier=self.config.get("multiplier"),
)
target_var.data[inds] = np.nan
# Interpolate over gaps
s = target_var.data.shape
if len(s) == 1:
N = s[0]
t_ind = np.arange(N)
x = target_var.data
nonNaN = ~np.isnan(x)
y = np.interp(t_ind, t_ind[nonNaN], x[nonNaN])
target_var.data = y
else:
N, D = target_var.data.shape
t_ind = np.arange(N)
for i in range(D):
x = np.reshape(target_var.data[:, i], (N,))
nonNaN = ~np.isnan(x)
y = np.interp(t_ind, t_ind[nonNaN], x[nonNaN])
target_var.data[:, i] = y
return target_var
[docs]
class Filter(PalProcess):
"""High-pass filter the TFA_Variable, using the SciPy Chebysev Type II filter"""
@property
def process_name(self) -> str:
return "TFA_Filter"
[docs]
def set_config(
self,
cutoff_frequency: float = 20 / 1000,
output_dataset: str = "PAL_TFA",
) -> None:
"""Set the process configuration
Parameters
----------
cutoff_frequency : float, optional
The cutoff frequency (in Hz), by default 20/1000
output_dataset : str
Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA"
"""
super().set_config(
cutoff_frequency=cutoff_frequency,
output_dataset=output_dataset,
)
[docs]
def _call(self, datatree) -> DataTree:
# Identify the DataArray to modify
subtree = _get_tfa_active_subtree(datatree, self.output_dataset)
target_var = subtree["TFA_Variable"]
# Apply filtering routine inplace
target_var = self._filter(
target_var, _get_sampling_rate(datatree, self.output_dataset)
)
return datatree
[docs]
def _filter(self, target_var, sampling_rate) -> DataArray:
target_var.data = tfalib.filter(
target_var.data,
sampling_rate,
self.config.get("cutoff_frequency"),
)
return target_var
[docs]
class Wavelet(PalProcess):
"""Apply wavelet analysis"""
@property
def process_name(self) -> str:
return "TFA_Wavelet"
[docs]
def set_config(
self,
min_frequency: float | None = None,
max_frequency: float | None = None,
min_scale: float | None = None,
max_scale: float | None = None,
dj: float = 0.1,
output_dataset: str = "PAL_TFA",
) -> None:
"""Set the process configuration
Parameters
----------
min_frequency : float | None, optional
_description_, by default None
max_frequency : float | None, optional
_description_, by default None
min_scale : float | None, optional
_description_, by default None
max_scale : float | None, optional
_description_, by default None
dj : float, optional
_description_, by default 0.1
output_dataset : str
Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA"
"""
super().set_config(
min_frequency=min_frequency,
max_frequency=max_frequency,
min_scale=min_scale,
max_scale=max_scale,
dj=dj,
output_dataset=output_dataset,
)
[docs]
def _call(self, datatree: DataTree) -> DataTree:
self._configure(datatree)
# Identify the DataArray to use
subtree = _get_tfa_active_subtree(datatree, self.output_dataset)
ds = subtree.to_dataset()
target_var = ds["TFA_Variable"]
# Apply wavelet routine
norm, scale = self._wavelets(target_var)
# Assign new array to dataset
ds = ds.assign_coords({"scale": scale})
ds["wavelet_power"] = DataArray(
data=norm,
dims=("scale", "TFA_Time"),
)
# Return datatree containing the updated dataset
datatree[subtree.name] = DataTree(dataset=ds)
return datatree
[docs]
def _wavelets(self, target_var: DataArray):
scale = tfalib.wavelet_scales(
self.config.get("min_scale"),
self.config.get("max_scale"),
self.config.get("dj"),
)
wave = tfalib.wavelet_transform(
target_var.data,
dx=1 / self.config.get("sampling_rate"),
minScale=self.config.get("min_scale"),
maxScale=self.config.get("max_scale"),
dj=self.config.get("dj"),
)[0]
norm = tfalib.wavelet_normalize(
np.abs(wave) ** 2,
scale,
dx=1 / self.config.get("sampling_rate"),
dj=self.config.get("dj"),
wavelet_norm_factor=0.74044116,
)
return norm, scale
[docs]
class WaveDetection(PalProcess):
"""Screen out potential false waves
Removes part of the wavelet spectrum that might be due to spikes, data gaps, ESFs or trailing parts of wave
activity from either above or below the range of frequencies that were
used to perform the wavelet transform.
"""
@property
def process_name(self) -> str:
return "TFA_WaveDetection"
self,
): ...
[docs]
def _call(self, datatree):
raise NotImplementedError
[docs]
def _attach_ibi(self): ...