JAX Engine¶
jax_engine
¶
JAX-based differentiable MM backend for analytical gradients.
Provides a pure-JAX molecular mechanics engine using harmonic bond/angle
and 12-6 Lennard-Jones energy functions. Torsion energy functions are
implemented but not yet wired for Q2MMMolecule (which lacks torsion
detection); they will activate once torsion matching is added.
All energy functions are differentiable via jax.grad, enabling analytical
gradient computation for force field parameter optimization.
ForceField stores parameters in canonical units (kcal/mol/Ų for bond_k,
kcal/mol/rad² for angle_k) with energy convention E = k·(x − x₀)².
The JAX energy functions use the same convention, so no unit conversion is
needed at the engine boundary.
For MM3-specific JAX forms, see issue #91.
JaxHandle
dataclass
¶
JaxHandle(molecule: Q2MMMolecule, bond_indices: ndarray, angle_indices: ndarray, torsion_indices: ndarray, vdw_pair_indices: ndarray, bond_param_map: ndarray, angle_param_map: ndarray, torsion_param_map: ndarray, atom_vdw_map: ndarray, n_bond_types: int, n_angle_types: int, n_torsion_types: int, n_vdw_types: int, _energy_fn: Callable | None = None, _grad_fn: Callable | None = None, _coord_hess_fn: Callable | None = None)
Cached topology and parameter mapping for a molecule.
Created once per molecule, reused across parameter updates.
Attributes:
| Name | Type | Description |
|---|---|---|
molecule |
Q2MMMolecule
|
Deep copy of the input molecule. |
bond_indices |
ndarray
|
Atom index pairs, shape |
angle_indices |
ndarray
|
Atom index triples, shape |
torsion_indices |
ndarray
|
Atom index quadruples, shape
|
vdw_pair_indices |
ndarray
|
Non-excluded pairs, shape |
bond_param_map |
ndarray
|
Maps each matched bond → index into
|
angle_param_map |
ndarray
|
Maps each matched angle → index into
|
torsion_param_map |
ndarray
|
Maps each matched torsion → index into
|
atom_vdw_map |
ndarray
|
Maps each atom → index into |
n_bond_types |
int
|
Number of unique bond parameter types. |
n_angle_types |
int
|
Number of unique angle parameter types. |
n_torsion_types |
int
|
Number of unique torsion parameter types. |
n_vdw_types |
int
|
Number of unique vdW parameter types. |
JaxEngine
¶
Bases: MMEngine
Differentiable MM backend using JAX with OPLSAA-style energy functions.
Provides analytical gradients of the energy with respect to force field
parameters via jax.grad, eliminating the need for finite differences
in parameter optimization.
The energy functions use standard harmonic/LJ forms (not MM3). Near
equilibrium, results are similar to MM3 but not identical. For exact
MM3 parity, use :class:~q2mm.backends.mm.openmm.OpenMMEngine.
Example
engine = JaxEngine() energy = engine.energy(molecule, forcefield) energy, grad = engine.energy_and_param_grad(molecule, forcefield)
Initialize the JAX engine.
Raises:
| Type | Description |
|---|---|
ImportError
|
If JAX is not installed. |
Source code in q2mm/backends/mm/jax_engine.py
name
property
¶
Human-readable engine name.
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
|
supported_functional_forms
¶
JAX currently supports harmonic forms only (see issue #91 for MM3).
Returns:
| Type | Description |
|---|---|
frozenset[str]
|
frozenset[str]: |
is_available
¶
Check if JAX is installed.
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
|
supports_runtime_params
¶
Whether parameters can be updated without rebuilding the system.
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
Always |
supports_analytical_gradients
¶
Whether this engine provides analytical parameter gradients.
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
Always |
create_context
¶
create_context(structure, forcefield: ForceField | None = None) -> JaxHandle
Build topology and compile energy function for a molecule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
structure
|
Q2MMMolecule | JaxHandle
|
A :class: |
required |
forcefield
|
ForceField | None
|
Force field to apply. Auto-generated from the
molecule if |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
JaxHandle |
JaxHandle
|
Compiled handle for energy evaluation and gradient computation. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If vdW parameters are defined but not all atoms have matching entries. |
Source code in q2mm/backends/mm/jax_engine.py
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 | |
energy
¶
Calculate energy in kcal/mol.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
structure
|
Q2MMMolecule | JaxHandle
|
A :class: |
required |
forcefield
|
ForceField
|
Force field parameters. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
float |
float
|
Potential energy in kcal/mol. |
Source code in q2mm/backends/mm/jax_engine.py
energy_and_param_grad
¶
Compute energy and analytical gradient w.r.t. FF parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
structure
|
Q2MMMolecule | JaxHandle
|
A :class: |
required |
forcefield
|
ForceField
|
Force field parameters. |
required |
Returns:
| Type | Description |
|---|---|
tuple[float, ndarray]
|
tuple[float, np.ndarray]: |
Source code in q2mm/backends/mm/jax_engine.py
hessian
¶
Compute Hessian via jax.hessian (d²E/dcoords²) in Hartree/Bohr².
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
structure
|
Q2MMMolecule | JaxHandle
|
A :class: |
required |
forcefield
|
ForceField
|
Force field parameters. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
np.ndarray: Shape |
Source code in q2mm/backends/mm/jax_engine.py
frequencies
¶
Compute vibrational frequencies in cm⁻¹ from the Hessian.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
structure
|
Q2MMMolecule | JaxHandle
|
A :class: |
required |
forcefield
|
ForceField
|
Force field parameters. |
required |
Returns:
| Type | Description |
|---|---|
list[float]
|
list[float]: Vibrational frequencies in cm⁻¹, sorted ascending. |
Source code in q2mm/backends/mm/jax_engine.py
minimize
¶
Minimize energy w.r.t. coordinates using analytical JAX gradients.
Uses scipy.optimize.minimize with the L-BFGS-B method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
structure
|
Q2MMMolecule | JaxHandle
|
A :class: |
required |
forcefield
|
ForceField
|
Force field parameters. |
required |
max_iterations
|
int
|
Maximum number of L-BFGS-B iterations. |
200
|
Returns:
| Type | Description |
|---|---|
tuple
|
tuple[float, list[str], np.ndarray]: |