Skip to content

Nonlinear couplings

f_alpha1(phi1, psi1) ¤

Gamma1 tensor calculation.

Parameters:

Name Type Description Default
phi1 array

Velocity modal shapes (Nmx6xNn)

required
psi1 array

Momentum modal shapes (Nmx6xNn)

required

Returns:

Type Description
array

alpha1 tensor (NmxNm)

Source code in feniax/intrinsic/couplings.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@jax.jit
def f_alpha1(phi1: jnp.array, psi1: jnp.array) -> jnp.array:
    """Gamma1 tensor calculation.

    Parameters
    ----------
    phi1 : jnp.array
        Velocity modal shapes (Nmx6xNn)
    psi1 : jnp.array
        Momentum modal shapes (Nmx6xNn)
    Returns
    -------
    jnp.array
        alpha1 tensor (NmxNm)

    """

    alpha1 = jnp.einsum("isn,jsn->ij", phi1, psi1)
    return alpha1

f_alpha2(phi2, psi2, delta_s) ¤

Gamma1 tensor calculation.

Parameters:

Name Type Description Default
phi2 array

Internal force modal shapes (Nmx6xNn)

required
psi2 array

Strain modal shapes (Nmx6xNn)

required
delta_s array

1D differential path increments (Nn)

required

Returns:

Type Description
array

Alpha2 tensor (NmxNm)

Source code in feniax/intrinsic/couplings.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@jax.jit
def f_alpha2(phi2: jnp.array, psi2: jnp.array, delta_s: jnp.array) -> jnp.array:
    """Gamma1 tensor calculation.

    Parameters
    ----------
    phi2 : jnp.array
        Internal force modal shapes (Nmx6xNn)
    psi2 : jnp.array
        Strain modal shapes (Nmx6xNn)
    delta_s : jnp.array
        1D differential path increments (Nn)

    Returns
    -------
    jnp.array
        Alpha2 tensor (NmxNm)

    """

    phi2i = phi2[:, :, 1:]
    psi2i = psi2[:, :, 1:]
    delta_si = delta_s[1:]
    alpha2 = jnp.einsum("isn,jsn,n->ij", phi2i, psi2i, delta_si)
    return alpha2

f_gamma1(phi1, psi1) ¤

Gamma1 tensor calculation.

Parameters:

Name Type Description Default
phi1 array

Velocity modal shapes (Nmx6xNn)

required
psi1 array

Momentum modal shapes (Nmx6xNn)

required

Returns:

Type Description
array

Gamma1 tensor (NmxNmxNm)

Source code in feniax/intrinsic/couplings.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@jax.jit
def f_gamma1(phi1: jnp.array, psi1: jnp.array) -> jnp.array:
    """Gamma1 tensor calculation.

    Parameters
    ----------
    phi1 : jnp.array
        Velocity modal shapes (Nmx6xNn)
    psi1 : jnp.array
        Momentum modal shapes (Nmx6xNn)
    Returns
    -------
    jnp.array
        Gamma1 tensor (NmxNmxNm)

    """
    f1 = jax.vmap(
        lambda u, v: jnp.tensordot(functions.L1(u), v, axes=(1, 1)),
        in_axes=(1, 2),
        out_axes=2,
    )  # iterate nodes
    f2 = jax.vmap(f1, in_axes=(0, None), out_axes=0)  # modes in 1st tensor
    L1 = f2(phi1, psi1)  # Nmx6xNmxNm
    gamma1 = jnp.einsum("isn,jskn->ijk", phi1, L1)
    return gamma1

f_gamma2(phi1m, phi2, psi2, delta_s) ¤

Gamma1 tensor calculation.

Parameters:

Name Type Description Default
psi1m array

Velocity modal shapes at mid-points (Nmx6xNn)

required
phi2 array

Internal force modal shapes (Nmx6xNn)

required
psi2 array

Strain modal shapes (Nmx6xNn)

required
delta_s array

1D differential path increments (Nn)

required

Returns:

Type Description
array

Gamma2 tensor (NmxNmxNm)

Source code in feniax/intrinsic/couplings.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@jax.jit
def f_gamma2(
    phi1m: jnp.array, phi2: jnp.array, psi2: jnp.array, delta_s: jnp.array
) -> jnp.array:
    """Gamma1 tensor calculation.

    Parameters
    ----------
    psi1m : jnp.array
        Velocity modal shapes at mid-points (Nmx6xNn)    
    phi2 : jnp.array
        Internal force modal shapes (Nmx6xNn)
    psi2 : jnp.array
        Strain modal shapes (Nmx6xNn)
    delta_s : jnp.array
        1D differential path increments (Nn)

    Returns
    -------
    jnp.array
        Gamma2 tensor (NmxNmxNm)

    """

    phi1mi = phi1m[:, :, 1:]
    phi2i = phi2[:, :, 1:]
    psi2i = psi2[:, :, 1:]
    delta_si = delta_s[1:]
    f1 = jax.vmap(
        lambda u, v: jnp.tensordot(functions.L2(u), v, axes=(1, 1)),
        in_axes=(1, 2),
        out_axes=2,
    )  # iterate nodes
    f2 = jax.vmap(f1, in_axes=(0, None), out_axes=0)  # modes in 1st tensor
    L2 = f2(phi2i, psi2i)  # Nmx6xNmxNm
    gamma2 = jnp.einsum("isn,jskn,n->ijk", phi1mi, L2, delta_si)
    # L2 = f2(phi2, psi2) # Nmx6xNmxNm
    # gamma2 = jnp.einsum('isn,jskn,n->ijk', phi1m, L2, delta_s)
    return gamma2