"""
Tools to connect to the outside world and get/create xarray Datasets
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from os import PathLike
from os.path import exists as path_exists
from pathlib import Path
from hapiclient import hapi, hapitime2datetime
from numpy.typing import ArrayLike
from pandas import to_datetime as to_pandas_datetime
# from xarray.core.extension_array import PandasExtensionArray
from pandas.core.arrays.categorical import Categorical
from viresclient import SwarmRequest
from xarray import Dataset, open_dataset
from swarmpal.io._cdf_interface import cdf_to_xarray
[docs]
@dataclass
class Parameters:
"""Control which dataset is accessed, and how the fetcher behaves"""
pad_times: tuple[timedelta] = ()
[docs]
@dataclass
class ViresParameters(Parameters):
collection: str = ""
measurements: list[str] = field(default_factory=list)
start_time: str | datetime = ""
end_time: str | datetime = ""
server_url: str = "https://vires.services/ows"
models: list[str] = field(default_factory=list)
auxiliaries: list[str] = field(default_factory=list)
sampling_step: str | None = None
filters: list[str] = field(default_factory=list)
options: dict = field(default_factory=dict)
[docs]
@dataclass
class HapiParameters(Parameters):
server: str = ""
dataset: str = ""
parameters: str = ""
start: str = ""
stop: str = ""
options: dict = field(default_factory=dict)
[docs]
@dataclass
class FileParameters(Parameters):
filename: PathLike | None = None
[docs]
@dataclass
class CDFFileParameters(FileParameters): ...
[docs]
@dataclass
class NetCDFFileParameters(FileParameters):
group: str | None = None
[docs]
@dataclass
class ManualParameters(Parameters): ...
[docs]
class DataFetcherBase(ABC):
"""Interface with an external data source"""
@property
@abstractmethod
def source(self) -> str:
"""String to identify the data source type (e.g. 'vires', 'hapi')"""
...
@property
@abstractmethod
def parameters(self) -> Parameters:
"""Set of parameters to control how/what data is accessed"""
...
[docs]
@abstractmethod
def fetch_data(self) -> Dataset:
"""Command to get data as an xarray Dataset"""
...
[docs]
def config(self) -> dict:
config = asdict(self._parameters)
config["provider"] = self.source
return config
[docs]
class ViresDataFetcher(DataFetcherBase):
"""Connects to and retrieves data from VirES through viresclient"""
@property
def source(self) -> str:
return "vires"
@property
def parameters(self) -> ViresParameters:
return self._parameters
@parameters.setter
def parameters(self, parameters: dict) -> None:
self._parameters = ViresParameters(**parameters)
def __init__(self, **parameters) -> None:
"""Create connection to VirES and initialise the request
Available parameters are listed in ViresParameters
Next run .fetch_data() to trigger the request
"""
self.parameters = parameters
self.vires_request = self._initialise_request()
[docs]
def _initialise_request(self) -> SwarmRequest:
"""Use the set parameters to initialise the request"""
vires_request = SwarmRequest(self.parameters.server_url)
vires_request.set_collection(self.parameters.collection)
vires_request.set_products(
measurements=self.parameters.measurements,
models=self.parameters.models,
auxiliaries=self.parameters.auxiliaries,
sampling_step=self.parameters.sampling_step,
)
for filter in self.parameters.filters:
vires_request.add_filter(filter)
return vires_request
[docs]
def fetch_data(self) -> Dataset:
"""Process the request on VirES and load an xarray Dataset"""
result = self.vires_request.get_between(
self.parameters.start_time,
self.parameters.end_time,
**self.parameters.options,
).as_xarray()
# Convert PandasExtensionArray to numpy.ndarray
for var in result.variables:
if isinstance(result[var].data, Categorical):
result[var].data = result[var].data.to_numpy()
return result
[docs]
class HapiDataFetcher(DataFetcherBase):
"""Connects to and retrieves data from a HAPI server through hapiclient"""
@property
def source(self) -> str:
return "hapi"
@property
def parameters(self) -> HapiParameters:
return self._parameters
@parameters.setter
def parameters(self, parameters: dict) -> None:
self._parameters = HapiParameters(**parameters)
def __init__(self, **parameters) -> None:
"""Prepare inputs for hapi & test connection
Available parameters are listed in HapiParameters
Next run .fetch_data() to trigger the request
"""
self.parameters = parameters
self._get_hapi_info()
[docs]
@staticmethod
def _hapi_to_xarray(data: ArrayLike, meta: dict) -> Dataset:
# Separate the time variable name from the other varnames
# (assuming time variable comes first in the list)
varnames = [p["name"] for p in meta["parameters"]]
timevar, varnames = varnames[0], varnames[1:]
# Generate dimension labels for each parameter
dims = ()
for p in meta["parameters"][1:]:
n_extra_dims = len(p.get("size", []))
extra_dims = (f"{p['name']}_dim{i + 1}" for i in range(n_extra_dims))
dims = (*dims, (timevar, *extra_dims))
# Convert time data to timezone-naive DatetimeIndex
tdata = to_pandas_datetime(hapitime2datetime(data[timevar]))
tdata = tdata.tz_convert("UTC").tz_convert(None)
# Assuming we now have ordered lists of varnames, dims,
# assemble a Dataset from the data & meta
ds = Dataset(
data_vars={
timevar: (timevar, tdata),
**{_name: (_dim, data[_name]) for _name, _dim in zip(varnames, dims)},
}
)
# Assign metadata for each data variable
for p in meta["parameters"][1:]:
ds[p["name"]].attrs = {
"units": p.get("units"),
"description": p.get("description"),
}
return ds
[docs]
def _get_hapi_info(self) -> dict:
# Get info response from HAPI server
return hapi(
self.parameters.server,
self.parameters.dataset,
self.parameters.parameters,
**self.parameters.options,
)
[docs]
def fetch_data(self) -> Dataset:
"""Make a HAPI query and load an xarray Dataset"""
data, meta = hapi(
self.parameters.server,
self.parameters.dataset,
self.parameters.parameters,
self.parameters.start,
self.parameters.stop,
**self.parameters.options,
)
return self._hapi_to_xarray(data, meta)
[docs]
class NetCDFfileDataFetcher(DataFetcherBase):
@property
def source(self) -> str:
return "netcdf_file"
@property
def parameters(self) -> NetCDFFileParameters:
return self._parameters
@parameters.setter
def parameters(self, parameters: dict) -> None:
self._parameters = NetCDFFileParameters(**parameters)
def __init__(self, filename: PathLike, group: str | None = None) -> None:
self.parameters = dict(filename=filename, group=group)
if not path_exists(self.parameters.filename):
raise FileNotFoundError(self.parameters.filename)
[docs]
def fetch_data(self) -> Dataset:
# filename_or_obj and group are kwargs for xarray.open_dataset
kwargs = {"filename_or_obj": self.parameters.filename}
if self.parameters.group:
kwargs["group"] = self.parameters.group
ds = open_dataset(**kwargs)
try:
ds.attrs["Sources"]
except KeyError:
ds.attrs["Sources"] = [Path(self.parameters.filename).name]
return ds
[docs]
class CDFfileDataFetcher(DataFetcherBase):
@property
def source(self) -> str:
return "cdf_file"
@property
def parameters(self) -> CDFFileParameters:
return self._parameters
@parameters.setter
def parameters(self, parameters: dict) -> None:
self._parameters = CDFFileParameters(**parameters)
def __init__(self, **parameters) -> None:
self.parameters = parameters
if not path_exists(self.parameters.filename):
raise FileNotFoundError(self.parameters.filename)
[docs]
def fetch_data(self) -> Dataset:
kwargs = {"cdf_file": self.parameters.filename}
ds = cdf_to_xarray(**kwargs)
try:
ds.attrs["Sources"]
except KeyError:
ds.attrs["Sources"] = [Path(self.parameters.filename).name]
return ds
[docs]
class ManualDataFetcher(DataFetcherBase):
@property
def source(self) -> str:
return "manual"
@property
def parameters(self) -> FileParameters:
return self._parameters
@parameters.setter
def parameters(self, parameters: dict) -> None:
self._parameters = ManualParameters(**parameters)
def __init__(self, xarray_dataset: Dataset) -> None:
self.parameters = dict()
if not isinstance(xarray_dataset, Dataset):
raise ValueError("Given data must be xarray.Dataset")
self._xarray = xarray_dataset.copy()
[docs]
def fetch_data(self) -> Dataset:
return self._xarray
[docs]
def get_fetcher(source) -> DataFetcherBase:
fetchers = {
"ViresDataFetcher": ViresDataFetcher,
"HapiDataFetcher": HapiDataFetcher,
"NetCDFfileDataFetcher": NetCDFfileDataFetcher,
"CDFfileDataFetcher": CDFfileDataFetcher,
"ManualDataFetcher": ManualDataFetcher,
}
try:
return fetchers[source]
except KeyError:
raise KeyError(f"Data source '{source}' not found")