data.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import pickle
  2. from glob import glob
  3. from typing import Tuple, List
  4. import torch
  5. import torch.utils.data as data
  6. class AV(data.Dataset):
  7. def __init__(self, path: str):
  8. self.path = path
  9. self.data = []
  10. def __len__(self):
  11. return len(self.data)
  12. def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
  13. return self.data[idx]
  14. class AudioVideo(AV):
  15. def __init__(self, path: str):
  16. # output format:
  17. # return (
  18. # torch.rand((1, 96, 64)),
  19. # torch.rand((3, 224, 224)),
  20. # np.random.choice([0, 1])
  21. # )
  22. super().__init__(path)
  23. for file_path in glob(f'{path}/*.pkl'):
  24. audios, images, label = pickle.load(open(file_path, 'rb'))
  25. self.data += [(audios[i], images[i], label) for i in range(len(audios))]
  26. class AudioVideo3D(AV):
  27. def __init__(self, path: str):
  28. # output format:
  29. # return (
  30. # torch.rand((1, 96, 64)),
  31. # torch.rand((3, 16, 224, 224)),
  32. # np.random.choice([0, 1])
  33. # )
  34. super().__init__(path)
  35. frames = 16
  36. for file_path in glob(f'{path}/*.pkl'):
  37. audios, images, label = pickle.load(open(file_path, 'rb'))
  38. images_temporal = self._process_temporal_tensor(images, frames)
  39. self.data += [(audios[i], images_temporal[i], label) for i in range(len(audios))]
  40. @staticmethod
  41. def _process_temporal_tensor(images: List[torch.Tensor],
  42. frames: int) -> List[torch.Tensor]:
  43. out = []
  44. for i in range(len(images)):
  45. e = torch.zeros((frames, 3, 224, 224))
  46. e[-1] = images[0]
  47. for j in range(min(i, frames)):
  48. e[-1 - j] = images[j]
  49. # try:
  50. # e[-1 - j] = images[j]
  51. # except:
  52. # raise ValueError(f"trying to get {i} from images with len = {len(images)}")
  53. ee = e.permute((1, 0, 2, 3))
  54. out.append(ee)
  55. return out