|
@@ -1,12 +1,13 @@
|
|
|
-from typing import List, Tuple
|
|
|
+import os
|
|
|
import random
|
|
|
+from typing import List, Tuple
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
-
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from PIL.Image import Image
|
|
|
from matplotlib.pyplot import figure, imshow, axis
|
|
|
+from pytube import YouTube
|
|
|
from torch import nn
|
|
|
|
|
|
from pipeline import BuildDataset
|
|
@@ -105,7 +106,7 @@ class Segmentor:
|
|
|
axis('off')
|
|
|
plt.show()
|
|
|
|
|
|
- def visualize_segments(self, path_video: str, n_to_show: int=10) -> None:
|
|
|
+ def visualize_segments(self, path_video: str, n_to_show: int = 10) -> None:
|
|
|
segments = self.get_segments(path_video)
|
|
|
n_segments = len(segments)
|
|
|
print(f'Found {len(segments)} segments')
|
|
@@ -122,3 +123,18 @@ class Segmentor:
|
|
|
print('Last 10')
|
|
|
self.show_images_horizontally(segment_images[-n_to_show:])
|
|
|
print('=' * 10)
|
|
|
+
|
|
|
+ def visualize_segments_youtube(self,
|
|
|
+ youtube_id: str,
|
|
|
+ n_to_show: int = 10,
|
|
|
+ show_title: bool = True,
|
|
|
+ remove_file: bool = True):
|
|
|
+ yt = YouTube(f'http://youtube.com/watch?v={youtube_id}')
|
|
|
+ if show_title:
|
|
|
+ print(f'Title: {yt.title}')
|
|
|
+ yt_stream = yt.streams.first()
|
|
|
+ path = f'{yt_stream.default_filename}'
|
|
|
+ yt_stream.download()
|
|
|
+ self.visualize_segments(path, n_to_show)
|
|
|
+ if remove_file:
|
|
|
+ os.remove(path)
|