Impletation of paper Understanding Retrieval Robustness for Retrieval-Augmented Image Captioning (ACL 2024).
The code is based on was developed in Python 3.9. and is based on SmallCap.
conda create -n robcap python=3.9
conda activate robcap
pip install -r requirements.txt
Follow the data downloading and preprocessing process as in SmallCap. Retrieved captions for COCO can be downloaded here, which includes the top 7 most relevant captions for each image. A preprocessed larger list (top 50) of retrieved captions can be downloaded here. You can also follow the instructions to retrieve captions by yourself.
python src/preprocess/permute_retrieved_caps.py --input_file <input_file_path> --method permute --topk 4
Use this version of pycocoeval, SPICE model would be automatically installed, and it is possible to calculate CLIPScore (Recommended)
pip install git+https://github.com/jmhessel/pycocoevalcap
Otherwise, you can download Stanford models for computing SPICE (a slightly modified version of this repo):
./coco-caption/get_stanford_models.sh
After the pycocoevalcap
is installed you can run:
python src/run_eval.py <GOLD_ANN_PATH> <PREDICTIONS_PATH>
output results are saved in the same folder as the <PREDICTIONS_PATH>
e.g. python src/run_eval.py coco-caption/annotations/captions_valKarpathy.json baseline/rag_7M_gpt2/checkpoint-88560/val_preds_original.json
The pretrained model is on huggingface:
config = AutoConfig.from_pretrained('lyan62/RobustCap')
model = AutoModel.from_pretrained('lyan62/RobustCap')
model.config = config
export ORDER="sample"
python train.py \
--experiments_dir $EXP \
--captions_path $CAPTIONS_PATH \
--decoder_name facebook/opt-350m \
--attention_size 1.75 \
--batch_size 64 \
--n_epochs 10 \
--order $ORDER \
--k 4
export ORDER="default"
python infer.py --model_path $MODEL_PATH --checkpoint_path checkpoint-88560 \
--decoder_name "facebook/opt-350m" \
--captions_path $CAPTIONS_PATH \
--order $ORDER \
--outfile_postfix _$ORDER
and calculate scores with:
python src/run_eval.py \
coco-caption/annotations/captions_valKarpathy.json \
$MODEL_PATH/checkpoint-88560/val_coco_preds_$ORDER.json
-
src/vis
contains two notebook which visualizes the attention plots (vis_attn.ipynb
) for decoder self attention and cross attention (vis_cross_attn.ipynb
). -
get_attn_layer_distr.py
andget_prompt_token_attn_distr.py
are helper scripts that extract maximum attention scores at a layerwise or a tokenwise manner for visualization use. (these might need to be organized later)
@article{li2024understanding,
title={Understanding Retrieval Robustness for Retrieval-Augmented Image Captioning},
author={Li, Wenyan and Li, Jiaang and Ramos, Rita and Tang, Raphael and Elliott, Desmond},
journal={arXiv preprint arXiv:2406.02265},
year={2024}
}