JaxEngine¶
A pure-JAX implementation supporting both harmonic (OPLSAA-style) and MM3
functional forms, including bond, angle, torsion, and vdW energy terms.
Best for small-to-medium molecules where periodic boundaries and neighbor
lists are not needed. All energy functions are differentiable via
jax.grad, enabling analytical gradient computation.
Installation¶
For GPU support, install the CUDA-enabled jaxlib:
Configuration¶
JaxEngine has no constructor parameters. It runs on whichever JAX backend
is active (cpu or gpu), detected via jax.default_backend().
Supported Energy Terms¶
| Term | Supported |
|---|---|
| Bonds (harmonic + MM3) | ✅ |
| Angles (harmonic + MM3) | ✅ |
| Torsions (cosine) | ✅ |
| Improper torsions | ❌ |
| vdW (LJ 12-6 + Buckingham exp-6) | ✅ |
| Electrostatics | ❌ |
| 1-4 scaling | ❌ Not implemented |
Functional forms: Harmonic and MM3.
Capabilities¶
| Method | Supported | Notes |
|---|---|---|
energy() |
✅ | Pure JAX |
minimize() |
✅ | JAX gradients + SciPy L-BFGS-B |
hessian() |
✅ | Analytical via jax.hessian |
frequencies() |
✅ | From analytical Hessian |
energy_and_param_grad() |
✅ | Analytical via jax.grad |
batched_energy() |
✅ | Vectorized via jax.vmap |
supports_runtime_params() |
✅ | — |
supports_analytical_gradients() |
✅ | — |
GPU Support¶
JaxEngine runs on whichever device JAX selects. To use a GPU:
- Install the CUDA-enabled JAX:
pip install jax[cuda12] - Verify:
python -c "import jax; print(jax.default_backend())"
The engine name includes the backend string (e.g., JAX (harmonic, gpu)
or JAX (harmonic, cpu)).
Performance
In the current benchmark set, JaxEngine is one of the fastest in-process backends for harmonic CH₃F optimization and offers analytical gradients for energy-based evaluators. Exact speedups depend on system size, objective, and device, so use the benchmark overview and GPU benchmarks for workload-specific numbers.
Limitations¶
- No 1-4 pair scaling — non-bonded energies differ from OpenMM/JAX-MD for molecules with 1-4 interactions. See the compatibility notes.
- No periodic boundaries — gas-phase only.
Example¶
from q2mm.backends.mm import JaxEngine
from q2mm.models.forcefield import ForceField
from q2mm.models.molecule import Q2MMMolecule
mol = Q2MMMolecule.from_xyz("molecule.xyz")
ff = ForceField.create_for_molecule(mol)
engine = JaxEngine()
e = engine.energy(mol, ff)
print(f"JAX energy: {e:.4f} kcal/mol")
# Analytical parameter gradients
e, grad = engine.energy_and_param_grad(mol, ff)
print(f"Energy: {e:.4f}, grad shape: {grad.shape}")
See Also¶
- JaxMDEngine — periodic boundaries, neighbor lists, 1-4 scaling
- Engine comparison table
- GPU benchmarks
- API Reference: JaxEngine