Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/utils/grad_cams.py: 26%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

66 statements  

1#!/usr/bin/env python 

2# coding: utf-8 

3# 

4# Author: Kazuto Nakashima 

5# URL: http://kazuto1011.github.io 

6# Created: 2017-05-26 

7 

8from collections import Sequence 

9 

10import numpy as np 

11import torch 

12import torch.nn as nn 

13from torch.nn import functional as F 

14from tqdm import tqdm 

15 

16 

17class BaseWrapper(object): 

18 def __init__(self, model): 

19 super(BaseWrapper, self).__init__() 

20 self.device = next(model.parameters()).device 

21 self.model_with_norm = model 

22 self.model = model.model 

23 self.handlers = [] # a set of hook function handlers 

24 

25 def _encode_one_hot(self, ids): 

26 one_hot = torch.zeros_like(self.logits).to(self.device) 

27 one_hot.scatter_(1, ids, 1.0) 

28 return one_hot 

29 

30 def forward(self, image): 

31 self.image_shape = image.shape[2:] 

32 self.logits = self.model_with_norm(image) 

33 self.probs = torch.sigmoid(self.logits) 

34 return self.probs.sort(dim=1, descending=True) # ordered results 

35 

36 def backward(self, ids): 

37 """ 

38 Class-specific backpropagation 

39 """ 

40 one_hot = self._encode_one_hot(ids) 

41 self.model_with_norm.zero_grad() 

42 self.logits.backward(gradient=one_hot, retain_graph=True) 

43 

44 def generate(self): 

45 raise NotImplementedError 

46 

47 def remove_hook(self): 

48 """ 

49 Remove all the forward/backward hook functions 

50 """ 

51 for handle in self.handlers: 

52 handle.remove() 

53 

54class GradCAM(BaseWrapper): 

55 """ 

56 "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" 

57 https://arxiv.org/pdf/1610.02391.pdf 

58 Look at Figure 2 on page 4 

59 """ 

60 

61 def __init__(self, model, candidate_layers=None): 

62 super(GradCAM, self).__init__(model) 

63 self.fmap_pool = {} 

64 self.grad_pool = {} 

65 self.candidate_layers = candidate_layers # list 

66 

67 def save_fmaps(key): 

68 def forward_hook(module, input, output): 

69 self.fmap_pool[key] = output.detach() 

70 

71 return forward_hook 

72 

73 def save_grads(key): 

74 def backward_hook(module, grad_in, grad_out): 

75 self.grad_pool[key] = grad_out[0].detach() 

76 

77 return backward_hook 

78 

79 # If any candidates are not specified, the hook is registered to all the layers. 

80 for name, module in self.model.named_modules(): 

81 if self.candidate_layers is None or name in self.candidate_layers: 

82 self.handlers.append(module.register_forward_hook(save_fmaps(name))) 

83 self.handlers.append(module.register_backward_hook(save_grads(name))) 

84 

85 def _find(self, pool, target_layer): 

86 if target_layer in pool.keys(): 

87 return pool[target_layer] 

88 else: 

89 raise ValueError("Invalid layer name: {}".format(target_layer)) 

90 

91 def generate(self, target_layer): 

92 fmaps = self._find(self.fmap_pool, target_layer) 

93 grads = self._find(self.grad_pool, target_layer) 

94 weights = F.adaptive_avg_pool2d(grads, 1) 

95 

96 gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) 

97 gcam = F.relu(gcam) 

98 gcam = F.interpolate( 

99 gcam, self.image_shape, mode="bilinear", align_corners=False 

100 ) 

101 

102 B, C, H, W = gcam.shape 

103 gcam = gcam.view(B, -1) 

104 gcam -= gcam.min(dim=1, keepdim=True)[0] 

105 gcam /= gcam.max(dim=1, keepdim=True)[0] 

106 gcam = gcam.view(B, C, H, W) 

107 

108 return gcam