import copy
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.loader import NeighborSampler
from openfgl.flcore.base import BaseClient
from openfgl.flcore.feddep.localdep import LocalDGen, FedDEP, Classifier_F
from openfgl.flcore.feddep._utils import HideGraph, LocalRecLoss, FedRecLoss, GraphMender
from openfgl.flcore.feddep.feddep_config import config
[docs]class FedDEPClient(BaseClient):
"""
FedDEPClient is a client implementation for the Federated Learning algorithm with Deep
Efficient Private Neighbor Generation for Subgraph Federated Learning (FedDEP).
It extends the BaseClient class and handles local training, private neighbor generation,
and graph recovery tasks in a federated setting.
Attributes:
hide_graph_model (nn.Module): Model for generating private subgraphs by hiding parts of the input graph.
data (torch_geometric.data.Data): The original dataset after applying splits for training, validation, and testing.
hide_data (torch_geometric.data.Data): The dataset with hidden information for training the private neighbor generation model.
emb (torch.Tensor): Embedding tensor generated by the hide_graph_model.
x_missing (torch.Tensor): Tensor containing the missing features of the graph generated by the hide_graph_model.
loss_fn_num (function): Loss function used for numerical loss during training.
loss_fn_rec (function): Loss function used for reconstruction loss during training.
fill_dataloader (dict): Dataloader for filling in missing graph information.
"""
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device):
"""
Initializes the FedDEPClient.
Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client's task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
"""
super(FedDEPClient, self).__init__(args, client_id, data, data_dir, message_pool, device)
self.task.load_custom_model(Classifier_F(
input_dim=(self.task.num_feats, self.args.hid_dim),
hid_dim=self.args.hid_dim, output_dim=self.task.num_global_classes,
num_layers=self.args.num_layers, dropout=self.args.dropout))
self.hide_graph_model = HideGraph(encoder_hid_dim=self.args.hid_dim, encoder_output_dim=self.task.num_global_classes, encoder_num_layers=self.args.num_layers, hidden_portion=config["hide_portion"], num_preds=config["num_preds"], num_protos=config["num_protos"], device=device)
self.data = self.task.splitted_data["data"]
self.data.train_mask = self.task.splitted_data["train_mask"]
self.data.val_mask = self.task.splitted_data["val_mask"]
self.data.test_mask = self.task.splitted_data["test_mask"]
self.hide_data, self.emb, self.x_missing = self.hide_graph_model(data=self.data)
self.loss_fn_num = F.smooth_l1_loss
self.loss_fn_rec = LocalRecLoss
self.task.loss_fn = F.cross_entropy
self.task.override_evaluate = self.get_override_evaluate()
self.send_message()
[docs] def send_message(self):
"""
Sends a message to the server containing the current model parameters, the embedding
tensor, and the tensor of missing features after applying the hide_graph_model.
"""
self.message_pool[f"client_{self.client_id}"] = {
"num_samples": self.task.num_samples,
"weight": list(self.task.model.parameters()),
"embedding": self.emb,
"x_missing": self.x_missing
}
[docs] def execute(self):
"""
Executes the local training process. Depending on the current round, it either pre-trains
the local model for private neighbor generation and prepares the FedDEP model, or performs
the federated training with the global model updates.
"""
# switch phase
if self.message_pool["round"] == 0:
self.phase = 0
if self.message_pool["round"] == 1:
self.phase = 1
self.filled_data = GraphMender(
model=self.feddep_model, impaired_data=self.hide_data,
original_data=self.data, num_preds=config["num_preds"])
self.filled_data["data"] = self.filled_data["data"].to(self.device)
# subgraph_sampler = NeighborSampler(
# self.data.edge_index, num_nodes=self.data.num_nodes,
# sizes=[-1], batch_size=4096, shuffle=False)
self.fill_dataloader = {
"data": self.filled_data["data"],
"train": NeighborSampler(
self.filled_data["data"].edge_index,
num_nodes=self.filled_data["data"].num_nodes,
node_idx=torch.where(self.filled_data["train_mask"] == True)[0],
sizes=[5, 5],
batch_size=64,
shuffle=True
),
# "val": subgraph_sampler,
# "test": subgraph_sampler
}
# execute
if self.phase == 0:
pre_train_model = LocalDGen(input_dim=self.task.num_feats,
emb_shape=self.args.hid_dim,
output_dim=self.task.num_global_classes, hid_dim=self.args.hid_dim,
gen_dim=config["gen_hidden"], dropout=self.args.dropout,
num_preds=config["num_preds"]).to(self.device)
print(f"Client {self.client_id} pre-train start...")
pre_train_model.train()
pre_train_optim = self.task.default_optim(pre_train_model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
for i in range(config["pre_train_epochs"]):
pred_missing, pred_emb, nc_pred = pre_train_model(self.hide_data.to(self.device))
mask_true_index = np.where(self.hide_data.train_mask.cpu().numpy() == True)[0]
loss_num = self.loss_fn_num(
pred_missing[self.hide_data.train_mask],
self.hide_data.num_missing[self.hide_data.train_mask]
)
loss_rec = self.loss_fn_rec(
pred_embs=pred_emb[self.hide_data.train_mask],
true_embs=[self.hide_data.x_missing[node] for node in mask_true_index],
pred_missing=pred_missing[self.hide_data.train_mask],
true_missing=self.hide_data.num_missing[self.hide_data.train_mask],
num_preds=config["num_preds"]
)
loss_clf = self.task.loss_fn(
nc_pred[self.hide_data.train_mask],
self.hide_data.y[self.hide_data.train_mask],
)
per_train_loss = config["beta_d"] * loss_num + config["beta_c"] * loss_clf + config["beta_n"] * loss_rec
pre_train_optim.zero_grad()
per_train_loss.backward()
pre_train_optim.step()
print(f"Client {self.client_id} local pre-train @Epoch {i}.")
print(f"Client {self.client_id} pre-train finish!")
self.feddep_model = FedDEP(pre_train_model).to(self.device)
feddep_optim = self.task.default_optim(self.feddep_model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
for i in range(config["feddep_epochs"]):
dep_grad = dict()
para_backup = copy.deepcopy(self.feddep_model.state_dict())
for client_id in self.message_pool["sampled_clients"]:
if client_id != self.client_id:
# calculate gradients
emb, x_missing = (
self.message_pool[f"client_{client_id}"]["embedding"],
self.message_pool[f"client_{client_id}"]["x_missing"])
self.feddep_model.load_state_dict(para_backup)
self.feddep_model.train()
_, embedding = self.feddep_model.encoder_model(self.hide_data)
pred_missing = self.feddep_model.reg_model(embedding)
pred_embs = self.feddep_model.gen(embedding)
emb_len = pred_embs.shape[-1] // config["num_preds"]
choice = np.random.choice(len(x_missing), embedding.shape[0])
global_target_emb = []
for c_i in choice:
choice_i = np.random.choice(
len(x_missing[c_i]), config["num_preds"])
for ch_i in choice_i:
if torch.sum(x_missing[c_i][ch_i]) < 1e-15:
global_target_emb.append(emb[c_i])
else:
global_target_emb.append(
x_missing[c_i][ch_i].detach().cpu().numpy())
global_target_emb = np.asarray(global_target_emb).reshape(
(embedding.shape[0], config["num_preds"], emb_len))
loss_emb = FedRecLoss(
pred_embs=pred_embs,
true_embs=global_target_emb,
pred_missing=pred_missing,
num_preds=config["num_preds"],
)
other_loss = (
1.0 / self.args.num_clients * config["beta_n"] * loss_emb
).requires_grad_()
other_loss.backward()
# sum up all gradients from other clients
if not dep_grad:
for k, v in self.feddep_model.named_parameters():
dep_grad[k] = v.grad
else:
for k, v in self.feddep_model.named_parameters():
dep_grad[k] += v.grad
# Rollback
self.feddep_model.load_state_dict(para_backup)
pred_missing, pred_emb, nc_pred = self.feddep_model.forward(self.hide_data)
mask_true_index = np.where(self.hide_data.train_mask.cpu().numpy() == True)[0]
loss_num = self.loss_fn_num(
pred_missing[self.hide_data.train_mask],
self.hide_data.num_missing[self.hide_data.train_mask]
)
loss_rec = self.loss_fn_rec(
pred_embs=pred_emb[self.hide_data.train_mask],
true_embs=[self.hide_data.x_missing[node] for node in mask_true_index],
pred_missing=pred_missing[self.hide_data.train_mask],
true_missing=self.hide_data.num_missing[self.hide_data.train_mask],
num_preds=config["num_preds"]
)
loss_clf = self.task.loss_fn(
nc_pred[self.hide_data.train_mask],
self.hide_data.y[self.hide_data.train_mask],
)
feddep_loss = config["beta_d"] * loss_num + config["beta_c"] * loss_clf + config["beta_n"] * loss_rec
feddep_loss = feddep_loss.float() / self.args.num_clients
feddep_optim.zero_grad()
feddep_loss.backward()
# feddep_optim.step()
for k, v in self.feddep_model.named_parameters():
v.grad += dep_grad[k]
feddep_optim.step()
else:
for (local_param, global_param) in zip(
self.task.model.parameters(), self.message_pool["server"]["weight"]):
local_param.data.copy_(global_param)
for data_batch in self.fill_dataloader["train"]:
batch_size, n_id, adjs = data_batch
adjs = [adj.to(self.device) for adj in adjs]
if "mend_emb" not in self.fill_dataloader["data"]:
mend_emb = torch.zeros(
(len(self.fill_dataloader["data"].x), self.task.model.emb_len)
).to(self.device)
else:
mend_emb = self.fill_dataloader["data"].mend_emb
pred = self.task.model.forward(
(self.fill_dataloader["data"].x[n_id], mend_emb[n_id]), adjs=adjs)
label = self.fill_dataloader["data"].y[n_id[:batch_size]].to(self.device)
loss_clf = self.task.loss_fn(pred, label)
self.task.optim.zero_grad()
loss_clf.backward()
self.task.optim.step()
[docs] def get_override_evaluate(self):
"""
Overrides the default evaluation method to evaluate the model on the locally filled data.
This method computes the evaluation metrics on training, validation, and test datasets.
Returns:
function: A custom evaluation function.
"""
from openfgl.utils.metrics import compute_supervised_metrics
def override_evaluate(splitted_data=None, mute=False):
if splitted_data is None:
try:
splitted_data = self.filled_data
except:
splitted_data = self.task.splitted_data
else:
names = ["data", "train_mask", "val_mask", "test_mask"]
for name in names:
assert name in splitted_data
eval_output = {}
self.task.model.eval()
with torch.no_grad():
logits = self.task.model.forward(splitted_data["data"])
loss_train = self.task.loss_fn(logits[splitted_data["train_mask"]], splitted_data["data"].y[splitted_data["train_mask"]])
loss_val = self.task.loss_fn(logits[splitted_data["val_mask"]], splitted_data["data"].y[splitted_data["val_mask"]])
loss_test = self.task.loss_fn(logits[splitted_data["test_mask"]], splitted_data["data"].y[splitted_data["test_mask"]])
eval_output["loss_train"] = loss_train
eval_output["loss_val"] = loss_val
eval_output["loss_test"] = loss_test
metric_train = compute_supervised_metrics(
metrics=self.args.metrics,
logits=logits[splitted_data["train_mask"]],
labels=splitted_data["data"].y[splitted_data["train_mask"]],
suffix="train"
)
metric_val = compute_supervised_metrics(
metrics=self.args.metrics,
logits=logits[splitted_data["val_mask"]],
labels=splitted_data["data"].y[splitted_data["val_mask"]],
suffix="val"
)
metric_test = compute_supervised_metrics(
metrics=self.args.metrics,
logits=logits[splitted_data["test_mask"]],
labels=splitted_data["data"].y[splitted_data["test_mask"]],
suffix="test"
)
eval_output = {**eval_output, **metric_train, **metric_val, **metric_test}
info = ""
for key, val in eval_output.items():
try:
info += f"\t{key}: {val:.4f}"
except:
continue
prefix = f"[client {self.client_id}]" if self.client_id is not None else "[server]"
if not mute:
print(prefix+info)
return eval_output
return override_evaluate