123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- #!/usr/bin/env python
- """
- @Contact : liuyuqi.gov@msn.cn
- @Time : 2024/03/22 08:06:24
- @License : Copyright © 2017-2022 liuyuqi. All Rights Reserved.
- @Desc : global config
- """
- import secrets
- import warnings
- from pathlib import Path
- from typing import Annotated, Any, Literal
- from pydantic import (AnyUrl, BaseModel, BeforeValidator, Field, HttpUrl,
- PostgresDsn, computed_field, model_validator)
- from pydantic_core import MultiHostUrl
- from pydantic_settings import BaseSettings, SettingsConfigDict
- from typing_extensions import Self
- def parse_cors(v: Any) -> list[str] | str:
- if isinstance(v, str) and not v.startswith("["):
- return [i.strip() for i in v.split(",")]
- elif isinstance(v, list | str):
- return v
- raise ValueError(v)
- class Settings(BaseSettings):
- model_config = SettingsConfigDict(
- env_file="../.env", env_ignore_empty=True, extra="ignore"
- )
- API_V1_STR: str = "/api/v1"
- SECRET_KEY: str = secrets.token_urlsafe(32)
- # 60 minutes * 24 hours * 8 days = 8 days
- ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
- DOMAIN: str = "localhost"
- ENVIRONMENT: Literal["dev", "prod"] = "dev"
- @computed_field # type: ignore[misc]
- @property
- def server_host(self) -> str:
- # Use HTTPS for anything other than local development
- if self.ENVIRONMENT == "local":
- return f"http://{self.DOMAIN}"
- return f"https://{self.DOMAIN}"
- BACKEND_CORS_ORIGINS: Annotated[
- list[AnyUrl] | str, BeforeValidator(parse_cors)
- ] = []
- PROJECT_NAME: str
- SENTRY_DSN: HttpUrl | None = None
- POSTGRES_SERVER: str
- POSTGRES_PORT: int = 5432
- POSTGRES_USER: str
- POSTGRES_PASSWORD: str
- POSTGRES_DB: str = ""
- @computed_field # type: ignore[misc]
- @property
- def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
- return MultiHostUrl.build(
- scheme="postgresql+psycopg",
- username=self.POSTGRES_USER,
- password=self.POSTGRES_PASSWORD,
- host=self.POSTGRES_SERVER,
- port=self.POSTGRES_PORT,
- path=self.POSTGRES_DB,
- )
- SMTP_TLS: bool = True
- SMTP_SSL: bool = False
- SMTP_PORT: int = 587
- SMTP_HOST: str | None = None
- SMTP_USER: str | None = None
- SMTP_PASSWORD: str | None = None
- # TODO: update type to EmailStr when sqlmodel supports it
- EMAILS_FROM_EMAIL: str | None = None
- EMAILS_FROM_NAME: str | None = None
- @model_validator(mode="after")
- def _set_default_emails_from(self) -> Self:
- if not self.EMAILS_FROM_NAME:
- self.EMAILS_FROM_NAME = self.PROJECT_NAME
- return self
- EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
- @computed_field # type: ignore[misc]
- @property
- def emails_enabled(self) -> bool:
- return bool(self.SMTP_HOST and self.EMAILS_FROM_EMAIL)
- # TODO: update type to EmailStr when sqlmodel supports it
- EMAIL_TEST_USER: str = "test@example.com"
- # TODO: update type to EmailStr when sqlmodel supports it
- FIRST_SUPERUSER: str
- FIRST_SUPERUSER_PASSWORD: str
- USERS_OPEN_REGISTRATION: bool = False
- def _check_default_secret(self, var_name: str, value: str | None) -> None:
- if value == "changethis":
- message = (
- f'The value of {var_name} is "changethis", '
- "for security, please change it, at least for deployments."
- )
- if self.ENVIRONMENT == "local":
- warnings.warn(message, stacklevel=1)
- else:
- raise ValueError(message)
- @model_validator(mode="after")
- def _enforce_non_default_secrets(self) -> Self:
- self._check_default_secret("SECRET_KEY", self.SECRET_KEY)
- self._check_default_secret("POSTGRES_PASSWORD", self.POSTGRES_PASSWORD)
- self._check_default_secret(
- "FIRST_SUPERUSER_PASSWORD", self.FIRST_SUPERUSER_PASSWORD
- )
- return self
- # class Config:
- # env_file = "../.env"
- # from_attributes = True
- class AppConfig(BaseModel):
- """ """
- BASE_DIR: Path = Path(__file__).resolve().parent.parent.parent
- SETTINGS_DIR: Path = BASE_DIR.joinpath("settings")
- SETTINGS_DIR.mkdir(parents=True, exist_ok=True)
- LOGS_DIR: Path = BASE_DIR.joinpath("logs")
- LOGS_DIR.mkdir(parents=True, exist_ok=True)
- MODELS_DIR: Path = BASE_DIR.joinpath("models")
- MODELS_DIR.mkdir(parents=True, exist_ok=True)
- # local cache directory to store images or text file
- CACHE_DIR: Path = BASE_DIR.joinpath("cache")
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
- class GlobalConfig(BaseSettings):
- # ENV_STATE: Optional[str] = Field(None, env="ENV_STATE")
- APP_CONFIG: AppConfig = AppConfig()
- API_NAME: str | None = Field(None, env="API_NAME")
- API_DESCRIPTION: str | None = Field(None, env="API_DESCRIPTION")
- API_VERSION: str | None = Field(None, env="API_VERSION")
- API_DEBUG_MODE: bool | None = Field(None, env="API_DEBUG_MODE")
- ENV_STATE: str | None = Field(None, env="ENV_STATE")
- LOG_CONFIG_FILENAME: str | None = Field(None, env="LOG_CONFIG_FILENAME")
- HOST: str | None = None
- PORT: int | None = None
- LOG_LEVEL: str | None = None
- DB: str | None = None
- MOBILENET_V2: str | None = None
- INCEPTION_V3: str | None = None
- class Setting:
- env_file: str = "../.env"
- from_attributes = True
- class DevConfig(GlobalConfig):
- class Setting:
- env_prefix: str = "DEV_"
- class ProdConfig(GlobalConfig):
- class Config:
- env_prefix: str = "PROD_"
- class FactoryConfig:
- def __init__(self, env_state: str | None):
- self.env_state = env_state
- def __call__(self):
- if self.env_state == "dev":
- return DevConfig()
- elif self.env_state == "prod":
- return ProdConfig
- settings = Settings()
- # settings = FactoryConfig(GlobalConfig().ENV_STATE)()
|