Source code for openfgl.flcore.feddep.client

import copy
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.loader import NeighborSampler

from openfgl.flcore.base import BaseClient

from openfgl.flcore.feddep.localdep import LocalDGen, FedDEP, Classifier_F
from openfgl.flcore.feddep._utils import HideGraph, LocalRecLoss, FedRecLoss, GraphMender
from openfgl.flcore.feddep.feddep_config import config


[docs]class FedDEPClient(BaseClient): """ FedDEPClient is a client implementation for the Federated Learning algorithm with Deep Efficient Private Neighbor Generation for Subgraph Federated Learning (FedDEP). It extends the BaseClient class and handles local training, private neighbor generation, and graph recovery tasks in a federated setting. Attributes: hide_graph_model (nn.Module): Model for generating private subgraphs by hiding parts of the input graph. data (torch_geometric.data.Data): The original dataset after applying splits for training, validation, and testing. hide_data (torch_geometric.data.Data): The dataset with hidden information for training the private neighbor generation model. emb (torch.Tensor): Embedding tensor generated by the hide_graph_model. x_missing (torch.Tensor): Tensor containing the missing features of the graph generated by the hide_graph_model. loss_fn_num (function): Loss function used for numerical loss during training. loss_fn_rec (function): Loss function used for reconstruction loss during training. fill_dataloader (dict): Dataloader for filling in missing graph information. """
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device): """ Initializes the FedDEPClient. Attributes: args (Namespace): Arguments containing model and training configurations. client_id (int): ID of the client. data (object): Data specific to the client's task. data_dir (str): Directory containing the data. message_pool (object): Pool for managing messages between client and server. device (torch.device): Device to run the computations on. """ super(FedDEPClient, self).__init__(args, client_id, data, data_dir, message_pool, device) self.task.load_custom_model(Classifier_F( input_dim=(self.task.num_feats, self.args.hid_dim), hid_dim=self.args.hid_dim, output_dim=self.task.num_global_classes, num_layers=self.args.num_layers, dropout=self.args.dropout)) self.hide_graph_model = HideGraph(encoder_hid_dim=self.args.hid_dim, encoder_output_dim=self.task.num_global_classes, encoder_num_layers=self.args.num_layers, hidden_portion=config["hide_portion"], num_preds=config["num_preds"], num_protos=config["num_protos"], device=device) self.data = self.task.splitted_data["data"] self.data.train_mask = self.task.splitted_data["train_mask"] self.data.val_mask = self.task.splitted_data["val_mask"] self.data.test_mask = self.task.splitted_data["test_mask"] self.hide_data, self.emb, self.x_missing = self.hide_graph_model(data=self.data) self.loss_fn_num = F.smooth_l1_loss self.loss_fn_rec = LocalRecLoss self.task.loss_fn = F.cross_entropy self.task.override_evaluate = self.get_override_evaluate() self.send_message()
[docs] def send_message(self): """ Sends a message to the server containing the current model parameters, the embedding tensor, and the tensor of missing features after applying the hide_graph_model. """ self.message_pool[f"client_{self.client_id}"] = { "num_samples": self.task.num_samples, "weight": list(self.task.model.parameters()), "embedding": self.emb, "x_missing": self.x_missing }
[docs] def execute(self): """ Executes the local training process. Depending on the current round, it either pre-trains the local model for private neighbor generation and prepares the FedDEP model, or performs the federated training with the global model updates. """ # switch phase if self.message_pool["round"] == 0: self.phase = 0 if self.message_pool["round"] == 1: self.phase = 1 self.filled_data = GraphMender( model=self.feddep_model, impaired_data=self.hide_data, original_data=self.data, num_preds=config["num_preds"]) self.filled_data["data"] = self.filled_data["data"].to(self.device) # subgraph_sampler = NeighborSampler( # self.data.edge_index, num_nodes=self.data.num_nodes, # sizes=[-1], batch_size=4096, shuffle=False) self.fill_dataloader = { "data": self.filled_data["data"], "train": NeighborSampler( self.filled_data["data"].edge_index, num_nodes=self.filled_data["data"].num_nodes, node_idx=torch.where(self.filled_data["train_mask"] == True)[0], sizes=[5, 5], batch_size=64, shuffle=True ), # "val": subgraph_sampler, # "test": subgraph_sampler } # execute if self.phase == 0: pre_train_model = LocalDGen(input_dim=self.task.num_feats, emb_shape=self.args.hid_dim, output_dim=self.task.num_global_classes, hid_dim=self.args.hid_dim, gen_dim=config["gen_hidden"], dropout=self.args.dropout, num_preds=config["num_preds"]).to(self.device) print(f"Client {self.client_id} pre-train start...") pre_train_model.train() pre_train_optim = self.task.default_optim(pre_train_model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) for i in range(config["pre_train_epochs"]): pred_missing, pred_emb, nc_pred = pre_train_model(self.hide_data.to(self.device)) mask_true_index = np.where(self.hide_data.train_mask.cpu().numpy() == True)[0] loss_num = self.loss_fn_num( pred_missing[self.hide_data.train_mask], self.hide_data.num_missing[self.hide_data.train_mask] ) loss_rec = self.loss_fn_rec( pred_embs=pred_emb[self.hide_data.train_mask], true_embs=[self.hide_data.x_missing[node] for node in mask_true_index], pred_missing=pred_missing[self.hide_data.train_mask], true_missing=self.hide_data.num_missing[self.hide_data.train_mask], num_preds=config["num_preds"] ) loss_clf = self.task.loss_fn( nc_pred[self.hide_data.train_mask], self.hide_data.y[self.hide_data.train_mask], ) per_train_loss = config["beta_d"] * loss_num + config["beta_c"] * loss_clf + config["beta_n"] * loss_rec pre_train_optim.zero_grad() per_train_loss.backward() pre_train_optim.step() print(f"Client {self.client_id} local pre-train @Epoch {i}.") print(f"Client {self.client_id} pre-train finish!") self.feddep_model = FedDEP(pre_train_model).to(self.device) feddep_optim = self.task.default_optim(self.feddep_model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) for i in range(config["feddep_epochs"]): dep_grad = dict() para_backup = copy.deepcopy(self.feddep_model.state_dict()) for client_id in self.message_pool["sampled_clients"]: if client_id != self.client_id: # calculate gradients emb, x_missing = ( self.message_pool[f"client_{client_id}"]["embedding"], self.message_pool[f"client_{client_id}"]["x_missing"]) self.feddep_model.load_state_dict(para_backup) self.feddep_model.train() _, embedding = self.feddep_model.encoder_model(self.hide_data) pred_missing = self.feddep_model.reg_model(embedding) pred_embs = self.feddep_model.gen(embedding) emb_len = pred_embs.shape[-1] // config["num_preds"] choice = np.random.choice(len(x_missing), embedding.shape[0]) global_target_emb = [] for c_i in choice: choice_i = np.random.choice( len(x_missing[c_i]), config["num_preds"]) for ch_i in choice_i: if torch.sum(x_missing[c_i][ch_i]) < 1e-15: global_target_emb.append(emb[c_i]) else: global_target_emb.append( x_missing[c_i][ch_i].detach().cpu().numpy()) global_target_emb = np.asarray(global_target_emb).reshape( (embedding.shape[0], config["num_preds"], emb_len)) loss_emb = FedRecLoss( pred_embs=pred_embs, true_embs=global_target_emb, pred_missing=pred_missing, num_preds=config["num_preds"], ) other_loss = ( 1.0 / self.args.num_clients * config["beta_n"] * loss_emb ).requires_grad_() other_loss.backward() # sum up all gradients from other clients if not dep_grad: for k, v in self.feddep_model.named_parameters(): dep_grad[k] = v.grad else: for k, v in self.feddep_model.named_parameters(): dep_grad[k] += v.grad # Rollback self.feddep_model.load_state_dict(para_backup) pred_missing, pred_emb, nc_pred = self.feddep_model.forward(self.hide_data) mask_true_index = np.where(self.hide_data.train_mask.cpu().numpy() == True)[0] loss_num = self.loss_fn_num( pred_missing[self.hide_data.train_mask], self.hide_data.num_missing[self.hide_data.train_mask] ) loss_rec = self.loss_fn_rec( pred_embs=pred_emb[self.hide_data.train_mask], true_embs=[self.hide_data.x_missing[node] for node in mask_true_index], pred_missing=pred_missing[self.hide_data.train_mask], true_missing=self.hide_data.num_missing[self.hide_data.train_mask], num_preds=config["num_preds"] ) loss_clf = self.task.loss_fn( nc_pred[self.hide_data.train_mask], self.hide_data.y[self.hide_data.train_mask], ) feddep_loss = config["beta_d"] * loss_num + config["beta_c"] * loss_clf + config["beta_n"] * loss_rec feddep_loss = feddep_loss.float() / self.args.num_clients feddep_optim.zero_grad() feddep_loss.backward() # feddep_optim.step() for k, v in self.feddep_model.named_parameters(): v.grad += dep_grad[k] feddep_optim.step() else: for (local_param, global_param) in zip( self.task.model.parameters(), self.message_pool["server"]["weight"]): local_param.data.copy_(global_param) for data_batch in self.fill_dataloader["train"]: batch_size, n_id, adjs = data_batch adjs = [adj.to(self.device) for adj in adjs] if "mend_emb" not in self.fill_dataloader["data"]: mend_emb = torch.zeros( (len(self.fill_dataloader["data"].x), self.task.model.emb_len) ).to(self.device) else: mend_emb = self.fill_dataloader["data"].mend_emb pred = self.task.model.forward( (self.fill_dataloader["data"].x[n_id], mend_emb[n_id]), adjs=adjs) label = self.fill_dataloader["data"].y[n_id[:batch_size]].to(self.device) loss_clf = self.task.loss_fn(pred, label) self.task.optim.zero_grad() loss_clf.backward() self.task.optim.step()
[docs] def get_override_evaluate(self): """ Overrides the default evaluation method to evaluate the model on the locally filled data. This method computes the evaluation metrics on training, validation, and test datasets. Returns: function: A custom evaluation function. """ from openfgl.utils.metrics import compute_supervised_metrics def override_evaluate(splitted_data=None, mute=False): if splitted_data is None: try: splitted_data = self.filled_data except: splitted_data = self.task.splitted_data else: names = ["data", "train_mask", "val_mask", "test_mask"] for name in names: assert name in splitted_data eval_output = {} self.task.model.eval() with torch.no_grad(): logits = self.task.model.forward(splitted_data["data"]) loss_train = self.task.loss_fn(logits[splitted_data["train_mask"]], splitted_data["data"].y[splitted_data["train_mask"]]) loss_val = self.task.loss_fn(logits[splitted_data["val_mask"]], splitted_data["data"].y[splitted_data["val_mask"]]) loss_test = self.task.loss_fn(logits[splitted_data["test_mask"]], splitted_data["data"].y[splitted_data["test_mask"]]) eval_output["loss_train"] = loss_train eval_output["loss_val"] = loss_val eval_output["loss_test"] = loss_test metric_train = compute_supervised_metrics( metrics=self.args.metrics, logits=logits[splitted_data["train_mask"]], labels=splitted_data["data"].y[splitted_data["train_mask"]], suffix="train" ) metric_val = compute_supervised_metrics( metrics=self.args.metrics, logits=logits[splitted_data["val_mask"]], labels=splitted_data["data"].y[splitted_data["val_mask"]], suffix="val" ) metric_test = compute_supervised_metrics( metrics=self.args.metrics, logits=logits[splitted_data["test_mask"]], labels=splitted_data["data"].y[splitted_data["test_mask"]], suffix="test" ) eval_output = {**eval_output, **metric_train, **metric_val, **metric_test} info = "" for key, val in eval_output.items(): try: info += f"\t{key}: {val:.4f}" except: continue prefix = f"[client {self.client_id}]" if self.client_id is not None else "[server]" if not mute: print(prefix+info) return eval_output return override_evaluate