import torch
from openfgl.flcore.base import BaseServer
from openfgl.flcore.fedtad.fedtad_config import config
from torch.optim import Adam
from openfgl.flcore.fedtad.generator import FedTAD_ConGenerator
import torch.nn.functional as F
from openfgl.flcore.fedtad._utils import construct_graph, DiversityLoss
import torch.nn as nn
[docs]class FedTADServer(BaseServer):
"""
FedTADServer implements the server-side operations for the Federated Learning algorithm
described in the paper 'FedTAD: Topology-aware Data-free Knowledge Distillation for Subgraph
Federated Learning'. This class handles global model aggregation, the training of a generator
for knowledge distillation, and the coordination of knowledge sharing between clients.
Attributes:
generator (FedTAD_ConGenerator): A generator model used for creating pseudo graphs to facilitate knowledge distillation.
generator_optimizer (torch.optim.Optimizer): Optimizer for the generator model.
"""
[docs] def __init__(self, args, global_data, data_dir, message_pool, device):
"""
Initializes the FedTADServer.
Attributes:
args (Namespace): Arguments containing model and training configurations.
global_data (object): Global dataset accessible to the server.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
"""
super(FedTADServer, self).__init__(args, global_data, data_dir, message_pool, device)
self.task.optim = Adam(self.task.model.parameters(), lr=config["lr_d"], weight_decay=self.args.weight_decay)
self.generator = FedTAD_ConGenerator(noise_dim=config["noise_dim"], feat_dim=args.hid_dim if config["distill_mode"] == 'rep_distill' else self.task.num_feats, out_dim=self.task.num_global_classes, dropout=config["gen_dropout"]).to(device)
self.generator_optimizer = Adam(self.generator.parameters(), lr=config["lr_g"], weight_decay=self.args.weight_decay)
[docs] def execute(self):
"""
Executes the main operations of the server during a federated learning round.
This includes aggregating the model parameters from the clients, training a generator
to create pseudo graphs for knowledge distillation, and updating the global model based
on the generated data and the knowledge shared by the clients.
"""
# global aggregation
with torch.no_grad():
num_tot_samples = sum([self.message_pool[f"client_{client_id}"]["num_samples"] for client_id in self.message_pool[f"sampled_clients"]])
for it, client_id in enumerate(self.message_pool["sampled_clients"]):
weight = self.message_pool[f"client_{client_id}"]["num_samples"] / num_tot_samples
for (local_param, global_param) in zip(self.message_pool[f"client_{client_id}"]["weight"], self.task.model.parameters()):
if it == 0:
global_param.data.copy_(weight * local_param)
else:
global_param.data += weight * local_param
# initialize
c_cnt = [0] * self.task.num_global_classes
for class_i in range(self.task.num_global_classes):
c_cnt[class_i] = int(config["num_gen"] * 1 / self.task.num_global_classes)
c_cnt[-1] += config["num_gen"] - sum(c_cnt)
c = torch.zeros(config["num_gen"]).to(self.device).long()
ptr = 0
for class_i in range(self.task.num_global_classes):
for _ in range(c_cnt[class_i]):
c[ptr] = class_i
ptr += 1
each_class_mask = {}
for class_i in range(self.task.num_global_classes):
each_class_mask[class_i] = c == class_i
each_class_mask[class_i] = each_class_mask[class_i].to(self.device)
for client_id in self.message_pool["sampled_clients"]:
self.message_pool[f"client_{client_id}"]["model"].eval()
for _ in range(config["glb_epochs"]):
############ sampling noise ##############
z = torch.randn((config["num_gen"], 32)).to(self.device)
############ train generator ##############
self.generator.train()
self.task.model.eval()
for it_g in range(config["it_g"]):
loss_sem = 0
loss_diverg = 0
loss_div = 0
self.generator_optimizer.zero_grad()
for client_id in self.message_pool["sampled_clients"]:
###### generator forward ########
node_logits = self.generator.forward(z=z, c=c)
node_norm = F.normalize(node_logits, p=2, dim=1)
adj_logits = torch.mm(node_norm, node_norm.t())
pseudo_graph = construct_graph(node_logits, adj_logits, k=config["topk"])
##### local & global model -> forward #########
local_embedding, local_logits = self.message_pool[f"client_{client_id}"]["model"].forward(pseudo_graph)
global_embedding, global_logits = self.task.model.forward(pseudo_graph)
if config["distill_mode"] == 'rep_distill':
local_pred = local_embedding
global_pred = global_embedding
else:
local_pred = local_logits
global_pred = global_logits
########## semantic loss #############
for class_i in range(self.task.num_global_classes):
loss_sem += self.message_pool[f"client_{client_id}"]["ckr"][class_i] * nn.CrossEntropyLoss()(local_pred[each_class_mask[class_i]], c[each_class_mask[class_i]])
############ diversity loss ##############
loss_div += DiversityLoss(metric='l1').to(self.device)(z.view(z.shape[0],-1), node_logits)
############ divergence loss ############
for class_i in range(self.task.num_global_classes):
loss_diverg += - self.message_pool[f"client_{client_id}"]["ckr"][class_i] * torch.mean(torch.mean(
torch.abs(global_pred[each_class_mask[class_i]] - local_pred[each_class_mask[class_i]].detach()), dim=1))
############ generator loss #############
loss_G = config["lam1"] * loss_sem + loss_diverg + config["lam2"] * loss_div
loss_G.backward()
self.generator_optimizer.step()
########### train global model ###########
self.generator.eval()
self.task.model.train()
###### generator forward ########
node_logits = self.generator.forward(z=z, c=c)
node_norm = F.normalize(node_logits, p=2, dim=1)
adj_logits = torch.mm(node_norm, node_norm.t())
pseudo_graph = construct_graph(node_logits.detach(), adj_logits.detach(), k=config["topk"])
for it_d in range(config["it_d"]):
self.task.optim.zero_grad()
loss_D = 0
for client_id in self.message_pool["sampled_clients"]:
####### local & global model -> forward #######
local_embedding, local_logits = self.message_pool[f"client_{client_id}"]["model"].forward(pseudo_graph)
global_embedding, global_logits = self.task.model.forward(pseudo_graph)
if config["distill_mode"] == 'rep_distill':
local_pred = local_embedding
global_pred = global_embedding
else:
local_pred = local_logits
global_pred = global_logits
############ divergence loss ############
for class_i in range(self.task.num_global_classes):
loss_D += self.message_pool[f"client_{client_id}"]["ckr"][class_i] * torch.mean(torch.mean(
torch.abs(global_pred[each_class_mask[class_i]] - local_pred[each_class_mask[class_i]]), dim=1))
loss_D.backward()
self.task.optim.step()
[docs] def send_message(self):
"""
Sends the updated global model weights to the clients.
The message sent to the clients includes the updated model parameters after aggregation
and knowledge distillation.
"""
self.message_pool["server"] = {
"weight": list(self.task.model.parameters())
}