Skip to content

JAX Engine

jax_engine

JAX-based differentiable MM backend for analytical gradients.

Provides a pure-JAX molecular mechanics engine using harmonic bond/angle and 12-6 Lennard-Jones energy functions. Torsion energy functions are implemented but not yet wired for Q2MMMolecule (which lacks torsion detection); they will activate once torsion matching is added. All energy functions are differentiable via jax.grad, enabling analytical gradient computation for force field parameter optimization.

ForceField stores parameters in canonical units (kcal/mol/Ų for bond_k, kcal/mol/rad² for angle_k) with energy convention E = k·(x − x₀)². The JAX energy functions use the same convention, so no unit conversion is needed at the engine boundary.

For MM3-specific JAX forms, see issue #91.

JaxHandle dataclass

JaxHandle(molecule: Q2MMMolecule, bond_indices: ndarray, angle_indices: ndarray, torsion_indices: ndarray, vdw_pair_indices: ndarray, bond_param_map: ndarray, angle_param_map: ndarray, torsion_param_map: ndarray, atom_vdw_map: ndarray, n_bond_types: int, n_angle_types: int, n_torsion_types: int, n_vdw_types: int, _energy_fn: Callable | None = None, _grad_fn: Callable | None = None, _coord_hess_fn: Callable | None = None)

Cached topology and parameter mapping for a molecule.

Created once per molecule, reused across parameter updates.

Attributes:

Name Type Description
molecule Q2MMMolecule

Deep copy of the input molecule.

bond_indices ndarray

Atom index pairs, shape (n_matched_bonds, 2).

angle_indices ndarray

Atom index triples, shape (n_matched_angles, 3).

torsion_indices ndarray

Atom index quadruples, shape (n_matched_torsions, 4).

vdw_pair_indices ndarray

Non-excluded pairs, shape (n_vdw_pairs, 2).

bond_param_map ndarray

Maps each matched bond → index into ForceField.bonds.

angle_param_map ndarray

Maps each matched angle → index into ForceField.angles.

torsion_param_map ndarray

Maps each matched torsion → index into ForceField.torsions.

atom_vdw_map ndarray

Maps each atom → index into ForceField.vdws.

n_bond_types int

Number of unique bond parameter types.

n_angle_types int

Number of unique angle parameter types.

n_torsion_types int

Number of unique torsion parameter types.

n_vdw_types int

Number of unique vdW parameter types.

JaxEngine

JaxEngine()

Bases: MMEngine

Differentiable MM backend using JAX with OPLSAA-style energy functions.

Provides analytical gradients of the energy with respect to force field parameters via jax.grad, eliminating the need for finite differences in parameter optimization.

The energy functions use standard harmonic/LJ forms (not MM3). Near equilibrium, results are similar to MM3 but not identical. For exact MM3 parity, use :class:~q2mm.backends.mm.openmm.OpenMMEngine.

Example

engine = JaxEngine() energy = engine.energy(molecule, forcefield) energy, grad = engine.energy_and_param_grad(molecule, forcefield)

Initialize the JAX engine.

Raises:

Type Description
ImportError

If JAX is not installed.

Source code in q2mm/backends/mm/jax_engine.py
def __init__(self):
    """Initialize the JAX engine.

    Raises:
        ImportError: If JAX is not installed.
    """
    _ensure_jax()

name property

name: str

Human-readable engine name.

Returns:

Name Type Description
str str

"JAX (harmonic)".

supported_functional_forms

supported_functional_forms() -> frozenset[str]

JAX currently supports harmonic forms only (see issue #91 for MM3).

Returns:

Type Description
frozenset[str]

frozenset[str]: {"harmonic"}.

Source code in q2mm/backends/mm/jax_engine.py
def supported_functional_forms(self) -> frozenset[str]:
    """JAX currently supports harmonic forms only (see issue #91 for MM3).

    Returns:
        frozenset[str]: ``{"harmonic"}``.
    """
    return frozenset({"harmonic"})

is_available

is_available() -> bool

Check if JAX is installed.

Returns:

Name Type Description
bool bool

True if the jax package is importable.

Source code in q2mm/backends/mm/jax_engine.py
def is_available(self) -> bool:
    """Check if JAX is installed.

    Returns:
        bool: ``True`` if the ``jax`` package is importable.
    """
    return _HAS_JAX

supports_runtime_params

supports_runtime_params() -> bool

Whether parameters can be updated without rebuilding the system.

Returns:

Name Type Description
bool bool

Always True for JAX.

Source code in q2mm/backends/mm/jax_engine.py
def supports_runtime_params(self) -> bool:
    """Whether parameters can be updated without rebuilding the system.

    Returns:
        bool: Always ``True`` for JAX.
    """
    return True

supports_analytical_gradients

supports_analytical_gradients() -> bool

Whether this engine provides analytical parameter gradients.

Returns:

Name Type Description
bool bool

Always True for JAX.

Source code in q2mm/backends/mm/jax_engine.py
def supports_analytical_gradients(self) -> bool:
    """Whether this engine provides analytical parameter gradients.

    Returns:
        bool: Always ``True`` for JAX.
    """
    return True

create_context

create_context(structure, forcefield: ForceField | None = None) -> JaxHandle

Build topology and compile energy function for a molecule.

Parameters:

Name Type Description Default
structure Q2MMMolecule | JaxHandle

A :class:Q2MMMolecule or :class:JaxHandle.

required
forcefield ForceField | None

Force field to apply. Auto-generated from the molecule if None.

None

Returns:

Name Type Description
JaxHandle JaxHandle

Compiled handle for energy evaluation and gradient computation.

Raises:

Type Description
ValueError

If vdW parameters are defined but not all atoms have matching entries.

Source code in q2mm/backends/mm/jax_engine.py
def create_context(self, structure, forcefield: ForceField | None = None) -> JaxHandle:
    """Build topology and compile energy function for a molecule.

    Args:
        structure (Q2MMMolecule | JaxHandle): A :class:`Q2MMMolecule` or :class:`JaxHandle`.
        forcefield: Force field to apply. Auto-generated from the
            molecule if ``None``.

    Returns:
        JaxHandle: Compiled handle for energy evaluation and gradient
            computation.

    Raises:
        ValueError: If vdW parameters are defined but not all atoms
            have matching entries.
    """
    if forcefield is not None:
        self._validate_forcefield(forcefield)
    molecule = _as_molecule(structure)
    if forcefield is None:
        forcefield = ForceField.create_for_molecule(molecule)

    # Match bonds
    bond_atom_indices = []
    bond_param_map = []
    for bond in molecule.bonds:
        idx, param = _match_bond(forcefield, bond.elements, env_id=bond.env_id, ff_row=bond.ff_row)
        if param is not None:
            bond_atom_indices.append((bond.atom_i, bond.atom_j))
            bond_param_map.append(idx)

    # Match angles
    angle_atom_indices = []
    angle_param_map = []
    for angle in molecule.angles:
        idx, param = _match_angle(forcefield, angle.elements, env_id=angle.env_id, ff_row=angle.ff_row)
        if param is not None:
            angle_atom_indices.append((angle.atom_i, angle.atom_j, angle.atom_k))
            angle_param_map.append(idx)

    # Match torsions (NOTE: Q2MMMolecule does not yet detect torsions,
    # so this loop will be empty until torsion matching is added.)
    torsion_atom_indices = []
    torsion_param_map = []
    for _i_tor, torsion in enumerate(molecule.torsions if hasattr(molecule, "torsions") else []):
        for j_ff, ff_tor in enumerate(forcefield.torsions):
            if ff_tor.ff_row is not None and hasattr(torsion, "ff_row") and torsion.ff_row == ff_tor.ff_row:
                torsion_atom_indices.append((torsion.atom_i, torsion.atom_j, torsion.atom_k, torsion.atom_l))
                torsion_param_map.append(j_ff)
                break

    # Match vdW
    atom_vdw_map = []
    for atom_index, (symbol, atom_type) in enumerate(zip(molecule.symbols, molecule.atom_types, strict=False)):
        idx, param = _match_vdw(forcefield, atom_type=atom_type, element=symbol)
        if param is not None:
            atom_vdw_map.append(idx)
        else:
            atom_vdw_map.append(-1)

    # Validate vdW mapping: if the force field defines vdW parameters,
    # disallow any unmatched atoms.  Using -1 as an index would silently
    # select the last vdW entry via JAX negative indexing, corrupting
    # the energy and gradients.  When the force field defines no vdW
    # terms at all, vdW energy is effectively disabled and unmatched
    # atoms are safe.
    unmatched = [i for i, idx in enumerate(atom_vdw_map) if idx == -1]
    if getattr(forcefield, "vdws", None) and unmatched:
        raise ValueError(
            f"Unmatched vdW parameters for atoms at indices {unmatched}. "
            "Ensure all atom types/elements have corresponding vdW "
            "parameters in the force field, or remove vdW terms from "
            "the force field if vdW interactions are not intended."
        )

    # Build vdW pair list with 1-2 and 1-3 exclusions
    vdw_pairs = _build_vdw_pairs(
        len(molecule.symbols),
        [(b.atom_i, b.atom_j) for b in molecule.bonds],
    )

    bond_indices_arr = (
        np.array(bond_atom_indices, dtype=np.int32) if bond_atom_indices else np.empty((0, 2), dtype=np.int32)
    )
    angle_indices_arr = (
        np.array(angle_atom_indices, dtype=np.int32) if angle_atom_indices else np.empty((0, 3), dtype=np.int32)
    )
    torsion_indices_arr = (
        np.array(torsion_atom_indices, dtype=np.int32) if torsion_atom_indices else np.empty((0, 4), dtype=np.int32)
    )

    handle = JaxHandle(
        molecule=copy.deepcopy(molecule),
        bond_indices=bond_indices_arr,
        angle_indices=angle_indices_arr,
        torsion_indices=torsion_indices_arr,
        vdw_pair_indices=vdw_pairs,
        bond_param_map=np.array(bond_param_map, dtype=np.int32),
        angle_param_map=np.array(angle_param_map, dtype=np.int32),
        torsion_param_map=np.array(torsion_param_map, dtype=np.int32),
        atom_vdw_map=np.array(atom_vdw_map, dtype=np.int32),
        n_bond_types=len(forcefield.bonds),
        n_angle_types=len(forcefield.angles),
        n_torsion_types=len(forcefield.torsions),
        n_vdw_types=len(forcefield.vdws),
    )

    # Compile energy function
    handle._energy_fn = _compile_energy_fn(handle)
    return handle

energy

energy(structure, forcefield) -> float

Calculate energy in kcal/mol.

Parameters:

Name Type Description Default
structure Q2MMMolecule | JaxHandle

A :class:Q2MMMolecule or :class:JaxHandle.

required
forcefield ForceField

Force field parameters.

required

Returns:

Name Type Description
float float

Potential energy in kcal/mol.

Source code in q2mm/backends/mm/jax_engine.py
def energy(self, structure, forcefield) -> float:
    """Calculate energy in kcal/mol.

    Args:
        structure (Q2MMMolecule | JaxHandle): A :class:`Q2MMMolecule` or :class:`JaxHandle`.
        forcefield (ForceField): Force field parameters.

    Returns:
        float: Potential energy in kcal/mol.
    """
    handle = self._get_handle(structure, forcefield)
    params, coords = self._params_and_coords(handle, forcefield)
    return float(handle._energy_fn(params, coords))

energy_and_param_grad

energy_and_param_grad(structure, forcefield) -> tuple[float, ndarray]

Compute energy and analytical gradient w.r.t. FF parameters.

Parameters:

Name Type Description Default
structure Q2MMMolecule | JaxHandle

A :class:Q2MMMolecule or :class:JaxHandle.

required
forcefield ForceField

Force field parameters.

required

Returns:

Type Description
tuple[float, ndarray]

tuple[float, np.ndarray]: (energy, grad) where energy is in kcal/mol and grad has the same shape as forcefield.get_param_vector().

Source code in q2mm/backends/mm/jax_engine.py
def energy_and_param_grad(self, structure, forcefield) -> tuple[float, np.ndarray]:
    """Compute energy and analytical gradient w.r.t. FF parameters.

    Args:
        structure (Q2MMMolecule | JaxHandle): A :class:`Q2MMMolecule` or :class:`JaxHandle`.
        forcefield (ForceField): Force field parameters.

    Returns:
        tuple[float, np.ndarray]: ``(energy, grad)`` where ``energy``
            is in kcal/mol and ``grad`` has the same shape as
            ``forcefield.get_param_vector()``.
    """
    handle = self._get_handle(structure, forcefield)
    params, coords = self._params_and_coords(handle, forcefield)

    if handle._grad_fn is None:
        handle._grad_fn = jax.jit(jax.value_and_grad(handle._energy_fn, argnums=0))

    val, grad = handle._grad_fn(params, coords)
    return float(val), np.asarray(grad)

hessian

hessian(structure, forcefield) -> ndarray

Compute Hessian via jax.hessian (d²E/dcoords²) in Hartree/Bohr².

Parameters:

Name Type Description Default
structure Q2MMMolecule | JaxHandle

A :class:Q2MMMolecule or :class:JaxHandle.

required
forcefield ForceField

Force field parameters.

required

Returns:

Type Description
ndarray

np.ndarray: Shape (3N, 3N) Hessian in Hartree/Bohr².

Source code in q2mm/backends/mm/jax_engine.py
def hessian(self, structure, forcefield) -> np.ndarray:
    """Compute Hessian via ``jax.hessian`` (d²E/dcoords²) in Hartree/Bohr².

    Args:
        structure (Q2MMMolecule | JaxHandle): A :class:`Q2MMMolecule` or :class:`JaxHandle`.
        forcefield (ForceField): Force field parameters.

    Returns:
        np.ndarray: Shape ``(3N, 3N)`` Hessian in Hartree/Bohr².
    """
    handle = self._get_handle(structure, forcefield)
    params, coords = self._params_and_coords(handle, forcefield)

    if handle._coord_hess_fn is None:

        def _energy_of_flat_coords(flat_coords, params_):
            return handle._energy_fn(params_, flat_coords.reshape(-1, 3))

        handle._coord_hess_fn = jax.jit(jax.hessian(_energy_of_flat_coords, argnums=0))

    flat_coords = coords.flatten()
    hess_kcal_a2 = handle._coord_hess_fn(flat_coords, params)
    return np.asarray(hess_kcal_a2) * _KCALMOLA2_TO_HESSIAN_AU

frequencies

frequencies(structure, forcefield) -> list[float]

Compute vibrational frequencies in cm⁻¹ from the Hessian.

Parameters:

Name Type Description Default
structure Q2MMMolecule | JaxHandle

A :class:Q2MMMolecule or :class:JaxHandle.

required
forcefield ForceField

Force field parameters.

required

Returns:

Type Description
list[float]

list[float]: Vibrational frequencies in cm⁻¹, sorted ascending.

Source code in q2mm/backends/mm/jax_engine.py
def frequencies(self, structure, forcefield) -> list[float]:
    """Compute vibrational frequencies in cm⁻¹ from the Hessian.

    Args:
        structure (Q2MMMolecule | JaxHandle): A :class:`Q2MMMolecule` or :class:`JaxHandle`.
        forcefield (ForceField): Force field parameters.

    Returns:
        list[float]: Vibrational frequencies in cm⁻¹, sorted ascending.
    """
    handle = self._get_handle(structure, forcefield)
    hess_au = self.hessian(handle, forcefield)
    n_atoms = len(handle.molecule.symbols)

    masses = np.array([MASSES[s] for s in handle.molecule.symbols])
    mass_weights = np.repeat(masses, 3)
    sqrt_inv_mass = 1.0 / np.sqrt(mass_weights)
    mw_hess = hess_au * np.outer(sqrt_inv_mass, sqrt_inv_mass)

    eigenvalues = np.linalg.eigvalsh(mw_hess)

    # Convert eigenvalues (Hartree / (amu * Bohr²)) → cm⁻¹
    hartree_to_j = 4.359744650e-18
    bohr_to_m = BOHR_TO_ANG * 1e-10
    factor = hartree_to_j / (AMU_TO_KG * bohr_to_m**2)

    freqs = []
    for ev in eigenvalues:
        val = ev * factor
        if val < 0:
            freq_hz = -math.sqrt(-val)
        else:
            freq_hz = math.sqrt(val)
        freq_cm = freq_hz / (2.0 * math.pi * SPEED_OF_LIGHT_MS * 100.0)
        freqs.append(freq_cm)

    return sorted(freqs)

minimize

minimize(structure, forcefield, max_iterations=200) -> tuple

Minimize energy w.r.t. coordinates using analytical JAX gradients.

Uses scipy.optimize.minimize with the L-BFGS-B method.

Parameters:

Name Type Description Default
structure Q2MMMolecule | JaxHandle

A :class:Q2MMMolecule or :class:JaxHandle.

required
forcefield ForceField

Force field parameters.

required
max_iterations int

Maximum number of L-BFGS-B iterations.

200

Returns:

Type Description
tuple

tuple[float, list[str], np.ndarray]: (energy, atoms, coords) where energy is in kcal/mol and coords are in Å.

Source code in q2mm/backends/mm/jax_engine.py
def minimize(self, structure, forcefield, max_iterations=200) -> tuple:
    """Minimize energy w.r.t. coordinates using analytical JAX gradients.

    Uses ``scipy.optimize.minimize`` with the L-BFGS-B method.

    Args:
        structure (Q2MMMolecule | JaxHandle): A :class:`Q2MMMolecule` or :class:`JaxHandle`.
        forcefield (ForceField): Force field parameters.
        max_iterations (int): Maximum number of L-BFGS-B iterations.

    Returns:
        tuple[float, list[str], np.ndarray]: ``(energy, atoms, coords)``
            where energy is in kcal/mol and coords are in Å.
    """
    from scipy.optimize import minimize as scipy_minimize

    handle = self._get_handle(structure, forcefield)
    params, coords = self._params_and_coords(handle, forcefield)

    energy_fn = handle._energy_fn
    coord_grad_fn = jax.jit(jax.grad(lambda c, p: energy_fn(p, c.reshape(-1, 3)), argnums=0))

    x0 = np.asarray(coords.flatten())

    def objective(x):
        return float(energy_fn(params, jnp.array(x).reshape(-1, 3)))

    def gradient(x):
        return np.asarray(coord_grad_fn(jnp.array(x), params))

    result = scipy_minimize(
        objective,
        x0,
        jac=gradient,
        method="L-BFGS-B",
        options={"maxiter": max_iterations},
    )

    opt_coords = result.x.reshape(-1, 3)
    opt_energy = float(result.fun)
    return opt_energy, list(handle.molecule.symbols), opt_coords