Components#
The components package contains all swappable building blocks for the ChemTorch pipeline.
Data Pipeline#
Data Source#
- class chemtorch.components.data_pipeline.data_source.abstract_data_source.AbstractDataSource[source]#
Bases:
ABCAbstract base class for data sources.
This class defines the interface for loading data from various sources. Subclasses should implement the load method to provide specific data loading functionality.
- class chemtorch.components.data_pipeline.data_source.single_csv_source.SingleCSVSource(data_path: str)[source]#
Bases:
AbstractDataSource
- class chemtorch.components.data_pipeline.data_source.pre_split_csv_source.PreSplitCSVSource(data_folder: str)[source]#
Bases:
AbstractDataSource
Column Mapper#
- class chemtorch.components.data_pipeline.column_mapper.abstract_column_mapper.AbstractColumnMapper[source]#
Bases:
ABC
- class chemtorch.components.data_pipeline.column_mapper.column_filter_rename.ColumnFilterAndRename(**column_mappings)[source]#
Bases:
AbstractColumnMapperA pipeline component that filters and renames columns in a DataFrame based on provided column mappings.
- Usage:
mapper = ColumnFilterAndRename(smiles=”smiles_column”, label=”target_column”) # This will rename “smiles_column” to “smiles” and “target_column” to “label”
Data Splitter#
- class chemtorch.components.data_pipeline.data_splitter.abstract_data_splitter.AbstractDataSplitter[source]#
Bases:
ABCAbstract base class for data splitting strategies.
Subclass should implement the __call__ method to define the splitting logic.
- class chemtorch.components.data_pipeline.data_splitter.data_splitter_base.DataSplitterBase(save_path: str | None = None)[source]#
Bases:
AbstractDataSplitterBase class for data splitting strategies. Callable that takes a DataFrame, executes splitting logic, and returns a DataSplit object. Subclass should implement the private _split method to define specific splitting logic.
- __init__(save_path: str | None = None) None[source]#
Initializes the DataSplitter.
- Parameters:
save_path (str | None, optional) – If provided, saves split indices as a pickle file to this path. Must end with ‘.pkl’. Defaults to None.
- Raises:
ValueError – If the provided save_path is not a pickle file (i.e. does not end with ‘.pkl’).
- class chemtorch.components.data_pipeline.data_splitter.ratio_splitter.RatioSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, save_path: str | None = None)[source]#
Bases:
DataSplitterBaseSplits data into training, validation, and test sets based on specified ratios.
Subclasses should override the _split method to implement custom splitting logic. By default the RatioSplitter randomly shuffles the data and splits it according to the specified ratios.
- class chemtorch.components.data_pipeline.data_splitter.scaffold_splitter.ScaffoldSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, split_on: str | None = None, mol_idx: int | None = None, save_path: str | None = None)[source]#
Bases:
SMILESGroupSplitterBase- __init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, split_on: str | None = None, mol_idx: int | None = None, save_path: str | None = None)[source]#
Initializes the ScaffoldSplitter.
Splits data by grouping molecules based on their Murcko scaffold, ensuring that all molecules with the same scaffold are in the same split (train, val, or test). This is a standard method to test a model’s ability to generalize to new chemical scaffolds.
- Parameters:
train_ratio (float) – The desired ratio of data for the training set.
val_ratio (float) – The desired ratio of data for the validation set.
test_ratio (float) – The desired ratio of data for the test set.
include_chirality (bool) – If True, includes chirality in the scaffold generation.
split_on (str | None) – Specifies whether to use the ‘reactant’ or ‘product’ for scaffold generation when processing reaction SMILES. Required for reaction SMILES, ignored for single molecules. Defaults to None.
mol_idx (int | None) – Zero-based index specifying which molecule to use if multiple are present (e.g., ‘A.B>>C’ or ‘A.B’). Required when multiple molecules are present in the selected part of the reaction. Defaults to None.
save_path (str | None, optional) – If provided, saves split indices as pickle file.
- class chemtorch.components.data_pipeline.data_splitter.target_splitter.TargetSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, sort_order: str = 'ascending', save_path: str | None = None)[source]#
Bases:
RatioSplitter
- class chemtorch.components.data_pipeline.data_splitter.size_splitter.SizeSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, sort_order: str = 'ascending', save_path: str | None = None)[source]#
Bases:
RatioSplitterSplits data into training, validation, and test sets based on molecular size (number of heavy atoms). For single molecules, sorts by the number of heavy atoms in each molecule. For reactions, sorts by the sum of heavy atoms in reactants and products.
- __init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, sort_order: str = 'ascending', save_path: str | None = None)[source]#
Initializes the SizeSplitter.
- Parameters:
train_ratio (float) – The ratio of data for the training set.
val_ratio (float) – The ratio of data for the validation set.
test_ratio (float) – The ratio of data for the test set.
sort_order (str) – ‘ascending’ or ‘descending’.
save_path (str | None, optional) – If provided, saves split indices as pickle file.
- class chemtorch.components.data_pipeline.data_splitter.reaction_core_splitter.ReactionCoreSplitter(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, save_path: str | None = None)[source]#
Bases:
SMILESGroupSplitterBase- __init__(train_ratio: float = 0.8, val_ratio: float = 0.1, test_ratio: float = 0.1, include_chirality: bool = False, save_path: str | None = None)[source]#
Initializes the ReactionCoreSplitter.
Splits data by grouping reactions based on their reaction core/template, ensuring that all reactions with the same reaction core are in the same split (train, val, or test). This is a method to test a model’s ability to generalize to new reaction types.
- Parameters:
train_ratio (float) – The desired ratio of data for the training set.
val_ratio (float) – The desired ratio of data for the validation set.
test_ratio (float) – The desired ratio of data for the test set.
include_chirality (bool) – If True, includes chirality in the reaction core generation.
save_path (str | None, optional) – If provided, saves split indices as pickle file.
- class chemtorch.components.data_pipeline.data_splitter.index_splitter.IndexSplitter(split_index_path: str, save_path: str | None = None)[source]#
Bases:
DataSplitterBase- __init__(split_index_path: str, save_path: str | None = None)[source]#
Initializes the IndexSplitter with the specified index path.
- Parameters:
split_index_path (str) – The path to the pickle file containing the index.
save_split_dir (str | None, optional) – If provided, enables saving of split files.
save_indices (bool) – If True and save_split_dir is set, re-saves ‘indices.pkl’.
save_csv (bool) – If True and save_split_dir is set, saves split DataFrames as CSVs.
Pipeline#
- class chemtorch.components.data_pipeline.simple_data_pipeline.SimpleDataPipeline(data_source: AbstractDataSource, column_mapper: AbstractColumnMapper, data_splitter: AbstractDataSplitter | None = None)[source]#
Bases:
objectA simple data pipeline that orchestrates data loading, column mapping, and data splitting.
The ingestion process is as follows: 1. Load data using the data_source. This can result in a single
DataFrame or an already split DataSplit object.
Apply column transformations (filtering, renaming) using the column_mapper. This mapper can operate on both single DataFrames and DataSplit objects.
If the data after mapping is a single DataFrame, split it using the data_splitter. If it’s already a DataSplit, this step is skipped.
- __init__(data_source: AbstractDataSource, column_mapper: AbstractColumnMapper, data_splitter: AbstractDataSplitter | None = None)[source]#
Initializes the SimpleDataPipeline.
- Parameters:
data_source (DataSource) – The component responsible for loading the initial data.
column_mapper (ColumnMapper) – The component for transforming columns. It should handle both pd.DataFrame and DataSplit inputs.
data_splitter (Optional[AbstractDataSplitter]) – The component for splitting a single DataFrame into train, validation, and test sets. This is not used if data_source already provides split data.
Representation#
Base Classes#
- class chemtorch.components.representation.abstract_representation.AbstractRepresentation[source]#
-
Abstract base class for all chemistry representation creators.
All representations in ChemTorch must subclass this class and implement the construct method. This class can also be used for type hinting where AbstractRepresentation[T] means that the __call__ and construct methods will return an object of type T.
The representation should be stateless - it should not hold any mutable state and the same input should always produce the same output.
- Raises:
TypeError – If the subclass does not implement the construct method.
- Example (correct usage):
>>> class MyRepresentation(AbstractRepresentation[torch.Tensor]): ... def construct(self, smiles: str) -> torch.Tensor: ... # Convert SMILES to tensor representation ... return torch.tensor([1, 2, 3]) >>> r = MyRepresentation() >>> r("CCO") # ethanol tensor([1, 2, 3])
- Example (incorrect usage, raises TypeError):
>>> class BadRepresentation(AbstractRepresentation[torch.Tensor]): ... pass >>> r = BadRepresentation() Traceback (most recent call last): ... TypeError: Can't instantiate abstract class BadRepresentation with abstract method construct
- abstract construct(smiles: str, **kwargs) T[source]#
Construct a representation from a SMILES string.
- Parameters:
smiles (str) – A SMILES string representing a molecule or reaction. For reactions, typically in the format “reactants>reagents>products” or “reactants>>products” (e.g., “CCO>>CC=O”). For molecules, a standard SMILES string (e.g., “CCO”).
**kwargs – Additional keyword arguments that may be required by specific representation implementations (e.g., reaction_dir for Reaction3DGraph).
- Returns:
- The constructed representation of the specified type.
Common types include torch.Tensor for token representations, torch_geometric.data.Data for graph representations, etc.
- Return type:
T
- Raises:
ValueError – If the SMILES string is invalid or cannot be processed.
RuntimeError – If representation construction fails for any other reason.
Graph Representations#
- class chemtorch.components.representation.graph.cgr.CGR(atom_featurizer: FeaturizerBase[Atom] | FeaturizerCompose, bond_featurizer: FeaturizerBase[Bond] | FeaturizerCompose, **kwargs)[source]#
Bases:
AbstractRepresentation[Data]Stateless class for constructing Condensed Graph of Reaction (CGR) representations.
This class does not hold any data itself. Instead, it provides a forward() method that takes a sample (e.g., a dict or pd.Series) and returns a PyTorch Geometric Data object representing the reaction as a graph.
- Usage:
>>> from chemtorch.components.representation.graph.featurizer import FeaturizerCompose >>> from chemtorch.components.representation.graph.featurizer.atom_featurizer import ( ... AtomicNumber, AtomDegree, AtomFormalCharge, AtomIsAromatic ... ) >>> from chemtorch.components.representation.graph.featurizer.bond_featurizer import ( ... BondType, BondIsInRing ... ) >>> >>> atom_featurizer = FeaturizerCompose([ ... AtomicNumber(), ... AtomDegree(), ... AtomFormalCharge(), ... AtomIsAromatic(), ... ]) >>> bond_featurizer = FeaturizerCompose([ ... BondType(), ... BondIsInRing(), ... ]) >>> >>> cgr = CGR(atom_featurizer, bond_featurizer) >>> data = cgr.construct("CC>>CCO") # reaction SMILES >>> data = cgr("CC>>CCO") # equivalent to above line
- __init__(atom_featurizer: FeaturizerBase[Atom] | FeaturizerCompose, bond_featurizer: FeaturizerBase[Bond] | FeaturizerCompose, **kwargs)[source]#
Initialize the CGR representation with atom and bond featurizers.
- Parameters:
atom_featurizer (FeaturizerBase[Atom] | FeaturizerCompose) – A featurizer for atom features, which can be a single featurizer or a composed one.
bond_featurizer (FeaturizerBase[Bond] | FeaturizerCompose) – A featurizer for bond features, which can also be a single featurizer or a composed one.
- chemtorch.components.representation.graph.reaction_3d_graph.read_xyz(file_path: str) Tuple[List[str], Tensor][source]#
Reads a standard XYZ file and returns atomic symbols and coordinates.
- Parameters:
file_path (str) – The path to the .xyz file.
- Returns:
A tuple containing a list of atomic symbols (str) and a tensor of atomic coordinates ([num_atoms, 3]).
- Raises:
FileNotFoundError – If the file does not exist.
ValueError – If the file is malformed or cannot be parsed.
- chemtorch.components.representation.graph.reaction_3d_graph.symbols_to_atomic_numbers(symbols: List[str]) Tensor[source]#
Converts a list of atomic symbols (e.g., [‘C’, ‘H’]) to a tensor of atomic numbers.
- Parameters:
symbols (List[str]) – A list of atomic symbols.
- Returns:
A tensor of atomic numbers.
- class chemtorch.components.representation.graph.reaction_3d_graph.Reaction3DData(z_r: Tensor, pos_r: Tensor, z_ts: Tensor, pos_ts: Tensor, smiles: str | None = None, num_nodes: int | None = None, **kwargs)[source]#
Bases:
DataCustom PyG Data class for reaction 3D graphs.
- z_r#
Atomic numbers for the reactant.
- Type:
- pos_r#
Atomic positions for the reactant.
- Type:
- z_ts#
Atomic numbers for the transition state.
- Type:
- pos_ts#
Atomic positions for the transition state.
- Type:
- class chemtorch.components.representation.graph.reaction_3d_graph.Reaction3DGraph(root_dir: str)[source]#
Bases:
AbstractRepresentation[Data]Constructs a 3D representation of a reaction from XYZ files.
This representation reads the 3D structures for a reactant (r) and transition state (ts) from their respective .xyz files and packages them into a single PyTorch Geometric Data object (specifically, a Reaction3DData object).
The resulting Reaction3DData object for each sample contains: - z_r, pos_r: Atomic numbers and 3D coordinates for the reactant - z_ts, pos_ts: Atomic numbers and 3D coordinates for the transition state - smiles: The reaction SMILES string (for reference) - num_nodes: Total number of atoms in the structure
This representation is particularly useful for models like DimeReaction that operate on 3D geometries of chemical reactions.
- __init__(root_dir: str)[source]#
- Parameters:
root_dir (str) – The root directory where reaction subfolders (e.g., ‘reaction_1’, ‘reaction_2’) are located.
- construct(smiles: str, reaction_dir: str) Data[source]#
Constructs a single reaction graph from its corresponding XYZ files.
This method is called by DatasetBase for each row in the DataFrame and reads the reactant and transition state XYZ files from the specified reaction directory.
- Parameters:
- Returns:
z_r, pos_r: Atomic numbers and positions for reactant
z_ts, pos_ts: Atomic numbers and positions for transition state
smiles: The reaction SMILES
num_nodes: Number of atoms
- Return type:
A Reaction3DData object containing the 3D structures with attributes
- Raises:
FileNotFoundError – If the reaction directory or any required .xyz file is not found.
ValueError – If the number of atoms is inconsistent between reactant and TS structures.
Fingerprint Representations#
- class chemtorch.components.representation.fingerprint.drfp.DRFP(n_folded_length: int = 2048, min_radius: int = 0, radius: int = 3, rings: bool = True, root_central_atom: bool = True, include_hydrogens: bool = False)[source]#
Bases:
AbstractRepresentation[Tensor]Stateless class for constructing DRFP (Differential Reaction Fingerprints).
This class provides a forward() method that takes a reaction SMILES string and returns a PyTorch Tensor representing the folded DRFP fingerprint. It utilizes the DrfpEncoder class for the core fingerprint generation logic.
- __init__(n_folded_length: int = 2048, min_radius: int = 0, radius: int = 3, rings: bool = True, root_central_atom: bool = True, include_hydrogens: bool = False)[source]#
Initializes the DRFP representation creator.
- Parameters:
n_folded_length (int, optional) – The length of the folded fingerprint. Default is 2048.
min_radius (int, optional) – The minimum radius for substructure extraction. Default is 0 (includes single atoms).
radius (int, optional) – The maximum radius for substructure extraction. Default is 3 (corresponds to DRFP6).
rings (bool, optional) – Whether to include full rings as substructures. Default is True.
root_central_atom (bool, optional) – Whether to root the central atom of substructures when generating SMILES. Default is True.
include_hydrogens (bool, optional) – Whether to include hydrogens in the molecular representation. Default is False.
- construct(smiles: str) Tensor[source]#
Generates a DRFP fingerprint for a single reaction SMILES string.
The method uses DrfpEncoder.internal_encode to get the hashed difference features of the reaction and then DrfpEncoder.fold to create the final binary fingerprint vector as a NumPy array, which is then converted to a PyTorch Tensor.
- Parameters:
smiles – The reaction SMILES string (e.g., “R1.R2>A>P1.P2”).
- Returns:
A PyTorch Tensor (dtype=torch.float32) representing the folded DRFP fingerprint.
- Raises:
NoReactionError – If the input SMILES is not a valid reaction SMILES (as detected by DrfpEncoder.internal_encode).
RuntimeError – For other errors encountered during fingerprint generation, wrapping the original exception.
- class chemtorch.components.representation.fingerprint.drfp.DRFPUtil[source]#
Bases:
objectA utility class for encoding SMILES as drfp fingerprints.
- static shingling_from_mol(in_mol: Mol, radius: int = 3, rings: bool = True, min_radius: int = 0, get_atom_indices: bool = False, root_central_atom: bool = True, include_hydrogens: bool = False) List[str] | Tuple[List[str], Dict[str, List[Set[int]]]][source]#
Creates a molecular shingling from a RDKit molecule (rdkit.Chem.rdchem.Mol).
- Parameters:
in_mol – A RDKit molecule instance
radius – The drfp radius (a radius of 3 corresponds to drfp6)
rings – Whether or not to include rings in the shingling
min_radius – The minimum radius that is used to extract n-grams
- Returns:
The molecular shingling.
- static internal_encode(in_smiles: str, radius: int = 3, min_radius: int = 0, rings: bool = True, get_atom_indices: bool = False, root_central_atom: bool = True, include_hydrogens: bool = False) Tuple[ndarray, ndarray] | Tuple[ndarray, ndarray, Dict[str, List[Dict[str, List[Set[int]]]]]][source]#
Creates an drfp array from a reaction SMILES string.
- Parameters:
in_smiles – A valid reaction SMILES string
radius – The drfp radius (a radius of 3 corresponds to drfp6)
min_radius – The minimum radius that is used to extract n-grams
rings – Whether or not to include rings in the shingling
- Returns:
A tuple with two arrays, the first containing the drfp hash values, the second the substructure SMILES
- static hash(shingling: List[str]) ndarray[source]#
Directly hash all the SMILES in a shingling to a 32-bit integer.
- Parameters:
shingling – A list of n-grams
- Returns:
A list of hashed n-grams
- static fold(hash_values: ndarray, length: int = 2048) Tuple[ndarray, ndarray][source]#
Folds the hash values to a binary vector of a given length.
- Parameters:
hash_value – An array containing the hash values
length – The length of the folded fingerprint
- Returns:
A tuple containing the folded fingerprint and the indices of the on bits
- static encode(X: Iterable | str, n_folded_length: int = 2048, min_radius: int = 0, radius: int = 3, rings: bool = True, mapping: bool = False, atom_index_mapping: bool = False, root_central_atom: bool = True, include_hydrogens: bool = False, show_progress_bar: bool = False) List[ndarray] | Tuple[List[ndarray], Dict[int, Set[str]]] | List[Dict[str, List[Dict[str, List[Set[int]]]]]][source]#
Encodes a list of reaction SMILES using the drfp fingerprint.
- Parameters:
X – An iterable (e.g. List) of reaction SMILES or a single reaction SMILES to be encoded
n_folded_length – The folded length of the fingerprint (the parameter for the modulo hashing)
min_radius – The minimum radius of a substructure (0 includes single atoms)
radius – The maximum radius of a substructure
rings – Whether to include full rings as substructures
mapping – Return a feature to substructure mapping in addition to the fingerprints
atom_index_mapping – Return the atom indices of mapped substructures for each reaction
root_central_atom – Whether to root the central atom of substructures when generating SMILES
show_progress_bar – Whether to show a progress bar when encoding reactions
- Returns:
A list of drfp fingerprints or, if mapping is enabled, a tuple containing a list of drfp fingerprints and a mapping dict.
Transform#
Base Classes#
- class chemtorch.components.transform.abstract_transform.AbstractTransform[source]#
-
Abstract base class for transforms in the chemtorch framework. This class serves as a base for creating transforms that operate on single data objects.
- Raises:
TypeError – If the subclass does not implement the
__call__method.
- Example (correct usage):
>>> class MyTransform(TransformBase[int]): ... def __call__(self, obj: int) -> int: ... return obj * 2 >>> t = MyTransform() >>> t(3) 6
- Example (incorrect usage, raises TypeError):
>>> class BadTransform(TransformBase[int]): ... pass >>> t = BadTransform() Traceback (most recent call last): ... TypeError: Can't instantiate abstract class BadTransform with abstract method __call__
Graph Transforms#
Augmentation#
- class chemtorch.components.augmentation.abstract_augmentation.AbstractAugmentation[source]#
-
Abstract base class for data augmentations in the chemtorch framework.
This class serves as a base for creating data augmentations that operate on data points. A data point (or sample) is an object of type T returned by the representation and an optional label tensor.
Subclasses must implement the
__call__method to define the augmentation logic.- Raises:
TypeError – If the subclass does not implement the
__call__method.RuntimeError – If the input object has a label, but not all augmented objects have labels, or vice versa.
Model#
Graph Neural Networks#
- class chemtorch.components.model.gnn.gnn.GNN(encoder: Callable[[Batch], Batch], layer_stack: Callable[[Batch], Batch], pool: Callable[[Batch], Tensor], head: Callable[[Tensor], Tensor])[source]#
Bases:
Module- __init__(encoder: Callable[[Batch], Batch], layer_stack: Callable[[Batch], Batch], pool: Callable[[Batch], Tensor], head: Callable[[Tensor], Tensor])[source]#
- Parameters:
encoder (Callable[[Batch], Batch]) – The encoder function that processes the input batch.
layer_stack (Callable[[Batch], Batch]) – The GNN layer stack that processes the batch.
pool (Callable[[Batch], torch.Tensor]) – The pooling function that converts the graphs to a tensor.
head (Callable[[torch.Tensor], torch.Tensor]) – The head function for final prediction.
Encoder#
Pooling#
Other Models#
- class chemtorch.components.model.mlp.MLP(in_channels: int, out_channels: int, hidden_dims: List[int] | None = None, hidden_size: int | None = None, num_hidden_layers: int | None = None, dropout: float = 0.0, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None, norm: str | Callable | None = None, norm_kwargs: Dict[str, Any] | None = None)[source]#
Bases:
ModuleMulti-Layer Perceptron (MLP) with configurable layers and activation functions.
- __init__(in_channels: int, out_channels: int, hidden_dims: List[int] | None = None, hidden_size: int | None = None, num_hidden_layers: int | None = None, dropout: float = 0.0, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None, norm: str | Callable | None = None, norm_kwargs: Dict[str, Any] | None = None)[source]#
Initializes an MLP. The architecture is built to match FFNHead’s structure and instantiation order precisely.
A Dropout layer is placed before every Linear layer.
Activation is applied after every hidden Linear layer.
Normalization is applied after every hidden Linear layer (if specified).
- Parameters:
in_channels (int) – Input dimension.
out_channels (int) – Output dimension.
hidden_dims (List[int], optional) – List of hidden layer dimensions (preferred).
hidden_size (int, optional) – Hidden layer size (used if hidden_dims is not provided).
num_hidden_layers (int, optional) – Number of hidden layers (used if hidden_dims is not provided).
dropout (float, optional) – Dropout rate. Defaults to 0.
act (str or Callable, optional) – Activation function. Defaults to “relu”.
act_kwargs (Dict[str, Any], optional) – Arguments for the activation function.
norm (str or Callable, optional) – Normalization layer. Defaults to None.
norm_kwargs (Dict[str, Any], optional) – Arguments for the normalization layer.
- class chemtorch.components.model.dimenet.Envelope(exponent: int)[source]#
Bases:
Module- __init__(exponent: int)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.BesselBasisLayer(num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#
Bases:
Module- __init__(num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(dist: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.SphericalBasisLayer(num_spherical: int, num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#
Bases:
Module- __init__(num_spherical: int, num_radial: int, cutoff: float = 5.0, envelope_exponent: int = 5)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(dist: Tensor, angle: Tensor, idx_kj: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.EmbeddingBlock(num_radial: int, hidden_channels: int, act: Callable)[source]#
Bases:
Module- __init__(num_radial: int, hidden_channels: int, act: Callable)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.ResidualLayer(hidden_channels: int, act: Callable)[source]#
Bases:
Module- __init__(hidden_channels: int, act: Callable)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.InteractionBlock(hidden_channels: int, num_bilinear: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#
Bases:
Module- __init__(hidden_channels: int, num_bilinear: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.InteractionPPBlock(hidden_channels: int, int_emb_size: int, basis_emb_size: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#
Bases:
Module- __init__(hidden_channels: int, int_emb_size: int, basis_emb_size: int, num_spherical: int, num_radial: int, num_before_skip: int, num_after_skip: int, act: Callable)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.OutputBlock(num_radial: int, hidden_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#
Bases:
Module- __init__(num_radial: int, hidden_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, rbf: Tensor, i: Tensor, num_nodes: int | None = None) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class chemtorch.components.model.dimenet.OutputPPBlock(num_radial: int, hidden_channels: int, out_emb_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#
Bases:
Module- __init__(num_radial: int, hidden_channels: int, out_emb_channels: int, out_channels: int, num_layers: int, act: Callable, output_initializer: str = 'zeros')[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, rbf: Tensor, i: Tensor, num_nodes: int | None = None) Tensor[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- chemtorch.components.model.dimenet.triplets(edge_index: Tensor, num_nodes: int) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]#
- class chemtorch.components.model.dimenet.DimeNet(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#
Bases:
ModuleThe directional message passing neural network (DimeNet) from the “Directional Message Passing for Molecular Graphs” paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.
Note
For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.
- Parameters:
hidden_channels (int) – Hidden embedding size.
out_channels (int) – Size of each output sample.
num_blocks (int) – Number of building blocks.
num_bilinear (int) – Size of the bilinear layer tensor.
num_spherical (int) – Number of spherical harmonics.
num_radial (int) – Number of radial basis functions.
cutoff (float, optional) – Cutoff distance for interatomic interactions. (default:
5.0)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)envelope_exponent (int, optional) – Shape of the smooth cutoff. (default:
5)num_before_skip (int, optional) – Number of residual layers in the interaction blocks before the skip connection. (default:
1)num_after_skip (int, optional) – Number of residual layers in the interaction blocks after the skip connection. (default:
2)num_output_layers (int, optional) – Number of linear layers for the output blocks. (default:
3)act (str or Callable, optional) – The activation function. (default:
"swish")output_initializer (str, optional) – The initialization method for the output layer (
"zeros","glorot_orthogonal"). (default:"zeros")
- url = 'https://github.com/klicperajo/dimenet/raw/master/pretrained/dimenet'#
- __init__(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- embed(z: Tensor, pos: Tensor, batch_indices: Tensor) Tensor[source]#
Embed node features and perform multiple interaction and output blocks.
- Parameters:
z (Tensor) – Atomic numbers of shape
[num_atoms].pos (Tensor) – Atom positions of shape
[num_atoms, 3].batch_indices (Tensor) – Batch indices assigning each atom to a separate molecule of shape
[num_atoms].
- Returns:
The output node features of shape
[num_molecules, out_channels].- Return type:
Tensor
- forward(batch) Tensor[source]#
Forward pass.
- Parameters:
batch (Data) –
A batch of
torch_geometric.data.Dataobjects holding multiple molecular graphs. Must contain the following .. attribute:: zAtomic number of each atom with shape
[num_atoms].- type:
torch.Tensor
- pos#
Coordinates of each atom with shape
[num_atoms, 3].- Type:
- batch#
Batch indices assigning each atom to a separate molecule with shape
[num_atoms]. (default:None)- Type:
torch.Tensor, optional
- class chemtorch.components.model.dimenet.DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#
Bases:
DimeNetThe DimeNet++ from the “Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules” paper.
DimeNetPlusPlusis an upgrade to theDimeNetmodel with 8x faster and 10% more accurate thanDimeNet.- Parameters:
hidden_channels (int) – Hidden embedding size.
out_channels (int) – Size of each output sample.
num_blocks (int) – Number of building blocks.
int_emb_size (int) – Size of embedding in the interaction block.
basis_emb_size (int) – Size of basis embedding in the interaction block.
out_emb_channels (int) – Size of embedding in the output block.
num_spherical (int) – Number of spherical harmonics.
num_radial (int) – Number of radial basis functions.
cutoff – (float, optional): Cutoff distance for interatomic interactions. (default:
5.0)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)envelope_exponent (int, optional) – Shape of the smooth cutoff. (default:
5)num_before_skip – (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default:
1)num_after_skip – (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default:
2)num_output_layers – (int, optional): Number of linear layers for the output blocks. (default:
3)act – (str or Callable, optional): The activation function. (default:
"swish")output_initializer (str, optional) – The initialization method for the output layer (
"zeros","glorot_orthogonal"). (default:"zeros")
- url = 'https://raw.githubusercontent.com/gasteigerjo/dimenet/master/pretrained/dimenet_pp'#
- __init__(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class chemtorch.components.model.dimenet.DimeReaction(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, head, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: str | Callable = 'swish', output_initializer: str = 'zeros')[source]#
Bases:
DimeNetPlusPlusDimeReaction model used by Spiekerman et al. in https://pubs.acs.org/doi/10.1021/acs.jpca.2c02614.
- forward(batch: Batch) Tensor[source]#
Forward pass.
- Parameters:
batch (Batch) –
A batch of
torch_geometric.data.Dataobjects holding multiple molecular graphs. Must contain the following .. attribute:: z_rAtomic number of each atom in the reactant with shape
[num_atoms].- type:
torch.Tensor
- pos_r#
Coordinates of each atom in the reactant with shape
[num_atoms, 3].- Type:
- z_ts#
Atomic number of each atom in the transition state with shape
[num_atoms].- Type:
- pos_ts#
Coordinates of each atom in the transition state with shape
[num_atoms, 3].- Type:
- batch#
Batch indices assigning each atom to a separate molecule with shape
[num_atoms]. (default:None)- Type:
torch.Tensor, optional
- class chemtorch.components.model.han.HAN(embedding_in_channels: int, embedding_hidden_channels: int, gru_hidden_channels: int, class_num: int, dropout=0.2)[source]#
Bases:
Module- __init__(embedding_in_channels: int, embedding_hidden_channels: int, gru_hidden_channels: int, class_num: int, dropout=0.2)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Layers#
Layer Stack#
- class chemtorch.components.layer.layer_stack.LayerStack(layer: DictConfig, depth: int, share_weights: bool = False)[source]#
-
A utility class for stacking a layer multiple times.
This class is useful for creating deep neural networks by stacking the same layer multiple times.
Note, that the input and output types of the layer must be the same.
- __init__(layer: DictConfig, depth: int, share_weights: bool = False)[source]#
Initialize the Stack using Hydra for instantiation.
- forward(x: T) T[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
GNN Layers#
- class chemtorch.components.layer.gnn_layer.dmpnn_stack.EdgeToNodeEmbedding(embedding_size: int, num_node_features: int, aggr: str | Callable = 'add', aggr_kwargs: Dict[str, Any] | None = None, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None)[source]#
Bases:
ModuleEdgeToNodeEmbedding is a neural network layer that takes edge embeddings and node features, aggregates the edge embeddings based on the node indices, concatenates them with the node features, and passes them through a linear layer followed by an activation function to produce node embeddings.
- __init__(embedding_size: int, num_node_features: int, aggr: str | Callable = 'add', aggr_kwargs: Dict[str, Any] | None = None, act: str | Callable | None = 'relu', act_kwargs: Dict[str, Any] | None = None)[source]#
Initialize the EdgeToNodeEmbedding layer.
- Parameters:
embedding_size (int) – Size of the edge embeddings (also size of the new node embeddings).
aggr (Union[str, Callable], optional) – Aggregation method. Defaults to “add”.
aggr_kwargs (Optional[Dict[str, Any]], optional) – Additional arguments for aggregation. Defaults to None.
act (Union[str, Callable, None], optional) – Activation function. Defaults to “relu”.
act_kwargs (Optional[Dict[str, Any]], optional) – Additional arguments for activation function. Defaults to None.
- class chemtorch.components.layer.gnn_layer.dmpnn_stack.DMPNNStack(dmpnn_blocks: LayerStack[Batch], edge_to_node_embedding: EdgeToNodeEmbedding)[source]#
Bases:
ModuleDMPNNStack is a neural network layer that implements a sequence of directed message passing steps, followed by an edge-to-node embedding layer, which generates node embeddings from the original node features and the edge embeddings obtained from the directed message passing steps.
- __init__(dmpnn_blocks: LayerStack[Batch], edge_to_node_embedding: EdgeToNodeEmbedding)[source]#
Initialize the DMPNNStack.
- Parameters:
dmpnn_blocks (Stack[DMPNNBlock]) – A stack of DMPNN blocks that perform directed message passing.
edge_to_node_embedding (EdgeToNodeEmbedding) – The layer that converts edge features to node features.