Source code for swarmpal._api

from __future__ import annotations

from datetime import timedelta

from xarray import DataTree

from swarmpal.io._paldata import PalDataItem, create_paldata
from swarmpal.schema import validate as _validate

processes_by_name = {}


[docs] def make_process(process_name=None, config={}): """Instantiates a PalProcess object with a given name and a configuration. Parameters ---------- process_name: Str The name of the process to apply. Must be one of ['FAC_single_sat']. config: dict Parameters passed to the Toolbox' process. """ if process_name not in processes_by_name: if process_name == "FAC_single_sat": from swarmpal.toolboxes.fac.processes import FAC_single_sat # noqua(I001) processes_by_name["FAC_single_sat"] = FAC_single_sat elif process_name == "DSECS_Preprocess": from swarmpal.toolboxes.dsecs.processes import Preprocess processes_by_name["DSECS_Preprocess"] = Preprocess elif process_name == "DSECS_Analysis": from swarmpal.toolboxes.dsecs.processes import Analysis processes_by_name["DSECS_Analysis"] = Analysis elif process_name == "TFA_Preprocess": from swarmpal.toolboxes.tfa.processes import Preprocess as TFA_Preprocess processes_by_name["TFA_Preprocess"] = TFA_Preprocess elif process_name == "TFA_Clean": from swarmpal.toolboxes.tfa.processes import Clean as TFA_Clean processes_by_name["TFA_Clean"] = TFA_Clean elif process_name == "TFA_Filter": from swarmpal.toolboxes.tfa.processes import Filter as TFA_Filter processes_by_name["TFA_Filter"] = TFA_Filter elif process_name == "TFA_Wavelet": from swarmpal.toolboxes.tfa.processes import Wavelet as TFA_Wavelet processes_by_name["TFA_Wavelet"] = TFA_Wavelet elif process_name == "TFA_WaveDetection": from swarmpal.toolboxes.tfa.processes import ( WaveDetection as TFA_WaveDetection, ) processes_by_name["TFA_WaveDetection"] = TFA_WaveDetection elif process_name == "EXP_LocalForwardMagneticModel": from swarmpal.experimental import LocalForwardMagneticModel processes_by_name["EXP_LocalForwardMagneticModel"] = ( LocalForwardMagneticModel ) else: raise ValueError( f"Unknown process {process_name}. Must be one of ['FAC_single_sat', 'DSECS_Preprocess', 'DSECS_Analysis']" ) return processes_by_name[process_name](config=config)
[docs] def apply_process(data, process_name=None, config={}): """Create a SwarmPAL process and apply it on the given data. Parameters ---------- data: DataTree the data on which the process will be applied to. process_name: Str the name of the process to apply. See ... for a list of Toolboxes and their Processes. config: dict parameters passed to the Toolbox. """ process = make_process(process_name=process_name, config=config) return process(data)
[docs] def apply_processes(data, process_params): """Apply a list of processes to a dataset. Parameters ---------- data: DataTree the data on which the process will be applied to. process_params: a list of processes to apply to the input data. """ for config in process_params: process_name = config.pop("process_name") apply_process(data, process_name=process_name, config=config) return data
[docs] def _str_to_timedelta(time): """Convert strings that match 'HH:MM:SS' to datetime.timedelta objects.""" hours, minutes, seconds = (int(part) for part in time.split(":")) return timedelta(hours=hours, minutes=minutes, seconds=seconds)
[docs] def _fetch_dataset(provider="", config={}, options=None): """Helper that downloads a single dataset from a data provider with a specified configuration. Parameters ---------- provider: Str The name of the data provider. Must be one of ['vires', 'hapi']. config: List A configuration passed to `create_paldata`. TODO: describe the 'schema' options: dict or None The option passed to `create_paldata`. When None, the following defaults are used: .. list-table:: :widths: 30 70 :header-rows: 1 * - Provider - Default Options * - `vires` - `dict(asynchronous=False, show_progress=False)` * - `hapi` - `dict(logging=False)` """ if provider == "file": # TODO: # check that the file exists. # check that all required config keys are present. return create_paldata( **{ config["dataset"]: PalDataItem.from_file( filename=config["filename"], filetype=config["filetype"], group=config.get("dataset", None), ) } ) # Convert pad_times from strings to timedelta objects if "pad_times" in config: config["pad_times"] = [_str_to_timedelta(time) for time in config["pad_times"]] # Download the data if provider == "vires": options = options or dict(asynchronous=False, show_progress=False) return create_paldata(PalDataItem.from_vires(options=options, **config)) if provider == "hapi": options = options or dict(logging=False) return create_paldata(PalDataItem.from_hapi(options=options, **config)) raise ValueError( f"Unknown provider {provider}. Provider must be one of ['vires', 'hapi', 'file']." )
[docs] def fetch_data(configurations): """Downloads list of datasets and returns a unified DataTree. Parameters ---------- configurations: List A list of configurations passed to `create_paldata`. TODO: describe the 'schema' Returns ------- """ _validate(configurations) data = DataTree() for dataset_config in configurations["data_params"]: provider = dataset_config.pop("provider") options = dataset_config.pop("options", {}) item = _fetch_dataset(provider=provider, options=options, config=dataset_config) for key, dt in item.children.items(): data[key] = dt return data
_TOOLBOX_QUICKLOOKS = { "FAC": lambda data: data.swarmpal_fac.quicklook(), "TFA": lambda data: __import__( "swarmpal.toolboxes.tfa.plotting", fromlist=["quicklook"] ).quicklook(data), "DSECS": lambda data: __import__( "swarmpal.experimental.dsecs_plotting", fromlist=["plot_analysed_pass"], ).plot_analysed_pass(data, pass_no=0), }
[docs] def _detect_toolbox(data: DataTree) -> str: """Inspect PAL_meta to determine which toolbox produced the output datasets.""" pal_meta = data.swarmpal.pal_meta output_datasets = pal_meta.get(".", {}).get("output_datasets", []) found = set() for output_dataset in output_datasets: for process_name in pal_meta.get(output_dataset, {}): prefix = process_name.split("_", 1)[0] if prefix in _TOOLBOX_QUICKLOOKS: found.add(prefix) if len(found) == 0: raise RuntimeError( "No quicklook available: no recognised toolbox processes found in PAL_meta" ) if len(found) > 1: raise RuntimeError( f"Ambiguous quicklook: multiple toolboxes detected ({sorted(found)})" ) return found.pop()
[docs] def quicklook(data: DataTree, toolbox: str | None = None): """Create a quicklook plot based on the toolbox processes recorded in PAL_meta. Parameters ---------- data: DataTree Data that has been processed by SwarmPAL toolbox: str, optional Override toolbox selection (one of: 'FAC', 'TFA', 'DSECS'). Case-insensitive. If omitted, the toolbox is auto-detected from PAL_meta. Returns ------- matplotlib.figure.Figure """ if toolbox is None: toolbox = _detect_toolbox(data) else: toolbox = toolbox.upper() if toolbox not in _TOOLBOX_QUICKLOOKS: raise ValueError( f"Unknown toolbox {toolbox!r}. Must be one of {sorted(_TOOLBOX_QUICKLOOKS)}" ) result = _TOOLBOX_QUICKLOOKS[toolbox](data) return result[0] if isinstance(result, tuple) else result