config.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #!/usr/bin/env python
  2. """
  3. @Contact : liuyuqi.gov@msn.cn
  4. @Time : 2024/03/22 08:06:24
  5. @License : Copyright © 2017-2022 liuyuqi. All Rights Reserved.
  6. @Desc : global config
  7. """
  8. import secrets
  9. import warnings
  10. from pathlib import Path
  11. from typing import Annotated, Any, Literal
  12. from pydantic import (AnyUrl, BaseModel, BeforeValidator, Field, HttpUrl,
  13. PostgresDsn, computed_field, model_validator)
  14. from pydantic_core import MultiHostUrl
  15. from pydantic_settings import BaseSettings, SettingsConfigDict
  16. from typing_extensions import Self
  17. def parse_cors(v: Any) -> list[str] | str:
  18. if isinstance(v, str) and not v.startswith("["):
  19. return [i.strip() for i in v.split(",")]
  20. elif isinstance(v, list | str):
  21. return v
  22. raise ValueError(v)
  23. class Settings(BaseSettings):
  24. model_config = SettingsConfigDict(
  25. env_file="../.env", env_ignore_empty=True, extra="ignore"
  26. )
  27. API_V1_STR: str = "/api/v1"
  28. SECRET_KEY: str = secrets.token_urlsafe(32)
  29. # 60 minutes * 24 hours * 8 days = 8 days
  30. ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
  31. DOMAIN: str = "localhost"
  32. ENVIRONMENT: Literal["dev", "prod"] = "dev"
  33. @computed_field # type: ignore[misc]
  34. @property
  35. def server_host(self) -> str:
  36. # Use HTTPS for anything other than local development
  37. if self.ENVIRONMENT == "local":
  38. return f"http://{self.DOMAIN}"
  39. return f"https://{self.DOMAIN}"
  40. BACKEND_CORS_ORIGINS: Annotated[
  41. list[AnyUrl] | str, BeforeValidator(parse_cors)
  42. ] = []
  43. PROJECT_NAME: str
  44. SENTRY_DSN: HttpUrl | None = None
  45. POSTGRES_SERVER: str
  46. POSTGRES_PORT: int = 5432
  47. POSTGRES_USER: str
  48. POSTGRES_PASSWORD: str
  49. POSTGRES_DB: str = ""
  50. @computed_field # type: ignore[misc]
  51. @property
  52. def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
  53. return MultiHostUrl.build(
  54. scheme="postgresql+psycopg",
  55. username=self.POSTGRES_USER,
  56. password=self.POSTGRES_PASSWORD,
  57. host=self.POSTGRES_SERVER,
  58. port=self.POSTGRES_PORT,
  59. path=self.POSTGRES_DB,
  60. )
  61. SMTP_TLS: bool = True
  62. SMTP_SSL: bool = False
  63. SMTP_PORT: int = 587
  64. SMTP_HOST: str | None = None
  65. SMTP_USER: str | None = None
  66. SMTP_PASSWORD: str | None = None
  67. # TODO: update type to EmailStr when sqlmodel supports it
  68. EMAILS_FROM_EMAIL: str | None = None
  69. EMAILS_FROM_NAME: str | None = None
  70. @model_validator(mode="after")
  71. def _set_default_emails_from(self) -> Self:
  72. if not self.EMAILS_FROM_NAME:
  73. self.EMAILS_FROM_NAME = self.PROJECT_NAME
  74. return self
  75. EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48
  76. @computed_field # type: ignore[misc]
  77. @property
  78. def emails_enabled(self) -> bool:
  79. return bool(self.SMTP_HOST and self.EMAILS_FROM_EMAIL)
  80. # TODO: update type to EmailStr when sqlmodel supports it
  81. EMAIL_TEST_USER: str = "test@example.com"
  82. # TODO: update type to EmailStr when sqlmodel supports it
  83. FIRST_SUPERUSER: str
  84. FIRST_SUPERUSER_PASSWORD: str
  85. USERS_OPEN_REGISTRATION: bool = False
  86. def _check_default_secret(self, var_name: str, value: str | None) -> None:
  87. if value == "changethis":
  88. message = (
  89. f'The value of {var_name} is "changethis", '
  90. "for security, please change it, at least for deployments."
  91. )
  92. if self.ENVIRONMENT == "local":
  93. warnings.warn(message, stacklevel=1)
  94. else:
  95. raise ValueError(message)
  96. @model_validator(mode="after")
  97. def _enforce_non_default_secrets(self) -> Self:
  98. self._check_default_secret("SECRET_KEY", self.SECRET_KEY)
  99. self._check_default_secret("POSTGRES_PASSWORD", self.POSTGRES_PASSWORD)
  100. self._check_default_secret(
  101. "FIRST_SUPERUSER_PASSWORD", self.FIRST_SUPERUSER_PASSWORD
  102. )
  103. return self
  104. # class Config:
  105. # env_file = "../.env"
  106. # from_attributes = True
  107. class AppConfig(BaseModel):
  108. """ """
  109. BASE_DIR: Path = Path(__file__).resolve().parent.parent.parent
  110. SETTINGS_DIR: Path = BASE_DIR.joinpath("settings")
  111. SETTINGS_DIR.mkdir(parents=True, exist_ok=True)
  112. LOGS_DIR: Path = BASE_DIR.joinpath("logs")
  113. LOGS_DIR.mkdir(parents=True, exist_ok=True)
  114. MODELS_DIR: Path = BASE_DIR.joinpath("models")
  115. MODELS_DIR.mkdir(parents=True, exist_ok=True)
  116. # local cache directory to store images or text file
  117. CACHE_DIR: Path = BASE_DIR.joinpath("cache")
  118. CACHE_DIR.mkdir(parents=True, exist_ok=True)
  119. class GlobalConfig(BaseSettings):
  120. # ENV_STATE: Optional[str] = Field(None, env="ENV_STATE")
  121. APP_CONFIG: AppConfig = AppConfig()
  122. API_NAME: str | None = Field(None, env="API_NAME")
  123. API_DESCRIPTION: str | None = Field(None, env="API_DESCRIPTION")
  124. API_VERSION: str | None = Field(None, env="API_VERSION")
  125. API_DEBUG_MODE: bool | None = Field(None, env="API_DEBUG_MODE")
  126. ENV_STATE: str | None = Field(None, env="ENV_STATE")
  127. LOG_CONFIG_FILENAME: str | None = Field(None, env="LOG_CONFIG_FILENAME")
  128. HOST: str | None = None
  129. PORT: int | None = None
  130. LOG_LEVEL: str | None = None
  131. DB: str | None = None
  132. MOBILENET_V2: str | None = None
  133. INCEPTION_V3: str | None = None
  134. class Setting:
  135. env_file: str = "../.env"
  136. from_attributes = True
  137. class DevConfig(GlobalConfig):
  138. class Setting:
  139. env_prefix: str = "DEV_"
  140. class ProdConfig(GlobalConfig):
  141. class Config:
  142. env_prefix: str = "PROD_"
  143. class FactoryConfig:
  144. def __init__(self, env_state: str | None):
  145. self.env_state = env_state
  146. def __call__(self):
  147. if self.env_state == "dev":
  148. return DevConfig()
  149. elif self.env_state == "prod":
  150. return ProdConfig
  151. settings = Settings()
  152. # settings = FactoryConfig(GlobalConfig().ENV_STATE)()