123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import copy
- import functools
- import os
- import torch
- import torch.utils.data as data
- from PIL import Image
- # import accimage
- import json
- def pil_loader(path):
- # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
- with open(path, 'rb') as f:
- with Image.open(f) as img:
- return img.convert('RGB')
- def accimage_loader(path):
- # try:
- # return accimage.Image(path)
- # except IOError:
- # # Potentially a decoding problem, fall back to PIL.Image
- # return pil_loader(path)
- return pil_loader(path)
- def get_default_image_loader():
- from torchvision import get_image_backend
- if get_image_backend() == 'accimage':
- return accimage_loader
- else:
- return pil_loader
- def video_loader(video_dir_path, frame_indices, image_loader):
- video = []
- for i in frame_indices:
- image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i))
- if os.path.exists(image_path):
- video.append(image_loader(image_path))
- else:
- return video
- return video
- def get_default_video_loader():
- image_loader = get_default_image_loader()
- return functools.partial(video_loader, image_loader=image_loader)
- def load_annotation_data(data_file_path):
- with open(data_file_path, 'r') as data_file:
- return json.load(data_file)
- def get_class_labels(data):
- class_labels_map = {}
- index = 0
- for class_label in data['labels']:
- class_labels_map[class_label] = index
- index += 1
- return class_labels_map
- def get_video_names_and_annotations(data, subset):
- video_names = []
- annotations = []
- for key, value in data['database'].items():
- this_subset = value['subset']
- if this_subset == subset:
- if subset == 'testing':
- video_names.append('test/{}'.format(key))
- else:
- label = value['annotations']['label']
- video_names.append('{}/{}'.format(label, key))
- annotations.append(value['annotations'])
- return video_names, annotations
- def make_dataset(video_path, sample_duration):
- dataset = []
- n_frames = len(os.listdir(video_path))
- begin_t = 1
- end_t = n_frames
- sample = {
- 'video': video_path,
- 'segment': [begin_t, end_t],
- 'n_frames': n_frames,
- }
- step = sample_duration
- for i in range(1, (n_frames - sample_duration + 1), step):
- sample_i = copy.deepcopy(sample)
- sample_i['frame_indices'] = list(range(i, i + sample_duration))
- sample_i['segment'] = torch.IntTensor([i, i + sample_duration - 1])
- dataset.append(sample_i)
- return dataset
- class Video(data.Dataset):
- def __init__(self, video_path,
- spatial_transform=None, temporal_transform=None,
- sample_duration=16, get_loader=get_default_video_loader):
- self.data = make_dataset(video_path, sample_duration)
- self.spatial_transform = spatial_transform
- self.temporal_transform = temporal_transform
- self.loader = get_loader()
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is class_index of the target class.
- """
- path = self.data[index]['video']
- frame_indices = self.data[index]['frame_indices']
- if self.temporal_transform is not None:
- frame_indices = self.temporal_transform(frame_indices)
- clip = self.loader(path, frame_indices)
- if self.spatial_transform is not None:
- clip = [self.spatial_transform(img) for img in clip]
- clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
- target = self.data[index]['segment']
- return clip, target
- def __len__(self):
- return len(self.data)
|