Source code for openfgl.flcore.fedtgp.server

import torch
from openfgl.flcore.base import BaseServer
import torch.nn as nn
import torch.nn.functional as F

class Trainable_prototypes(nn.Module):
    """
    Trainable_prototypes implements a neural network model that learns global prototypes 
    for each class in a federated learning setup. The model is used by the server in 
    FedTGP to generate prototypes that adapt to data and model heterogeneity across clients.

    Attributes:
        device (torch.device): The device on which the model is running (e.g., CPU or GPU).

        embeddings (nn.Embedding): Embedding layer that stores a trainable vector for each class.

        middle (nn.Sequential): A sequence of layers that apply a linear transformation followed by a ReLU activation function.
        
        fc (nn.Linear): The final fully connected layer that maps the transformed embeddings to the desired feature dimension.
    """
    def __init__(self, num_classes, server_hidden_dim, feature_dim, device):
        super().__init__()

        self.device = device

        self.embedings = nn.Embedding(num_classes, feature_dim)
        layers = [nn.Sequential(
            nn.Linear(feature_dim, server_hidden_dim), 
            nn.ReLU()
        )]
        self.middle = nn.Sequential(*layers)
        self.fc = nn.Linear(server_hidden_dim, feature_dim)

    def forward(self, class_id):
        class_id = torch.tensor(class_id, device=self.device)

        emb = self.embedings(class_id)
        mid = self.middle(emb)
        z = self.fc(mid)
        return z
    
    
[docs]class FedTGPServer(BaseServer): """ FedTGPServer implements the server-side operations for the Federated Learning algorithm described in the paper 'FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning'. This class handles the global aggregation of prototypes, the training of global prototypes, and communication with the clients. Attributes: fedtgp_lambda (float): Weight for the FedTGP loss component. num_glb_epochs (int): Number of global epochs for training the prototypes. lr_glb (float): Learning rate for the global prototype optimizer. trainable_prototypes (nn.Module): A trainable model for generating global prototypes. gp_optimizer (torch.optim.Optimizer): Optimizer for training the global prototypes. global_prototype (dict): Dictionary to store the global prototypes for each class. """
[docs] def __init__(self, args, global_data, data_dir, message_pool, device, fedtgp_lambda=1, num_glb_epochs=10, lr_glb=1e-2): """ Initializes the FedTGPServer. Attributes: args (Namespace): Arguments containing model and training configurations. global_data (object): The global dataset, if available. data_dir (str): Directory containing the data. message_pool (object): Pool for managing messages between server and clients. device (torch.device): Device to run the computations on. fedtgp_lambda (float): Weight for the FedTGP loss component, default is 1. num_glb_epochs (int): Number of global epochs for training the prototypes, default is 10. lr_glb (float): Learning rate for the global prototype optimizer, default is 1e-2. """ super(FedTGPServer, self).__init__(args, global_data, data_dir, message_pool, device) self.fedtgp_lambda = fedtgp_lambda self.num_glb_epochs = num_glb_epochs self.lr_glb = lr_glb self.trainable_prototypes = Trainable_prototypes(self.task.num_global_classes, args.hid_dim, args.hid_dim, device).to(device) self.gp_optimizer = torch.optim.Adam(self.trainable_prototypes.parameters(), lr=lr_glb, weight_decay=args.weight_decay) self.global_prototype = {}
[docs] def execute(self): """ Executes the global aggregation and prototype training process. The method first aggregates local prototypes from the clients, computes the global prototypes, and trains the global prototypes using an adaptive-margin-enhanced contrastive learning approach. """ y_list = [] tensor_list = [] for client_i in self.message_pool["sampled_clients"]: for class_i in range(self.task.num_global_classes): y_list.append(class_i) tensor_list.append(self.message_pool[f"client_{client_i}"]["local_prototype"][class_i]) y = torch.tensor(y_list).type(torch.int64).to(self.device) all_local_prototypes = torch.cat([v.unsqueeze(0) for v in tensor_list], dim=0) row_id = [class_id for class_id in range(self.task.num_global_classes)] avg_proto = torch.zeros((self.task.num_global_classes, all_local_prototypes.shape[1])).to(self.device) num_local_prototypes = len(tensor_list) for proto_i in range(num_local_prototypes): avg_proto[y_list[proto_i]] += all_local_prototypes[proto_i,:] for class_i in range(self.task.num_global_classes): avg_proto /= y_list.count(class_i) gap = torch.ones(self.task.num_global_classes, device=self.device) * 1e9 for k1 in range(self.task.num_global_classes): for k2 in range(self.task.num_global_classes): if k1 > k2: dis = torch.norm(avg_proto[k1] - avg_proto[k2], p=2) gap[k1] = torch.min(gap[k1], dis) gap[k2] = torch.min(gap[k2], dis) min_gap = torch.min(gap) for i in range(len(gap)): if gap[i] > torch.tensor(1e8, device=self.device): gap[i] = min_gap max_gap = torch.max(gap) for _ in range(self.num_glb_epochs): self.gp_optimizer.zero_grad() global_prototypes = self.trainable_prototypes.forward(row_id) features_square = torch.sum(torch.pow(all_local_prototypes, 2), 1, keepdim=True) centers_square = torch.sum(torch.pow(global_prototypes, 2), 1, keepdim=True) features_into_centers = torch.matmul(all_local_prototypes, global_prototypes.T) dist = features_square - 2 * features_into_centers + centers_square.T dist = torch.sqrt(dist) one_hot = F.one_hot(y, self.task.num_global_classes).to(self.device) gap2 = min(max_gap.item(), 100) dist = dist + one_hot * gap2 glb_loss = 0 glb_loss = nn.CrossEntropyLoss()(-dist, y) glb_loss.backward() self.gp_optimizer.step() for class_i in range(self.task.num_global_classes): self.global_prototype[class_i] = self.trainable_prototypes.forward(class_i).detach()
[docs] def send_message(self): """ Sends a message to the clients containing the updated global prototypes. """ self.message_pool["server"] = { "global_prototype": self.global_prototype }