data.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import copy
  2. import functools
  3. import os
  4. import torch
  5. import torch.utils.data as data
  6. from PIL import Image
  7. # import accimage
  8. import json
  9. def pil_loader(path):
  10. # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  11. with open(path, 'rb') as f:
  12. with Image.open(f) as img:
  13. return img.convert('RGB')
  14. def accimage_loader(path):
  15. # try:
  16. # return accimage.Image(path)
  17. # except IOError:
  18. # # Potentially a decoding problem, fall back to PIL.Image
  19. # return pil_loader(path)
  20. return pil_loader(path)
  21. def get_default_image_loader():
  22. from torchvision import get_image_backend
  23. if get_image_backend() == 'accimage':
  24. return accimage_loader
  25. else:
  26. return pil_loader
  27. def video_loader(video_dir_path, frame_indices, image_loader):
  28. video = []
  29. for i in frame_indices:
  30. image_path = os.path.join(video_dir_path, 'image_{:05d}.jpg'.format(i))
  31. if os.path.exists(image_path):
  32. video.append(image_loader(image_path))
  33. else:
  34. return video
  35. return video
  36. def get_default_video_loader():
  37. image_loader = get_default_image_loader()
  38. return functools.partial(video_loader, image_loader=image_loader)
  39. def load_annotation_data(data_file_path):
  40. with open(data_file_path, 'r') as data_file:
  41. return json.load(data_file)
  42. def get_class_labels(data):
  43. class_labels_map = {}
  44. index = 0
  45. for class_label in data['labels']:
  46. class_labels_map[class_label] = index
  47. index += 1
  48. return class_labels_map
  49. def get_video_names_and_annotations(data, subset):
  50. video_names = []
  51. annotations = []
  52. for key, value in data['database'].items():
  53. this_subset = value['subset']
  54. if this_subset == subset:
  55. if subset == 'testing':
  56. video_names.append('test/{}'.format(key))
  57. else:
  58. label = value['annotations']['label']
  59. video_names.append('{}/{}'.format(label, key))
  60. annotations.append(value['annotations'])
  61. return video_names, annotations
  62. def make_dataset(video_path, sample_duration):
  63. dataset = []
  64. n_frames = len(os.listdir(video_path))
  65. begin_t = 1
  66. end_t = n_frames
  67. sample = {
  68. 'video': video_path,
  69. 'segment': [begin_t, end_t],
  70. 'n_frames': n_frames,
  71. }
  72. step = sample_duration
  73. for i in range(1, (n_frames - sample_duration + 1), step):
  74. sample_i = copy.deepcopy(sample)
  75. sample_i['frame_indices'] = list(range(i, i + sample_duration))
  76. sample_i['segment'] = torch.IntTensor([i, i + sample_duration - 1])
  77. dataset.append(sample_i)
  78. return dataset
  79. class Video(data.Dataset):
  80. def __init__(self, video_path,
  81. spatial_transform=None, temporal_transform=None,
  82. sample_duration=16, get_loader=get_default_video_loader):
  83. self.data = make_dataset(video_path, sample_duration)
  84. self.spatial_transform = spatial_transform
  85. self.temporal_transform = temporal_transform
  86. self.loader = get_loader()
  87. def __getitem__(self, index):
  88. """
  89. Args:
  90. index (int): Index
  91. Returns:
  92. tuple: (image, target) where target is class_index of the target class.
  93. """
  94. path = self.data[index]['video']
  95. frame_indices = self.data[index]['frame_indices']
  96. if self.temporal_transform is not None:
  97. frame_indices = self.temporal_transform(frame_indices)
  98. clip = self.loader(path, frame_indices)
  99. if self.spatial_transform is not None:
  100. clip = [self.spatial_transform(img) for img in clip]
  101. clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
  102. target = self.data[index]['segment']
  103. return clip, target
  104. def __len__(self):
  105. return len(self.data)