JaxMDEngine¶
Built on the JAX-MD library, this engine adds periodic boundary conditions, neighbor lists, and configurable 1-4 scaling on top of JAX's differentiable energy functions.
Installation¶
For GPU support:
Configuration¶
from q2mm.backends.mm import JaxMDEngine
engine = JaxMDEngine(
box=(100.0, 100.0, 100.0), # simulation box dimensions (Å)
coulomb=None, # CoulombHandler; default: CutoffCoulomb(r_cut=12.0)
nb_options=None, # NonbondedOptions; default: r_cut=12.0
)
| Parameter | Type | Default | Description |
|---|---|---|---|
box |
tuple[float, float, float] |
(100.0, 100.0, 100.0) |
Periodic box dimensions in Å |
coulomb |
CoulombHandler \| None |
CutoffCoulomb(r_cut=12.0) |
Electrostatic handler |
nb_options |
NonbondedOptions \| None |
NonbondedOptions(r_cut=12.0) |
Non-bonded cutoff and options |
Supported Energy Terms¶
| Term | Supported |
|---|---|
| Bonds (harmonic) | ✅ |
| Angles (harmonic) | ✅ |
| Torsions (cosine) | ✅ |
| Improper torsions | ❌ |
| vdW (LJ 12-6) | ✅ |
| Electrostatics | Infrastructure only (charges zeroed) |
| 1-4 scaling | ✅ Configurable (default: AMBER 0.5) |
| Periodic boundaries | ✅ |
| Neighbor lists | ✅ (jax-md native) |
Functional forms: Harmonic only.
Capabilities¶
| Method | Supported | Notes |
|---|---|---|
energy() |
✅ | — |
energy_breakdown() |
✅ | Per-term decomposition |
minimize() |
✅ | JAX gradients + SciPy L-BFGS-B |
hessian() |
✅ | Analytical via jax.hessian |
frequencies() |
✅ | From analytical Hessian |
energy_and_param_grad() |
✅ | Analytical via jax.grad |
batched_energy() |
✅ | Vectorized via jax.vmap |
supports_runtime_params() |
✅ | — |
supports_analytical_gradients() |
✅ | — |
GPU Support¶
JaxMDEngine runs on whichever device JAX selects. To use a GPU:
- Install the CUDA-enabled JAX:
pip install jax[cuda12] - Verify:
python -c "import jax; print(jax.default_backend())"
The engine name includes the backend string (e.g., JAX-MD (OPLSAA, gpu)).
Limitations¶
- Harmonic only — MM3 support is tracked in #91.
- Electrostatics zeroed — Coulomb energy is computed with zero charges; charge optimization is not yet supported.
- No improper torsions — topology arrays are empty.
- 64-bit mode forced — importing this module enables
jax_enable_x64globally, which affects all JAX code in the process.
Example¶
from q2mm.backends.mm import JaxMDEngine
from q2mm.models.forcefield import ForceField
from q2mm.models.molecule import Q2MMMolecule
mol = Q2MMMolecule.from_xyz("molecule.xyz")
ff = ForceField.create_for_molecule(mol)
engine = JaxMDEngine(box=(50.0, 50.0, 50.0))
e = engine.energy(mol, ff)
print(f"JAX-MD energy: {e:.4f} kcal/mol")
# Analytical parameter gradients
e, grad = engine.energy_and_param_grad(mol, ff)
print(f"Energy: {e:.4f}, grad shape: {grad.shape}")
See Also¶
- JaxEngine — simpler, no periodic boundaries
- Engine comparison table
- GPU benchmarks
- API Reference: JaxMDEngine