-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
release initial version for ideal word computation
- Loading branch information
Showing
11 changed files
with
564 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
name: Ruff | ||
on: [push, pull_request] | ||
jobs: | ||
ruff: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: chartboost/ruff-action@v1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
name: Unit tests | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
test: | ||
name: py${{ matrix.python }}-${{ matrix.os }} | ||
runs-on: ${{ matrix.os }}-latest | ||
timeout-minutes: 10 | ||
strategy: | ||
matrix: | ||
os: [ubuntu] | ||
python: [3.10, 3.11, 3.12] | ||
include: | ||
- os: macos | ||
python: 3.12 | ||
- os: windows | ||
python: 3.12 | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python }} | ||
- name: Install | ||
run: | | ||
pip install .[dev] | ||
- name: Tests | ||
run: | | ||
pytest -vv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.4.0 | ||
hooks: | ||
- { id: check-added-large-files, args: ["--maxkb=300"] } | ||
- { id: check-case-conflict } | ||
- { id: detect-private-key } | ||
- repo: https://github.com/astral-sh/ruff-pre-commit | ||
# Ruff version. | ||
rev: v0.4.8 | ||
hooks: | ||
# Run the linter. | ||
- id: ruff | ||
args: [ --fix , --exit-non-zero-on-fix ] | ||
# Run the formatter. | ||
- id: ruff-format |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,10 @@ | ||
# ideal_words | ||
A PyTorch implementation of ideal word computation. | ||
# Ideal Words | ||
|
||
This small package provides a PyTorch implementation of ideal word computation which was proposed by Trager et al. in the paper [Linear Spaces of Meanings: Compositional Structures in Vision-Language Models](https://arxiv.org/abs/2302.14383). Ideal words can be seen as a compositional approximation to a given set of embedding vectors. This package allows computing these ideal words given a factored set of concepts $\mathcal{Z} = \mathcal{Z}_1 \times \dots \times \mathcal{Z}_k$ (e.g., $\{\mathrm{blue}, \mathrm{red}\} \times \{\mathrm{car}, \mathrm{bike}\}$) and a embedding function $f : \mathcal{Z} \to \mathbb{R}^n$. Additionally, it allows to quantify compositionality using the ideal word, real word, and average score from the paper (see Table 6 and 7 for details). | ||
|
||
## Usage | ||
|
||
You can install the package using: | ||
``` | ||
pip install git+https://github.com/icetube23/ideal_words.git | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
|
||
import torch | ||
from open_clip import create_model_and_transforms, get_tokenizer | ||
|
||
from ideal_words import FactorEmbedding, IdealWords | ||
|
||
|
||
class AttributeObjectFactorEmbedding(FactorEmbedding): | ||
def encode_text(self, text: torch.Tensor) -> torch.Tensor: | ||
# CLIP is not only a text encoder, so we need to specify how to use it for encoding text | ||
|
||
return self.txt_encoder.encode_text(text) | ||
|
||
def joint_repr(self, pair: tuple[str, ...]) -> str: | ||
attr, obj = pair | ||
# classic zero shot type caption of attribute-object dataset | ||
return f'an image of a {attr} {obj}' | ||
|
||
def single_repr(self, zi: str, factor_idx: int) -> str: | ||
# for the real word score, we also need to encode factors separately | ||
if factor_idx == 0: | ||
# zi is an attribute | ||
return f'image of a {zi} object' | ||
elif factor_idx == 1: | ||
# zi is an object type | ||
return f'image of a {zi}' | ||
else: | ||
raise IndexError(f'Invalid factor index: {factor_idx}') | ||
|
||
|
||
if __name__ == '__main__': | ||
# in the paper, they used a ViT-L-14 based CLIP model from OpenAI | ||
clip, _, preprocess_val = create_model_and_transforms('ViT-L-14', precision='fp16', pretrained='openai') | ||
tokenizer = get_tokenizer('ViT-L-14') | ||
print('Loaded CLIP model and tokenizer.') | ||
|
||
# load factors from mit-states | ||
dirname = os.path.dirname(os.path.abspath(__file__)) | ||
factors = {} | ||
with open(os.path.join(dirname, 'mit-states.csv'), 'r') as f: | ||
Z1, Z2 = [line.strip().split(',') for line in f.readlines()] | ||
factors['mit-states'] = Z1, Z2 | ||
with open(os.path.join(dirname, 'ut-zappos.csv'), 'r') as f: | ||
Z1, Z2 = [line.strip().split(',') for line in f.readlines()] | ||
factors['ut-zappos'] = Z1, Z2 | ||
print('Loaded factors for MIT-States and UT Zappos.') | ||
|
||
fe = AttributeObjectFactorEmbedding(clip, tokenizer, normalize=True) | ||
|
||
# compute ideal words and score for mit-states | ||
mit_iw = IdealWords(fe, factors['mit-states'], weights=None) # weights=None uses uniform weights for all factors | ||
mit_iw_score, mit_iw_std = mit_iw.iw_score | ||
mit_rw_score, mit_rw_std = mit_iw.rw_score | ||
mit_avg_score, mit_avg_std = mit_iw.avg_score | ||
print('Computed ideal words and scores for MIT-States.') | ||
|
||
# compute ideal words and score for ut-zappos | ||
ut_iw = IdealWords(fe, factors['ut-zappos'], weights=None) # weights=None uses uniform weights for all factors | ||
ut_iw_score, ut_iw_std = ut_iw.iw_score | ||
ut_rw_score, ut_rw_std = ut_iw.rw_score | ||
ut_avg_score, ut_avg_std = ut_iw.avg_score | ||
print('Computed ideal words and scores for UT Zappos.') | ||
|
||
# print table | ||
print() | ||
print(' IW RW Avg ') | ||
print('----------------------------------------------------') | ||
print( | ||
f'MIT-States {mit_iw_score:.2f} ± {mit_iw_std:.2f} {mit_rw_score:.2f} ± {mit_rw_std:.2f} {mit_avg_score:.2f} ± {mit_avg_std:.2f}' | ||
) | ||
print( | ||
f'UT Zappos {ut_iw_score:.2f} ± {ut_iw_std:.2f} {ut_rw_score:.2f} ± {ut_rw_std:.2f} {ut_avg_score:.2f} ± {ut_avg_std:.2f}' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
ancient,barren,bent,blunt,bright,broken,browned,brushed,burnt,caramelized,chipped,clean,clear,closed,cloudy,cluttered,coiled,cooked,cored,cracked,creased,crinkled,crumpled,crushed,curved,cut,damp,dark,deflated,dented,diced,dirty,draped,dry,dull,empty,engraved,eroded,fallen,filled,foggy,folded,frayed,fresh,frozen,full,grimy,heavy,huge,inflated,large,lightweight,loose,mashed,melted,modern,moldy,molten,mossy,muddy,murky,narrow,new,old,open,painted,peeled,pierced,pressed,pureed,raw,ripe,ripped,rough,ruffled,runny,rusty,scratched,sharp,shattered,shiny,short,sliced,small,smooth,spilled,splintered,squished,standing,steaming,straight,sunny,tall,thawed,thick,thin,tight,tiny,toppled,torn,unpainted,unripe,upright,verdant,viscous,weathered,wet,whipped,wide,wilted,windblown,winding,worn,wrinkled,young | ||
aluminum,animal,apple,armor,bag,ball,balloon,banana,basement,basket,bathroom,bay,beach,bean,bear,bed,beef,belt,berry,bike,blade,boat,book,bottle,boulder,bowl,box,bracelet,branch,brass,bread,bridge,bronze,bubble,bucket,building,bus,bush,butter,cabinet,cable,cake,camera,candle,candy,canyon,car,card,carpet,castle,cat,cave,ceiling,ceramic,chains,chair,cheese,chicken,chocolate,church,city,clay,cliff,clock,clothes,cloud,coal,coast,coat,coffee,coin,column,computer,concrete,cookie,copper,cord,cotton,creek,deck,desert,desk,diamond,dirt,dog,door,dress,drum,dust,eggs,elephant,envelope,fabric,fan,farm,fence,field,fig,fire,fish,flame,floor,flower,foam,forest,frame,fruit,furniture,garage,garden,garlic,gate,gear,gemstone,glass,glasses,granite,ground,handle,hat,highway,horse,hose,house,ice,iguana,island,jacket,jewelry,jungle,key,keyboard,kitchen,knife,lake,laptop,lead,leaf,lemon,library,lightbulb,lightning,log,mat,meat,metal,milk,mirror,moss,mountain,mud,necklace,nest,newspaper,nut,ocean,oil,orange,paint,palm,pants,paper,pasta,paste,pear,penny,persimmon,phone,pie,pizza,plant,plastic,plate,pond,pool,pot,potato,redwood,ribbon,ring,river,road,rock,roof,room,roots,rope,rubber,salad,salmon,sand,sandwich,sauce,screw,sea,seafood,shell,shirt,shoes,shore,shorts,shower,silk,sky,smoke,snake,snow,soup,steel,steps,stone,stream,street,sugar,sword,table,tea,thread,tie,tiger,tile,tire,tomato,tower,town,toy,trail,tree,truck,tube,tulip,vacuum,valley,vegetable,velvet,wall,water,wave,wax,well,wheel,window,wire,wood,wool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
canvas,cotton,faux fur,faux leather,full grain leather,hair calf,leather,nubuck,nylon,patent leather,rubber,satin,sheepskin,suede,synthetic,wool | ||
boots ankle,boots knee high,boots mid-calf,sandals,shoes boat shoes,shoes clogs and mules,shoes flats,shoes heels,shoes loafers,shoes oxfords,shoes sneakers and athletic shoes,slippers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .ideal_words import FactorEmbedding, IdealWords | ||
|
||
__all__ = ['FactorEmbedding', 'IdealWords'] |
Oops, something went wrong.