Skip to content

4.2 Explainable AI for crack detection

Eric Breitbarth edited this page Apr 7, 2024 · 1 revision

We implemented the explainable machine learning tool "Seg-Grad-CAM". The tool shows areas of high model attention. We now show how to compute the corresponding attention heatmap for a model prediction. This follows the example in scripts/crack_detection/crack_tip_attention.py.

First, let us import the necessary modules and set the paths:

# Imports
import os

import torch
import matplotlib.pyplot as plt

from crackpy.crack_detection.model import get_model
from crackpy.crack_detection.deep_learning import setup, attention
from crackpy.crack_detection.detection import CrackDetection
from crackpy.fracture_analysis.data_processing import InputData
from crackpy.structure_elements.data_files import Nodemap

# Paths
NODEMAP_FILE = 'ATON_Dummy2_WPXXX_DummyVersuch_2_dic_results_1_52.txt'
DATA_PATH = os.path.join('..', '..', 'test_data', 'analysis', 'Nodemaps_new')
OUTPUT_PATH = 'attention'

Then, we make the necessary settings

# Setup
det = CrackDetection(
    side='right',
    detection_window_size=30,
    offset=(5, 0),
    angle_det_radius=10,
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
)

setup = setup.Setup()
setup.side = det.side
setup.set_size(det.detection_window_size, det.offset)
setup.set_output_path(OUTPUT_PATH)
setup.set_visu_layers(['down1', 'down2', 'down3', 'down4', 'base', 'up1', 'up2', 'up3', 'up4'])

load the model

# Load the model
model_with_hooks = attention.ParallelNetsWithHooks()
model = get_model('ParallelNets')
model_with_hooks.load_state_dict(model.state_dict())
model_with_hooks = model_with_hooks.unet

and preprocess the nodemap data

# Get nodemap data
nodemap = Nodemap(name=NODEMAP_FILE, folder=DATA_PATH)
data = InputData(nodemap)

# Interpolate data on arrays (256 x 256 pixels)
interp_disps, _ = det.interpolate(data)

# Preprocess input
input_ch = det.preprocess(interp_disps)

We calculate the attention heatmap by forward and backward pass and respective combination of the features and gradients

# Initialize Seg-Grad-CAM
sgc = attention.SegGradCAM(setup, model_with_hooks)

# Forward pass with hooks to catch the features and gradients
output, heatmap = sgc(input_ch)

Finally, let us plot and save the result

# Plot and save heatmap
fig = sgc.plot(output, heatmap)
plt.savefig(os.path.join(OUTPUT_PATH, NODEMAP_FILE[:-4] + '_attention.png'), dpi=300)
plt.close(fig)

example attention heatmap