Source code for openfgl.flcore.fedstar.client

import torch
import torch.nn as nn
from openfgl.flcore.base import BaseClient
from openfgl.flcore.fedstar._utils import init_structure_encoding
from openfgl.flcore.fedstar.gin_dc import DecoupledGIN
from torch_geometric.loader import DataLoader
from openfgl.flcore.fedstar.fedstar_config import config


[docs]class FedStarClient(BaseClient): """ FedStarClient is the client-side implementation for the Federated Learning algorithm described in the paper 'Federated Learning on Non-IID Graphs via Structural Knowledge Sharing'. This class handles local training, structural knowledge sharing, and communication with the server within a federated learning framework. Attributes: task (object): The task object that holds the model and data for training. device (torch.device): The device on which computations will be performed. """
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device): """ Initializes the FedStarClient. Attributes: args (Namespace): Arguments containing model and training configurations. client_id (int): ID of the client. data (object): The graph data specific to the client's task. data_dir (str): Directory containing the data. message_pool (object): Pool for managing messages between client and server. device (torch.device): Device to run the computations on. """ super(FedStarClient, self).__init__(args, client_id, data, data_dir, message_pool, device, personalized=True) self.task.load_custom_model(DecoupledGIN(input_dim=self.task.num_feats, hid_dim=self.args.hid_dim, output_dim=self.task.num_global_classes, n_se=config["n_rw"] + config["n_dg"], num_layers=self.args.num_layers, dropout=self.args.dropout).to(self.device)) self.task.data = init_structure_encoding(config["n_rw"], config["n_dg"], self.task.data, config["type_init"]) tmp = torch.nonzero(self.task.train_mask, as_tuple=True)[0] self.task.splitted_data['train_dataloader'] = DataLoader([self.task.data[i] for i in tmp], batch_size=self.args.batch_size, shuffle=False) tmp = torch.nonzero(self.task.val_mask, as_tuple=True)[0] self.task.splitted_data['val_dataloader'] = DataLoader([self.task.data[i] for i in tmp], batch_size=self.args.batch_size, shuffle=False) tmp = torch.nonzero(self.task.test_mask, as_tuple=True)[0] self.task.splitted_data['test_dataloader'] = DataLoader([self.task.data[i] for i in tmp], batch_size=self.args.batch_size, shuffle=False)
[docs] def get_custom_loss_fn(self): """ Returns a custom loss function for the FedStarClient. This loss function is based on negative log-likelihood loss (nll_loss) and is applied to the model's logits and the true labels. Returns: function: A function that computes the loss given embeddings, logits, labels, and mask. """ def custom_loss_fn(embedding, logits, label, mask): loss = torch.nn.functional.nll_loss(logits[mask], label[mask]) return loss return custom_loss_fn
[docs] def execute(self): """ Executes the local training process. Before training, the structural knowledge (weights with '_s' in their names) is updated with the global weights received from the server. """ with torch.no_grad(): g_w = self.message_pool["server"]["weight"] for k,v in self.task.model.state_dict().items(): if '_s' in k: v.data = g_w[k].data.clone() self.task.loss_fn = self.get_custom_loss_fn() self.task.train()
[docs] def send_message(self): """ Sends a message to the server containing the local model's state_dict and the number of samples used for training. """ self.message_pool[f"client_{self.client_id}"] = { "num_samples": self.task.num_samples, "weight": self.task.model.state_dict() }