Source code for tensorlayer.files.dataset_loaders.mpii_dataset

#! /usr/bin/python
# -*- coding: utf-8 -*-

import os

from tensorlayer import logging
from tensorlayer.files.utils import (del_file, folder_exists, load_file_list, maybe_download_and_extract)

__all__ = ['load_mpii_pose_dataset']


[docs]def load_mpii_pose_dataset(path='data', is_16_pos_only=False): """Load MPII Human Pose Dataset. Parameters ----------- path : str The path that the data is downloaded to. is_16_pos_only : boolean If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation) Returns ---------- img_train_list : list of str The image directories of training data. ann_train_list : list of dict The annotations of training data. img_test_list : list of str The image directories of testing data. ann_test_list : list of dict The annotations of testing data. Examples -------- >>> import pprint >>> import tensorlayer as tl >>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset() >>> image = tl.vis.read_image(img_train_list[0]) >>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png') >>> pprint.pprint(ann_train_list[0]) References ----------- - `MPII Human Pose Dataset. CVPR 14 <http://human-pose.mpi-inf.mpg.de>`__ - `MPII Human Pose Models. CVPR 16 <http://pose.mpi-inf.mpg.de>`__ - `MPII Human Shape, Poselet Conditioned Pictorial Structures and etc <http://pose.mpi-inf.mpg.de/#related>`__ - `MPII Keyponts and ID <http://human-pose.mpi-inf.mpg.de/#download>`__ """ path = os.path.join(path, 'mpii_human_pose') logging.info("Load or Download MPII Human Pose > {}".format(path)) # annotation url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/" tar_filename = "mpii_human_pose_v1_u12_2.zip" extracted_filename = "mpii_human_pose_v1_u12_2" if folder_exists(os.path.join(path, extracted_filename)) is False: logging.info("[MPII] (annotation) {} is nonexistent in {}".format(extracted_filename, path)) maybe_download_and_extract(tar_filename, path, url, extract=True) del_file(os.path.join(path, tar_filename)) # images url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/" tar_filename = "mpii_human_pose_v1.tar.gz" extracted_filename2 = "images" if folder_exists(os.path.join(path, extracted_filename2)) is False: logging.info("[MPII] (images) {} is nonexistent in {}".format(extracted_filename, path)) maybe_download_and_extract(tar_filename, path, url, extract=True) del_file(os.path.join(path, tar_filename)) # parse annotation, format see http://human-pose.mpi-inf.mpg.de/#download import scipy.io as sio logging.info("reading annotations from mat file ...") # mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat")) # def fix_wrong_joints(joint): # https://github.com/mitmul/deeppose/blob/master/datasets/mpii_dataset.py # if '12' in joint and '13' in joint and '2' in joint and '3' in joint: # if ((joint['12'][0] < joint['13'][0]) and # (joint['3'][0] < joint['2'][0])): # joint['2'], joint['3'] = joint['3'], joint['2'] # if ((joint['12'][0] > joint['13'][0]) and # (joint['3'][0] > joint['2'][0])): # joint['2'], joint['3'] = joint['3'], joint['2'] # return joint ann_train_list = [] ann_test_list = [] img_train_list = [] img_test_list = [] def save_joints(): # joint_data_fn = os.path.join(path, 'data.json') # fp = open(joint_data_fn, 'w') mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat")) for _, (anno, train_flag) in enumerate( # all images zip(mat['RELEASE']['annolist'][0, 0][0], mat['RELEASE']['img_train'][0, 0][0])): img_fn = anno['image']['name'][0, 0][0] train_flag = int(train_flag) # print(i, img_fn, train_flag) # DEBUG print all images if train_flag: img_train_list.append(img_fn) ann_train_list.append([]) else: img_test_list.append(img_fn) ann_test_list.append([]) head_rect = [] if 'x1' in str(anno['annorect'].dtype): head_rect = zip( [x1[0, 0] for x1 in anno['annorect']['x1'][0]], [y1[0, 0] for y1 in anno['annorect']['y1'][0]], [x2[0, 0] for x2 in anno['annorect']['x2'][0]], [y2[0, 0] for y2 in anno['annorect']['y2'][0]] ) else: head_rect = [] # TODO if 'annopoints' in str(anno['annorect'].dtype): annopoints = anno['annorect']['annopoints'][0] head_x1s = anno['annorect']['x1'][0] head_y1s = anno['annorect']['y1'][0] head_x2s = anno['annorect']['x2'][0] head_y2s = anno['annorect']['y2'][0] for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(annopoints, head_x1s, head_y1s, head_x2s, head_y2s): # if annopoint != []: # if len(annopoint) != 0: if annopoint.size: head_rect = [ float(head_x1[0, 0]), float(head_y1[0, 0]), float(head_x2[0, 0]), float(head_y2[0, 0]) ] # joint coordinates annopoint = annopoint['point'][0, 0] j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]] x = [x[0, 0] for x in annopoint['x'][0]] y = [y[0, 0] for y in annopoint['y'][0]] joint_pos = {} for _j_id, (_x, _y) in zip(j_id, zip(x, y)): joint_pos[int(_j_id)] = [float(_x), float(_y)] # joint_pos = fix_wrong_joints(joint_pos) # visibility list if 'is_visible' in str(annopoint.dtype): vis = [v[0] if v.size > 0 else [0] for v in annopoint['is_visible'][0]] vis = dict([(k, int(v[0])) if len(v) > 0 else v for k, v in zip(j_id, vis)]) else: vis = None # if len(joint_pos) == 16: if ((is_16_pos_only ==True) and (len(joint_pos) == 16)) or (is_16_pos_only == False): # only use image with 16 key points / or use all data = { 'filename': img_fn, 'train': train_flag, 'head_rect': head_rect, 'is_visible': vis, 'joint_pos': joint_pos } # print(json.dumps(data), file=fp) # py3 if train_flag: ann_train_list[-1].append(data) else: ann_test_list[-1].append(data) # def write_line(datum, fp): # joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()]) # joints = np.array([j for i, j in joints]).flatten() # # out = [datum['filename']] # out.extend(joints) # out = [str(o) for o in out] # out = ','.join(out) # # print(out, file=fp) # def split_train_test(): # # fp_test = open('data/mpii/test_joints.csv', 'w') # fp_test = open(os.path.join(path, 'test_joints.csv'), 'w') # # fp_train = open('data/mpii/train_joints.csv', 'w') # fp_train = open(os.path.join(path, 'train_joints.csv'), 'w') # # all_data = open('data/mpii/data.json').readlines() # all_data = open(os.path.join(path, 'data.json')).readlines() # N = len(all_data) # N_test = int(N * 0.1) # N_train = N - N_test # # print('N:{}'.format(N)) # print('N_train:{}'.format(N_train)) # print('N_test:{}'.format(N_test)) # # np.random.seed(1701) # perm = np.random.permutation(N) # test_indices = perm[:N_test] # train_indices = perm[N_test:] # # print('train_indices:{}'.format(len(train_indices))) # print('test_indices:{}'.format(len(test_indices))) # # for i in train_indices: # datum = json.loads(all_data[i].strip()) # write_line(datum, fp_train) # # for i in test_indices: # datum = json.loads(all_data[i].strip()) # write_line(datum, fp_test) save_joints() # split_train_test() # ## read images dir logging.info("reading images list ...") img_dir = os.path.join(path, extracted_filename2) _img_list = load_file_list(path=os.path.join(path, extracted_filename2), regx='\\.jpg', printable=False) # ann_list = json.load(open(os.path.join(path, 'data.json'))) for i, im in enumerate(img_train_list): if im not in _img_list: print('missing training image {} in {} (remove from img(ann)_train_list)'.format(im, img_dir)) # img_train_list.remove(im) del img_train_list[i] del ann_train_list[i] for i, im in enumerate(img_test_list): if im not in _img_list: print('missing testing image {} in {} (remove from img(ann)_test_list)'.format(im, img_dir)) # img_test_list.remove(im) del img_train_list[i] del ann_train_list[i] ## check annotation and images n_train_images = len(img_train_list) n_test_images = len(img_test_list) n_images = n_train_images + n_test_images logging.info("n_images: {} n_train_images: {} n_test_images: {}".format(n_images, n_train_images, n_test_images)) n_train_ann = len(ann_train_list) n_test_ann = len(ann_test_list) n_ann = n_train_ann + n_test_ann logging.info("n_ann: {} n_train_ann: {} n_test_ann: {}".format(n_ann, n_train_ann, n_test_ann)) n_train_people = len(sum(ann_train_list, [])) n_test_people = len(sum(ann_test_list, [])) n_people = n_train_people + n_test_people logging.info("n_people: {} n_train_people: {} n_test_people: {}".format(n_people, n_train_people, n_test_people)) # add path to all image file name for i, value in enumerate(img_train_list): img_train_list[i] = os.path.join(img_dir, value) for i, value in enumerate(img_test_list): img_test_list[i] = os.path.join(img_dir, value) return img_train_list, ann_train_list, img_test_list, ann_test_list