Source code for openfgl.flcore.adafgl.client

import torch
import torch.nn as nn
from openfgl.flcore.base import BaseClient
from openfgl.flcore.adafgl.adafgl_models import AdaFGLModel
from openfgl.flcore.adafgl.adafgl_config import config
from openfgl.flcore.adafgl._utils import adj_initialize
from torch.optim import Adam
import torch.nn.functional as F


[docs]class AdaFGLClient(BaseClient): """ AdaFGLClient implements the client-side logic for federated learning using the AdaFGL model, as described in the paper "AdaFGL: A New Paradigm for Federated Node Classification with Topology Heterogeneity". It extends the BaseClient class by incorporating topology-aware learning methods, enabling the client to adapt to varying graph structures during the federated learning process. Attributes: phase (int): Indicates the current phase of training. Initially set to 0, it switches to 1 when entering the AdaFGL phase. """
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device): """ Initializes the AdaFGLClient. Attributes: args (Namespace): Arguments containing model and training configurations. client_id (int): ID of the client. data (object): Data specific to the client's task. data_dir (str): Directory containing the data. message_pool (object): Pool for managing messages between client and server. device (torch.device): Device to run the computations on. """ super(AdaFGLClient, self).__init__(args, client_id, data, data_dir, message_pool, device, personalized=True) self.phase = 0
[docs] def execute(self): """ Executes the training process. This method handles the switching between different phases of training, initializes the AdaFGL model, and performs training based on the current phase. """ # switch phase if self.message_pool["round"] == config["num_rounds_vanilla"]: self.phase = 1 self.adafgl_model = AdaFGLModel(prop_steps=config["prop_steps"], feat_dim=self.task.num_feats, hidden_dim=self.args.hid_dim, output_dim=self.task.num_global_classes, train_mask=self.task.train_mask, val_mask=self.task.val_mask, test_mask=self.task.test_mask, r=config["r"]) self.task.data = adj_initialize(self.task.data) # last time download global model with torch.no_grad(): for (local_param, global_param) in zip(self.task.model.parameters(), self.message_pool["server"]["weight"]): local_param.data.copy_(global_param) # initialize adafgl model eval_output = self.task.evaluate() node_logits = eval_output["logits"] soft_label = nn.Softmax(dim=1)(node_logits) self.adafgl_model.non_para_lp(subgraph=self.task.data, soft_label=soft_label, x=self.task.data.x, device=self.device) self.adafgl_model.preprocess(adj=self.task.data.adj) self.adafgl_model = self.adafgl_model.to(self.device) # create optimizer self.adafgl_optimizer = Adam(self.adafgl_model.parameters(), lr=config["adafgl_lr"], weight_decay=config["adafgl_weight_decay"]) # override evaluation self.task.override_evaluate = self.get_adafgl_override_evaluate() # execute if self.phase == 0: with torch.no_grad(): for (local_param, global_param) in zip(self.task.model.parameters(), self.message_pool["server"]["weight"]): local_param.data.copy_(global_param) self.task.train() else: self.adafgl_postprocess()
[docs] def send_message(self): """ Sends a message to the server containing the model parameters and the number of samples in the current client's dataset. The content of the message depends on the current phase. """ if self.phase == 0: self.message_pool[f"client_{self.client_id}"] = { "num_samples": self.task.num_samples, "weight": list(self.task.model.parameters()) } else: self.message_pool[f"client_{self.client_id}"] = {}
[docs] def adafgl_postprocess(self, loss_ce_fn=nn.CrossEntropyLoss()): """ Performs post-processing for the AdaFGL model after training. This includes computing the loss, performing backpropagation, and updating the model parameters. Attributes: loss_ce_fn: Loss function for cross-entropy, default is nn.CrossEntropyLoss(). """ self.adafgl_model.train() self.adafgl_optimizer.zero_grad() # homo forward local_smooth_logits, global_logits = self.adafgl_model.homo_forward(self.device) loss_train1 = loss_ce_fn(local_smooth_logits[self.task.train_mask], self.task.data.y[self.task.train_mask]) loss_train2 = nn.MSELoss()(local_smooth_logits, global_logits) loss_train_homo = loss_train1 + loss_train2 # hete forward local_ori_logits, local_smooth_logits, local_message_propagation = self.adafgl_model.hete_forward(self.device) loss_train1 = loss_ce_fn(local_ori_logits[self.task.train_mask], self.task.data.y[self.task.train_mask]) loss_train2 = loss_ce_fn(local_smooth_logits[self.task.train_mask], self.task.data.y[self.task.train_mask]) loss_train3 = loss_ce_fn(local_message_propagation[self.task.train_mask], self.task.data.y[self.task.train_mask]) loss_train_hete = loss_train1 + loss_train2 + loss_train3 # final loss loss_final = loss_train_homo + loss_train_hete loss_final.backward() self.adafgl_optimizer.step()
[docs] def get_adafgl_override_evaluate(self): """ Returns a custom evaluation function that overrides the default evaluation method for the AdaFGL model. This function computes metrics based on both homogeneous and heterogeneous forward passes. Returns: override_evaluate (function): A custom evaluation function. """ from openfgl.utils.metrics import compute_supervised_metrics def override_evaluate(splitted_data=None, mute=False): assert splitted_data is None, "AdaFGL doesn't support global data evaluation." splitted_data = self.task.splitted_data eval_output = {} self.adafgl_model.eval() # homo eval with torch.no_grad(): local_smooth_logits, global_logits = self.adafgl_model.homo_forward(self.device) output_homo = (F.softmax(local_smooth_logits.data, 1) + F.softmax(global_logits.data, 1)) / 2 # hete eval local_ori_logits, local_smooth_logits, local_message_propagation = self.adafgl_model.hete_forward(self.device) output_hete = (F.softmax(local_ori_logits.data, 1) + F.softmax(local_smooth_logits.data, 1) + F.softmax( local_message_propagation.data, 1)) / 3 # merge homo_weight = self.adafgl_model.reliability_acc embedding = None logits = homo_weight * output_homo + (1- homo_weight) * output_hete loss_train = self.task.loss_fn(embedding, logits, splitted_data["data"].y, splitted_data["train_mask"]) loss_val = self.task.loss_fn(embedding, logits, splitted_data["data"].y, splitted_data["val_mask"]) loss_test = self.task.loss_fn(embedding, logits, splitted_data["data"].y, splitted_data["test_mask"]) eval_output["embedding"] = embedding eval_output["logits"] = logits eval_output["loss_train"] = loss_train eval_output["loss_val"] = loss_val eval_output["loss_test"] = loss_test metric_train = compute_supervised_metrics(metrics=self.args.metrics, logits=logits[splitted_data["train_mask"]], labels=splitted_data["data"].y[splitted_data["train_mask"]], suffix="train") metric_val = compute_supervised_metrics(metrics=self.args.metrics, logits=logits[splitted_data["val_mask"]], labels=splitted_data["data"].y[splitted_data["val_mask"]], suffix="val") metric_test = compute_supervised_metrics(metrics=self.args.metrics, logits=logits[splitted_data["test_mask"]], labels=splitted_data["data"].y[splitted_data["test_mask"]], suffix="test") eval_output = {**eval_output, **metric_train, **metric_val, **metric_test} info = "" for key, val in eval_output.items(): try: info += f"\t{key}: {val:.4f}" except: continue prefix = f"[client {self.client_id}]" if self.client_id is not None else "[server]" if not mute: print(prefix+info) return eval_output return override_evaluate