Defining Custom Components#
This page shows how to extend ChemTorch with custom components, highlights recurring design patterns, and points you to real, well-documented implementations in the codebase and API docs to get you started.
Development Workflow#
Use this lightweight checklist before you start coding:
Check existing components in the API docs and source.
Implement by inheriting and reusing existing components wherever possible.
Test thoroughly: start with unit tests; add integration tests when helpful.
Component Interfaces#
Component |
Required Method |
Input → Output |
See Also |
|---|---|---|---|
Data Source |
|
→ |
API docs |
Column Mapper |
|
|
API docs |
Data Splitter |
|
|
API docs |
Representation |
|
|
API docs |
Transform |
|
|
API docs |
Model |
|
batch → predictions |
API docs |
Refer to the API docs for full signatures, parameters, and existing implementations.
Data Pipeline#
A data pipeline is anything that returns a pd.DataFrame or DataSplit when called.
There is no explicit abstract interface—the contract is implicit.
In practice, data pipelines typically load data, apply transformations, and split into train/val/test.
SimpleDataPipeline provides a reference implementation that composes a data source, column mapper, and splitter.
You can use it as a starting point or build your own from scratch.
See Using Your Own Data for a comprehensive walkthrough of building a complete data pipeline.
You may want to implement your own data source component for your specific data format, or a custom data splitter tailored to your use case. The examples below illustrate both.
Example: custom data source (SQLite database)
import sqlite3
import pandas as pd
from chemtorch.components.data_pipeline.data_source import AbstractDataSource
class SQLiteDataSource(AbstractDataSource):
"""Loads chemical data from a SQLite database."""
def __init__(self, db_path: str, table_name: str, query: str = None):
"""
Args:
db_path: Path to SQLite database file
table_name: Name of the table to query
query: Optional custom SQL query (overrides table_name)
"""
self.db_path = db_path
self.table_name = table_name
self.query = query or f"SELECT * FROM {table_name}"
def load(self) -> pd.DataFrame:
"""Load data from SQLite database."""
conn = sqlite3.connect(self.db_path)
try:
df = pd.read_sql_query(self.query, conn)
return df
finally:
conn.close()
Example: custom data splitter (time-based)
The following example demonstrates inheriting from RatioSplitter rather than implementing AbstractDataSplitter directly.
This approach reuses ratio validation and other infrastructure:
import pandas as pd
from chemtorch.components.data_pipeline.data_splitter import RatioSplitter
from chemtorch.utils import DataSplit
class TimeBasedSplitter(RatioSplitter):
"""
Splits data chronologically based on a timestamp column.
Useful for time-series data where you want to predict future outcomes.
Inherits ratio validation from RatioSplitter.
"""
def __init__(
self,
timestamp_col: str,
train_ratio: float = 0.7,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
):
"""
Args:
timestamp_col: Column name containing timestamps
train_ratio: Fraction for training (earliest data)
val_ratio: Fraction for validation
test_ratio: Fraction for testing (most recent data)
"""
super().__init__(
train_ratio=train_ratio,
val_ratio=val_ratio,
test_ratio=test_ratio,
)
self.timestamp_col = timestamp_col
def _split(self, df: pd.DataFrame) -> DataSplit:
# Sort by timestamp instead of random shuffle
df_sorted = df.sort_values(by=self.timestamp_col).reset_index(drop=True)
n = len(df_sorted)
train_end = int(n * self.train_ratio)
val_end = train_end + int(n * self.val_ratio)
train_df = df_sorted.iloc[:train_end]
val_df = df_sorted.iloc[train_end:val_end]
test_df = df_sorted.iloc[val_end:]
return DataSplit(train=train_df, val=val_df, test=test_df)
To further avoid code duplication, we could add a _sort() method to RatioSplitter with random shuffling as default implementation and call it in the _split() method.
Then the TimeBasedSplitter could simply override the _sort() method instead of re-implementing the whole _split() method.
Representation#
Representations convert SMILES strings into data structures suitable for neural networks (graphs, tensors, etc.). Below is a high-level overview of representation classes, how they are implemented, and source code examples to get you started.
Representation |
Description |
Examples |
|---|---|---|
Fingerprint |
Fixed-length binary or count vectors encoding molecular substructures occuring in (reaction) SMILES. |
|
Graph (CGR) |
Graph representation encoding molecular connectivity as nodes (atoms) and edges (bonds). Uses composable featurizers to extract node and edge features. For reactions, uses condensed graph of reaction (CGR) or similar. |
|
Token |
Sequence-based representation encoding molecules/reactions as sequences of discrete tokens (similar to text). Uses external vocabulary file which must first be created and validated (see |
|
3D Graph |
Graph representation with 3D spatial coordinates capturing molecular geometry and conformational information. Requires external |
See chemtorch.components.representation for full API details and implementations.
Best Practices#
Type Hints: Always specify the generic type parameter indicating the type of the produced data object:
class MyRep(AbstractRepresentation[torch.Tensor]): # ✓ Good ... class MyRep(AbstractRepresentation): # ✗ Missing type info ...
Error Handling: Validate SMILES and provide clear error messages:
def construct(self, smiles: str): mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError(f"Invalid SMILES: {smiles}") # ... rest of implementation
Statefulness: Keep representations stateless. Don’t store mutable state:
# ✓ Good: parameters are immutable def __init__(self, radius: int = 2): self.radius = radius # ✗ Bad: mutable state def __init__(self): self.cache = {} # Don't cache results
Transform / Augmentation#
Transforms are commonly used in computer vision to preprocess or augment images (e.g., normalization, rotation, cropping). In ChemTorch, transforms serve a similar purpose for molecular representations. For example, graph transforms modify graph structure or features (e.g., positional encodings, dummy nodes, feature normalization) and 3D transforms modify spatial coordinates (e.g., random noise, rotation, translation).
Transforms can be composed into pipelines via CallableCompose.
You can also leverage transforms to create additional test data loaders by providing a list or dict of test transforms in the DataModule.
Each entry becomes its own test dataset/dataloader with the specified transform applied.
Augmentations are built on top of transforms to augment training data with transformed versions, making models invariant to specific perturbations.
See chemtorch.components.transform and chemtorch.components.augmentation for available transforms and augmentations, and consult the API docs for implementation details.
Best Practices#
Immutability: Avoid modifying the input in-place. Clone if needed:
def __call__(self, data: Data) -> Data: data = data.clone() # ✓ Safe data.x = data.x * 2 return data
Type Consistency: Preserve the object type:
def __call__(self, data: Data) -> Data: # Return same type return data
Docstrings: Document what the transform does:
class MyTransform(AbstractTransform[Data]): """ Brief description of what this transform does. Args: param1: Description param2: Description """
Model#
ChemTorch models are standard PyTorch modules with no special abstract base class.
The only requirement is that the forward() method accepts batched inputs matching the representation type and returns predictions as torch.Tensor with shape (batch_size, output_dim).
For example, graph models typically accepts PyTorch Geometric Batch objects.
Best practices:#
Device Handling: Let PyTorch Lightning handle device placement (don’t call
.to(device)manually)Documentation: Document the
__init__()method and the input/output shapes and requirements of theforward()method.Modularization: start simple and protype your neural network in a single file but modlarize it as variants emerge. The
GNNis the perfact example since it decomposes into encoder blocks, a configurable layer stack, pooling, and a head. See,chemtorch.components.model.gnnandchemtorch.components.layer. This decomposition makes it easy to swap components, run ablations, and scale complexity while keeping everything testable.
Common Patterns#
Pattern |
Description |
Examples |
|---|---|---|
Composition |
Build complex behavior from small, focused components. |
|
Orchestration & Delegation |
Keep top-level components lean; orchestrate and delegate specifics. |
|
External Artifact Contracts |
Components require external data/artifacts referenced via paths. |
|
Next Steps#
Create default configs for your components to make them swappable from the CLI, see Understanding Configs and the
conffolder for composition patterns and overrides.Add integration tests to validate your components work together with the data pipelines, transforms, and models end-to-end, and detect regressions, see Testing.
Contribute you tested components to help us grow ChemTorch, see Contributing.