import os
from torch_geometric.data import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset
from bgd.real.facebook_dataset import FacebookDataset
from bgd.real.twitch_ego_dataset import TwitchEgoDataset
from bgd.real.cora_dataset import CoraDataset
from bgd.real.neural_dataset import NeuralDataset
from bgd.real.pennsylvania_road_dataset import PennsylvaniaRoadDataset
from bgd.real.reddit_dataset import RedditDataset
from bgd.real.from_ogb_dataset import FromOGBDataset, from_ogb_dataset
from bgd.real.from_tu_dataset import from_tu_dataset
from bgd.real.livejournal_dataset import LivejournalDataset
from bgd.synthetic.random_dataset import RandomDataset
from bgd.synthetic.tree_dataset import TreeDataset
from bgd.synthetic.lattice_dataset import LatticeDataset
from bgd.synthetic.community_dataset import CommunityDataset
[docs]
def get_datasets(transforms, num, stage="train", exclude=None, include=None):
r"""
Retrieves and transforms a list of datasets based on specified inclusion and exclusion criteria.
Parameters:
transforms (function): A function to apply transformations to each dataset.
num (int): The number of data points to include in each dataset.
stage (str, optional): The stage of data processing (e.g., "train", "test", "validate"). Default is "train".
exclude (list or str, optional): A list or a single string specifying dataset names to exclude from the selection. If None, no datasets will be excluded. Default is None.
include (list or str, optional): A list or a single string specifying dataset names to include in the selection. If None, all datasets not in the exclude list will be included. Default is None.
Returns:
list: A list of transformed datasets.
list: A list of names of the selected datasets.
Notes:
- If both `exclude` and `include` are provided, the function first applies the `exclude` filter and then the `include` filter.
- The function checks for the existence of a "bgd_files" directory and creates it if it does not exist.
- The function supports various datasets, including predefined datasets and those from the Open Graph Benchmark (OGB) and TU datasets.
Example:
>>> def dummy_transform(dataset):
>>> return dataset
>>> datasets, names = get_datasets(dummy_transform, num=100, stage="train", exclude=["reddit"], include=["cora", "trees"])
>>> print(names)
['cora', 'trees']
"""
if "bgd_files" not in os.listdir():
os.mkdir("bgd_files")
all_datasets = {
"facebook_large": FacebookDataset,
"reddit": RedditDataset,
"roads": PennsylvaniaRoadDataset,
"twitch_egos": TwitchEgoDataset,
"cora": CoraDataset,
"fruit_fly": NeuralDataset,
"livejournal":LivejournalDataset,
"trees": TreeDataset,
"random": RandomDataset,
"community": CommunityDataset
}
ogb_names = ["ogbg-molpcba", "ogbg-molesol", "ogbg-molclintox",
"ogbg-molfreesolv", "ogbg-mollipo", "ogbg-molhiv",
"ogbg-molbbbp", "ogbg-molbace"]
all_datasets.update({name: from_ogb_dataset for name in ogb_names})
tu_names = ["MUTAG", "ENZYMES", "PROTEINS", "COLLAB", "IMDB-BINARY", "REDDIT-BINARY"]
all_datasets.update({name: from_tu_dataset for name in tu_names})
names = list(all_datasets.keys())
if exclude:
if isinstance(exclude, str):
exclude = [exclude]
names = [d for d in names if d not in exclude]
if include:
if isinstance(include, str):
include = [include]
names = [d for d in names if d in include]
datasets = [transforms(all_datasets[name](os.path.join(os.getcwd(), 'bgd_files', name), num=num, stage=stage))
for name in names]
return datasets, names
[docs]
def get_node_task_datasets(transforms, num=5000, stage="train"):
"""
Returns datasets with node-level tasks, both classification and regression.
Args:
transforms (list): List of data transformations to apply to the datasets.
num (int, optional): Number of datasets to retrieve. Defaults to 5000.
stage (str, optional): Stage of the datasets to retrieve. Defaults to "train".
Returns:
list: List of node task datasets (:obj:torch_geometric.data.InMemoryDataset).
"""
includes = ["facebook_large", "cora"]
return get_datasets(transforms, num, stage, include=includes)
[docs]
def get_edge_task_datasets(transforms, num=5000, stage="train"):
"""
Returns datasets with edge-level tasks, both regression and classification.
Args:
transforms (list): List of data transformations to apply to the datasets.
num (int, optional): Number of datasets to retrieve. Defaults to 5000.
stage (str, optional): Stage of the datasets to retrieve. Defaults to "train".
Returns:
list: List of edge task datasets (:obj:torch_geometric.data.InMemoryDataset).
"""
includes = ["fruit_fly", "reddit"]
return get_datasets(transforms, num, stage, include=includes)
[docs]
def get_graph_task_datasets(transforms, num=5000, stage="train"):
"""
Returns datasets with graph-level tasks, both regression and classification.
Args:
transforms (list): List of data transformations to apply to the datasets.
num (int, optional): Number of datasets to retrieve. Defaults to 5000.
stage (str, optional): Stage of the datasets to retrieve. Defaults to "train".
Returns:
list: List of graph task datasets (:obj:torch_geometric.data.InMemoryDataset).
"""
non_graph_level_excludes = ["facebook_large", "cora", "fruit_fly", "reddit"]
return get_datasets(transforms, num, stage, exclude=non_graph_level_excludes)
[docs]
def get_graph_classification_datasets(transforms, num=5000, stage="train"):
"""
Returns datasets with graph classification tasks.
Args:
transforms (list): List of data transformations to apply to the datasets.
num (int, optional): Number of datasets to retrieve. Defaults to 5000.
stage (str, optional): Stage of the datasets to retrieve. Defaults to "train".
Returns:
list: List of graph classification datasets (:obj:torch_geometric.data.InMemoryDataset).
"""
non_graph_classification_excludes = ["facebook_large", "cora", "fruit_fly", "reddit",
"ogbg-molesol", "ogbg-molfreesolv", "ogbg-mollipo",
"community", "trees", "random"]
return get_datasets(transforms, num, stage, exclude=non_graph_classification_excludes)
[docs]
def get_graph_regression_datasets(transforms, num=5000, stage="train"):
"""
Returns datasets with graph regression tasks.
Args:
transforms (list): List of data transformations to apply to the datasets.
num (int, optional): Number of datasets to retrieve. Defaults to 5000.
stage (str, optional): Stage of the datasets to retrieve. Defaults to "train".
Returns:
list: List of graph regression datasets (:obj:torch_geometric.data.InMemoryDataset).
"""
non_graph_level_excludes = ["facebook_large", "cora", "fruit_fly", "reddit",
"ogbg-molpcba", "ogbg-molhiv", "ogbg-moltox21",
"ogbg-molbace", "ogbg-molbbbp", "ogbg-molclintox",
"ogbg-molmuv", "ogbg-molsider", "ogbg-moltoxcast",
"MUTAG", "ENZYMES", "PROTEINS",
"COLLAB", "IMDB-BINARY", "REDDIT-BINARY"]
return get_datasets(transforms, num, stage, exclude=non_graph_level_excludes)
[docs]
def get_test_datasets(transforms, num=2000, mol_only=False, exclude=["community", "trees", "random"], include=None):
"""
Get the test split of each dataset.
Args:
transforms (list): List of data transformations to apply.
num (int): Number of samples in datasets to include (default is 2000).
mol_only (bool): Flag indicating whether to include only chemical datasets (default is False).
Returns:
tuple: A tuple containing two elements:
- datasets (list): List of test datasets.
- names (list): List of dataset names.
"""
datasets, names = get_datasets(transforms, num, stage="test", include=include, exclude=exclude)
return datasets, names
[docs]
def get_val_datasets(transforms, num=2000, mol_only=False, exclude=["community", "trees", "random"], include=None):
"""
Get validation splits for each dataset.
Args:
transforms (list): List of data transformations to apply.
num (int, optional): Number of samples in datasets to include. Defaults to 2000.
mol_only (bool, optional): Flag indicating whether to include only chemical datasets. Defaults to False.
Returns:
tuple: A tuple containing two elements:
- datasets (list): List of validation datasets.
- names (list): List of dataset names.
"""
datasets, names = get_datasets(transforms, num, stage="val", include=include, exclude=exclude)
return datasets, names
[docs]
def get_train_datasets(transforms, num=2000, mol_only=False, exclude=["ogbg-molpcba"], include=None):
"""
Get the training splits of each dataset.
Args:
transforms (list): List of data transformations to apply.
num (int): Number of datasets to retrieve.
mol_only (bool): Flag indicating whether to retrieve only chemical datasets.
Returns:
tuple: A tuple containing two elements:
- datasets (list): A list of all the datasets.
- all_names (list): A list of names corresponding to each dataset.
"""
datasets, names = get_datasets(transforms, num, stage= "train", exclude = exclude, include = include)
return datasets, names
def remove_duplicates_in_place(list1, list2):
seen = set()
i = 0
while i < len(list2):
if list2[i] in seen:
del list1[i]
del list2[i]
else:
seen.add(list2[i])
i += 1
[docs]
def get_all_datasets(transforms, num=5000, mol_only=False):
"""
Get all datasets for training and validation, in that order.
Args:
transforms (list): List of data transformations to apply to the datasets.
num (int, optional): Number of samples to load from each dataset. Defaults to 5000.
mol_only (bool, optional): Flag indicating whether to include only chemical datasets. Defaults to False.
Returns:
tuple: A tuple containing two elements:
- datasets (list): A list of all the datasets.
- all_names (list): A list of names corresponding to each dataset.
"""
train_datasets, train_names = get_train_datasets(transforms, num)
val_datasets, val_names = get_val_datasets(transforms, -1)
test_datasets, test_names = get_val_datasets(transforms, -1)
datasets = train_datasets + val_datasets + test_datasets
all_names = train_names + val_names + test_names
remove_duplicates_in_place(datasets, all_names)
return datasets, all_names