Client
- class openfgl.flcore.base.BaseClient(args, client_id, data, data_dir, message_pool, device, personalized=False)[source]
Bases:
objectBase class for a client in a federated learning setup.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
task (object): Task-specific data and functions loaded via the load_task utility.
personalized (bool): Flag to indicate if the client is using a personalized algorithm.
- __init__(args, client_id, data, data_dir, message_pool, device, personalized=False)[source]
Initialize the BaseClient with provided arguments and data.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
personalized (bool, optional): Flag to indicate if the client is using a personalized algorithm. Defaults to False.
- class openfgl.flcore.fedavg.client.FedAvgClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedAvgClient implements the client-side logic for the Federated Averaging (FedAvg) algorithm, introduced in the paper “Communication-Efficient Learning of Deep Networks from Decentralized Data” by McMahan et al. (2017). This class extends the BaseClient class and manages local training and communication with the server.
The FedAvg algorithm allows clients to train models locally on their data and send the updated model parameters to the server for aggregation, enabling efficient learning in decentralized environments.
- Attributes:
None (inherits attributes from BaseClient)
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedAvgClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- class openfgl.flcore.adafgl.client.AdaFGLClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientAdaFGLClient implements the client-side logic for federated learning using the AdaFGL model, as described in the paper “AdaFGL: A New Paradigm for Federated Node Classification with Topology Heterogeneity”. It extends the BaseClient class by incorporating topology-aware learning methods, enabling the client to adapt to varying graph structures during the federated learning process.
- Attributes:
phase (int): Indicates the current phase of training. Initially set to 0, it switches to 1 when entering the AdaFGL phase.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the AdaFGLClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- adafgl_postprocess(loss_ce_fn=CrossEntropyLoss())[source]
Performs post-processing for the AdaFGL model after training. This includes computing the loss, performing backpropagation, and updating the model parameters.
- Attributes:
loss_ce_fn: Loss function for cross-entropy, default is nn.CrossEntropyLoss().
- execute()[source]
Executes the training process. This method handles the switching between different phases of training, initializes the AdaFGL model, and performs training based on the current phase.
- get_adafgl_override_evaluate()[source]
Returns a custom evaluation function that overrides the default evaluation method for the AdaFGL model. This function computes metrics based on both homogeneous and heterogeneous forward passes.
- Returns:
override_evaluate (function): A custom evaluation function.
- class openfgl.flcore.feddc.client.FedDCClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedDCClient is a client implementation for the Federated Learning algorithm with Drift Decoupling and Correction (FedDC). It extends the BaseClient class and manages local training while correcting for local drift to handle non-IID data effectively.
- Attributes:
local_drift (list): A list of tensors representing the accumulated drift for each model parameter.
last_update (list): A list of tensors representing the last update applied to each model parameter.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedDCClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. This method first synchronizes the local model with the global model parameters received from the server, then trains the model using a custom loss function that incorporates drift correction. After training, it updates the local drift and last update tensors.
- class openfgl.flcore.feddep.client.FedDEPClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedDEPClient is a client implementation for the Federated Learning algorithm with Deep Efficient Private Neighbor Generation for Subgraph Federated Learning (FedDEP). It extends the BaseClient class and handles local training, private neighbor generation, and graph recovery tasks in a federated setting.
- Attributes:
hide_graph_model (nn.Module): Model for generating private subgraphs by hiding parts of the input graph.
data (torch_geometric.data.Data): The original dataset after applying splits for training, validation, and testing.
hide_data (torch_geometric.data.Data): The dataset with hidden information for training the private neighbor generation model.
emb (torch.Tensor): Embedding tensor generated by the hide_graph_model.
x_missing (torch.Tensor): Tensor containing the missing features of the graph generated by the hide_graph_model.
loss_fn_num (function): Loss function used for numerical loss during training.
loss_fn_rec (function): Loss function used for reconstruction loss during training.
fill_dataloader (dict): Dataloader for filling in missing graph information.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedDEPClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. Depending on the current round, it either pre-trains the local model for private neighbor generation and prepares the FedDEP model, or performs the federated training with the global model updates.
- class openfgl.flcore.fedgl.client.FedGLClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedGLClient is a client implementation for the Federated Graph Learning (FedGL) framework with global self-supervision. It extends the BaseClient class and handles the local training of graph neural networks in a federated learning environment, incorporating global self-supervision through pseudo-labels and global graph structures.
- Attributes:
adj (torch.Tensor): Sparse adjacency matrix in CSR format representing the local graph structure.
mask (torch.Tensor): Tensor indicating which nodes are included in the global map, used for masking operations.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedGLClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. The method synchronizes the local model with the global model parameters received from the server, and if applicable, incorporates the global graph structure into the adjacency matrix. It then trains the model using the custom loss function.
- get_custom_loss_fn()[source]
Returns a custom loss function for the FedGL framework. This loss function combines the standard cross-entropy loss with an additional self-supervised learning (SSL) loss based on pseudo-labels and a global graph structure.
- Returns:
custom_loss_fn (function): A custom loss function.
- class openfgl.flcore.fedgta.client.FedGTAClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedGTAClient is a client implementation for the Federated Graph Learning framework with Topology-aware Averaging (FedGTA). This client handles local model training, label propagation, and the computation of topology-aware metrics for federated learning.
- Attributes:
LP (LabelPropagation): A label propagation model for graph-based semi-supervised learning.
num_neig (torch.Tensor): Tensor representing the degree (number of neighbors) of each node in the graph.
train_label_onehot (torch.Tensor): One-hot encoded labels for the training nodes.
lp_moment_v (torch.Tensor): Computed moments from the label propagation results, used for topology-aware averaging.
agg_w (torch.Tensor): Aggregation weights based on the information entropy of label propagation results.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedGTAClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. The method first synchronizes the local model with the global model parameters received from the server, then trains the model locally, and finally performs post-processing using label propagation and topology-aware averaging.
- class openfgl.flcore.fedproto.client.FedProtoClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedProtoClient is a client implementation for the Federated Prototype Learning (FedProto) framework. This client handles the local training of models, computes class-specific prototypes, and interacts with the server to contribute to the global model updates.
- Attributes:
local_prototype (dict): A dictionary storing the local prototypes for each class after training.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedProtoClient.
- :
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. This method sets a custom loss function that incorporates the prototype-based regularization term, performs local training, and then updates the local prototypes for each class.
- get_custom_loss_fn()[source]
Returns a custom loss function for the FedProto framework. This loss function combines the standard cross-entropy loss with an additional prototype-based regularization term.
- Returns:
custom_loss_fn (function): A custom loss function.
- class openfgl.flcore.fedprox.client.FedProxClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedProxClient is a client implementation for the Federated Proximal (FedProx) framework, introduced in the paper “Federated Optimization in Heterogeneous Networks.” This client handles local training with a custom loss function that includes a proximal term, designed to address the challenges of heterogeneity in federated learning environments.
- Attributes:
None
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedProxClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. This method first synchronizes the local model with the global model parameters received from the server, and then trains the model locally using the custom loss function that includes the FedProx proximal term.
- get_custom_loss_fn()[source]
Returns a custom loss function for the FedProx framework. This loss function combines the standard task-specific loss (e.g., cross-entropy) with a proximal term that penalizes the deviation of local model parameters from the global model parameters.
- Returns:
custom_loss_fn (function): A custom loss function that includes the proximal term.
- send_message()[source]
Sends a message to the server containing the local model parameters and the number of samples used for training. This information is used by the server to update the global model parameters.
- The message includes:
num_samples: The number of samples used in local training.
weight: The updated local model parameters.
- class openfgl.flcore.fedstar.client.FedStarClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedStarClient is the client-side implementation for the Federated Learning algorithm described in the paper ‘Federated Learning on Non-IID Graphs via Structural Knowledge Sharing’. This class handles local training, structural knowledge sharing, and communication with the server within a federated learning framework.
- Attributes:
task (object): The task object that holds the model and data for training.
device (torch.device): The device on which computations will be performed.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedStarClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): The graph data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. Before training, the structural knowledge (weights with ‘_s’ in their names) is updated with the global weights received from the server.
- get_custom_loss_fn()[source]
Returns a custom loss function for the FedStarClient.
This loss function is based on negative log-likelihood loss (nll_loss) and is applied to the model’s logits and the true labels.
- Returns:
function: A function that computes the loss given embeddings, logits, labels, and mask.
- class openfgl.flcore.fedtad.client.FedTADClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedTADClient implements the client-side operations for the Federated Learning algorithm described in the paper ‘FedTAD: Topology-aware Data-free Knowledge Distillation for Subgraph Federated Learning’. This class handles the local training, model updates, and knowledge distillation process based on topological data.
- Attributes:
ckr (torch.Tensor): Class-wise Knowledge Reliability (CKR) vector, which stores the reliability of the topological knowledge for each class.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedTADClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process.
The method first synchronizes the local model with the global model weights received from the server. Then, it trains the local model using the client’s data.
- fedtad_initialization()[source]
Initializes the Class-wise Knowledge Reliability (CKR) based on the topological data.
This method calculates the CKR by computing topological embeddings for the graph structure and using cosine similarity between nodes to update the CKR for each class. The CKR is saved to disk if the configuration allows.
- class openfgl.flcore.fedtgp.client.FedTGPClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFedTGPClient implements the client-side operations for the Federated Learning algorithm described in the paper ‘FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning’. This class handles the local training of the model, updating local prototypes, and communication with the server.
- Attributes:
local_prototype (dict): Stores the local prototypes for each class.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FedTGPClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
- execute()[source]
Executes the local training process. The method applies the custom loss function that includes the prototype-based loss, and updates the local prototypes after training.
- get_custom_loss_fn()[source]
Defines a custom loss function that incorporates both the standard classification loss and the prototype-based loss for contrastive learning.
- Returns:
custom_loss_fn (function): A function that computes the loss based on the current round and whether the model is being evaluated on global data.
- class openfgl.flcore.fggp.client.FGGPClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientFGGPClient is a client-side implementation for the Federated Graph Learning with Generalizable Prototypes (FGGP) framework. This client handles local training, model updates, and prototype generation in a federated learning setting, focusing on overcoming domain shifts across clients.
- Attributes:
global_model (nn.Module): A copy of the global model used to compute global embeddings.
personal_project (nn.Module): A projection layer used for personalizing embeddings.
data2 (torch_geometric.data.Data): A copy of the data with modified edges for use in the FGGP algorithm.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the FGGPClient.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): The ID of the client.
data (torch_geometric.data.Data): The graph data specific to the client’s task.
data_dir (str): Directory containing the data.
message_pool (dict): Pool for managing messages between client and server.
device (torch.device): The device on which computations will be performed (e.g., CPU or GPU).
- execute()[source]
Executes the local training process. This involves: - Synchronizing the local and global model parameters with the server. - Calculating the k-nearest neighbors graph for global embeddings. - Training the model using the custom loss function.
- class openfgl.flcore.gcfl_plus.client.GCFLPlusClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientGCFLPlusClient implements the client-side functionality for the Federated Graph Classification framework (GCFL+). This client is designed to operate on non-IID graphs and includes personalized training, weight updating, and message passing functionalities. It builds on a Graph Isomorphism Network (GIN) model for graph classification tasks.
- Attributes:
task (object): The task object containing the model and data configurations.
W (dict): A dictionary containing the current model parameters.
dW (dict): A dictionary to store the differences between the current and previous model parameters.
W_old (dict): A dictionary to store a copy of the model parameters before training.
gconvNames (list): A list of the names of the graph convolution layers in the model, which are updated during training.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the GCFLPlusClient with the provided arguments, data, and device.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): The unique identifier for the client.
data (object): The client’s local graph data.
data_dir (str): Directory containing the data.
message_pool (dict): Pool for managing messages between the client and server.
device (torch.device): Device on which computations will be performed (e.g., CPU or GPU).
- execute()[source]
Executes the local training process on the client’s data. During the first round, the client initializes the graph convolutional layer names. It then updates the model weights based on the server’s clustered model weights, trains the model, and calculates the gradients for the graph convolutional layers.
- send_message()[source]
Sends the updated model parameters, gradient norms, and weight differences to the server. This information will be used by the server to update the global model and cluster assignments.
- The message contains:
num_samples: The number of samples the client trained on.
W: The current model parameters.
convGradsNorm: The norm of the gradients for the graph convolutional layers.
dW: The differences between the current and previous model parameters.
- class openfgl.flcore.moon.client.MoonClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientMoonClient implements the client-side logic for Model-contrastive Federated Learning (MOON). This approach enhances federated learning by incorporating a contrastive loss that encourages the client’s model to stay close to a global model while maintaining consistency with its previous local model state. The class extends the BaseClient class and manages local training, contrastive loss computation, and embedding updates.
- Attributes:
prev_local_embedding (torch.Tensor): The embedding generated by the previous local model.
global_embedding (torch.Tensor): The embedding generated by the global model.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the MoonClient with the provided arguments, data, and device.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): The unique ID of the client.
data (object): The data associated with this client.
data_dir (str): Directory containing the data.
message_pool (dict): Pool for managing messages between the client and server.
device (torch.device): The device (CPU or GPU) to be used for computations.
- execute()[source]
Executes the client’s local training process. It synchronizes the local model with the global model, evaluates the current embedding, and then trains the local model using the custom loss function. The previous local embedding is updated after training.
- get_custom_loss_fn()[source]
Defines a custom loss function that combines the standard task loss with a model-contrastive loss. The contrastive loss encourages the local model to stay close to the global model while also maintaining consistency with its previous state.
- Returns:
function: The custom loss function.
- class openfgl.flcore.isolate.client.IsolateClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientIsolateClient represents a federated learning client that operates in isolation, without participating in the typical federated aggregation and communication process. This class is intended for use cases where the client trains a model independently and does not send updates back to the server.
- Attributes:
task (object): The task object containing the model, data, and training configurations.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the IsolateClient with the provided arguments, data, and device.
- Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): Unique identifier for the client.
data (object): The dataset assigned to this client.
data_dir (str): Directory containing the data.
message_pool (dict): Pool for managing messages between the server and clients.
device (torch.device): Device on which computations will be performed (e.g., CPU or GPU).
- class openfgl.flcore.scaffold.client.ScaffoldClient(args, client_id, data, data_dir, message_pool, device)[source]
Bases:
BaseClientScaffoldClient implements the client-side logic for the SCAFFOLD algorithm in Federated Learning. SCAFFOLD aims to reduce the variance caused by client drift by introducing control variates (local and global control variables) that adjust the client updates during training.
- Attributes:
local_control (list[torch.Tensor]): A list of tensors representing the local control variates for each parameter in the model.
- __init__(args, client_id, data, data_dir, message_pool, device)[source]
Initializes the ScaffoldClient with the provided arguments, client ID, data, and device.
- Args:
args (Namespace): Arguments containing model and training configurations. client_id (int): ID of the client. data (object): Data specific to the client’s task. data_dir (str): Directory containing the data. message_pool (dict): Pool for managing messages between the client and the server. device (torch.device): Device to run the computations on (CPU or GPU).
- execute()[source]
Executes the local training process for the client. It involves updating the local model with the global model parameters and applying the control variates to adjust the gradients before training.
- send_message()[source]
Sends the updated model parameters and local control variates to the server after local training is completed. This information is used by the server to update the global model and control variates for the next round.