Source code for bgd.synthetic.random_dataset

import numpy as np
import networkx as nx
import torch
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.utils import erdos_renyi_graph

def vis_from_pyg(data, filename = None):
    edges = data.edge_index.T.cpu().numpy()
    labels = data.x[:,0].cpu().numpy()

    g = nx.Graph()
    g.add_edges_from(edges)

    # dropped_nodes = np.ones(labels.shape[0]).astype(bool)
    for ilabel in range(labels.shape[0]):
        if ilabel not in np.unique(edges):
            g.add_node(ilabel)
    # labels = labels[dropped_nodes]

    fig, ax = plt.subplots(figsize = (6,6))

    pos = nx.kamada_kawai_layout(g)

    nx.draw_networkx_edges(g, pos = pos, ax = ax)
    nx.draw_networkx_nodes(g, pos = pos, node_color=labels, cmap="tab20",
                           vmin = 0, vmax = 20, ax = ax)

    ax.axis('off')

    plt.tight_layout()
    if filename is None:
        plt.show()
    else:
        plt.savefig(filename)
        plt.close()


def get_random_graph():

    size = np.random.randint(low = 12, high = 128)

    rho = 0.05 + 0.25 * np.random.random()

    edge_index = erdos_renyi_graph(size, rho)

    G = Data(edge_index = edge_index)

    return G, rho

def get_random_dataset(num = 1000):

    nx_graph_list_rhos = [get_random_graph() for _ in tqdm(range(num), leave=False)]
    datalist = [item[0] for item in nx_graph_list_rhos]
    rhos= [item[1] for item in nx_graph_list_rhos]

    for idata, data in enumerate(datalist):
        data.y = torch.Tensor([rhos[idata]])
        datalist[idata] = data


    return datalist





[docs] class RandomDataset(InMemoryDataset): r""" Contributor: Alex O. Davies Contributor email: `alexander.davies@bristol.ac.uk` Dataset of random erdos-renyi graphs, between 12 and 128 nodes, of random density between 0.05 and 0.3. The target is the density of each graph. Args: root (str): Root directory where the dataset should be saved. stage (str): The stage of the dataset to load. One of "train", "val", "test". (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) num (int): The number of samples to take from the original dataset. -1 takes all available samples for that stage. (default: :obj:`-1`). """
[docs] def __init__(self, root, stage="train", transform=None, pre_transform=None, pre_filter=None, num = 2000): self.num = num self.stage = stage self.stage_to_index = {"train":0, "val":1, "test":2} super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[self.stage_to_index[self.stage]])
@property def raw_file_names(self): return [] @property def processed_file_names(self): return ['train.pt', 'val.pt', 'test.pt']
[docs] def process(self): # Read data into huge `Data` list. if os.path.isfile(self.processed_paths[self.stage_to_index[self.stage]]): print("Random files exist") return data_list = get_random_dataset(num=self.num) if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] # if self.stage != "train": # for i, data in enumerate(data_list): # vis_from_pyg(data, filename=self.root + f'/processed/{self.stage}-{i}.png') for data in data_list: print(data) data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[self.stage_to_index[self.stage]])