import torch
from openfgl.flcore.base import BaseServer
from openfgl.flcore.fedgl.models import FedGCN
from openfgl.flcore.fedgl.fedgl_config import config
from scipy.spatial.distance import cdist
import scipy as sp
import numpy as np
from torch_geometric.utils import to_torch_csr_tensor
[docs]class FedGLServer(BaseServer):
"""
FedGLServer is a server implementation for the Federated Graph Learning (FedGL) framework
with global self-supervision. It extends the BaseServer class and manages the aggregation
of model parameters from multiple clients, updates global pseudo-labels, and reconstructs
the global graph structure.
Attributes:
pseudo_labels (list): List of pseudo-labels generated by the server for each client.
pseudo_labels_mask (list): List of masks indicating which nodes have pseudo-labels for each client.
whole_adj (list): List of reconstructed adjacency matrices for each client.
"""
[docs] def __init__(self, args, global_data, data_dir, message_pool, device):
"""
Initializes the FedGLServer.
Attributes:
args (Namespace): Arguments containing model and training configurations.
global_data (object): Global dataset accessible by the server.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between server and clients.
device (torch.device): Device to run the computations on.
"""
super(FedGLServer, self).__init__(args, global_data, data_dir, message_pool, device)
self.task.load_custom_model(FedGCN(nfeat=self.task.num_feats, nhid=self.args.hid_dim,
nclass=self.task.num_global_classes, nlayer=self.args.num_layers,
dropout=self.args.dropout))
self.task.splitted_data["data"].adj = to_torch_csr_tensor(self.task.data.edge_index)
self.pseudo_labels = []
self.pseudo_labels_mask = []
self.whole_adj = []
[docs] def send_message(self):
"""
Sends a message to the clients. In the initial round, only the global model parameters are sent.
In subsequent rounds, the server also sends the pseudo-labels, pseudo-labels masks, and reconstructed
adjacency matrices to the clients.
"""
if self.message_pool["round"] == 0 :
self.message_pool["server"] = {
"weight": list(self.task.model.parameters())
}
else:
self.message_pool["server"] = {
"weight": list(self.task.model.parameters()),
"pseudo_labels": self.pseudo_labels,
"pseudo_labels_mask": self.pseudo_labels_mask,
"whole_adj": self.whole_adj
}
[docs] def execute(self):
"""
Executes the server-side operations. This includes aggregating model parameters from sampled clients,
updating global pseudo-labels, and reconstructing the global adjacency matrix if needed.
"""
sample_weights = []
client_masks = []
client_embeddings = []
client_preds = []
with torch.no_grad():
num_tot_samples = sum([self.message_pool[f"client_{client_id}"]["num_samples"] for client_id in
self.message_pool[f"sampled_clients"]])
for it, client_id in enumerate(self.message_pool["sampled_clients"]):
client_masks.append(self.message_pool[f"client_{client_id}"]["mask"])
client_embeddings.append(self.message_pool[f"client_{client_id}"]["embeddings"])
client_preds.append(self.message_pool[f"client_{client_id}"]["preds"])
weight = self.message_pool[f"client_{client_id}"]["num_samples"] / num_tot_samples
sample_weights.append(weight)
for (local_param, global_param) in zip(self.message_pool[f"client_{client_id}"]["weight"],
self.task.model.parameters()):
if it == 0:
global_param.data.copy_(weight * local_param)
else:
global_param.data += weight * local_param
if self.message_pool["round"]%config["pseudo_labels_update_epoch"] == 0:
# ------- compute weight --------
if config["ssl_loss_weight"]>0 or config["pseudo_graph_weight"] > 0:
# ---- weight of each client per node
if config["pred_weight"] == 'mean':
random_clients_data_rate = [1 for i in range(len(self.message_pool["sampled_clients"]))]
elif config["pred_weight"] == 'sampling_rate':
random_clients_data_rate = sample_weights
else:
raise ("error pred weight")
random_clients_weights_per_node = torch.zeros(self.task.data.x.shape[0])
for rate, mask in zip(random_clients_data_rate, client_masks):
mask_all = torch.zeros(self.task.data.x.shape[0])
mask_all[mask] = 1.
random_clients_weights_per_node += rate * mask_all
random_clients_weights_per_node[random_clients_weights_per_node == 0] = 1.
# -------- reconstruct the adj matrix in server ------
if config["pseudo_graph_weight"] > 0:
# ----- obtain global embedding
global_emb = np.zeros((self.task.data.x.shape[0], self.task.num_global_classes))
for rate, embed, mask in zip(random_clients_data_rate, client_embeddings, client_masks):
client_emb = np.zeros((self.task.data.x.shape[0], self.task.num_global_classes))
client_emb[mask.detach().cpu().numpy()] = embed.detach().cpu().numpy()
global_emb += rate * client_emb
# row normalization
global_emb = global_emb / random_clients_weights_per_node[:, None].numpy()
server_adj = construct_server_adj(global_emb, type='dot', s=config['k'], mode=0, sigma=2)
np.fill_diagonal(server_adj, 1)
whole_adj = config['pseudo_graph_weight'] * normalize_server_adj(server_adj).to(self.device)
for i,mask in enumerate(client_masks):
self.whole_adj.append(whole_adj[mask,:][:,mask].to_sparse_csr())
# print(server_adj_final)
# --------- pseudo labels -------
if config['ssl_loss_weight'] > 0:
# ----- obtain global prediction
global_pred = torch.zeros((self.task.data.x.shape[0], self.task.num_global_classes))
for rate, pred, mask in zip(random_clients_data_rate, client_preds, client_masks):
client_pred = torch.zeros((self.task.data.x.shape[0], self.task.num_global_classes))
client_pred[mask] = torch.nn.functional.softmax(pred.detach(),dim=1).cpu()
# set pred value to 0 if it less than probability_threshold
# print(pred.max(axis=1))
client_pred[client_pred < config['probability_threshold']] = 0
# weight sum of client prediction
global_pred += rate * client_pred
# global_pred += client_pred
global_pred = global_pred / random_clients_weights_per_node[:, None]
# ----- label sharpening
# print('before sharpening:\n', global_pred.argmax(axis=1)[:20])
# global_pred = global_pred / global_pred.sum(axis=1)[:, None]
# global_pred = np.square(global_pred) / np.square(global_pred).sum(axis=1)[:, None]
# print('after sharpening:\n', global_pred)
# ----- self-supervised learning ----
# select pseudo labels
pseudo_labels_col_index = torch.argmax(global_pred, dim=1)
global_pred_rowsum = global_pred.sum(dim=1)
pseudo_labels_row_index = torch.where(global_pred_rowsum > 0)[0]
# print('pseudo labels num: ', len(pseudo_labels_row_index), 'class: ', set(pseudo_labels_col_index))
# update pseudo labels mask
p_mask = torch.zeros(self.task.data.x.shape[0]).to(self.device)
p_mask[pseudo_labels_row_index] = 1
for i,mask in enumerate(client_masks):
self.pseudo_labels_mask.append(p_mask[mask])
# update global pseudo labels
p_global = torch.zeros(self.task.data.x.shape[0]).to(self.device)
p_global[pseudo_labels_row_index] = pseudo_labels_col_index[pseudo_labels_row_index].type(torch.float).to(self.device)
for i,mask in enumerate(client_masks):
self.pseudo_labels.append(p_global[mask])
def construct_server_adj(data, type, s, mode=0, sigma=2):
"""
Constructs the adjacency matrix on the server based on client embeddings using either dot product or Gaussian kernel.
Args:
data (numpy.ndarray): The global embeddings from clients.
type (str): The type of similarity measure ('dot' or 'kernel').
s (int): The number of neighbors to consider for each node.
mode (int, optional): The mode for Gaussian kernel computation. Defaults to 0.
sigma (float, optional): The sigma value for Gaussian kernel. Defaults to 2.
Returns:
numpy.ndarray: The constructed adjacency matrix.
"""
if type == 'dot':
Z_full = data.dot(data.T)
Z_full[Z_full < 0] = 0
elif type == 'kernel':
Z_full = gaussian_kernel(data, mode, sigma)
else:
print('type is error!')
Z = np.zeros(Z_full.shape)
for i in range(Z.shape[0]):
index_s = np.argsort(-Z_full[i, :])[0:s]
Z[i, index_s] = Z_full[i, index_s]
Z /= Z.sum(axis=1)[:, None]
for i in range(Z.shape[1]):
if (Z[:, i] == np.zeros(Z.shape[0])).all():
row = np.random.choice(Z.shape[0], 1)
Z[row[0]][i] = 0.1
return Z
def gaussian_kernel(X, Y, mode=1, segma=2, K=5):
"""
Computes the Gaussian kernel between two sets of vectors.
Args:
X (numpy.ndarray): The first set of vectors.
Y (numpy.ndarray): The second set of vectors.
mode (int, optional): The mode for sigma computation. Defaults to 1.
sigma (float, optional): The sigma value for Gaussian kernel. Defaults to 2.
K (int, optional): The number of neighbors to consider for adaptive sigma. Defaults to 5.
Returns:
numpy.ndarray: The Gaussian kernel matrix.
"""
sqdist = cdist(X, Y, metric='sqeuclidean')
if mode == 0:
segma_ij = 2 * segma **2
elif mode == 1:
sqdist_sort_row = np.sort(sqdist, axis=1)
sqdist_sort_col = np.sort(sqdist, axis=0)
segma_i = 1/K * np.sqrt(sqdist_sort_row[:, 0:K - 1].sum(axis=1)).reshape(sqdist.shape[0], 1)
segma_j = 1/K * np.sqrt(sqdist_sort_col[0:K - 1, :].sum(axis=0)).reshape(1, sqdist.shape[1])
segma_ij = segma_i.dot(segma_j)
else:
sqdist_sort_row = np.sort(sqdist, axis=1)
sqdist_sort_col = np.sort(sqdist, axis=0)
segma_i = np.sqrt(sqdist_sort_row[:,K-1]).reshape(sqdist.shape[0], 1)
segma_j = np.sqrt(sqdist_sort_col[K-1, :]).reshape(1, sqdist.shape[1])
segma_ij = segma_i.dot(segma_j)
#print(sqdist, segma_ij)
return np.exp(-sqdist / segma_ij)
def normalize_server_adj(adj):
"""
Symmetrically normalizes the server adjacency matrix.
Args:
adj (numpy.ndarray): The adjacency matrix to normalize.
Returns:
torch.Tensor: The symmetrically normalized adjacency matrix as a PyTorch tensor.
"""
adj = torch.tensor(adj)
rowsum = adj.sum(1)
d_inv_sqrt = torch.pow(rowsum, -0.5).flatten()
d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
return adj.mm(d_mat_inv_sqrt).T.mm(d_mat_inv_sqrt)