data.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import copy
  2. import functools
  3. import json
  4. import os
  5. import pickle
  6. from glob import glob
  7. from typing import Tuple
  8. import torch
  9. import torch.utils.data as data
  10. from PIL import Image
  11. def pil_loader(path):
  12. # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  13. with open(path, 'rb') as f:
  14. with Image.open(f) as img:
  15. return img.convert('RGB')
  16. def accimage_loader(path):
  17. # try:
  18. # return accimage.Image(path)
  19. # except IOError:
  20. # # Potentially a decoding problem, fall back to PIL.Image
  21. # return pil_loader(path)
  22. return pil_loader(path)
  23. def get_default_image_loader():
  24. from torchvision import get_image_backend
  25. if get_image_backend() == 'accimage':
  26. return accimage_loader
  27. else:
  28. return pil_loader
  29. def video_loader(video_dir_path, frame_indices, image_loader):
  30. video = []
  31. for i in frame_indices:
  32. image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i))
  33. if os.path.exists(image_path):
  34. video.append(image_loader(image_path))
  35. else:
  36. return video
  37. return video
  38. def get_default_video_loader():
  39. image_loader = get_default_image_loader()
  40. return functools.partial(video_loader, image_loader=image_loader)
  41. def load_annotation_data(data_file_path):
  42. with open(data_file_path, 'r') as data_file:
  43. return json.load(data_file)
  44. def get_class_labels(data):
  45. class_labels_map = {}
  46. index = 0
  47. for class_label in data['labels']:
  48. class_labels_map[class_label] = index
  49. index += 1
  50. return class_labels_map
  51. def get_video_names_and_annotations(data, subset):
  52. video_names = []
  53. annotations = []
  54. for key, value in data['database'].items():
  55. this_subset = value['subset']
  56. if this_subset == subset:
  57. if subset == 'testing':
  58. video_names.append('test/{}'.format(key))
  59. else:
  60. label = value['annotations']['label']
  61. video_names.append('{}/{}'.format(label, key))
  62. annotations.append(value['annotations'])
  63. return video_names, annotations
  64. def make_dataset(video_path, sample_duration):
  65. dataset = []
  66. n_frames = len(os.listdir(video_path))
  67. begin_t = 1
  68. end_t = n_frames
  69. sample = {
  70. 'video': video_path,
  71. 'segment': [begin_t, end_t],
  72. 'n_frames': n_frames,
  73. }
  74. step = sample_duration
  75. for i in range(1, (n_frames - sample_duration + 1), step):
  76. sample_i = copy.deepcopy(sample)
  77. sample_i['frame_indices'] = list(range(i, i + sample_duration))
  78. sample_i['segment'] = torch.IntTensor([i, i + sample_duration - 1])
  79. dataset.append(sample_i)
  80. return dataset
  81. class Video(data.Dataset):
  82. def __init__(self, video_path,
  83. spatial_transform=None, temporal_transform=None,
  84. sample_duration=16, get_loader=get_default_video_loader):
  85. self.data = make_dataset(video_path, sample_duration)
  86. self.spatial_transform = spatial_transform
  87. self.temporal_transform = temporal_transform
  88. self.loader = get_loader()
  89. def __getitem__(self, index):
  90. """
  91. Args:
  92. index (int): Index
  93. Returns:
  94. tuple: (image, target) where target is class_index of the target class.
  95. """
  96. path = self.data[index]['video']
  97. frame_indices = self.data[index]['frame_indices']
  98. if self.temporal_transform is not None:
  99. frame_indices = self.temporal_transform(frame_indices)
  100. clip = self.loader(path, frame_indices)
  101. if self.spatial_transform is not None:
  102. clip = [self.spatial_transform(img) for img in clip]
  103. clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
  104. target = self.data[index]['segment']
  105. return clip, target
  106. def __len__(self):
  107. return len(self.data)
  108. class AudioVideo(data.Dataset):
  109. def __init__(self, path: str):
  110. self.path = path
  111. self.data = []
  112. for file_path in glob(f'{path}/*.pkl'):
  113. audios, images, label = pickle.load(open(file_path, 'rb'))
  114. self.data += [(audios[i], images[i], label) for i in range(len(audios))]
  115. def __len__(self):
  116. return len(self.data)
  117. def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
  118. # output format:
  119. # return (
  120. # torch.rand((1, 96, 64)),
  121. # torch.rand((3, 224, 224)),
  122. # np.random.choice([0, 1])
  123. # )
  124. return self.data[idx]