import os.path as osp
from typing import List, Tuple, Optional
import torch
from rdkit import Chem
from torch_geometric.data import Data
try:
# Python ≥ 3.12
from typing import override # type: ignore
except ImportError:
# Python < 3.12
from typing_extensions import override # type: ignore
from chemtorch.components.representation.abstract_representation import (
AbstractRepresentation,
)
[docs]
def read_xyz(file_path: str) -> Tuple[List[str], torch.Tensor]:
"""
Reads a standard XYZ file and returns atomic symbols and coordinates.
Args:
file_path (str): The path to the .xyz file.
Returns:
A tuple containing a list of atomic symbols (str) and a tensor
of atomic coordinates ([num_atoms, 3]).
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file is malformed or cannot be parsed.
"""
if not osp.exists(file_path):
raise FileNotFoundError(f"XYZ file not found at: {file_path}")
try:
with open(file_path, "r") as f:
lines = f.readlines()
atomic_symbols = []
coords = []
for line in lines[2:]:
parts = line.strip().split()
if len(parts) >= 4:
atomic_symbols.append(parts[0])
coords.append([float(p) for p in parts[1:4]])
return atomic_symbols, torch.tensor(coords, dtype=torch.float)
except (IOError, IndexError, ValueError) as e:
raise ValueError(f"Error reading or parsing XYZ file {file_path}: {e}")
[docs]
def symbols_to_atomic_numbers(symbols: List[str]) -> torch.Tensor:
"""
Converts a list of atomic symbols (e.g., ['C', 'H']) to a tensor of atomic numbers.
Args:
symbols (List[str]): A list of atomic symbols.
Returns:
A tensor of atomic numbers.
"""
pt = Chem.GetPeriodicTable()
try:
atomic_nums = [pt.GetAtomicNumber(s) for s in symbols]
return torch.tensor(atomic_nums, dtype=torch.long)
except Exception as e:
raise ValueError(f"Error converting symbols to atomic numbers: {e}")
[docs]
class Reaction3DData(Data):
"""
Custom PyG Data class for reaction 3D graphs.
Attributes:
z_r (torch.Tensor): Atomic numbers for the reactant.
pos_r (torch.Tensor): Atomic positions for the reactant.
z_ts (torch.Tensor): Atomic numbers for the transition state.
pos_ts (torch.Tensor): Atomic positions for the transition state.
"""
z_r: torch.Tensor
pos_r: torch.Tensor
z_ts: torch.Tensor
pos_ts: torch.Tensor
[docs]
def __init__(
self,
z_r: torch.Tensor,
pos_r: torch.Tensor,
z_ts: torch.Tensor,
pos_ts: torch.Tensor,
smiles: Optional[str] = None,
num_nodes: Optional[int] = None,
**kwargs
):
super().__init__(
z_r=z_r,
pos_r=pos_r,
z_ts=z_ts,
pos_ts=pos_ts,
smiles=smiles,
num_nodes=num_nodes,
**kwargs
)
self.z_r = z_r
self.pos_r = pos_r
self.z_ts = z_ts
self.pos_ts = pos_ts
if smiles is not None:
self.smiles = smiles
if num_nodes is not None:
self.num_nodes = num_nodes
[docs]
class Reaction3DGraph(AbstractRepresentation[Data]):
"""
Constructs a 3D representation of a reaction from XYZ files.
This representation reads the 3D structures for a reactant (r) and transition state (ts)
from their respective .xyz files and packages them into a single PyTorch Geometric
`Data` object (specifically, a `Reaction3DData` object).
The resulting `Reaction3DData` object for each sample contains:
- `z_r`, `pos_r`: Atomic numbers and 3D coordinates for the reactant
- `z_ts`, `pos_ts`: Atomic numbers and 3D coordinates for the transition state
- `smiles`: The reaction SMILES string (for reference)
- `num_nodes`: Total number of atoms in the structure
This representation is particularly useful for models like DimeReaction that operate
on 3D geometries of chemical reactions.
"""
[docs]
def __init__(self, root_dir: str):
"""
Args:
root_dir (str): The root directory where reaction subfolders (e.g., 'reaction_1',
'reaction_2') are located.
"""
if not osp.isdir(root_dir):
raise FileNotFoundError(
f"The specified root directory does not exist: {root_dir}"
)
self.root_dir = root_dir
[docs]
@override
def construct(self, smiles: str, reaction_dir: str) -> Data:
"""
Constructs a single reaction graph from its corresponding XYZ files.
This method is called by `DatasetBase` for each row in the DataFrame and reads
the reactant and transition state XYZ files from the specified reaction directory.
Args:
smiles (str): The reaction SMILES string.
reaction_dir (str): The name/ID of the subdirectory within `root_dir`
containing the XYZ files for this reaction (e.g., '1', '42').
Will be zero-padded to 6 digits (e.g., '000001', '000042').
Returns:
A `Reaction3DData` object containing the 3D structures with attributes:
- z_r, pos_r: Atomic numbers and positions for reactant
- z_ts, pos_ts: Atomic numbers and positions for transition state
- smiles: The reaction SMILES
- num_nodes: Number of atoms
Raises:
FileNotFoundError: If the reaction directory or any required .xyz file is not found.
ValueError: If the number of atoms is inconsistent between reactant and TS structures.
"""
reaction_dir = str(reaction_dir).zfill(6)
folder_path = osp.join(self.root_dir, f"rxn{reaction_dir}")
if not osp.isdir(folder_path):
raise FileNotFoundError(f"Reaction directory not found: {folder_path}")
structures = {}
for state in ["r", "ts"]:
file_path = osp.join(folder_path, f"{state}{reaction_dir}.xyz")
symbols, pos = read_xyz(file_path)
z = symbols_to_atomic_numbers(symbols)
structures[state] = {"z": z, "pos": pos}
num_atoms = structures["ts"]["pos"].shape[0]
if not all(s["pos"].shape[0] == num_atoms for s in structures.values()):
raise ValueError(
f"Inconsistent number of atoms in reaction {reaction_dir}."
)
data = Reaction3DData(
z_r=structures["r"]["z"],
pos_r=structures["r"]["pos"],
z_ts=structures["ts"]["z"],
pos_ts=structures["ts"]["pos"],
smiles=smiles,
num_nodes=num_atoms,
)
return data