Core#
The core package contains PyTorch Lightning modules that orchestrate the training loop.
Data Module#
- class chemtorch.core.data_module.DataModule(data_pipeline: Callable[[...], DataSplit | DataFrame], representation: AbstractRepresentation, dataloader_factory: DataLoaderFactoryProtocol, transform: Any | Callable[[Any], Any] | Dict[Literal['train', 'val', 'test', 'predict'], Any | Callable[[Any], Any]] | Dict[Literal['train', 'val', 'test', 'predict'], List[Any | Callable[[Any], Any]] | Dict[str, Any | Callable[[Any], Any]]] | None = None, augmentations: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | Dict[str, Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | Dict[Literal['train', 'val', 'test', 'predict'], int | float] | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False)[source]#
Bases:
LightningDataModule- __init__(data_pipeline: Callable[[...], DataSplit | DataFrame], representation: AbstractRepresentation, dataloader_factory: DataLoaderFactoryProtocol, transform: Any | Callable[[Any], Any] | Dict[Literal['train', 'val', 'test', 'predict'], Any | Callable[[Any], Any]] | Dict[Literal['train', 'val', 'test', 'predict'], List[Any | Callable[[Any], Any]] | Dict[str, Any | Callable[[Any], Any]]] | None = None, augmentations: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | Dict[str, Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | Dict[Literal['train', 'val', 'test', 'predict'], int | float] | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False) None[source]#
Initialize the DataModule with a data pipeline, dataset factory, and dataloader factory.
- Parameters:
data_pipeline (Callable) – A callable that returns a DataSplit or a pandas DataFrame.
representation (AbstractRepresentation) – An instance of a representation class to convert raw data into model-ready format.
dataloader_factory (DataLoaderFactoryProtocol) – A DataLoader class or factory function to create DataLoader instances that work with the data format returned by the representation. Typically, this will be a partially initialized object that subclasses the torch.utils.data.DataLoader class.
transform (Optional[Union[TransformType, Dict[DatasetKey, TransformType], Dict[DatasetKey, Union[List[TransformType], Dict[str, TransformType]]]]]) – An optional transform or a dictionary of transforms for each stage (‘train’, ‘val’, ‘test’, ‘predict’). If a single transform is provided, it will be applied to all stages. If a dictionary is provided, it should map each dataset key to its corresponding transform. For the test set, a list or dict of transforms can be provided to create multiple test datasets. If a dict is provided, the keys should be the names of the test datasets.
augmentations (Optional[List[AbstractAugmentation]] | Dict[str, AbstractAugmentation]) – An optional list or dictionary of augmentations to be applied to the training dataset.
subsample (Optional[Union[int, float, Dict[DatasetKey, Union[int, float]]]]) – An optional integer or float to subsample the datasets. If a float is provided, it should be between 0 and 1 and represents the fraction of the dataset to keep. If an integer is provided, it represents the exact number of samples to keep. If a dictionary is provided, it should map each dataset key to its corresponding subsample fraction or count.
precompute_all (bool) – If True, precompute all samples of the dataset. Default is True.
cache (bool) – If True, enable caching of dataset samples. Default is False.
max_cache_size (Optional[int]) – Maximum number of samples to cache. If None, cache size is unlimited. Default is None.
integer_labels (bool) – If True, convert labels to integers. Default is False.
- Raises:
TypeError – If the output of the data pipeline is not a DataSplit or a pandas DataFrame.
- get_dataset(key: str) DatasetBase[source]#
Retrieve the dataset for the specified key.
- Parameters:
key (str) – The key for which to retrieve the dataset (‘train’, ‘val’, ‘test’, ‘predict’, ‘test_<name>’).
- Returns:
The dataset corresponding to the specified key.
- Return type:
- Raises:
ValueError – If the dataset for the specified key is not initialized.
- get_dataset_names() List[str][source]#
Get all available dataset names/keys.
- Returns:
List of all dataset keys that can be used with get_dataset() and make_dataloader().
- Return type:
List[str]
- make_dataloader(key: str) DataLoader[source]#
Create a dataloader for the specified key.
- Parameters:
key (str) – The key for which to create the dataloader (‘train’, ‘val’, ‘test’, ‘predict’, ‘test_<name>’).
- Returns:
The created dataloader for the specified dataset key.
- Return type:
DataLoader
- Raises:
ValueError – If the dataset for the specified key is not initialized.
- maybe_get_test_dataloader_idx_to_suffix() Dict[int, str] | None[source]#
If multiple named test datasets are initialized by passing a dict with named test set transforms, return a mapping from dataloader index to suffix. Otherwise, return None.
- train_dataloader()[source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
- test_dataloader()[source]#
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- predict_dataloader()[source]#
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data().predict()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying prediction samples.
Dataset#
- class chemtorch.core.dataset_base.DatasetBase(dataframe: DataFrame, representation: R, name: str | None = None, transform: AbstractTransform[T] | Callable[[T], T] | None = None, augmentation_list: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False)[source]#
Bases:
Generic[T,R]Base class for ChemTorch datasets with type-safe representations.
- Type parameters:
T: The data type produced by the representation and returned by the dataset R: The representation type (bounded by AbstractRepresentation)
- Type safety features:
The dataset returns objects of type T (or Tuple[T, torch.Tensor] with labels)
The representation is of type R (bounded by AbstractRepresentation)
Transforms must be compatible with type T
Static type checker can verify basic type relationships
- Usage examples:
# Graph dataset DatasetBase[Data, GraphRepresentation]
# Token dataset DatasetBase[torch.Tensor, TokenRepresentation]
# Fingerprint dataset DatasetBase[torch.Tensor, FingerprintRepresentation]
The dataset can handle both labeled and unlabeled data. If the input DataFrame contains a ‘label’ column, the dataset will return tuples of (data_object, label). Otherwise, it will return only the data objects.
Warning: If the subclass inherits from multiple classes, ensure that
DatasetBaseis the first class in the inheritance list to ensure correct method resolution order (MRO).- Raises:
RuntimeError – If the subclass does not call super().__init__() in its __init__() method.
- __init__(dataframe: DataFrame, representation: R, name: str | None = None, transform: AbstractTransform[T] | Callable[[T], T] | None = None, augmentation_list: List[Any | Callable[[Any, Tensor | None], List[tuple[Any, Tensor | None]]]] | None = None, subsample: int | float | None = None, precompute_all: bool = True, cache: bool = False, max_cache_size: int | None = None, integer_labels: bool = False)[source]#
Initialize the DatasetBase.
- Parameters:
dataframe (pd.DataFrame) – The input data as a pandas DataFrame. Each row represents a single sample. If the dataset contains a label column, it will be returned alongside the computed representation. Otherwise, only the representation will be returned.
representation (R) – A representation instance that constructs the data object consumed by the model. Must take in the fields of a single sample from the
dataframe(row) as keyword arguments and return an object of type T.name (Optional[str]) – An optional name for the dataset to be used in logging (<name> dataset). Default is None.
transform (Optional[AbstractTransform[T] | Callable[[T], T]]) – An optional transformation function or a composition thereof (
Compose) that takes in an object of type T and returns a (possibly modified) object of the same type.augmentation_list (Optional[List[AbstractAugmentation[T] | Callable[[T, torch.Tensor], List[Tuple[T, torch.Tensor]]]]]) – An optional list of data augmentation functions that take in an object of type T and return a (possibly modified) object of the same type. Note: Augmentations are only applied to the training partition, and only if precompute_all is True.
subsample (Optional[int | float]) – The subsample size or fraction. If None, no subsampling is done. If an int, it specifies the number of samples to take. If a float, it specifies the fraction of samples to take. Default is None.
precompute_all (bool) – If True, precompute all samples in the dataset. Default is True.
cache (bool) – If True, cache the processed samples. Default is False.
max_cache_size (Optional[int]) – Maximum size of the cache. Default is None.
integer_labels (bool) – Whether to use integer labels (for classification) or float labels (for regression). If True, labels will be torch.int64. If False, labels will be torch.float. Default is False.
- Raises:
ValueError – If the dataframe is not a pandas DataFrame.
ValueError – If the representation is not an AbstractRepresentation instance.
ValueError – If the transform is not a TransformBase, a callable, or None.
- get_labels()[source]#
Retrieve the labels for the dataset.
- Returns:
The labels for the dataset if they exist.
- Return type:
pd.Series
- Raises:
RuntimeError – If the dataset does not contain labels.
Routine#
Supervised Routine#
- class chemtorch.core.routine.supervised_routine.SupervisedRoutine(model: Module, loss: Callable | None = None, optimizer: Callable[[Iterator[Parameter]], Optimizer] | None = None, lr_scheduler: Callable[[Optimizer], LRScheduler] | Dict[str, Any] | None = None, ckpt_path: str | None = None, resume_training: bool = False, metrics: Metric | MetricCollection | Dict[str, Metric | MetricCollection] | None = None, test_dataloader_idx_to_suffix: Dict[int, str] | None = None)[source]#
Bases:
LightningModuleA flexible LightningModule wrapper for supervised tasks, supporting both training and inference.
- This class can be used for:
Full training/validation/testing with loss, optimizer, scheduler, and metrics.
Inference-only (prediction), requiring only the model.
Example usage:
>>> # Training usage >>> routine = SupervisedRoutine( ... model=my_model, ... loss=my_loss_fn, ... optimizer=lambda params: torch.optim.Adam(params, lr=1e-3), ... lr_scheduler=lambda opt: torch.optim.lr_scheduler.StepLR(opt, step_size=10), ... metrics=my_metrics, ... ) >>> trainer = pl.Trainer(...) >>> trainer.fit(routine, datamodule=my_datamodule)
>>> # Inference-only usage >>> routine = SupervisedRoutine(model=my_model) >>> preds = routine(torch.randn(8, 16)) # Forward pass for prediction
- __init__(model: Module, loss: Callable | None = None, optimizer: Callable[[Iterator[Parameter]], Optimizer] | None = None, lr_scheduler: Callable[[Optimizer], LRScheduler] | Dict[str, Any] | None = None, ckpt_path: str | None = None, resume_training: bool = False, metrics: Metric | MetricCollection | Dict[str, Metric | MetricCollection] | None = None, test_dataloader_idx_to_suffix: Dict[int, str] | None = None)[source]#
Initialize the SupervisedRoutine.
- Parameters:
model (nn.Module) – The model to be trained or used for inference.
loss (Callable, optional) – The loss function to be used. Required for training/validation/testing.
optimizer (Callable, optional) –
A factory function that takes in the model’s parameters and returns an optimizer instance. Required for training/validation/testing.
Example
optimizer=lambda params: torch.optim.Adam(params, lr=1e-3)
lr_scheduler (Callable or Dict, optional) –
Either a factory function that takes in the optimizer and returns a learning rate scheduler instance, or a Lightning config dictionary containing a “scheduler” key with the partially instantiated scheduler factory and optional Lightning-specific keys. Only needed for training.
Examples
# Factory function approach lr_scheduler=lambda opt: torch.optim.lr_scheduler.StepLR(opt, step_size=10)
# Lightning config dictionary approach lr_scheduler={
”scheduler”: partial_scheduler_factory, # e.g., functools.partial(torch.optim.lr_scheduler.StepLR, step_size=10) # Lightning-specific keys (optional): “interval”: “epoch”, “frequency”: 1, “monitor”: “val_loss”, # etc.
}
ckpt_path (str, optional) – Path to a pre-trained model checkpoint.
resume_training (bool, optional) – Whether to resume training from a checkpoint.
metrics (Metric, MetricCollection or Dict[str, Metric/MetricCollection], optional) –
Metrics to use for evaluation. - If a single Metric is provided, it will be cloned for ‘train’, ‘val’ and ‘test’ stages. - If a single MetricCollection is provided, it will be cloned for ‘train’, ‘val’ and ‘test’ stages. - If a dictionary is provided, it must map keys ‘train’, ‘val’, and/or ‘test’ to Metric or MetricCollection instances. This allows you to specify different metrics for each stage. In all cases, the metrics will be registered as attributes of the LightningModule for proper logging.
- Example usage:
>>> from torchmetrics import MetricCollection, MeanAbsoluteError, MeanSquaredError ... >>> # Single Metric for all stages >>> metric = MeanAbsoluteError() >>> routine = SupervisedRoutine( ... model=my_model, ... loss=my_loss_fn, ... optimizer=lambda params: torch.optim.Adam(params, lr=1e-3), ... metrics=metric, ... ) >>> # Single MetricCollection for all stages >>> metrics = MetricCollection({ ... "mae": MeanAbsoluteError(), ... "rmse": MeanSquaredError(squared=False), ... }) >>> routine = SupervisedRoutine( ... model=my_model, ... loss=my_loss_fn, ... optimizer=lambda params: torch.optim.Adam(params, lr=1e-3), ... metrics=metrics, ... ) >>> # Distinct metrics for each stage (mix of single metrics and collections) >>> metrics_dict = { ... "train": MeanAbsoluteError(), # Single metric ... "val": MetricCollection({"rmse": MeanSquaredError(squared=False)}), # Collection ... "test": MeanSquaredError(), # Single metric ... } >>> routine = SupervisedRoutine( ... model=my_model, ... loss=my_loss_fn, ... optimizer=lambda params: torch.optim.Adam(params, lr=1e-3), ... metrics=metrics_dict, ... )
test_dataloader_idx_to_suffix (Dict[int, str], optional) – A mapping from test dataloader index to its suffix. This is used to suffix test metrics like test_<metric_name>/<suffix>. If not suffix mapping is provided, the dataloader index will be used as the suffix. The keys should be integers corresponding to dataloader indices (1, 2, …). No suffix mapping is needed for the first test dataloader (index 0) since it is assumed to be the main test set and not suffixed.
- Raises:
TypeError – If metrics is not a Metric, MetricCollection, or a dictionary of Metrics/MetricCollections.
ValueError – If metrics is a dictionary, but its keys are not ‘train’, ‘val’, or ‘test’, or if the keys are not unique.
- setup(stage: Literal['fit', 'validate', 'test', 'predict'] | None = None)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage – either
'fit','validate','test', or'predict'
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- configure_optimizers()[source]#
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.Note
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- forward(inputs: Tensor) Tensor[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- training_step(batch: Tuple[Tensor, Tensor]) Tensor[source]#
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch: Tuple[Tensor, Tensor]) Tensor[source]#
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- test_step(batch, batch_idx, dataloader_idx=0)[source]#
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- predict_step(batch: Tensor | Tuple[Tensor, Tensor], batch_idx: int, dataloader_idx: int = 0) Tensor[source]#
Perform a prediction step.
This method handles batch unpacking for prediction, supporting both: 1. Batches with targets: (inputs, targets) - extracts only inputs for prediction 2. Batches without targets: inputs only - uses directly
This allows calling trainer.predict() with both training, validation, testing dataloaders which contain targets, as well as prediction dataloaders which contain just the inputs. Otherwise, an error would be raised if the dataloader yields batches with targets.
- Parameters:
batch – Either a tuple (inputs, targets) or just inputs
batch_idx – Index of the current batch
dataloader_idx – Index of the current dataloader
- Returns:
Model predictions
- Return type:
Regression Routine#
- class chemtorch.core.routine.regression_routine.RegressionRoutine(standardizer: Standardizer | None = None, *args, **kwargs)[source]#
Bases:
SupervisedRoutineExtends SupervisedRoutine for regression tasks by allowing the use of an optional standardizer.
This class is intended for regression models where outputs may need to be destandardized (e.g., to return predictions in the original scale). If a Standardizer is provided, predictions are automatically destandardized in the forward pass.
- Parameters:
standardizer (Standardizer, optional) – An instance of Standardizer for output destandardization.
*args – Additional positional arguments for SupervisedRoutine.
**kwargs – Additional keyword arguments for SupervisedRoutine.
See also
chemtorch.routine.supervised_routine.SupervisedRoutineExamples
>>> # With standardizer (for regression) >>> routine = RegressionRoutine( ... model=my_model, ... standardizer=my_standardizer, ... loss=my_loss_fn, ... optimizer=lambda params: torch.optim.Adam(params, lr=1e-3), ... metrics=my_metrics, ... ) >>> preds = routine(torch.randn(8, 16)) # Returns destandardized predictions
>>> # Without standardizer (raw model output) >>> routine = RegressionRoutine( ... model=my_model, ... loss=my_loss_fn, ... optimizer=lambda params: torch.optim.Adam(params, lr=1e-3), ... metrics=my_metrics, ... ) >>> preds = routine(torch.randn(8, 16)) # Returns raw model predictions
- __init__(standardizer: Standardizer | None = None, *args, **kwargs)[source]#
Initialize a regression routine with an optional standardizer.
- Parameters:
standardizer (Standardizer, optional) – An instance of Standardizer for data normalization.
*args – Additional positional arguments for SupervisedRoutine.
**kwargs – Additional keyword arguments for SupervisedRoutine.
- forward(inputs: Tensor) Tensor[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- predict_step(*args: Any, **kwargs: Any) Any[source]#
Perform a prediction step.
This method handles batch unpacking for prediction, supporting both: 1. Batches with targets: (inputs, targets) - extracts only inputs for prediction 2. Batches without targets: inputs only - uses directly
This allows calling trainer.predict() with both training, validation, testing dataloaders which contain targets, as well as prediction dataloaders which contain just the inputs. Otherwise, an error would be raised if the dataloader yields batches with targets.
- Parameters:
batch – Either a tuple (inputs, targets) or just inputs
batch_idx – Index of the current batch
dataloader_idx – Index of the current dataloader
- Returns:
Model predictions
- Return type:
- on_save_checkpoint(checkpoint: dict) None[source]#
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint: dict) None[source]#
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()this is your chance to restore this.- Parameters:
checkpoint – Loaded checkpoint
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
Scheduler#
- class chemtorch.core.scheduler.cosine_with_warmup_lr.CosineWithWarmupLR(optimizer, num_warmup_steps: int, num_training_steps: int, eta_min: float = 0.0, start_factor: float = 1e-06, end_factor: float = 1.0, **kwargs)[source]#
Bases:
LRSchedulerCosine annealing with optional linear warmup. If num_warmup_steps > 0, uses SequentialLRWrapper to combine LinearLR and CosineAnnealingLR. Otherwise, uses CosineAnnealingLR directly.
- __init__(optimizer, num_warmup_steps: int, num_training_steps: int, eta_min: float = 0.0, start_factor: float = 1e-06, end_factor: float = 1.0, **kwargs)[source]#
- state_dict()[source]#
Return the state of the scheduler as a
dict.It contains an entry for every variable in self.__dict__ which is not the optimizer.
- load_state_dict(state_dict)[source]#
Load the scheduler’s state.
- Parameters:
state_dict (dict) – scheduler state. Should be an object returned from a call to
state_dict().
- chemtorch.core.scheduler.graphgps_cosine_with_warmup_lr.get_cosine_scheduler_with_warmup(num_warmup_steps: int, num_training_steps: int, num_cycles: int = 0.5)[source]#
- class chemtorch.core.scheduler.sequential_lr_wrapper.SequentialLRWrapper(optimizer, schedulers, milestones, **kwargs)[source]#
Bases:
SequentialLRA wrapper around SequentialLR that recursively instantiates any schedulers or scheduler factories with a shared optimizer passed to this wrapper optimizer.
- __init__(optimizer, schedulers, milestones, **kwargs)[source]#
Initialize the SequentialLR and pass the optimizer to each scheduler or scheduler factory.
- Parameters:
optimizer (Optimizer) – The optimizer to be used with the schedulers.
schedulers (list) – A list of scheduler instances or factory functions that return scheduler instances.
milestones (list) – A list of milestones for the SequentialLR.
**kwargs – Additional keyword arguments to be passed to the SequentialLR constructor.
Property System#
Simplified dataset property calculation system.
This module provides a clean way to compute properties needed for model configuration at runtime, with proper handling of partition-dependent vs partition-independent properties.
- chemtorch.core.property_system.compute_property_with_dataset_handling(property_instance: DatasetProperty, dataset: DatasetBase | Dict[str, DatasetBase] | List[DatasetBase]) Any[source]#
Compute a property while handling edge cases where dataset might be a dict or list.
For dict datasets (multiple named test datasets), computes the property for the first dataset. For list datasets (multiple test datasets), computes the property for the first dataset. For single datasets, computes the property directly.
- Parameters:
property_instance – The property to compute
dataset – The dataset(s) to compute the property for
- Returns:
The computed property value
- chemtorch.core.property_system.resolve_sources(source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], tasks: List[Literal['fit', 'validate', 'test', 'predict']]) List[Literal['train', 'val', 'test', 'predict']][source]#
Resolve a PropertySource into a list of DatasetKeys.
- class chemtorch.core.property_system.DatasetProperty(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
ABCBase class for dataset properties that can be computed at runtime.
- __init__(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False) None[source]#
- abstract compute(dataset: DatasetBase) Any[source]#
Compute the property value from the dataset.
- class chemtorch.core.property_system.PrecomputeTime(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
DatasetProperty- compute(dataset: DatasetBase) float[source]#
Compute the property value from the dataset.
- class chemtorch.core.property_system.NumNodeFeatures(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
DatasetProperty- compute(dataset: DatasetBase[Data, Any]) int[source]#
Compute the property value from the dataset.
- class chemtorch.core.property_system.NumEdgeFeatures(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
DatasetProperty- compute(dataset: DatasetBase) int[source]#
Compute the property value from the dataset.
- class chemtorch.core.property_system.FingerprintLength(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
DatasetProperty- compute(dataset: DatasetBase) int[source]#
Compute the property value from the dataset.
- class chemtorch.core.property_system.VocabSize(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
DatasetProperty- compute(dataset: DatasetBase[Any, AbstractTokenRepresentation]) int[source]#
Compute the property value from the dataset.
- class chemtorch.core.property_system.LabelMean(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
DatasetProperty- compute(dataset: DatasetBase) float[source]#
Compute the property value from the dataset.
- class chemtorch.core.property_system.LabelStd(name: str, source: Literal['any', 'all', 'train', 'val', 'test', 'predict'], log: bool = False, add_to_cfg: bool = False)[source]#
Bases:
DatasetProperty- compute(dataset: DatasetBase) float[source]#
Compute the property value from the dataset.