import torch
import torch.nn.functional as F
from openfgl.flcore.base import BaseServer
from openfgl.flcore.feddep.localdep import Classifier_F
[docs]class FedDEPEServer(BaseServer):
"""
FedDEPEServer is a server implementation for the Federated Learning algorithm with Deep
Efficient Private Neighbor Generation for Subgraph Federated Learning (FedDEP). This
server manages the aggregation of model parameters from multiple clients and oversees
the global model updates in a federated learning environment.
Attributes:
None (inherits attributes from BaseServer)
"""
[docs] def __init__(self, args, global_data, data_dir, message_pool, device):
"""
Initializes the FedDEPEServer.
Attributes:
args (Namespace): Arguments containing model and training configurations.
global_data (object): Global dataset accessible by the server.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between server and clients.
device (torch.device): Device to run the computations on.
"""
super(FedDEPEServer, self).__init__(args, global_data, data_dir, message_pool, device)
self.task.load_custom_model(Classifier_F(
input_dim=(self.task.num_feats, self.args.hid_dim),
hid_dim=self.args.hid_dim, output_dim=self.task.num_global_classes,
num_layers=self.args.num_layers, dropout=self.args.dropout))
self.task.loss_fn = F.cross_entropy
self.task.override_evaluate = self.get_override_evaluate()
[docs] def execute(self):
"""
Executes the server-side operations. If it's not the initial round, this method
aggregates the model parameters received from sampled clients by computing their
weighted average to update the global model.
"""
if self.message_pool["round"] == 0:
pass
else:
with torch.no_grad():
for it, client_id in enumerate(self.message_pool["sampled_clients"]):
weight = 1 / len(self.message_pool["sampled_clients"])
for local_param, global_param in zip(
self.message_pool[f"client_{client_id}"]["weight"],
self.task.model.parameters()):
if it == 0:
global_param.data.copy_(weight * local_param)
else:
global_param.data += weight * local_param
[docs] def send_message(self):
"""
Sends a message to the clients containing the updated global model parameters
after aggregation.
"""
self.message_pool["server"] = {"weight": list(self.task.model.parameters())}
[docs] def get_override_evaluate(self):
"""
Overrides the default evaluation method. This method evaluates the global model
on the training, validation, and test datasets using the specified evaluation
metrics.
Returns:
function: A custom evaluation function.
"""
from openfgl.utils.metrics import compute_supervised_metrics
def override_evaluate(splitted_data=None, mute=False):
"""
Evaluates the model on the provided dataset splits (or the default splits) and
computes relevant metrics. Outputs evaluation information unless muted.
Args:
splitted_data (dict, optional): The dataset splits to evaluate on. Defaults to None.
mute (bool, optional): If True, suppresses the print output. Defaults to False.
Returns:
dict: Evaluation output containing losses and metrics for training, validation, and test datasets.
"""
if splitted_data is None:
splitted_data = self.task.splitted_data
else:
names = ["train_mask", "val_mask", "test_mask"]
for name in names:
assert name in splitted_data
splitted_data["data"] = splitted_data["data"].to(self.device)
eval_output = {}
self.task.model.eval()
with torch.no_grad():
logits = self.task.model.forward(splitted_data["data"])
loss_train = self.task.loss_fn(logits[splitted_data["train_mask"]], splitted_data["data"].y[splitted_data["train_mask"]])
loss_val = self.task.loss_fn(logits[splitted_data["val_mask"]], splitted_data["data"].y[splitted_data["val_mask"]])
loss_test = self.task.loss_fn(logits[splitted_data["test_mask"]], splitted_data["data"].y[splitted_data["test_mask"]])
eval_output["loss_train"] = loss_train
eval_output["loss_val"] = loss_val
eval_output["loss_test"] = loss_test
metric_train = compute_supervised_metrics(
metrics=self.args.metrics,
logits=logits[splitted_data["train_mask"]],
labels=splitted_data["data"].y[splitted_data["train_mask"]],
suffix="train"
)
metric_val = compute_supervised_metrics(
metrics=self.args.metrics,
logits=logits[splitted_data["val_mask"]],
labels=splitted_data["data"].y[splitted_data["val_mask"]],
suffix="val"
)
metric_test = compute_supervised_metrics(
metrics=self.args.metrics,
logits=logits[splitted_data["test_mask"]],
labels=splitted_data["data"].y[splitted_data["test_mask"]],
suffix="test"
)
eval_output = {**eval_output, **metric_train, **metric_val, **metric_test}
info = ""
for key, val in eval_output.items():
try:
info += f"\t{key}: {val:.4f}"
except:
continue
prefix = "[server]"
if not mute:
print(prefix + info)
return eval_output
return override_evaluate