Source code for bgd.real.from_ogb_dataset

import os
import torch
from torch_geometric.data import InMemoryDataset, Data
from ..utils import describe_one_dataset

from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
from ogb.graphproppred import PygGraphPropPredDataset
from tqdm import tqdm


full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()



def to_onehot_atoms(x):
    one_hot_tensors = []
    for i, num_values in enumerate(full_atom_feature_dims):
        one_hot = torch.nn.functional.one_hot(x[:, i], num_classes=num_values)
        one_hot_tensors.append(one_hot)

    return torch.cat(one_hot_tensors, dim=1)

def to_onehot_bonds(x):
    one_hot_tensors = []
    for i, num_values in enumerate(full_bond_feature_dims):
        one_hot = torch.nn.functional.one_hot(x[:, i], num_classes=num_values)
        one_hot_tensors.append(one_hot)

    return torch.cat(one_hot_tensors, dim=1)

[docs] class FromOGBDataset(InMemoryDataset): r""" Contributor: Alex O. Davies Contributor email: `alexander.davies@bristol.ac.uk` Converts an Open Graph Benchmark dataset into a `torch_geometric.data.InMemoryDataset`. This allows standard dataset operations like concatenation with other datasets. The Open Graph Benchmark project is available here: `Hu, Weihua, et al. "Open graph benchmark: Datasets for machine learning on graphs." Advances in neural information processing systems 33 (2020): 22118-22133.` We convert atom and bond features into one-hot encodings. The resulting shapes are: - node (atom features): (174, N Atoms) - edge (bond features) features: (13, N Bonds) Args: root (str): Root directory where the dataset should be saved. ogb_dataset (list): an `PygGraphPropPredDataset` to be converted back to `InMemoryDataset`. stage (str): The stage of the dataset to load. One of "train", "val", "test". (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) num (int): The number of samples to take from the original dataset. -1 takes all available samples for that stage. (default: :obj:`-1`). """
[docs] def __init__(self, root, ogb_dataset, stage = "train", num = -1, transform=None, pre_transform=None, pre_filter=None): self.ogb_dataset = ogb_dataset self.stage = stage self.stage_to_index = {"train":0, "val":1, "test":2, "train-adgcl":3} self.num = num self.dataset_name = root.split('/')[-1] print(f"Converting OGB stage {self.stage}") super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[self.stage_to_index[self.stage]])
@property def raw_file_names(self): return ['dummy.csv'] @property def processed_file_names(self): return ['train.pt', 'val.pt', 'test.pt']
[docs] def process(self): # Read data into huge `Data` list. if os.path.isfile(self.processed_paths[self.stage_to_index[self.stage]]): print(f"\nOGB files exist at {self.processed_paths[self.stage_to_index[self.stage]]}") return data_list = list(self.ogb_dataset) num_samples = len(self.ogb_dataset) if num_samples < self.num: keep_n = num_samples else: keep_n = self.num data_list = data_list[:keep_n] print(f"Converting OGB dataset to Big Graph Dataset format, keeping {keep_n} samples out of {num_samples}") for i, item in enumerate(tqdm(data_list[:keep_n], leave = False)): if "mol" in self.dataset_name: data = Data(x = to_onehot_atoms(item.x), edge_index=item.edge_index, edge_attr= to_onehot_bonds(item.edge_attr), y = item.y) else: data = Data(x = None, edge_index=item.edge_index, edge_attr= item.edge_attr, y = item.y) data_list[i] = data if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[self.stage_to_index[self.stage]])
[docs] def from_ogb_dataset(root, stage="train", num=-1): """ Load a dataset from the Open Graph Benchmark (OGB) and convert it to the Big Graph Dataset format. Args: name (str): The name of the OGB dataset. (Classification: "ogbg-molpcba", "ogbg-molhiv", "ogbg-moltox21", "ogbg-molbace", "ogbg-molbbbp", "ogbg-molclintox", "ogbg-molmuv", "ogbg-molsider", "ogbg-moltoxcast", "ogbg-ppa") (Regression: "ogbg-molesol", "ogbg-molfreesolv", "ogbg-mollipo") stage (str, optional): The stage of the dataset to load (e.g., "train", "valid", "test"). Defaults to "train". num (int, optional): The number of samples to load. Set to -1 to load all samples. Defaults to -1. Returns: FromOGBDataset: The converted dataset in the Big Graph Dataset format. """ name = root.split('/')[-1] root = root[:-len(name)] print(name, root) dataset = PygGraphPropPredDataset(name = name, root = root) split_idx = dataset.get_idx_split() dataset = dataset[split_idx["valid" if stage == "val" else stage]] return FromOGBDataset(root + name, dataset, stage = stage, num = num)
if __name__ == "__main__": molesol = from_ogb_dataset("ogbg-molesol", stage = "train") describe_one_dataset(molesol)