#!/usr/bin/env python
# coding=utf-8
import re
import typing
from collections import defaultdict
from functools import partial, reduce
from itertools import chain, count, product
from operator import and_, mul
from queue import Queue
from copy import deepcopy
import cloudpickle
import attr
from cached_property import cached_property
import numpy as np
from loguru import logger
from more_itertools import unique_everseen
from sympy import (
Derivative,
Eq,
Function,
Indexed,
KroneckerDelta,
Number,
Symbol,
Wild,
Dummy,
oo,
solve,
sympify,
)
from sympy.utilities.lambdify import MODULES
from .spatial_schemes import FiniteDifferenceScheme, chain_schemes, upwind
from .variables import Coordinate, Unknown, _convert_coord_list, _convert_unk_list
def _convert_pde_list(pdes):
if isinstance(pdes, str):
return [pdes]
else:
return pdes
def _partial_derivative(expr, symbolic_coordinates):
# proxy function that can be easily curried (with functools.partial)
return Derivative(expr, *symbolic_coordinates)
def _build_sympy_namespace(
equation, coordinates, unknowns, parameters, custom_functions=None
):
""" Check the equation, find all the derivative in Euler notation
(see https://en.wikipedia.org/wiki/Notation_for_differentiation#Euler's_notation)
the way that dxxU will be equal to Derivative(U(x), x, x).
All the derivative found that way are add to a subsitution rule as dict and
applied when the equation is sympified.
"""
# look at all the dxxU, dxyV... and dx(...), dxy(...) and so on in the equation
if custom_functions is None:
custom_functions = {}
spatial_derivative_re = re.compile(
r"d(?P<derargs>\w+?)(?:(?P<depder>(?:%s)+)|\((?P<inder>.*?)\))"
% "|".join([var.name for var in chain(unknowns, parameters)])
)
spatial_derivatives = spatial_derivative_re.findall(str(equation))
# they can be derivatives inside the dx(...), we check it until there is no more
queue = Queue()
[queue.put(sder[2]) for sder in spatial_derivatives if sder[2]]
while not queue.empty():
inside_derivative = queue.get()
new_derivatives = spatial_derivative_re.findall(inside_derivative)
[queue.put(sder[2]) for sder in new_derivatives if sder[2]]
spatial_derivatives.extend(new_derivatives)
# The sympy namespace is built with...
namespace = deepcopy(custom_functions)
# All the coordinates
namespace.update({coord.name: coord.symbol for coord in coordinates})
# All the dependent variables: unk and parameters
namespace.update(
{
dvar.name: (
dvar.symbol(*[coord.symbol for coord in dvar.coordinates])
if dvar.coordinates
else dvar.symbol
)
for dvar in chain(unknowns, parameters)
}
)
# All the harversted derivatives
for coord, unk, _ in spatial_derivatives:
if unk:
namespace["d%s%s" % (coord, unk)] = _partial_derivative(
namespace[unk], coord
)
else:
namespace["d%s" % coord] = partial(
_partial_derivative, symbolic_coordinates=coord
)
return namespace
def _build_keep_derivs_namespace(equation):
""" Check the equation, find all the derivative in Euler notation
(see https://en.wikipedia.org/wiki/Notation_for_differentiation#Euler's_notation)
the way that dxxU will be equal to Derivative(U(x), x, x).
All the derivative found that way are add to a subsitution rule as dict and
applied when the equation is sympified.
"""
# look at all the DxxU, DxyV... and Dx(...), Dxy(...) and so on in the equation
spatial_derivative_re = re.compile(r"D(\w+?)")
spatial_derivatives = []
for function in equation.atoms(Function):
match = spatial_derivative_re.match(str(function.func))
if not match:
continue
wrts = match.groups()[0].split()
spatial_derivatives.append((wrts, function.func(*function.args)))
return spatial_derivatives
[docs]def list_node_coords(domain):
available_node_coords = {}
for coord, (left_cond, right_cond) in domain.items():
node_coords = []
node_coords.append(coord.idx)
# for both side, left|right_cond can be true : in that case, no coords has to
# be added : all that side is in the bulk.
try:
node_coords.extend(np.arange(coord.idx.lower, left_cond.rhs))
except AttributeError:
pass
try:
node_coords.extend(
np.arange(right_cond.rhs - coord.N + 1, coord.idx.upper - coord.N + 1)
+ coord.N
)
except AttributeError:
pass
available_node_coords[coord] = node_coords
return available_node_coords
[docs]def list_conditions(domain):
available_conds = {}
for coord, (left_cond, right_cond) in domain.items():
conds = []
conds.append(left_cond & right_cond)
# for both side, left|right_cond can be true : in that case, no coords has to
# be added : all that side is in the bulk.
try:
conds.extend(
[
Eq(coord.idx, node_coord)
for node_coord in np.arange(coord.idx.lower, left_cond.rhs)
]
)
except AttributeError:
pass
try:
conds.extend(
[
Eq(coord.idx, node_coord)
for node_coord in np.arange(
right_cond.rhs - coord.N + 1, coord.idx.upper - coord.N + 1
)
+ coord.N
]
)
except AttributeError:
pass
available_conds[coord] = conds
return available_conds
[docs]@attr.s
class PDEquation:
equation = attr.ib(type=str)
unknowns = attr.ib(type=typing.Sequence[Unknown], converter=_convert_unk_list)
parameters = attr.ib(
type=typing.Sequence[Unknown], converter=_convert_unk_list, default=[]
)
subs = attr.ib(type=dict, factory=dict)
custom_functions = attr.ib(type=dict, factory=dict)
boundary_conditions = attr.ib(type=dict, factory=dict)
schemes = attr.ib(
type=typing.Sequence[FiniteDifferenceScheme],
default=(FiniteDifferenceScheme(),),
repr=False,
)
symbolic_equation = attr.ib(init=False, repr=False)
fdiff = attr.ib(init=False, repr=False)
raw = attr.ib(type=bool, default=False, repr=False)
dirichlet_nodes = attr.ib(
type=typing.Sequence[typing.Tuple[int, ...]], factory=list, repr=False
)
def __attrs_post_init__(self):
self._t = Symbol("t")
logger.info("complete coordinate")
self._complete_coordinates()
if self.raw:
# For "raw" equations already in discretized form as periodic bc
self.fdiff = sympify(
self.equation, locals={unk.name: unk.discrete for unk in self.unknowns}
)
return
logger.info("fill incomplete unks")
self._fill_incomplete_unknowns()
logger.info("build sympy namespace")
self._sympy_namespace = _build_sympy_namespace(
self.equation,
self.coordinates,
self.unknowns,
self.parameters,
self.custom_functions,
)
sympified_subs = {}
logger.info("deal with subs")
# substitute the auxiliary definition
for subs_key, subs_value in self.subs.items():
local_namespace = _build_sympy_namespace(
subs_value,
self.coordinates,
self.unknowns,
self.parameters,
self.custom_functions,
)
self._sympy_namespace.update(local_namespace)
sympified_subs[subs_key] = sympify(subs_value, locals=self._sympy_namespace)
sympified_subs = {
str(key): value.subs(sympified_subs)
for key, value in sympified_subs.items()
}
self._sympy_namespace = dict(**self._sympy_namespace, **sympified_subs)
logger.info("translate to symbolic equation")
self._sympify_equation()
self.symbolic_equation = self.symbolic_equation.subs(sympified_subs)
self._check_symbolic_equation(self.symbolic_equation)
logger.info("translate to discrete equation")
self._as_finite_diff()
self._reduce_coordinates()
def _check_symbolic_equation(self, symbolic_equation):
available_symbols = [
str(var.symbol)
for var in [*self.unknowns, *self.coordinates, *self.parameters]
]
available_symbols.append("t")
available_symbols.append("upwind")
available_symbols.extend(set(re.findall(r"(D\w+?)", str(symbolic_equation))))
available_symbols.extend(MODULES["scipy"][0].keys())
orphaned_symbols = {
str(token) for token in symbolic_equation.atoms(Symbol)
}.difference(available_symbols)
orphaned_functions = {
str(token.func) for token in symbolic_equation.atoms(Function)
}.difference(available_symbols)
orphans = {*orphaned_symbols, *orphaned_functions}
if orphans:
raise ValueError(
f"One or more symbols ({', '.join(orphans)}) are missing. "
"You may have forgotten to include them into the unknowns, "
"parameters or substitution."
)
def __getstate__(self):
return {key: cloudpickle.dumps(value) for key, value in self.__dict__.items()}
def __setstate__(self, state):
self.__dict__.update(
{key: cloudpickle.loads(value) for key, value in state.items()}
)
@property
def parsed_boundary_conditions(self):
return dict(self._parse_bc(self.boundary_conditions))
def _reduce_coordinates(self):
variables = [
self.mapper[str(indexed.base)] for indexed in self.fdiff.atoms(Indexed)
]
available_unks = [
variable for variable in variables if isinstance(variable, Unknown)
]
real_coordinates = set().union(*[unk.coordinates for unk in available_unks])
self.coordinates = sorted(real_coordinates)
def _non_expendable_deriv(self, wrts, arg):
arg = sympify(arg)
dummy = Function(str(Dummy()))
dummy_deriv = Derivative(dummy(*self.coordinates), *wrts)
fdiff_dummy = chain_schemes(self.schemes, dummy_deriv)
def replace_dummy(*coords):
return arg.subs(
{
symbolic_coord.name: coord
for symbolic_coord, coord in zip(self.coordinates, coords)
}
)
return fdiff_dummy.replace(dummy, replace_dummy)
@property
def stencils(self):
stencils = {}
ref_pde = self([0] * len(self.coordinates))
indexed = ref_pde.atoms(Indexed)
for unk in self.unknowns:
coords = set(
[
tuple(map(int, index.indices))
for index in indexed
if unk.discrete == index.args[0]
]
)
ptp = np.array(list(coords)).ptp() + 1
center = ptp // 2
stencil_array = np.zeros([ptp] * len(unk.coords), dtype=bool)
idxs = [np.array(coord + center, dtype=int) for coord in zip(*coords)]
stencil_array[tuple(idxs)] = True
stencils[unk.name] = stencil_array
return stencils
def _fill_incomplete_unknowns(self):
"""fill every dependent variables that lack information on
independent variables with the global independent variables
"""
for i, unk in enumerate(self.unknowns):
if not unk.coordinates:
object.__setattr__(self.unknowns[i], "coordinates", self.coordinates)
def _complete_coordinates(self):
"""if independent variables is not set, extract them from
the dependent variables. If not set in dependent variables,
1D resolution with "x" as independent variable is assumed.
"""
harvested_coords = list(
chain(
*[
dep_var.coordinates
for dep_var in self.unknowns
if dep_var.coordinates is not None
]
)
)
if harvested_coords:
self.coordinates = harvested_coords
else:
self.coordinates = _convert_coord_list(["x"])
self.coordinates = list(unique_everseen(self.coordinates))
def _sympify_equation(self):
self.symbolic_equation = sympify(self.equation, locals=self._sympy_namespace)
for unk in self.unknowns:
self.symbolic_equation = self.symbolic_equation.subs(unk.name, unk.symbol)
def _as_finite_diff(self):
logger.debug("expand derivs")
fdiff = self.symbolic_equation.doit()
logger.debug("apply finite difference")
fdiff = chain_schemes(self.schemes, fdiff)
logger.debug("replace upwinds")
fdiff = fdiff.replace(Function("upwind"), upwind)
logger.debug("deal with kept-in-form derivatives")
to_keep_derivs = _build_keep_derivs_namespace(fdiff)
for wrts, arg in to_keep_derivs:
fdiff = fdiff.replace(
Function("D%s" % "".join(wrts)),
partial(self._non_expendable_deriv, wrts),
)
logger.debug("replace continuous coordinate to discrete idx.")
for coord in self.coordinates:
a = Wild("a", exclude=[coord.step, coord.symbol, 0])
for func in fdiff.atoms(Function):
new_func = func.replace(coord.symbol + a * coord.step, coord.idx + a)
fdiff = fdiff.subs(func, new_func)
logger.debug("Replace symbols to indexed.")
for var in chain(self.unknowns, self.parameters):
def replacement(*args):
return var.discrete[args]
if var.coordinates:
for func in fdiff.atoms(var.symbol):
new = func.replace(var.symbol, replacement)
fdiff = fdiff.subs(func, new)
logger.debug("Replace indices.")
for indexed in fdiff.atoms(Indexed):
new_indexed = indexed.subs(
{coord.symbol: coord.idx for coord in self.coordinates}
)
fdiff = fdiff.subs(indexed, new_indexed)
fdiff = fdiff.subs(
{coord.symbol: coord.discrete[coord.idx] for coord in self.coordinates}
)
self.fdiff = fdiff
def __call__(self, coordinates):
logger.trace("evaluate pde %s at %s" % (self.equation, coordinates))
subs = {
coord.idx: coord_value
for coord, coord_value in zip(self.coordinates, coordinates)
}
logger.trace("subs: %s" % subs)
return self.fdiff.subs(subs)
def __str__(self):
return self.__repr__()
[docs] def domains(self, *node_coords):
return tuple(
[
coord.domain(node_coord)
for coord, node_coord in zip(self.coordinates, node_coords)
]
)
@property
def unknowns_dict(self):
return {unk.name: unk for unk in self.unknowns}
@property
def coordinates_dict(self):
return {
coord.name: coord
for coord in set(chain(*[unk.coords for unk in self.unknowns]))
}
@property
def parameters_dict(self):
return {par.name: par for par in self.parameters}
@property
def mapper(self):
return dict(
**self.unknowns_dict, **self.parameters_dict, **self.coordinates_dict
)
@property
def physical_domains(self):
domains = product(*[("left", "bulk", "right") for coord in self.coordinates])
conds = product(
*[
(
Eq(coord.idx, coord.idx.lower),
(coord.idx.lower < coord.idx) & (coord.idx < coord.idx.upper),
Eq(coord.idx, coord.idx.upper),
)
for coord in self.coordinates
]
)
return dict(zip(domains, [reduce(and_, cond) for cond in conds]))
@property
def node_coords(self):
domains = product(*[("left", "bulk", "right") for coord in self.coordinates])
return [
tuple(
coord.get_node_coord(domain)
for coord, domain in zip(self.coordinates, domain)
)
for domain in domains
]
@property
def node_coords_subs(self):
return [
{
coord.idx: node_coord
for coord, node_coord in zip(self.coordinates, node_coords)
}
for node_coords in self.node_coords
]
[docs] def get_domains(self, indexed):
unk = self.mapper[str(indexed.base)]
return unk.domains(*indexed.indices)
[docs] def get_distances(self, indexed):
unk = self.mapper[str(indexed.base)]
return unk.distances(*indexed.indices)
[docs] def is_in_bulk(self, indexed):
return not self.is_outside(indexed)
[docs] def is_outside(self, indexed):
return any(self.get_distances(indexed))
def _distance_node_to_other(self, node, other_node):
return tuple(np.abs(np.array(node) - np.array(other_node)))
def _walk_in_domain(self, node, domains):
next_node = tuple(node)
for distance in count():
local_pde = self(next_node)
outside_indexed = list(filter(self.is_outside, local_pde.atoms(Indexed)))
if not outside_indexed:
break
next_node = [
idx + sign_distance(1, domain)
for idx, domain in zip(next_node, domains)
]
product_distances = product(*[range(distance)] * len(node))
for distances in product_distances:
yield tuple(
[
idx + sign_distance(distance, domain)
for idx, domain, distance in zip(node, domains, distances)
]
)
def _parse_bc(self, bcs):
if isinstance(bcs, str):
default_bcs = defaultdict(lambda: bcs)
else:
default_bcs = defaultdict(lambda: "noflux", bcs)
unks = [*self.unknowns, *self.parameters]
for unk, axis in chain(*[product([unk], unk.coords) for unk in unks]):
yield (unk, axis), BoundaryCondition(
unk, axis, self, default_bcs[(unk.name, axis.name)]
)
# long running, should not be in property
@cached_property
def computation_nodes(self):
extended_computation_nodes = []
for new_node in chain(
*[
self._walk_in_domain(node, domains)
for node, domains in zip(self.node_coords, self.physical_domains)
]
):
extended_computation_nodes.append(new_node)
return list(set(extended_computation_nodes).union(self.node_coords))
def _get_filtered_cross(self, coord, node_coords):
coord_idx = self.coordinates.index(coord)
cross_node_coords = tuple(
[node for i, node in enumerate(node_coords) if i != coord_idx]
)
cross_computation_nodes = [
tuple([node for i, node in enumerate(nodes) if i != coord_idx])
for nodes in self.computation_nodes
]
for i, (computation_node_coords, cross_computation_nodes) in enumerate(
zip(self.computation_nodes, cross_computation_nodes)
):
ziped_crosses = list(zip(cross_node_coords, cross_computation_nodes))
same_nodes = [
node == computation_node for node, computation_node in ziped_crosses
]
bulk_nodes = [
str(coord.idx) in map(str, computation_node.atoms(Indexed))
for node, computation_node in ziped_crosses
]
relevant_nodes = [
(same_node or bulk_node)
for same_node, bulk_node in zip(same_nodes, bulk_nodes)
]
this_node = computation_node_coords[coord_idx]
if all(relevant_nodes) and str(coord.idx) not in map(
str, this_node.atoms()
):
yield this_node
def _node_to_domain(self, node_coords):
domain = []
for i, (coord, node_coord) in enumerate(zip(self.coordinates, node_coords)):
if str(coord.idx) not in map(str, node_coord.atoms()):
domain.append(Eq(coord.idx, node_coord))
else:
# get all the nodes on the same "cross"
same_cross_nodes = list(self._get_filtered_cross(coord, node_coords))
right_nodes = set(
[
node
for node in same_cross_nodes
if str(coord.N) in map(str, node.atoms())
]
)
left_nodes = set(same_cross_nodes) - right_nodes
left_node = max(left_nodes)
right_node = (
min([node - coord.idx.upper for node in right_nodes])
+ coord.idx.upper
)
domain.append((left_node < coord.idx) & (coord.idx < right_node))
return reduce(and_, domain)
@property
def computation_domains(self):
domains = []
for coord_nodes in self.computation_nodes:
domains.append(self._node_to_domain(coord_nodes))
return domains
[docs] def get_ghost_equation(self, ghost, node):
unk = self.mapper[str(ghost.base)]
indices = ghost.indices
domains = self.get_domains(ghost)
distances = self.get_distances(ghost)
coords, domains, distances = zip(
*[
(coord, domain, distance)
for coord, domain, distance in zip(unk.coords, domains, distances)
if distance > 0
]
)
boundaries = [
self.parsed_boundary_conditions[unk, coord]
for coord, domain in zip(coords, domains)
]
eqs, kinds = zip(
*[
(
bc.get(
domain,
domains={
coord: domain
for coord, domain in zip(
unk.coords, self.get_domains(ghost)
)
},
evaluation_node=ghost,
offset=-distance,
),
getattr(bc, "kind_%s" % domain),
)
for bc, coord, domain, distance in zip(
boundaries, coords, domains, distances
)
]
)
for eq, kind, domain, coord in zip(eqs, kinds, domains, coords):
if kind != "periodic":
eval_indices = [
index
if coord_idx != coord
else (coord.idx.lower if domain == "left" else coord.idx.upper)
for index, coord_idx in zip(indices, unk.coordinates)
]
return eq(eval_indices)
else:
return eq(indices)
def _extrapolate_coords(self, local_eq):
outside_coords = [
idx
for idx in filter(self.is_outside, local_eq.atoms(Indexed))
if isinstance(self.mapper[str(idx.base)], Coordinate)
]
for indexed_coord in outside_coords:
coord = self.mapper[str(indexed_coord.base)]
index = indexed_coord.indices[0]
distance = coord.distance_from_domain(index)
domain = coord.domain(index)
if domain == "left":
rhs = coord.discrete[coord.idx.lower] - coord.step * distance
else:
rhs = coord.discrete[coord.idx.upper] + coord.step * distance
logger.debug(
"extrapolate outside coordinate, subs: {%s: %s}" % (indexed_coord, rhs)
)
local_eq = local_eq.subs(indexed_coord, rhs)
return local_eq
def _get_ghosts(self, local_eq):
return [
idx
for idx in filter(self.is_outside, local_eq.atoms(Indexed))
if isinstance(self.mapper[str(idx.base)], Unknown)
]
# long running, should not be in property
@cached_property
def piecewise_system(self):
piecewise_system = []
for node_coord, domain in zip(self.computation_nodes, self.computation_domains):
if node_coord in self.dirichlet_nodes:
piecewise_system.append(Number(0))
continue
local_eq = self._extrapolate_coords(self(node_coord))
ghosts = self._get_ghosts(local_eq)
while ghosts:
logger.debug("ghosts: %s" % ghosts)
eqs = [self.get_ghost_equation(ghost, node_coord) for ghost in ghosts]
solved = solve(eqs, ghosts, dict=True)
local_eq = self._extrapolate_coords(local_eq.subs(solved[0]))
ghosts = self._get_ghosts(local_eq)
piecewise_system.append(local_eq)
return tuple(piecewise_system)
[docs]def sign_distance(distance, domain):
if domain == "bulk":
return 0
if domain == "left":
return distance
return -distance
[docs]def dict_product(d):
keys = d.keys()
for element in product(*d.values()):
yield dict(zip(keys, element))
[docs]@attr.s
class PDESys:
evolution_equations = attr.ib(
type=typing.Sequence[str], converter=_convert_pde_list
)
unknowns = attr.ib(type=typing.Sequence[Unknown], converter=_convert_unk_list)
parameters = attr.ib(
type=typing.Sequence[Unknown], converter=_convert_unk_list, factory=list
)
coordinates = attr.ib(
type=typing.Sequence[Coordinate],
default=[],
converter=_convert_coord_list,
repr=False,
)
boundary_conditions = attr.ib(type=dict, factory=dict)
subs = attr.ib(type=dict, factory=dict)
custom_functions = attr.ib(type=dict, factory=dict)
# domains = attr.ib(type=dict, default=None, init=False, repr=False)
def _coerce_equations(self):
self.evolution_equations = [
PDEquation(
eq,
self.unknowns,
self.parameters,
subs=self.subs,
custom_functions=self.custom_functions,
boundary_conditions=self.boundary_conditions,
)
for eq in self.evolution_equations.copy()
]
self._apply_dirichlet()
self.coordinates = sorted(
set(chain(*[eq.coordinates for eq in self.evolution_equations.copy()]))
)
def _apply_dirichlet(self):
for unk, pde in zip(self.unknowns, self.evolution_equations):
for domains, coord_node in zip(pde.physical_domains, pde.node_coords):
bcs = [
pde.parsed_boundary_conditions[(unk, coord)]
for coord in unk.coordinates
]
kinds = [
getattr(bc, "kind_%s" % domain)
for bc, domain in zip(bcs, domains)
if domain != "bulk"
]
if "Dirichlet" in kinds:
pde.dirichlet_nodes.append(coord_node)
def _get_shapes(self):
self.pivot = None
gridinfo = dict(zip(self.coordinates, [coord.N for coord in self.coordinates]))
shapes = []
for unk in self.unknowns:
gridshape = [
(gridinfo[coord] if coord in unk.coordinates else 1)
for coord in self.coordinates
]
shapes.append(tuple(gridshape))
sizes = [reduce(mul, shape) for shape in shapes]
self.size = sum(sizes)
self.shapes = dict(zip(self.unknowns, shapes))
self.sizes = dict(zip(self.unknowns, sizes))
def __attrs_post_init__(self):
logger.info("processing pde system")
logger.info("coerce equations...")
self._t = Symbol("t")
self._coerce_equations()
self._get_shapes()
logger.info("done")
def _sort_indexed(self, indexed):
unk_idx = [unk.name for unk in self.unknowns].index(str(indexed.args[0]))
coords = self.unknowns[unk_idx].coordinates
idxs = indexed.args[1:]
idxs = [
idxs[coords.index(coord)] if coord in coords else 0
for coord in self.coordinates
]
return [unk_idx, *idxs]
def _filter_unk_indexed(self, indexed):
return indexed.base in [unk.discrete for unk in self.unknowns]
def _simplify_kron(self, *kron_args):
kron = KroneckerDelta(*kron_args)
return kron.subs({coord.N: oo for coord in self.coordinates})
@cached_property
def piecewise_system(self):
return list(chain(*(pde.piecewise_system for pde in self)))
@cached_property
def jacobian_values(self):
jacobian_values = []
for expr in self.piecewise_system:
wrts = list(filter(self._filter_unk_indexed, expr.atoms(Indexed)))
diffs = [
expr.diff(wrt).replace(KroneckerDelta, self._simplify_kron).n()
for wrt in wrts
]
jacobian_values.append(diffs)
return jacobian_values
@cached_property
def jacobian_columns(self):
jacobian_columns = []
for expr in self.piecewise_system:
wrts = list(filter(self._filter_unk_indexed, expr.atoms(Indexed)))
grids = list(map(self._sort_indexed, wrts))
jacobian_columns.append(grids)
return jacobian_columns
@property
def unknowns_dict(self):
return {unk.name: unk for unk in self.unknowns}
@property
def coordinates_dict(self):
return {
coord.name: coord
for coord in set(chain(*[unk.coords for unk in self.unknowns]))
}
@property
def parameters_dict(self):
return {par.name: par for par in self.parameters}
@property
def computation_domains(self):
return [pde.computation_domains for pde in self]
@property
def mapper(self):
return dict(
**self.unknowns_dict, **self.parameters_dict, **self.coordinates_dict
)
@property
def equation_dict(self):
return {
unk.name: equation
for unk, equation in zip(self.unknowns, self.evolution_equations)
}
def __getitem__(self, key):
if isinstance(key, int):
return self.evolution_equations[key]
if isinstance(key, str):
return self.equation_dict[key]
raise KeyError(key)
[docs]def target_axis(derivative, wrt, axis):
return wrt == axis.symbol
[docs]@attr.s
class BoundaryCondition:
unknown = attr.ib(type=Unknown, converter=Unknown)
axis = attr.ib(type=Coordinate, converter=Coordinate)
pde = attr.ib(PDEquation)
bcs = attr.ib(type=typing.Optional[typing.Union[str, typing.Tuple[str, str]]])
left = attr.ib(type=PDEquation, init=False, repr=False)
right = attr.ib(type=PDEquation, init=False, repr=False)
kind_left = attr.ib(type=str, init=False)
kind_right = attr.ib(type=typing.Union[typing.Tuple[str, str], str], init=False)
[docs] @bcs.default
def noflux_default(self):
return ("noflux", "noflux")
[docs] @bcs.validator
def bcs_validator(self, attribute, value):
if value in ["periodic", "noflux"]:
return True
if not isinstance(value, str) and len(value) == 2:
return True
def _build_bcs(self, bcs):
if bcs == "periodic":
left = self._build_periodic("left")
right = self._build_periodic("right")
return (left, "periodic"), (right, "periodic")
if bcs == "noflux":
bcs = ("noflux", "noflux")
return (self._build_bc(bc) for bc in bcs)
def _build_bc(self, eq):
kind = None
if eq is None or eq == "noflux":
eq = "d%s%s" % (self.axis.name, self.unknown.name)
if eq == "dirichlet":
eq = "d%s%s" % (self.axis.name, self.unknown.name)
kind = "Dirichlet"
return eq, kind or "Ghost"
def _build_periodic(self, side):
unk = self.unknown
axis = self.axis
if side == "left":
substitute = axis.idx + axis.N
else:
substitute = axis.idx - axis.N
return unk.discrete_i - unk.discrete_i.subs(axis.idx, substitute)
def __attrs_post_init__(self):
self.bcs = self.bcs or "noflux"
(self.left, self.kind_left), (self.right, self.kind_right) = self._build_bcs(
self.bcs
)
[docs] def evaluate_side_patterns(self, node, coord, domain):
if coord == self.axis:
return
if str(node) == str(coord.idx):
return
if domain == "bulk" and str(coord.idx) in map(str, node.atoms(Symbol)):
this_offset = node - coord.idx
if this_offset > 0:
scheme = "left"
else:
scheme = "right"
elif domain == "bulk":
return
else:
if domain == "right":
scheme = "left"
else:
scheme = "right"
return FiniteDifferenceScheme(
scheme=scheme, pattern=partial(target_axis, axis=coord)
)
[docs] def get(self, side, domains, evaluation_node, offset=None):
if getattr(self, "kind_%s" % side) == "periodic":
return PDEquation(self._build_periodic(side), self.unknown, raw=True)
if offset is None:
offset = 1
if side == "left":
scheme = "right"
main_offset = offset
else:
scheme = "left"
main_offset = -offset
main_scheme = FiniteDifferenceScheme(
scheme=scheme,
offset=main_offset,
pattern=partial(target_axis, axis=self.axis),
)
schemes = [main_scheme]
for node, (coord, domain) in zip(evaluation_node.indices, domains.items()):
scheme = self.evaluate_side_patterns(node, coord, domain)
if scheme:
schemes.append(scheme)
return PDEquation(
getattr(self, side),
self.pde.unknowns,
self.pde.parameters,
schemes=schemes[::-1],
)