Pytorch复现导向反向传播Guided Backpropagation
迪丽瓦拉
2025-05-31 23:01:46
0

Pytorch复现导向反向传播Guided Backpropagation

  • 前言
  • 一、导向反向传播Guided Backpropagation的原理
  • 二、导向反向传播Guided Backpropagation的复现
  • 三、导向反向传播Guided Backpropagation的效果
  • 四、参考链接

前言

  笔者在学习Grad-Cam算法对应的论文时,注意到该论文利用导向反向传播Guided Backpropagation来可视化细粒度信息,用Grad-Cam来定位判别性区域,大致如下图所示。因此,便对导向反向传播Guided Backpropagation算法进行了学习与复现。如果你想了解Grad-Cam算法的复现,可参考这篇博客:https://editor.csdn.net/md/?articleId=129650976

在这里插入图片描述

一、导向反向传播Guided Backpropagation的原理

关于导向反向传播Guided Backpropagation的原理可参考下图。
在这里插入图片描述

二、导向反向传播Guided Backpropagation的复现

注意,只需修改下段代码中输入图像的路径及输出图像的路径,然后运行即可。

import os
import cv2
import torch
from torch import nn
from PIL import Image
from torchvision import models
from torchvision import transformsclass Guided_backprop():def __init__(self, model):self.model = modelself.image_reconstruction = Noneself.activation_maps = []self.model.eval()self.register_hooks()def register_hooks(self):def first_layer_hook_fn(module, grad_in, grad_out):# 在全局变量中保存输入图片的梯度,该梯度由第一层卷积层# 反向传播得到,因此该函数需绑定第一个 Conv2d Layerself.image_reconstruction = grad_in[0]def forward_hook_fn(module, input, output):# 在全局变量中保存 ReLU 层的前向传播输出# 用于将来做 guided backpropagationself.activation_maps.append(output)def backward_hook_fn(module, grad_in, grad_out):# ReLU 层反向传播时,用其正向传播的输出作为 guide# 反向传播和正向传播相反,先从后面传起grad = self.activation_maps.pop()# ReLU 正向传播的输出要么大于0,要么等于0,# 大于 0 的部分,梯度为1,# 等于0的部分,梯度还是 0grad[grad > 0] = 1# grad_in[0] 表示 feature 的梯度,只保留大于 0 的部分positive_grad_in = torch.clamp(grad_in[0], min=0.0)# 创建新的输入端梯度new_grad_in = positive_grad_in * grad# ReLU 不含 parameter,输入端梯度是一个只有一个元素的 tuplereturn (new_grad_in,)# 获取 module,这里只针对 alexnet,如果是别的,则需修改modules = list(self.model.features.named_children())# 遍历所有 module,对 ReLU 注册 forward hook 和 backward hookfor name, module in modules:if isinstance(module, nn.ReLU):module.register_forward_hook(forward_hook_fn)module.register_backward_hook(backward_hook_fn)# 对第1层卷积层注册 hookfirst_layer = modules[0][1]first_layer.register_backward_hook(first_layer_hook_fn)def visualize(self, input_image, target_class):# 获取输出,之前注册的 forward hook 开始起作用model_output = self.model(input_image)self.model.zero_grad()pred_class = model_output.argmax().item()# 生成目标类 one-hot 向量,作为反向传播的起点grad_target_map = torch.zeros(model_output.shape,dtype=torch.float)if target_class is not None:grad_target_map[0][target_class] = 1else:grad_target_map[0][pred_class] = 1# 反向传播,之前注册的 backward hook 开始起作用model_output.backward(grad_target_map)# 得到 target class 对输入图片的梯度,转换成图片格式result = self.image_reconstruction.data[0].permute(1, 2, 0)return result.numpy()def normalize(I):# 归一化梯度map,先归一化到 mean=0 std=1norm = (I - I.mean()) / I.std()# 把 std 重置为 0.1,让梯度map中的数值尽可能接近 0norm = norm * 0.1# 均值加 0.5,保证大部分的梯度值为正norm = norm + 0.5# 把 0,1 以外的梯度值分别设置为 0 和 1norm = norm.clip(0, 1)return normif __name__ == '__main__':I = Image.open('./test_0002_aligned.jpg').convert('RGB')transform = transforms.Compose([transforms.Resize((224,224)),transforms.CenterCrop((224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])tensor = transform(I).unsqueeze(0).requires_grad_()model = models.alexnet(pretrained=True)guided_bp = Guided_backprop(model)result = guided_bp.visualize(tensor, None)result = normalize(result)result= result[:, :, ::-1]*255cv2.imwrite('./test_aligned_deconv.jpg',result)

三、导向反向传播Guided Backpropagation的效果

下左图为上述代码的输入图像,即原图;下右图为为上述代码的输出图像,即原图经导向反向传播后的图。
在这里插入图片描述

四、参考链接

1.http://ddrv.cn/a/145213

2.https://blog.csdn.net/qq_41647438/article/details/109504316

3.https://zhuanlan.zhihu.com/p/479485138?utm_id=0

4.https://blog.csdn.net/cdknight_happy/article/details/108792065

相关内容

热门资讯

linux入门---制作进度条 了解缓冲区 我们首先来看看下面的操作: 我们首先创建了一个文件并在这个文件里面添加了...
C++ 机房预约系统(六):学... 8、 学生模块 8.1 学生子菜单、登录和注销 实现步骤: 在Student.cpp的...
A.机器学习入门算法(三):基... 机器学习算法(三):K近邻(k-nearest neigh...
数字温湿度传感器DHT11模块... 模块实例https://blog.csdn.net/qq_38393591/article/deta...
有限元三角形单元的等效节点力 文章目录前言一、重新复习一下有限元三角形单元的理论1、三角形单元的形函数(Nÿ...
Redis 所有支持的数据结构... Redis 是一种开源的基于键值对存储的 NoSQL 数据库,支持多种数据结构。以下是...
win下pytorch安装—c... 安装目录一、cuda安装1.1、cuda版本选择1.2、下载安装二、cudnn安装三、pytorch...
MySQL基础-多表查询 文章目录MySQL基础-多表查询一、案例及引入1、基础概念2、笛卡尔积的理解二、多表查询的分类1、等...
keil调试专题篇 调试的前提是需要连接调试器比如STLINK。 然后点击菜单或者快捷图标均可进入调试模式。 如果前面...
MATLAB | 全网最详细网... 一篇超超超长,超超超全面网络图绘制教程,本篇基本能讲清楚所有绘制要点&#...
IHome主页 - 让你的浏览... 随着互联网的发展,人们越来越离不开浏览器了。每天上班、学习、娱乐,浏览器...
TCP 协议 一、TCP 协议概念 TCP即传输控制协议(Transmission Control ...
营业执照的经营范围有哪些 营业执照的经营范围有哪些 经营范围是指企业可以从事的生产经营与服务项目,是进行公司注册...
C++ 可变体(variant... 一、可变体(variant) 基础用法 Union的问题: 无法知道当前使用的类型是什...
血压计语音芯片,电子医疗设备声... 语音电子血压计是带有语音提示功能的电子血压计,测量前至测量结果全程语音播报࿰...
MySQL OCP888题解0... 文章目录1、原题1.1、英文原题1.2、答案2、题目解析2.1、题干解析2.2、选项解析3、知识点3...
【2023-Pytorch-检... (肆十二想说的一些话)Yolo这个系列我们已经更新了大概一年的时间,现在基本的流程也走走通了,包含数...
实战项目:保险行业用户分类 这里写目录标题1、项目介绍1.1 行业背景1.2 数据介绍2、代码实现导入数据探索数据处理列标签名异...
记录--我在前端干工地(thr... 这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 前段时间接触了Th...
43 openEuler搭建A... 文章目录43 openEuler搭建Apache服务器-配置文件说明和管理模块43.1 配置文件说明...