from __future__ import annotations
import logging
import matplotlib.pyplot as plt
from numpy import stack
from xarray import Dataset, DataTree, register_datatree_accessor
from swarmpal.io import PalProcess
from swarmpal.toolboxes.fac.fac_algorithms import fac_single_sat_algo
logger = logging.getLogger(__name__)
__all__ = (
"FAC_single_sat",
"PalFacDataTreeAccessor",
)
[docs]
class FAC_single_sat(PalProcess):
"""Provides the process for the classic single-satellite FAC algorithm"""
@property
def process_name(self):
return "FAC_single_sat"
[docs]
def set_config(
self,
dataset: str = "SW_OPER_MAGA_LR_1B",
model_varname: str = "B_NEC_CHAOS",
measurement_varname: str = "B_NEC",
inclination_limit: float = 30,
time_jump_limit: int = 1,
include_auxiliaries: bool = True,
output_dataset: str = "PAL_FAC_single_sat",
) -> None:
"""Configures the process
Parameters
----------
dataset : str, optional
Dataset to use, by default "SW_OPER_MAGA_LR_1B"
model_varname : str, optional
Name of the magnetic model predictions, by default "B_NEC_Model"
measurement_varname : str, optional
Name of the measurements, by default "B_NEC"
inclination_limit : float, optional
Limit of inclination for FAC validity (in degrees), by default 30
time_jump_limit : int, optional
Maximum allowable time step in data for FAC validity (in seconds), by default 1
include_auxiliaries : bool, optional
Whether to include e.g. Latitude, Longitude, Flags, etc, by default True
output_dataset : str
Sets the name of the dataset in the data tree that TFA processes will write results to, by default "PAL_FAC_singlesat"
"""
super().set_config(
dataset=dataset,
model_varname=model_varname,
measurement_varname=measurement_varname,
inclination_limit=inclination_limit,
time_jump_limit=time_jump_limit,
include_auxiliaries=include_auxiliaries,
output_dataset=output_dataset,
)
[docs]
def _call(self, datatree):
# Identify inputs for algorithm
subtree = datatree[self.config.get("dataset")]
dataset_in = subtree.ds
# Apply algorithm
fac_results = fac_single_sat_algo(
time=self._get_time(dataset_in),
positions=self._get_positions(dataset_in),
B_res=self._get_B_res(dataset_in),
B_model=self._get_B_model(dataset_in),
inclination_limit=self.config.get("inclination_limit"),
time_jump_limit=self.config.get("time_jump_limit"),
)
# Insert a new output dataset with these results
ds_out = Dataset(
data_vars={
"Timestamp": ("Timestamp", fac_results["time"]),
"FAC": ("Timestamp", fac_results["fac"]),
"IRC": ("Timestamp", fac_results["irc"]),
}
)
ds_out["FAC"].attrs = {"units": "uA/m2"}
ds_out["IRC"].attrs = {"units": "uA/m2"}
if self.config.get("include_auxiliaries"):
ds_out = self._append_aux(dataset_in, ds_out)
datatree[self.output_dataset] = DataTree(dataset=ds_out)
return datatree
[docs]
def _validate(self): ...
[docs]
def _get_time(self, dataset):
return dataset.get("Timestamp").data.astype("datetime64[ns]")
[docs]
def _get_positions(self, dataset):
return stack(
[
dataset.get("Latitude").data,
dataset.get("Longitude").data,
dataset.get("Radius").data,
],
axis=1,
)
[docs]
def _get_B_res(self, dataset):
measurement_varname = self.config.get("measurement_varname", "B_NEC")
model_varname = self.config.get("model_varname", "B_NEC_Model")
return dataset.get(measurement_varname).data - dataset.get(model_varname).data
[docs]
def _get_B_model(self, dataset):
model_varname = self.config.get("model_varname", "B_NEC_Model")
return dataset.get(model_varname).data
[docs]
def _append_aux(self, ds_in, ds_out):
"""Extract auxiliary information from inputs and add to output dataset"""
# Identify available auxiliaries that can be added
aux_in = set(ds_in.data_vars)
aux_desired = {
"Latitude",
"Longitude",
"Radius",
"Flags_F",
"Flags_B",
"Flags_q",
}
aux_matched = aux_desired.intersection(aux_in)
aux_missing = aux_desired.difference(aux_in)
if aux_missing:
logging.warning(f"Missing auxiliaries: {aux_missing}")
# FAC time series is shorter than the inputs, so need to interpolate
# Subset only the ones we want to append
if len(ds_in["Timestamp"]) > 0:
ds_in_interpd = ds_in[list(aux_matched)].interp_like(
ds_out, method="nearest"
)
else:
ds_in_interpd = ds_in.copy()
# Convert data types back to the source data (interpolation changes it to float64)
for aux in aux_matched:
ds_in_interpd[aux] = ds_in_interpd[aux].astype(ds_in[aux].dtype)
# Attach the subselected data variables
ds_out = ds_out.assign(
{aux_name: ds_in_interpd[aux_name] for aux_name in aux_matched}
)
ds_out.attrs["Sources"] = ds_in_interpd.attrs["Sources"]
return ds_out
[docs]
@register_datatree_accessor("swarmpal_fac")
class PalFacDataTreeAccessor:
def __init__(self, datatree) -> None:
self._datatree = datatree
[docs]
def quicklook(self):
fig, axes = plt.subplots(nrows=2, sharex=True)
meta = self._datatree.swarmpal.pal_meta
output_datasets = meta["."]["output_datasets"]
for output_dataset in output_datasets:
process_config = meta[output_dataset]
if "FAC_single_sat" not in process_config:
continue
dataset_dir = f"./{output_dataset}"
self._datatree[dataset_dir]["IRC"].plot.line(ax=axes[0])
self._datatree[dataset_dir]["FAC"].plot.line(ax=axes[1])
axes[0].set_xlabel("")
axes[0].grid()
axes[1].grid()
input_dataset = process_config["FAC_single_sat"]["dataset"]
fig.suptitle(f"Input: {input_dataset}")
return fig, axes