|
@@ -1,14 +1,9 @@
|
|
|
-import copy
|
|
|
-import functools
|
|
|
-import json
|
|
|
-import os
|
|
|
import pickle
|
|
|
from glob import glob
|
|
|
from typing import Tuple, List
|
|
|
|
|
|
import torch
|
|
|
import torch.utils.data as data
|
|
|
-from PIL import Image
|
|
|
|
|
|
|
|
|
class AV(data.Dataset):
|
|
@@ -71,133 +66,3 @@ class AudioVideo3D(AV):
|
|
|
ee = e.permute((1, 0, 2, 3))
|
|
|
out.append(ee)
|
|
|
return out
|
|
|
-
|
|
|
-
|
|
|
-def pil_loader(path):
|
|
|
- # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
|
|
- with open(path, 'rb') as f:
|
|
|
- with Image.open(f) as img:
|
|
|
- return img.convert('RGB')
|
|
|
-
|
|
|
-
|
|
|
-def accimage_loader(path):
|
|
|
- # try:
|
|
|
- # return accimage.Image(path)
|
|
|
- # except IOError:
|
|
|
- # # Potentially a decoding problem, fall back to PIL.Image
|
|
|
- # return pil_loader(path)
|
|
|
- return pil_loader(path)
|
|
|
-
|
|
|
-
|
|
|
-def get_default_image_loader():
|
|
|
- from torchvision import get_image_backend
|
|
|
- if get_image_backend() == 'accimage':
|
|
|
- return accimage_loader
|
|
|
- else:
|
|
|
- return pil_loader
|
|
|
-
|
|
|
-
|
|
|
-def video_loader(video_dir_path, frame_indices, image_loader):
|
|
|
- video = []
|
|
|
- for i in frame_indices:
|
|
|
- image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i))
|
|
|
- if os.path.exists(image_path):
|
|
|
- video.append(image_loader(image_path))
|
|
|
- else:
|
|
|
- return video
|
|
|
-
|
|
|
- return video
|
|
|
-
|
|
|
-
|
|
|
-def get_default_video_loader():
|
|
|
- image_loader = get_default_image_loader()
|
|
|
- return functools.partial(video_loader, image_loader=image_loader)
|
|
|
-
|
|
|
-
|
|
|
-def load_annotation_data(data_file_path):
|
|
|
- with open(data_file_path, 'r') as data_file:
|
|
|
- return json.load(data_file)
|
|
|
-
|
|
|
-
|
|
|
-def get_class_labels(data):
|
|
|
- class_labels_map = {}
|
|
|
- index = 0
|
|
|
- for class_label in data['labels']:
|
|
|
- class_labels_map[class_label] = index
|
|
|
- index += 1
|
|
|
- return class_labels_map
|
|
|
-
|
|
|
-
|
|
|
-def get_video_names_and_annotations(data, subset):
|
|
|
- video_names = []
|
|
|
- annotations = []
|
|
|
-
|
|
|
- for key, value in data['database'].items():
|
|
|
- this_subset = value['subset']
|
|
|
- if this_subset == subset:
|
|
|
- if subset == 'testing':
|
|
|
- video_names.append('test/{}'.format(key))
|
|
|
- else:
|
|
|
- label = value['annotations']['label']
|
|
|
- video_names.append('{}/{}'.format(label, key))
|
|
|
- annotations.append(value['annotations'])
|
|
|
-
|
|
|
- return video_names, annotations
|
|
|
-
|
|
|
-
|
|
|
-def make_dataset(video_path, sample_duration):
|
|
|
- dataset = []
|
|
|
-
|
|
|
- n_frames = len(os.listdir(video_path))
|
|
|
-
|
|
|
- begin_t = 1
|
|
|
- end_t = n_frames
|
|
|
- sample = {
|
|
|
- 'video': video_path,
|
|
|
- 'segment': [begin_t, end_t],
|
|
|
- 'n_frames': n_frames,
|
|
|
- }
|
|
|
-
|
|
|
- step = sample_duration
|
|
|
- for i in range(1, (n_frames - sample_duration + 1), step):
|
|
|
- sample_i = copy.deepcopy(sample)
|
|
|
- sample_i['frame_indices'] = list(range(i, i + sample_duration))
|
|
|
- sample_i['segment'] = torch.IntTensor([i, i + sample_duration - 1])
|
|
|
- dataset.append(sample_i)
|
|
|
-
|
|
|
- return dataset
|
|
|
-
|
|
|
-
|
|
|
-class Video(data.Dataset):
|
|
|
- def __init__(self, video_path,
|
|
|
- spatial_transform=None, temporal_transform=None,
|
|
|
- sample_duration=16, get_loader=get_default_video_loader):
|
|
|
- self.data = make_dataset(video_path, sample_duration)
|
|
|
-
|
|
|
- self.spatial_transform = spatial_transform
|
|
|
- self.temporal_transform = temporal_transform
|
|
|
- self.loader = get_loader()
|
|
|
-
|
|
|
- def __getitem__(self, index):
|
|
|
- """
|
|
|
- Args:
|
|
|
- index (int): Index
|
|
|
- Returns:
|
|
|
- tuple: (image, target) where target is class_index of the target class.
|
|
|
- """
|
|
|
- path = self.data[index]['video']
|
|
|
-
|
|
|
- frame_indices = self.data[index]['frame_indices']
|
|
|
- if self.temporal_transform is not None:
|
|
|
- frame_indices = self.temporal_transform(frame_indices)
|
|
|
- clip = self.loader(path, frame_indices)
|
|
|
- if self.spatial_transform is not None:
|
|
|
- clip = [self.spatial_transform(img) for img in clip]
|
|
|
- clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
|
|
|
-
|
|
|
- target = self.data[index]['segment']
|
|
|
-
|
|
|
- return clip, target
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- return len(self.data)
|