I'm trying to generalize reliability diagrams [1] to a multiclass classifier and implement that using pytorch and pytorch-metrics.
So far so good but I'm somewhat confused about the definition of accuracy and how it applies to the intersection of multiple classes and multiple confidence bins, let me illustrate with an example:
Say I have 3 classes A, B and C (which is there just to make sure the problem is not a binary classifier). For the sake of the example, my model always outputs confidence 0 for C class.
Let's say I see 20 samples where:
- 10 samples are classified as
Awith confidence 0.8 and their ground truth is actuallyA - 9 samples are classified as
Bwith confidence 0.8 and their ground truth is actuallyB - 1 sample is classified as
Bwith confidence 0.8 but its ground truth isA
Let's say I'm drawing multiclass reliability diagram with 2 bins: [0.0-0.5) and [0.5-1]. My current code would output this: Class A: [0, 1] Class B: [0, 9/10]
This seems correct to me as the definition of accuracy per confidence bin from Guo et al. goes "the accuracy of B_m is": 
But it very much shocks me to ascribe an accuracy of 1 to class A knowing it had a false negative. I've tried to rationalize this away by telling myself that all false negatives are already accounted as false positives on the other class but I'm not very sure this makes much sense. It also checks out that this false negative would need to be considered on the 0.0-0.5 bin... and it technically is, as it contributes exactly as a false negative would being there: by incrementing the sample count for that bin but not the reliability.
So my questions are:
- Is this the correct way to calculate a reliability diagram for each class when there's many different classes?
- Is there any literature on doing this?
- If there isn't, is this a reasonable way to do it?
This is my current code in case you're interested:
from typing import Any, List, Optional, Tuple import torch from torch import Tensor from torchmetrics.metric import Metric class MulticlassReliabilityDiagram(Metric): r"""Compute the reliability diagram for classification tasks. A reliability diagram depicts accuracy on its Y-axis over confidence intervals split on several bins on its X-axis. Reliability diagrams are useful for visualizing a classifier's calibration error across all confidence bins by capturing fine calibration information in an easily interpretable plot. As input to ``forward`` and ``update`` the metric accepts the following input: - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). .. note:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns: - ``reliability`` (:class:`~torch.Tensor`): a tensor containing the Reliability histogram - ``frequency`` (:class:`~torch.Tensor`): a tensor containing the frequency observed for each individual confidence interval - ``class_reliability`` (:class:`~torch.Tensor`): a tensor containing the Reliability histogram per class, ignored index does *not* affect this output in order to keep the same numbering for the classes, you will need to manually ignore the index when using this tensor - ``class_frequency`` (:class:`~torch.Tensor`): a tensor containing the frequency observed for each individual confidence interval per class Args: num_classes: Integer specifying the number of classes bins: Integer specifying number of bins in which to partition the confidence domain, default: 10 kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False preds: List[Tensor] target: List[Tensor] def __init__( self, num_classes: int, bins: Optional[int] = None, ignore_index: Optional[int] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.num_classes = num_classes self.ignore_index = ignore_index self.bins = bins or 10 shape = (self.bins,) # Stores frequency of each bin self.add_state( "frequency", default=torch.zeros(shape, dtype=torch.int64), dist_reduce_fx="sum", ) # Stores success frequency of each bin self.add_state( "success", default=torch.zeros(shape, dtype=torch.int64), dist_reduce_fx="sum", ) shape_classes = (self.num_classes, self.bins) # Stores frequency of each bin per-class self.add_state( "class_frequency", default=torch.zeros(shape_classes, dtype=torch.int64), dist_reduce_fx="sum", ) # Stores success frequency of each bin per-class self.add_state( "class_success", default=torch.zeros(shape_classes, dtype=torch.int64), dist_reduce_fx="sum", ) def update(self, preds: Tensor, target: Tensor) -> None: """Update metric states.""" preds, target = _multiclass_reliability_diagram_format( preds, target, self.num_classes, self.ignore_index ) confidences, classes = torch.max(preds, dim=1) # NOTE: No deterministic implementation of histc, this is not very relevant # since this is used for generating reliability diagram, but it forces us # to turn determinism off for this func, see also next occurrence a few lines # further down torch.use_deterministic_algorithms(False) # Update frequencies and success counts both total and per-class freq_histogram = torch.histc(preds, bins=self.bins, min=0, max=1) self.frequency = torch.add(self.frequency, freq_histogram) succ_idx = classes == target succ_confidences = confidences[succ_idx] success_histogram = torch.histc(succ_confidences, bins=self.bins, min=0, max=1) self.success = torch.add(self.success, success_histogram) class_freq_histogram = torch.zeros( (self.num_classes, self.bins), device=self.device ) class_succ_histogram = torch.zeros( (self.num_classes, self.bins), device=self.device ) for index in range(self.num_classes): class_freq_histogram[index] = torch.histc( preds[:, index], bins=self.bins, min=0, max=1 ) class_succ_idx = torch.logical_and(succ_idx, classes == index) class_succ_confidences = confidences[class_succ_idx] class_succ_histogram[index] = torch.histc( class_succ_confidences, bins=self.bins, min=0, max=1 ) self.class_frequency = torch.add(self.class_frequency, class_freq_histogram) self.class_success = torch.add(self.class_success, class_succ_histogram) torch.use_deterministic_algorithms(True) def compute( self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Compute metric.""" accuracy = _safe_divide(self.success, self.frequency) frequency = _safe_divide(self.frequency, self.frequency.sum()) class_accuracy = _safe_divide(self.class_success, self.class_frequency) class_frequency = _safe_divide( self.class_frequency, self.class_frequency.sum(dim=1, keepdim=True) ) return (accuracy, frequency, class_accuracy, class_frequency) def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: """Safe division, by preventing division by zero. Additionally casts to float if input is not already to secure backwards compatibility. """ num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() return num / torch.where(denom == 0.0, 1.0, denom) def _multiclass_reliability_diagram_format( preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """Convert all input to the right format. - flattens additional dimensions - Remove all datapoints that should be ignored - Applies softmax if pred tensor not in [0,1] range """ preds = preds.transpose(0, 1).reshape(num_classes, -1).T target = target.flatten() if ignore_index is not None: idx = target != ignore_index preds = preds[idx] target = target[idx] if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) return preds, target I've also written some tests to check that this is right:
import math import torch from ..src.reliability_diagram import MulticlassReliabilityDiagram def array_is_close(array1, array2, index) -> bool: return math.isclose(array1[index], array2[index]) def test_all_right(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.6, 0.2], [0.2, 0.3]], # class 1 [[0.1, 0.7], [0.7, 0.2]], # class 2 [[0.3, 0.1], [0.1, 0.5]], ], # batch 1 [ # class 0 [[0.3, 0.2], [0.6, 0.6]], # class 1 [[0.2, 0.7], [0.1, 0.1]], # class 2 [[0.5, 0.1], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 4 0, 0.25, 0.50, 0.75 expected_reliability = [0, 0, 1, 0] expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0] expected_class_reliability = [ [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 1, 0], ] expected_class_frequency = [ [3 / 8, 2 / 8, 3 / 8, 0], [5 / 8, 0 / 8, 3 / 8, 0], [3 / 8, 3 / 8, 2 / 8, 0], ] num_classes = 3 num_bins = 4 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, class_reliability, class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index) for class_number in range(num_classes): for index in range(len(frequency)): assert array_is_close( class_frequency[class_number].tolist(), expected_class_frequency[class_number], index, ) for index in range(len(reliability)): assert array_is_close( class_reliability[class_number].tolist(), expected_class_reliability[class_number], index, ) def test_all_wrong(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.1, 0.2], [0.2, 0.5]], # class 1 [[0.6, 0.1], [0.1, 0.2]], # class 2 [[0.3, 0.7], [0.7, 0.3]], ], # batch 1 [ # class 0 [[0.5, 0.2], [0.1, 0.1]], # class 1 [[0.2, 0.1], [0.6, 0.6]], # class 2 [[0.3, 0.7], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 4 0, 0.25, 0.50, 0.75 expected_reliability = [0, 0, 0, 0] expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0] expected_class_reliability = [ [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], ] expected_class_frequency = [ [6 / 8, 0, 2 / 8, 0], [5 / 8, 0, 3 / 8, 0], [0, 5 / 8, 3 / 8, 0], ] num_classes = 3 num_bins = 4 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, class_reliability, class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index) for class_number in range(num_classes): for index in range(len(frequency)): assert array_is_close( class_frequency[class_number].tolist(), expected_class_frequency[class_number], index, ) for index in range(len(reliability)): assert array_is_close( class_reliability[class_number].tolist(), expected_class_reliability[class_number], index, ) def test_mixed(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.6, 0.2], [0.2, 0.3]], # class 1 [[0.1, 0.7], [0.7, 0.2]], # class 2 [[0.3, 0.1], [0.1, 0.5]], ], # batch 1 [ # class 0 [[0.5, 0.2], [0.6, 0.6]], # class 1 [[0.2, 0.7], [0.1, 0.1]], # class 2 [[0.3, 0.1], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 4 0, 0.25, 0.50, 0.75 expected_reliability = [0, 0, 7 / 8, 0] expected_frequency = [11 / 24, 5 / 24, 8 / 24, 0] expected_class_reliability = [ [0, 0, 3 / 4, 0], [0, 0, 1, 0], [0, 0, 1, 0], ] expected_class_frequency = [ [3 / 8, 1 / 8, 4 / 8, 0], [5 / 8, 0 / 8, 3 / 8, 0], [3 / 8, 4 / 8, 1 / 8, 0], ] num_classes = 3 num_bins = 4 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, class_reliability, class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index) for class_number in range(num_classes): for index in range(len(frequency)): assert array_is_close( class_frequency[class_number].tolist(), expected_class_frequency[class_number], index, ) for index in range(len(reliability)): assert array_is_close( class_reliability[class_number].tolist(), expected_class_reliability[class_number], index, ) def test_mixed_10_bins(): # Batch size: 2 # Number of classes: 3 # Width, height: 2, 2 target = torch.tensor([[[0, 1], [1, 2]], [[2, 1], [0, 0]]], dtype=torch.float64) prediction = torch.tensor( [ # batch 0 [ # class 0 [[0.1, 0.2], [0.2, 0.5]], # class 1 [[0.6, 0.1], [0.1, 0.2]], # class 2 [[0.3, 0.7], [0.7, 0.3]], ], # batch 1 [ # class 0 [[0.3, 0.2], [0.1, 0.1]], # class 1 [[0.2, 0.1], [0.6, 0.6]], # class 2 [[0.5, 0.7], [0.3, 0.3]], ], ], dtype=torch.float64, ) # bins = 10 expected_reliability = [0, 0, 0, 0, 0, 1 / 2, 0, 0, 0, 0] expected_frequency = [ 0, 6 / 24, 5 / 24, 5 / 24, 0, 2 / 24, 3 / 24, 3 / 24, 0, 0, ] num_classes = 3 num_bins = 10 metric = MulticlassReliabilityDiagram(num_classes, num_bins) metric.update(prediction, target) reliability, frequency, _class_reliability, _class_frequency = metric.compute() for index in range(len(frequency)): assert array_is_close(frequency.tolist(), expected_frequency, index) for index in range(len(reliability)): assert array_is_close(reliability.tolist(), expected_reliability, index)