import torch
import torch.nn as nn
from openfgl.task.base import BaseTask
from openfgl.utils.basic_utils import extract_floats, idx_to_mask_tensor, mask_tensor_to_idx
from os import path as osp
from openfgl.utils.metrics import compute_supervised_metrics
import os
import torch
from openfgl.utils.task_utils import load_node_edge_level_default_model
import pickle
import numpy as np
from sklearn.cluster import KMeans
def compute_edge_logits(node_embedding, edge_index):
"""
Compute edge logits based on node embeddings and edge index.
Attributes:
node_embedding (torch.Tensor): Node embeddings.
edge_index (torch.Tensor): Edge indices.
Returns:
torch.Tensor: Edge logits.
"""
source_node_embedding = node_embedding[edge_index[0]]
target_node_embedding = node_embedding[edge_index[1]]
edge_logits = (source_node_embedding * target_node_embedding).sum(dim=1)
return edge_logits
[docs]class NodeClustTask(BaseTask):
"""
Task class for node clustering in a federated learning setup.
Attributes:
client_id (int): ID of the client.
data_dir (str): Directory containing the data.
args (Namespace): Arguments containing model and training configurations.
device (torch.device): Device to run the computations on.
data (object): Data specific to the task.
model (torch.nn.Module): Model to be trained.
optim (torch.optim.Optimizer): Optimizer for the model.
splitted_data (dict): Dictionary containing split data and DataLoaders.
"""
def __init__(self, args, client_id, data, data_dir, device):
"""
Initialize the NodeClustTask with provided arguments, data, and device.
Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the task.
data_dir (str): Directory containing the data.
device (torch.device): Device to run the computations on.
"""
super(NodeClustTask, self).__init__(args, client_id, data, data_dir, device)
merged_edge_index_list = []
for source in range(self.num_samples):
for target in range(self.num_samples):
merged_edge_index_list.append((source, target))
merged_edge_index = torch.tensor(merged_edge_index_list).T.long().to(self.device)
merged_edge_label = torch.zeros((merged_edge_index.shape[1],)).float().to(self.device)
for edge_id in range(self.data.edge_index.shape[1]):
source = self.data.edge_index[0, edge_id].item()
target = self.data.edge_index[1, edge_id].item()
idx = source * self.num_samples + target
merged_edge_label[idx] = 1
for source in range(self.num_samples):
idx = source * self.num_samples + source
merged_edge_label[idx] = 1
merged_edge_label = merged_edge_label.to(self.device)
self.splitted_data = {
"data": self.data,
"merged_edge_index": merged_edge_index,
"merged_edge_label": merged_edge_label
}
[docs] def train(self, splitted_data=None):
"""
Train the model on the provided or processed data.
Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
"""
if splitted_data is None:
splitted_data = self.splitted_data
else:
names = ["data", "merged_edge_index", "merged_edge_label"]
for name in names:
assert name in splitted_data
self.model.train()
for _ in range(self.args.num_epochs):
self.optim.zero_grad()
node_embedding, node_logits = self.model.forward(splitted_data["data"])
edge_logits = compute_edge_logits(node_embedding, splitted_data["merged_edge_index"])
loss = self.loss_fn(None, edge_logits, splitted_data["merged_edge_label"], mask=None)
loss.backward()
if self.step_preprocess is not None:
self.step_preprocess()
self.optim.step()
[docs] def evaluate(self, splitted_data=None, mute=False):
"""
Evaluate the model on the provided or processed data.
Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
mute (bool, optional): If True, suppress the print statements. Defaults to False.
Returns:
dict: Dictionary containing evaluation metrics and results.
"""
if self.override_evaluate is None:
if splitted_data is None:
splitted_data = self.splitted_data
else:
names = ["data", "merged_edge_index", "merged_edge_label"]
for name in names:
assert name in splitted_data
eval_output = {}
self.model.eval()
with torch.no_grad():
node_embedding, node_logits = self.model.forward(splitted_data["data"])
edge_logits = compute_edge_logits(node_embedding, splitted_data["merged_edge_index"])
loss = self.loss_fn(None, edge_logits, splitted_data["merged_edge_label"], mask=None)
kmeans = KMeans(n_clusters=self.args.num_clusters, random_state=0)
node_embeddings_np = node_embedding.detach().cpu().numpy()
cluster_label_tensor = torch.tensor(kmeans.fit_predict(node_embeddings_np)).to(self.device)
eval_output["embedding"] = None
eval_output["logits"] = cluster_label_tensor
eval_output["loss"] = loss
metric = compute_supervised_metrics(metrics=self.args.metrics, logits=cluster_label_tensor, labels=splitted_data["data"].y, suffix="")
eval_output = {**eval_output, **metric}
info = ""
for key, val in eval_output.items():
try:
info += f"\t{key}: {val:.4f}"
except:
continue
prefix = f"[client {self.client_id}]" if self.client_id is not None else "[server]"
if not mute:
print(prefix+info)
return eval_output
else:
return self.override_evaluate(splitted_data, mute)
[docs] def loss_fn(self, embedding, logits, label, mask):
"""
Calculate the loss for the model.
Attributes:
embedding (torch.Tensor): Embeddings from the model.
logits (torch.Tensor): Logits from the model.
label (torch.Tensor): Ground truth labels.
mask (torch.Tensor): Mask to filter the logits and labels.
Returns:
torch.Tensor: Calculated loss.
"""
return self.default_loss_fn(logits, label)
@property
def train_val_test_path(self):
"""
Get the path to the train/validation/test split file.
Returns:
str: Path to the split file.
"""
return osp.join(self.data_dir, f"node_clust")
@property
def default_model(self):
"""
Get the default model for node and edge level tasks.
Returns:
torch.nn.Module: Default model.
"""
return load_node_edge_level_default_model(self.args, input_dim=self.num_feats, output_dim=self.num_global_classes, client_id=self.client_id)
@property
def default_optim(self):
"""
Get the default optimizer for the task.
Returns:
torch.optim.Optimizer: Default optimizer.
"""
if self.args.optim == "adam":
from torch.optim import Adam
return Adam
@property
def num_samples(self):
"""
Get the number of samples in the dataset.
Returns:
int: Number of samples.
"""
return self.data.x.shape[0]
@property
def num_feats(self):
"""
Get the number of features in the dataset.
Returns:
int: Number of features.
"""
return self.data.x.shape[1]
@property
def num_global_classes(self):
"""
Get the number of global classes in the dataset.
Returns:
int: Number of global classes.
"""
return self.data.num_global_classes
@property
def default_loss_fn(self):
"""
Get the default loss function for the task.
Returns:
function: Default loss function.
"""
return nn.BCEWithLogitsLoss(weight=None)
@property
def default_train_val_test_split(self):
"""
Get the default train/validation/test split. Not used in this task.
Returns:
None
"""
return None
@property
def train_val_test_path(self):
"""
Get the path to the train/validation/test split file. Not used in this task.
Returns:
None
"""
pass
[docs] def load_train_val_test_split(self):
"""
Load the train/validation/test split from a file. Not used in this task.
"""
pass