import torch
import torch.nn as nn
from openfgl.flcore.base import BaseClient
from openfgl.flcore.moon.moon_config import config
[docs]class MoonClient(BaseClient):
"""
MoonClient implements the client-side logic for Model-contrastive Federated Learning (MOON).
This approach enhances federated learning by incorporating a contrastive loss that encourages
the client's model to stay close to a global model while maintaining consistency with its
previous local model state. The class extends the BaseClient class and manages local training,
contrastive loss computation, and embedding updates.
Attributes:
prev_local_embedding (torch.Tensor): The embedding generated by the previous local model.
global_embedding (torch.Tensor): The embedding generated by the global model.
"""
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device):
"""
Initializes the MoonClient with the provided arguments, data, and device.
Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): The unique ID of the client.
data (object): The data associated with this client.
data_dir (str): Directory containing the data.
message_pool (dict): Pool for managing messages between the client and server.
device (torch.device): The device (CPU or GPU) to be used for computations.
"""
super(MoonClient, self).__init__(args, client_id, data, data_dir, message_pool, device)
self.prev_local_embedding = None
self.global_embedding = None
[docs] def get_custom_loss_fn(self):
"""
Defines a custom loss function that combines the standard task loss with a model-contrastive
loss. The contrastive loss encourages the local model to stay close to the global model while
also maintaining consistency with its previous state.
Returns:
function: The custom loss function.
"""
def custom_loss_fn(embedding, logits, label, mask):
task_loss = self.task.default_loss_fn(logits[mask], label[mask])
if self.message_pool["round"] == 0 or self.task.num_samples != label.shape[0]: # first round eval on global
return task_loss
else:
sim_global = torch.cosine_similarity(embedding, self.global_embedding, dim=-1).view(-1, 1)
sim_prev = torch.cosine_similarity(embedding, self.prev_local_embedding, dim=-1).view(-1, 1)
logits = torch.cat((sim_global, sim_prev), dim=1) / config["temperature"]
lbls = torch.zeros(embedding.size(0)).to(self.device).long()
contrastive_loss = nn.CrossEntropyLoss()(logits ,lbls)
moon_loss = config["moon_mu"] * contrastive_loss
return task_loss + moon_loss
return custom_loss_fn
[docs] def execute(self):
"""
Executes the client's local training process. It synchronizes the local model with the global model,
evaluates the current embedding, and then trains the local model using the custom loss function.
The previous local embedding is updated after training.
"""
with torch.no_grad():
for (local_param, global_param) in zip(self.task.model.parameters(), self.message_pool["server"]["weight"]):
local_param.data.copy_(global_param)
self.global_embedding = self.task.evaluate(mute=True)["embedding"].detach()
self.task.loss_fn = self.get_custom_loss_fn()
self.task.train()
self.prev_local_embedding = self.task.evaluate(mute=True)["embedding"].detach()
[docs] def send_message(self):
"""
Sends the updated model weights and the number of samples used in training to the server.
This message is used by the server to aggregate the global model.
"""
self.message_pool[f"client_{self.client_id}"] = {
"num_samples": self.task.num_samples,
"weight": list(self.task.model.parameters())
}