Source code for chemtorch.components.transform.graph_transform.randomwalkpe

import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import (
    get_self_loop_attr,
    is_torch_sparse_tensor,
    scatter,
    to_edge_index,
    to_torch_coo_tensor,
    to_torch_csr_tensor,
)

from chemtorch.components.transform.abstract_transform import AbstractTransform


[docs] class RandomWalkPETransform(AbstractTransform[Data]): """ This code includes implementations adapted from PyTorch Geometric (https://github.com/pyg-team/pytorch_geometric) """
[docs] def __init__( self, walk_length: int, attr_name=None, type: str = "graph", ) -> None: super().__init__() self.walk_length = walk_length self.attr_name = attr_name
# override def __call__(self, data: Data) -> Data: assert data.edge_index is not None row, col = data.edge_index N = data.num_nodes assert N is not None if data.edge_weight is None: value = torch.ones(data.num_edges, device=row.device) else: value = data.edge_weight value = scatter(value, row, dim_size=N, reduce="sum").clamp(min=1)[row] value = 1.0 / value if N <= 2_000: # Dense code path for faster computation: adj = torch.zeros((N, N), device=row.device) adj[row, col] = value loop_index = torch.arange(N, device=row.device) elif torch_geometric.typing.NO_MKL: # pragma: no cover adj = to_torch_coo_tensor(data.edge_index, value, size=data.size()) else: adj = to_torch_csr_tensor(data.edge_index, value, size=data.size()) def get_pe(out: torch.Tensor) -> torch.Tensor: if is_torch_sparse_tensor(out): return get_self_loop_attr(*to_edge_index(out), num_nodes=N) return out[loop_index, loop_index] out = adj pe_list = [get_pe(out)] for _ in range(self.walk_length - 1): out = out @ adj pe_list.append(get_pe(out)) pe = torch.stack(pe_list, dim=-1) if self.attr_name is None: if data.x is not None: x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x data.x = torch.cat([x, pe.to(x.device, x.dtype)], dim=-1) else: data.x = pe else: data[self.attr_name] = pe return data