Source code for chemtorch.components.model.mlp

from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import nn
from torch_geometric.nn.resolver import activation_resolver

from chemtorch.components.layer.utils import init_norm


[docs] class MLP(nn.Module): """ Multi-Layer Perceptron (MLP) with configurable layers and activation functions. """
[docs] def __init__( self, in_channels: int, out_channels: int, hidden_dims: Optional[List[int]] = None, hidden_size: Optional[int] = None, num_hidden_layers: Optional[int] = None, dropout: float = 0.0, act: Union[str, Callable, None] = "relu", act_kwargs: Optional[Dict[str, Any]] = None, norm: Union[str, Callable, None] = None, norm_kwargs: Optional[Dict[str, Any]] = None, ): """ Initializes an MLP. The architecture is built to match FFNHead's structure and instantiation order precisely. - A Dropout layer is placed before every Linear layer. - Activation is applied after every hidden Linear layer. - Normalization is applied after every hidden Linear layer (if specified). Args: in_channels (int): Input dimension. out_channels (int): Output dimension. hidden_dims (List[int], optional): List of hidden layer dimensions (preferred). hidden_size (int, optional): Hidden layer size (used if `hidden_dims` is not provided). num_hidden_layers (int, optional): Number of hidden layers (used if `hidden_dims` is not provided). dropout (float, optional): Dropout rate. Defaults to 0. act (str or Callable, optional): Activation function. Defaults to "relu". act_kwargs (Dict[str, Any], optional): Arguments for the activation function. norm (str or Callable, optional): Normalization layer. Defaults to None. norm_kwargs (Dict[str, Any], optional): Arguments for the normalization layer. """ super().__init__() # Validate and resolve hidden_dims hidden_dims = self._resolve_hidden_dims( hidden_dims, hidden_size, num_hidden_layers ) self.activation = activation_resolver(act, **(act_kwargs or {})) self.norm = norm self.norm_kwargs = norm_kwargs or {} self.dropout = dropout layers = [] current_dim = in_channels # [Dropout, Linear, Activation, Normalization] for hidden_dim in hidden_dims: if dropout > 0.0: layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(current_dim, hidden_dim)) if self.activation is not None: layers.append(self.activation) if self.norm is not None: layers.append( init_norm(self.norm, hidden_dim, **self.norm_kwargs) ) current_dim = hidden_dim # [Dropout, Linear] if dropout > 0.0: layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(current_dim, out_channels)) self.layers = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of the MLP.""" return self.layers(x)
@staticmethod def _resolve_hidden_dims( hidden_dims: Optional[List[int]], hidden_size: Optional[int], num_hidden_layers: Optional[int] ) -> List[int]: # Predefined error messages val_err_misspecified_args = ValueError( "Specify either hidden_dims OR hidden_size and num_hidden_layers, not both." ) if hidden_dims is None and hidden_size is None and num_hidden_layers is None: return [] if hidden_dims is not None: if hidden_size is not None or num_hidden_layers is not None: raise val_err_misspecified_args if not isinstance(hidden_dims, list): raise ValueError("hidden_dims must be a list or None") return hidden_dims else: if num_hidden_layers is not None: if num_hidden_layers < 1: return [] elif hidden_size is None: raise val_err_misspecified_args return [hidden_size] * num_hidden_layers else: raise val_err_misspecified_args