Client#

class fl_sim.nodes.Client(client_id: int, device: device, model: Module, dataset: FedDataset, config: ClientConfig)[source]#

Bases: Node

The class to simulate the client node.

The client node is responsible for training the local models, and communicating with the server node.

Parameters:
evaluate(part: str) Dict[str, float][source]#

Evaluate the model on the given part of the dataset.

Parameters:

part (str) – The part of the dataset to evaluate on, can be either “train” or “val”.

Returns:

The metrics of the evaluation.

Return type:

Dict[str, float]

extra_repr_keys() List[str][source]#

Extra keys for __repr__() and __str__().

get_all_data() Tuple[Tensor, Tensor][source]#

Get all the data on the client.

This method is a helper function for fast access to the data on the client, including both training and validation data; both features and labels.

property is_convergent: bool#

Whether the training process is convergent.

sample_data() Tuple[Tensor, Tensor][source]#

Sample data for training.

solve_inner() None[source]#

Main part of inner loop solver.

Basic example:

self.model.train()
epoch_losses = []
for epoch in range(self.config.num_epochs):
    batch_losses = []
    for batch_idx, (data, target) in enumerate(self.train_loader):
        data, target = data.to(self.device), target.to(self.device)
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.criterion(output, target)
        loss.backward()
        self.optimizer.step()
        batch_losses.append(loss.item())
    epoch_losses.append(sum(batch_losses) / len(batch_losses))
    self.lr_scheduler.step()
abstract train() None[source]#

Main part of inner loop solver.

Basic example:

self.model.train()
epoch_losses = []
for epoch in range(self.config.num_epochs):
    batch_losses = []
    for batch_idx, (data, target) in enumerate(self.train_loader):
        data, target = data.to(self.device), target.to(self.device)
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.criterion(output, target)
        loss.backward()
        self.optimizer.step()
        batch_losses.append(loss.item())
    epoch_losses.append(sum(batch_losses) / len(batch_losses))
    self.lr_scheduler.step()