DiffMixin

DiffMixin#

class fl_sim.models.DiffMixin[source]#

Bases: object

Mixin for differences of two models.

Examples

class ModelA(nn.Module, DiffMixin):
    def __init__(self, out_dim):
        super().__init__()
        self.fc = nn.Linear(10, out_dim)

model_1 = ModelA(10)
model_2 = ModelA(10)
model_1.diff(model_2, norm=2)
diff(other: object, norm: str | int | float | None = None) float | List[Tensor][source]#

Compute the difference between two models.

Parameters:
  • other (object) – Another model, which has the same structure as this one.

  • norm (str or int or float, optional) – The norm to compute the difference. None for the raw difference. Refer to torch.linalg.norm() for more details.

Returns:

diff – The difference.

Return type:

float or List[torch.Tensor]