Source code for swarmpal.toolboxes.tfa.processes

from __future__ import annotations

import numpy as np
from xarray import DataArray, Dataset, DataTree

from swarmpal.io import PalProcess
from swarmpal.toolboxes.tfa import tfalib
from swarmpal.utils.exceptions import PalError

FLAG_THRESHOLDS = {
    "B_NEC": {"flag_name": "Flags_B", "max_val": 30},
    "F": {"flag_name": "Flags_F", "max_val": 63},
    "E_NEC": {"flag_name": "Flags_TII", "max_val": 23},
    "n": {"flag_name": "Flags_LP", "max_val": 63},
    "Bubble_Probability": {"flag_name": "Flags_F", "max_val": 1000},
    "FAC": {"flag_name": "Flags_F", "max_val": 1000},
    "EEF": {"flag_name": "Flags", "max_val": 1000},
    "Eh_XYZ": {"flag_name": "Quality_flags", "max_val": 0},
    "Ev_XYZ": {"flag_name": "Quality_flags", "max_val": 0},
}


[docs] class Preprocess(PalProcess): """Prepare data for input to other TFA tools""" @property def process_name(self) -> str: return "TFA_Preprocess"
[docs] def set_config( self, dataset: str = "", timevar: str = "Timestamp", active_variable: str = "", active_component: int | None = None, sampling_rate: float = 1, remove_model: bool = False, model: str = "", convert_to_mfa: bool = False, use_magnitude: bool = False, clean_by_flags: bool = False, flagclean_varname: str = "", flagclean_flagname: str = "", flagclean_maxval: int | None = None, output_dataset: str = "PAL_TFA", ) -> None: """Set the process configuration Parameters ---------- dataset : str Selects this dataset from the datatree timevar : str Identifies the name of the time variable, usually "Timestamp" or "Time" active_variable : str Selects the variable to use from within the dataset active_component : int, optional Selects the component to use (if active_variable is a vector) sampling_rate : float, optional Identify the sampling rate of the data input (in Hz), by default 1 remove_model : bool, optional Remove a magnetic model prediction or not, by default False model : str, optional The name of the model convert_to_mfa : bool, optional Rotate B to mean-field aligned (MFA) coordinates, by default False use_magnitude : bool, optional Use the magnitude of a vector instead, by default False clean_by_flags : bool, optional Whether to apply additional flag cleaning or not, by default False flagclean_varname : str, optional Name of the variable to clean flagclean_flagname : str, optional Name of the flag to use to clean by flagclean_maxval : int, optional Maximum allowable flag value output_dataset : str Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA" Notes ----- Some special ``active_variable`` names exist which are added to the dataset on-the-fly: * "B_NEC_res_Model" where a model prediction must be available in the data, like ``"B_NEC_<Model>"``, and ``remove_model`` has been set. The name of the model can be set with, for example, ``model="CHAOS"``. * "B_MFA" when ``convert_to_mfa`` has been set. * "Eh_XYZ" and "Ev_XYZ" when using the TCT datasets, with vectors defined in ``("Ehx", "Ehy", "Ehz")`` and ``("Evx", "Evy", "Evz")`` respectively. """ super().set_config( dataset=dataset, timevar=timevar, active_variable=active_variable, active_component=active_component, sampling_rate=sampling_rate, remove_model=remove_model, model=model, convert_to_mfa=convert_to_mfa, use_magnitude=use_magnitude, clean_by_flags=clean_by_flags, flagclean_varname=flagclean_varname, flagclean_flagname=flagclean_flagname, flagclean_maxval=flagclean_maxval, output_dataset=output_dataset, )
@property def active_variable(self): return self.config.get("active_variable", "") @property def active_component(self): return self.config.get("active_component", "")
[docs] def _call(self, datatree: DataTree) -> DataTree: self._validate_inputs(datatree) # Select the datatree to work on self.subtree = datatree[self.config.get("dataset")] # Prepare data depending on the content if self.active_variable in ("Eh_XYZ", "Ev_XYZ"): ds = self._prep_efi_expt_data(self.subtree.ds) else: ds = self._prep_magnetic_data(self.subtree.ds) # Optionally clean according to flag values if self.config.get("clean_by_flags", False): ds = self._flag_cleaning(ds) # Identify and assign "TFA_Variable" (to be used in next processes) if self.active_component is not None: da = ds[self.active_variable][:, self.active_component].copy(deep=True) elif self.config["use_magnitude"]: da = (ds[self.active_variable] ** 2).sum(axis=1).pipe(np.sqrt) else: da = ds[self.active_variable].copy(deep=True) # Rename (Timestamp/Time) to TFA_Time to avoid collision da = da.rename({self.config["timevar"]: "TFA_Time"}) da = self._constant_cadence(da) ds_out = Dataset(data_vars={"TFA_Variable": da, "TFA_Time": da["TFA_Time"]}) # Append the specially created variables (e.g. B_NEC_res_Model, B_MFA, Eh_XYZ, Ev_XYZ) # (created in "ds" by "_prep_magnetic_data" & "_prep_efi_expt_data") # so that they are preserved in the output special_vars = {"B_NEC_res_Model", "B_MFA", "Eh_XYZ", "Ev_XYZ"}.intersection( set(ds.data_vars) ) for special_var in special_vars: ds_out = ds_out.assign({special_var: ds[special_var]}) # Remove attrs, because .to_netcdf() is failing when blank units are set here ds_out["TFA_Time"].attrs = {} # Assign dataset back into the datatree to return datatree[self.output_dataset] = ds_out return datatree
[docs] def _validate_inputs(self, datatree): """Some checks that the inputs and config are valid""" dataset = self.config.get("dataset") active_variable = self.config.get("active_variable") active_component = self.config.get("active_component") use_magnitude = self.config.get("use_magnitude") timevar = self.config.get("timevar") if not all((dataset, active_variable)): raise PalError("TFA Preprocess: dataset and/or active_variable not set") if timevar not in datatree[dataset].coords: raise PalError(f"TFA Preprocess: {timevar=} not available in dataset") # Catch the cases with special names that aren't initially available in the dataset (they are set later) if any(x in active_variable for x in ("B_NEC_res_", "MFA", "Eh_XYZ", "Ev_XYZ")): target_shape = (len(datatree[dataset][timevar]), 3) else: target_shape = datatree[dataset][active_variable].shape # Check if active_component is set appropriately, according to the shape of the active_variable if ( (len(target_shape) > 1) and (active_component is None) and (not use_magnitude) ): raise PalError("TFA Preprocess: active_component not set") if (len(target_shape) == 1) and (active_component is not None): raise PalError("TFA Preprocess: active_component set, but no vector found")
[docs] def _prep_magnetic_data(self, ds: Dataset) -> Dataset: """Subtract model and/or rotate to MFA""" remove_model = self.config.get("remove_model", False) convert_to_mfa = self.config.get("convert_to_mfa", False) timevar = self.config.get("timevar") # Identify model name from config or from PAL meta model = self.config.get("model", "") try: model = model or self.subtree.swarmpal.magnetic_model_name except PalError: model = "" # Optionally assign residuals to dataset if remove_model: ds = ds.assign( {"B_NEC_res_Model": self.subtree.swarmpal.magnetic_residual(model)}, ) # Optionally rotate to MFA (Mean-field aligned coordinates) if convert_to_mfa: if remove_model: B_MFA = tfalib.mfa( ds["B_NEC_res_Model"].data, ds[f"B_NEC_{model}"].data ) else: B_MFA = tfalib.mfa(ds["B_NEC"].data, ds[f"B_NEC_{model}"].data) ds = ds.assign_coords({"MFA": [0, 1, 2]}) ds = ds.assign({"B_MFA": ((timevar, "MFA"), B_MFA)}) ds["B_MFA"].attrs = { "units": "nT", "description": "Magnetic field in Mean-field aligned coordinates", } return ds
[docs] def _prep_efi_expt_data(self, ds: Dataset) -> Dataset: """Assign the Eh_XYZ or Ev_XYZ vector data variable""" # Validate input data timevar = self.config.get("timevar") available_vars = set(ds.data_vars) vectors = { "Eh_XYZ": ("Ehx", "Ehy", "Ehz"), "Ev_XYZ": ("Evx", "Evy", "Evz"), } vectors = vectors[self.active_variable] required_vars = {*vectors, "Quality_flags"} if not required_vars.issubset(available_vars): raise PalError(f"Not all available: {required_vars}") # Create and assign the vector parameter E_XYZ = np.vstack([ds[i] for i in vectors]).T ds = ds.assign_coords({"XYZ": ["X", "Y", "Z"]}) ds = ds.assign({self.active_variable: ((timevar, "XYZ"), E_XYZ)}) return ds
[docs] def _flag_cleaning(self, ds): """Set values to NaN where flags exceed a threshold""" varname = self.config.get("flagclean_varname", None) flagname = self.config.get("flagclean_flagname", None) max_val = self.config.get("flagclean_maxval", None) # Use default parameters if none given in config varname = varname or self.active_variable flagname = ( flagname or FLAG_THRESHOLDS[varname.replace("_res_Model", "")]["flag_name"] ) max_val = ( max_val or FLAG_THRESHOLDS[varname.replace("_res_Model", "")]["max_val"] ) # Set flagged values to NaN inds_to_remove = ds[flagname] > max_val ds[varname][inds_to_remove, ...] = np.nan return ds
[docs] def _constant_cadence(self, da): """Convert array to that of constant cadence""" # Convert time to seconds for tfalib.constant_cadence t_old = da["TFA_Time"].data t_old_sec = (t_old - t_old[0]) / np.timedelta64(1, "s") new_t_sec, new_X = tfalib.constant_cadence( t_old_sec, da.data, self.config["sampling_rate"], interp=False )[0:2] new_t = t_old[0] + (new_t_sec * 1e9).astype("timedelta64[ns]") # Assign into new array to return da_new = DataArray( data=new_X, dims=("TFA_Time",), ) da_new = da_new.assign_coords({"TFA_Time": new_t}) da_new.attrs = { "units": da.attrs.get("units", ""), "description": da.attrs.get("description", ""), } da_new["TFA_Time"].attrs = { "units": da["TFA_Time"].attrs.get("units", ""), "description": da["TFA_Time"].attrs.get("description", ""), } return da_new
[docs] def _get_tfa_active_subtree(datatree, output_dataset): """Returns the relevant subtree when Preprocess has been applied""" # Scan the tree based on previous preprocess application pal_processes_meta = datatree.swarmpal.pal_meta.get(output_dataset, {}) tfa_preprocess_meta = pal_processes_meta.get("TFA_Preprocess") if not tfa_preprocess_meta: raise PalError("Must first run tfa.processes.Preprocess") return datatree[tfa_preprocess_meta.get("output_dataset")]
[docs] def _get_sampling_rate(datatree, output_dataset): """Get the sampling rate set by Preprocess""" pal_processes_meta = datatree.swarmpal.pal_meta.get(output_dataset, {}) tfa_preprocess_meta = pal_processes_meta.get("TFA_Preprocess") return tfa_preprocess_meta["sampling_rate"]
[docs] class Clean(PalProcess): """Clean TFA_Variable by removing outliers and interpolate gaps""" @property def process_name(self) -> str: return "TFA_Clean"
[docs] def set_config( self, window_size: int = 10, method: str = "iqr", multiplier: float = 0.5, output_dataset: str = "PAL_TFA", ) -> None: """Set the process configuration Parameters ---------- window_size : int, optional The size (number of points) of the rolling window, by default 10 method : str, optional "normal" or "iqr", by default "iqr" multiplier : float, optional Indicates the spread of the zone of accepted values, by default 0.5 output_dataset : str Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA" """ super().set_config( window_size=window_size, method=method, multiplier=multiplier, output_dataset=output_dataset, )
[docs] def _call(self, datatree) -> DataTree: # Identify the DataArray to modify subtree = _get_tfa_active_subtree(datatree, self.output_dataset) target_var = subtree["TFA_Variable"] # Apply cleaning routine inplace self._clean_variable(target_var) return datatree
[docs] def _clean_variable(self, target_var) -> DataArray: # Remove outliers inds = tfalib.outliers( target_var.data, self.config.get("window_size"), method=self.config.get("method"), multiplier=self.config.get("multiplier"), ) target_var.data[inds] = np.nan # Interpolate over gaps s = target_var.data.shape if len(s) == 1: N = s[0] t_ind = np.arange(N) x = target_var.data nonNaN = ~np.isnan(x) y = np.interp(t_ind, t_ind[nonNaN], x[nonNaN]) target_var.data = y else: N, D = target_var.data.shape t_ind = np.arange(N) for i in range(D): x = np.reshape(target_var.data[:, i], (N,)) nonNaN = ~np.isnan(x) y = np.interp(t_ind, t_ind[nonNaN], x[nonNaN]) target_var.data[:, i] = y return target_var
[docs] class Filter(PalProcess): """High-pass filter the TFA_Variable, using the SciPy Chebysev Type II filter""" @property def process_name(self) -> str: return "TFA_Filter"
[docs] def set_config( self, cutoff_frequency: float = 20 / 1000, output_dataset: str = "PAL_TFA", ) -> None: """Set the process configuration Parameters ---------- cutoff_frequency : float, optional The cutoff frequency (in Hz), by default 20/1000 output_dataset : str Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA" """ super().set_config( cutoff_frequency=cutoff_frequency, output_dataset=output_dataset, )
[docs] def _call(self, datatree) -> DataTree: # Identify the DataArray to modify subtree = _get_tfa_active_subtree(datatree, self.output_dataset) target_var = subtree["TFA_Variable"] # Apply filtering routine inplace target_var = self._filter( target_var, _get_sampling_rate(datatree, self.output_dataset) ) return datatree
[docs] def _filter(self, target_var, sampling_rate) -> DataArray: target_var.data = tfalib.filter( target_var.data, sampling_rate, self.config.get("cutoff_frequency"), ) return target_var
[docs] class Wavelet(PalProcess): """Apply wavelet analysis""" @property def process_name(self) -> str: return "TFA_Wavelet"
[docs] def set_config( self, min_frequency: float | None = None, max_frequency: float | None = None, min_scale: float | None = None, max_scale: float | None = None, dj: float = 0.1, output_dataset: str = "PAL_TFA", ) -> None: """Set the process configuration Parameters ---------- min_frequency : float | None, optional _description_, by default None max_frequency : float | None, optional _description_, by default None min_scale : float | None, optional _description_, by default None max_scale : float | None, optional _description_, by default None dj : float, optional _description_, by default 0.1 output_dataset : str Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_TFA" """ super().set_config( min_frequency=min_frequency, max_frequency=max_frequency, min_scale=min_scale, max_scale=max_scale, dj=dj, output_dataset=output_dataset, )
[docs] def _call(self, datatree: DataTree) -> DataTree: self._configure(datatree) # Identify the DataArray to use subtree = _get_tfa_active_subtree(datatree, self.output_dataset) ds = subtree.to_dataset() target_var = ds["TFA_Variable"] # Apply wavelet routine norm, scale = self._wavelets(target_var) # Assign new array to dataset ds = ds.assign_coords({"scale": scale}) ds["wavelet_power"] = DataArray( data=norm, dims=("scale", "TFA_Time"), ) # Return datatree containing the updated dataset datatree[subtree.name] = DataTree(dataset=ds) return datatree
[docs] def _configure(self, datatree): if self.config["min_scale"] is None: self.config["min_scale"] = 1 / self.config["max_frequency"] self.config["max_scale"] = 1 / self.config["min_frequency"] self.config["sampling_rate"] = self.config.get( "sampling_rate", _get_sampling_rate(datatree, self.output_dataset) )
[docs] def _wavelets(self, target_var: DataArray): scale = tfalib.wavelet_scales( self.config.get("min_scale"), self.config.get("max_scale"), self.config.get("dj"), ) wave = tfalib.wavelet_transform( target_var.data, dx=1 / self.config.get("sampling_rate"), minScale=self.config.get("min_scale"), maxScale=self.config.get("max_scale"), dj=self.config.get("dj"), )[0] norm = tfalib.wavelet_normalize( np.abs(wave) ** 2, scale, dx=1 / self.config.get("sampling_rate"), dj=self.config.get("dj"), wavelet_norm_factor=0.74044116, ) return norm, scale
[docs] class WaveDetection(PalProcess): """Screen out potential false waves Removes part of the wavelet spectrum that might be due to spikes, data gaps, ESFs or trailing parts of wave activity from either above or below the range of frequencies that were used to perform the wavelet transform. """ @property def process_name(self) -> str: return "TFA_WaveDetection"
[docs] def set_config(
self, ): ...
[docs] def _call(self, datatree): raise NotImplementedError
[docs] def _attach_ibi(self): ...