|
|
@@ -1,3 +1,4 @@
|
|
|
+import os
|
|
|
from sqlalchemy import Column, Integer, String, Text, DateTime, Float, create_engine
|
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
@@ -101,13 +102,30 @@ class PriceHistory(Base):
|
|
|
|
|
|
|
|
|
def get_engine():
|
|
|
- db_config = DATABASE_CONFIG
|
|
|
- connection_string = (
|
|
|
- f"mysql+pymysql://{db_config['user']}:{db_config['password']}"
|
|
|
- f"@{db_config['host']}:{db_config['port']}/{db_config['database']}"
|
|
|
- f"?charset={db_config['charset']}"
|
|
|
- )
|
|
|
- return create_engine(connection_string, echo=False)
|
|
|
+ db_type = DATABASE_CONFIG.get('type', 'sqlite')
|
|
|
+
|
|
|
+ if db_type == 'sqlite':
|
|
|
+ sqlite_config = DATABASE_CONFIG.get('sqlite', {})
|
|
|
+ db_path = sqlite_config.get('path', 'data/price_crawler.db')
|
|
|
+
|
|
|
+ db_dir = os.path.dirname(db_path)
|
|
|
+ if db_dir and not os.path.exists(db_dir):
|
|
|
+ os.makedirs(db_dir, exist_ok=True)
|
|
|
+
|
|
|
+ connection_string = f"sqlite:///{db_path}"
|
|
|
+ return create_engine(connection_string, echo=False, connect_args={'check_same_thread': False})
|
|
|
+
|
|
|
+ elif db_type == 'mysql':
|
|
|
+ mysql_config = DATABASE_CONFIG.get('mysql', {})
|
|
|
+ connection_string = (
|
|
|
+ f"mysql+pymysql://{mysql_config.get('user', 'root')}:{mysql_config.get('password', '')}"
|
|
|
+ f"@{mysql_config.get('host', 'localhost')}:{mysql_config.get('port', 3306)}/{mysql_config.get('database', 'price_crawler')}"
|
|
|
+ f"?charset={mysql_config.get('charset', 'utf8mb4')}"
|
|
|
+ )
|
|
|
+ return create_engine(connection_string, echo=False)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的数据库类型: {db_type}")
|
|
|
|
|
|
|
|
|
def init_db():
|