import torch
import torch.nn as nn
from openfgl.utils.basic_utils import load_task
[docs]class BaseClient:
"""
Base class for a client in a federated learning setup.
Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
task (object): Task-specific data and functions loaded via the `load_task` utility.
personalized (bool): Flag to indicate if the client is using a personalized algorithm.
"""
[docs] def __init__(self, args, client_id, data, data_dir, message_pool, device, personalized=False):
"""
Initialize the BaseClient with provided arguments and data.
Attributes:
args (Namespace): Arguments containing model and training configurations.
client_id (int): ID of the client.
data (object): Data specific to the client's task.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
personalized (bool, optional): Flag to indicate if the client is using a personalized algorithm. Defaults to False.
"""
self.args = args
self.client_id = client_id
self.message_pool = message_pool
self.device = device
self.task = load_task(args, client_id, data, data_dir, device)
self.personalized = personalized
[docs] def execute(self):
"""
Client local execution. This method should be implemented by subclasses.
"""
raise NotImplementedError
[docs] def send_message(self):
"""
Send a message to the server. This method should be implemented by subclasses.
"""
raise NotImplementedError
[docs]class BaseServer:
"""
Base class for a server in a federated learning setup.
Attributes:
args (Namespace): Arguments containing model and training configurations.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
task (object): Task-specific data and functions loaded via the `load_task` utility.
personalized (bool): Flag to indicate if the server is using a personalized algorithm.
"""
[docs] def __init__(self, args, global_data, data_dir, message_pool, device, personalized=False):
"""
Initialize the BaseServer with provided arguments and data.
Attributes:
args (Namespace): Arguments containing model and training configurations.
global_data (object): Global data accessible to the server.
data_dir (str): Directory containing the data.
message_pool (object): Pool for managing messages between client and server.
device (torch.device): Device to run the computations on.
personalized (bool, optional): Flag to indicate if the server is using a personalized algorithm. Defaults to False.
"""
self.args = args
self.message_pool = message_pool
self.device = device
self.task = load_task(args, None, global_data, data_dir, device)
self.personalized = personalized
[docs] def execute(self):
"""
Server global execution. This method should be implemented by subclasses.
"""
raise NotImplementedError
[docs] def send_message(self):
"""
Send messages to clients. This method should be implemented by subclasses.
"""
raise NotImplementedError