Source code for openfgl.flcore.fedtgp.client

import torch
import torch.nn as nn
from openfgl.flcore.base import BaseClient
from openfgl.flcore.fedtgp.fedtgp_config import config

[docs]class FedTGPClient(BaseClient): """ FedTGPClient implements the client-side operations for the Federated Learning algorithm described in the paper 'FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning'. This class handles the local training of the model, updating local prototypes, and communication with the server. Attributes: local_prototype (dict): Stores the local prototypes for each class. """
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device): """ Initializes the FedTGPClient. Attributes: args (Namespace): Arguments containing model and training configurations. client_id (int): ID of the client. data (object): Data specific to the client's task. data_dir (str): Directory containing the data. message_pool (object): Pool for managing messages between client and server. device (torch.device): Device to run the computations on. """ super(FedTGPClient, self).__init__(args, client_id, data, data_dir, message_pool, device) self.local_prototype = {}
[docs] def execute(self): """ Executes the local training process. The method applies the custom loss function that includes the prototype-based loss, and updates the local prototypes after training. """ self.task.loss_fn = self.get_custom_loss_fn() self.task.train() self.update_local_prototype()
[docs] def get_custom_loss_fn(self): """ Defines a custom loss function that incorporates both the standard classification loss and the prototype-based loss for contrastive learning. Returns: custom_loss_fn (function): A function that computes the loss based on the current round and whether the model is being evaluated on global data. """ def custom_loss_fn(embedding, logits, label, mask): if self.message_pool["round"] == 0 or self.task.num_samples != label.shape[0]: # first round or eval on global return self.task.default_loss_fn(logits[mask], label[mask]) else: loss_fedtgp = 0 for class_i in range(self.task.num_global_classes): selected_idx = self.task.train_mask & (label == class_i) if selected_idx.sum() == 0: continue input = embedding[selected_idx] target = self.message_pool["server"]["global_prototype"][class_i].expand_as(input) loss_fedtgp += nn.MSELoss()(input, target) return self.task.default_loss_fn(logits[mask], label[mask]) + config["fedtgp_lambda"] * loss_fedtgp return custom_loss_fn
[docs] def update_local_prototype(self): """ Updates the local prototypes for each class based on the embeddings generated by the local model after training. """ with torch.no_grad(): embedding = self.task.evaluate(mute=True)["embedding"] for class_i in range(self.task.num_global_classes): selected_idx = self.task.train_mask & (self.task.data.y.to(self.device) == class_i) if selected_idx.sum() == 0: self.local_prototype[class_i] = torch.zeros(self.args.hid_dim).to(self.device) else: input = embedding[selected_idx] self.local_prototype[class_i] = torch.mean(input, dim=0)
[docs] def send_message(self): """ Sends a message to the server containing the number of samples and the updated local prototypes. """ self.message_pool[f"client_{self.client_id}"] = { "num_samples": self.task.num_samples, "local_prototype": self.local_prototype }