demo2.py 747 B

12345678910111213141516171819202122
  1. import gradio as gr
  2. import torch
  3. from torchvision import transforms
  4. import requests
  5. from PIL import Image
  6. model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
  7. # Download human-readable labels for ImageNet.
  8. response = requests.get("https://git.io/JJkYN")
  9. labels = response.text.split("\n")
  10. def predict(inp):
  11. inp = Image.fromarray(inp.astype('uint8'), 'RGB')
  12. inp = transforms.ToTensor()(inp).unsqueeze(0)
  13. with torch.no_grad():
  14. prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
  15. return {labels[i]: float(prediction[i]) for i in range(1000)}
  16. inputs = gr.inputs.Image()
  17. outputs = gr.outputs.Label(num_top_classes=3)
  18. gr.Interface(fn=predict, inputs=inputs, outputs=outputs).launch()