Source code for openfgl.task.graph_cls

import torch
import torch.nn as nn
from openfgl.task.base import BaseTask
from openfgl.utils.basic_utils import extract_floats, idx_to_mask_tensor, mask_tensor_to_idx
from os import path as osp
from openfgl.utils.metrics import compute_supervised_metrics
import os
import torch
from openfgl.utils.task_utils import load_graph_cls_default_model
import pickle
from torch_geometric.loader import DataLoader
import numpy as np
from openfgl.data.processing import processing



[docs]class GraphClsTask(BaseTask): """ Task class for graph classification in a federated learning setup. Attributes: client_id (int): ID of the client. data_dir (str): Directory containing the data. args (Namespace): Arguments containing model and training configurations. device (torch.device): Device to run the computations on. data (object): Data specific to the task. model (torch.nn.Module): Model to be trained. optim (torch.optim.Optimizer): Optimizer for the model. train_mask (torch.Tensor): Mask for the training set. val_mask (torch.Tensor): Mask for the validation set. test_mask (torch.Tensor): Mask for the test set. train_dataloader (DataLoader): DataLoader for the training set. val_dataloader (DataLoader): DataLoader for the validation set. test_dataloader (DataLoader): DataLoader for the test set. splitted_data (dict): Dictionary containing split data and DataLoaders. processed_data (object): Processed data for training. """ def __init__(self, args, client_id, data, data_dir, device): """ Initialize the GraphClsTask with provided arguments, data, and device. Attributes: args (Namespace): Arguments containing model and training configurations. client_id (int): ID of the client. data (object): Data specific to the task. data_dir (str): Directory containing the data. device (torch.device): Device to run the computations on. """ super(GraphClsTask, self).__init__(args, client_id, data, data_dir, device)
[docs] def train(self, splitted_data=None): """ Train the model on the provided or processed data. Attributes: splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None. """ if splitted_data is None: splitted_data = self.processed_data # use processed_data to train else: names = ["data", "train_dataloader", "val_dataloader", "test_dataloader", "train_mask", "val_mask", "test_mask"] for name in names: assert name in splitted_data self.model.train() for _ in range(self.args.num_epochs): for batch in splitted_data["train_dataloader"]: self.optim.zero_grad() embedding, logits = self.model.forward(batch) loss_train = self.loss_fn(embedding, logits, batch.y, torch.ones_like(batch.y).bool()) loss_train.backward() if self.step_preprocess is not None: self.step_preprocess() self.optim.step()
[docs] def evaluate(self, splitted_data=None, mute=False): """ Evaluate the model on the provided or processed data. Attributes: splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None. mute (bool, optional): If True, suppress the print statements. Defaults to False. Returns: dict: Dictionary containing evaluation metrics and results. """ if splitted_data is None: splitted_data = self.splitted_data # use splitted_data to evaluate else: names = ["data", "train_dataloader", "val_dataloader", "test_dataloader", "train_mask", "val_mask", "test_mask"] for name in names: assert name in splitted_data eval_output = {} self.model.eval() num_samples = len(splitted_data["data"]) num_global_classes = splitted_data["data"].num_global_classes embedding_all = torch.zeros((num_samples, self.args.hid_dim)).to(self.device) logits_all = torch.zeros((num_samples, num_global_classes)).to(self.device) label_all = torch.zeros((num_samples)).to(self.device).long() train_idx = splitted_data["train_mask"].nonzero().squeeze().tolist() if isinstance(train_idx, int): train_idx = [train_idx] val_idx = splitted_data["val_mask"].nonzero().squeeze().tolist() if isinstance(val_idx, int): val_idx = [val_idx] test_idx = splitted_data["test_mask"].nonzero().squeeze().tolist() if isinstance(test_idx, int): test_idx = [test_idx] train_cnt = 0 val_cnt = 0 test_cnt = 0 with torch.no_grad(): for batch in splitted_data["train_dataloader"]: embedding, logits = self.model.forward(batch) embedding_all[train_idx[train_cnt:train_cnt+batch.num_graphs]] = embedding logits_all[train_idx[train_cnt:train_cnt+batch.num_graphs]] = logits label_all[train_idx[train_cnt:train_cnt+batch.num_graphs]] = batch.y train_cnt += batch.num_graphs for batch in splitted_data["val_dataloader"]: embedding, logits = self.model.forward(batch) embedding_all[val_idx[val_cnt:val_cnt+batch.num_graphs]] = embedding logits_all[val_idx[val_cnt:val_cnt+batch.num_graphs]] = logits label_all[val_idx[val_cnt:val_cnt+batch.num_graphs]] = batch.y val_cnt += batch.num_graphs for batch in splitted_data["test_dataloader"]: embedding, logits = self.model.forward(batch) embedding_all[test_idx[test_cnt:test_cnt+batch.num_graphs]] = embedding logits_all[test_idx[test_cnt:test_cnt+batch.num_graphs]] = logits label_all[test_idx[test_cnt:test_cnt+batch.num_graphs]] = batch.y test_cnt += batch.num_graphs loss_train = self.loss_fn(embedding_all, logits_all, label_all, splitted_data["train_mask"]) loss_val = self.loss_fn(embedding_all, logits_all, label_all, splitted_data["val_mask"]) loss_test = self.loss_fn(embedding_all, logits_all, label_all, splitted_data["test_mask"]) eval_output["embedding"] = embedding_all eval_output["logits"] = logits_all eval_output["loss_train"] = loss_train eval_output["loss_val"] = loss_val eval_output["loss_test"] = loss_test metric_train = compute_supervised_metrics(metrics=self.args.metrics, logits=logits_all[splitted_data["train_mask"]], labels=label_all[splitted_data["train_mask"]], suffix="train") metric_val = compute_supervised_metrics(metrics=self.args.metrics, logits=logits_all[splitted_data["val_mask"]], labels=label_all[splitted_data["val_mask"]], suffix="val") metric_test = compute_supervised_metrics(metrics=self.args.metrics, logits=logits_all[splitted_data["test_mask"]], labels=label_all[splitted_data["test_mask"]], suffix="test") eval_output = {**eval_output, **metric_train, **metric_val, **metric_test} info = "" for key, val in eval_output.items(): try: info += f"\t{key}: {val:.4f}" except: continue prefix = f"[client {self.client_id}]" if self.client_id is not None else "[server]" if not mute: print(prefix+info) return eval_output
[docs] def loss_fn(self, embedding, logits, label, mask): """ Calculate the loss for the model. Attributes: embedding (torch.Tensor): Embeddings from the model. logits (torch.Tensor): Logits from the model. label (torch.Tensor): Ground truth labels. mask (torch.Tensor): Mask to filter the logits and labels. Returns: torch.Tensor: Calculated loss. """ return self.default_loss_fn(logits[mask], label[mask])
@property def default_model(self): """ Get the default model for graph classification. Returns: torch.nn.Module: Default model. """ return load_graph_cls_default_model(self.args, input_dim=self.num_feats, output_dim=self.num_global_classes, client_id=self.client_id) @property def default_optim(self): """ Get the default optimizer for the task. Returns: torch.optim.Optimizer: Default optimizer. """ if self.args.optim == "adam": from torch.optim import Adam return Adam @property def num_samples(self): """ Get the number of samples in the dataset. Returns: int: Number of samples. """ return len(self.data) @property def num_feats(self): """ Get the number of features in the dataset. Returns: int: Number of features. """ return self.data[0].x.shape[1] @property def num_global_classes(self): """ Get the number of global classes in the dataset. Returns: int: Number of global classes. """ return self.data.num_global_classes @property def default_loss_fn(self): """ Get the default loss function for the task. Returns: function: Default loss function. """ return nn.CrossEntropyLoss() @property def default_train_val_test_split(self): """ Get the default train/validation/test split. Returns: tuple: Default train/validation/test split ratios. """ return 0.8, 0.1, 0.1 @property def train_val_test_path(self): """ Get the path to the train/validation/test split file. Returns: str: Path to the split file. """ return osp.join(self.data_dir, "graph_cls")
[docs] def load_train_val_test_split(self): """ Load the train/validation/test split from a file. """ if self.client_id is None and len(self.args.dataset) == 1: # server glb_train = [] glb_val = [] glb_test = [] for client_id in range(self.args.num_clients): glb_train_path = osp.join(self.train_val_test_path, f"glb_train_{client_id}.pkl") glb_val_path = osp.join(self.train_val_test_path, f"glb_val_{client_id}.pkl") glb_test_path = osp.join(self.train_val_test_path, f"glb_test_{client_id}.pkl") with open(glb_train_path, 'rb') as file: glb_train_data = pickle.load(file) glb_train += glb_train_data with open(glb_val_path, 'rb') as file: glb_val_data = pickle.load(file) glb_val += glb_val_data with open(glb_test_path, 'rb') as file: glb_test_data = pickle.load(file) glb_test += glb_test_data train_mask = idx_to_mask_tensor(glb_train, self.num_samples).bool() val_mask = idx_to_mask_tensor(glb_val, self.num_samples).bool() test_mask = idx_to_mask_tensor(glb_test, self.num_samples).bool() else: # client train_path = osp.join(self.train_val_test_path, f"train_{self.client_id}.pt") val_path = osp.join(self.train_val_test_path, f"val_{self.client_id}.pt") test_path = osp.join(self.train_val_test_path, f"test_{self.client_id}.pt") glb_train_path = osp.join(self.train_val_test_path, f"glb_train_{self.client_id}.pkl") glb_val_path = osp.join(self.train_val_test_path, f"glb_val_{self.client_id}.pkl") glb_test_path = osp.join(self.train_val_test_path, f"glb_test_{self.client_id}.pkl") if osp.exists(train_path) and osp.exists(val_path) and osp.exists(test_path)\ and osp.exists(glb_train_path) and osp.exists(glb_val_path) and osp.exists(glb_test_path): train_mask = torch.load(train_path) val_mask = torch.load(val_path) test_mask = torch.load(test_path) else: train_mask, val_mask, test_mask = self.local_graph_train_val_test_split(self.data, self.args.train_val_test) if not osp.exists(self.train_val_test_path): os.makedirs(self.train_val_test_path) torch.save(train_mask, train_path) torch.save(val_mask, val_path) torch.save(test_mask, test_path) if len(self.args.dataset) == 1: # map to global glb_train_id = [] glb_val_id = [] glb_test_id = [] for id_train in train_mask.nonzero(): glb_train_id.append(self.data.global_map[id_train.item()]) for id_val in val_mask.nonzero(): glb_val_id.append(self.data.global_map[id_val.item()]) for id_test in test_mask.nonzero(): glb_test_id.append(self.data.global_map[id_test.item()]) with open(glb_train_path, 'wb') as file: pickle.dump(glb_train_id, file) with open(glb_val_path, 'wb') as file: pickle.dump(glb_val_id, file) with open(glb_test_path, 'wb') as file: pickle.dump(glb_test_id, file) self.train_mask = train_mask.to(self.device) self.val_mask = val_mask.to(self.device) self.test_mask = test_mask.to(self.device) self.train_dataloader = DataLoader([basedata for basedata in self.data[self.train_mask]], batch_size=self.args.batch_size, shuffle=False) self.val_dataloader = DataLoader([basedata for basedata in self.data[self.val_mask]], batch_size=self.args.batch_size, shuffle=False) self.test_dataloader = DataLoader([basedata for basedata in self.data[self.test_mask]], batch_size=self.args.batch_size, shuffle=False) self.splitted_data = { "data": self.data, "train_dataloader": self.train_dataloader, "val_dataloader": self.val_dataloader, "test_dataloader": self.test_dataloader, "train_mask": self.train_mask, "val_mask": self.val_mask, "test_mask": self.test_mask } self.processed_data = processing(args=self.args, splitted_data=self.splitted_data, processed_dir=self.data_dir, client_id=self.client_id)
[docs] def local_graph_train_val_test_split(self, local_graphs, split, shuffle=True): """ Split the local graphs into train, validation, and test sets. Attributes: local_graphs (object): Local graphs to be split. split (str or tuple): Split ratios or default split identifier. shuffle (bool, optional): If True, shuffle the graphs before splitting. Defaults to True. Returns: tuple: Masks for the train, validation, and test sets. """ num_graphs = self.num_samples if split == "default_split": train_, val_, test_ = self.default_train_val_test_split else: train_, val_, test_ = extract_floats(split) train_mask = idx_to_mask_tensor([], num_graphs) val_mask = idx_to_mask_tensor([], num_graphs) test_mask = idx_to_mask_tensor([], num_graphs) for class_i in range(local_graphs.num_global_classes): class_i_graph_mask = local_graphs.y == class_i num_class_i_graphs = class_i_graph_mask.sum() class_i_graph_list = mask_tensor_to_idx(class_i_graph_mask) if shuffle: np.random.shuffle(class_i_graph_list) train_mask += idx_to_mask_tensor(class_i_graph_list[:int(train_ * num_class_i_graphs)], num_graphs) val_mask += idx_to_mask_tensor(class_i_graph_list[int(train_ * num_class_i_graphs) : int((train_+val_) * num_class_i_graphs)], num_graphs) test_mask += idx_to_mask_tensor(class_i_graph_list[int((train_+val_) * num_class_i_graphs): min(num_class_i_graphs, int((train_+val_+test_) * num_class_i_graphs))], num_graphs) train_mask = train_mask.bool() val_mask = val_mask.bool() test_mask = test_mask.bool() return train_mask, val_mask, test_mask