Source code for skfdiff.plugins.displays

#!/usr/bin/env python
# coding=utf-8

from collections import defaultdict
from functools import partial

from holoviews import Curve, DynamicMap, Image, Layout, streams
from loguru import logger
from path import Path  # noqa


[docs]class Display: def __init__(self, skel_data, plot_function): self._plot_pipe = streams.Pipe(data=skel_data) self._dynmap = DynamicMap(plot_function, streams=[self._plot_pipe]) def _repr_mimebundle_(self, *args, **kwargs): return self.hv_curve._repr_mimebundle_(*args, **kwargs)
[docs] def connect(self, stream, n=1): stream.partition(n).pluck(-1).sink(self._plot_pipe.send)
@property def hv_curve(self): return self._dynmap.collate()
[docs] def show(self): return self.hv_curve
def __add__(self, other): if isinstance(other, Display): return self._dynmap + other._dynmap return self._dynmap + other def __mul__(self, other): if isinstance(other, Display): return self._dynmap * other._dynmap self._dynmap * other
[docs] @staticmethod def display_custom(simul, plot_function, n=1): def wrapped_plot_function(data): return plot_function(data) display = Display(simul, wrapped_plot_function) display.connect(simul.stream, n) return display
[docs] @staticmethod def display_fields(simul, keys="unknowns", n=1, dim_allowed=(0, 1, 2)): _0D_history = defaultdict( list ) # mypy: Dict[str, Sequence[Tuple[float, float]]] def plot_function(keys, _0D_history, data): curves = [] if keys == "all": keys = data.fields.data_vars elif keys == "unknowns": parameters = [par.name for par in simul.model.pdesys.parameters] keys = [ key for key in data.fields.data_vars if key in set(data.fields.data_vars).difference(parameters) ] if not keys: raise ValueError("keys are empty.") for var in keys: displayed_field = data.fields[var] if 0 in dim_allowed and len(displayed_field.dims) == 0: _0D_history[var].append((simul.t, float(data.fields[var]))) curve = Curve(_0D_history[var], kdims="t", vdims=var, label=var) curves.append(curve) if 1 in dim_allowed and len(displayed_field.dims) == 1: curves.append(Curve((displayed_field.squeeze()), label=var)) elif 2 in dim_allowed and len(displayed_field.dims) == 2: curves.append(Image((displayed_field.squeeze()), label=var)) else: continue return Layout(curves) display = Display(simul, partial(plot_function, keys, _0D_history)) display.connect(simul.stream, n) display.cols = display.hv_curve.cols return display