Source code for chemtorch.components.representation.graph.reaction_3d_graph

import os.path as osp
from typing import List, Tuple, Optional

import torch
from rdkit import Chem
from torch_geometric.data import Data
try:
    # Python ≥ 3.12
    from typing import override  # type: ignore
except ImportError:
    # Python < 3.12
    from typing_extensions import override  # type: ignore

from chemtorch.components.representation.abstract_representation import (
    AbstractRepresentation,
)


[docs] def read_xyz(file_path: str) -> Tuple[List[str], torch.Tensor]: """ Reads a standard XYZ file and returns atomic symbols and coordinates. Args: file_path (str): The path to the .xyz file. Returns: A tuple containing a list of atomic symbols (str) and a tensor of atomic coordinates ([num_atoms, 3]). Raises: FileNotFoundError: If the file does not exist. ValueError: If the file is malformed or cannot be parsed. """ if not osp.exists(file_path): raise FileNotFoundError(f"XYZ file not found at: {file_path}") try: with open(file_path, "r") as f: lines = f.readlines() atomic_symbols = [] coords = [] for line in lines[2:]: parts = line.strip().split() if len(parts) >= 4: atomic_symbols.append(parts[0]) coords.append([float(p) for p in parts[1:4]]) return atomic_symbols, torch.tensor(coords, dtype=torch.float) except (IOError, IndexError, ValueError) as e: raise ValueError(f"Error reading or parsing XYZ file {file_path}: {e}")
[docs] def symbols_to_atomic_numbers(symbols: List[str]) -> torch.Tensor: """ Converts a list of atomic symbols (e.g., ['C', 'H']) to a tensor of atomic numbers. Args: symbols (List[str]): A list of atomic symbols. Returns: A tensor of atomic numbers. """ pt = Chem.GetPeriodicTable() try: atomic_nums = [pt.GetAtomicNumber(s) for s in symbols] return torch.tensor(atomic_nums, dtype=torch.long) except Exception as e: raise ValueError(f"Error converting symbols to atomic numbers: {e}")
[docs] class Reaction3DData(Data): """ Custom PyG Data class for reaction 3D graphs. Attributes: z_r (torch.Tensor): Atomic numbers for the reactant. pos_r (torch.Tensor): Atomic positions for the reactant. z_ts (torch.Tensor): Atomic numbers for the transition state. pos_ts (torch.Tensor): Atomic positions for the transition state. """ z_r: torch.Tensor pos_r: torch.Tensor z_ts: torch.Tensor pos_ts: torch.Tensor
[docs] def __init__( self, z_r: torch.Tensor, pos_r: torch.Tensor, z_ts: torch.Tensor, pos_ts: torch.Tensor, smiles: Optional[str] = None, num_nodes: Optional[int] = None, **kwargs ): super().__init__( z_r=z_r, pos_r=pos_r, z_ts=z_ts, pos_ts=pos_ts, smiles=smiles, num_nodes=num_nodes, **kwargs ) self.z_r = z_r self.pos_r = pos_r self.z_ts = z_ts self.pos_ts = pos_ts if smiles is not None: self.smiles = smiles if num_nodes is not None: self.num_nodes = num_nodes
[docs] class Reaction3DGraph(AbstractRepresentation[Data]): """ Constructs a 3D representation of a reaction from XYZ files. This representation reads the 3D structures for a reactant (r) and transition state (ts) from their respective .xyz files and packages them into a single PyTorch Geometric `Data` object (specifically, a `Reaction3DData` object). The resulting `Reaction3DData` object for each sample contains: - `z_r`, `pos_r`: Atomic numbers and 3D coordinates for the reactant - `z_ts`, `pos_ts`: Atomic numbers and 3D coordinates for the transition state - `smiles`: The reaction SMILES string (for reference) - `num_nodes`: Total number of atoms in the structure This representation is particularly useful for models like DimeReaction that operate on 3D geometries of chemical reactions. """
[docs] def __init__(self, root_dir: str): """ Args: root_dir (str): The root directory where reaction subfolders (e.g., 'reaction_1', 'reaction_2') are located. """ if not osp.isdir(root_dir): raise FileNotFoundError( f"The specified root directory does not exist: {root_dir}" ) self.root_dir = root_dir
[docs] @override def construct(self, smiles: str, reaction_dir: str) -> Data: """ Constructs a single reaction graph from its corresponding XYZ files. This method is called by `DatasetBase` for each row in the DataFrame and reads the reactant and transition state XYZ files from the specified reaction directory. Args: smiles (str): The reaction SMILES string. reaction_dir (str): The name/ID of the subdirectory within `root_dir` containing the XYZ files for this reaction (e.g., '1', '42'). Will be zero-padded to 6 digits (e.g., '000001', '000042'). Returns: A `Reaction3DData` object containing the 3D structures with attributes: - z_r, pos_r: Atomic numbers and positions for reactant - z_ts, pos_ts: Atomic numbers and positions for transition state - smiles: The reaction SMILES - num_nodes: Number of atoms Raises: FileNotFoundError: If the reaction directory or any required .xyz file is not found. ValueError: If the number of atoms is inconsistent between reactant and TS structures. """ reaction_dir = str(reaction_dir).zfill(6) folder_path = osp.join(self.root_dir, f"rxn{reaction_dir}") if not osp.isdir(folder_path): raise FileNotFoundError(f"Reaction directory not found: {folder_path}") structures = {} for state in ["r", "ts"]: file_path = osp.join(folder_path, f"{state}{reaction_dir}.xyz") symbols, pos = read_xyz(file_path) z = symbols_to_atomic_numbers(symbols) structures[state] = {"z": z, "pos": pos} num_atoms = structures["ts"]["pos"].shape[0] if not all(s["pos"].shape[0] == num_atoms for s in structures.values()): raise ValueError( f"Inconsistent number of atoms in reaction {reaction_dir}." ) data = Reaction3DData( z_r=structures["r"]["z"], pos_r=structures["r"]["pos"], z_ts=structures["ts"]["z"], pos_ts=structures["ts"]["pos"], smiles=smiles, num_nodes=num_atoms, ) return data