Utilities#

Helper functions and utilities used throughout ChemTorch.

Types#

class chemtorch.utils.types.DataLoaderFactoryProtocol(*args, **kwargs)[source]#

Bases: Protocol

Protocol defining the interface for dataloader factory functions.

__init__(*args, **kwargs)#
class chemtorch.utils.types.RoutineFactoryProtocol(*args, **kwargs)[source]#

Bases: Protocol

Protocol defining the interface for routine factory functions.

__init__(*args, **kwargs)#
class chemtorch.utils.types.DataSplit(train: T, val: T, test: T)[source]#

Bases: Generic[T]

A data structure to hold the data splits for training, validation, and testing. Provides named tuple-like access with generic typing support.

__init__(train: T, val: T, test: T) None[source]#
to_dict() dict[str, T][source]#

Hydra Utilities#

chemtorch.utils.hydra.resolve_target(target_str)[source]#

Resolve a string like ‘module.submodule.ClassName’ to the actual class.

chemtorch.utils.hydra.order_config_by_signature(cfg)[source]#

Recursively reorder configuration dictionaries to match the argument order of their specified _target_ class or function’s constructor signature.

Parameters:

cfg (dict or DictConfig) – The configuration dictionary or OmegaConf DictConfig.

Returns:

A reordered configuration dictionary with keys ordered according to the

constructor signature of the _target_ if present, otherwise recursively processes child dictionaries.

Return type:

dict

Why is this important?#

By default, Hydra instantiates objects in the order that keys appear in the configuration dictionary. However, the order of keys in a YAML or Python dict is not guaranteed to match the order of arguments in the target class or function’s constructor. In frameworks like PyTorch, the order in which submodules (layers, blocks, etc.) are instantiated directly affects the order in which random numbers are consumed for parameter initialization. This means that simply reordering keys in the config file can lead to different random initializations and thus different results, even if the model architecture and random seed remain unchanged.

This function ensures that the instantiation order of all objects (and their subcomponents) is invariant to the order of keys in the configuration file. It does so by reordering each config dictionary to match the argument order of the corresponding `_target_`s constructor. This guarantees reproducible model initialization and results, regardless of how the config is written.

chemtorch.utils.hydra.filter_config_by_signature(cfg)[source]#

Recursively filter configuration dictionaries to only include keys that match the argument names of their specified _target_ class or function’s constructor signature. If the signature supports **kwargs, do not filter. Special Hydra keys (from _Keys) are always preserved.

Parameters:

cfg (dict or DictConfig) – The configuration dictionary or OmegaConf DictConfig.

Returns:

A filtered configuration dictionary.

Return type:

dict

Why is this important?#

This is useful for convenient key interpolation in configs. If multiple lower-level configs in the hierarchy depend on a shared value, it is convenient to define it in a higher-level config and use a universal interpolation syntax to reference that key, regardless of the lower-level config structure (which can differ). For example, the hidden_channels argument is shared across multiple components of a GNN, such as encoders, convolution layers, and higher-level blocks. However, the higher-level GNN class that assembles these components might not have a hidden_channels argument itself. If you specify a hidden_channels key in the higher-level config and use top-level interpolation, Hydra will raise a ValueError when it tries to instantiate the GNN with the config, since the argument is unexpected. The use case for this function is to resolve the config (so all interpolations are applied), then filter it so that only valid keys for each _target_ remain, avoiding instantiation errors and enabling flexible, DRY config design.

chemtorch.utils.hydra.safe_instantiate(cfg, *args, **kwargs)[source]#

Safely instantiate an object from a configuration dictionary, ensuring that the configuration is ordered and filtered according to the target’s constructor signature.

Parameters:
  • cfg (dict or DictConfig) – The configuration dictionary or OmegaConf DictConfig.

  • *args – Positional arguments to pass to the target’s constructor.

  • **kwargs – Keyword arguments to pass to the target’s constructor.

Returns:

An instance of the target class specified in the configuration.

Return type:

object

chemtorch.utils.hydra.get_num_workers(slurm_env: str = 'SLURM_CPUS_PER_TASK', leave_free: int = 1) int[source]#

Return a safe number of PyTorch DataLoader workers.

Behavior: - If the SLURM env var (default: SLURM_CPUS_PER_TASK) is present and an

integer, use min(slurm_value, os.cpu_count()) - leave_free.

  • Otherwise use os.cpu_count() - leave_free.

  • Never return a negative number; result is clamped to 0.

leave_free defaults to 1 to leave one CPU for the main process. This is a common pattern but optional depending on your workload.

CLI Utilities#

Standardizer#

class chemtorch.utils.standardizer.Standardizer(mean: float | Tensor | ndarray, std: float | Tensor | ndarray)[source]#

Bases: object

__init__(mean: float | Tensor | ndarray, std: float | Tensor | ndarray) None[source]#

Create a standardizer to standardize sample using the given mean and standard deviation.

Parameters:
  • mean (Union[float, torch.Tensor, np.ndarray]) – Mean value(s) for standardization.

  • std (Union[float, torch.Tensor, np.ndarray]) – Standard deviation value(s) for standardization.

Raises:
  • TypeError – If mean or std are not of type torch.Tensor or np.ndarray, or if they are not of the same type.

  • ValueError – If mean and std do not have the same shape.

standardize(x: Tensor) Tensor[source]#

Standardize the input data by subtracting the mean and dividing by the standard deviation.

Parameters:

x (torch.Tensor) – Input data to standardize.

Returns:

Standardized data.

Return type:

torch.Tensor

destandardize(x: Tensor) Tensor[source]#

Reverse the standardization of the input data by multiplying by the standard deviation

Parameters:

x (torch.Tensor) – Input data to reverse standardize.

Returns:

Reverse standardized data.

Return type:

torch.Tensor

static validate(mean: float | Tensor | ndarray, std: float | Tensor | ndarray) None[source]#

Validate the mean and standard deviation values.

Parameters:
  • mean (Union[float, torch.Tensor, np.ndarray]) – Mean value(s) for standardization.

  • std (Union[float, torch.Tensor, np.ndarray]) – Standard deviation value(s) for standardization.

Raises:
  • TypeError – If mean or std are not of type torch.Tensor or np.ndarray, or if they are not of the same type.

  • ValueError – If mean and std do not have the same shape.

Reaction Utilities#

chemtorch.utils.reaction_utils.get_atom_index_by_mapnum(mol: Mol, mapnum: int) int | None[source]#

Get the atom index for an atom with a specific atom map number.

Parameters:
  • mol (Chem.Mol) – The RDKit molecule object to search.

  • mapnum (int) – The atom map number to search for.

Returns:

The atom index if found, None otherwise.

Return type:

Optional[int]

chemtorch.utils.reaction_utils.unmap_smarts(smarts: str) str[source]#

Remove atom map numbers from a SMARTS string.

Parameters:

smarts (str) – The input SMARTS string with atom map numbers.

Returns:

The SMARTS string with atom map numbers removed.

Return type:

str

Raises:

ValueError – If the SMARTS string cannot be parsed.

chemtorch.utils.reaction_utils.unmap_smiles(smiles: str) str[source]#

Remove atom map numbers from a SMILES string.

Parameters:

smiles (str) – The input SMILES string with atom map numbers.

Returns:

The SMILES string with atom map numbers removed.

Return type:

str

Raises:

ValueError – If the SMILES string cannot be parsed.

chemtorch.utils.reaction_utils.smarts2smarts(smarts: str) str[source]#

Parse and reformat a SMARTS string to ensure proper formatting.

Parameters:

smarts (str) – The input SMARTS string.

Returns:

The reformatted SMARTS string.

Return type:

str

Raises:

ValueError – If the SMARTS string cannot be parsed.

chemtorch.utils.reaction_utils.smiles2smiles(smiles_str: str) str[source]#

Parse and reformat a SMILES string to ensure proper formatting.

Parameters:

smiles_str (str) – The input SMILES string.

Returns:

The reformatted SMILES string.

Return type:

str

Raises:

ValueError – If the SMILES string cannot be parsed.

chemtorch.utils.reaction_utils.bondtypes(atom: Atom) List[BondType][source]#

Get a sorted list of bond types for all bonds connected to an atom.

Parameters:

atom (Chem.Atom) – The RDKit atom object.

Returns:

A sorted list of bond types connected to the atom.

Return type:

List[Chem.BondType]

chemtorch.utils.reaction_utils.neighbors(atom: Atom) List[int][source]#

Get a sorted list of atom map numbers for all neighboring atoms.

Parameters:

atom (Chem.Atom) – The RDKit atom object.

Returns:

A sorted list of atom map numbers of neighboring atoms.

Return type:

List[int]

chemtorch.utils.reaction_utils.neighbors_and_bondtypes(atom: Atom) List[int | BondType][source]#

Get a combined list of neighboring atom map numbers and bond types.

Parameters:

atom (Chem.Atom) – The RDKit atom object.

Returns:

A combined list containing neighbor map numbers and bond types.

Return type:

List[Union[int, Chem.BondType]]

chemtorch.utils.reaction_utils.remove_atoms_from_rxn(mr: Mol, mp: Mol, atoms_to_remove: ndarray) List[str][source]#

Remove specified atoms from reactant and product molecules and return SMILES.

Parameters:
  • mr (Chem.Mol) – The reactant molecule.

  • mp (Chem.Mol) – The product molecule.

  • atoms_to_remove (np.ndarray) – A 2D numpy array of shape (N, 2) where each row contains a pair of atom indices [reactant_idx, product_idx] representing corresponding atoms to be removed from the reactant and product molecules respectively. The indices are 0-based atom indices within each molecule.

Returns:

A list containing the SMILES strings of the modified reactant and product.

Return type:

List[str]

chemtorch.utils.reaction_utils.get_reaction_core(r_smiles: str, p_smiles: str) Tuple[str, List[int]][source]#

Extract the reaction core by removing atoms that don’t change during the reaction.

This function identifies atoms in the reactant and product that have identical neighborhoods (same neighbors and bond types) and removes them to extract only the reaction core - the atoms that actually participate in the transformation.

Parameters:
  • r_smiles (str) – The reactant SMILES string.

  • p_smiles (str) – The product SMILES string.

Returns:

A tuple containing:
  • The reaction SMILES with only the changing atoms (reaction core)

  • A list of atom map numbers that were removed (0-indexed)

Return type:

Tuple[str, List[int]]

Raises:

ValueError – If SMILES cannot be parsed or processed.

Atom Mapping#

class chemtorch.utils.atom_mapping.AtomOriginType(value)[source]#

Bases: IntEnum

An enumeration.

REACTANT = 0#
PRODUCT = 1#
DUMMY = 2#
REACTANT_PRODUCT = 3#
class chemtorch.utils.atom_mapping.EdgeOriginType(value)[source]#

Bases: IntEnum

An enumeration.

REACTANT = 0#
PRODUCT = 1#
DUMMY = 2#
REACTANT_PRODUCT = 3#
chemtorch.utils.atom_mapping.make_mol(smi: str) Tuple[Mol, List[int]][source]#

Create RDKit mol with atom mapping.

Parameters:

smi – SMILES string

Returns:

Tuple containing the RDKit molecule and a list of atom origins

chemtorch.utils.atom_mapping.map_reac_to_prod(mol_reac: Mol, mol_prod: Mol) Dict[int, int][source]#

Map reactant atom indices to product atom indices.

Parameters:
  • mol_reac – Reactant molecule

  • mol_prod – Product molecule

Returns:

Dictionary mapping reactant atom indices to product atom indices

chemtorch.utils.atom_mapping.remove_atom_mapping(smiles: str)[source]#

Callable Composition#

class chemtorch.utils.callable_compose.CallableCompose(callables: Sequence[Callable])[source]#

Bases: Generic[R, T]

CallableCompose composes a sequence of callables into a single callable object.

This class is generic over input type R and output type T, and allows you to chain multiple callables (functions, transforms, etc.) so that the output of each is passed as the input to the next. The composed object itself is callable and applies all callables in order.

Typical use cases include composing data transforms, preprocessing steps, or any sequence of operations that should be applied in a pipeline fashion.

Type parameters:

R: The input type to the first callable. T: The output type of the final callable.

Example

>>> def double(x: int) -> int:
...     return x * 2
>>> def stringify(x: int) -> str:
...     return str(x)
>>> composed = CallableCompose([double, stringify])
>>> composed(3)
'6'

You can also use the static method compose for a more concise syntax:

>>> composed = CallableCompose.compose(double, stringify)
>>> composed(4)
'8'
__init__(callables: Sequence[Callable]) None[source]#

Initializes the CallableCompose.

Parameters:

callables (Sequence[Callable]) – A sequence of callables to be composed. Each callable should accept the output type of the previous callable (the first should accept type R, the last should return type T).

static compose(*callables: Callable[[T], T]) CallableCompose[T, T][source]#

Convenience method to compose multiple callables of the same input/output type.

Parameters:

*callables – Callables to compose, each of type Callable[[T], T].

Returns:

A composed callable applying all in order.

Return type:

CallableCompose[T, T]

Miscellaneous#

chemtorch.utils.misc.save_predictions(preds: List[Any], reference_df: DataFrame, save_path: str | Path | None, log_func: Callable[[...], Any] | None = None, root_dir: Path | None = None) DataFrame | None[source]#

Process predictions and save them to a dataframe with the original data.

Parameters:
  • preds – List of predictions from trainer.predict() or trainer.test() (can be tensors or other types) Each element in the list is typically a batch of predictions

  • reference_df – Original dataframe to add predictions to

  • save_path – Path to save the predictions CSV file (relative to root_dir if provided)

  • log_func – Optional logging function (e.g., wandb.log)

  • root_dir – Root directory for resolving relative save_path

Returns:

DataFrame with predictions added, or None if processing failed

chemtorch.utils.misc.handle_prediction_saving(get_preds_func: Callable[[str], List[Any]], get_reference_df_func: Callable[[str], DataFrame], get_dataset_names_func: Callable[[], List[str]], predictions_save_dir: str | None = None, predictions_save_path: str | None = None, save_predictions_for: str | List[str] | ListConfig | None = None, tasks: List[str] | ListConfig | None = None, log_func: Callable | None = None, root_dir: Path | None = None) None[source]#

Handle prediction saving logic with configurable prediction and reference data retrieval.

Parameters:
  • get_preds_func – Function that takes a dataset_key and returns predictions

  • get_reference_df_func – Function that takes a dataset_key and returns reference dataframe

  • get_dataset_names_func – Function that returns all available dataset names/keys

  • predictions_save_dir – Directory to save predictions (for multiple partitions)

  • predictions_save_path – Specific path to save predictions (for single partition)

  • save_predictions_for – String or list of dataset keys to save predictions for

  • tasks – List of tasks being executed (used for determining available partitions)

  • log_func – Optional logging function (e.g., wandb.log)

  • root_dir – Root directory for relative path resolution

Decorators#

chemtorch.utils.decorators.enforce_base_init.enforce_base_init(base_cls: type) Callable[source]#

Decorator to enforce that subclasses of a given base class call the base class’s __init__ method. This is useful for ensuring that the base class’s initialization logic is always executed.

Parameters:

base_cls (type) – The base class whose __init__ method must be called.

Returns:

A decorator that enforces the base class’s __init__ method call.

Return type:

Callable

Raises:

RuntimeError – If the subclass’s __init__ method does not call the base class’s __init__ method.

Example

class Base:
def __init__(self):

print(“Base init called”) self._initialized_by_base = True

def __init_subclass__(cls):

enforce_base_init(Base)(cls) return super().__init_subclass__()

class SubClass(Base):
def __init__(self):

print(“SubClass init called”) super().__init__() # Must call base class’s __init__

sub = SubClass() # This will work

Example

>>> class Base:
...     def __init__(self):
...         print("Base init called")
...         self._initialized_by_base = True
...
...     def __init_subclass__(cls):
...         enforce_base_init(Base)(cls)
...         return super().__init_subclass__()
...
>>> class GoodSubClass(Base):
...     def __init__(self):
...         super().__init__()  # Must call base class's __init__
...         print("SubClass init called")
...
>>> class BadSubClass(Base):
...     def __init__(self):
...         print("OtherSubClass init called")
...         # super().__init__()  # This will raise an error
...
>>> sub = GoodSubClass()  # This will work
Base init called
GoodSubClass init called
>>> other_sub = BadSubClass()  # This will raise an error
Traceback (most recent call last):
    ...
RuntimeError: BadSubClass.__init__() must call super().__init__() from Base