Node#

class fl_sim.nodes.Node[source]#

Bases: ReprMixin, ABC

An abstract base class for the server and client nodes.

static aggregate_results_from_json_log(d: dict | str | Path, part: str = 'val', metric: str = 'acc') ndarray[source]#

Aggregate the federated results from csv log.

Parameters:
  • d (dict or str or pathlib.Path) – The dict of the json/yaml log, or the path to the json/yaml log file.

  • part (str, default "train") – The part of the log to aggregate.

  • metric (str, default "acc") – The metric to aggregate.

Returns:

The aggregated results (curve).

Return type:

numpy.ndarray

Note

The parameter d should be a dict similar to the following structure:

{
    "train": {
        "client0": [
            {
                "epoch": 1,
                "step": 1,
                "time": "2020-01-01 00:00:00",
                "loss": 0.1,
                "acc": 0.2,
                "top3_acc": 0.3,
                "top5_acc": 0.4,
                "num_samples": 100
            }
        ]
    },
    "val": {
        "client0": [
            {
                "epoch": 1,
                "step": 1,
                "time": "2020-01-01 00:00:00",
                "loss": 0.1,
                "acc": 0.2,
                "top3_acc": 0.3,
                "top5_acc": 0.4,
                "num_samples": 100
            }
        ]
    }
}
abstract communicate(target: Node) None[source]#

Communicate with the target node.

The current node communicates model parameters, gradients, etc. to target node. For example, a client node communicates its local model parameters to server node via

target._received_messages.append(
    ClientMessage(
        {
            "client_id": self.client_id,
            "parameters": self.get_detached_model_parameters(),
            "train_samples": self.config.num_epochs * self.config.num_steps * self.config.batch_size,
            "metrics": self._metrics,
        }
    )
)

For a server node, global model parameters are communicated to clients via

target._received_messages = {"parameters": self.get_detached_model_parameters()}
compute_gradients(at: Sequence[Tensor] | None = None, dataloader: DataLoader | None = None) List[Tensor][source]#

Compute the gradients of the model on the node.

The gradients are computed on the model parameters at or the current model parameters, as the average of the gradients on the mini-batches from dataloader or self.train_loader.

Parameters:
  • at (list of torch.Tensor, optional) – The model parameters to compute the gradients. None for the current model parameters.

  • dataloader (torch.utils.data.DataLoader, optional) – The dataloader to compute the gradients. None for self.train_loader.

get_detached_model_parameters() List[Tensor][source]#

Get the detached model parameters.

get_gradients(norm: str | int | float | None = None, model: Module | None = None) float | List[Tensor][source]#

Get the gradients or norm of the gradients of the model on the node.

Parameters:
  • norm (str or int or float, optional) – The norm of the gradients to compute. None for the raw gradients (list of tensors). Refer to torch.linalg.norm() for more details.

  • model (torch.nn.Module, optional) – The model to get the gradients, default to self.model.

Returns:

The gradients or norm of the gradients.

Return type:

float or List[torch.Tensor]

static get_norm(tensor: Number | Tensor | ndarray | Parameter | generator | Sequence[ndarray | Tensor | Parameter | generator], norm: str | int | float = 'fro') float[source]#

Get the norm of a tensor.

Parameters:
Returns:

The norm of the tensor.

Return type:

float

abstract property is_convergent: bool#

Whether the training process on the node is convergent.

abstract property required_config_fields: List[str]#

The list of required fields in the config.

set_parameters(params: Iterable[Parameter], model: Module | None = None) None[source]#

Set the parameters of the model on the node.

Parameters:
Return type:

None

abstract update() None[source]#

Update model parameters, gradients, etc. according to self._reveived_messages.