Source code for openfgl.flcore.fggp.client

import torch
import torch.nn as nn
from openfgl.flcore.base import BaseClient
import copy
from torch_geometric.utils import to_torch_csc_tensor
from openfgl.flcore.fggp.models import FedGCN,MLP
from sklearn.neighbors import kneighbors_graph
from openfgl.flcore.fggp.fggp_config import config
from openfgl.flcore.fggp.utils import  get_norm_and_orig,get_proto_norm_weighted,proto_align_loss
import torch.nn.functional as F


[docs]class FGGPClient(BaseClient): """ FGGPClient is a client-side implementation for the Federated Graph Learning with Generalizable Prototypes (FGGP) framework. This client handles local training, model updates, and prototype generation in a federated learning setting, focusing on overcoming domain shifts across clients. Attributes: global_model (nn.Module): A copy of the global model used to compute global embeddings. personal_project (nn.Module): A projection layer used for personalizing embeddings. data2 (torch_geometric.data.Data): A copy of the data with modified edges for use in the FGGP algorithm. """
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device): """ Initializes the FGGPClient. Attributes: args (Namespace): Arguments containing model and training configurations. client_id (int): The ID of the client. data (torch_geometric.data.Data): The graph data specific to the client's task. data_dir (str): Directory containing the data. message_pool (dict): Pool for managing messages between client and server. device (torch.device): The device on which computations will be performed (e.g., CPU or GPU). """ super(FGGPClient, self).__init__(args, client_id, data, data_dir, message_pool, device) self.task.load_custom_model(FedGCN(nfeat=self.task.num_feats, nhid=self.args.hid_dim, nclass=self.task.num_global_classes, nlayer=self.args.num_layers, dropout=self.args.dropout)) self.global_model = copy.deepcopy(self.task.model) self.task.splitted_data["data"].adj = to_torch_csc_tensor(self.task.data.edge_index) self.personal_project = MLP(self.args.hid_dim,self.args.hid_dim,0.5)
[docs] def get_custom_loss_fn(self): """ Returns the custom loss function used during local training. The loss function includes: - Cross-entropy loss for classification. - Graph augmentation loss for learning on augmented graph structures. - Prototype alignment loss to align local and global prototypes. """ def custom_loss_fn(embedding, logits, label, mask): loss_ce = torch.nn.functional.cross_entropy(logits[mask], label[mask]) adj_sampled, adj_logits = self.task.model.aug(self.data2) self.data2.adj = adj_sampled emb_g,logits_g = self.task.model(self.data2) ga_loss = self.data2.norm_w * F.binary_cross_entropy_with_logits(adj_logits, self.data2.adj_orig, pos_weight=self.data2.pos_weight) loss_ce2 = F.cross_entropy(logits_g[mask],self.data2.y[mask]) output_exp = torch.exp(F.log_softmax(logits,dim=1)) confidences = output_exp.max(1)[0] pseudo_labels = output_exp.max(1)[1].type_as(label) pseudo_labels[mask] = label[mask] confidences[mask] = 1.0 unique_labels = torch.unique(pseudo_labels) proto = get_proto_norm_weighted(self.task.num_global_classes, embedding, pseudo_labels, confidences, unique_labels) proto_global = get_proto_norm_weighted(self.task.num_global_classes, emb_g, pseudo_labels, confidences, unique_labels) loss_pa = proto_align_loss(proto_global, proto, temperature=0.5) loss = loss_ce + ga_loss + loss_ce2 + loss_pa return loss return custom_loss_fn
[docs] def execute(self): """ Executes the local training process. This involves: - Synchronizing the local and global model parameters with the server. - Calculating the k-nearest neighbors graph for global embeddings. - Training the model using the custom loss function. """ with torch.no_grad(): for (local_param, g_p,global_param) in zip(self.task.model.parameters(), self.global_model.parameters(),self.message_pool["server"]["weight"]): local_param.data.copy_(global_param) g_p.data.copy_(global_param) self.task.loss_fn = self.get_custom_loss_fn() self.global_model.eval() globel_emb, _ = self.global_model(self.task.data) adj = kneighbors_graph(globel_emb.detach().cpu(), config['neibor_num'], metric='cosine') del globel_emb, _ adj.setdiag(1) coo = adj.tocoo() self.task.data.global_edge_index = torch.tensor([coo.row, coo.col], dtype=torch.long).to(self.device) del coo del adj combined_edge_index = torch.cat([self.task.data.edge_index, self.task.data.global_edge_index], dim=1) # combined_edge_index = torch.cat([train_loader.edge_index, train_loader.edge_index], dim=1) edge_set = set(zip(combined_edge_index[0].cpu().tolist(), combined_edge_index[1].cpu().tolist())) union_edge_index = torch.tensor([[i[0] for i in edge_set], [i[1] for i in edge_set]], dtype=torch.long) self.data2 = self.task.splitted_data["data"].clone() self.data2.edge_index = union_edge_index self.data2 = get_norm_and_orig(self.data2) adj_orig = self.data2.adj_orig norm_w = adj_orig.shape[0] ** 2 / float((adj_orig.shape[0] ** 2 - adj_orig.sum()) * 2) pos_weight = torch.FloatTensor([float(adj_orig.shape[0] ** 2 - adj_orig.sum()) / adj_orig.sum()]).to( self.device) self.data2.norm_w = norm_w self.data2.pos_weight = pos_weight self.task.train()
[docs] def send_message(self): """ Sends the client's local model parameters and the computed prototypes to the server. """ self.task.model.eval() emb,logits = self.task.model(self.task.splitted_data["data"]) #feat = self.personal_project(emb) confidences = logits.max(1)[0] pseudo_labels = logits.max(1)[1].type_as(self.task.splitted_data["data"].y) pseudo_labels[self.task.splitted_data['train_mask']] = self.task.splitted_data["data"].y[self.task.splitted_data['train_mask']] confidences[self.task.splitted_data['train_mask']] = 1.0 unique_labels = torch.unique(pseudo_labels) proto = get_proto_norm_weighted(config['N_CLASS'], emb, pseudo_labels, confidences, unique_labels) tensor_dict = {i: proto[i].data for i in range(proto.shape[0])} self.message_pool[f"client_{self.client_id}"] = { "num_samples": self.task.num_samples, "weight": list(self.task.model.parameters()), "protos" : tensor_dict }