db_utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import logging
  2. from typing import Dict, List, Optional
  3. from datetime import datetime
  4. from sqlalchemy.exc import IntegrityError
  5. from models.product import Product, PriceHistory, get_session
  6. class DBUtils:
  7. def __init__(self):
  8. self.logger = logging.getLogger(__name__)
  9. self.session = get_session()
  10. def __del__(self):
  11. if hasattr(self, 'session'):
  12. self.session.close()
  13. def add_product(self, product_data: Dict) -> Optional[Product]:
  14. """
  15. 添加或更新商品信息
  16. :param product_data: 商品数据字典
  17. :return: Product对象或None
  18. """
  19. try:
  20. existing = self.session.query(Product).filter(
  21. Product.product_id == product_data.get('product_id'),
  22. Product.platform == product_data.get('platform')
  23. ).first()
  24. if existing:
  25. for key, value in product_data.items():
  26. if hasattr(existing, key) and value is not None:
  27. setattr(existing, key, value)
  28. existing.update_time = datetime.now()
  29. self.session.commit()
  30. self.logger.info(f"更新商品信息: {product_data.get('product_id')}")
  31. return existing
  32. else:
  33. product = Product(**product_data)
  34. self.session.add(product)
  35. self.session.commit()
  36. self.logger.info(f"添加新商品: {product_data.get('product_id')}")
  37. return product
  38. except IntegrityError as e:
  39. self.session.rollback()
  40. self.logger.warning(f"商品已存在,跳过: {product_data.get('product_id')} - {e}")
  41. return None
  42. except Exception as e:
  43. self.session.rollback()
  44. self.logger.error(f"添加商品失败: {e}")
  45. return None
  46. def add_products_batch(self, products_data: List[Dict]) -> int:
  47. """
  48. 批量添加商品
  49. :param products_data: 商品数据列表
  50. :return: 成功添加的数量
  51. """
  52. success_count = 0
  53. for product_data in products_data:
  54. if self.add_product(product_data):
  55. success_count += 1
  56. return success_count
  57. def add_price_history(self, price_data: Dict) -> Optional[PriceHistory]:
  58. """
  59. 添加价格历史记录
  60. :param price_data: 价格数据字典
  61. :return: PriceHistory对象或None
  62. """
  63. try:
  64. price_history = PriceHistory(**price_data)
  65. self.session.add(price_history)
  66. self.session.commit()
  67. self.logger.debug(f"添加价格历史: product_id={price_data.get('product_id')}, price={price_data.get('price')}")
  68. return price_history
  69. except Exception as e:
  70. self.session.rollback()
  71. self.logger.error(f"添加价格历史失败: {e}")
  72. return None
  73. def add_price_history_batch(self, prices_data: List[Dict]) -> int:
  74. """
  75. 批量添加价格历史
  76. :param prices_data: 价格数据列表
  77. :return: 成功添加的数量
  78. """
  79. success_count = 0
  80. for price_data in prices_data:
  81. if self.add_price_history(price_data):
  82. success_count += 1
  83. return success_count
  84. def get_product_by_id(self, product_id: str, platform: str) -> Optional[Product]:
  85. """
  86. 根据商品ID和平台获取商品
  87. :param product_id: 商品ID
  88. :param platform: 平台
  89. :return: Product对象或None
  90. """
  91. try:
  92. return self.session.query(Product).filter(
  93. Product.product_id == product_id,
  94. Product.platform == platform
  95. ).first()
  96. except Exception as e:
  97. self.logger.error(f"获取商品失败: {e}")
  98. return None
  99. def get_products_by_platform(self, platform: str, limit: int = 100) -> List[Dict]:
  100. """
  101. 获取指定平台的商品
  102. :param platform: 平台
  103. :param limit: 返回数量限制
  104. :return: 商品列表
  105. """
  106. try:
  107. products = self.session.query(Product).filter(
  108. Product.platform == platform
  109. ).order_by(Product.update_time.desc()).limit(limit).all()
  110. return [p.to_dict() for p in products]
  111. except Exception as e:
  112. self.logger.error(f"获取商品列表失败: {e}")
  113. return []
  114. def get_product_count(self, platform: str = None) -> int:
  115. """
  116. 获取商品总数
  117. :param platform: 平台,可选
  118. :return: 商品数量
  119. """
  120. try:
  121. query = self.session.query(Product)
  122. if platform:
  123. query = query.filter(Product.platform == platform)
  124. return query.count()
  125. except Exception as e:
  126. self.logger.error(f"获取商品数量失败: {e}")
  127. return 0
  128. def get_price_history(self, product_id: str, platform: str, limit: int = 30) -> List[Dict]:
  129. """
  130. 获取商品的价格历史
  131. :param product_id: 商品ID
  132. :param platform: 平台
  133. :param limit: 返回数量限制
  134. :return: 价格历史列表
  135. """
  136. try:
  137. history = self.session.query(PriceHistory).filter(
  138. PriceHistory.product_id == product_id,
  139. PriceHistory.platform == platform
  140. ).order_by(PriceHistory.crawl_time.desc()).limit(limit).all()
  141. return [h.to_dict() for h in history]
  142. except Exception as e:
  143. self.logger.error(f"获取价格历史失败: {e}")
  144. return []
  145. def get_all_products(self, limit: int = 100, offset: int = 0) -> List[Dict]:
  146. """
  147. 获取所有商品
  148. :param limit: 返回数量限制
  149. :param offset: 偏移量
  150. :return: 商品列表
  151. """
  152. try:
  153. products = self.session.query(Product).order_by(
  154. Product.update_time.desc()
  155. ).offset(offset).limit(limit).all()
  156. return [p.to_dict() for p in products]
  157. except Exception as e:
  158. self.logger.error(f"获取所有商品失败: {e}")
  159. return []
  160. def search_products(self, keyword: str, platform: str = None, limit: int = 50) -> List[Dict]:
  161. """
  162. 搜索商品
  163. :param keyword: 搜索关键词
  164. :param platform: 平台,可选
  165. :param limit: 返回数量限制
  166. :return: 商品列表
  167. """
  168. try:
  169. query = self.session.query(Product).filter(
  170. Product.name.like(f'%{keyword}%')
  171. )
  172. if platform:
  173. query = query.filter(Product.platform == platform)
  174. products = query.order_by(Product.update_time.desc()).limit(limit).all()
  175. return [p.to_dict() for p in products]
  176. except Exception as e:
  177. self.logger.error(f"搜索商品失败: {e}")
  178. return []