Components

Contents

Components#

The components package contains all swappable building blocks for the ChemTorch pipeline.

Data Pipeline#

Data Source#

class chemtorch.components.data_pipeline.data_source.abstract_data_source.AbstractDataSource[source]#

Bases: ABC

Abstract base class for data sources.

This class defines the interface for loading data from various sources. Subclasses should implement the load method to provide specific data loading functionality.

abstract load()[source]#
class chemtorch.components.data_pipeline.data_source.single_csv_source.SingleCSVSource(data_path: str)[source]#

Bases: AbstractDataSource

__init__(data_path: str)[source]#
load() DataFrame[source]#

Load data from a single CSV file.

class chemtorch.components.data_pipeline.data_source.pre_split_csv_source.PreSplitCSVSource(data_folder: str)[source]#

Bases: AbstractDataSource

__init__(data_folder: str)[source]#
load() DataSplit[source]#

Load presplit data from CSV files in a specified folder. The files should be named ‘train.csv’, ‘val.csv’, and ‘test.csv’.

Column Mapper#

class chemtorch.components.data_pipeline.column_mapper.abstract_column_mapper.AbstractColumnMapper[source]#

Bases: ABC

class chemtorch.components.data_pipeline.column_mapper.column_filter_rename.ColumnFilterAndRename(**column_mappings)[source]#

Bases: AbstractColumnMapper

A pipeline component that filters and renames columns in a DataFrame based on provided column mappings.

Usage:

mapper = ColumnFilterAndRename(smiles=”smiles_column”, label=”target_column”) # This will rename “smiles_column” to “smiles” and “target_column” to “label”

__init__(**column_mappings)[source]#

Initialize the ColumnFilterAndRename.

Parameters:

**column_mappings – Keyword arguments where the key is the new column name and the value is the existing column name in the DataFrame. Example: smiles=”smiles_column”, label=”target_column”

Data Splitter#

class chemtorch.components.data_pipeline.data_splitter.abstract_data_splitter.AbstractDataSplitter[source]#

Bases: ABC

Abstract base class for data splitting strategies.

Subclass should implement the __call__ method to define the splitting logic.

class chemtorch.components.data_pipeline.data_splitter.data_splitter_base.DataSplitterBase(save_path: str | None = None)[source]#

Bases: AbstractDataSplitter

Base class for data splitting strategies. Callable that takes a DataFrame, executes splitting logic, and returns a DataSplit object. Subclass should implement the private _split method to define specific splitting logic.

__init__(save_path: str | None = None) None[source]#

Initializes the DataSplitter.

Parameters:

save_path (str | None, optional) – If provided, saves split indices as a pickle file to this path. Must end with ‘.pkl’. Defaults to None.

Raises:

ValueError – If the provided save_path is not a pickle file (i.e. does not end with ‘.pkl’).

class chemtorch.components.data_pipeline.data_splitter.ratio_splitter.RatioSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, save_path: str | None = None)[source]#

Bases: DataSplitterBase

Splits data into training, validation, and test sets based on specified ratios.

Subclasses should override the _split method to implement custom splitting logic. By default the RatioSplitter randomly shuffles the data and splits it according to the specified ratios.

__init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, save_path: str | None = None)[source]#

Initializes a RatioSplitter.

Parameters:
  • train_ratio (float) – The ratio of data for training.

  • val_ratio (float) – The ratio of data for validation.

  • test_ratio (float) – The ratio of data for testing.

  • save_path (str | None, optional) – If provided, saves split indices as pickle file.

class chemtorch.components.data_pipeline.data_splitter.scaffold_splitter.ScaffoldSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, split_on: str | None = None, mol_idx: int | None = None, save_path: str | None = None)[source]#

Bases: SMILESGroupSplitterBase

__init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, split_on: str | None = None, mol_idx: int | None = None, save_path: str | None = None)[source]#

Initializes the ScaffoldSplitter.

Splits data by grouping molecules based on their Murcko scaffold, ensuring that all molecules with the same scaffold are in the same split (train, val, or test). This is a standard method to test a model’s ability to generalize to new chemical scaffolds.

Parameters:
  • train_ratio (float) – The desired ratio of data for the training set.

  • val_ratio (float) – The desired ratio of data for the validation set.

  • test_ratio (float) – The desired ratio of data for the test set.

  • include_chirality (bool) – If True, includes chirality in the scaffold generation.

  • split_on (str | None) – Specifies whether to use the ‘reactant’ or ‘product’ for scaffold generation when processing reaction SMILES. Required for reaction SMILES, ignored for single molecules. Defaults to None.

  • mol_idx (int | None) – Zero-based index specifying which molecule to use if multiple are present (e.g., ‘A.B>>C’ or ‘A.B’). Required when multiple molecules are present in the selected part of the reaction. Defaults to None.

  • save_path (str | None, optional) – If provided, saves split indices as pickle file.

class chemtorch.components.data_pipeline.data_splitter.target_splitter.TargetSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, sort_order: str = 'ascending', save_path: str | None = None)[source]#

Bases: RatioSplitter

__init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, sort_order: str = 'ascending', save_path: str | None = None)[source]#

Initializes the TargetSplitter.

Parameters:
  • train_ratio (float) – The ratio of data for the training set.

  • val_ratio (float) – The ratio of data for the validation set.

  • test_ratio (float) – The ratio of data for the test set.

  • sort_order (str) – ‘ascending’ or ‘descending’.

  • save_path (str | None, optional) – Path to save split indices.

class chemtorch.components.data_pipeline.data_splitter.size_splitter.SizeSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, sort_order: str = 'ascending', save_path: str | None = None)[source]#

Bases: RatioSplitter

Splits data into training, validation, and test sets based on molecular size (number of heavy atoms). For single molecules, sorts by the number of heavy atoms in each molecule. For reactions, sorts by the sum of heavy atoms in reactants and products.

__init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, sort_order: str = 'ascending', save_path: str | None = None)[source]#

Initializes the SizeSplitter.

Parameters:
  • train_ratio (float) – The ratio of data for the training set.

  • val_ratio (float) – The ratio of data for the validation set.

  • test_ratio (float) – The ratio of data for the test set.

  • sort_order (str) – ‘ascending’ or ‘descending’.

  • save_path (str | None, optional) – If provided, saves split indices as pickle file.

class chemtorch.components.data_pipeline.data_splitter.reaction_core_splitter.ReactionCoreSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, save_path: str | None = None)[source]#

Bases: SMILESGroupSplitterBase

__init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, save_path: str | None = None)[source]#

Initializes the ReactionCoreSplitter.

Splits data by grouping reactions based on their reaction core/template, ensuring that all reactions with the same reaction core are in the same split (train, val, or test). This is a method to test a model’s ability to generalize to new reaction types.

Parameters:
  • train_ratio (float) – The desired ratio of data for the training set.

  • val_ratio (float) – The desired ratio of data for the validation set.

  • test_ratio (float) – The desired ratio of data for the test set.

  • include_chirality (bool) – If True, includes chirality in the reaction core generation.

  • save_path (str | None, optional) – If provided, saves split indices as pickle file.

class chemtorch.components.data_pipeline.data_splitter.index_splitter.IndexSplitter(split_index_path: str, save_path: str | None = None)[source]#

Bases: DataSplitterBase

__init__(split_index_path: str, save_path: str | None = None)[source]#

Initializes the IndexSplitter with the specified index path.

Parameters:
  • split_index_path (str) – The path to the pickle file containing the index.

  • save_split_dir (str | None, optional) – If provided, enables saving of split files.

  • save_indices (bool) – If True and save_split_dir is set, re-saves ‘indices.pkl’.

  • save_csv (bool) – If True and save_split_dir is set, saves split DataFrames as CSVs.

Pipeline#

class chemtorch.components.data_pipeline.simple_data_pipeline.SimpleDataPipeline(data_source: AbstractDataSource, column_mapper: AbstractColumnMapper, data_splitter: AbstractDataSplitter | None = None)[source]#

Bases: object

A simple data pipeline that orchestrates data loading, column mapping, and data splitting.

The ingestion process is as follows: 1. Load data using the data_source. This can result in a single

DataFrame or an already split DataSplit object.

  1. Apply column transformations (filtering, renaming) using the column_mapper. This mapper can operate on both single DataFrames and DataSplit objects.

  2. If the data after mapping is a single DataFrame, split it using the data_splitter. If it’s already a DataSplit, this step is skipped.

__init__(data_source: AbstractDataSource, column_mapper: AbstractColumnMapper, data_splitter: AbstractDataSplitter | None = None)[source]#

Initializes the SimpleDataPipeline.

Parameters:
  • data_source (DataSource) – The component responsible for loading the initial data.

  • column_mapper (ColumnMapper) – The component for transforming columns. It should handle both pd.DataFrame and DataSplit inputs.

  • data_splitter (Optional[AbstractDataSplitter]) – The component for splitting a single DataFrame into train, validation, and test sets. This is not used if data_source already provides split data.

Representation#

Base Classes#

class chemtorch.components.representation.abstract_representation.AbstractRepresentation[source]#

Bases: ABC, Generic[T]

Abstract base class for all chemistry representation creators.

All representations in ChemTorch must subclass this class and implement the construct method. This class can also be used for type hinting where AbstractRepresentation[T] means that the __call__ and construct methods will return an object of type T.

The representation should be stateless - it should not hold any mutable state and the same input should always produce the same output.

Raises:

TypeError – If the subclass does not implement the construct method.

Example (correct usage):
>>> class MyRepresentation(AbstractRepresentation[torch.Tensor]):
...     def construct(self, smiles: str) -> torch.Tensor:
...         # Convert SMILES to tensor representation
...         return torch.tensor([1, 2, 3])
>>> r = MyRepresentation()
>>> r("CCO")  # ethanol
tensor([1, 2, 3])
Example (incorrect usage, raises TypeError):
>>> class BadRepresentation(AbstractRepresentation[torch.Tensor]):
...     pass
>>> r = BadRepresentation()
Traceback (most recent call last):
    ...
TypeError: Can't instantiate abstract class BadRepresentation with abstract method construct
abstract construct(smiles: str, **kwargs) T[source]#

Construct a representation from a SMILES string.

Parameters:
  • smiles (str) – A SMILES string representing a molecule or reaction. For reactions, typically in the format “reactants>reagents>products” or “reactants>>products” (e.g., “CCO>>CC=O”). For molecules, a standard SMILES string (e.g., “CCO”).

  • **kwargs – Additional keyword arguments that may be required by specific representation implementations (e.g., reaction_dir for Reaction3DGraph).

Returns:

The constructed representation of the specified type.

Common types include torch.Tensor for token representations, torch_geometric.data.Data for graph representations, etc.

Return type:

T

Raises:
  • ValueError – If the SMILES string is invalid or cannot be processed.

  • RuntimeError – If representation construction fails for any other reason.

Graph Representations#

class chemtorch.components.representation.graph.cgr.CGR(atom_featurizer: FeaturizerBase[Atom] | FeaturizerCompose, bond_featurizer: FeaturizerBase[Bond] | FeaturizerCompose, **kwargs)[source]#

Bases: AbstractRepresentation[Data]

Stateless class for constructing Condensed Graph of Reaction (CGR) representations.

This class does not hold any data itself. Instead, it provides a forward() method that takes a sample (e.g., a dict or pd.Series) and returns a PyTorch Geometric Data object representing the reaction as a graph.

Usage:
>>> from chemtorch.components.representation.graph.featurizer import FeaturizerCompose
>>> from chemtorch.components.representation.graph.featurizer.atom_featurizer import (
...     AtomicNumber, AtomDegree, AtomFormalCharge, AtomIsAromatic
... )
>>> from chemtorch.components.representation.graph.featurizer.bond_featurizer import (
...     BondType, BondIsInRing
... )
>>>
>>> atom_featurizer = FeaturizerCompose([
...     AtomicNumber(),
...     AtomDegree(),
...     AtomFormalCharge(),
...     AtomIsAromatic(),
... ])
>>> bond_featurizer = FeaturizerCompose([
...     BondType(),
...     BondIsInRing(),
... ])
>>>
>>> cgr = CGR(atom_featurizer, bond_featurizer)
>>> data = cgr.construct("CC>>CCO")  # reaction SMILES
>>> data = cgr("CC>>CCO")  # equivalent to above line
__init__(atom_featurizer: FeaturizerBase[Atom] | FeaturizerCompose, bond_featurizer: FeaturizerBase[Bond] | FeaturizerCompose, **kwargs)[source]#

Initialize the CGR representation with atom and bond featurizers.

Parameters:
  • atom_featurizer (FeaturizerBase[Atom] | FeaturizerCompose) – A featurizer for atom features, which can be a single featurizer or a composed one.

  • bond_featurizer (FeaturizerBase[Bond] | FeaturizerCompose) – A featurizer for bond features, which can also be a single featurizer or a composed one.

construct(smiles: str) Data[source]#

Construct a CGR graph from the sample.

chemtorch.components.representation.graph.reaction_3d_graph.read_xyz(file_path: str) Tuple[List[str], Tensor][source]#

Reads a standard XYZ file and returns atomic symbols and coordinates.

Parameters:

file_path (str) – The path to the .xyz file.

Returns:

A tuple containing a list of atomic symbols (str) and a tensor of atomic coordinates ([num_atoms, 3]).

Raises:
chemtorch.components.representation.graph.reaction_3d_graph.symbols_to_atomic_numbers(symbols: List[str]) Tensor[source]#

Converts a list of atomic symbols (e.g., [‘C’, ‘H’]) to a tensor of atomic numbers.

Parameters:

symbols (List[str]) – A list of atomic symbols.

Returns:

A tensor of atomic numbers.

class chemtorch.components.representation.graph.reaction_3d_graph.Reaction3DData(z_r: Tensor, pos_r: Tensor, z_ts: Tensor, pos_ts: Tensor, smiles: str | None = None, num_nodes: int | None = None, **kwargs)[source]#

Bases: Data

Custom PyG Data class for reaction 3D graphs.

z_r#

Atomic numbers for the reactant.

Type:

torch.Tensor

pos_r#

Atomic positions for the reactant.

Type:

torch.Tensor

z_ts#

Atomic numbers for the transition state.

Type:

torch.Tensor

pos_ts#

Atomic positions for the transition state.

Type:

torch.Tensor

__init__(z_r: Tensor, pos_r: Tensor, z_ts: Tensor, pos_ts: Tensor, smiles: str | None = None, num_nodes: int | None = None, **kwargs)[source]#
z_r: Tensor#
pos_r: Tensor#
z_ts: Tensor#
pos_ts: Tensor#
class chemtorch.components.representation.graph.reaction_3d_graph.Reaction3DGraph(root_dir: str)[source]#

Bases: AbstractRepresentation[Data]

Constructs a 3D representation of a reaction from XYZ files.

This representation reads the 3D structures for a reactant (r) and transition state (ts) from their respective .xyz files and packages them into a single PyTorch Geometric Data object (specifically, a Reaction3DData object).

The resulting Reaction3DData object for each sample contains: - z_r, pos_r: Atomic numbers and 3D coordinates for the reactant - z_ts, pos_ts: Atomic numbers and 3D coordinates for the transition state - smiles: The reaction SMILES string (for reference) - num_nodes: Total number of atoms in the structure

This representation is particularly useful for models like DimeReaction that operate on 3D geometries of chemical reactions.

__init__(root_dir: str)[source]#
Parameters:

root_dir (str) – The root directory where reaction subfolders (e.g., ‘reaction_1’, ‘reaction_2’) are located.

construct(smiles: str, reaction_dir: str) Data[source]#

Constructs a single reaction graph from its corresponding XYZ files.

This method is called by DatasetBase for each row in the DataFrame and reads the reactant and transition state XYZ files from the specified reaction directory.

Parameters:
  • smiles (str) – The reaction SMILES string.

  • reaction_dir (str) – The name/ID of the subdirectory within root_dir containing the XYZ files for this reaction (e.g., ‘1’, ‘42’). Will be zero-padded to 6 digits (e.g., ‘000001’, ‘000042’).

Returns:

  • z_r, pos_r: Atomic numbers and positions for reactant

  • z_ts, pos_ts: Atomic numbers and positions for transition state

  • smiles: The reaction SMILES

  • num_nodes: Number of atoms

Return type:

A Reaction3DData object containing the 3D structures with attributes

Raises:
  • FileNotFoundError – If the reaction directory or any required .xyz file is not found.

  • ValueError – If the number of atoms is inconsistent between reactant and TS structures.

Fingerprint Representations#

class chemtorch.components.representation.fingerprint.drfp.DRFP(n_folded_length: int = 2048, min_radius: int = 0, radius: int = 3, rings: bool = True, root_central_atom: bool = True, include_hydrogens: bool = False)[source]#

Bases: AbstractRepresentation[Tensor]

Stateless class for constructing DRFP (Differential Reaction Fingerprints).

This class provides a forward() method that takes a reaction SMILES string and returns a PyTorch Tensor representing the folded DRFP fingerprint. It utilizes the DrfpEncoder class for the core fingerprint generation logic.

__init__(n_folded_length: int = 2048, min_radius: int = 0, radius: int = 3, rings: bool = True, root_central_atom: bool = True, include_hydrogens: bool = False)[source]#

Initializes the DRFP representation creator.

Parameters:
  • n_folded_length (int, optional) – The length of the folded fingerprint. Default is 2048.

  • min_radius (int, optional) – The minimum radius for substructure extraction. Default is 0 (includes single atoms).

  • radius (int, optional) – The maximum radius for substructure extraction. Default is 3 (corresponds to DRFP6).

  • rings (bool, optional) – Whether to include full rings as substructures. Default is True.

  • root_central_atom (bool, optional) – Whether to root the central atom of substructures when generating SMILES. Default is True.

  • include_hydrogens (bool, optional) – Whether to include hydrogens in the molecular representation. Default is False.

construct(smiles: str) Tensor[source]#

Generates a DRFP fingerprint for a single reaction SMILES string.

The method uses DrfpEncoder.internal_encode to get the hashed difference features of the reaction and then DrfpEncoder.fold to create the final binary fingerprint vector as a NumPy array, which is then converted to a PyTorch Tensor.

Parameters:

smiles – The reaction SMILES string (e.g., “R1.R2>A>P1.P2”).

Returns:

A PyTorch Tensor (dtype=torch.float32) representing the folded DRFP fingerprint.

Raises:
  • NoReactionError – If the input SMILES is not a valid reaction SMILES (as detected by DrfpEncoder.internal_encode).

  • RuntimeError – For other errors encountered during fingerprint generation, wrapping the original exception.

class chemtorch.components.representation.fingerprint.drfp.DRFPUtil[source]#

Bases: object

A utility class for encoding SMILES as drfp fingerprints.

static shingling_from_mol(in_mol: Mol, radius: int = 3, rings: bool = True, min_radius: int = 0, get_atom_indices: bool = False, root_central_atom: bool = True, include_hydrogens: bool = False) List[str] | Tuple[List[str], Dict[str, List[Set[int]]]][source]#

Creates a molecular shingling from a RDKit molecule (rdkit.Chem.rdchem.Mol).

Parameters:
  • in_mol – A RDKit molecule instance

  • radius – The drfp radius (a radius of 3 corresponds to drfp6)

  • rings – Whether or not to include rings in the shingling

  • min_radius – The minimum radius that is used to extract n-grams

Returns:

The molecular shingling.

static internal_encode(in_smiles: str, radius: int = 3, min_radius: int = 0, rings: bool = True, get_atom_indices: bool = False, root_central_atom: bool = True, include_hydrogens: bool = False) Tuple[ndarray, ndarray] | Tuple[ndarray, ndarray, Dict[str, List[Dict[str, List[Set[int]]]]]][source]#

Creates an drfp array from a reaction SMILES string.

Parameters:
  • in_smiles – A valid reaction SMILES string

  • radius – The drfp radius (a radius of 3 corresponds to drfp6)

  • min_radius – The minimum radius that is used to extract n-grams

  • rings – Whether or not to include rings in the shingling

Returns:

A tuple with two arrays, the first containing the drfp hash values, the second the substructure SMILES

static hash(shingling: List[str]) ndarray[source]#

Directly hash all the SMILES in a shingling to a 32-bit integer.

Parameters:

shingling – A list of n-grams

Returns:

A list of hashed n-grams

static fold(hash_values: ndarray, length: int = 2048) Tuple[ndarray, ndarray][source]#

Folds the hash values to a binary vector of a given length.

Parameters:
  • hash_value – An array containing the hash values

  • length – The length of the folded fingerprint

Returns:

A tuple containing the folded fingerprint and the indices of the on bits

static encode(X: Iterable | str, n_folded_length: int = 2048, min_radius: int = 0, radius: int = 3, rings: bool = True, mapping: bool = False, atom_index_mapping: bool = False, root_central_atom: bool = True, include_hydrogens: bool = False, show_progress_bar: bool = False) List[ndarray] | Tuple[List[ndarray], Dict[int, Set[str]]] | List[Dict[str, List[Dict[str, List[Set[int]]]]]][source]#

Encodes a list of reaction SMILES using the drfp fingerprint.

Parameters:
  • X – An iterable (e.g. List) of reaction SMILES or a single reaction SMILES to be encoded

  • n_folded_length – The folded length of the fingerprint (the parameter for the modulo hashing)

  • min_radius – The minimum radius of a substructure (0 includes single atoms)

  • radius – The maximum radius of a substructure

  • rings – Whether to include full rings as substructures

  • mapping – Return a feature to substructure mapping in addition to the fingerprints

  • atom_index_mapping – Return the atom indices of mapped substructures for each reaction

  • root_central_atom – Whether to root the central atom of substructures when generating SMILES

  • show_progress_bar – Whether to show a progress bar when encoding reactions

Returns:

A list of drfp fingerprints or, if mapping is enabled, a tuple containing a list of drfp fingerprints and a mapping dict.

Transform#

Base Classes#

class chemtorch.components.transform.abstract_transform.AbstractTransform[source]#

Bases: ABC, Generic[T]

Abstract base class for transforms in the chemtorch framework. This class serves as a base for creating transforms that operate on single data objects.

Raises:

TypeError – If the subclass does not implement the __call__ method.

Example (correct usage):
>>> class MyTransform(TransformBase[int]):
...     def __call__(self, obj: int) -> int:
...         return obj * 2
>>> t = MyTransform()
>>> t(3)
6
Example (incorrect usage, raises TypeError):
>>> class BadTransform(TransformBase[int]):
...     pass
>>> t = BadTransform()
Traceback (most recent call last):
    ...
TypeError: Can't instantiate abstract class BadTransform with abstract method __call__

Graph Transforms#

Augmentation#

class chemtorch.components.augmentation.abstract_augmentation.AbstractAugmentation[source]#

Bases: ABC, Generic[T]

Abstract base class for data augmentations in the chemtorch framework.

This class serves as a base for creating data augmentations that operate on data points. A data point (or sample) is an object of type T returned by the representation and an optional label tensor.

Subclasses must implement the __call__ method to define the augmentation logic.

Raises:
  • TypeError – If the subclass does not implement the __call__ method.

  • RuntimeError – If the input object has a label, but not all augmented objects have labels, or vice versa.

Model#

Graph Neural Networks#

class chemtorch.components.model.gnn.gnn.GNN(encoder: Callable[[Batch], Batch], layer_stack: Callable[[Batch], Batch], pool: Callable[[Batch], Tensor], head: Callable[[Tensor], Tensor])[source]#

Bases: Module

__init__(encoder: Callable[[Batch], Batch], layer_stack: Callable[[Batch], Batch], pool: Callable[[Batch], Tensor], head: Callable[[Tensor], Tensor])[source]#
Parameters:
  • encoder (Callable[[Batch], Batch]) – The encoder function that processes the input batch.

  • layer_stack (Callable[[Batch], Batch]) – The GNN layer stack that processes the batch.

  • pool (Callable[[Batch], torch.Tensor]) – The pooling function that converts the graphs to a tensor.

  • head (Callable[[torch.Tensor], torch.Tensor]) – The head function for final prediction.

forward(batch: Batch) Tensor[source]#

Forward pass through the GNN model.

Parameters:

batch (Batch) – The input batch of graphs.

Returns:

The output predictions.

Return type:

torch.Tensor

Encoder#

Pooling#

Other Models#

class chemtorch.components.model.mlp.MLP(in_channels: int, out_channels: int, hidden_dims: List[int] | None = None, hidden_size: int | None = None, num_hidden_layers: int | None = None, dropout: float = 0.0, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None, norm: str | Callable | None = None, norm_kwargs: Dict[str, Any] | None = None)[source]#

Bases: Module

Multi-Layer Perceptron (MLP) with configurable layers and activation functions.

__init__(in_channels: int, out_channels: int, hidden_dims: List[int] | None = None, hidden_size: int | None = None, num_hidden_layers: int | None = None, dropout: float = 0.0, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None, norm: str | Callable | None = None, norm_kwargs: Dict[str, Any] | None = None)[source]#

Initializes an MLP. The architecture is built to match FFNHead’s structure and instantiation order precisely.

  • A Dropout layer is placed before every Linear layer.

  • Activation is applied after every hidden Linear layer.

  • Normalization is applied after every hidden Linear layer (if specified).

Parameters:
  • in_channels (int) – Input dimension.

  • out_channels (int) – Output dimension.

  • hidden_dims (List[int], optional) – List of hidden layer dimensions (preferred).

  • hidden_size (int, optional) – Hidden layer size (used if hidden_dims is not provided).

  • num_hidden_layers (int, optional) – Number of hidden layers (used if hidden_dims is not provided).

  • dropout (float, optional) – Dropout rate. Defaults to 0.

  • act (str or Callable, optional) – Activation function. Defaults to “relu”.

  • act_kwargs (Dict[str, Any], optional) – Arguments for the activation function.

  • norm (str or Callable, optional) – Normalization layer. Defaults to None.

  • norm_kwargs (Dict[str, Any], optional) – Arguments for the normalization layer.

forward(x: Tensor) Tensor[source]#

Forward pass of the MLP.

class chemtorch.components.model.dimenet.Envelope(exponent: int)[source]#

Bases: Module

__init__(exponent: int)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.BesselBasisLayer(num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#

Bases: Module

__init__(num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#
forward(dist: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.SphericalBasisLayer(num_spherical: int, num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#

Bases: Module

__init__(num_spherical: int, num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(dist: Tensor, angle: Tensor, idx_kj: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.EmbeddingBlock(num_radial: int, hidden_channels: int, act: Callable)[source]#

Bases: Module

__init__(num_radial: int, hidden_channels: int, act: Callable)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#
forward(x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.ResidualLayer(hidden_channels: int, act: Callable)[source]#

Bases: Module

__init__(hidden_channels: int, act: Callable)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#
forward(x: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.InteractionBlock(hidden_channels: int, num_bilinear: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#

Bases: Module

__init__(hidden_channels: int, num_bilinear: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#
forward(x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.InteractionPPBlock(hidden_channels: int, int_emb_size: int, basis_emb_size: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#

Bases: Module

__init__(hidden_channels: int, int_emb_size: int, basis_emb_size: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#
forward(x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.OutputBlock(num_radial: int, hidden_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#

Bases: Module

__init__(num_radial: int, hidden_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#
forward(x: Tensor, rbf: Tensor, i: Tensor, num_nodes: int | None = None) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class chemtorch.components.model.dimenet.OutputPPBlock(num_radial: int, hidden_channels: int, out_emb_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#

Bases: Module

__init__(num_radial: int, hidden_channels: int, out_emb_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#
forward(x: Tensor, rbf: Tensor, i: Tensor, num_nodes: int | None = None) Tensor[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

chemtorch.components.model.dimenet.triplets(edge_index: Tensor, num_nodes: int) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]#
class chemtorch.components.model.dimenet.DimeNet(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#

Bases: Module

The directional message passing neural network (DimeNet) from the “Directional Message Passing for Molecular Graphs” paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.

Note

For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.

Parameters:
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • num_bilinear (int) – Size of the bilinear layer tensor.

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff (float, optional) – Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip (int, optional) – Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip (int, optional) – Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers (int, optional) – Number of linear layers for the output blocks. (default: 3)

  • act (str or Callable, optional) – The activation function. (default: "swish")

  • output_initializer (str, optional) – The initialization method for the output layer ("zeros", "glorot_orthogonal"). (default: "zeros")

url = 'https://github.com/klicperajo/dimenet/raw/master/pretrained/dimenet'#
__init__(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

reset_parameters()[source]#

Resets all learnable parameters of the module.

embed(z: Tensor, pos: Tensor, batch_indices: Tensor) Tensor[source]#

Embed node features and perform multiple interaction and output blocks.

Parameters:
  • z (Tensor) – Atomic numbers of shape [num_atoms].

  • pos (Tensor) – Atom positions of shape [num_atoms, 3].

  • batch_indices (Tensor) – Batch indices assigning each atom to a separate molecule of shape [num_atoms].

Returns:

The output node features of shape [num_molecules, out_channels].

Return type:

Tensor

forward(batch) Tensor[source]#

Forward pass.

Parameters:

batch (Data) –

A batch of torch_geometric.data.Data objects holding multiple molecular graphs. Must contain the following .. attribute:: z

Atomic number of each atom with shape [num_atoms].

type:

torch.Tensor

pos#

Coordinates of each atom with shape [num_atoms, 3].

Type:

torch.Tensor

batch#

Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

Type:

torch.Tensor, optional

class chemtorch.components.model.dimenet.DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#

Bases: DimeNet

The DimeNet++ from the “Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules” paper.

DimeNetPlusPlus is an upgrade to the DimeNet model with 8x faster and 10% more accurate than DimeNet.

Parameters:
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • int_emb_size (int) – Size of embedding in the interaction block.

  • basis_emb_size (int) – Size of basis embedding in the interaction block.

  • out_emb_channels (int) – Size of embedding in the output block.

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff – (float, optional): Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip – (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip – (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers – (int, optional): Number of linear layers for the output blocks. (default: 3)

  • act – (str or Callable, optional): The activation function. (default: "swish")

  • output_initializer (str, optional) – The initialization method for the output layer ("zeros", "glorot_orthogonal"). (default: "zeros")

url = 'https://raw.githubusercontent.com/gasteigerjo/dimenet/master/pretrained/dimenet_pp'#
__init__(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class chemtorch.components.model.dimenet.DimeReaction(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#

Bases: DimeNetPlusPlus

DimeReaction model used by Spiekerman et al. in https://pubs.acs.org/doi/10.1021/acs.jpca.2c02614.

forward(batch: Batch) Tensor[source]#

Forward pass.

Parameters:

batch (Batch) –

A batch of torch_geometric.data.Data objects holding multiple molecular graphs. Must contain the following .. attribute:: z_r

Atomic number of each atom in the reactant with shape [num_atoms].

type:

torch.Tensor

pos_r#

Coordinates of each atom in the reactant with shape [num_atoms, 3].

Type:

torch.Tensor

z_ts#

Atomic number of each atom in the transition state with shape [num_atoms].

Type:

torch.Tensor

pos_ts#

Coordinates of each atom in the transition state with shape [num_atoms, 3].

Type:

torch.Tensor

batch#

Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

Type:

torch.Tensor, optional

class chemtorch.components.model.han.HAN(embedding_in_channels: int, embedding_hidden_channels: int, gru_hidden_channels: int, class_num: int, dropout=0.2)[source]#

Bases: Module

__init__(embedding_in_channels: int, embedding_hidden_channels: int, gru_hidden_channels: int, class_num: int, dropout=0.2)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Layers#

Layer Stack#

class chemtorch.components.layer.layer_stack.LayerStack(layer: DictConfig, depth: int, share_weights: bool = False)[source]#

Bases: Module, Generic[T]

A utility class for stacking a layer multiple times.

This class is useful for creating deep neural networks by stacking the same layer multiple times.

Note, that the input and output types of the layer must be the same.

__init__(layer: DictConfig, depth: int, share_weights: bool = False)[source]#

Initialize the Stack using Hydra for instantiation.

Parameters:
  • layer (DictConfig) – The configuration for the layer to be stacked.

  • depth (int) – The number of times to repeat the layer.

  • share_weights (bool) – If True, share weights between the stacked layers.

forward(x: T) T[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

GNN Layers#

class chemtorch.components.layer.gnn_layer.dmpnn_stack.EdgeToNodeEmbedding(embedding_size: int, num_node_features: int, aggr: str | Callable = 'add', aggr_kwargs: Dict[str, Any] | None = None, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None)[source]#

Bases: Module

EdgeToNodeEmbedding is a neural network layer that takes edge embeddings and node features, aggregates the edge embeddings based on the node indices, concatenates them with the node features, and passes them through a linear layer followed by an activation function to produce node embeddings.

__init__(embedding_size: int, num_node_features: int, aggr: str | Callable = 'add', aggr_kwargs: Dict[str, Any] | None = None, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None)[source]#

Initialize the EdgeToNodeEmbedding layer.

Parameters:
  • embedding_size (int) – Size of the edge embeddings (also size of the new node embeddings).

  • aggr (Union[str, Callable], optional) – Aggregation method. Defaults to “add”.

  • aggr_kwargs (Optional[Dict[str, Any]], optional) – Additional arguments for aggregation. Defaults to None.

  • act (Union[str, Callable, None], optional) – Activation function. Defaults to “relu”.

  • act_kwargs (Optional[Dict[str, Any]], optional) – Additional arguments for activation function. Defaults to None.

forward(batch: Batch) Batch[source]#

Forward pass through the EdgeToNodeEmbedding layer.

Parameters:

batch (Batch) – The input batch of graphs containing node features and edge embeddings.

Returns:

The output batch with updated node features.

Return type:

Batch

class chemtorch.components.layer.gnn_layer.dmpnn_stack.DMPNNStack(dmpnn_blocks: LayerStack[Batch], edge_to_node_embedding: EdgeToNodeEmbedding)[source]#

Bases: Module

DMPNNStack is a neural network layer that implements a sequence of directed message passing steps, followed by an edge-to-node embedding layer, which generates node embeddings from the original node features and the edge embeddings obtained from the directed message passing steps.

__init__(dmpnn_blocks: LayerStack[Batch], edge_to_node_embedding: EdgeToNodeEmbedding)[source]#

Initialize the DMPNNStack.

Parameters:
  • dmpnn_blocks (Stack[DMPNNBlock]) – A stack of DMPNN blocks that perform directed message passing.

  • edge_to_node_embedding (EdgeToNodeEmbedding) – The layer that converts edge features to node features.

forward(batch: Batch) Tensor[source]#

Forward pass through the DMPNN layer.

Parameters:

batch (Batch) – The input batch of graphs.

Returns:

The output predictions.

Return type:

torch.Tensor