Skip to content

Commit

Permalink
Merge pull request #7 from ChaseMonsterAway/master
Browse files Browse the repository at this point in the history
new: Show fp fn of cls in visualization.
  • Loading branch information
hxcai authored Dec 4, 2020
2 parents b2514ed + 32f1fc2 commit a3962ba
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 24 deletions.
27 changes: 19 additions & 8 deletions volkscv/analyzer/visualization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,27 @@ from volkscv.analyzer.visualization import visualization


def test_cls():
gt_anno = parse_data(format='image',
# txt_file='data/val.txt',
imgs_folder='data')
dt_anno = parse_data(format='image',
# txt_file='data/val.txt',
imgs_folder='data')
"""
Format of txt_file:
img_path1 dog 0.7
img_path1 cat 0.9
The first column represents the path of the image file, the second column
represents the category, the third column represents the predict score.
Note: The score is optional, if you don't have score, you don't have add it.
"""
gt_anno = parse_data(format='txt',
txt_file=r'gt_file.txt',
categories=('dog', 'cat'),
imgs_folder=r'./test_vis')
dt_anno = parse_data(format='txt',
txt_file=r'pred_file.txt',
categories=('dog', 'cat'),
imgs_folder=r'./test_vis')

vis = visualization(task='cls', gt=gt_anno, pred=dt_anno)
params = dict(save_folder='./result',
# specified_imgs='data/val.txt',
show_ori=True,
show_fpfn=True,
show_ori=False,
category_to_show=None,
show_score=True)
vis.show(**params)
Expand Down
14 changes: 6 additions & 8 deletions volkscv/analyzer/visualization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseVis(metaclass=ABCMeta):
gt (dict): Gt data for visualization.
pred (dict): Pred data for visualization.
colors (dict): Colors for visualization.
extension (str): Image extention. Default: 'jpg'.
extension (str): Image extension. Default: 'jpg'.
"""

def __init__(self,
Expand Down Expand Up @@ -53,8 +53,7 @@ def __init__(self,

self._colors = get_pallete(
self.categories) if colors is None else colors
self.fnames = [os.path.split(i)[-1] for i in self.img_names]
self.img_prefixs = [os.path.split(i)[0] for i in self.img_names]
self.fnames = self.img_names.tolist()

@property
def colors(self):
Expand Down Expand Up @@ -112,8 +111,6 @@ def show(self,
key = -1
while len(index_list):
fname = index_list[0]
fname = os.path.join(self.img_prefixs[self.fnames.index(fname)],
fname)
img, flag = self.img_process(fname, **kwargs)
if flag:
key = show_img(img)
Expand Down Expand Up @@ -147,6 +144,8 @@ def show(self,
else:
index_list.append(index_list.popleft())
current_show += 1
if current_show >= len(index_list) - 1:
break
continue

def save(self, save_folder=None, specified_imgs=None, **kwargs):
Expand Down Expand Up @@ -217,9 +216,8 @@ def get_single_data(self, fname, category_to_show):
data = {}
labels = []
for key, value in self.data.items():
fname_list = value['img_names'].tolist()
if fname in fname_list:
index = fname_list.index(fname)
if fname in self.fnames:
index = self.fnames.index(fname)
anno = {k: v[index] for k, v in value.items() if v is not None}
if isinstance(anno['labels'], np.ndarray):
labels.extend(anno['labels'].tolist())
Expand Down
8 changes: 8 additions & 0 deletions volkscv/analyzer/visualization/classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import lru_cache

import cv2
import numpy as np

from .base import BaseVis
from .utils import draw_image, generate_mpl_figure
Expand All @@ -12,6 +13,7 @@ class ClsVis(BaseVis):
@lru_cache(maxsize=32)
def img_process(self,
fname,
show_fpfn=False,
show_ori=False,
category_to_show=None,
show_score=True):
Expand All @@ -21,6 +23,12 @@ def img_process(self,
if not flag:
return img, flag

if show_fpfn:
assert len(data) == 2, "Show fpfn need both ground truth file" \
" and prediction file."
if data['pred']['labels'] == data['gt']['labels']:
return None, False

imgs, title = {}, self.default_title.copy()
if show_ori:
imgs.update({'ori': img})
Expand Down
17 changes: 11 additions & 6 deletions volkscv/analyzer/visualization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,32 @@ def save_img(save_path, img):
cv2.imwrite(save_path, img)


def get_index_list(specified_imgs, default_index_list, extention='jpg'):
def get_index_list(specified_imgs, default_index_list, extension=('all',)):
"""Get index list for image browsing.
Args:
specified_imgs (str): Images need to be viewed, folder or txt file.
default_index_list (list): Default images.
extention (str): Image extention. Default: 'jpg'.
extension (str): Image extension. Default: 'jpg'.
Returns:
(deque): Images fname list.
"""

if not isinstance(extension, tuple):
extension = (extension, )
if specified_imgs is None:
index_list = deque(default_index_list)
elif os.path.isdir(specified_imgs):
img_list = os.listdir(specified_imgs)
index_list = deque([i for i in img_list])
index_list = deque([os.path.join(specified_imgs, i) for i in img_list if i.split('.')[-1] in extension])
elif specified_imgs.endswith('txt'):
img_list, _ = read_imglist(specified_imgs)
index_list = deque(
[i if i.endswith(extention) else f'{i}.{extention}' for i in
img_list])
if 'all' in extension:
index_list = deque(img_list)
else:
index_list = deque(
[i for i in img_list if i.split('.')[-1] in extension])
else:
index_list = deque(default_index_list)
return index_list
Expand Down
2 changes: 1 addition & 1 deletion volkscv/utils/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BaseParser(metaclass=ABCMeta):
xxx.jpg
xxxx.jpg
extensions (str): Image extention. Default: 'jpg'.
extensions (str): Image extension. Default: 'jpg'.
"""

def __init__(self,
Expand Down
2 changes: 1 addition & 1 deletion volkscv/utils/parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def read_imglist(imglist_path):

fnames, annos = [], []
with open(imglist_path, 'r') as fd:
for line in fd:
for line in fd.readlines():
ll = line.strip().split()
fnames.append(ll[0])
if len(ll) > 1:
Expand Down

0 comments on commit a3962ba

Please sign in to comment.