Decoding the Black Box

Transformer model interpretability and experimental visualization

A.I Hub
14 min readAug 16, 2024
Image owned by LinkedIn

In the rapidly advancing field of AI, understanding the 'why' behind a transformer’s decision is just as crucial as the 'what.' Transformer models, with their unparalleled ability to capture complex patterns, often operate as black boxes, powerful yet opaque. But what if we could peer inside these models, unraveling the intricate web of attention and decision making processes? Enter the world of transformer model interpretability and experimental visualization, where we break open the black box, turning complexity into clarity. This journey not only demystifies how these models think but also empowers us to fine-tune and trust them like never before, setting the stage for a future where AI is both powerful and transparent.

Table of Content

  • Introduction
  • Explainability vs interpretability
  • Interpretability
  • Explainability
  • Tools for explainability and interpretability
  • CAPTUM for interpreting transformer prediction
  • Model loading
  • Input preparation
  • Why baseline tensor
  • Layer integrated gradients
  • Visualization
  • Tensor board for pytorch models

Introduction

Image owned by Lark

Machine learning interpretability is about understanding why a model
chooses certain results. It helps us explain model outcomes. Deep learning
models, like Transformers, can be very complicated. As they grow more

advanced, it becomes harder to know why they decide certain things. This is
especially important in areas like healthcare or self-driving cars, where

model decisions can really affect people’s lives. To use these models

responsibly, we must understand their decisions. By understanding how a
model thinks, we can make sure it is deciding things for good reasons and
correct any wrong or biased choices. Similarly, experimental logging and visualization help in fine-tuning and
understanding machine learning models. Logging is recording how the
model behaves during training, and visualization shows this information in

charts or graphs. These tools make it easier to find and solve issues and get a
clear picture of how the model operates. As models and datasets grow
complex, it is crucial to log and visualize data to ensure things run correctly.

Considering the intricacy of Transformer models, in this section, we will
explore tools and methods to interpret them. We will also see how to make these complex models more understandable. Plus, we will dive into tools
that visually represents experimental data.

Explainability Vs Interpretability

Image owned by Open Reply

Interpretability and explainability are two important concepts in the field of
machine learning, often used interchangeably, but they do have distinct
differences.
Interpretability is the degree to which a human can understand the inner
workings of a machine learning model or how the model makes decisions
based on given inputs. An interpretable model allows you to predict what is

going to happen, given a change in input or algorithmic parameters. For
example, linear regression models are considered highly interpretable
because it is clear how changes in the input variables affect the output. Explainability, on the other hand, is the extent to which a machine learning

model’s behavior can be explained in human understandable terms. It
focuses on providing understandable descriptions of how a model arrives at
a decision, even if the internal workings of the model itself are not fully understood or transparent. This is often the case with complex models like
neural networks and ensemble models. For instance, explaining a decision
made by a deep learning model in terms of which features were most

influential in driving the prediction. In this section, we will discuss
both interpretability and explainability in the context of the transformer model.

Interpretability

Image owned by LinkedIn

Let us consider the self-attention mechanism in a transformer, which allows
the model to focus on different words when making predictions. In the
context of interpretability, we could look at the attention scores the model

assigns when processing the word intriguing.
If the model is functioning correctly, it should pay attention to the word

intriguing when trying to determine the sentiment of the sentence. We can
visualize this with an attention map. The attention map might show high

attention scores between intriguing and not only and between intriguing and

but also. This is because these phrases indicate that the word intriguing is
being used in a positive context.

Explainability

Image owned by The Gradient

In the context of explainability, we want to describe how the model arrived
at its final prediction. For our sentiment analysis example, let us say the
model correctly predicts that the sentiment of the sentence is positive.
An explainability tool like LIME could help us understand this decision.

LIME creates a simplified, locally linear version of the model around the
prediction we are interested in explaining. It perturbs the input sentence, gets
new predictions and weighs them based on their proximity to the original

sentence.
LIME might show us that the words intriguing and suspense were highly
influential in the model’s decision to classify the sentiment as positive.

Tools for Explainability and Interpretability

Image owned by interpret.ml

When dealing with Transformer models, we have several tools and

techniques that aid both in interpretability and explainability.

  1. Attention maps — These are widely used for both interpretability and

    explainability of Transformer models. They allow us to visualize the
    attention weights in each layer of the model, highlighting the input
    tokens that each output token is attending to. For example, in a

    language translation task, an attention map can show which words in
    the source sentences are being considered while generating each word
    in the target sentence.
  2. BERTViz — This tool is specifically designed for the BERT model, a
    type of Transformer model. It visualizes attention in the model,
    helping with both interpretability and understanding how different parts
    of the model are interacting and explainability understanding which
    parts of the input sentence were most important for a particular

    output.
  3. ExBERT — This tool allows interactive exploration of BERT models. It
    provides multiple ways to analyze the model, such as neuron
    activations and attention distributions, thus aiding in both
    interpretability and explainability.
  4. Local Interpretable Model Agnostic Explanation — While

    not specifically designed for Transformers, LIME can be used with
    any model to help explain individual predictions. It works by
    approximating the model locally with an interpretable one and can

    thus provide insights into what features the model is using to make
    predictions.
  5. Captum — Captum is a model interpretability library for PyTorch. It
    allows researchers and developers to understand how the data is being
    used and transformed within their models. Captum offers a wide

    variety of attribution algorithms that provide insights into the
    importance of individual features and how they contribute to model

    predictions.

Several other notable tools that should be mentioned include Eli51, SHAP2
and TensorFlow Model Analysis (TFMA). In the next section, we will demonstrate how we can use Captum for interpretability and explainability.

CAPTUM for Interpreting Transformer Prediction

Image owned by Kati & Co

In this section, we will use Captum with the model distilbert-base-uncased-finetuned-sst-2-english to interpret the sentiment analysis
of a given text. Let us explain the key components and how they work.

Model Loading

We are using the pre-trained DistilBERT model fine-tuned for sentiment
analysis. This model classifies given text into positive or negative
sentiment.

# Pre-trained model and tokenizer
model_path = 'distilbert-base-uncased-finetuned-sst-2-english'
model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model.eval()

Input Preparation

The function construct_input_and_baseline is designed to take a textual
input and transform it into tensors that can be fed into a model, such as
DistilBERT. In addition to the model’s input tensor, the function also
constructs a baseline tensor. Let us break down what is happening here,
specifically focusing on the concept of the baseline tensor.

Input Tensor

  1. Text tokenization — The input text is tokenized into a sequence of
    integers using the model’s tokenizer. This sequence represents the words and sub words in the original text.
  2. Add special tokens — Special tokens [CLS] and [SEP] are added at the
    beginning and end of the sequence, respectively.
  3. Input IDs — The resulting sequence of integers (input_ids) is
    converted into a tensor that can be fed into the model.

Baseline Tensor

The baseline tensor is a reference input that represents the absence or neutral
state of the features you are trying to interpret. In the context of NLP, a
common choice for the baseline is a sequence of padding tokens.

  1. Baseline Token ID — The ID corresponding to the padding token is
    retrieved (baseline_token_id).
  2. Create Baseline Sequence — The baseline sequence is created by
    replacing the text’s tokens with the padding token ID. The special
    tokens [CLS] and [SEP] are retained at the beginning and end of the

    sequence.
  3. Baseline Input IDs — The resulting sequence (baseline_input_ids) is

    converted into a tensor.

Example:

Suppose the input text is “I love movies” and the corresponding token IDs
after tokenization are [10, 18, 27]. The constructed input tensor and
baseline tensor might look like this.

  • Input IDs — [CLS_ID, 10, 18, 27, SEP_ID]
  • Baseline Input IDs — [CLS_ID, PAD_ID, PAD_ID, PAD_ID, SEP_ID]

Why Baseline Tensor

Image owned by Nvidia

The baseline tensor is used in certain attribution methods like Integrated
Gradients to understand how much each feature contributes to the difference
between the model’s prediction for the actual input and the baseline. By
comparing the model’s behavior on the input to its behavior on this baseline,
you can interpret how important each feature is for the prediction.

In this code snippet is constructing both the actual input

to the model reflecting the text you want to analyze and a baseline input
that reflecting a neutral or non-informative version of the text. The comparison
between these two inputs will be used to understand how the model is

interpreting the text.
The text and baseline are tokenized and converted into tensors. A baseline is
often a reference input that represents the absence of the features of interest
for example, all padding tokens.

def construct_input_and_baseline(input_text: str):
"""Constructs input and baseline tensors for the given text."""

max_length = 768
baseline_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id
text_ids = tokenizer.encode(input_text, max_length=max_length, truncation=True, add_special_tokens=False)
input_ids = [cls_token_id] + text_ids + [sep_token_id]
baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
token_list = tokenizer.convert_ids_to_tokens(input_ids)

return torch.tensor([input_ids], device='cpu'),
torch.tensor([baseline_input_ids], device='cpu'), token_list

# Constructing input and baseline
input_ids, baseline_input_ids, all_tokens =
construct_input_and_baseline(text)

Layer Integrated Gradients

Image owned by TDS

The code uses an attribution method to clarify how a model’s predictions are
influenced by different parts of its input, which are tokens in this case. Here,

is a breakdown of its key components.

  1. Model output function: The function named model_output is a
    wrapper around the model’s forward pass. It extracts the prediction
    scores often referred to as logits from the model’s output.
  2. Setting up layer integrated gradients — The

    LayerIntegratedGradients function is initialized with two primary
    components.
  • The model’s forward function, represented by model_output.
  • The specific layer of the model we’re interested in examining,

    which is the embeddings layer (model.distilbert.embeddings).

Attribution Calculation

The code computes attributions for both sentiment classes positive
and negative. Attributions essentially give us a score, indicating
how much each token influenced the prediction for each sentiment
class.

Attribution Summarization and Normalization

  • The importance scores attributions for each token are aggregated
    by summing across the embedding dimensions.
  • These aggregated scores are then normalized to ensure that their
    magnitudes are comparable. The outcome is a 1D tensor, where
    each value signifies the relative significance of its corresponding
    token. For instance, if we are looking at negative sentiment, a
    higher score would mean that the token strongly suggests a

    negative sentiment.

Choosing Attribution Based on Prediction

  • The model determines whether a given text is positive or negative
    in sentiment.
  • Depending on this prediction, the code selects the corresponding
    set of attributions, either positive or negative, to analyze further.

In essence, this approach provides an in-depth look into which words or
phrases have most sway the model’s sentiment prediction.

#Model Output Function
def model_output(inputs):
return model(inputs)[0]
#Setting Up Layer Integrated Gradients
lig = LayerIntegratedGradients(model_output, model.distilbert.embeddings)
# Attribution Calculation
target_classes = [0, 1]
attributions = {}
delta = {}
# Calculating attributions for both classes
# We will calculate the attributions for each class
for target_class in target_classes:
attributions[target_class], delta[target_class]
= lig.attribute(
inputs=input_ids,
baselines=baseline_input_ids,
target=target_class,
return_convergence_delta=True,
internal_batch_size=1)
#Attributions summarization and normalization
neg_attributions =
attributions[0].sum(dim=-1).squeeze(0) /
torch.norm(attributions[0])
pos_attributions =
attributions[1].sum(dim=-1).squeeze(0) /
torch.norm(attributions[1])
# Choosing Attribution based on the Prediction
pred_prob, pred_class = torch.max(model(input_ids)
[0]), int(torch.argmax(model(input_ids)[0]))
# Selecting the attributions based on the predicted class
summarized_attr = pos_attributions if pred_class ==
1 else neg_attributions

Visualization

Captum provides visualization tools like viz.visualize_text to represent
the attributions visually. It shows the tokens and their corresponding
importance scores, highlighting the tokens that are more influential in the
model’s decision. Let us understand the important aspect of this code.

  • true_class=None — This indicates the actual or ground-truth class for
    the input text. Since we are not providing any ground truth in this

    context, it is set to None.
  • raw_input_ids=all_tokens — This provides the tokenized version of the
    input text (all_tokens) which helps in mapping attributions back to

    their respective words/tokens in the visualization.
  • convergence_score=delta[pred_class] — This score measures the
    quality or reliability of the calculated attributions. A smaller
    convergence score indicates that the attributions are more reliable.
score_vis = viz.VisualizationDataRecord(

word_attributions=summarized_attr,
pred_prob=pred_prob,
pred_class=pred_class,
true_class=None,
attr_class=text,

attr_score=summarized_attr.sum(),
raw_input_ids=all_tokens,

convergence_score=delta[pred_class])
# Visualizing the result
viz.visualize_text([score_vis])

Figure 1.1 is the captum visualization. As you can see, words awesome and
enjoyed has the highest attributions score for the positive sentiment predictions.

Figure 1.1 - The result of captum visualization

Tensor Board For Pytorch Model

Image owned by Toptal

TensorBoard, initially developed for TensorFlow, has become a vital
visualization toolkit for neural network training across different frameworks.
For PyTorch enthusiasts, the torch.utils.tensorboard integration allows
them to leverage TensorBoard’s robust visualization capabilities, ranging
from monitoring training milestones to examining learned embeddings. To
initiate TensorBoard, input tensorboard --logdir=runs into the terminal. By
default, access to TensorBoard is available at:http://localhost:6006.

You can visualize a variety of things related to your PyTorch models and
training sessions. Here are the key visualizations you can achieve using
TensorBoard with PyTorch.

  • Scalars — Scalars refer to simple, single-number metrics that you track
    over time or iterations. They are typically used to log and visualize

    metrics that change with each epoch or iteration, such as training loss,
    validation accuracy, learning rate and so on.
import torch
from torch.utils.tensorboard import SummaryWriter

# Create a dummy model and optimizer
model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01)
# Instantiate SummaryWriter
writer = SummaryWriter()
for epoch in range(100):
# Dummy training loop
optimizer.zero_grad()
output = model(torch.randn(32, 10))
loss = ((output - torch.randn(32,
1))**2).mean()
loss.backward()
optimizer.step()

# Log loss to TensorBoard
writer.add_scalar("Training loss", loss, epoch)
# Close the writer
writer.close()
  • Histogram — Visualize the distribution of tensor values, for example,
    layer weights. The following code demonstrates how you can visualize

    the model’s named parameter as histogram.
for name, weight in model.named_parameters():
writer.add_histogram(name, weight, epoch)
  • Text — Log textual information. This code snippet shows how
    you can log textual data.
writer.add_text('Loss_Text', 'The training loss was very low this epoch', epoch)
  • Distribution — It is just a smother version of the histogram. You can
    use the same code you use for histogram.
  • Visualizing model graphs — Beyond just scalars, you can visualize the
    architecture of your model. This code shows how you can
    visualize Bert Architecture.
import torch
from transformers import BertModel, BertTokenizer
from torch.utils.tensorboard import SummaryWriter
# Load pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
bert_model = BertModel.from_pretrained(model_name)
tokenizer =
BertTokenizer.from_pretrained(model_name)
class SimpleBERT(torch.nn.Module):
def __init__(self, bert_model):
super(SimpleBERT, self).__init__()
self.bert = bert_model
def forward(self, input_ids,
attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
return outputs.last_hidden_state
model = SimpleBERT(bert_model)

# Instantiate the SummaryWriter
writer = SummaryWriter()

# Create a dummy input for the BERT model
tokens = tokenizer("Hello, TensorBoard!", return_tensors="pt")
input_ids = tokens["input_ids"]

attention_mask = tokens["attention_mask"]
# Add the BERT model graph to TensorBoard
writer.add_graph(model, [input_ids, attention_mask])
# Close the writer
writer.close()
  • Embedding — Using this functionality, you can visualize the
    embedding of tokens in 3-D space. When you view the embeddings in
    TensorBoard, you will see each token word/sub-word positioned in
    the embedding space. The similar word should appear near where as
    dissimilar words should appear further.
with torch.no_grad():
embeddings = model(input_ids,
attention_mask=attention_mask)
# Just as an example, using the tokens as metadata
# Note: We remove the [CLS] and [SEP] tokens for visualization.
metadata = [token for token in
tokenizer.tokenize(text)]
embeddings = embeddings[0, 1:-1, :] # Removing
embeddings for [CLS] and [SEP]
writer.add_embedding(embeddings, metadata=metadata)
# Close the writer
writer.close()
  • PR curves — For understanding classification performance.
probs = model(input_data)
writer.add_pr_curve('pr_curve', true_labels, probs, epoch)
  • Hyperparameters — Visualize hyper parameters.
hparams = {'lr': 0.1, 'batch_size': 32}
metrics = {'accuracy': 0.8}
writer.add_hparams(hparams, metrics)
  • Profiling — In the case of PyTorch’s torch.profiler, it is specifically
    designed to profile the execution of PyTorch models. When you
    profile a PyTorch model, here are some things you are typically interested in.
  1. Operator level performance — Which specific operations (for

    example, matrix multiplications, convolutions are taking the most
    time? How long does each operation take to execute?
  2. Memory consumption — Which operations consume the most

    memory? This is crucial for deep learning models which can often
    be memory bound.
  3. Call stack information — Which lines in your source code

    correspond to the various operations? This helps link the profiled
    performance data back to specific lines of your code.
  4. CPU/GPU time — How long are operations taking on the CPU

    versus the GPU? This can help in identifying data transfer

    bottlenecks, among other things.

When this information is logged to TensorBoard using

writer.add_text(), you can visualize and analyze it, making it easier
to understand the performance characteristics of your model. This is
especially valuable when you are trying to optimize a model to run
faster or when diagnosing performance issues.

for inputs, targets in dataloader:
with
torch.profiler.profile(with_stack=True) as
prof:
train_step(inputs, targets)

# Log the profiling results to TensorBoard
writer.add_text("Profile",
str(prof.key_averages().table()))
  • Visualizing image data — For convolutional networks or any model
    working with image data, visualizing the input or output can be
    informative.
images = torch.randn(32, 3, 64, 64) #
Simulating a batch of 32 images
grid = torchvision.utils.make_grid(images)
writer.add_image("images", grid, 0)

The end-to-end code for Tensor board logging discussed above is provided in
the accompanying notebook.

Conclusion

In the realm of AI, where understanding is as crucial as innovation, mastering both explainability and interpretability transforms how we interact with our models. We have delved into the nuances that distinguish interpretability from explainability, exploring powerful tools like CAPTUM to dissect and illuminate transformer predictions. From model loading and input preparation to understanding the critical role of baseline tensors and leveraging layer integrated gradients, each step sharpens our ability to demystify complex models. Visualization with tools like TensorBoard further bridges the gap between raw data and actionable insights, making AI not just a powerful tool but a transparent ally. As we integrate these practices, we unlock a deeper connection with our models, ensuring they are not only effective but also understandable and trustworthy. This holistic approach to AI interpretability and visualization paves the way for smarter, more transparent systems that inspire confidence and drive innovation forward.

--

--

A.I Hub
A.I Hub

Written by A.I Hub

We writes about Data Science | Software Development | Machine Learning | Artificial Intelligence | Ethical Hacking and much more. Unleash your potential with us

No responses yet