Source code for skfdiff.plugins.container

#!/usr/bin/env python
# coding=utf-8

from functools import wraps
import uuid
from collections import deque
from typing import Union
import threading

from loguru import logger
from path import Path
from streamz import collect
from xarray import concat, open_dataset, open_mfdataset, open_zarr
from abc import ABC, abstractmethod

try:
    import zarr
except ModuleNotFoundError:
    print("Zarr module not found, ZarrContainer not available.")


@wraps(open_dataset)
def _safe_open_dataset(*args, **kwargs):
    with open_dataset(*args, **kwargs) as ds:
        data = ds.copy()
    return data


@wraps(open_mfdataset)
def _safe_open_mfdataset(*args, **kwargs):
    with open_mfdataset(*args, **kwargs) as ds:
        data = ds.compute().copy()
    return data


[docs]class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self
[docs]def coerce_attr(key, value): value_type = type(value) if value_type in [int, float, str]: return value for cast in (int, float, str): try: value = cast(value) logger.debug( "Illegal netCDF type ({}) of attribute for {}, " "casted to {}".format(value_type, key, cast) ) return value except TypeError: pass raise TypeError( "Illegal netCDF type ({}) of attribute for {}, " "auto-casting failed, tried to cast to " "int, float and str" )
[docs]def retrieve_container( path: Path, isel: Union[str, dict, int] = "all", lazy: bool = False ): if Path(path).dirname().ext == ".zarr": return ZarrContainer.retrieve(path, isel, lazy) else: return NetCDFContainer.retrieve(path, isel, lazy)
[docs]class Container: def __init__( self, path=None, mode="a", *, save="all", force=False, nbuffer=None, save_interval=None, background_write=False ): if save_interval is None and nbuffer is None: save_interval = 10 if save_interval is not None and nbuffer is not None: raise ValueError("You should provide either nbuffer or save_interval.") self.nbuffer = nbuffer self.save_interval = save_interval self._mode = mode self.save = save self._cached_data = deque([], self._n_save) self._collector = None self._count_iter = 0 self.path = path = Path(path).abspath() if path else None self._writer_thread_name = "writer_%s" % uuid.uuid1() self.background_write = background_write self.write = self._bg_write if background_write else self._write if not path: return if self._mode == "w" and force: path.rmtree_p() if self._mode == "w" and not force and path.exists(): raise FileExistsError( "Directory %s exists, set force=True to override it" % path ) if self._mode == "r" and not path.exists(): raise FileNotFoundError("Container not found.") path.makedirs_p() @property def save(self): return "last" if self._n_save else "all" @save.setter def save(self, value): if value == "all": self._n_save = None elif value == "last" or value == -1: self._n_save = 1 else: raise ValueError( 'save argument accept only "all", "last" or -1 ' "as value, not %s" % value ) def _expand_fields(self, t, fields): fields = fields.assign_coords(t=float(t)).expand_dims("t") self._cached_data.append(fields) return fields def _concat_fields(self, fields): if fields: return concat(fields, dim="t")
[docs] def connect(self, stream): def get_t_fields(simul): return simul.t, simul.fields def expand_fields(inps): return self._expand_fields(*inps) def get_last(list_fields): try: return list_fields[-1] except IndexError: pass accumulation_stream = stream.map(get_t_fields).map(expand_fields) self._collector = collect(accumulation_stream) if self.save == "all": self._collector.map(self._concat_fields).sink(self.write) else: self._collector.map(get_last).sink(self.write) if self.nbuffer is not None: (accumulation_stream.partition(self.nbuffer).sink(self._collector.flush)) if self.save_interval is not None: ( accumulation_stream.timed_window(self.save_interval) .filter(bool) .sink(self._collector.flush) ) return self._collector
@property def writers(self): return [ thread for thread in threading.enumerate() if thread.name == self._writer_thread_name ] @property def is_writing(self): if self.writers: return True return False
[docs] def flush(self): if self._collector: self._collector.flush() while self.is_writing: pass
def _bg_write(self, concatenated_fields): thread = threading.Thread( name=self._writer_thread_name, target=self._write, args=(concatenated_fields,), ) thread.start() def __repr__(self): repr = """path: {path} {data}""".format( path=self.path, data=self.data if self.data is not None else "Empty" ) return repr def __del__(self): self.flush() @abstractmethod def _write(self): pass @property @abstractmethod def data(self): pass
[docs]class MemoryContainer(Container): def _write(self, concatenated_fields): pass @property def data(self): try: return self._concat_fields(self._cached_data).sortby("t") except (OSError, AttributeError): return
[docs]class NetCDFContainer(Container): def __init__( self, path, mode="a", *, save="all", force=False, nbuffer=None, save_interval=None, background_write=True ): super().__init__( path, mode, save=save, force=force, nbuffer=nbuffer, save_interval=save_interval, background_write=background_write, ) def _write(self, concatenated_fields): if concatenated_fields is not None: len_concat_fields = concatenated_fields.t.size target_file = ( self.path / "data_%i-%i.nc" % (self._count_iter, self._count_iter + len_concat_fields - 1) ) self._count_iter = len_concat_fields + self._count_iter concatenated_fields.to_netcdf(target_file) concatenated_fields.close() self._cached_data = deque([], self._n_save) if self.save == "last": [ file.remove() for file in self.path.glob("data_*.nc") if file != target_file ]
[docs] @staticmethod def retrieve(path: Path, isel: Union[str, dict, int] = "all", lazy: bool = False): """Retrieve the data of a persistent container. Parameters ---------- path: Path or str The folder where the persistent container lives. isel : Union[str, dict, int], optional can be either "all" or "last", an integer or a sequence of integer. (the default is "all") lazy : bool, optional if True, return a lazy xarray Datasets that will be loaded when requested by the user (using the :py:func:`compute` method). Useful when the data are too big to fit in memory. (the default is False) Returns ------- :py:class:`xarray.Dataset` The requested data as time-dependant xarray Dataset. """ path = Path(path) if isel == "last": last_file = sorted( [filename for filename in path.files("data*.nc")], key=lambda filename: int( filename.basename().stripext().split("_")[-1].split("-")[-1] ), )[-1] data = _safe_open_dataset(last_file).isel(t=-1) else: if lazy: data = _safe_open_mfdataset( path / "data*.nc", combine="by_coords" ).sortby("t") else: data = concat( [open_dataset(filename) for filename in path.files("data*.nc")], dim="t", ).sortby("t") if isel not in ["all", "last"]: data = data.isel(t=isel) return data
[docs] @staticmethod def merge_datafiles(path, override=False): path = Path(path) if (path / "data.nc").exists() and not override: raise FileExistsError(path / "merged_data.nc") (path / "data.nc").remove_p() split_data = _safe_open_mfdataset( path / "data*.nc", combine="by_coords" ).sortby("t") split_data.to_netcdf(path / "merged_data.nc") merged_data = _safe_open_dataset( path / "merged_data.nc", chunks=split_data.chunks if split_data.chunks else None, ) if not split_data.equals(merged_data): (path / "merged_data.nc").remove() raise IOError("Unable to merge data ") split_data.close() merged_data.close() return path / "merged_data.nc"
@property def data(self): try: self.flush() return _safe_open_mfdataset( self.path / "data*.nc", combine="by_coords" ).sortby("t") except (OSError, AttributeError): return
[docs] def merge(self, override=True): if self.path: return self.merge_datafiles(self.path, override=override)
[docs]class ZarrContainer(Container): def __init__( self, path, mode="a", *, save="all", force=False, nbuffer=None, save_interval=None, background_write=False ): super().__init__( path, mode, save=save, force=force, nbuffer=nbuffer, save_interval=save_interval, background_write=background_write, ) def _write(self, concatenated_fields): if concatenated_fields is not None: concatenated_fields.to_zarr(self.path, append_dim="t", mode="a") @property def data(self): self.flush() return open_zarr(self.path)
[docs] @staticmethod def retrieve(path: Path, isel: Union[str, dict, int] = "all", lazy: bool = False): """Retrieve the data of a persistent container. Parameters ---------- path: Path or str The folder where the persistent container lives. isel : Union[str, dict, int], optional can be either "all" or "last", an integer or a sequence of integer. (the default is "all") lazy : bool, optional if True, return a lazy xarray Datasets that will be loaded when requested by the user (using the :py:func:`compute` method). Useful when the data are too big to fit in memory. (the default is False) Returns ------- :py:class:`xarray.Dataset` The requested data as time-dependant xarray Dataset. """ path = Path(path) data = open_zarr(path) if not lazy: data = data.compute() if isel == "all": return data if isel == "last": isel = -1 return data.isel(t=isel)