Task
- class openfgl.task.base.BaseTask(args, client_id, data, data_dir, device)[source]
Bases:
objectBase class for defining a task in a federated learning setup.
- Attributes:
client_id (int): ID of the client.
data_dir (str): Directory containing the data.
args (Namespace): Arguments containing model and training configurations.
device (torch.device): Device to run the computations on.
data (object): Data specific to the task.
model (torch.nn.Module): Model to be trained.
optim (torch.optim.Optimizer): Optimizer for the model.
override_evaluate (function): Custom evaluation function, if provided.
step_preprocess (function): Custom preprocessing step, if provided.
- property default_loss_fn
Get the default loss function for the task. This method should be implemented by subclasses.
- Returns:
function: Default loss function.
- property default_model
Get the default model for the task. This method should be implemented by subclasses.
- Returns:
torch.nn.Module: Default model.
- property default_optim
Get the default optimizer for the task. This method should be implemented by subclasses.
- Returns:
torch.optim.Optimizer: Default optimizer.
- property default_train_val_test_split
Get the default train/validation/test split. This method should be implemented by subclasses.
- Returns:
dict: Default train/validation/test split.
- evaluate()[source]
Evaluate the model on the provided data. This method should be implemented by subclasses.
- load_custom_model(custom_model)[source]
Load a custom model for the task and reinitialize the optimizer.
- Args:
custom_model (torch.nn.Module): Custom model to be used.
- load_train_val_test_split()[source]
Load the train/validation/test split from a file. This method should be implemented by subclasses.
- property num_samples
Get the number of samples in the dataset. This method should be implemented by subclasses.
- Returns:
int: Number of samples.
- train()[source]
Train the model on the provided data. This method should be implemented by subclasses.
- property train_val_test_path
Get the path to the train/validation/test split file. This method should be implemented by subclasses.
- Returns:
str: Path to the split file.
- class openfgl.task.graph_cls.GraphClsTask(args, client_id, data, data_dir, device)[source]
Bases:
BaseTaskTask class for graph classification in a federated learning setup.
- Attributes:
client_id (int): ID of the client.
data_dir (str): Directory containing the data.
args (Namespace): Arguments containing model and training configurations.
device (torch.device): Device to run the computations on.
data (object): Data specific to the task.
model (torch.nn.Module): Model to be trained.
optim (torch.optim.Optimizer): Optimizer for the model.
train_mask (torch.Tensor): Mask for the training set.
val_mask (torch.Tensor): Mask for the validation set.
test_mask (torch.Tensor): Mask for the test set.
train_dataloader (DataLoader): DataLoader for the training set.
val_dataloader (DataLoader): DataLoader for the validation set.
test_dataloader (DataLoader): DataLoader for the test set.
splitted_data (dict): Dictionary containing split data and DataLoaders.
processed_data (object): Processed data for training.
- property default_loss_fn
Get the default loss function for the task.
- Returns:
function: Default loss function.
- property default_model
Get the default model for graph classification.
- Returns:
torch.nn.Module: Default model.
- property default_optim
Get the default optimizer for the task.
- Returns:
torch.optim.Optimizer: Default optimizer.
- property default_train_val_test_split
Get the default train/validation/test split.
- Returns:
tuple: Default train/validation/test split ratios.
- evaluate(splitted_data=None, mute=False)[source]
Evaluate the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
mute (bool, optional): If True, suppress the print statements. Defaults to False.
- Returns:
dict: Dictionary containing evaluation metrics and results.
- local_graph_train_val_test_split(local_graphs, split, shuffle=True)[source]
Split the local graphs into train, validation, and test sets.
- Attributes:
local_graphs (object): Local graphs to be split.
split (str or tuple): Split ratios or default split identifier.
shuffle (bool, optional): If True, shuffle the graphs before splitting. Defaults to True.
- Returns:
tuple: Masks for the train, validation, and test sets.
- loss_fn(embedding, logits, label, mask)[source]
Calculate the loss for the model.
- Attributes:
embedding (torch.Tensor): Embeddings from the model.
logits (torch.Tensor): Logits from the model.
label (torch.Tensor): Ground truth labels.
mask (torch.Tensor): Mask to filter the logits and labels.
- Returns:
torch.Tensor: Calculated loss.
- property num_feats
Get the number of features in the dataset.
- Returns:
int: Number of features.
- property num_global_classes
Get the number of global classes in the dataset.
- Returns:
int: Number of global classes.
- property num_samples
Get the number of samples in the dataset.
- Returns:
int: Number of samples.
- train(splitted_data=None)[source]
Train the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
- property train_val_test_path
Get the path to the train/validation/test split file.
- Returns:
str: Path to the split file.
- class openfgl.task.link_pred.LinkPredTask(args, client_id, data, data_dir, device)[source]
Bases:
BaseTaskTask class for link prediction in a federated learning setup.
- Attributes:
client_id (int): ID of the client.
data_dir (str): Directory containing the data.
args (Namespace): Arguments containing model and training configurations.
device (torch.device): Device to run the computations on.
data (object): Data specific to the task.
model (torch.nn.Module): Model to be trained.
optim (torch.optim.Optimizer): Optimizer for the model.
forward_data (Data): Data for the forward pass.
merged_edge_index (torch.Tensor): Merged edge indices.
merged_edge_label (torch.Tensor): Labels for merged edges.
merged_edge_train_mask (torch.Tensor): Mask for training edges.
merged_edge_val_mask (torch.Tensor): Mask for validation edges.
merged_edge_test_mask (torch.Tensor): Mask for test edges.
splitted_data (dict): Dictionary containing split data and DataLoaders.
- property default_loss_fn
Get the default loss function for the task.
- Returns:
function: Default loss function.
- property default_model
Get the default model for node and edge level tasks.
- Returns:
torch.nn.Module: Default model.
- property default_optim
Get the default optimizer for the task.
- Returns:
torch.optim.Optimizer: Default optimizer.
- property default_train_val_test_split
Get the default train/validation/test split.
- Returns:
tuple: Default train/validation/test split ratios.
- evaluate(splitted_data=None, mute=False)[source]
Evaluate the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
mute (bool, optional): If True, suppress the print statements. Defaults to False.
- Returns:
dict: Dictionary containing evaluation metrics and results.
- local_subgraph_train_val_test_split(local_subgraph, split, shuffle=True)[source]
Split the local subgraph into train, validation, and test sets.
- Attributes:
local_subgraph (object): Local subgraph to be split.
split (str or tuple): Split ratios or default split identifier.
shuffle (bool, optional): If True, shuffle the subgraph before splitting. Defaults to True.
- Returns:
tuple: Masks for the train, validation, and test sets.
- loss_fn(embedding, logits, label, mask)[source]
Calculate the loss for the model.
- Attributes:
embedding (torch.Tensor): Embeddings from the model.
logits (torch.Tensor): Logits from the model.
label (torch.Tensor): Ground truth labels.
mask (torch.Tensor): Mask to filter the logits and labels.
- Returns:
torch.Tensor: Calculated loss.
- property num_feats
Get the number of features in the dataset.
- Returns:
int: Number of features.
- property num_global_classes
Get the number of global classes in the dataset.
- Returns:
int: Number of global classes.
- property num_samples
Get the number of samples in the dataset.
- Returns:
int: Number of samples.
- train(splitted_data=None)[source]
Train the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
- property train_val_test_path
Get the path to the train/validation/test split file.
- Returns:
str: Path to the split file.
- class openfgl.task.node_cls.NodeClsTask(args, client_id, data, data_dir, device)[source]
Bases:
BaseTaskTask class for node classification in a federated learning setup.
- Attributes:
client_id (int): ID of the client.
data_dir (str): Directory containing the data.
args (Namespace): Arguments containing model and training configurations.
device (torch.device): Device to run the computations on.
data (object): Data specific to the task.
model (torch.nn.Module): Model to be trained.
optim (torch.optim.Optimizer): Optimizer for the model.
train_mask (torch.Tensor): Mask for the training set.
val_mask (torch.Tensor): Mask for the validation set.
test_mask (torch.Tensor): Mask for the test set.
splitted_data (dict): Dictionary containing split data and DataLoaders.
processed_data (object): Processed data for training.
- property default_loss_fn
Get the default loss function for the task.
- Returns:
function: Default loss function.
- property default_model
Get the default model for node and edge level tasks.
- Returns:
torch.nn.Module: Default model.
- property default_optim
Get the default optimizer for the task.
- Returns:
torch.optim.Optimizer: Default optimizer.
- property default_train_val_test_split
Get the default train/validation/test split based on the dataset.
- Returns:
tuple: Default train/validation/test split ratios.
- evaluate(splitted_data=None, mute=False)[source]
Evaluate the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
mute (bool, optional): If True, suppress the print statements. Defaults to False.
- Returns:
dict: Dictionary containing evaluation metrics and results.
- local_subgraph_train_val_test_split(local_subgraph, split, shuffle=True)[source]
Split the local subgraph into train, validation, and test sets.
- Attributes:
local_subgraph (object): Local subgraph to be split.
split (str or tuple): Split ratios or default split identifier.
shuffle (bool, optional): If True, shuffle the subgraph before splitting. Defaults to True.
- Returns:
tuple: Masks for the train, validation, and test sets.
- loss_fn(embedding, logits, label, mask)[source]
Calculate the loss for the model.
- Attributes:
embedding (torch.Tensor): Embeddings from the model.
logits (torch.Tensor): Logits from the model.
label (torch.Tensor): Ground truth labels.
mask (torch.Tensor): Mask to filter the logits and labels.
- Returns:
torch.Tensor: Calculated loss.
- property num_feats
Get the number of features in the dataset.
- Returns:
int: Number of features.
- property num_global_classes
Get the number of global classes in the dataset.
- Returns:
int: Number of global classes.
- property num_samples
Get the number of samples in the dataset.
- Returns:
int: Number of samples.
- train(splitted_data=None)[source]
Train the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
- property train_val_test_path
Get the path to the train/validation/test split file.
- Returns:
str: Path to the split file.
- class openfgl.task.node_clust.NodeClustTask(args, client_id, data, data_dir, device)[source]
Bases:
BaseTaskTask class for node clustering in a federated learning setup.
- Attributes:
client_id (int): ID of the client.
data_dir (str): Directory containing the data.
args (Namespace): Arguments containing model and training configurations.
device (torch.device): Device to run the computations on.
data (object): Data specific to the task.
model (torch.nn.Module): Model to be trained.
optim (torch.optim.Optimizer): Optimizer for the model.
splitted_data (dict): Dictionary containing split data and DataLoaders.
- property default_loss_fn
Get the default loss function for the task.
- Returns:
function: Default loss function.
- property default_model
Get the default model for node and edge level tasks.
- Returns:
torch.nn.Module: Default model.
- property default_optim
Get the default optimizer for the task.
- Returns:
torch.optim.Optimizer: Default optimizer.
- property default_train_val_test_split
Get the default train/validation/test split. Not used in this task.
- Returns:
None
- evaluate(splitted_data=None, mute=False)[source]
Evaluate the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
mute (bool, optional): If True, suppress the print statements. Defaults to False.
- Returns:
dict: Dictionary containing evaluation metrics and results.
- load_train_val_test_split()[source]
Load the train/validation/test split from a file. Not used in this task.
- loss_fn(embedding, logits, label, mask)[source]
Calculate the loss for the model.
- Attributes:
embedding (torch.Tensor): Embeddings from the model.
logits (torch.Tensor): Logits from the model.
label (torch.Tensor): Ground truth labels.
mask (torch.Tensor): Mask to filter the logits and labels.
- Returns:
torch.Tensor: Calculated loss.
- property num_feats
Get the number of features in the dataset.
- Returns:
int: Number of features.
- property num_global_classes
Get the number of global classes in the dataset.
- Returns:
int: Number of global classes.
- property num_samples
Get the number of samples in the dataset.
- Returns:
int: Number of samples.
- train(splitted_data=None)[source]
Train the model on the provided or processed data.
- Attributes:
splitted_data (dict, optional): Dictionary containing split data and DataLoaders. Defaults to None.
- property train_val_test_path
Get the path to the train/validation/test split file. Not used in this task.
- Returns:
None