Source code for openfgl.flcore.adafgl.server

import torch
from openfgl.flcore.base import BaseServer
from openfgl.flcore.adafgl.adafgl_config import config
from openfgl.flcore.adafgl._utils import adj_initialize
from scipy import sparse as sp 



        
        
        
[docs]class AdaFGLServer(BaseServer): """ AdaFGLServer implements the server-side logic for federated learning using the AdaFGL model, as described in the paper "AdaFGL: A New Paradigm for Federated Node Classification with Topology Heterogeneity". It extends the BaseServer class by managing the aggregation of model updates from clients, and coordinating the training process across different phases, particularly handling topology heterogeneity. Attributes: phase (int): Indicates the current phase of the server's operations. It starts at 0 (initial phase) and switches to 1 (AdaFGL phase) when the vanilla rounds are completed. """
[docs] def __init__(self, args, global_data, data_dir, message_pool, device): """ Initializes the AdaFGLServer. Attributes: args (Namespace): Arguments containing model and training configurations. global_data (object): Global dataset accessible by the server. data_dir (str): Directory containing the data. message_pool (object): Pool for managing messages between server and clients. device (torch.device): Device to run the computations on. """ super(AdaFGLServer, self).__init__(args, global_data, data_dir, message_pool, device, personalized=True) self.phase = 0
[docs] def execute(self): """ Executes the server-side operations. This method handles the switching between different phases of the federated learning process, and aggregates model updates from clients during the initial phase. """ # switch phase if self.message_pool["round"] == config["num_rounds_vanilla"]: self.phase = 1 # execute if self.phase == 0: with torch.no_grad(): num_tot_samples = sum([self.message_pool[f"client_{client_id}"]["num_samples"] for client_id in self.message_pool[f"sampled_clients"]]) for it, client_id in enumerate(self.message_pool["sampled_clients"]): weight = self.message_pool[f"client_{client_id}"]["num_samples"] / num_tot_samples for (local_param, global_param) in zip(self.message_pool[f"client_{client_id}"]["weight"], self.task.model.parameters()): if it == 0: global_param.data.copy_(weight * local_param) else: global_param.data += weight * local_param else: pass # do nothing
[docs] def send_message(self): """ Sends a message to the clients containing the aggregated model parameters. The content of the message depends on the current phase. """ if self.phase == 0: self.message_pool["server"] = { "weight": list(self.task.model.parameters()) } else: self.message_pool["server"] = {}