Greeks via Autodiff¶
VALAX computes Greeks by differentiating pricing functions with jax.grad. This works for every pricing method — no method-specific code needed.
Quick Start¶
from valax.greeks import greeks, greek
from valax.pricing.analytic import black_scholes_price
from valax.instruments import EuropeanOption
import jax.numpy as jnp
option = EuropeanOption(strike=jnp.array(100.0), expiry=jnp.array(1.0), is_call=True)
args = (jnp.array(100.0), jnp.array(0.20), jnp.array(0.05), jnp.array(0.02))
# All Greeks at once
g = greeks(black_scholes_price, option, *args)
# g["delta"], g["gamma"], g["vega"], g["theta"], g["rho"], ...
# Single Greek
delta = greek(black_scholes_price, "delta", option, *args)
Available Greeks¶
First Order¶
| Name | Definition | argnums |
|---|---|---|
delta |
\(\partial V / \partial S\) | spot |
vega |
\(\partial V / \partial \sigma\) | vol |
rho |
\(\partial V / \partial r\) | rate |
dividend_rho |
\(\partial V / \partial q\) | dividend |
theta |
\(\partial V / \partial T\) | time (bump-based) |
Second Order¶
| Name | Definition |
|---|---|
gamma |
\(\partial^2 V / \partial S^2\) |
vanna |
\(\partial^2 V / \partial S \partial \sigma\) |
volga |
\(\partial^2 V / \partial \sigma^2\) |
How It Works¶
First-order Greeks use a single jax.grad call:
import jax
# Delta: differentiate price w.r.t. spot (arg index 1)
delta_fn = jax.grad(black_scholes_price, argnums=1)
# All first-order Greeks in one backward pass
all_grads = jax.grad(black_scholes_price, argnums=(1, 2, 3, 4))
delta, vega, rho, div_rho = all_grads(option, *args)
Second-order Greeks use nested jax.grad:
# Gamma = d(delta)/d(spot)
gamma_fn = jax.grad(jax.grad(black_scholes_price, argnums=1), argnums=1)
# Vanna = d(delta)/d(vol)
vanna_fn = jax.grad(jax.grad(black_scholes_price, argnums=1), argnums=2)
Theta
Theta is computed via finite difference (bump expiry by 1/365) rather than autodiff, since expiry lives inside the instrument pytree rather than as a separate market argument. The returned value is per-year theta.
Greeks Through Any Pricing Method¶
The same greeks() function works with MC, PDE, and lattice pricers:
import jax
# MC delta — autodiff through the full simulation
def mc_price_fn(spot):
return mc_price(option, spot, model, config, key)
mc_delta = jax.grad(mc_price_fn)(jnp.array(100.0))
# PDE delta
def pde_price_fn(spot):
return pde_price(option, spot, vol, rate, dividend, config)
pde_delta = jax.grad(pde_price_fn)(jnp.array(100.0))
Accuracy¶
Autodiff Greeks match closed-form analytical solutions to machine precision (~1e-10 in float64). This is validated across multiple moneyness/vol/rate regimes in the test suite.