fl_sim.optimizers.get_optimizer

Contents

fl_sim.optimizers.get_optimizer#

fl_sim.optimizers.get_optimizer(optimizer_name: str | type, params: Iterable[dict | Parameter], config: Any) Optimizer[source]#

Get optimizer by name.

Parameters:
  • optimizer_name (Union[str, type]) – Optimizer name or class

  • params (Iterable[Union[dict, torch.nn.parameter.Parameter]]) – Parameters to be optimized

  • config (Any) – Config for optimizer. Should be a dict or a class with attributes which can be accessed by config.attr.

Returns:

Instance of the given optimizer.

Return type:

torch.optim.Optimizer

Examples

import torch

model = torch.nn.Linear(10, 1)
optimizer = get_optimizer("SGD", model.parameters(), {"lr": 1e-2})  # PyTorch built-in
optimizer = get_optimizer("yogi", model.parameters(), {"lr": 1e-2})  # from pytorch_optimizer
optimizer = get_optimizer("FedPD_SGD", model.parameters(), {"lr": 1e-2})  # federated