attr = LayerIntegratedGradients(vqa_resnet, [vqa_resnet.module.input_maps["v"], vqa_resnet.module.module.text.embedding])
The use of .module.module
in your code suggests that vqa_resnet
is wrapped inside a module container (likely using torch.nn.DataParallel
or torch.nn.parallel.DistributedDataParallel
), which is a common practice when working with multi-GPU setups in PyTorch. Let me break this down more clearly:
.module
in PyTorch:
-
When you use
torch.nn.DataParallel
ortorch.nn.parallel.DistributedDataParallel
, PyTorch wraps the original model (vqa_resnet
in this case) inside a container. The container has a.module
attribute that points to the actual model.For example:
python model = torch.nn.DataParallel(vqa_resnet) # or torch.nn.parallel.DistributedDataParallel(vqa_resnet)
This means that:
-
vqa_resnet
is now inside aDataParallel
(orDistributedDataParallel
) container. -
To access the original
vqa_resnet
model, you need to use.module
.
-
.module.module
:
Now, based on the code you provided:
python
vqa_resnet.module.module.text.embedding
It suggests that the vqa_resnet
model has been wrapped twice in a container (perhaps a custom wrapper inside your codebase). This would mean:
-
The first
.module
accesses the model wrapped byDataParallel
orDistributedDataParallel
. <span style="color:#92d050 !important;">这里是 captum 的 ModelInputWrapper </span> -
The second
.module
accesses another level of encapsulation or custom module (like another wrapper or submodule) aroundvqa_resnet
.
这里确实有两层 wrapper,第一个是ModelInputWrapper(vqa_resnet),第二个是 torch.nn.DataParallel(vqa_resnet)
查阅 pytorch-vqa 源码知道,text.embedding
self.text 是一个 TextProcessor 类的实例,而这个 embedding 是一个 PyTorch 的 nn.Embedding 层,用于将输入的单词索引序列(问题的 token id)映射成词向量(embedding)