Source code for bgd.real.twitch_ego_dataset

import json
import os
import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch_geometric as pyg
from torch_geometric.data import InMemoryDataset
import wget
from tqdm import tqdm
from ..utils import describe_one_dataset
import zipfile

def get_twitch(num = 49152, include_targets = False):
    print("\nProcessing twitch egos dataset")
    zip_url = "https://snap.stanford.edu/data/twitch_egos.zip"
    start_dir = os.getcwd()
    os.chdir("bgd_files")

    if "twitch_edges.json" not in os.listdir("twitch_egos"):
        print("Downloading Twitch Egos")
        _ = wget.download(zip_url)
        with zipfile.ZipFile("twitch_egos.zip", 'r') as zip_ref:
            zip_ref.extractall(".")
        os.remove("twitch_egos.zip")
    os.chdir("twitch_egos")


    with open("twitch_edges.json", "r") as f:
        all_edges = json.load(f)

    twitch_targets = pd.read_csv("twitch_target.csv")
    ids, targets = twitch_targets["id"], twitch_targets["target"]
    id_to_target = {ids[i]:targets[i] for i in range(len(ids))}


    graph_ids = list(all_edges.keys())

    graphs = []

    print("Entering ego processing loop")
    for id in tqdm(graph_ids[:num], leave = False):
        edges = all_edges[id]

        g = nx.Graph()

        nodes = np.unique(edges).tolist()

        for node in nodes:
            g.add_node(node) #, attr = torch.Tensor([1]))

        for edge in edges:
            g.add_edge(edge[0], edge[1]) #, attr=torch.Tensor([1]))
        graphs.append(g)

    os.chdir(start_dir)
    # data_objects = [pyg.utils.from_networkx(g, group_node_attrs=all, group_edge_attrs=all) for g in graphs]
    data_objects = [pyg.utils.from_networkx(g) for g in graphs]

    for id, data in enumerate(data_objects):
        data.y = torch.Tensor([id_to_target[id]])

    return  data_objects

[docs] class TwitchEgoDataset(InMemoryDataset): r""" Contributor: Alex O. Davies Contributor email: `alexander.davies@bristol.ac.uk` Ego networks from the streaming platform Twitch. The original graph is sourced from: `B. Rozemberczki, O. Kiss, R. Sarkar: An API Oriented Open-source Python Framework for Unsupervised Learning on Graphs 2019.` The task is predicting whether a given streamer plays multiple different games. - Task: Graph classification - Num node features: None - Num edge features: None - Num target values: 1 - Target shape: 1 - Num graphs: 127094 Args: root (str): Root directory where the dataset should be saved. stage (str): The stage of the dataset to load. One of "train", "val", "test". (default: :obj:`"train"`) transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) num (int): The number of samples to take from the original dataset. (default: :obj:`2000`). """
[docs] def __init__(self, root, stage = "train", transform=None, pre_transform=None, pre_filter=None, num = 5000): self.num = num self.stage = stage self.stage_to_index = {"train":0, "val":1, "test":2} self.task = "graph-classification" super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[self.stage_to_index[self.stage]])
@property def raw_file_names(self): return ['twitch_edges.json', 'twitch_target.json'] @property def processed_file_names(self): return ['train.pt', 'val.pt', 'test.pt']
[docs] def process(self): # Read data into huge `Data` list. if os.path.isfile(self.processed_paths[self.stage_to_index[self.stage]]): print("Ego files exist") return data_list = get_twitch(num=self.num, include_targets =self.stage != "train") if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[self.stage_to_index[self.stage]])
if __name__ == "__main__": dataset = TwitchEgoDataset(os.getcwd()+'/bgd_files/'+'twitch_egos', stage = "train") describe_one_dataset(dataset)