Skip to content

Modes

axis_tilde(tensor) ¤

Apply tilde0010 to a tensor

The input tesor is iterated through axis 2 first, and axis 1 subsequently; tilde0010 is applied to axis 0.

Parameters:

Name Type Description Default
tensor ndarray

3xN1xN2 tensor

required

Returns:

Type Description
ndarray

6x6xN1xN2 tensor

Source code in feniax/intrinsic/modes.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
@jit
def axis_tilde(tensor: jnp.ndarray) -> jnp.ndarray:
    """Apply tilde0010 to a tensor

    The input tesor is iterated through axis 2 first, and axis 1
    subsequently; tilde0010 is applied to axis 0.

    Parameters
    ----------
    tensor : jnp.ndarray
        3xN1xN2 tensor

    Returns
    -------
    jnp.ndarray
        6x6xN1xN2 tensor

    """

    f1 = jax.vmap(tilde0010, in_axes=1, out_axes=2)
    f2 = jax.vmap(f1, in_axes=2, out_axes=3)
    f = f2(tensor)

    return f

contraction(moments, loadpaths, precision) ¤

Sums the moments from the nodal forces along the corresponding load path

Parameters:

Name Type Description Default
moments ndarray

num_modes x 6 x num_nodes(index) x num_nodes(moment at the previous index due to forces at this node)

required
loadpaths ndarray

num_node x num_node such that [ni, nj] is 1 or 0 depending on whether ni is a node in the loadpath of nj respectively

required

Returns:

Type Description
ndarray

num_modes x 6 x num_nodes(index) as the sum of moments due to forces at each node

Source code in feniax/intrinsic/modes.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
@partial(jit, static_argnames=["precision"])
def contraction(moments: jnp.ndarray, loadpaths: jnp.ndarray, precision) -> jnp.ndarray:
    """Sums the moments from the nodal forces along the corresponding load path

    Parameters
    ----------
    moments : jnp.ndarray
        num_modes x 6 x num_nodes(index) x num_nodes(moment at the
        previous index due to forces at this node)
    loadpaths : jnp.ndarray
        num_node x num_node such that [ni, nj] is 1 or 0 depending on
        whether ni is a node in the loadpath of nj respectively

    Returns
    -------
    jnp.ndarray
        num_modes x 6 x num_nodes(index) as the sum of moments
        due to forces at each node

    """

    f = jax.vmap(
        lambda u, v: jnp.tensordot(u, v, axes=(2, 0), precision=precision),
        in_axes=(2, 1),
        out_axes=2,
    )
    fuv = f(moments, loadpaths)
    return fuv

coordinates_difftensor(X, Xm, precision) ¤

Computes coordinates

The tensor represents the following: Coordinates, middle point of each element, minus the position of each node in the structure

Parameters:

Name Type Description Default
X ndarray

Grid coordinates

required
Mavg ndarray

Matrix to calculate the averege point between nodes

required
num_nodes int

Number of nodes

required

Returns:

Name Type Description
X3 jnp.ndarray: (3xNnxNn)

Tensor, Xm1 -(X1)' : [Coordinates, Middle point of segment, Node]

Source code in feniax/intrinsic/modes.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
@partial(jit, static_argnames=["precision"])
def coordinates_difftensor(X: jnp.ndarray, Xm: jnp.ndarray, precision) -> jnp.ndarray:
    """Computes coordinates

    The tensor represents the following: Coordinates, middle point of each element,
    minus the position of each node in the structure

    Parameters
    ----------
    X : jnp.ndarray
        Grid coordinates
    Mavg : jnp.ndarray
        Matrix to calculate the averege point between nodes
    num_nodes : int
        Number of nodes

    Returns
    -------
    X3 : jnp.ndarray: (3xNnxNn)
        Tensor, Xm*1 -(X*1)' : [Coordinates, Middle point of segment, Node]


    """

    # Xm = jnp.matmul(X, Mavg, precision=precision)
    num_nodes = X.shape[1]
    ones = jnp.ones(num_nodes)
    Xm3 = jnp.tensordot(
        Xm, ones, axes=0, precision=precision
    )  # copy Xm along a 3rd dimension
    Xn3 = jnp.transpose(
        jnp.tensordot(X, ones, axes=0, precision=precision), axes=[0, 2, 1]
    )  # copy X along the 2nd dimension
    X3 = Xm3 - Xn3
    return X3

eigh(a, b) ¤

Compute the solution to the symmetrized generalized eigenvalue problem.

a_s @ w = b_s @ w @ np.diag(v)

where a_s = (a + a.H) / 2, b_s = (b + b.H) / 2 are the symmetrized versions of the inputs and H is the Hermitian (conjugate transpose) operator.

For self-adjoint inputs the solution should be consistent with scipy.linalg.eigh i.e.

v, w = eigh(a, b) v_sp, w_sp = scipy.linalg.eigh(a, b) np.testing.assert_allclose(v, v_sp) np.testing.assert_allclose(w, standardize_angle(w_sp))

Note this currently uses jax.linalg.eig(jax.linalg.solve(b, a)), which will be slow because there is no GPU implementation of eig and it's just a generally inefficient way of doing it. Future implementations should wrap cuda primitives. This implementation is provided primarily as a means to test eigh_jvp_rule.

Args: a: [n, n] float self-adjoint matrix (i.e. conj(transpose(a)) == a) b: [n, n] float self-adjoint matrix (i.e. conj(transpose(b)) == b)

Returns: v: eigenvalues of the generalized problem in ascending order. w: eigenvectors of the generalized problem, normalized such that w.H @ b @ w = I.

Source code in feniax/intrinsic/modes.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@jax.custom_jvp  # jax.scipy.linalg.eigh doesn't support general problem i.e. b not None
def eigh(a, b):
    """
    Compute the solution to the symmetrized generalized eigenvalue problem.

    a_s @ w = b_s @ w @ np.diag(v)

    where a_s = (a + a.H) / 2, b_s = (b + b.H) / 2 are the symmetrized versions of the
    inputs and H is the Hermitian (conjugate transpose) operator.

    For self-adjoint inputs the solution should be consistent with `scipy.linalg.eigh`
    i.e.

    v, w = eigh(a, b)
    v_sp, w_sp = scipy.linalg.eigh(a, b)
    np.testing.assert_allclose(v, v_sp)
    np.testing.assert_allclose(w, standardize_angle(w_sp))

    Note this currently uses `jax.linalg.eig(jax.linalg.solve(b, a))`, which will be
    slow because there is no GPU implementation of `eig` and it's just a generally
    inefficient way of doing it. Future implementations should wrap cuda primitives.
    This implementation is provided primarily as a means to test `eigh_jvp_rule`.

    Args:
        a: [n, n] float self-adjoint matrix (i.e. conj(transpose(a)) == a)
        b: [n, n] float self-adjoint matrix (i.e. conj(transpose(b)) == b)

    Returns:
        v: eigenvalues of the generalized problem in ascending order.
        w: eigenvectors of the generalized problem, normalized such that
            w.H @ b @ w = I.
    """
    a = symmetrize(a)
    b = symmetrize(b)
    b_inv_a = jax.scipy.linalg.cho_solve(jax.scipy.linalg.cho_factor(b), a)
    v, w = jax.jit(jax.numpy.linalg.eig, backend="cpu")(b_inv_a)
    v = v.real
    # with loops.Scope() as s:
    #     for _ in s.cond_range(jnp.isrealobj)
    if jnp.isrealobj(a) and jnp.isrealobj(b):
        w = w.real
    # reorder as ascending in w
    order = jnp.argsort(v)
    v = v.take(order, axis=0)
    w = w.take(order, axis=1)
    # renormalize so v.H @ b @ H == 1
    norm2 = jax.vmap(lambda wi: (wi.conj() @ b @ wi).real, in_axes=1)(w)
    norm = jnp.sqrt(norm2)
    w = w / norm
    w = standardize_angle(w, b)
    return v, w

eigh_jvp_rule(primals, tangents) ¤

Derivation based on Boedekker et al.

https://arxiv.org/pdf/1701.00392.pdf

Note diagonal entries of Winv dW/dt != 0 as they claim.

Source code in feniax/intrinsic/modes.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@eigh.defjvp
def eigh_jvp_rule(primals, tangents):
    """
    Derivation based on Boedekker et al.

    https://arxiv.org/pdf/1701.00392.pdf

    Note diagonal entries of Winv dW/dt != 0 as they claim.
    """
    a, b = primals
    da, db = tangents
    if not all(jnp.isrealobj(x) for x in (a, b, da, db)):
        raise NotImplementedError("jvp only implemented for real inputs.")
    da = symmetrize(da)
    db = symmetrize(db)

    v, w = eigh(a, b)

    # compute only the diagonal entries
    dv = jax.vmap(
        lambda vi, wi: -wi.conj() @ db @ wi * vi + wi.conj() @ da @ wi,
        in_axes=(0, 1),
    )(v, w)

    dv = dv.real

    E = v[jnp.newaxis, :] - v[:, jnp.newaxis]

    # diagonal entries: compute as column then put into diagonals
    diags = jnp.diag(-0.5 * jax.vmap(lambda wi: wi.conj() @ db @ wi, in_axes=1)(w))
    # off-diagonals: there will be NANs on the diagonal, but these aren't used
    off_diags = jnp.reciprocal(E) * (_H(w) @ (da @ w - db @ w * v[jnp.newaxis, :]))

    dw = w @ jnp.where(jnp.eye(a.shape[0], dtype=np.bool), diags, off_diags)

    return (v, w), (dv, dw)

make_C6(v1) ¤

Given a 3x3xNn tensor, make the diagonal 6x6xNn

It iterates over a third dimension in the input tensor

Parameters:

Name Type Description Default
v1 ndarray

A tensor of the form (3x3xNn)

required
Source code in feniax/intrinsic/modes.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
@jit
def make_C6(v1) -> jnp.ndarray:
    """Given a 3x3xNn tensor, make the diagonal 6x6xNn

    It iterates over a third dimension in the input tensor

    Parameters
    ----------
    v1 : jnp.ndarray
        A tensor of the form (3x3xNn)

    """
    f = jax.vmap(
        lambda v: jnp.vstack(
            [jnp.hstack([v, jnp.zeros((3, 3))]), jnp.hstack([jnp.zeros((3, 3)), v])]
        ),
        in_axes=2,
        out_axes=2,
    )
    fv = f(v1)
    return fv

moment_force(force, X3t, precision) ¤

Yields moments associated to each node due to the forces

Parameters:

Name Type Description Default
force ndarray

Force tensor (Nmx6xNn) for which we want to obtain the resultant moments

required
X3t ndarray

Tilde positions tensor (6x6xNnxNn)

required

Returns:

Type Description
jnp.ndarray: (Nmx6xNnxNn)
Source code in feniax/intrinsic/modes.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
@partial(jit, static_argnames=["precision"])
def moment_force(force: jnp.ndarray, X3t: jnp.ndarray, precision) -> jnp.ndarray:
    """Yields moments associated to each node due to the forces

    Parameters
    ----------
    force : jnp.ndarray
        Force tensor (Nmx6xNn) for which we want to obtain the
        resultant moments
    X3t : jnp.ndarray
        Tilde positions tensor (6x6xNnxNn)

    Returns
    -------
    jnp.ndarray: (Nmx6xNnxNn)

    """

    f1 = jax.vmap(
        lambda u, v: jnp.tensordot(u, v, axes=(1, 1), precision=precision),
        in_axes=(None, 2),
        out_axes=2,
    )  # tensordot along coordinate axis (len=6)
    f2 = jax.vmap(f1, in_axes=(2, 3), out_axes=3)
    fuv = f2(force, X3t)

    return fuv

reshape_modes(_phi, num_modes, num_nodes) ¤

Reshapes vectors in the input matrix to form a 3rd-order tensor

Each vector is made into a 6xNn matrix

Parameters:

Name Type Description Default
_phi ndarray

Matrix as in the output of eigenvector analysis (6NnxNm)

required
num_modes int

Number of modes

required
num_nodes int

Number of nodes

required
Source code in feniax/intrinsic/modes.py
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
@partial(jit, static_argnames=["num_modes", "num_nodes"])
def reshape_modes(_phi: jnp.ndarray, num_modes: int, num_nodes: int):
    """Reshapes vectors in the input matrix to form a 3rd-order tensor

    Each vector is made into a 6xNn matrix

    Parameters
    ----------
    _phi : jnp.ndarray
        Matrix as in the output of eigenvector analysis (6NnxNm)
    num_modes : int
        Number of modes
    num_nodes : int
        Number of nodes


    """

    phi = jnp.reshape(_phi, (num_nodes, 6, num_modes), order="C")
    return phi.T

scale(phi1, psi1, phi2, phi1l, phi1ml, psi1l, phi2l, psi2l, omega, X_xdelta, C0ab, C06ab, *args, **kwargs) ¤

Sacales the intrinsic modes

The porpuse is that the integrals alpha1 and alpha2 are the identity

Parameters:

Name Type Description Default
phi1 ndarray
required
psi1 ndarray
required
phi2 ndarray
required
phi1l ndarray
required
phi1ml ndarray
required
psi1l ndarray
required
phi2l ndarray
required
psi2l ndarray
required
omega ndarray
required
X_xdelta ndarray
required
C0ab ndarray
required
C06ab ndarray
required
*args
()
**kwargs
{}
Source code in feniax/intrinsic/modes.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def scale(
    phi1: jnp.ndarray,
    psi1: jnp.ndarray,
    phi2: jnp.ndarray,
    phi1l: jnp.ndarray,
    phi1ml: jnp.ndarray,
    psi1l: jnp.ndarray,
    phi2l: jnp.ndarray,
    psi2l: jnp.ndarray,
    omega: jnp.ndarray,
    X_xdelta: jnp.ndarray,
    C0ab: jnp.ndarray,
    C06ab: jnp.ndarray,
    *args,
    **kwargs,
):
    """Sacales the intrinsic modes

    The porpuse is that the integrals alpha1 and alpha2 are the
    identity

    Parameters
    ----------
    phi1 : jnp.ndarray
    psi1 : jnp.ndarray
    phi2 : jnp.ndarray
    phi1l : jnp.ndarray
    phi1ml : jnp.ndarray
    psi1l : jnp.ndarray
    phi2l : jnp.ndarray
    psi2l : jnp.ndarray
    omega : jnp.ndarray
    X_xdelta : jnp.ndarray
    C0ab : jnp.ndarray
    C06ab : jnp.ndarray
    *args :
    **kwargs :


    """

    alpha1 = couplings.f_alpha1(phi1, psi1)
    alpha2 = couplings.f_alpha2(phi2l, psi2l, X_xdelta)
    num_modes = len(alpha1)
    # Broadcasting in division
    alpha1_diagonal = alpha1.diagonal()
    alpha2_diagonal = alpha2.diagonal()
    # filter for rigid-body modes
    alpha2d_filtered = jnp.where(alpha2_diagonal > 1e-4, alpha2_diagonal, 1.0)
    phi1 /= jnp.sqrt(alpha1_diagonal).reshape(num_modes, 1, 1)
    psi1 /= jnp.sqrt(alpha1_diagonal).reshape(num_modes, 1, 1)
    phi1l /= jnp.sqrt(alpha1_diagonal).reshape(num_modes, 1, 1)
    phi1ml /= jnp.sqrt(alpha1_diagonal).reshape(num_modes, 1, 1)
    psi1l /= jnp.sqrt(alpha1_diagonal).reshape(num_modes, 1, 1)
    phi2 /= jnp.sqrt(alpha2d_filtered).reshape(num_modes, 1, 1)
    phi2l /= jnp.sqrt(alpha2d_filtered).reshape(num_modes, 1, 1)
    psi2l /= jnp.sqrt(alpha2d_filtered).reshape(num_modes, 1, 1)

    return (
        phi1,
        psi1,
        phi2,
        phi1l,
        phi1ml,
        psi1l,
        phi2l,
        psi2l,
        omega,
        X_xdelta,
        C0ab,
        C06ab,
    )

tilde0010(vector) ¤

Tilde matrix for cross product (moments due to forces)

Parameters:

Name Type Description Default
vector ndarray

A 3-element array

required

Returns:

Type Description
ndarray

6x6 matrix with (3:6 x 0:3) tilde operator

Source code in feniax/intrinsic/modes.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
@jit
def tilde0010(vector: jnp.ndarray) -> jnp.ndarray:
    """Tilde matrix for cross product (moments due to forces)

    Parameters
    ----------
    vector : jnp.ndarray
        A 3-element array

    Returns
    -------
    jnp.ndarray
        6x6 matrix with (3:6 x 0:3) tilde operator

    """

    vector_tilde = jnp.vstack(
        [jnp.zeros((3, 6)), jnp.hstack([tilde(vector), jnp.zeros((3, 3))])]
    )
    return vector_tilde