Source code for swarmpal.experimental.dsecs_plotting

from __future__ import annotations

import datetime as dt
import logging
from contextlib import contextmanager
from functools import wraps

import cartopy.crs as ccrs
import ipywidgets as widgets
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output, display
from matplotlib.colors import Normalize

from swarmpal.utils.exceptions import PalError

logger = logging.getLogger(__name__)

__all__ = (
    "plot_analysed_pass",
    "quicklook",
    "quicklook_animated",
    "quicklook_subtree",
)


@contextmanager
def _disable_mpl_interactive_mode():
    """Used to temporarily disable interactive plotting mode"""
    initial_state = plt.isinteractive()  # Get the initial state
    try:
        plt.ioff()  # Turn off interactive mode during the context
        yield
    finally:
        if initial_state:
            plt.ion()  # Restore interactive mode if it was initially on
        else:
            plt.ioff()  # Restore interactive mode if it was initially off


def _turn_off_interactive_mode(func):
    """Used to temporarily disable interactive plotting mode"""

    @wraps(func)
    def wrapper(*args, **kwargs):
        initially_interactive = mpl.is_interactive()
        plt.ioff()  # Turn off interactive mode before calling the function
        result = func(*args, **kwargs)
        if initially_interactive:
            plt.ion()  # Turn interactive mode back on after the function call
        return result

    return wrapper


def _get_dsecs_meta(datatree, output_dataset, check_analysis=False):
    pal_processes_meta = datatree.swarmpal.pal_meta.get(output_dataset, {})
    if not pal_processes_meta.get("DSECS_Preprocess"):
        raise PalError("Must first run dsecs.processes.Preprocess")
    if check_analysis and pal_processes_meta.get("DSECS_Analysis") is None:
        raise PalError("Must first run dsecs.processes.Analysis")
    return pal_processes_meta


def _get_dsecs_pass_time_interval(datatree, output_dataset, pass_no=0):
    """Extract time start and end of a given pass"""
    s = datatree[f"{output_dataset}/{pass_no}"]["currents"].attrs["Time interval"]
    t1, t2 = s.split(" - ")
    t1 = dt.datetime.fromisoformat(t1.split(".")[0])
    t2 = dt.datetime.fromisoformat(t2.split(".")[0])
    return t1, t2


[docs] @_turn_off_interactive_mode def plot_analysed_pass(datatree, output_dataset, pass_no=0, extent="global"): """Plot a figure showing currents from one orbital pass Parameters ---------- datatree : DataTree A datatree processed with the DSECS toolbox output_dataset: String The subtree in datatree to plot; should be an output produced by the DSECS toolbox. pass_no : int A number between 0 and x, specifying the pass to plot extent : str | tuple, default "global" "global" or "automatic", or a tuple to be provided to ax.set_extent() Returns ------- matplotlib.figure.Figure """ # Select the inputs we'll need for the figure pal_processes_meta = _get_dsecs_meta(datatree, output_dataset, check_analysis=True) dataset_name_alpha = pal_processes_meta["DSECS_Preprocess"]["dataset_alpha"] dataset_name_charlie = pal_processes_meta["DSECS_Preprocess"]["dataset_charlie"] data_a = datatree[dataset_name_alpha] data_c = datatree[dataset_name_charlie] data_currents = datatree[f"{output_dataset}/{pass_no}/currents"] # Create a figure and axes with an orthographic projection # centred around the spacecraft longitude midpoint # central_lon = data_a["Longitude"].median().data central_lon = data_currents["Longitude"].isel(y=0).median().data fig, axes = plt.subplots( nrows=1, ncols=3, figsize=(15, 5), subplot_kw={ "projection": ccrs.Orthographic( central_longitude=central_lon, central_latitude=0 ) }, ) # Set common view on each subplot, with spacecraft tracks for ax in axes: if extent == "global": ax.set_global() elif extent == "automatic": min_lon = data_currents["Longitude"].min() - 10 max_lon = data_currents["Longitude"].max() + 10 ax.set_extent([min_lon, max_lon, -50, 50]) else: ax.set_extent(extent) ax.coastlines(color="purple", alpha=0.5) ax.scatter( data_a["Longitude"], data_a["Latitude"], color="grey", transform=ccrs.PlateCarree(), s=0.1, alpha=0.1, ) ax.scatter( data_c["Longitude"], data_c["Latitude"], color="grey", transform=ccrs.PlateCarree(), s=0.1, alpha=0.1, ) # Plot arrows for each of the DF and CF currents # # or a single central slice?: # slicenum = int(data_currents.dims["y"]/2) # current_slice = data_currents.isel(y=slicenum) for slicenum in range(data_currents.dims["y"]): current_slice = data_currents.isel(y=slicenum) axes[0].quiver( current_slice["Longitude"], current_slice["Latitude"], current_slice["JEastDf"], current_slice["JNorthDf"], transform=ccrs.PlateCarree(), pivot="tail", angles="uv", scale=0.0001, scale_units="xy", width=10000, units="xy", color="blue", ) axes[0].set_title("Horizontal DF current") axes[1].quiver( current_slice["Longitude"], current_slice["Latitude"], current_slice["JEastCf"], current_slice["JNorthCf"], transform=ccrs.PlateCarree(), pivot="tail", angles="uv", scale=0.0001, scale_units="xy", width=10000, units="xy", color="blue", ) axes[1].set_title("Horizontal CF current") # Set color min/max range, based on 99th percentile of data # Symmetric around 0, and rounded up to nearest 10 vmin, vmax = np.nanquantile(data_currents["Jr"], (0.01, 0.99)) vminmax = np.ceil(np.max(np.abs((vmin, vmax))) / 10) * 10 norm = Normalize(vmin=-vminmax, vmax=vminmax) axes[2].pcolormesh( data_currents["Longitude"], data_currents["Latitude"], data_currents["Jr"], transform=ccrs.PlateCarree(), cmap="RdYlBu", norm=norm, ) axes[2].set_title("Radial current") # Add time start and end of pass, and dataset sources # TODO: Add product version numbers title_text = f"{dataset_name_alpha}\n{dataset_name_charlie}" t1, t2 = _get_dsecs_pass_time_interval(datatree, output_dataset, pass_no=pass_no) title_text += f"\nStart: {t1.isoformat()}\nEnd: {t2.isoformat()}" fig.suptitle(title_text, x=0.9, ha="right", va="bottom") # plt.close() return fig
[docs] @_turn_off_interactive_mode def quicklook_subtree(datatree, output_dataset, frame_select="all"): """Returns figures overviewing the outputs of the analysis for a specific subtree in a datatree. Parameters ---------- datatree : DataTree A datatree from the DSECS toolbox output_dataset: String The subtree in datatree to plot; should be an output produced by the DSECS toolbox. frame_select : str, default "all" "all", "odd", "even" to limit the frame numbers displayed Returns ------- dict[int, matplotlib.figure.Figure] """ try: _ = _get_dsecs_meta(datatree, output_dataset, check_analysis=True) except PalError: raise PalError("No quicklook available before analysis has been run") # Identify number of analysed passes and generate a fig for each one num_passes = len(datatree[output_dataset].children) # Config to select which frames to generate if frame_select == "all": frames = range(num_passes) elif frame_select == "odd": frames = [i for i in range(num_passes) if i % 2 != 0] elif frame_select == "even": frames = [i for i in range(num_passes) if i % 2 == 0] else: raise ValueError("frame_select should be 'all', 'odd', or 'even'") fig_collection = {} for i in frames: fig_collection[i] = plot_analysed_pass(datatree, output_dataset, pass_no=i) return fig_collection
[docs] @_turn_off_interactive_mode def quicklook(datatree, frame_select="all"): for output_dataset in datatree.swarmpal.pal_meta["."]["output_datasets"]: if "DSECS_Analysis" not in datatree.swarmpal.pal_meta[output_dataset]: continue return quicklook_subtree(datatree, output_dataset, frame_select=frame_select)
[docs] @_turn_off_interactive_mode def quicklook_animated(datatree, frame_select="all"): """Creates an animation of quicklook plots, using ipywidgets Parameters ---------- datatree : DataTree A datatree from the DSECS toolbox frame_select : str, default "all" "all", "odd", "even" to limit the frame numbers displayed """ # Prerender the figures for each frame fig_collection = quicklook(datatree, frame_select=frame_select) # Get the figure associated with a particular pass def get_pass_figure(pass_no): fig = fig_collection.get(pass_no) return fig # Generate widgets to use for the output and control output = widgets.Output() def update_figure(frame): with output: clear_output(wait=True) display(get_pass_figure(frame)) # Select frames according to odd/even frames = tuple(fig_collection.keys()) frame_start = frames[0] frame_end = frames[-1] frame_step = frames[1] - frames[0] slider = widgets.IntSlider( value=frame_start, min=frame_start, max=frame_end, step=frame_step ) play = widgets.Play( value=frame_start, min=frame_start, max=frame_end, step=frame_step, interval=1200, description="Press play", disabled=False, playing=False, repeat=True, ) # Link the widgets and create the display widgets.jslink((play, "value"), (slider, "value")) play.observe(lambda change: update_figure(change["new"]), "value") return widgets.VBox( ( widgets.HBox([play, slider]), output, ) )