Core

Contents

Core#

The core package contains PyTorch Lightning modules that orchestrate the training loop.

Data Module#

class chemtorch.core.data_module.DataModule(data_pipeline: Callable[[...], DataSplit | DataFrame], representation: AbstractRepresentation, dataloader_factory: DataLoaderFactoryProtocol, transform: Any | Callable[[Any], Any] | Dict[Literal['train', 'val', 'test', 'predict'], Any | Callable[[Any], Any]] | Dict[Literal['train', 'val', 'test', 'predict'], List[Any | Callable[[Any], Any]] | Dict[str, Any | Callable[[Any], Any]]] | None = None, augmentations: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | Dict[str, Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | Dict[Literal['train', 'val', 'test', 'predict'], int | float] | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False)[source]#

Bases: LightningDataModule

__init__(data_pipeline: Callable[[...], DataSplit | DataFrame], representation: AbstractRepresentation, dataloader_factory: DataLoaderFactoryProtocol, transform: Any | Callable[[Any], Any] | Dict[Literal['train', 'val', 'test', 'predict'], Any | Callable[[Any], Any]] | Dict[Literal['train', 'val', 'test', 'predict'], List[Any | Callable[[Any], Any]] | Dict[str, Any | Callable[[Any], Any]]] | None = None, augmentations: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | Dict[str, Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | Dict[Literal['train', 'val', 'test', 'predict'], int | float] | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False) None[source]#

Initialize the DataModule with a data pipeline, dataset factory, and dataloader factory.

Parameters:
  • data_pipeline (Callable) – A callable that returns a DataSplit or a pandas DataFrame.

  • representation (AbstractRepresentation) – An instance of a representation class to convert raw data into model-ready format.

  • dataloader_factory (DataLoaderFactoryProtocol) – A DataLoader class or factory function to create DataLoader instances that work with the data format returned by the representation. Typically, this will be a partially initialized object that subclasses the torch.utils.data.DataLoader class.

  • transform (Optional[Union[TransformType, Dict[DatasetKey, TransformType], Dict[DatasetKey, Union[List[TransformType], Dict[str, TransformType]]]]]) – An optional transform or a dictionary of transforms for each stage (‘train’, ‘val’, ‘test’, ‘predict’). If a single transform is provided, it will be applied to all stages. If a dictionary is provided, it should map each dataset key to its corresponding transform. For the test set, a list or dict of transforms can be provided to create multiple test datasets. If a dict is provided, the keys should be the names of the test datasets.

  • augmentations (Optional[List[AbstractAugmentation]] | Dict[str, AbstractAugmentation]) – An optional list or dictionary of augmentations to be applied to the training dataset.

  • subsample (Optional[Union[int, float, Dict[DatasetKey, Union[int, float]]]]) – An optional integer or float to subsample the datasets. If a float is provided, it should be between 0 and 1 and represents the fraction of the dataset to keep. If an integer is provided, it represents the exact number of samples to keep. If a dictionary is provided, it should map each dataset key to its corresponding subsample fraction or count.

  • precompute_all (bool) – If True, precompute all samples of the dataset. Default is True.

  • cache (bool) – If True, enable caching of dataset samples. Default is False.

  • max_cache_size (Optional[int]) – Maximum number of samples to cache. If None, cache size is unlimited. Default is None.

  • integer_labels (bool) – If True, convert labels to integers. Default is False.

Raises:

TypeError – If the output of the data pipeline is not a DataSplit or a pandas DataFrame.

get_dataset(key: str) DatasetBase[source]#

Retrieve the dataset for the specified key.

Parameters:

key (str) – The key for which to retrieve the dataset (‘train’, ‘val’, ‘test’, ‘predict’, ‘test_<name>’).

Returns:

The dataset corresponding to the specified key.

Return type:

DatasetBase

Raises:

ValueError – If the dataset for the specified key is not initialized.

get_dataset_names() List[str][source]#

Get all available dataset names/keys.

Returns:

List of all dataset keys that can be used with get_dataset() and make_dataloader().

Return type:

List[str]

make_dataloader(key: str) DataLoader[source]#

Create a dataloader for the specified key.

Parameters:

key (str) – The key for which to create the dataloader (‘train’, ‘val’, ‘test’, ‘predict’, ‘test_<name>’).

Returns:

The created dataloader for the specified dataset key.

Return type:

DataLoader

Raises:

ValueError – If the dataset for the specified key is not initialized.

maybe_get_test_dataloader_idx_to_suffix() Dict[int, str] | None[source]#

If multiple named test datasets are initialized by passing a dict with named test set transforms, return a mapping from dataloader index to suffix. Otherwise, return None.

Returns:

A mapping from dataloader index to suffix, or None. Index 0 is always the main “test” dataset (no suffix). Indices 1+ correspond to additional test datasets sorted alphabetically.

Return type:

Dict[int, str]

train_dataloader()[source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

test_dataloader()[source]#

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • test()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

predict_dataloader()[source]#

An iterable or collection of iterables specifying prediction samples.

For more information about multiple dataloaders, see this section.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • predict()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Returns:

A torch.utils.data.DataLoader or a sequence of them specifying prediction samples.

Dataset#

class chemtorch.core.dataset_base.DatasetBase(dataframe: DataFrame, representation: R, name: str | None = None, transform: AbstractTransform[T] | Callable[[T], T] | None = None, augmentation_list: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False)[source]#

Bases: Generic[T, R]

Base class for ChemTorch datasets with type-safe representations.

Type parameters:

T: The data type produced by the representation and returned by the dataset R: The representation type (bounded by AbstractRepresentation)

Type safety features:
  • The dataset returns objects of type T (or Tuple[T, torch.Tensor] with labels)

  • The representation is of type R (bounded by AbstractRepresentation)

  • Transforms must be compatible with type T

  • Static type checker can verify basic type relationships

Usage examples:

# Graph dataset DatasetBase[Data, GraphRepresentation]

# Token dataset DatasetBase[torch.Tensor, TokenRepresentation]

# Fingerprint dataset DatasetBase[torch.Tensor, FingerprintRepresentation]

The dataset can handle both labeled and unlabeled data. If the input DataFrame contains a ‘label’ column, the dataset will return tuples of (data_object, label). Otherwise, it will return only the data objects.

Warning: If the subclass inherits from multiple classes, ensure that DatasetBase is the first class in the inheritance list to ensure correct method resolution order (MRO).

Raises:

RuntimeError – If the subclass does not call super().__init__() in its __init__() method.

__init__(dataframe: DataFrame, representation: R, name: str | None = None, transform: AbstractTransform[T] | Callable[[T], T] | None = None, augmentation_list: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False)[source]#

Initialize the DatasetBase.

Parameters:
  • dataframe (pd.DataFrame) – The input data as a pandas DataFrame. Each row represents a single sample. If the dataset contains a label column, it will be returned alongside the computed representation. Otherwise, only the representation will be returned.

  • representation (R) – A representation instance that constructs the data object consumed by the model. Must take in the fields of a single sample from the dataframe (row) as keyword arguments and return an object of type T.

  • name (Optional[str]) – An optional name for the dataset to be used in logging (<name> dataset). Default is None.

  • transform (Optional[AbstractTransform[T] | Callable[[T], T]]) – An optional transformation function or a composition thereof (Compose) that takes in an object of type T and returns a (possibly modified) object of the same type.

  • augmentation_list (Optional[List[AbstractAugmentation[T] | Callable[[T, torch.Tensor], List[Tuple[T, torch.Tensor]]]]]) – An optional list of data augmentation functions that take in an object of type T and return a (possibly modified) object of the same type. Note: Augmentations are only applied to the training partition, and only if precompute_all is True.

  • subsample (Optional[int | float]) – The subsample size or fraction. If None, no subsampling is done. If an int, it specifies the number of samples to take. If a float, it specifies the fraction of samples to take. Default is None.

  • precompute_all (bool) – If True, precompute all samples in the dataset. Default is True.

  • cache (bool) – If True, cache the processed samples. Default is False.

  • max_cache_size (Optional[int]) – Maximum size of the cache. Default is None.

  • integer_labels (bool) – Whether to use integer labels (for classification) or float labels (for regression). If True, labels will be torch.int64. If False, labels will be torch.float. Default is False.

Raises:
  • ValueError – If the dataframe is not a pandas DataFrame.

  • ValueError – If the representation is not an AbstractRepresentation instance.

  • ValueError – If the transform is not a TransformBase, a callable, or None.

get_labels()[source]#

Retrieve the labels for the dataset.

Returns:

The labels for the dataset if they exist.

Return type:

pd.Series

Raises:

RuntimeError – If the dataset does not contain labels.

Routine#

Supervised Routine#

class chemtorch.core.routine.supervised_routine.SupervisedRoutine(model: Module, loss: Callable | None = None, optimizer: Callable[[Iterator[Parameter]], Optimizer] | None = None, lr_scheduler: Callable[[Optimizer], LRScheduler] | Dict[str, Any] | None = None, ckpt_path: str | None = None, resume_training: bool = False, metrics: Metric | MetricCollection | Dict[str, Metric | MetricCollection] | None = None, test_dataloader_idx_to_suffix: Dict[int, str] | None = None)[source]#

Bases: LightningModule

A flexible LightningModule wrapper for supervised tasks, supporting both training and inference.

This class can be used for:
  • Full training/validation/testing with loss, optimizer, scheduler, and metrics.

  • Inference-only (prediction), requiring only the model.

Example usage:

>>> # Training usage
>>> routine = SupervisedRoutine(
...     model=my_model,
...     loss=my_loss_fn,
...     optimizer=lambda params: torch.optim.Adam(params, lr=1e-3),
...     lr_scheduler=lambda opt: torch.optim.lr_scheduler.StepLR(opt, step_size=10),
...     metrics=my_metrics,
... )
>>> trainer = pl.Trainer(...)
>>> trainer.fit(routine, datamodule=my_datamodule)
>>> # Inference-only usage
>>> routine = SupervisedRoutine(model=my_model)
>>> preds = routine(torch.randn(8, 16))  # Forward pass for prediction
__init__(model: Module, loss: Callable | None = None, optimizer: Callable[[Iterator[Parameter]], Optimizer] | None = None, lr_scheduler: Callable[[Optimizer], LRScheduler] | Dict[str, Any] | None = None, ckpt_path: str | None = None, resume_training: bool = False, metrics: Metric | MetricCollection | Dict[str, Metric | MetricCollection] | None = None, test_dataloader_idx_to_suffix: Dict[int, str] | None = None)[source]#

Initialize the SupervisedRoutine.

Parameters:
  • model (nn.Module) – The model to be trained or used for inference.

  • loss (Callable, optional) – The loss function to be used. Required for training/validation/testing.

  • optimizer (Callable, optional) –

    A factory function that takes in the model’s parameters and returns an optimizer instance. Required for training/validation/testing.

    Example

    optimizer=lambda params: torch.optim.Adam(params, lr=1e-3)

  • lr_scheduler (Callable or Dict, optional) –

    Either a factory function that takes in the optimizer and returns a learning rate scheduler instance, or a Lightning config dictionary containing a “scheduler” key with the partially instantiated scheduler factory and optional Lightning-specific keys. Only needed for training.

    Examples

    # Factory function approach lr_scheduler=lambda opt: torch.optim.lr_scheduler.StepLR(opt, step_size=10)

    # Lightning config dictionary approach lr_scheduler={

    ”scheduler”: partial_scheduler_factory, # e.g., functools.partial(torch.optim.lr_scheduler.StepLR, step_size=10) # Lightning-specific keys (optional): “interval”: “epoch”, “frequency”: 1, “monitor”: “val_loss”, # etc.

    }

  • ckpt_path (str, optional) – Path to a pre-trained model checkpoint.

  • resume_training (bool, optional) – Whether to resume training from a checkpoint.

  • metrics (Metric, MetricCollection or Dict[str, Metric/MetricCollection], optional) –

    Metrics to use for evaluation. - If a single Metric is provided, it will be cloned for ‘train’, ‘val’ and ‘test’ stages. - If a single MetricCollection is provided, it will be cloned for ‘train’, ‘val’ and ‘test’ stages. - If a dictionary is provided, it must map keys ‘train’, ‘val’, and/or ‘test’ to Metric or MetricCollection instances. This allows you to specify different metrics for each stage. In all cases, the metrics will be registered as attributes of the LightningModule for proper logging.

    Example usage:
    >>> from torchmetrics import MetricCollection, MeanAbsoluteError, MeanSquaredError
    ...
    >>> # Single Metric for all stages
    >>> metric = MeanAbsoluteError()
    >>> routine = SupervisedRoutine(
    ...     model=my_model,
    ...     loss=my_loss_fn,
    ...     optimizer=lambda params: torch.optim.Adam(params, lr=1e-3),
    ...     metrics=metric,
    ... )
    >>> # Single MetricCollection for all stages
    >>> metrics = MetricCollection({
    ...     "mae": MeanAbsoluteError(),
    ...     "rmse": MeanSquaredError(squared=False),
    ... })
    >>> routine = SupervisedRoutine(
    ...     model=my_model,
    ...     loss=my_loss_fn,
    ...     optimizer=lambda params: torch.optim.Adam(params, lr=1e-3),
    ...     metrics=metrics,
    ... )
    >>> # Distinct metrics for each stage (mix of single metrics and collections)
    >>> metrics_dict = {
    ...     "train": MeanAbsoluteError(),  # Single metric
    ...     "val": MetricCollection({"rmse": MeanSquaredError(squared=False)}),  # Collection
    ...     "test": MeanSquaredError(),  # Single metric
    ... }
    >>> routine = SupervisedRoutine(
    ...     model=my_model,
    ...     loss=my_loss_fn,
    ...     optimizer=lambda params: torch.optim.Adam(params, lr=1e-3),
    ...     metrics=metrics_dict,
    ... )
    

  • test_dataloader_idx_to_suffix (Dict[int, str], optional) – A mapping from test dataloader index to its suffix. This is used to suffix test metrics like test_<metric_name>/<suffix>. If not suffix mapping is provided, the dataloader index will be used as the suffix. The keys should be integers corresponding to dataloader indices (1, 2, …). No suffix mapping is needed for the first test dataloader (index 0) since it is assumed to be the main test set and not suffixed.

Raises:
  • TypeError – If metrics is not a Metric, MetricCollection, or a dictionary of Metrics/MetricCollections.

  • ValueError – If metrics is a dictionary, but its keys are not ‘train’, ‘val’, or ‘test’, or if the keys are not unique.

setup(stage: Literal['fit', 'validate', 'test', 'predict'] | None = None)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
configure_optimizers()[source]#

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

forward(inputs: Tensor) Tensor[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

training_step(batch: Tuple[Tensor, Tensor]) Tensor[source]#

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch: Tuple[Tensor, Tensor]) Tensor[source]#

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

predict_step(batch: Tensor | Tuple[Tensor, Tensor], batch_idx: int, dataloader_idx: int = 0) Tensor[source]#

Perform a prediction step.

This method handles batch unpacking for prediction, supporting both: 1. Batches with targets: (inputs, targets) - extracts only inputs for prediction 2. Batches without targets: inputs only - uses directly

This allows calling trainer.predict() with both training, validation, testing dataloaders which contain targets, as well as prediction dataloaders which contain just the inputs. Otherwise, an error would be raised if the dataloader yields batches with targets.

Parameters:
  • batch – Either a tuple (inputs, targets) or just inputs

  • batch_idx – Index of the current batch

  • dataloader_idx – Index of the current dataloader

Returns:

Model predictions

Return type:

torch.Tensor

Regression Routine#

class chemtorch.core.routine.regression_routine.RegressionRoutine(standardizer: Standardizer | None = None, *args, **kwargs)[source]#

Bases: SupervisedRoutine

Extends SupervisedRoutine for regression tasks by allowing the use of an optional standardizer.

This class is intended for regression models where outputs may need to be destandardized (e.g., to return predictions in the original scale). If a Standardizer is provided, predictions are automatically destandardized in the forward pass.

Parameters:
  • standardizer (Standardizer, optional) – An instance of Standardizer for output destandardization.

  • *args – Additional positional arguments for SupervisedRoutine.

  • **kwargs – Additional keyword arguments for SupervisedRoutine.

See also

chemtorch.routine.supervised_routine.SupervisedRoutine

Examples

>>> # With standardizer (for regression)
>>> routine = RegressionRoutine(
...     model=my_model,
...     standardizer=my_standardizer,
...     loss=my_loss_fn,
...     optimizer=lambda params: torch.optim.Adam(params, lr=1e-3),
...     metrics=my_metrics,
... )
>>> preds = routine(torch.randn(8, 16))  # Returns destandardized predictions
>>> # Without standardizer (raw model output)
>>> routine = RegressionRoutine(
...     model=my_model,
...     loss=my_loss_fn,
...     optimizer=lambda params: torch.optim.Adam(params, lr=1e-3),
...     metrics=my_metrics,
... )
>>> preds = routine(torch.randn(8, 16))  # Returns raw model predictions
__init__(standardizer: Standardizer | None = None, *args, **kwargs)[source]#

Initialize a regression routine with an optional standardizer.

Parameters:
  • standardizer (Standardizer, optional) – An instance of Standardizer for data normalization.

  • *args – Additional positional arguments for SupervisedRoutine.

  • **kwargs – Additional keyword arguments for SupervisedRoutine.

forward(inputs: Tensor) Tensor[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

predict_step(*args: Any, **kwargs: Any) Any[source]#

Perform a prediction step.

This method handles batch unpacking for prediction, supporting both: 1. Batches with targets: (inputs, targets) - extracts only inputs for prediction 2. Batches without targets: inputs only - uses directly

This allows calling trainer.predict() with both training, validation, testing dataloaders which contain targets, as well as prediction dataloaders which contain just the inputs. Otherwise, an error would be raised if the dataloader yields batches with targets.

Parameters:
  • batch – Either a tuple (inputs, targets) or just inputs

  • batch_idx – Index of the current batch

  • dataloader_idx – Index of the current dataloader

Returns:

Model predictions

Return type:

torch.Tensor

on_save_checkpoint(checkpoint: dict) None[source]#

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

on_load_checkpoint(checkpoint: dict) None[source]#

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint – Loaded checkpoint

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

Scheduler#

class chemtorch.core.scheduler.cosine_with_warmup_lr.CosineWithWarmupLR(optimizer, num_warmup_steps: int, num_training_steps: int, eta_min: float = 0.0, start_factor: float = 1e-06, end_factor: float = 1.0, **kwargs)[source]#

Bases: LRScheduler

Cosine annealing with optional linear warmup. If num_warmup_steps > 0, uses SequentialLRWrapper to combine LinearLR and CosineAnnealingLR. Otherwise, uses CosineAnnealingLR directly.

__init__(optimizer, num_warmup_steps: int, num_training_steps: int, eta_min: float = 0.0, start_factor: float = 1e-06, end_factor: float = 1.0, **kwargs)[source]#
step(*args, **kwargs)[source]#

Perform a step.

state_dict()[source]#

Return the state of the scheduler as a dict.

It contains an entry for every variable in self.__dict__ which is not the optimizer.

load_state_dict(state_dict)[source]#

Load the scheduler’s state.

Parameters:

state_dict (dict) – scheduler state. Should be an object returned from a call to state_dict().

get_last_lr()[source]#

Return last computed learning rate by current scheduler.

chemtorch.core.scheduler.graphgps_cosine_with_warmup_lr.get_cosine_scheduler_with_warmup(num_warmup_steps: int, num_training_steps: int, num_cycles: int = 0.5)[source]#
class chemtorch.core.scheduler.sequential_lr_wrapper.SequentialLRWrapper(optimizer, schedulers, milestones, **kwargs)[source]#

Bases: SequentialLR

A wrapper around SequentialLR that recursively instantiates any schedulers or scheduler factories with a shared optimizer passed to this wrapper optimizer.

__init__(optimizer, schedulers, milestones, **kwargs)[source]#

Initialize the SequentialLR and pass the optimizer to each scheduler or scheduler factory.

Parameters:
  • optimizer (Optimizer) – The optimizer to be used with the schedulers.

  • schedulers (list) – A list of scheduler instances or factory functions that return scheduler instances.

  • milestones (list) – A list of milestones for the SequentialLR.

  • **kwargs – Additional keyword arguments to be passed to the SequentialLR constructor.

Property System#

Simplified dataset property calculation system.

This module provides a clean way to compute properties needed for model configuration at runtime, with proper handling of partition-dependent vs partition-independent properties.

chemtorch.core.property_system.compute_property_with_dataset_handling(property_instance: DatasetProperty, dataset: DatasetBase | Dict[str, DatasetBase] | List[DatasetBase]) Any[source]#

Compute a property while handling edge cases where dataset might be a dict or list.

For dict datasets (multiple named test datasets), computes the property for the first dataset. For list datasets (multiple test datasets), computes the property for the first dataset. For single datasets, computes the property directly.

Parameters:
  • property_instance – The property to compute

  • dataset – The dataset(s) to compute the property for

Returns:

The computed property value

chemtorch.core.property_system.resolve_sources(source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], tasks: List[Literal['fit', 'validate', 'test', 'predict']]) List[Literal['train', 'val', 'test', 'predict']][source]#

Resolve a PropertySource into a list of DatasetKeys.

class chemtorch.core.property_system.DatasetProperty(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: ABC

Base class for dataset properties that can be computed at runtime.

__init__(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False) None[source]#
abstract compute(dataset: DatasetBase) Any[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.PrecomputeTime(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase) float[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.NumNodeFeatures(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase[Data, Any]) int[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.NumEdgeFeatures(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase) int[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.FingerprintLength(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase) int[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.VocabSize(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase[Any, AbstractTokenRepresentation]) int[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.LabelMean(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase) float[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.LabelStd(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase) float[source]#

Compute the property value from the dataset.

class chemtorch.core.property_system.DegreeStatistics(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#

Bases: DatasetProperty

compute(dataset: DatasetBase) Dict[str, int | List[int]][source]#

Compute the property value from the dataset.