#!/usr/bin/env python
"""
@Contact :   liuyuqi.gov@msn.cn
@Time    :   2024/03/25 11:34:35
@License :   Copyright © 2017-2022 liuyuqi. All Rights Reserved.
@Desc    :   image interfence
"""

from typing import Any

from apps.models.item import ItemsOut
from apps.service.image_inference import ImageClassificationService

from fastapi import APIRouter, File, UploadFile

router = APIRouter()

image_inference = ImageClassificationService()


@router.get("/predict", response_model=ItemsOut)
async def predict(file: UploadFile = File(...)) -> Any:
    """
    Predict image category.

    """
    extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not extension:
        return "Image must be jpg or png format!"
    # logger.info('Image Classification')
    image = await BasicImageUtils.read_image_file(
        await file.read(), filename=file.filename, cache=True
    )
    image_category = await image_inference.classify(image)
    return image_category