Skip to content

Solvers

Interface to nonlinear and ODE solvers

factory(module, name) ¤

Factory method for the solvers

Parameters:

Name Type Description Default
module str

Name of library providing solvers (diffrax, runge_kutta...)

required
name str

Name of function to be used (ODE, newton...)

required

Returns:

Type Description
(Callable, Callable)

Two functions, to build the solution object and to extract the states (qs) from this objects

Source code in feniax/systems/sollibs/__init__.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def factory(module: str, name: str) -> (Callable, Callable):
    """Factory method for the solvers

    Parameters
    ----------
    module : str
        Name of library providing solvers (diffrax, runge_kutta...)
    name : str
        Name of function to be used (ODE, newton...)

    Returns
    -------
    (Callable, Callable)
        Two functions, to build the solution object and to extract the
        states (qs) from this objects

    """

    library = importlib.import_module(f".{module}", __name__)
    function = getattr(library, name)
    states_puller = getattr(library, "pull_" + name)
    return states_puller, function

Diffrax¤

ode(F, args, sett, q0, t0, t1, tn, dt, **kwargs) ¤

Diffrax ODE solover

Source code in feniax/systems/sollibs/diffrax.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def ode(
    F: callable,
    args,
    sett,
    # solver_name: str,
    q0,
    t0,
    t1,
    tn,
    dt,
    # save_at=None,
    **kwargs,
) -> diffrax.Solution:
    """
    Diffrax ODE solover
    """

    # TODO: Logic should be extended and improved

    solver_sett = dict()
    diffeqsolve_sett = dict()
    term = diffrax.ODETerm(F)
    if sett.save_at is None:
        saveat = diffrax.SaveAt(
            ts=jnp.linspace(t0, t1, tn)
        )  # diffrax.SaveAt(steps=True) #
    else:
        saveat = sett.save_at
    _solver = getattr(diffrax, sett.solver_name)
    if (root := sett.root_finder) is not None:
        _root_finder = getattr(optx, list(root.keys())[0])
        root_finder = _root_finder(**list(root.values())[0])
        solver_sett["root_finder"] = root_finder
    solver = _solver(**solver_sett)

    if (stepsize := sett.stepsize_controller) is not None:
        _stepsize_controller = getattr(optx, list(stepsize.keys())[0])
        stepsize_controller = _stepsize_controller(**list(stepsize.values())[0])
        diffeqsolve_sett["stepsize_controller"] = stepsize_controller

    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=t0,
        t1=t1,
        dt0=dt,
        y0=q0,
        args=args,
        # throw=False,
        max_steps=sett.max_steps,
        saveat=saveat,
        **diffeqsolve_sett,
    )
    return sol

pull_newton(sol) ¤

Extract states from diffrax Newton solution object

Source code in feniax/systems/sollibs/diffrax.py
108
109
110
111
112
113
def pull_newton(sol):
    """Extract states from diffrax Newton solution object
    """

    qs = jnp.array(sol.value)
    return qs

pull_ode(sol) ¤

Extract states from diffrax ODE solution object

Source code in feniax/systems/sollibs/diffrax.py
100
101
102
103
104
105
def pull_ode(sol):
    """Extract states from diffrax ODE solution object
    """

    qs = jnp.array(sol.ys)
    return qs

Runge-Kutta¤