"""TODO Docstring for TFA
"""
import datetime as dt
import logging
import re
import sys
from abc import ABC, abstractmethod
import matplotlib.pyplot as plt
import numpy as np
from swarmpal.io import ExternalData
from swarmpal.toolboxes.tfa import tfalib
logging.basicConfig(level=logging.WARN)
[docs]def isotime2num(iso_str):
"""Converts ISO duration to its numeric value in seconds"""
# pairs of ISO characters and their values in seconds
date_vals = (("Y", 31536000), ("M", 2592000), ("D", 86400))
time_vals = (("H", 3600), ("M", 60), ("S", 1))
duration = 0.0
# isolate part between P and T chars, i.e. the date part
s_date = re.findall("P(.+)T", iso_str)
# check of values of years, months and days and add their values to duration
if s_date:
for c, m in date_vals:
match = re.findall(r"([\d\.]+)" + c, s_date[0])
if match:
duration += float(match[0]) * m
# isolate part between after the T char, i.e. the time part
s_time = re.findall("T(.+)", iso_str)
# check of values of hours, mins and secs and add their values to duration
if s_time:
for c, m in time_vals:
match = re.findall(r"([\d\.]+)" + c, s_time[0])
if match:
duration += float(match[0]) * m
return duration
[docs]class TfaProcessor:
input_data: TfaInput
X_name: str
X_label: str
params: dict
segment_index: list
"""
Data to be used by the TFA tools and functions
Parameters
----------
input_data : TfaInput object
The object containing the data
active_variable: dict
Dictionary of the form {"varname": "F"} which specifies the name of
the variable that will be used for further processing. If this
variable is a vector, an additional key is required containing the
component of this vector that will be processed, e.g.
{"varname": "B_NEC", "component": 0}. Component number uses the
Python convention of starting from zero, so 0 is the first component
(N in the NEC case), 1 is the second (E) and 2 is the third (C).
"""
def __init__(self, input_data, active_variable, params=None):
self.X_name = None
self.X_label = None
self.input_data = input_data
self.segment_index = []
self.Max_Segment = 0
self.params = (
params
if params
else {
"General": {
"freq_lims": [0.020, 0.100],
"lat_lims": None,
"maglat_lims": [-75, 75],
}
}
)
# specify the variable that will be processed
self.X_name = active_variable["varname"]
self.X_label = self.X_name
X = self.input_data.xarray[self.X_name].data
# if its a vector, isolate the specified component
if "component" in active_variable.keys():
Xi = active_variable["component"]
X = X[:, Xi]
self.X_label += "_" + str(Xi + 1)
# get Timestamps and convert to seconds
t_old = self.input_data.xarray.get("Timestamp").data
t_old_sec = (t_old - t_old[0]) / np.timedelta64(
1, "s"
) # converts diff to seconds
# apply "constant_candence"
new_t_sec, new_X = tfalib.constant_cadence(
t_old_sec, X, 1 / self.input_data.SAMPLING_TIME, interp=False
)[0:2]
new_t = t_old[0] + new_t_sec * np.timedelta64(1, "s")
# insert new values to xarray Dataset
self.input_data.xarray = self.input_data.xarray.assign_coords({"time": new_t})
self.input_data.xarray = self.input_data.xarray.assign({"X": (("time"), new_X)})
# self.segment_index = self.create_segment_index()
@property
def analysis_window(self):
"""list[datetime]: pair of unpadded times"""
return self.input_data.TIME_LIMS
[docs] def apply(self, process):
"""Apply a TFA Process"""
process.apply(self)
[docs] def create_segment_index(self, lat_lims=None):
"""Create an array with the segment index for each time"""
if not (lat_lims is None):
if hasattr(lat_lims, "__len__") and (not isinstance(lat_lims, str)):
if len(lat_lims) == 2:
self.params["General"]["maglat_lims"] = lat_lims
elif len(lat_lims) == 1:
self.params["General"]["maglat_lims"] = [
-np.abs(lat_lims),
+np.abs(lat_lims),
]
else:
logging.warn(
"create_segment_index: 'lat_lims' not a 2-element \
vector - using default value"
)
else:
self.params["General"]["maglat_lims"] = [
-np.abs(lat_lims),
+np.abs(lat_lims),
]
# interpolate mag_lat to new "time"
old_t = (
self.input_data.xarray["Timestamp"].data
- self.input_data.xarray["Timestamp"].data[0]
).astype(np.float64)
new_t = (
self.input_data.xarray["time"].data - self.input_data.xarray["time"].data[0]
).astype(np.float64)
mlat = np.interp(new_t, old_t, self.input_data.xarray["QDLat"])
ind = np.arange(len(mlat))
mlat_bool = (mlat > self.params["General"]["maglat_lims"][0]) & (
mlat < self.params["General"]["maglat_lims"][1]
)
d = np.hstack(([0], np.diff(ind[mlat_bool])))
c = np.cumsum(d > 1)
si = np.full(mlat.shape, np.NaN)
si[mlat_bool] = c
# remove segments that are in the padded time intervals
tlims = np.array(self.analysis_window).astype(np.datetime64)
si[
(self.input_data.xarray["time"].data < tlims[0])
| (self.input_data.xarray["time"].data > tlims[-1])
] = np.NaN
si = si - np.nanmin(si)
self.segment_index = si
self.Max_Segment = np.nanmax(si)
[docs] def get_segment_inds_and_lims(self, segment):
if segment is None:
inds = np.full(self.input_data.xarray["time"].data.shape, True)
else:
if segment > self.Max_Segment:
segment = self.Max_Segment
inds = np.full(self.input_data.xarray["time"].data.shape, False)
inds[self.segment_index == segment] = True
t_min = np.nanmin(self.input_data.xarray["time"].data[inds])
t_max = np.nanmax(self.input_data.xarray["time"].data[inds])
return (inds, [t_min, t_max])
[docs] def plotX(self, full=False, segment=None):
"""Plot the active variable time series"""
(inds, [t_min, t_max]) = self.get_segment_inds_and_lims(segment)
# plt.figure()
plt.plot(
self.input_data.xarray["time"].data[inds],
self.input_data.xarray["X"].data[inds],
)
plt.title(self.input_data.COLLECTION)
if self.X_label[0] == "E":
y_label = self.X_label + " (mV/m)"
else:
y_label = self.X_label + " (nT)"
plt.ylabel(y_label)
plt.grid(True)
if not full and segment is None:
plt.xlim(self.input_data.TIME_LIMS)
elif not full:
plt.xlim([t_min, t_max])
elif full:
pass
[docs] def image(self, full=False, segment=None, cbar_lims=None, log=True):
"""
Plot the dynamic spectrum of the result of the wavelet transform
The wavelet process must have been successfully applied first.
"""
(inds, [t_min, t_max]) = self.get_segment_inds_and_lims(segment)
freqs = 1000 / self.input_data.xarray["scale"].data
if cbar_lims is None:
m = np.max([np.log10(np.min(self.input_data.xarray["wavelet_power"])), -6])
x = np.log10(np.max(self.input_data.xarray["wavelet_power"]))
else:
m, x = cbar_lims
cb_ticks = np.arange(np.ceil(m), np.floor(x))
# plt.figure()
# plt.subplot(D,1,i+1)
if log:
plt.contourf(
self.input_data.xarray["time"].data[inds],
freqs,
np.log10(self.input_data.xarray["wavelet_power"][:, inds]),
cmap="jet",
levels=np.linspace(m, x, 20),
extend="both",
)
else:
plt.contourf(
self.input_data.xarray["time"].data[inds],
freqs,
self.input_data.xarray["wavelet_power"][:, inds],
cmap="jet",
levels=np.linspace(m, x, 20),
extend="both",
)
# plt.yticks(ticks=yticks, labels=yticklabels)
# plt.ylim(freq_lims)
plt.ylabel("Freq (mHz)")
cbh = plt.colorbar(orientation="horizontal", shrink=1, aspect=50)
# cb_ticks = cbh.get_ticks()
cbh.set_ticks(cb_ticks)
cbh.set_ticklabels(["%.1f" % i for i in cb_ticks], fontsize=8)
if not full and segment is None:
plt.xlim(self.input_data.TIME_LIMS)
elif not full:
plt.xlim([t_min, t_max])
# plt.title(self.input_data.COLLECTION)
[docs] def plotAUX(self, full=False, segment=None):
"""Plot Mag.Lat. and MLT time series"""
[t_min, t_max] = self.get_segment_inds_and_lims(segment)[1]
# plt.figure()
# fig,ax = plt.subplots()
plt.plot(
self.input_data.xarray["Timestamp"].data,
self.input_data.xarray["QDLat"].data,
"-b",
label="QD Lat",
)
plt.plot(0, 0, "-r", label="MLT") # this is just for the legend
ax = plt.gca()
ax.set_ylabel("QDLat (deg)")
ax.set_ylim([-90, 90])
ax.set_yticks([-90, -45, 0, 45, 90])
ax2 = ax.twinx()
ax2.plot(
self.input_data.xarray["Timestamp"].data,
self.input_data.xarray["MLT"].data,
"-r",
label="MLT",
)
ax2.set_ylabel("MLT (hr)")
ax2.set_ylim([0, 24])
ax2.set_yticks([0, 6, 12, 18, 24])
# ax.title(self.input_data.COLLECTION)
ax.grid(True)
ax.legend()
if not full and segment is None:
ax.set_xlim(self.input_data.TIME_LIMS)
elif not full:
ax.set_xlim([t_min, t_max])
elif full:
pass
[docs] def plotI(self, full=False, segment=None):
"""
Plot the wave index time series.
The wavelet process must have been successfully applied first and then
the wave_index() function has to be executed to produce the index.
Optionally, the user can run the wave_detection() function, before the
wave_index() to remove parts of the signal that have been identified
as suspicious false positives (e.g. spikes) or that might be related
to ESF signatures (Plasma Bubbles).
"""
(inds, [t_min, t_max]) = self.get_segment_inds_and_lims(segment)
# plt.figure()
plt.plot(
self.input_data.xarray["time"].data[inds],
self.input_data.xarray["wavelet_index"].data[inds],
)
plt.title(self.input_data.COLLECTION)
plt.ylabel("Wave Index")
plt.grid(True)
if not full and segment is None:
plt.xlim(self.input_data.TIME_LIMS)
elif not full:
plt.xlim([t_min, t_max])
elif full:
pass
[docs] def wave_index(self):
"""
Produce the index of wave activity for the frequencies that were used
in the wavelet process.
The wavelet process must have been successfully applied first.
"""
if "wavelet_power" in self.input_data.xarray:
wavindex = np.nansum(self.input_data.xarray["wavelet_power"].data, 0)
self.input_data.xarray = self.input_data.xarray.assign(
{"wavelet_index": (("time"), wavindex)}
)
else:
logging.warn(
"wave_index(): No wavelet array 'wavelet_power' found! Must apply the Wavelet function first!"
)
[docs] def interp_nans(self):
"""Interpolate NaN values by a piecewise linear interpolation scheme"""
i = np.arange(len(self.input_data.xarray["X"]))
nonNaNinds = np.where(~np.isnan(self.input_data.xarray["X"].data))
self.input_data.xarray["X"].data = np.interp(
i, i[nonNaNinds], self.input_data.xarray["X"].data[nonNaNinds]
)
[docs] def wave_detection(self, threshold=0):
"""
Remove parts of the wavelet spectrum that might not be true waves.
The wavelet process must have been successfully applied first.
This function removes (sets to NaN) parts of the wavelet spectrum that
might be due to spikes, data gaps, ESFs or trailing parts of wave
activity from either above or below the range of frequencies that were
used to perform the wavelet transform.
"""
if "wavelet_power" in self.input_data.xarray:
# remove points below the threshold
threshInds = self.input_data.xarray["wavelet_power"].data < threshold
self.input_data.xarray["wavelet_power"].data[threshInds] = np.NaN
# remove points that are outside the segments
# (if segments have been defined)
if len(self.segment_index) > 0:
outInds = np.isnan(self.segment_index)
self.input_data.xarray["wavelet_power"].data[:, outInds] = np.NaN
# find peak frequency for each time and exclude events with peak
# at the edges of the frequency range
maxInds = np.argmax(self.input_data.xarray["wavelet_power"].data, 0)
self.input_data.xarray["wavelet_power"].data[
:, np.where(maxInds == 0)
] = np.NaN
nFreqs = len(self.input_data.xarray["scale"])
self.input_data.xarray["wavelet_power"].data[
:, np.where(maxInds == nFreqs - 1)
] = np.NaN
# remove according to IBI (only for Swarm!)
if self.input_data.COLLECTION[0:2] == "SW":
sat = self.input_data.COLLECTION[11] # get sat char
from viresclient import SwarmRequest
SERVER_URL = "https://vires.services/ows"
request = SwarmRequest(SERVER_URL)
request.set_collection("SW_OPER_IBI%sTMS_2F" % sat)
request.set_products(
measurements=["Bubble_Probability"], sampling_step="PT1S"
)
logging.warn("wave_detection: Retrieving IBI L2 product")
data = request.get_between(
self.analysis_window[0].strftime("%Y-%m-%dT%H:%M:%S"),
self.analysis_window[1].strftime("%Y-%m-%dT%H:%M:%S"),
)
ibi = data.as_xarray()
# interpolate to the times of the wavelet_power array
bubble = np.interp(
self.input_data.xarray["time"].data.astype(np.float64),
ibi["Timestamp"].data.astype(np.float64),
ibi["Bubble_Probability"].data.astype(np.float64),
)
bubbleInds = np.where(bubble > 0.20)
self.input_data.xarray["wavelet_power"].data[:, bubbleInds] = np.NaN
return bubble
else:
logging.warn(
"wave_detection(): No wavelet array 'wavelet_power' found! Must apply the Wavelet function first!"
)
[docs]class TFA_Process(ABC):
params = None
def __init__(self, params):
self.params = params
[docs] @abstractmethod
def apply(self, target):
return
# append the parameters to the "meta" dictionary of the target TFA_Data object
[docs] def append_params(self, target):
target.params[self.__name__] = self.params
[docs]class Cleaning(TFA_Process):
__name__ = "Cleaning"
def __init__(self, params=None):
if params is None:
self.params = {"Window_Size": 10, "Method": "iqr", "Multiplier": 0.5}
else:
self.params = params
[docs] def apply(self, target):
inds = tfalib.outliers(
target.input_data.xarray["X"].data,
self.params["Window_Size"],
method=self.params["Method"],
multiplier=self.params["Multiplier"],
)
target.input_data.xarray["X"].data[inds] = np.NaN
# interpolate cleaned values and pre-existing gaps
s = target.input_data.xarray["X"].data.shape
if len(s) == 1:
N = s[0]
t_ind = np.arange(N)
x = target.input_data.xarray["X"].data
nonNaN = ~np.isnan(x)
y = np.interp(t_ind, t_ind[nonNaN], x[nonNaN])
target.input_data.xarray["X"].data = y
else:
N, D = target.input_data.xarray["X"].data.shape
t_ind = np.arange(N)
for i in range(D):
x = np.reshape(target.input_data.xarray["X"].data[:, i], (N,))
nonNaN = ~np.isnan(x)
y = np.interp(t_ind, t_ind[nonNaN], x[nonNaN])
target.input_data.xarray["X"].data = y
self.append_params(target)
return target
[docs]class Filtering(TFA_Process):
__name__ = "Filtering"
def __init__(self, params=None):
if params is None:
self.params = None
else:
self.params = params
[docs] def apply(self, target):
if self.params is None:
self.params = {
"Sampling_Rate": 1 / target.input_data.SAMPLING_TIME,
"Cutoff_Frequency": 20 / 1000,
}
else:
self.params["Sampling_Rate"] = 1 / target.input_data.SAMPLING_TIME
if "Cutoff_Scale" in self.params:
self.params["Cutoff_Frequency"] = 1 / self.params["Cutoff_Scale"]
target.input_data.xarray["X"].data = tfalib.filter(
target.input_data.xarray["X"].data,
self.params["Sampling_Rate"],
self.params["Cutoff_Frequency"],
)
self.append_params(target)
return target
[docs]class Wavelet(TFA_Process):
__name__ = "Wavelet"
def __init__(self, params=None):
if params is None:
self.params = None
else:
if "Min_Frequency" in params and "Max_Frequency" in params:
if params["Min_Frequency"] < params["Max_Frequency"]:
self.params = params
else:
logging.warn("Min_Frequency must be smaller than Max_Frequency")
elif "Min_Frequency" in params or "Max_Frequency" in params:
logging.warn(
"The limits must both be in either frequency or scale,\
no combinations allowed."
)
else:
if params["Min_Scale"] < params["Max_Scale"]:
self.params = params
else:
logging.warn("Min_Scale must be smaller than Max_Scale")
[docs] def apply(self, target):
if self.params is None:
self.params = {
"Time_Step": target.input_data.SAMPLING_TIME,
"Min_Scale": 1000 / 100,
"Max_Scale": 1000 / 1,
"dj": 0.1,
}
else:
self.params["Time_Step"] = target.input_data.SAMPLING_TIME
if "Min_Frequency" in self.params and "Max_Frequency" in self.params:
if self.params["Max_Frequency"] <= 1 / (2 * self.params["Time_Step"]):
self.params["Min_Scale"] = 1 / self.params["Max_Frequency"]
self.params["Max_Scale"] = 1 / self.params["Min_Frequency"]
else:
logging.warn(
"Max_Frequency needs to be smaller than 1/(2*Time_Step)"
)
else:
if self.params["Min_Scale"] < 2 * self.params["Time_Step"]:
logging.warn("Min_Scale needs to be bigger or equal to 2*Time_Step")
self.params["Wavelet_Function"] = "Morlet"
self.params["Wavelet_Param"] = 6.2036
self.params["Wavelet_Norm_Factor"] = 0.74044116
s = tfalib.wavelet_scales(
self.params["Min_Scale"], self.params["Max_Scale"], self.params["dj"]
)
wave = tfalib.wavelet_transform(
target.input_data.xarray["X"].data,
dx=self.params["Time_Step"],
minScale=self.params["Min_Scale"],
maxScale=self.params["Max_Scale"],
dj=self.params["dj"],
)[0]
norm = tfalib.wavelet_normalize(
np.abs(wave) ** 2,
s,
dx=self.params["Time_Step"],
dj=self.params["dj"],
wavelet_norm_factor=0.74044116,
)
# delete old ones, if they exist, first! This is necessary for multiple
# applications of the Wavelet() process, otherwise it conlficts with
# the variables that are already in the xarray
if "wavelet_power" in target.input_data.xarray:
target.input_data.xarray = target.input_data.xarray.drop("wavelet_power")
if "scale" in target.input_data.xarray:
target.input_data.xarray = target.input_data.xarray.drop("scale")
# insert new values to xarray Dataset
target.input_data.xarray = target.input_data.xarray.assign_coords({"scale": s})
target.input_data.xarray = target.input_data.xarray.assign(
{"wavelet_power": (("scale", "time"), norm)}
)
self.append_params(target)
return target