Utilities#
Helper functions and utilities used throughout ChemTorch.
Types#
- class chemtorch.utils.types.DataLoaderFactoryProtocol(*args, **kwargs)[source]#
Bases:
ProtocolProtocol defining the interface for dataloader factory functions.
- __init__(*args, **kwargs)#
- class chemtorch.utils.types.RoutineFactoryProtocol(*args, **kwargs)[source]#
Bases:
ProtocolProtocol defining the interface for routine factory functions.
- __init__(*args, **kwargs)#
Hydra Utilities#
- chemtorch.utils.hydra.resolve_target(target_str)[source]#
Resolve a string like ‘module.submodule.ClassName’ to the actual class.
- chemtorch.utils.hydra.order_config_by_signature(cfg)[source]#
Recursively reorder configuration dictionaries to match the argument order of their specified _target_ class or function’s constructor signature.
- Parameters:
cfg (dict or DictConfig) – The configuration dictionary or OmegaConf DictConfig.
- Returns:
- A reordered configuration dictionary with keys ordered according to the
constructor signature of the _target_ if present, otherwise recursively processes child dictionaries.
- Return type:
Why is this important?#
By default, Hydra instantiates objects in the order that keys appear in the configuration dictionary. However, the order of keys in a YAML or Python dict is not guaranteed to match the order of arguments in the target class or function’s constructor. In frameworks like PyTorch, the order in which submodules (layers, blocks, etc.) are instantiated directly affects the order in which random numbers are consumed for parameter initialization. This means that simply reordering keys in the config file can lead to different random initializations and thus different results, even if the model architecture and random seed remain unchanged.
This function ensures that the instantiation order of all objects (and their subcomponents) is invariant to the order of keys in the configuration file. It does so by reordering each config dictionary to match the argument order of the corresponding `_target_`s constructor. This guarantees reproducible model initialization and results, regardless of how the config is written.
- chemtorch.utils.hydra.filter_config_by_signature(cfg)[source]#
Recursively filter configuration dictionaries to only include keys that match the argument names of their specified _target_ class or function’s constructor signature. If the signature supports **kwargs, do not filter. Special Hydra keys (from _Keys) are always preserved.
- Parameters:
cfg (dict or DictConfig) – The configuration dictionary or OmegaConf DictConfig.
- Returns:
A filtered configuration dictionary.
- Return type:
Why is this important?#
This is useful for convenient key interpolation in configs. If multiple lower-level configs in the hierarchy depend on a shared value, it is convenient to define it in a higher-level config and use a universal interpolation syntax to reference that key, regardless of the lower-level config structure (which can differ). For example, the hidden_channels argument is shared across multiple components of a GNN, such as encoders, convolution layers, and higher-level blocks. However, the higher-level GNN class that assembles these components might not have a hidden_channels argument itself. If you specify a hidden_channels key in the higher-level config and use top-level interpolation, Hydra will raise a ValueError when it tries to instantiate the GNN with the config, since the argument is unexpected. The use case for this function is to resolve the config (so all interpolations are applied), then filter it so that only valid keys for each _target_ remain, avoiding instantiation errors and enabling flexible, DRY config design.
- chemtorch.utils.hydra.safe_instantiate(cfg, *args, **kwargs)[source]#
Safely instantiate an object from a configuration dictionary, ensuring that the configuration is ordered and filtered according to the target’s constructor signature.
- Parameters:
cfg (dict or DictConfig) – The configuration dictionary or OmegaConf DictConfig.
*args – Positional arguments to pass to the target’s constructor.
**kwargs – Keyword arguments to pass to the target’s constructor.
- Returns:
An instance of the target class specified in the configuration.
- Return type:
- chemtorch.utils.hydra.get_num_workers(slurm_env: str = 'SLURM_CPUS_PER_TASK', leave_free: int = 1) int[source]#
Return a safe number of PyTorch DataLoader workers.
Behavior: - If the SLURM env var (default:
SLURM_CPUS_PER_TASK) is present and aninteger, use min(slurm_value, os.cpu_count()) - leave_free.
Otherwise use os.cpu_count() - leave_free.
Never return a negative number; result is clamped to 0.
leave_free defaults to 1 to leave one CPU for the main process. This is a common pattern but optional depending on your workload.
CLI Utilities#
Standardizer#
- class chemtorch.utils.standardizer.Standardizer(mean: float | Tensor | ndarray, std: float | Tensor | ndarray)[source]#
Bases:
object- __init__(mean: float | Tensor | ndarray, std: float | Tensor | ndarray) None[source]#
Create a standardizer to standardize sample using the given mean and standard deviation.
- Parameters:
mean (Union[float, torch.Tensor, np.ndarray]) – Mean value(s) for standardization.
std (Union[float, torch.Tensor, np.ndarray]) – Standard deviation value(s) for standardization.
- Raises:
TypeError – If mean or std are not of type torch.Tensor or np.ndarray, or if they are not of the same type.
ValueError – If mean and std do not have the same shape.
- standardize(x: Tensor) Tensor[source]#
Standardize the input data by subtracting the mean and dividing by the standard deviation.
- Parameters:
x (torch.Tensor) – Input data to standardize.
- Returns:
Standardized data.
- Return type:
- destandardize(x: Tensor) Tensor[source]#
Reverse the standardization of the input data by multiplying by the standard deviation
- Parameters:
x (torch.Tensor) – Input data to reverse standardize.
- Returns:
Reverse standardized data.
- Return type:
- static validate(mean: float | Tensor | ndarray, std: float | Tensor | ndarray) None[source]#
Validate the mean and standard deviation values.
- Parameters:
mean (Union[float, torch.Tensor, np.ndarray]) – Mean value(s) for standardization.
std (Union[float, torch.Tensor, np.ndarray]) – Standard deviation value(s) for standardization.
- Raises:
TypeError – If mean or std are not of type torch.Tensor or np.ndarray, or if they are not of the same type.
ValueError – If mean and std do not have the same shape.
Reaction Utilities#
- chemtorch.utils.reaction_utils.get_atom_index_by_mapnum(mol: Mol, mapnum: int) int | None[source]#
Get the atom index for an atom with a specific atom map number.
- chemtorch.utils.reaction_utils.unmap_smarts(smarts: str) str[source]#
Remove atom map numbers from a SMARTS string.
- Parameters:
smarts (str) – The input SMARTS string with atom map numbers.
- Returns:
The SMARTS string with atom map numbers removed.
- Return type:
- Raises:
ValueError – If the SMARTS string cannot be parsed.
- chemtorch.utils.reaction_utils.unmap_smiles(smiles: str) str[source]#
Remove atom map numbers from a SMILES string.
- Parameters:
smiles (str) – The input SMILES string with atom map numbers.
- Returns:
The SMILES string with atom map numbers removed.
- Return type:
- Raises:
ValueError – If the SMILES string cannot be parsed.
- chemtorch.utils.reaction_utils.smarts2smarts(smarts: str) str[source]#
Parse and reformat a SMARTS string to ensure proper formatting.
- Parameters:
smarts (str) – The input SMARTS string.
- Returns:
The reformatted SMARTS string.
- Return type:
- Raises:
ValueError – If the SMARTS string cannot be parsed.
- chemtorch.utils.reaction_utils.smiles2smiles(smiles_str: str) str[source]#
Parse and reformat a SMILES string to ensure proper formatting.
- Parameters:
smiles_str (str) – The input SMILES string.
- Returns:
The reformatted SMILES string.
- Return type:
- Raises:
ValueError – If the SMILES string cannot be parsed.
- chemtorch.utils.reaction_utils.bondtypes(atom: Atom) List[BondType][source]#
Get a sorted list of bond types for all bonds connected to an atom.
- Parameters:
atom (Chem.Atom) – The RDKit atom object.
- Returns:
A sorted list of bond types connected to the atom.
- Return type:
List[Chem.BondType]
- chemtorch.utils.reaction_utils.neighbors(atom: Atom) List[int][source]#
Get a sorted list of atom map numbers for all neighboring atoms.
- Parameters:
atom (Chem.Atom) – The RDKit atom object.
- Returns:
A sorted list of atom map numbers of neighboring atoms.
- Return type:
List[int]
- chemtorch.utils.reaction_utils.neighbors_and_bondtypes(atom: Atom) List[int | BondType][source]#
Get a combined list of neighboring atom map numbers and bond types.
- Parameters:
atom (Chem.Atom) – The RDKit atom object.
- Returns:
A combined list containing neighbor map numbers and bond types.
- Return type:
List[Union[int, Chem.BondType]]
- chemtorch.utils.reaction_utils.remove_atoms_from_rxn(mr: Mol, mp: Mol, atoms_to_remove: ndarray) List[str][source]#
Remove specified atoms from reactant and product molecules and return SMILES.
- Parameters:
mr (Chem.Mol) – The reactant molecule.
mp (Chem.Mol) – The product molecule.
atoms_to_remove (np.ndarray) – A 2D numpy array of shape (N, 2) where each row contains a pair of atom indices [reactant_idx, product_idx] representing corresponding atoms to be removed from the reactant and product molecules respectively. The indices are 0-based atom indices within each molecule.
- Returns:
A list containing the SMILES strings of the modified reactant and product.
- Return type:
List[str]
- chemtorch.utils.reaction_utils.get_reaction_core(r_smiles: str, p_smiles: str) Tuple[str, List[int]][source]#
Extract the reaction core by removing atoms that don’t change during the reaction.
This function identifies atoms in the reactant and product that have identical neighborhoods (same neighbors and bond types) and removes them to extract only the reaction core - the atoms that actually participate in the transformation.
- Parameters:
- Returns:
- A tuple containing:
The reaction SMILES with only the changing atoms (reaction core)
A list of atom map numbers that were removed (0-indexed)
- Return type:
- Raises:
ValueError – If SMILES cannot be parsed or processed.
Atom Mapping#
- class chemtorch.utils.atom_mapping.AtomOriginType(value)[source]#
Bases:
IntEnumAn enumeration.
- REACTANT = 0#
- PRODUCT = 1#
- DUMMY = 2#
- REACTANT_PRODUCT = 3#
- class chemtorch.utils.atom_mapping.EdgeOriginType(value)[source]#
Bases:
IntEnumAn enumeration.
- REACTANT = 0#
- PRODUCT = 1#
- DUMMY = 2#
- REACTANT_PRODUCT = 3#
- chemtorch.utils.atom_mapping.make_mol(smi: str) Tuple[Mol, List[int]][source]#
Create RDKit mol with atom mapping.
- Parameters:
smi – SMILES string
- Returns:
Tuple containing the RDKit molecule and a list of atom origins
Callable Composition#
- class chemtorch.utils.callable_compose.CallableCompose(callables: Sequence[Callable])[source]#
Bases:
Generic[R,T]CallableCompose composes a sequence of callables into a single callable object.
This class is generic over input type R and output type T, and allows you to chain multiple callables (functions, transforms, etc.) so that the output of each is passed as the input to the next. The composed object itself is callable and applies all callables in order.
Typical use cases include composing data transforms, preprocessing steps, or any sequence of operations that should be applied in a pipeline fashion.
- Type parameters:
R: The input type to the first callable. T: The output type of the final callable.
Example
>>> def double(x: int) -> int: ... return x * 2 >>> def stringify(x: int) -> str: ... return str(x) >>> composed = CallableCompose([double, stringify]) >>> composed(3) '6'
You can also use the static method compose for a more concise syntax:
>>> composed = CallableCompose.compose(double, stringify) >>> composed(4) '8'
- __init__(callables: Sequence[Callable]) None[source]#
Initializes the CallableCompose.
- Parameters:
callables (Sequence[Callable]) – A sequence of callables to be composed. Each callable should accept the output type of the previous callable (the first should accept type R, the last should return type T).
- static compose(*callables: Callable[[T], T]) CallableCompose[T, T][source]#
Convenience method to compose multiple callables of the same input/output type.
- Parameters:
*callables – Callables to compose, each of type Callable[[T], T].
- Returns:
A composed callable applying all in order.
- Return type:
CallableCompose[T, T]
Miscellaneous#
- chemtorch.utils.misc.save_predictions(preds: List[Any], reference_df: DataFrame, save_path: str | Path | None, log_func: Callable[[...], Any] | None = None, root_dir: Path | None = None) DataFrame | None[source]#
Process predictions and save them to a dataframe with the original data.
- Parameters:
preds – List of predictions from trainer.predict() or trainer.test() (can be tensors or other types) Each element in the list is typically a batch of predictions
reference_df – Original dataframe to add predictions to
save_path – Path to save the predictions CSV file (relative to root_dir if provided)
log_func – Optional logging function (e.g., wandb.log)
root_dir – Root directory for resolving relative save_path
- Returns:
DataFrame with predictions added, or None if processing failed
- chemtorch.utils.misc.handle_prediction_saving(get_preds_func: Callable[[str], List[Any]], get_reference_df_func: Callable[[str], DataFrame], get_dataset_names_func: Callable[[], List[str]], predictions_save_dir: str | None = None, predictions_save_path: str | None = None, save_predictions_for: str | List[str] | ListConfig | None = None, tasks: List[str] | ListConfig | None = None, log_func: Callable | None = None, root_dir: Path | None = None) None[source]#
Handle prediction saving logic with configurable prediction and reference data retrieval.
- Parameters:
get_preds_func – Function that takes a dataset_key and returns predictions
get_reference_df_func – Function that takes a dataset_key and returns reference dataframe
get_dataset_names_func – Function that returns all available dataset names/keys
predictions_save_dir – Directory to save predictions (for multiple partitions)
predictions_save_path – Specific path to save predictions (for single partition)
save_predictions_for – String or list of dataset keys to save predictions for
tasks – List of tasks being executed (used for determining available partitions)
log_func – Optional logging function (e.g., wandb.log)
root_dir – Root directory for relative path resolution
Decorators#
- chemtorch.utils.decorators.enforce_base_init.enforce_base_init(base_cls: type) Callable[source]#
Decorator to enforce that subclasses of a given base class call the base class’s __init__ method. This is useful for ensuring that the base class’s initialization logic is always executed.
- Parameters:
base_cls (type) – The base class whose __init__ method must be called.
- Returns:
A decorator that enforces the base class’s __init__ method call.
- Return type:
Callable
- Raises:
RuntimeError – If the subclass’s __init__ method does not call the base class’s __init__ method.
Example
- class Base:
- def __init__(self):
print(“Base init called”) self._initialized_by_base = True
- def __init_subclass__(cls):
enforce_base_init(Base)(cls) return super().__init_subclass__()
- class SubClass(Base):
- def __init__(self):
print(“SubClass init called”) super().__init__() # Must call base class’s __init__
sub = SubClass() # This will work
Example
>>> class Base: ... def __init__(self): ... print("Base init called") ... self._initialized_by_base = True ... ... def __init_subclass__(cls): ... enforce_base_init(Base)(cls) ... return super().__init_subclass__() ... >>> class GoodSubClass(Base): ... def __init__(self): ... super().__init__() # Must call base class's __init__ ... print("SubClass init called") ... >>> class BadSubClass(Base): ... def __init__(self): ... print("OtherSubClass init called") ... # super().__init__() # This will raise an error ... >>> sub = GoodSubClass() # This will work Base init called GoodSubClass init called >>> other_sub = BadSubClass() # This will raise an error Traceback (most recent call last): ... RuntimeError: BadSubClass.__init__() must call super().__init__() from Base