CLFMixin#

class fl_sim.models.CLFMixin[source]#

Bases: object

Mixin class for classifiers.

predict(input: Tensor | ndarray, thr: float | None = None, class_map: Dict[int, str] | None = None, batched: bool = False) list[source]#

Predict the class labels.

Parameters:
  • input (torch.Tensor or numpy.ndarray) – The input data.

  • thr (float, optional) – The threshold for multi-label classification. None for single-label classification.

  • class_map (dict, optional) – The mapping from class index to class name.

  • batched (bool, default False) – Whether the input is batched.

Returns:

labels – The predicted class labels.

Return type:

list

predict_proba(input: Tensor | ndarray, multi_label: bool = False, batched: bool = False) ndarray[source]#

Predict probabilities for each class.

Parameters:
  • input (torch.Tensor or numpy.ndarray) – The input data.

  • multi_label (bool, default False) – Whether the model is a multi-label classifier.

  • batched (bool, default False) – Whether the input is batched.

Returns:

proba – The predicted probabilities.

Return type:

numpy.ndarray