Source code for bgd.real.reddit_dataset

import os
import pickle
import wget
import pandas as pd
import networkx as nx
from tqdm import tqdm
import numpy as np
from ..utils import ESWR
import torch
from torch_geometric.data import Data, InMemoryDataset


def fix_property_string(input_string):
    input_string = input_string.split(',')
    input_string = [float(item) for item in input_string]

    return np.array(input_string)

def download_reddit():
    print("Getting reddit networkx graph")
    start_dir = os.getcwd()
    os.chdir("bgd_files")
    if "reddit" not in os.listdir():
        os.mkdir("reddit")
    os.chdir("reddit")

    graph_url = "https://snap.stanford.edu/data/soc-redditHyperlinks-title.tsv"
    embedding_url = "http://snap.stanford.edu/data/web-redditEmbeddings-subreddits.csv"
        
    if "reddit-graph.npz" in os.listdir():
        with open("reddit-graph.npz", "rb") as f:
            graph = pickle.load(f)
        os.chdir(start_dir)
        return graph

    if "soc-redditHyperlinks-title.tsv" not in os.listdir():
        graph_data = wget.download(graph_url)
    if "web-redditEmbeddings-subreddits.csv" not in os.listdir():
        embeddings = wget.download(embedding_url)

    # We know that there are 300 components in the node feature vectors
    embedding_column_names = ["COMPONENT", *[i for i in range(300)]]
    embeddings = pd.read_csv("web-redditEmbeddings-subreddits.csv", names=embedding_column_names).transpose()
    graph_data = pd.read_csv("soc-redditHyperlinks-title.tsv", sep = "\t")




    embeddings.columns = embeddings.iloc[0]
    embeddings = embeddings.drop(["COMPONENT"], axis = 0)

    graph = nx.Graph()

    for col in tqdm(embeddings.columns, desc = "Adding nodes"):
        # attrs here is taken from the embedding data, with the node id the column (col)
        graph.add_node(col, attrs=embeddings[col].to_numpy().astype(float))

    sources = graph_data["SOURCE_SUBREDDIT"].to_numpy()
    targets = graph_data["TARGET_SUBREDDIT"].to_numpy()

    # This line can take a while!
    attrs = [fix_property_string(properties) for properties in tqdm(graph_data["PROPERTIES"].tolist(), desc = "Wrangling edge features")]
    labels = graph_data["LINK_SENTIMENT"].to_numpy()
    all_nodes = list(graph.nodes())

    for i in tqdm(range(sources.shape[0]), desc = "Adding edges"):
        if sources[i] in all_nodes and targets[i] in all_nodes:
            graph.add_edge(sources[i], targets[i],
                        labels = labels[i],
                        attr = attrs[i])


    # Last tidying bits
    graph = nx.convert_node_labels_to_integers(graph)
    CGs = [graph.subgraph(c) for c in nx.connected_components(graph)]
    CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
    graph = CGs[0]
    graph = nx.convert_node_labels_to_integers(graph)

    # Save the graph!
    with open("reddit-graph.npz", "wb") as f:
        pickle.dump(graph, f)

        # Avoids weird directory problems
    os.chdir(start_dir)

    return graph

def get_reddit_dataset(num = 2000):
    graph = download_reddit()

    # Sample 1000 graphs of max 96 nodes from the big reddit graph
    nx_graph_list = ESWR(graph, num, 96)

    pyg_graph_list = [specific_from_networkx(g) for g in nx_graph_list]

    return pyg_graph_list

def specific_from_networkx(graph):
    # Turns a graph into a pytorch geometric object
    # Mostly by unpacking dictionaries on nodes and edges
    # Here node labels are the target
    # One of these functions for each dataset ideally - they are unlikely to transfer across datasets
    node_attrs = []
    edge_indices = []
    edge_labels = []
    edge_attrs = []

    # Collect node labels and attributes
    for n in list(graph.nodes(data=True)):
        # list(graph.nodes(data=True)) returns [(node_id1, {attribute dictionary}), (node_id2, ...), (node_id3, ...)]
        node_attrs.append(torch.Tensor(n[1]["attrs"]))

    # Collect edge indices and attributes
    for e in graph.edges(data=True):
        # graph.edges(data=True) is a generator producing (node_id1, node_id2, {attribute dictionary})
        edge_indices.append((e[0], e[1]))

        edge_attrs.append(torch.Tensor(e[2]["attr"])) 
        edge_labels.append(e[2]["labels"])


    # Specific to classification on edges! This is a binary edge classification (pos/neg) task
    edge_labels = ((torch.Tensor(edge_labels) + 1)/2).reshape(-1,1)

    edge_attrs = torch.stack(edge_attrs)
    node_attrs = torch.stack(node_attrs)
    edge_indices = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()

    # Create PyG Data object
    # Can pass:
    # x:            node features, shape (n nodes x n features)
    # edge_index:   the list of edges in the graph, shape (2, n_edges). Entries edge_index[i, :] are [node_id1, node_id2].
    # edge_attr:    edge features, shape (n_edges, n_features), same order as edgelist
    # y:            targets. Graph regression shape (n_variables), graph classification (n_classes), node classification (n_nodes, n_classes), edge classification (n_edges, n_classes)
    data = Data(x=node_attrs, edge_index=edge_indices, edge_attr = edge_attrs,  y=edge_labels)

    return data


[docs] class RedditDataset(InMemoryDataset): r""" Contributor: Alex O. Davies Contributor email: `alexander.davies@bristol.ac.uk` Reddit hyperlink graphs - ie graphs of subreddits interacting with one another. The original graph is sourced from: `Kumar, Srijan, et al. "Community interaction and conflict on the web." Proceedings of the 2018 world wide web conference. 2018.` We produce this dataset of small graphs using ESWR. The data has text embeddings as node features for each subreddit and text features for the cross-post edges. The task is edge classification for the sentiment of the interaction between subreddits. - Task: Edge classification - Num node features: 300 - Num edge features: 86 - Num target values: 1 - Target shape: N Edges - Num graphs: Parameterised by `num` Args: root (str): Root directory where the dataset should be saved. stage (str): The stage of the dataset to load. One of "train", "val", "test". (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) num (int): The number of samples to take from the original dataset. (default: :obj:`2000`). """
[docs] def __init__(self, root, stage = "train", transform=None, pre_transform=None, pre_filter=None, num = 2000): self.num = num self.stage = stage self.stage_to_index = {"train":0, "val":1, "test":2} # Options are node-classification, node-regression, graph-classification, graph-regression, edge-regression, edge-classification # Graph-level tasks are preferred! (graph-classification and graph-regression) # edge-prediction is another option if you can't think of a good task self.task = "edge-classification" super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[self.stage_to_index[self.stage]])
@property def raw_file_names(self): # Replace with your saved raw file name return ['reddit-graph.npz'] @property def processed_file_names(self): return ['train.pt', 'val.pt', 'test.pt']
[docs] def process(self): # Read data into huge `Data` list. if os.path.isfile(self.processed_paths[self.stage_to_index[self.stage]]): print(f"Reddit files exist") return # Get a list of num pytorch_geometric.data.Data objects data_list = get_reddit_dataset(self.num) # Torch geometric stuff if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] # Save data data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[self.stage_to_index[self.stage]])