When does the analytical path win?¶
This page answers one question: when is the JAX analytical-gradient path worth its compile-time overhead, and when does single-shot scipy on OpenMM still win? It is interpretive — pulling from the Small Molecules, GPU Acceleration, and Rh-Enamide result tables — and is intended as a decision aid, not a new benchmark.
TL;DR¶
- Single-shot, one molecule, frequency-only: scipy on OpenMM is faster. The JAX analytical path pays a one-time JIT compile cost (~1–2 s on CPU, more on GPU first-touch) that a 5-atom system cannot amortize.
- Multi-start (n ≥ 5): the analytical path crosses over. JIT compile amortizes across all starts inside one XLA kernel; per-start cost collapses.
- Many parameters, organometallic-scale (Rh-enamide, 182 params): GPU helps even for single-shot, because per-evaluation cost is now large enough to dominate kernel-launch overhead.
- TS curvature inversion (QFUERZA) inside JIT: now lives inside the analytical path (Phase 5), so SN2/rh-enamide benchmarks no longer pay a Python round-trip per evaluation.
What "analytical" means here¶
The analytical path = JAX-traced loss with jax.value_and_grad, JIT
compiled once, run inside a jaxopt solver (also JIT compiled). No
finite-difference gradients. See
Architecture for the residual-kind ×
analytical-gradient × in-JIT matrix.
The non-analytical baselines are scipy L-BFGS-B with finite-difference
gradients on top of OpenMM (CUDA), Tinker (CPU), or JAX/JAX-MD as a
function evaluator (no gradients propagated back).
Crossover, in numbers¶
These are the rows that bracket the question. For the full matrix see Small Molecules.
CH₃F (1 mol, 8 params, harmonic + MM3)¶
| Path | Configuration | RMSD (cm⁻¹) | Wall time |
|---|---|---|---|
| Scipy / OpenMM CUDA | L-BFGS-B (FD) | 59.5 | 4.8 s |
| JAX analytical | L-BFGS-B (analytical) | 579.0 | 2.2 s |
| JAX analytical | optax:adam (analytical) | 56.3 | 25.2 s |
| JAX analytical | jaxopt:lbfgsb | 579.5 | 6.8 s |
| JAX analytical | multi:L-BFGS-B (n=10) | 586.3 | 7.5 s |
| Scipy / OpenMM CUDA | multi:L-BFGS-B (n=10) | 28.7 | 157.4 s |
Reading this table: on CH₃F, the cheapest good answer is OpenMM single-shot L-BFGS-B (4.8 s, RMSD 59.5). The analytical path matches that RMSD only with optax:adam (25.2 s) — slower in absolute terms. But the multi-start row crosses over: jaxopt multi(n=10) finishes in 7.5 s versus OpenMM's 157.4 s for the same n. That is the regime where analytical wins: OpenMM's per-start cost is fixed; JAX's per-start cost approaches zero as n grows because all starts share one JIT-compiled gradient kernel.
What's missing from the table — and from the analytical pipeline today
— is the JAX multi:L-BFGS-B (n=10) row reaching the same low RMSD
as the OpenMM sequential row (28.7). Two reasons it does not:
- The analytical gradient path with
jaxopt.LBFGSBconverges to a different local minimum than scipy'sL-BFGS-Bon the MM3 landscape. Both are valid optima of the chosen loss; the basins they prefer differ. This is a parameterization issue, not a correctness gap. - The JAX path now uses topology-grouped
vmapfor Hessian batching (PR #264) — molecules sharing the same topology compute their Hessians in a single vectorised call. This matters for training sets with multiple geometries (e.g., Rh-enamide's 9 molecules).
Rh-enamide (9 mols, 182 params, organometallic, frequency-only)¶
| Path | Configuration | s / eval | Wall time |
|---|---|---|---|
| JAX-MD (OPLSAA) | L-BFGS-B (FD), CPU | 75.4 | 23,819 s |
| JAX-MD (OPLSAA) | L-BFGS-B (FD), GPU | 13.4 | 6,009 s (4.0× faster) |
| JAX (harmonic) | L-BFGS-B (FD), CPU | 26.2 | 550 s |
| JAX (harmonic) | L-BFGS-B (FD), GPU | 12.6 | 391 s (1.4× faster) |
| JAX MM3 grad-simp (analytical) | GPU | — | ~25 min |
Source: GPU Acceleration, Rh-Enamide.
For the realistic case study, GPU helps because the per-evaluation arithmetic is finally large enough to dominate kernel launch overhead. The 4.0× JAX-MD speedup is the strongest extant device-level acceleration in the project.
When you should reach for the analytical path¶
- You are running a parameter sweep or multi-start with
n ≥ 5. - Your training set is dominated by frequency or eigenmatrix residuals
(the analytical Hessian/Jacobian path makes these cheap; the
finite-difference baseline pays an
O(n_param)cost per outer step). - You are on GPU with ≥ ~100 parameters or ≥ ~50 atoms total across the training set.
- You are iterating on TS systems and need QFUERZA inversion inside the inner loss (Phase 5; pure-JAX path).
When OpenMM single-shot still wins¶
- One small molecule (≤ ~10 atoms), single-shot, frequency-only, 8–20 parameters. The JIT compile is not amortized.
- You need the cheapest answer above some quality bar and you do not need the full Pareto frontier of optima.
- You are CPU-only and the system is small enough that vectorization doesn't materially help.
What is not on the path¶
The "analytical path" today does not include:
vmapover molecules. ~~Deferred~~ — Done (PR #264).JaxLossnow groups molecules by topology and usesjax.vmapfor Hessian batching within each group. 25% loss-eval speedup on Rh-enamide (632 → 472 ms). Molecules with different topologies still use a small outer loop (no padding, no eigendecomposition corruption).- Geometry references (bond_length, bond_angle, torsion_angle) via
implicit differentiation. ~~Deferred~~ — Done (PR #249).
jaxopt.LBFGS(implicit_diff=True)relaxes coordinates at the current parameters; the outerjax.gradgets exact∂x*/∂pvia the implicit function theorem. Non-convergence fallback adds a penalty when the inner solver fails to converge (PR #269). Tested on CH₃F with mixed frequency + geometry objectives. - Stretch-bend cross-term in OpenMM/Tinker. The JAX engine now
computes stretch-bend energy; OpenMM and Tinker do not yet.
StretchBendParamis inForceField, and the MM3.fldloader populates it. Wildcard atom type00matching is not yet supported. - Parameter equivalences / linked params as JIT-traceable
projections. Box bounds work today via
jaxopt.LBFGSB. Anything more — atom-type equivalence groups, linear constraints — would require model-side infrastructure that does not yet exist in q2mm.
These are explicit non-goals for the current analytical pipeline; they will be revisited when a real workflow makes them necessary.
Where the numbers come from¶
- CH₃F rows:
benchmarks/ch3f/(golden fixtures committed to the repo). - Rh-enamide rows:
benchmarks/rh-enamide/. - Architectural background: Architecture.
- Theory of analytical-gradient observables: Theory & Methods.