Source code for openfgl.task.base

from torch.optim import Adam
    
[docs]class BaseTask: """ Base class for defining a task in a federated learning setup. Attributes: client_id (int): ID of the client. data_dir (str): Directory containing the data. args (Namespace): Arguments containing model and training configurations. device (torch.device): Device to run the computations on. data (object): Data specific to the task. model (torch.nn.Module): Model to be trained. optim (torch.optim.Optimizer): Optimizer for the model. override_evaluate (function): Custom evaluation function, if provided. step_preprocess (function): Custom preprocessing step, if provided. """ def __init__(self, args, client_id, data, data_dir, device): """ Initialize the BaseTask with provided arguments, data, and device. Attributes: args (Namespace): Arguments containing model and training configurations. client_id (int): ID of the client. data (object): Data specific to the task. data_dir (str): Directory containing the data. device (torch.device): Device to run the computations on. """ self.client_id = client_id self.data_dir = data_dir self.args = args self.device = device if data is not None: self.data = data if hasattr(self.data, "_data_list"): self.data._data_list = None self.data = self.data.to(device) self.load_train_val_test_split() self.model = self.default_model.to(device) self.optim = Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) self.override_evaluate = None self.step_preprocess = None
[docs] def train(self): """ Train the model on the provided data. This method should be implemented by subclasses. """ raise NotImplementedError
[docs] def evaluate(self): """ Evaluate the model on the provided data. This method should be implemented by subclasses. """ raise NotImplementedError
@property def num_samples(self): """ Get the number of samples in the dataset. This method should be implemented by subclasses. Returns: int: Number of samples. """ raise NotImplementedError @property def default_model(self): """ Get the default model for the task. This method should be implemented by subclasses. Returns: torch.nn.Module: Default model. """ raise NotImplementedError @property def default_optim(self): """ Get the default optimizer for the task. This method should be implemented by subclasses. Returns: torch.optim.Optimizer: Default optimizer. """ raise NotImplementedError @property def default_loss_fn(self): """ Get the default loss function for the task. This method should be implemented by subclasses. Returns: function: Default loss function. """ raise NotImplementedError @property def train_val_test_path(self): """ Get the path to the train/validation/test split file. This method should be implemented by subclasses. Returns: str: Path to the split file. """ raise NotImplementedError @property def default_train_val_test_split(self): """ Get the default train/validation/test split. This method should be implemented by subclasses. Returns: dict: Default train/validation/test split. """ raise NotImplementedError
[docs] def load_train_val_test_split(self): """ Load the train/validation/test split from a file. This method should be implemented by subclasses. """ raise NotImplementedError
[docs] def load_custom_model(self, custom_model): """ Load a custom model for the task and reinitialize the optimizer. Args: custom_model (torch.nn.Module): Custom model to be used. """ self.model = custom_model.to(self.device) self.optim = self.optim = Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)