0
$\begingroup$

I'm trying to generalize reliability diagrams [1] to a multiclass classifier and implement that using pytorch and pytorch-metrics.

So far so good but I'm somewhat confused about the definition of accuracy and how it applies to the intersection of multiple classes and multiple confidence bins, let me illustrate with an example:

Say I have 3 classes A, B and C (which is there just to make sure the problem is not a binary classifier). For the sake of the example, my model always outputs confidence 0 for C class.

Let's say I see 20 samples where:

  • 10 samples are classified as A with confidence 0.8 and their ground truth is actually A
  • 9 samples are classified as B with confidence 0.8 and their ground truth is actually B
  • 1 sample is classified as B with confidence 0.8 but its ground truth is A

Let's say I'm drawing multiclass reliability diagram with 2 bins: [0.0-0.5) and [0.5-1]. My current code would output this: Class A: [0, 1] Class B: [0, 9/10]

This seems correct to me as the definition of accuracy per confidence bin from Guo et al. goes "the accuracy of B_m is": acc(B_m) = 1/B_m sum(1 * (ŷ_i = y_i)) for i in B_m

But it very much shocks me to ascribe an accuracy of 1 to class A knowing it had a false negative. I've tried to rationalize this away by telling myself that all false negatives are already accounted as false positives on the other class but I'm not very sure this makes much sense. It also checks out that this false negative would need to be considered on the 0.0-0.5 bin... and it technically is, as it contributes exactly as a false negative would being there: by incrementing the sample count for that bin but not the reliability.

So my questions are:

  • Is this the correct way to calculate a reliability diagram for each class when there's many different classes?
  • Is there any literature on doing this?
  • If there isn't, is this a reasonable way to do it?

This is my current code in case you're interested:

from typing import Any, List, Optional, Tuple import torch from torch import Tensor from torchmetrics.metric import Metric class MulticlassReliabilityDiagram(Metric): r"""Compute the reliability diagram for classification tasks. A reliability diagram depicts accuracy on its Y-axis over confidence intervals split on several bins on its X-axis. Reliability diagrams are useful for visualizing a classifier's calibration error across all confidence bins by capturing fine calibration information in an easily interpretable plot. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). .. note:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns: - ``reliability`` (:class:`~torch.Tensor`): a tensor containing the Reliability histogram - ``frequency`` (:class:`~torch.Tensor`): a tensor containing the frequency observed for each individual confidence interval - ``class_reliability`` (:class:`~torch.Tensor`): a tensor containing the Reliability histogram per class, ignored index does *not* affect this output in order to keep the same numbering for the classes, you will need to manually ignore the index when using this tensor - ``class_frequency`` (:class:`~torch.Tensor`): a tensor containing the frequency observed for each individual confidence interval per class Args: num_classes: Integer specifying the number of classes bins: Integer specifying number of bins in which to partition the confidence domain, default: 10 kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False preds: List[Tensor] target: List[Tensor] def __init__( self, num_classes: int, bins: Optional[int] = None, ignore_index: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.num_classes = num_classes self.ignore_index = ignore_index self.bins = bins or 10 shape = (self.bins,) # Stores frequency of each bin self.add_state( "frequency", default=torch.zeros(shape, dtype=torch.int64), dist_reduce_fx="sum", ) # Stores success frequency of each bin self.add_state( "success", default=torch.zeros(shape, dtype=torch.int64), dist_reduce_fx="sum", ) shape_classes = (self.num_classes, self.bins) # Stores frequency of each bin per-class self.add_state( "class_frequency", default=torch.zeros(shape_classes, dtype=torch.int64), dist_reduce_fx="sum", ) # Stores success frequency of each bin per-class self.add_state( "class_success", default=torch.zeros(shape_classes, dtype=torch.int64), dist_reduce_fx="sum", ) def update(self, preds: Tensor, target: Tensor) -> None: """Update metric states.""" preds, target = _multiclass_reliability_diagram_format( preds, target, self.num_classes, self.ignore_index ) confidences, classes = torch.max(preds, dim=1) # NOTE: No deterministic implementation of histc, this is not very relevant # since this is used for generating reliability diagram, but it forces us # to turn determinism off for this func, see also next occurrence a few lines # further down torch.use_deterministic_algorithms(False) # Update frequencies and success counts both total and per-class freq_histogram = torch.histc(preds, bins=self.bins, min=0, max=1) self.frequency = torch.add(self.frequency, freq_histogram) succ_idx = classes == target succ_confidences = confidences[succ_idx] success_histogram = torch.histc(succ_confidences, bins=self.bins, min=0, max=1) self.success = torch.add(self.success, success_histogram) class_freq_histogram = torch.zeros( (self.num_classes, self.bins), device=self.device ) class_succ_histogram = torch.zeros( (self.num_classes, self.bins), device=self.device ) for index in range(self.num_classes): class_freq_histogram[index] = torch.histc( preds[:, index], bins=self.bins, min=0, max=1 ) class_succ_idx = torch.logical_and(succ_idx, classes == index) class_succ_confidences = confidences[class_succ_idx] class_succ_histogram[index] = torch.histc( class_succ_confidences, bins=self.bins, min=0, max=1 ) self.class_frequency = torch.add(self.class_frequency, class_freq_histogram) self.class_success = torch.add(self.class_success, class_succ_histogram) torch.use_deterministic_algorithms(True) def compute( self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Compute metric.""" accuracy = _safe_divide(self.success, self.frequency) frequency = _safe_divide(self.frequency, self.frequency.sum()) class_accuracy = _safe_divide(self.class_success, self.class_frequency) class_frequency = _safe_divide( self.class_frequency, self.class_frequency.sum(dim=1, keepdim=True) ) return (accuracy, frequency, class_accuracy, class_frequency) def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: """Safe division, by preventing division by zero. Additionally casts to float if input is not already to secure backwards compatibility. """ num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() return num / torch.where(denom == 0.0, 1.0, denom) def _multiclass_reliability_diagram_format( preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """Convert all input to the right format. - flattens additional dimensions - Remove all datapoints that should be ignored - Applies softmax if pred tensor not in [0,1] range """ preds = preds.transpose(0, 1).reshape(num_classes, -1).T target = target.flatten() if ignore_index is not None: idx = target != ignore_index preds = preds[idx] target = target[idx] if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) return preds, target 

I've also written some tests to check that this is right:

import math import torch from ..src.reliability_diagram import MulticlassReliabilityDiagram def array_is_close(array1, array2, index) -> bool: return math.isclose(array1[index], array2[index]) def test_all_right(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.6, 0.2], [0.2, 0.3]], # class 1 [[0.1, 0.7], [0.7, 0.2]], # class 2 [[0.3, 0.1], [0.1, 0.5]], ], # batch 1 [ # class 0 [[0.3, 0.2], [0.6, 0.6]], # class 1 [[0.2, 0.7], [0.1, 0.1]], # class 2 [[0.5, 0.1], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 4 0, 0.25, 0.50, 0.75 expected_reliability = [0, 0, 1, 0] expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0] expected_class_reliability = [ [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 1, 0], ] expected_class_frequency = [ [3 / 8, 2 / 8, 3 / 8, 0], [5 / 8, 0 / 8, 3 / 8, 0], [3 / 8, 3 / 8, 2 / 8, 0], ] num_classes = 3 num_bins = 4 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, class_reliability, class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index) for class_number in range(num_classes): for index in range(len(frequency)): assert array_is_close( class_frequency[class_number].tolist(), expected_class_frequency[class_number], index, ) for index in range(len(reliability)): assert array_is_close( class_reliability[class_number].tolist(), expected_class_reliability[class_number], index, ) def test_all_wrong(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.1, 0.2], [0.2, 0.5]], # class 1 [[0.6, 0.1], [0.1, 0.2]], # class 2 [[0.3, 0.7], [0.7, 0.3]], ], # batch 1 [ # class 0 [[0.5, 0.2], [0.1, 0.1]], # class 1 [[0.2, 0.1], [0.6, 0.6]], # class 2 [[0.3, 0.7], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 4 0, 0.25, 0.50, 0.75 expected_reliability = [0, 0, 0, 0] expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0] expected_class_reliability = [ [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], ] expected_class_frequency = [ [6 / 8, 0, 2 / 8, 0], [5 / 8, 0, 3 / 8, 0], [0, 5 / 8, 3 / 8, 0], ] num_classes = 3 num_bins = 4 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, class_reliability, class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index) for class_number in range(num_classes): for index in range(len(frequency)): assert array_is_close( class_frequency[class_number].tolist(), expected_class_frequency[class_number], index, ) for index in range(len(reliability)): assert array_is_close( class_reliability[class_number].tolist(), expected_class_reliability[class_number], index, ) def test_mixed(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.6, 0.2], [0.2, 0.3]], # class 1 [[0.1, 0.7], [0.7, 0.2]], # class 2 [[0.3, 0.1], [0.1, 0.5]], ], # batch 1 [ # class 0 [[0.5, 0.2], [0.6, 0.6]], # class 1 [[0.2, 0.7], [0.1, 0.1]], # class 2 [[0.3, 0.1], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 4 0, 0.25, 0.50, 0.75 expected_reliability = [0, 0, 7 / 8, 0] expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0] expected_class_reliability = [ [0, 0, 3 / 4, 0], [0, 0, 1, 0], [0, 0, 1, 0], ] expected_class_frequency = [ [3 / 8, 1 / 8, 4 / 8, 0], [5 / 8, 0 / 8, 3 / 8, 0], [3 / 8, 4 / 8, 1 / 8, 0], ] num_classes = 3 num_bins = 4 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, class_reliability, class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index) for class_number in range(num_classes): for index in range(len(frequency)): assert array_is_close( class_frequency[class_number].tolist(), expected_class_frequency[class_number], index, ) for index in range(len(reliability)): assert array_is_close( class_reliability[class_number].tolist(), expected_class_reliability[class_number], index, ) def test_mixed_10_bins(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.1, 0.2], [0.2, 0.5]], # class 1 [[0.6, 0.1], [0.1, 0.2]], # class 2 [[0.3, 0.7], [0.7, 0.3]], ], # batch 1 [ # class 0 [[0.3, 0.2], [0.1, 0.1]], # class 1 [[0.2, 0.1], [0.6, 0.6]], # class 2 [[0.5, 0.7], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 10 expected_reliability = [0, 0, 0, 0, 0, 1 / 2, 0, 0, 0, 0] expected_frequency = [ 0, 6 / 24, 5 / 24, 5 / 24, 0, 2 / 24, 3 / 24, 3 / 24, 0, 0, ] num_classes = 3 num_bins = 10 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, _class_reliability, _class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index) 
$\endgroup$

1 Answer 1

1
$\begingroup$

I think the output for Class A should be in your case [1/10, 1]. That's because you have 9 samples that with confidence 0.2 predict A (but their ground truth is actually B) and 1 sample that with confidence 0.2 predicts A (and its ground truth is also A). So it's 1/(1+9)=1/10.

One link that helped me in understanding this is https://towardsdatascience.com/introduction-to-reliability-diagrams-for-probability-calibration-ed785b3f5d44

$\endgroup$
1
  • $\begingroup$ Thank you! I have since finished my research and since there are no more answers I'm giving you the accepted answer while a more comprehensive doesn't appear. $\endgroup$ Commented Nov 6, 2024 at 17:06

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.