RawAttention#
- class txv.exp.RawAttention(model: Module)#
Basic Attention Visualization. This is a class-agnostic explanation method. Therefore, an index cannot be passed as an argument.
- __init__(model: Module)#
- Parameters:
model (torch.nn.Module) – A model from
txv.vit
Tip
Use the model with
lrp=False
as LRP models have higher memory footprint.
- explain(input: Tensor, layer: int = 0) Tensor #
- Parameters:
input (torch.Tensor) – Input tensor
layer (int, optional) – Layer number to visualize, by default 0. 0 \(\leq\) layer \(\leq\) model.depth - 1
- Return type:
Returns attention map of the specified layer. Dimensions are (batch_size, num_heads, num_tokens, num_tokens)
Note
Perform necessary post-processing operations to visualize the attention map. Take proper care in choosing between [CLS] token and other tokens.