-
Notifications
You must be signed in to change notification settings - Fork 12
4.2 Explainable AI for crack detection
Eric Breitbarth edited this page Apr 7, 2024
·
1 revision
We implemented the explainable machine learning tool "Seg-Grad-CAM". The tool shows areas of high model attention. We now show how to compute the corresponding attention heatmap for a model prediction. This follows the example in scripts/crack_detection/crack_tip_attention.py
.
First, let us import the necessary modules and set the paths:
# Imports
import os
import torch
import matplotlib.pyplot as plt
from crackpy.crack_detection.model import get_model
from crackpy.crack_detection.deep_learning import setup, attention
from crackpy.crack_detection.detection import CrackDetection
from crackpy.fracture_analysis.data_processing import InputData
from crackpy.structure_elements.data_files import Nodemap
# Paths
NODEMAP_FILE = 'ATON_Dummy2_WPXXX_DummyVersuch_2_dic_results_1_52.txt'
DATA_PATH = os.path.join('..', '..', 'test_data', 'analysis', 'Nodemaps_new')
OUTPUT_PATH = 'attention'
Then, we make the necessary settings
# Setup
det = CrackDetection(
side='right',
detection_window_size=30,
offset=(5, 0),
angle_det_radius=10,
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
)
setup = setup.Setup()
setup.side = det.side
setup.set_size(det.detection_window_size, det.offset)
setup.set_output_path(OUTPUT_PATH)
setup.set_visu_layers(['down1', 'down2', 'down3', 'down4', 'base', 'up1', 'up2', 'up3', 'up4'])
load the model
# Load the model
model_with_hooks = attention.ParallelNetsWithHooks()
model = get_model('ParallelNets')
model_with_hooks.load_state_dict(model.state_dict())
model_with_hooks = model_with_hooks.unet
and preprocess the nodemap data
# Get nodemap data
nodemap = Nodemap(name=NODEMAP_FILE, folder=DATA_PATH)
data = InputData(nodemap)
# Interpolate data on arrays (256 x 256 pixels)
interp_disps, _ = det.interpolate(data)
# Preprocess input
input_ch = det.preprocess(interp_disps)
We calculate the attention heatmap by forward and backward pass and respective combination of the features and gradients
# Initialize Seg-Grad-CAM
sgc = attention.SegGradCAM(setup, model_with_hooks)
# Forward pass with hooks to catch the features and gradients
output, heatmap = sgc(input_ch)
Finally, let us plot and save the result
# Plot and save heatmap
fig = sgc.plot(output, heatmap)
plt.savefig(os.path.join(OUTPUT_PATH, NODEMAP_FILE[:-4] + '_attention.png'), dpi=300)
plt.close(fig)