Decoding the Black Box
Transformer model interpretability and experimental visualization
In the rapidly advancing field of AI, understanding the 'why' behind a transformer’s decision is just as crucial as the 'what.' Transformer models, with their unparalleled ability to capture complex patterns, often operate as black boxes, powerful yet opaque. But what if we could peer inside these models, unraveling the intricate web of attention and decision making processes? Enter the world of transformer model interpretability and experimental visualization, where we break open the black box, turning complexity into clarity. This journey not only demystifies how these models think but also empowers us to fine-tune and trust them like never before, setting the stage for a future where AI is both powerful and transparent.
Table of Content
- Introduction
- Explainability vs interpretability
- Interpretability
- Explainability
- Tools for explainability and interpretability
- CAPTUM for interpreting transformer prediction
- Model loading
- Input preparation
- Why baseline tensor
- Layer integrated gradients
- Visualization
- Tensor board for pytorch models
Introduction
Machine learning interpretability is about understanding why a model
chooses certain results. It helps us explain model outcomes. Deep learning
models, like Transformers, can be very complicated. As they grow more
advanced, it becomes harder to know why they decide certain things. This is
especially important in areas like healthcare or self-driving cars, where
model decisions can really affect people’s lives. To use these models
responsibly, we must understand their decisions. By understanding how a
model thinks, we can make sure it is deciding things for good reasons and
correct any wrong or biased choices. Similarly, experimental logging and visualization help in fine-tuning and
understanding machine learning models. Logging is recording how the
model behaves during training, and visualization shows this information in
charts or graphs. These tools make it easier to find and solve issues and get a
clear picture of how the model operates. As models and datasets grow
complex, it is crucial to log and visualize data to ensure things run correctly.
Considering the intricacy of Transformer models, in this section, we will
explore tools and methods to interpret them. We will also see how to make these complex models more understandable. Plus, we will dive into tools
that visually represents experimental data.
Explainability Vs Interpretability
Interpretability and explainability are two important concepts in the field of
machine learning, often used interchangeably, but they do have distinct
differences.
Interpretability is the degree to which a human can understand the inner
workings of a machine learning model or how the model makes decisions
based on given inputs. An interpretable model allows you to predict what is
going to happen, given a change in input or algorithmic parameters. For
example, linear regression models are considered highly interpretable
because it is clear how changes in the input variables affect the output. Explainability, on the other hand, is the extent to which a machine learning
model’s behavior can be explained in human understandable terms. It
focuses on providing understandable descriptions of how a model arrives at
a decision, even if the internal workings of the model itself are not fully understood or transparent. This is often the case with complex models like
neural networks and ensemble models. For instance, explaining a decision
made by a deep learning model in terms of which features were most
influential in driving the prediction. In this section, we will discuss
both interpretability and explainability in the context of the transformer model.
Interpretability
Let us consider the self-attention mechanism in a transformer, which allows
the model to focus on different words when making predictions. In the
context of interpretability, we could look at the attention scores the model
assigns when processing the word intriguing.
If the model is functioning correctly, it should pay attention to the word
intriguing when trying to determine the sentiment of the sentence. We can
visualize this with an attention map. The attention map might show high
attention scores between intriguing and not only and between intriguing and
but also. This is because these phrases indicate that the word intriguing is
being used in a positive context.
Explainability
In the context of explainability, we want to describe how the model arrived
at its final prediction. For our sentiment analysis example, let us say the
model correctly predicts that the sentiment of the sentence is positive.
An explainability tool like LIME could help us understand this decision.
LIME creates a simplified, locally linear version of the model around the
prediction we are interested in explaining. It perturbs the input sentence, gets
new predictions and weighs them based on their proximity to the original
sentence.
LIME might show us that the words intriguing and suspense were highly
influential in the model’s decision to classify the sentiment as positive.
Tools for Explainability and Interpretability
When dealing with Transformer models, we have several tools and
techniques that aid both in interpretability and explainability.
- Attention maps — These are widely used for both interpretability and
explainability of Transformer models. They allow us to visualize the
attention weights in each layer of the model, highlighting the input
tokens that each output token is attending to. For example, in a
language translation task, an attention map can show which words in
the source sentences are being considered while generating each word
in the target sentence. - BERTViz — This tool is specifically designed for the BERT model, a
type of Transformer model. It visualizes attention in the model,
helping with both interpretability and understanding how different parts
of the model are interacting and explainability understanding which
parts of the input sentence were most important for a particular
output. - ExBERT — This tool allows interactive exploration of BERT models. It
provides multiple ways to analyze the model, such as neuron
activations and attention distributions, thus aiding in both
interpretability and explainability. - Local Interpretable Model Agnostic Explanation — While
not specifically designed for Transformers, LIME can be used with
any model to help explain individual predictions. It works by
approximating the model locally with an interpretable one and can
thus provide insights into what features the model is using to make
predictions. - Captum — Captum is a model interpretability library for PyTorch. It
allows researchers and developers to understand how the data is being
used and transformed within their models. Captum offers a wide
variety of attribution algorithms that provide insights into the
importance of individual features and how they contribute to model
predictions.
Several other notable tools that should be mentioned include Eli51, SHAP2
and TensorFlow Model Analysis (TFMA). In the next section, we will demonstrate how we can use Captum for interpretability and explainability.
CAPTUM for Interpreting Transformer Prediction
In this section, we will use Captum with the model distilbert-base-uncased-finetuned-sst-2-english to interpret the sentiment analysis
of a given text. Let us explain the key components and how they work.
Model Loading
We are using the pre-trained DistilBERT model fine-tuned for sentiment
analysis. This model classifies given text into positive or negative
sentiment.
# Pre-trained model and tokenizer
model_path = 'distilbert-base-uncased-finetuned-sst-2-english'
model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
model.eval()
Input Preparation
The function construct_input_and_baseline is designed to take a textual
input and transform it into tensors that can be fed into a model, such as
DistilBERT. In addition to the model’s input tensor, the function also
constructs a baseline tensor. Let us break down what is happening here,
specifically focusing on the concept of the baseline tensor.
Input Tensor
- Text tokenization — The input text is tokenized into a sequence of
integers using the model’s tokenizer. This sequence represents the words and sub words in the original text. - Add special tokens — Special tokens [CLS] and [SEP] are added at the
beginning and end of the sequence, respectively. - Input IDs — The resulting sequence of integers (input_ids) is
converted into a tensor that can be fed into the model.
Baseline Tensor
The baseline tensor is a reference input that represents the absence or neutral
state of the features you are trying to interpret. In the context of NLP, a
common choice for the baseline is a sequence of padding tokens.
- Baseline Token ID — The ID corresponding to the padding token is
retrieved (baseline_token_id). - Create Baseline Sequence — The baseline sequence is created by
replacing the text’s tokens with the padding token ID. The special
tokens [CLS] and [SEP] are retained at the beginning and end of the
sequence. - Baseline Input IDs — The resulting sequence (baseline_input_ids) is
converted into a tensor.
Example:
Suppose the input text is “I love movies” and the corresponding token IDs
after tokenization are [10, 18, 27]. The constructed input tensor and
baseline tensor might look like this.
- Input IDs — [CLS_ID, 10, 18, 27, SEP_ID]
- Baseline Input IDs — [CLS_ID, PAD_ID, PAD_ID, PAD_ID, SEP_ID]
Why Baseline Tensor
The baseline tensor is used in certain attribution methods like Integrated
Gradients to understand how much each feature contributes to the difference
between the model’s prediction for the actual input and the baseline. By
comparing the model’s behavior on the input to its behavior on this baseline,
you can interpret how important each feature is for the prediction.
In this code snippet is constructing both the actual input
to the model reflecting the text you want to analyze and a baseline input
that reflecting a neutral or non-informative version of the text. The comparison
between these two inputs will be used to understand how the model is
interpreting the text.
The text and baseline are tokenized and converted into tensors. A baseline is
often a reference input that represents the absence of the features of interest
for example, all padding tokens.
def construct_input_and_baseline(input_text: str):
"""Constructs input and baseline tensors for the given text."""
max_length = 768
baseline_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id
text_ids = tokenizer.encode(input_text, max_length=max_length, truncation=True, add_special_tokens=False)
input_ids = [cls_token_id] + text_ids + [sep_token_id]
baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
token_list = tokenizer.convert_ids_to_tokens(input_ids)
return torch.tensor([input_ids], device='cpu'),
torch.tensor([baseline_input_ids], device='cpu'), token_list
# Constructing input and baseline
input_ids, baseline_input_ids, all_tokens =
construct_input_and_baseline(text)
Layer Integrated Gradients
The code uses an attribution method to clarify how a model’s predictions are
influenced by different parts of its input, which are tokens in this case. Here,
is a breakdown of its key components.
- Model output function: The function named model_output is a
wrapper around the model’s forward pass. It extracts the prediction
scores often referred to as logits from the model’s output. - Setting up layer integrated gradients — The
LayerIntegratedGradients function is initialized with two primary
components.
- The model’s forward function, represented by model_output.
- The specific layer of the model we’re interested in examining,
which is the embeddings layer (model.distilbert.embeddings).
Attribution Calculation
The code computes attributions for both sentiment classes positive
and negative. Attributions essentially give us a score, indicating
how much each token influenced the prediction for each sentiment
class.
Attribution Summarization and Normalization
- The importance scores attributions for each token are aggregated
by summing across the embedding dimensions. - These aggregated scores are then normalized to ensure that their
magnitudes are comparable. The outcome is a 1D tensor, where
each value signifies the relative significance of its corresponding
token. For instance, if we are looking at negative sentiment, a
higher score would mean that the token strongly suggests a
negative sentiment.
Choosing Attribution Based on Prediction
- The model determines whether a given text is positive or negative
in sentiment. - Depending on this prediction, the code selects the corresponding
set of attributions, either positive or negative, to analyze further.
In essence, this approach provides an in-depth look into which words or
phrases have most sway the model’s sentiment prediction.
#Model Output Function
def model_output(inputs):
return model(inputs)[0]
#Setting Up Layer Integrated Gradients
lig = LayerIntegratedGradients(model_output, model.distilbert.embeddings)
# Attribution Calculation
target_classes = [0, 1]
attributions = {}
delta = {}
# Calculating attributions for both classes
# We will calculate the attributions for each class
for target_class in target_classes:
attributions[target_class], delta[target_class]
= lig.attribute(
inputs=input_ids,
baselines=baseline_input_ids,
target=target_class,
return_convergence_delta=True,
internal_batch_size=1)
#Attributions summarization and normalization
neg_attributions =
attributions[0].sum(dim=-1).squeeze(0) /
torch.norm(attributions[0])
pos_attributions =
attributions[1].sum(dim=-1).squeeze(0) /
torch.norm(attributions[1])
# Choosing Attribution based on the Prediction
pred_prob, pred_class = torch.max(model(input_ids)
[0]), int(torch.argmax(model(input_ids)[0]))
# Selecting the attributions based on the predicted class
summarized_attr = pos_attributions if pred_class ==
1 else neg_attributions
Visualization
Captum provides visualization tools like viz.visualize_text to represent
the attributions visually. It shows the tokens and their corresponding
importance scores, highlighting the tokens that are more influential in the
model’s decision. Let us understand the important aspect of this code.
- true_class=None — This indicates the actual or ground-truth class for
the input text. Since we are not providing any ground truth in this
context, it is set to None. - raw_input_ids=all_tokens — This provides the tokenized version of the
input text (all_tokens) which helps in mapping attributions back to
their respective words/tokens in the visualization. - convergence_score=delta[pred_class] — This score measures the
quality or reliability of the calculated attributions. A smaller
convergence score indicates that the attributions are more reliable.
score_vis = viz.VisualizationDataRecord(
word_attributions=summarized_attr,
pred_prob=pred_prob,
pred_class=pred_class,
true_class=None,
attr_class=text,
attr_score=summarized_attr.sum(),
raw_input_ids=all_tokens,
convergence_score=delta[pred_class])
# Visualizing the result
viz.visualize_text([score_vis])
Figure 1.1 is the captum visualization. As you can see, words awesome and
enjoyed has the highest attributions score for the positive sentiment predictions.
Tensor Board For Pytorch Model
TensorBoard, initially developed for TensorFlow, has become a vital
visualization toolkit for neural network training across different frameworks.
For PyTorch enthusiasts, the torch.utils.tensorboard integration allows
them to leverage TensorBoard’s robust visualization capabilities, ranging
from monitoring training milestones to examining learned embeddings. To
initiate TensorBoard, input tensorboard --logdir=runs into the terminal. By
default, access to TensorBoard is available at:http://localhost:6006.
You can visualize a variety of things related to your PyTorch models and
training sessions. Here are the key visualizations you can achieve using
TensorBoard with PyTorch.
- Scalars — Scalars refer to simple, single-number metrics that you track
over time or iterations. They are typically used to log and visualize
metrics that change with each epoch or iteration, such as training loss,
validation accuracy, learning rate and so on.
import torch
from torch.utils.tensorboard import SummaryWriter
# Create a dummy model and optimizer
model = torch.nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(),
lr=0.01)
# Instantiate SummaryWriter
writer = SummaryWriter()
for epoch in range(100):
# Dummy training loop
optimizer.zero_grad()
output = model(torch.randn(32, 10))
loss = ((output - torch.randn(32,
1))**2).mean()
loss.backward()
optimizer.step()
# Log loss to TensorBoard
writer.add_scalar("Training loss", loss, epoch)
# Close the writer
writer.close()
- Histogram — Visualize the distribution of tensor values, for example,
layer weights. The following code demonstrates how you can visualize
the model’s named parameter as histogram.
for name, weight in model.named_parameters():
writer.add_histogram(name, weight, epoch)
- Text — Log textual information. This code snippet shows how
you can log textual data.
writer.add_text('Loss_Text', 'The training loss was very low this epoch', epoch)
- Distribution — It is just a smother version of the histogram. You can
use the same code you use for histogram. - Visualizing model graphs — Beyond just scalars, you can visualize the
architecture of your model. This code shows how you can
visualize Bert Architecture.
import torch
from transformers import BertModel, BertTokenizer
from torch.utils.tensorboard import SummaryWriter
# Load pre-trained BERT model and tokenizer
model_name = "bert-base-uncased"
bert_model = BertModel.from_pretrained(model_name)
tokenizer =
BertTokenizer.from_pretrained(model_name)
class SimpleBERT(torch.nn.Module):
def __init__(self, bert_model):
super(SimpleBERT, self).__init__()
self.bert = bert_model
def forward(self, input_ids,
attention_mask=None, token_type_ids=None):
outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
return outputs.last_hidden_state
model = SimpleBERT(bert_model)
# Instantiate the SummaryWriter
writer = SummaryWriter()
# Create a dummy input for the BERT model
tokens = tokenizer("Hello, TensorBoard!", return_tensors="pt")
input_ids = tokens["input_ids"]
attention_mask = tokens["attention_mask"]
# Add the BERT model graph to TensorBoard
writer.add_graph(model, [input_ids, attention_mask])
# Close the writer
writer.close()
- Embedding — Using this functionality, you can visualize the
embedding of tokens in 3-D space. When you view the embeddings in
TensorBoard, you will see each token word/sub-word positioned in
the embedding space. The similar word should appear near where as
dissimilar words should appear further.
with torch.no_grad():
embeddings = model(input_ids,
attention_mask=attention_mask)
# Just as an example, using the tokens as metadata
# Note: We remove the [CLS] and [SEP] tokens for visualization.
metadata = [token for token in
tokenizer.tokenize(text)]
embeddings = embeddings[0, 1:-1, :] # Removing
embeddings for [CLS] and [SEP]
writer.add_embedding(embeddings, metadata=metadata)
# Close the writer
writer.close()
- PR curves — For understanding classification performance.
probs = model(input_data)
writer.add_pr_curve('pr_curve', true_labels, probs, epoch)
- Hyperparameters — Visualize hyper parameters.
hparams = {'lr': 0.1, 'batch_size': 32}
metrics = {'accuracy': 0.8}
writer.add_hparams(hparams, metrics)
- Profiling — In the case of PyTorch’s torch.profiler, it is specifically
designed to profile the execution of PyTorch models. When you
profile a PyTorch model, here are some things you are typically interested in.
- Operator level performance — Which specific operations (for
example, matrix multiplications, convolutions are taking the most
time? How long does each operation take to execute? - Memory consumption — Which operations consume the most
memory? This is crucial for deep learning models which can often
be memory bound. - Call stack information — Which lines in your source code
correspond to the various operations? This helps link the profiled
performance data back to specific lines of your code. - CPU/GPU time — How long are operations taking on the CPU
versus the GPU? This can help in identifying data transfer
bottlenecks, among other things.
When this information is logged to TensorBoard using
writer.add_text(), you can visualize and analyze it, making it easier
to understand the performance characteristics of your model. This is
especially valuable when you are trying to optimize a model to run
faster or when diagnosing performance issues.
for inputs, targets in dataloader:
with
torch.profiler.profile(with_stack=True) as
prof:
train_step(inputs, targets)
# Log the profiling results to TensorBoard
writer.add_text("Profile",
str(prof.key_averages().table()))
- Visualizing image data — For convolutional networks or any model
working with image data, visualizing the input or output can be
informative.
images = torch.randn(32, 3, 64, 64) #
Simulating a batch of 32 images
grid = torchvision.utils.make_grid(images)
writer.add_image("images", grid, 0)
The end-to-end code for Tensor board logging discussed above is provided in
the accompanying notebook.
Conclusion
In the realm of AI, where understanding is as crucial as innovation, mastering both explainability and interpretability transforms how we interact with our models. We have delved into the nuances that distinguish interpretability from explainability, exploring powerful tools like CAPTUM to dissect and illuminate transformer predictions. From model loading and input preparation to understanding the critical role of baseline tensors and leveraging layer integrated gradients, each step sharpens our ability to demystify complex models. Visualization with tools like TensorBoard further bridges the gap between raw data and actionable insights, making AI not just a powerful tool but a transparent ally. As we integrate these practices, we unlock a deeper connection with our models, ensuring they are not only effective but also understandable and trustworthy. This holistic approach to AI interpretability and visualization paves the way for smarter, more transparent systems that inspire confidence and drive innovation forward.