XAI

[XAI] ResNet50에 GradCAM 구현하기

아리22 2024. 3. 14. 14:38

사전학습된 ResNet50에 이미지 XAI(explainable AI) 모델인 GradCAM을 적용시켜보겠다. 

import numpy as np 
import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras.applications import resnet50, ResNet50
from tensorflow.keras.preprocessing import image 
import matplotlib.pyplot as plt 
import matplotlib.cm as cm
model = ResNet50(weights = 'imagenet')
model.summary()

ImageNet이라는 대용량 데이터셋으로 CNN계열의 모델인 ResNet50을 사전학습 시킨다. 

image_path='./converse.jpg'
img = image.load_img(image_path, target_size=(224,224))
plt.matshow(img)

컨버스 사진을 소스코드가 있는 폴더에 넣어준다. 

x = image.img_to_array(img)
x = np.expand_dims(x,axis=0)
x = resnet50.preprocess_input(x)

preds=model.predict(x)
print("예측 결과:", resnet50.decode_predictions(preds,top=5)[0])

사전학습시킨 ResNet50으로 해당 사진을 분류하고 top 5의 결과만 print한다. 

결과를 보면 1위가 running_shoe, 2위는 sandal, 3위는 Loafer 등등으로 나온다. 

last_conv_layer=model.get_layer("conv5_block3_out")

model_1 = keras.Model(model.inputs,last_conv_layer.output)

input_2 = keras.Input(shape = last_conv_layer.output.shape[1:])
x_2=model.get_layer("avg_pool")(input_2)
x_2=model.get_layer("predictions")(x_2)
model_2=keras.Model(input_2,x_2)

with tf.GradientTape() as tape:
  output_1=model_1(x)
  tape.watch(output_1)
  preds=model_2(output_1)
  class_id=tf.argmax(preds[0])
  output_2=preds[:,class_id]

grads = tape.gradient(output_2,output_1)
pooled_grads = tf.reduce_mean(grads,axis=(0,1,2))

output_1 = output_1.numpy()[0]
pooled_grads=pooled_grads.numpy()
for i in range(pooled_grads.shape[-1]):
  output_1[:,:,i]*=pooled_grads[i]
heatmap=np.mean(output_1,axis=-1)

heatmap=np.maximum(heatmap,0)/np.max(heatmap)
plt.matshow(heatmap)

여기서부터 본격적인 Grad-CAM 코드이다. 

heatmap을 출력했다. 

img=image.load_img(image_path)

img=image.img_to_array(img)
heatmap=np.uint8(255*heatmap)

jet=cm.get_cmap("jet")
color=jet(np.arange(256))[:,:3]
color_heatmap = color[heatmap]

color_heatmap = keras.preprocessing.image.array_to_img(color_heatmap)
color_heatmap = color_heatmap.resize((img.shape[1],img.shape[0]))
color_heatmap = keras.preprocessing.image.img_to_array(color_heatmap)

overlay_img=color_heatmap*0.7+img
overlay_img=keras.preprocessing.image.array_to_img(overlay_img)
plt.matshow(overlay_img)

이제 원본사진에 heatmap을 덧씌워준다. 

결과

이 코드에서는 jet 색상표를 사용하여 사진은 인식해 분류하는데 중요하지 않은 화소는 파란색, 중요하게 작용한 화소는 빨간색으로 표시된다. 결과 사진을 보면 신발 영역을 빨갛게 표시하고 있어 GradCAM이 ResNet50의 의사결정을 제대로 셜명한다는 것을 확인할 수 있다. 


reference 

파이썬으로 만드는 인공지능 

Chapter 12. 설명 가능 인공지능