Task

class openfgl.task.base.BaseTask(args, client_id, data, data_dir, device)[source]

Bases: object

Base class for defining a task in a federated learning setup.

Attributes:

client_id (int): ID of the client.

data_dir (str): Directory containing the data.

args (Namespace): Arguments containing model and training configurations.

device (torch.device): Device to run the computations on.

data (object): Data specific to the task.

model (torch.nn.Module): Model to be trained.

optim (torch.optim.Optimizer): Optimizer for the model.

override_evaluate (function): Custom evaluation function, if provided.

step_preprocess (function): Custom preprocessing step, if provided.

property default_loss_fn

Get the default loss function for the task. This method should be implemented by subclasses.

Returns:

function: Default loss function.

property default_model

Get the default model for the task. This method should be implemented by subclasses.

Returns:

torch.nn.Module: Default model.

property default_optim

Get the default optimizer for the task. This method should be implemented by subclasses.

Returns:

torch.optim.Optimizer: Default optimizer.

property default_train_val_test_split

Get the default train/validation/test split. This method should be implemented by subclasses.

Returns:

dict: Default train/validation/test split.

evaluate()[source]

Evaluate the model on the provided data. This method should be implemented by subclasses.

load_custom_model(custom_model)[source]

Load a custom model for the task and reinitialize the optimizer.

Args:

custom_model (torch.nn.Module): Custom model to be used.

load_train_val_test_split()[source]

Load the train/validation/test split from a file. This method should be implemented by subclasses.

property num_samples

Get the number of samples in the dataset. This method should be implemented by subclasses.

Returns:

int: Number of samples.

train()[source]

Train the model on the provided data. This method should be implemented by subclasses.

property train_val_test_path

Get the path to the train/validation/test split file. This method should be implemented by subclasses.

Returns:

str: Path to the split file.

class openfgl.task.graph_cls.GraphClsTask(args, client_id, data, data_dir, device)[source]

Bases: BaseTask

Task class for graph classification in a federated learning setup.

Attributes:

client_id (int): ID of the client.

data_dir (str): Directory containing the data.

args (Namespace): Arguments containing model and training configurations.

device (torch.device): Device to run the computations on.

data (object): Data specific to the task.

model (torch.nn.Module): Model to be trained.

optim (torch.optim.Optimizer): Optimizer for the model.

train_mask (torch.Tensor): Mask for the training set.

val_mask (torch.Tensor): Mask for the validation set.

test_mask (torch.Tensor): Mask for the test set.

train_dataloader (DataLoader): DataLoader for the training set.

val_dataloader (DataLoader): DataLoader for the validation set.

test_dataloader (DataLoader): DataLoader for the test set.

splitted_data (dict): Dictionary containing split data and DataLoaders.

processed_data (object): Processed data for training.

property default_loss_fn

Get the default loss function for the task.

Returns:

function: Default loss function.

property default_model

Get the default model for graph classification.

Returns:

torch.nn.Module: Default model.

property default_optim

Get the default optimizer for the task.

Returns:

torch.optim.Optimizer: Default optimizer.

property default_train_val_test_split

Get the default train/validation/test split.

Returns:

tuple: Default train/validation/test split ratios.

evaluate(splitted_data=None, mute=False)[source]

Evaluate the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

mute (bool, optional): If True, suppress the print statements. Defaults to False.

Returns:

dict: Dictionary containing evaluation metrics and results.

load_train_val_test_split()[source]

Load the train/validation/test split from a file.

local_graph_train_val_test_split(local_graphs, split, shuffle=True)[source]

Split the local graphs into train, validation, and test sets.

Attributes:

local_graphs (object): Local graphs to be split.

split (str or tuple): Split ratios or default split identifier.

shuffle (bool, optional): If True, shuffle the graphs before splitting. Defaults to True.

Returns:

tuple: Masks for the train, validation, and test sets.

loss_fn(embedding, logits, label, mask)[source]

Calculate the loss for the model.

Attributes:

embedding (torch.Tensor): Embeddings from the model.

logits (torch.Tensor): Logits from the model.

label (torch.Tensor): Ground truth labels.

mask (torch.Tensor): Mask to filter the logits and labels.

Returns:

torch.Tensor: Calculated loss.

property num_feats

Get the number of features in the dataset.

Returns:

int: Number of features.

property num_global_classes

Get the number of global classes in the dataset.

Returns:

int: Number of global classes.

property num_samples

Get the number of samples in the dataset.

Returns:

int: Number of samples.

train(splitted_data=None)[source]

Train the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

property train_val_test_path

Get the path to the train/validation/test split file.

Returns:

str: Path to the split file.

Bases: BaseTask

Task class for link prediction in a federated learning setup.

Attributes:

client_id (int): ID of the client.

data_dir (str): Directory containing the data.

args (Namespace): Arguments containing model and training configurations.

device (torch.device): Device to run the computations on.

data (object): Data specific to the task.

model (torch.nn.Module): Model to be trained.

optim (torch.optim.Optimizer): Optimizer for the model.

forward_data (Data): Data for the forward pass.

merged_edge_index (torch.Tensor): Merged edge indices.

merged_edge_label (torch.Tensor): Labels for merged edges.

merged_edge_train_mask (torch.Tensor): Mask for training edges.

merged_edge_val_mask (torch.Tensor): Mask for validation edges.

merged_edge_test_mask (torch.Tensor): Mask for test edges.

splitted_data (dict): Dictionary containing split data and DataLoaders.

Get the default loss function for the task.

Returns:

function: Default loss function.

Get the default model for node and edge level tasks.

Returns:

torch.nn.Module: Default model.

Get the default optimizer for the task.

Returns:

torch.optim.Optimizer: Default optimizer.

Get the default train/validation/test split.

Returns:

tuple: Default train/validation/test split ratios.

Evaluate the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

mute (bool, optional): If True, suppress the print statements. Defaults to False.

Returns:

dict: Dictionary containing evaluation metrics and results.

Load the train/validation/test split from a file.

Split the local subgraph into train, validation, and test sets.

Attributes:

local_subgraph (object): Local subgraph to be split.

split (str or tuple): Split ratios or default split identifier.

shuffle (bool, optional): If True, shuffle the subgraph before splitting. Defaults to True.

Returns:

tuple: Masks for the train, validation, and test sets.

Calculate the loss for the model.

Attributes:

embedding (torch.Tensor): Embeddings from the model.

logits (torch.Tensor): Logits from the model.

label (torch.Tensor): Ground truth labels.

mask (torch.Tensor): Mask to filter the logits and labels.

Returns:

torch.Tensor: Calculated loss.

Get the number of features in the dataset.

Returns:

int: Number of features.

Get the number of global classes in the dataset.

Returns:

int: Number of global classes.

Get the number of samples in the dataset.

Returns:

int: Number of samples.

Train the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

Get the path to the train/validation/test split file.

Returns:

str: Path to the split file.

class openfgl.task.node_cls.NodeClsTask(args, client_id, data, data_dir, device)[source]

Bases: BaseTask

Task class for node classification in a federated learning setup.

Attributes:

client_id (int): ID of the client.

data_dir (str): Directory containing the data.

args (Namespace): Arguments containing model and training configurations.

device (torch.device): Device to run the computations on.

data (object): Data specific to the task.

model (torch.nn.Module): Model to be trained.

optim (torch.optim.Optimizer): Optimizer for the model.

train_mask (torch.Tensor): Mask for the training set.

val_mask (torch.Tensor): Mask for the validation set.

test_mask (torch.Tensor): Mask for the test set.

splitted_data (dict): Dictionary containing split data and DataLoaders.

processed_data (object): Processed data for training.

property default_loss_fn

Get the default loss function for the task.

Returns:

function: Default loss function.

property default_model

Get the default model for node and edge level tasks.

Returns:

torch.nn.Module: Default model.

property default_optim

Get the default optimizer for the task.

Returns:

torch.optim.Optimizer: Default optimizer.

property default_train_val_test_split

Get the default train/validation/test split based on the dataset.

Returns:

tuple: Default train/validation/test split ratios.

evaluate(splitted_data=None, mute=False)[source]

Evaluate the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

mute (bool, optional): If True, suppress the print statements. Defaults to False.

Returns:

dict: Dictionary containing evaluation metrics and results.

load_train_val_test_split()[source]

Load the train/validation/test split from a file.

local_subgraph_train_val_test_split(local_subgraph, split, shuffle=True)[source]

Split the local subgraph into train, validation, and test sets.

Attributes:

local_subgraph (object): Local subgraph to be split.

split (str or tuple): Split ratios or default split identifier.

shuffle (bool, optional): If True, shuffle the subgraph before splitting. Defaults to True.

Returns:

tuple: Masks for the train, validation, and test sets.

loss_fn(embedding, logits, label, mask)[source]

Calculate the loss for the model.

Attributes:

embedding (torch.Tensor): Embeddings from the model.

logits (torch.Tensor): Logits from the model.

label (torch.Tensor): Ground truth labels.

mask (torch.Tensor): Mask to filter the logits and labels.

Returns:

torch.Tensor: Calculated loss.

property num_feats

Get the number of features in the dataset.

Returns:

int: Number of features.

property num_global_classes

Get the number of global classes in the dataset.

Returns:

int: Number of global classes.

property num_samples

Get the number of samples in the dataset.

Returns:

int: Number of samples.

train(splitted_data=None)[source]

Train the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

property train_val_test_path

Get the path to the train/validation/test split file.

Returns:

str: Path to the split file.

class openfgl.task.node_clust.NodeClustTask(args, client_id, data, data_dir, device)[source]

Bases: BaseTask

Task class for node clustering in a federated learning setup.

Attributes:

client_id (int): ID of the client.

data_dir (str): Directory containing the data.

args (Namespace): Arguments containing model and training configurations.

device (torch.device): Device to run the computations on.

data (object): Data specific to the task.

model (torch.nn.Module): Model to be trained.

optim (torch.optim.Optimizer): Optimizer for the model.

splitted_data (dict): Dictionary containing split data and DataLoaders.

property default_loss_fn

Get the default loss function for the task.

Returns:

function: Default loss function.

property default_model

Get the default model for node and edge level tasks.

Returns:

torch.nn.Module: Default model.

property default_optim

Get the default optimizer for the task.

Returns:

torch.optim.Optimizer: Default optimizer.

property default_train_val_test_split

Get the default train/validation/test split. Not used in this task.

Returns:

None

evaluate(splitted_data=None, mute=False)[source]

Evaluate the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

mute (bool, optional): If True, suppress the print statements. Defaults to False.

Returns:

dict: Dictionary containing evaluation metrics and results.

load_train_val_test_split()[source]

Load the train/validation/test split from a file. Not used in this task.

loss_fn(embedding, logits, label, mask)[source]

Calculate the loss for the model.

Attributes:

embedding (torch.Tensor): Embeddings from the model.

logits (torch.Tensor): Logits from the model.

label (torch.Tensor): Ground truth labels.

mask (torch.Tensor): Mask to filter the logits and labels.

Returns:

torch.Tensor: Calculated loss.

property num_feats

Get the number of features in the dataset.

Returns:

int: Number of features.

property num_global_classes

Get the number of global classes in the dataset.

Returns:

int: Number of global classes.

property num_samples

Get the number of samples in the dataset.

Returns:

int: Number of samples.

train(splitted_data=None)[source]

Train the model on the provided or processed data.

Attributes:

splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.

property train_val_test_path

Get the path to the train/validation/test split file. Not used in this task.

Returns:

None