import numpy as np
import networkx as nx
import pandas as pd
import torch
import torch_geometric as pyg
from torch_geometric.utils.convert import to_networkx
from tqdm import tqdm
import os
from ..utils import describe_one_dataset, vis_grid
from torch_geometric.data import InMemoryDataset, Data
import inspect
from ..utils import ESWR
from littleballoffur.exploration_sampling import *
import littleballoffur.exploration_sampling as samplers
import sys
import pickle
import zipfile
import wget
def four_cycles(g):
"""
Returns the number of 4-cycles in a graph, normalised by the number of nodes
"""
cycles = nx.simple_cycles(g, length_bound=4)
return len(list(cycles))
def load_fly(return_tensor = False):
start_dir = os.getcwd()
print(start_dir)
os.chdir("bgd_files")
# os.chdir("fruit_fly")
# os.chdir("Supplementary-Data-S1")
if "fruit_fly" not in os.listdir():
os.mkdir("fruit_fly")
os.chdir("fruit_fly")
data_url = "https://raw.githubusercontent.com/alexodavies/general-gcl/main/fruit_fly/Supplementary-Data-S1/all-all_connectivity_matrix.csv"
if "fly_graph.npz" in os.listdir():
print("fly dataset already exists")
with open("fly_graph.npz", "rb") as f:
graph = pickle.load(f)
os.chdir(start_dir)
return graph
else:
print("Downloading fly brain graph")
# os.chdir("roads")
_ = wget.download(data_url)
# with gzip.open('roadNet-PA.txt.gz', 'rb') as f_in:
# with open('roadNet-PA.txt', 'wb') as f_out:
# shutil.copyfileobj(f_in, f_out)
# os.remove("roadNet-PA.txt.gz")
data_path = os.path.join(os.getcwd(), "all-all_connectivity_matrix.csv")
fly_mat = pd.read_csv(
data_path).drop(
columns=['Unnamed: 0'])
fly_mat = fly_mat.to_numpy()
# os.chdir(start_dir)
# os.chdir("bgd_files")
# Could be fun to trim only to multiple-synapse connections?
# fly_mat[fly_mat <= 2] = 0
# fly_mat[fly_mat > 2] = 1
fly_mat[np.identity(fly_mat.shape[0], dtype=bool)] = 0.
fly_graph = fly_mat
nx_graph = nx.from_numpy_array(fly_graph, create_using=nx.Graph)
CGs = [nx_graph.subgraph(c) for c in nx.connected_components(nx_graph)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
nx_graph = CGs[0]
nx_graph = nx.convert_node_labels_to_integers(nx_graph)
nx_graph.remove_edges_from(nx.selfloop_edges(nx_graph))
with open("fly_graph.npz", "wb") as f:
pickle.dump(nx_graph, f)
# print(os.getcwd(), start_dir)
# quit()
os.chdir(start_dir)
return nx_graph
def specific_from_networkx(graph):
# Turns a graph into a pytorch geometric object
# Mostly by unpacking dictionaries on nodes and edges
# Here edge labels are the target
edge_labels = []
edge_indices = []
# Collect edge indices and attributes
for e in graph.edges(data=True):
# graph.edges(data=True) is a generator producing (node_id1, node_id2, {attribute dictionary})
edge_indices.append((e[0], e[1]))
edge_labels.append(torch.Tensor([e[2]["weight"]]))
# Specific to classification on edges! This is a binary edge classification (pos/neg) task
edge_labels = torch.Tensor(edge_labels).reshape(-1,1)
edge_indices = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
# Create PyG Data object
# Can pass:
# x: node features, shape (n nodes x n features)
# edge_index: the list of edges in the graph, shape (2, n_edges). Entries edge_index[i, :] are [node_id1, node_id2].
# edge_attr: edge features, shape (n_edges, n_features), same order as edgelist
# y: targets. Graph regression shape (n_variables), graph classification (n_classes), node classification (n_nodes, n_classes), edge classification (n_edges, n_classes)
data = Data(x=None, edge_index=edge_indices, edge_attr = None, y=edge_labels)
return data
def get_fly_dataset(num = 2000):
fb_graph = load_fly()
nx_graph_list = ESWR(fb_graph, num, 96)
datalist = [specific_from_networkx(item) for item in nx_graph_list]
return datalist
[docs]
class NeuralDataset(InMemoryDataset):
r"""
Contributor: Alex O. Davies
Contributor email: `alexander.davies@bristol.ac.uk`
A dataset of the connectome of a fruit fly larvae.
The original graph is sourced from:
`Michael Winding et al. , The connectome of an insect brain.Science379,eadd9330(2023).DOI:10.1126/science.add9330`
We process the original multigraph into ESWR samples of this neural network, with predicting the strength of the connection (number of synapses) between two neurons as the target.
- Task: Edge regression
- Num node features: 0
- Num edge features: 0
- Num target values: 1
- Target shape: N Edges
- Num graphs: Parameterised by `num`
Args:
root (str): Root directory where the dataset should be saved.
stage (str): The stage of the dataset to load. One of "train", "val", "test". (default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`)
num (int): The number of samples to take from the original dataset. (default: :obj:`2000`).
"""
[docs]
def __init__(self, root, stage="train", transform=None, pre_transform=None, pre_filter=None, num = 2000):
self.num = num
self.stage = stage
self.stage_to_index = {"train":0,
"val":1,
"test":2}
self.task = "edge-regression"
_ = load_fly()
del _
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[self.stage_to_index[self.stage]])
@property
def raw_file_names(self):
return []
@property
def processed_file_names(self):
return ['train.pt',
'val.pt',
'test.pt']
[docs]
def process(self):
# Read data into huge `Data` list.
if os.path.isfile(self.processed_paths[self.stage_to_index[self.stage]]):
print("Connectome files exist")
return
data_list = get_fly_dataset(num=self.num)
# if self.stage == "train":
# print("Found stage train, dropping targets")
# new_data_list = []
# for i, item in enumerate(data_list):
# n_nodes, n_edges = item.x.shape[0], item.edge_index.shape[1]
# data = Data(x = torch.ones(n_nodes).to(torch.int).reshape((-1, 1)),
# edge_index=item.edge_index,
# edge_attr=torch.ones(n_edges).to(torch.int).reshape((-1,1)),
# y = None)
# new_data_list.append(data)
# data_list = new_data_list
# else:
# new_data_list = []
# for i, item in enumerate(data_list):
# n_nodes, n_edges = item.x.shape[0], item.edge_index.shape[1]
# data = Data(x = None,
# edge_index=item.edge_index,
# edge_attr=None,
# y = item.y)
# new_data_list.append(data)
# data_list = new_data_list
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
# if self.stage != "train":
# for i, data in enumerate(data_list):
# vis_from_pyg(data, filename=self.root + f'/processed/{self.stage}-{i}.png')
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[self.stage_to_index[self.stage]])
del data_list
if __name__ == "__main__":
dataset = NeuralDataset(os.getcwd()+'/bgd_files/'+'fruit_fly', stage = "train")
describe_one_dataset(dataset)
vis_grid(dataset[:16], os.getcwd()+"/bgd_files/fruit_fly/train.png")
dataset = NeuralDataset(os.getcwd()+'/bgd_files/'+'fruit_fly', stage = "val")
describe_one_dataset(dataset)
vis_grid(dataset[:16], os.getcwd()+"/bgd_files/fruit_fly/val.png")
dataset = NeuralDataset(os.getcwd()+'/bgd_files/'+'fruit_fly', stage = "test")
describe_one_dataset(dataset)
vis_grid(dataset[:16], os.getcwd()+"/bgd_files/fruit_fly/test.png")