Source code for bgd.synthetic.community_dataset

import numpy as np
import networkx as nx
import torch
import torch_geometric as pyg
from torch_geometric.utils.convert import to_networkx
from tqdm import tqdm
import os
from torch_geometric.data import InMemoryDataset

def get_community_graph(size = 48, proportions = [0.25, 0.25, 0.25, 0.25], P_intra = 0.5, P_inter=0.05 + 0.1*np.random.random()):

    sizes = (np.array(proportions) * size).astype(int).tolist()#


    subgraphs = []
    counter = 0
    for i_size, size in enumerate(sizes):
        g = nx.Graph()

        for i in range(counter, counter + size):
            g.add_node(i)

        counter += size
        subgraphs.append(g)

    for g in subgraphs:
        for n1 in g.nodes():
            for n2 in g.nodes():
                if np.random.random() <= P_intra:
                    g.add_edge(n1, n2)

    node_identifiers = [list(g.nodes()) for g in subgraphs]

    G = nx.Graph()
    for g in subgraphs:
        G = nx.compose(G, g)

    for ids_1 in node_identifiers:
        for ids_2 in node_identifiers:
            if ids_1 == ids_2:
                pass
            else:
                for n1 in ids_1:
                    for n2 in ids_2:
                        if np.random.random() <= P_inter:
                            G.add_edge(n1, n2)

    return G, P_inter


def get_community_dataset(num = 1000):
    nx_graph_list_rhos = [get_community_graph(P_inter=0.05 + 0.05*np.random.random()) for _ in tqdm(range(num), leave=False)]
    nx_graph_list = [item[0] for item in nx_graph_list_rhos]
    rhos= [item[1] for item in nx_graph_list_rhos]

    datalist = [pyg.utils.from_networkx(g) for g in tqdm(nx_graph_list)]

    for idata, data in enumerate(datalist):
        data.y = torch.Tensor([rhos[idata]])
        datalist[idata] = data

    return datalist

[docs] class CommunityDataset(InMemoryDataset): r""" Contributor: Alex O. Davies Contributor email: `alexander.davies@bristol.ac.uk` Dataset of random community-like graphs, composed of four evenly sized subgraphs with a random inter-subgraph edge probability. Subgraphs have a density of 0.5, and inter-subgraph edges have a random density between 0.05 and 0.15. The target is the inter-subgraph density of each graph. Args: root (str): Root directory where the dataset should be saved. stage (str): The stage of the dataset to load. One of "train", "val", "test". (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) num (int): The number of samples to take from the original dataset. -1 takes all available samples for that stage. (default: :obj:`-1`). """
[docs] def __init__(self, root, stage="train", transform=None, pre_transform=None, pre_filter=None, num = 2000): self.num = num self.stage = stage self.stage_to_index = {"train":0, "val":1, "test":2} # _ = download_facebook() super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[self.stage_to_index[self.stage]])
@property def raw_file_names(self): return [] @property def processed_file_names(self): return ['train.pt', 'val.pt', 'test.pt']
[docs] def process(self): # Read data into huge `Data` list. if os.path.isfile(self.processed_paths[self.stage_to_index[self.stage]]): print("Facebook files exist") return data_list = get_community_dataset(num=self.num) if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[self.stage_to_index[self.stage]])