RawAttention#

class txv.exp.RawAttention(model: Module)#

Basic Attention Visualization. This is a class-agnostic explanation method. Therefore, an index cannot be passed as an argument.

__init__(model: Module)#
Parameters:

model (torch.nn.Module) – A model from txv.vit

Tip

Use the model with lrp=False as LRP models have higher memory footprint.

explain(input: Tensor, layer: int = 0) Tensor#
Parameters:
  • input (torch.Tensor) – Input tensor

  • layer (int, optional) – Layer number to visualize, by default 0. 0 \(\leq\) layer \(\leq\) model.depth - 1

Return type:

Returns attention map of the specified layer. Dimensions are (batch_size, num_heads, num_tokens, num_tokens)

Note

Perform necessary post-processing operations to visualize the attention map. Take proper care in choosing between [CLS] token and other tokens.