text_mongo.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """
  2. mongodb中的文本数据。
  3. """
  4. import pandas as pd
  5. from dateutil.parser import parse
  6. from pymongo import MongoClient
  7. from spiderNotices.settings import REMOTEMONGO
  8. class TextMongo(object):
  9. """" 只做数据查询。"""
  10. def __init__(self, uri=REMOTEMONGO['uri']):
  11. self.client = MongoClient(uri)
  12. # 上市公司公告的数据库
  13. self.db_notices = self.client[REMOTEMONGO['notices']]
  14. def get_notices_stk(self):
  15. """ 获取notices数据库下存在的表。"""
  16. coll_names = self.db_notices.list_collection_names(session=None)
  17. coll_names.sort()
  18. return coll_names
  19. def get_notices(self, stk_list=[], begin='', end='', columns=[]):
  20. """
  21. 从mongodb中获取数据。
  22. :param stk_list: xxxxxx.zz或xxxxxx.zzzz格式,切分后取前面数字编码。
  23. :param begin:
  24. :param end:
  25. :param columns:
  26. :return: DataFrame
  27. """
  28. # 循环股票列表
  29. stk_list = list(set(stk_list))
  30. stk_list.sort()
  31. each_list = []
  32. for stk in stk_list:
  33. each = self.get_notices_single(stk, begin=begin, end=end, columns=columns)
  34. if not each.empty:
  35. each_list.append(each)
  36. df = pd.concat(each_list).reset_index(drop=True)
  37. return df
  38. def get_notices_single(self, stk, begin='', end='', columns=[]):
  39. # 数据库表
  40. coll = self.db_notices[stk.split('.')[0]]
  41. # 查询条件
  42. query = {}
  43. if begin:
  44. begin = parse(begin)
  45. if end:
  46. end = parse(end)
  47. query['ann_date'] = {"$gte": begin, "$lte": end}
  48. else:
  49. query['ann_date'] = {"$gte": begin}
  50. else:
  51. if end:
  52. end = parse(end)
  53. query['ann_date'] = {"$lte": end}
  54. else:
  55. pass
  56. # 查询列
  57. if columns:
  58. cursor = coll.find(query, {x: 1 for x in columns}) # query为{}时,全取出
  59. else:
  60. cursor = coll.find(query)
  61. df = pd.DataFrame(list(cursor))
  62. # 整理数据
  63. if '_id' in df.columns:
  64. del df['_id']
  65. df.reset_index(drop=True, inplace=True)
  66. return df
  67. if __name__ == '__main__':
  68. # 单个获取
  69. result = TextMongo().get_notices_single('000001.SZ', '2010-01-01', '2012-12-31')
  70. result = TextMongo().get_notices_single('000001.SZ')
  71. # 多个获取
  72. result = TextMongo().get_notices(['000001.SZ', '000002.SZ'])
  73. # 遍历存有的股票
  74. result = TextMongo().get_notices_stk()