In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, AutoProcessor
from PIL import Image
from config import get_default_cfg
from iprm_model import IPRM_Model
import pickle
from vis_utils import *
import json
[2024-03-15 04:46:30,771] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
In [3]:
image_processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
#below line needs to be done for VQA settings (where full image should be visible and not center cropped); further the 'transformers.image_processing_clip' needs to be modified to handle 'do_center_crop' (change/override the 'resize' function in transformers.image_processing_clip and set default_to_square=True if do_center_crop is False else you will get non-uniform image resolutions)
image_processor.image_processor.do_center_crop=False
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
clip_model = clip_model.to(device)
In [4]:
cfg = get_default_cfg()
config_file = "./configs/clip-vit-l-iprm.yaml"
if config_file:
cfg.merge_from_file(config_file)
In [5]:
iprm_model = IPRM_Model(cfg).to(device)
aid2w = iprm_model.aid2w
In [6]:
with open("/data/usrdata/clip_iprm_vitl_gqa.model", 'rb') as f:
state = torch.load(f, map_location=device)
iprm_model.load_state_dict(state, strict=True)
Out[6]:
<All keys matched successfully>
In [7]:
iprm_model = iprm_model.eval()
In [8]:
questions = json.load(open("/data/usrdata/GQA/balanced_testdev_data.json",'rb'))['questions']
Example 1 (Correct model prediction with sensible intermediate attentions¶
In [9]:
q = questions[9894]
q
Out[9]:
{'questionId': '201873667', 'group': 'existAttrOrC', 'answer': 'no', 'type': 'logical', 'semanticStr': 'select: flag (9)->filter color: red [0]->exist: ? [1]->select: kite (-) ->filter color: red [3]->exist: ? [4]->or: [2, 5]', 'fullAnswer': 'No, there is a flag but it is blue.', 'question': 'Are there any red kites or flags?', 'semantic': [{'operation': 'select', 'dependencies': [], 'argument': 'flag (9)'}, {'operation': 'filter color', 'dependencies': [0], 'argument': 'red'}, {'operation': 'exist', 'dependencies': [1], 'argument': '?'}, {'operation': 'select', 'dependencies': [], 'argument': 'kite (-) '}, {'operation': 'filter color', 'dependencies': [3], 'argument': 'red'}, {'operation': 'exist', 'dependencies': [4], 'argument': '?'}, {'operation': 'or', 'dependencies': [2, 5], 'argument': ''}], 'imageId': 'n309148'}
In [10]:
img = Image.open(f"/data/usrdata/GQA/images/{q['imageId']}.jpg")
In [11]:
img
Out[11]:
In [12]:
image_input = image_processor(images=[img], return_tensors="pt").to(device)
clip_outs = clip_model.vision_model.forward(**image_input)['last_hidden_state']
clip_outs.shape
Out[12]:
torch.Size([1, 257, 1024])
In [13]:
og_question = q['question']
out = iprm_model(clip_outs, og_question)
In [14]:
outputs = out.argmax(dim=-1)
aid2w[int(outputs[0])] #model answer
Out[14]:
'no'
In [15]:
atts = iprm_model.vlm_module.attentions
In [16]:
b_i=0
img = img.convert('RGB') #remove 4 channels just in case
vis_parallel_ops = 6
vis_iterative_steps = 8
img_att_tokens = atts['image'].squeeze(-1).transpose(0,1)[b_i,:vis_iterative_steps,:vis_parallel_ops,:]
img_att_tokens = F.softmax(img_att_tokens, dim=2).detach().cpu()
text_att_tokens = atts['text'].squeeze(-1).transpose(0,1)[b_i,:vis_iterative_steps,:vis_parallel_ops,:]
text_att_tokens = F.softmax(text_att_tokens, dim=2).detach().cpu()
mem_att_tokens = atts['inter_mem_op'].squeeze(-1).transpose(0,1)[b_i].detach().cpu()
In [17]:
og_question_text_list = iprm_model.tokenizer.tokenize(og_question, add_special_tokens=False, return_length=True, padding=True)
og_question_text = og_question
In [18]:
num_parallel_ops = iprm_model.vlm_module.num_parallel_ops
num_iterative_steps = iprm_model.vlm_module.num_iterative_steps
In [19]:
prev_text_att_wt = None
prev_vis_att_wt = None
img_atts_over_t = []
text_atts_over_t = []
text_atts_over_t_maxed = []
text_atts_over_t_meaned = []
img_atts_over_t_maxed = []
img_atts_over_t_meaned = []
text_atts_over_t_cumulate_ops = []
for t in range(vis_iterative_steps):
input_tokens_text = ['Op{}'.format(i) for i in range(vis_parallel_ops)]
inter_output_tokens_text = og_question_text_list
output_tokens_text = []
for tok in inter_output_tokens_text:
if(tok.startswith('Ä ')):
output_tokens_text.append(tok[1:])
else:
output_tokens_text.append(tok)
attention_weights_text_before = np.array(text_att_tokens[t].transpose(0,1))
attention_weights_text_before = attention_weights_text_before[:len(output_tokens_text), :]
text_atts_over_t.append((attention_weights_text_before, input_tokens_text, output_tokens_text))
text_atts_over_t_meaned.append((attention_weights_text_before.mean(axis=1, keepdims=True), ['Cuml_Op'], output_tokens_text))
text_atts_over_t_maxed.append((attention_weights_text_before.max(axis=1, keepdims=True), ['Cuml_Op'], output_tokens_text))
img_atts = get_img_atts(img, img_att_tokens[t], vis_parallel_ops)
img_atts_over_t.append(img_atts)
img_atts_over_t_meaned.append(np.array(img_atts).mean(axis=0, keepdims=True))
img_atts_over_t_maxed.append(np.array(img_atts).max(axis=0, keepdims=True))
In [20]:
plot_lang_atts_across_times_together(img_atts_over_t, text_atts_over_t, figsize=(13,3.7))
In [21]:
plot_img_atts_across_times_together(img_atts_over_t_maxed, text_atts_over_t_maxed, figsize=(24,24),
hlow=0, hmax=256, wlow=0, wmax=256,
wspace=0.05, hspace=0.05
)
Example2 (Incorrect model prediction but intermediate attentions appear sensible and model may be referring to ground truth answer but use wrong/imprecise name)¶
In [22]:
q = questions[3910]
q
Out[22]:
{'questionId': '20705704', 'group': 'categoryThat', 'answer': 'monitor', 'type': 'query', 'semanticStr': 'select: device (1)->filter: on [0]->query: name [1]', 'fullAnswer': 'The device is a monitor.', 'question': 'What is the device that is on?', 'semantic': [{'operation': 'select', 'dependencies': [], 'argument': 'device (1)'}, {'operation': 'filter', 'dependencies': [0], 'argument': 'on'}, {'operation': 'query', 'dependencies': [1], 'argument': 'name'}], 'imageId': 'n264887'}
In [23]:
img = Image.open(f"/data/usrdata/GQA/images/{q['imageId']}.jpg")
In [24]:
img
Out[24]:
In [25]:
image_input = image_processor(images=[img], return_tensors="pt").to(device)
clip_outs = clip_model.vision_model.forward(**image_input)['last_hidden_state']
clip_outs.shape
Out[25]:
torch.Size([1, 257, 1024])
In [26]:
og_question = q['question']
out = iprm_model(clip_outs, og_question)
In [27]:
outputs = out.argmax(dim=-1)
aid2w[int(outputs[0])] #model prediction ('computer') but ground truth is 'monitor' (further visual attentions below suggest method refers to monitor as computer)
Out[27]:
'computer'
In [28]:
atts = iprm_model.vlm_module.attentions
In [29]:
b_i=0
img = img.convert('RGB') #remove 4 channels just in case
img_att_tokens = atts['image'].squeeze(-1).transpose(0,1)[b_i,:vis_iterative_steps,:vis_parallel_ops,:]
img_att_tokens = F.softmax(img_att_tokens, dim=2).detach().cpu()
text_att_tokens = atts['text'].squeeze(-1).transpose(0,1)[b_i,:vis_iterative_steps,:vis_parallel_ops,:]
text_att_tokens = F.softmax(text_att_tokens, dim=2).detach().cpu()
mem_att_tokens = atts['inter_mem_op'].squeeze(-1).transpose(0,1)[b_i].detach().cpu()
In [30]:
og_question_text_list = iprm_model.tokenizer.tokenize(og_question, add_special_tokens=False, return_length=True, padding=True)
og_question_text = og_question
In [31]:
num_parallel_ops = iprm_model.vlm_module.num_parallel_ops
num_iterative_steps = iprm_model.vlm_module.num_iterative_steps
In [32]:
prev_text_att_wt = None
prev_vis_att_wt = None
img_atts_over_t = []
text_atts_over_t = []
text_atts_over_t_maxed = []
text_atts_over_t_meaned = []
img_atts_over_t_maxed = []
img_atts_over_t_meaned = []
text_atts_over_t_cumulate_ops = []
for t in range(vis_iterative_steps):
input_tokens_text = ['Op{}'.format(i) for i in range(vis_parallel_ops)]
inter_output_tokens_text = og_question_text_list
output_tokens_text = []
for tok in inter_output_tokens_text:
if(tok.startswith('Ä ')):
output_tokens_text.append(tok[1:])
else:
output_tokens_text.append(tok)
attention_weights_text_before = np.array(text_att_tokens[t].transpose(0,1))
attention_weights_text_before = attention_weights_text_before[:len(output_tokens_text), :]
text_atts_over_t.append((attention_weights_text_before, input_tokens_text, output_tokens_text))
text_atts_over_t_meaned.append((attention_weights_text_before.mean(axis=1, keepdims=True), ['Cuml_Op'], output_tokens_text))
text_atts_over_t_maxed.append((attention_weights_text_before.max(axis=1, keepdims=True), ['Cuml_Op'], output_tokens_text))
img_atts = get_img_atts(img, img_att_tokens[t], vis_parallel_ops)
img_atts_over_t.append(img_atts)
img_atts_over_t_meaned.append(np.array(img_atts).mean(axis=0, keepdims=True))
img_atts_over_t_maxed.append(np.array(img_atts).max(axis=0, keepdims=True))
In [33]:
plot_lang_atts_across_times_together(img_atts_over_t, text_atts_over_t, figsize=(13,3.7))
In [34]:
plot_img_atts_across_times_together(img_atts_over_t_maxed, text_atts_over_t_maxed, figsize=(24,24),
hlow=0, hmax=256, wlow=0, wmax=256,
wspace=0.05, hspace=0.05
)
Example 3: Correct model prediction but intermediate visual attentions are more scattered and relatively imprecise¶
In [35]:
q = questions[8521]
q
Out[35]:
{'questionId': '201623853', 'group': 'existRelS', 'answer': 'yes', 'type': 'verify', 'semanticStr': 'select: refrigerator (0)->relate: floor,on,o (8) [0]->relate: cabinets,above,s (5) [1]->exist: ? [2]', 'fullAnswer': 'Yes, there are cabinets above the floor.', 'question': 'Do you see a cabinet above the floor the freezer is on?', 'semantic': [{'operation': 'select', 'dependencies': [], 'argument': 'refrigerator (0)'}, {'operation': 'relate', 'dependencies': [0], 'argument': 'floor,on,o (8)'}, {'operation': 'relate', 'dependencies': [1], 'argument': 'cabinets,above,s (5)'}, {'operation': 'exist', 'dependencies': [2], 'argument': '?'}], 'imageId': 'n501609'}
In [36]:
img = Image.open(f"/data/usrdata/GQA/images/{q['imageId']}.jpg")
In [37]:
img
Out[37]:
In [38]:
image_input = image_processor(images=[img], return_tensors="pt").to(device)
clip_outs = clip_model.vision_model.forward(**image_input)['last_hidden_state']
clip_outs.shape
Out[38]:
torch.Size([1, 257, 1024])
In [39]:
og_question = q['question']
out = iprm_model(clip_outs, og_question)
In [40]:
outputs = out.argmax(dim=-1)
aid2w[int(outputs[0])] #model answer is correct however with less relevant/more scattered visual attentions
Out[40]:
'yes'
In [41]:
atts = iprm_model.vlm_module.attentions
In [42]:
b_i=0
img = img.convert('RGB') #remove 4 channels just in case
img_att_tokens = atts['image'].squeeze(-1).transpose(0,1)[b_i,:vis_iterative_steps,:vis_parallel_ops,:]
img_att_tokens = F.softmax(img_att_tokens, dim=2).detach().cpu()
text_att_tokens = atts['text'].squeeze(-1).transpose(0,1)[b_i,:vis_iterative_steps,:vis_parallel_ops,:]
text_att_tokens = F.softmax(text_att_tokens, dim=2).detach().cpu()
mem_att_tokens = atts['inter_mem_op'].squeeze(-1).transpose(0,1)[b_i].detach().cpu()
In [43]:
img_att_tokens.shape
Out[43]:
torch.Size([8, 6, 256])
In [44]:
og_question_text_list = iprm_model.tokenizer.tokenize(og_question, add_special_tokens=False, return_length=True, padding=True)
og_question_text = og_question
In [45]:
num_parallel_ops = iprm_model.vlm_module.num_parallel_ops
num_iterative_steps = iprm_model.vlm_module.num_iterative_steps
In [46]:
prev_text_att_wt = None
prev_vis_att_wt = None
img_atts_over_t = []
text_atts_over_t = []
text_atts_over_t_maxed = []
text_atts_over_t_meaned = []
img_atts_over_t_maxed = []
img_atts_over_t_meaned = []
text_atts_over_t_cumulate_ops = []
for t in range(vis_iterative_steps):
input_tokens_text = ['Op{}'.format(i) for i in range(vis_parallel_ops)]
inter_output_tokens_text = og_question_text_list
output_tokens_text = []
for tok in inter_output_tokens_text:
if(tok.startswith('Ä ')):
output_tokens_text.append(tok[1:])
else:
output_tokens_text.append(tok)
attention_weights_text_before = np.array(text_att_tokens[t].transpose(0,1))
attention_weights_text_before = attention_weights_text_before[:len(output_tokens_text), :]
text_atts_over_t.append((attention_weights_text_before, input_tokens_text, output_tokens_text))
text_atts_over_t_meaned.append((attention_weights_text_before.mean(axis=1, keepdims=True), ['Cuml_Op'], output_tokens_text))
text_atts_over_t_maxed.append((attention_weights_text_before.max(axis=1, keepdims=True), ['Cuml_Op'], output_tokens_text))
img_atts = get_img_atts(img, img_att_tokens[t], vis_parallel_ops)
img_atts_over_t.append(img_atts)
img_atts_over_t_meaned.append(np.array(img_atts).mean(axis=0, keepdims=True))
img_atts_over_t_maxed.append(np.array(img_atts).max(axis=0, keepdims=True))
In [47]:
plot_lang_atts_across_times_together(img_atts_over_t, text_atts_over_t, figsize=(13,3.7))
In [48]:
plot_img_atts_across_times_together(img_atts_over_t_maxed, text_atts_over_t_maxed, figsize=(24,24),
hlow=0, hmax=256, wlow=0, wmax=256,
wspace=0.05, hspace=0.05
)