INSightR-Net: Interpretable Neural Network for Regression using Similarity-based Comparisons to Prototypical Examples
Linde S Hesse1,2
Ana IL Namburete1,3

1 Oxford Machine Learning in NeuroImaging (OMNI) laboratory, Department of Computer Science, University of Oxford
2 Institute of Biomedical Engineering, Department of Engineering Science
3 Wellcome Centre for Integrative Neuroscience, Nuffield Department of Clinical Neuroscience, University of Oxford

[preprint]
[code]


Abstract

Convolutional neural networks (CNNs) have shown exceptional performance for a range of medical imaging tasks. However, conventional CNNs are not able to explain their reasoning process, therefore limiting their adoption in clinical practice. In this work, we propose an inherently interpretable CNN for regression using similarity-based comparisons (INSightR-Net) and demonstrate our methods on the task of diabetic retinopathy grading. A prototype layer incorporated into the architecture enables visualization of the areas in the image that are most similar to learned prototypes. The final prediction is then intuitively modeled as a mean of prototype labels, weighted by the similarities. We achieved competitive prediction performance with our INSightR-Net compared to a ResNet baseline, showing that it is not necessary to compromise performance for interpretability. Furthermore, we quantified the quality of our explanations using sparsity and diversity, two concepts considered important for a good explanation, and demonstrated the effect of several parameters on the latent space embeddings.


Methods

We included a prototypical layer in our network architecture that computes the similarities of an image latent representation to a set of learned prototypes. All prototypes are assigned a label at the start of training, in the same range as the dataset labels. The prediction of our model for a new image is then made using these similarities, and is formulated as a weighted mean of prototype labels, with the weights consisting of the similarities (s) and a prototype improtance score (r).




Results

Example prediction

The prediction for an image as a weighted mean of prototype labels provides an intuitive explanation of the models' reasoning process.


Quantitative results

INSightR-Net maintains baseline prediction performance, indicating it is not necessary to sacrifice performance for interpretability. Our method also obtains a lower sparsity while maintaining model diversity by replacing the Log-Similarity activation by a new function.


Latent Representations

The latent representations clearly show the effect of the applied cluster loss and the replacement of the minimum by the k-minimum. The cluster loss enforces the latent patches to be clustered around prototypes, thus encouraging each prototype to display a representative concept.

Top row: 2D PCA of latent space embeddings (Z) of the test set samples (stars) and learned prototypes (circles). For each sample only the 5 embedings closest to a prototype are shown. Bottom row: Probability density histograms of the occurrence of a certain prototype in the top-5 most contributing prototypes of a test set sample.


Acknowledgements

LH acknowledges the support of the UK Engineering and Physical Sciences Research Council (EPSRC) Doctoral Training Award. AN is grateful for support from the UK Royal Academy of Engineering under the Engineering for Development Research Fellowships scheme.

This template was originally made by Phillip Isola and Richard Zhang for a colorful ECCV project; the code can be found here.