wordcount_spark.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. """
  2. PySpark 方式的词频统计模块
  3. 使用 PySpark 实现词频统计,这是现代大数据处理的推荐方式:
  4. - 更简洁的 API
  5. - 更好的性能
  6. - 支持更多的数据处理操作
  7. - 可以与 Spark SQL、MLlib 等集成
  8. 现代化增强:
  9. - 配置管理集成
  10. - 多种数据格式支持(JSON、CSV、Parquet 等)
  11. - 性能优化配置
  12. - 数据质量检查
  13. - 结果持久化到多种存储
  14. - 命令行工具增强
  15. 对应 Java 版本的 WordCount 类,但使用更现代的 Spark 框架。
  16. """
  17. import sys
  18. import os
  19. import json
  20. from typing import Dict, List, Optional, Tuple, Any, Union
  21. from collections import defaultdict
  22. from dataclasses import dataclass, field
  23. from enum import Enum
  24. from pathlib import Path
  25. from ..config import ConfigurationManager, SparkConfig, get_config
  26. from ..utils.helpers import setup_logger, format_file_size
  27. class OutputFormat(Enum):
  28. """输出格式枚举"""
  29. TEXT = "text"
  30. JSON = "json"
  31. CSV = "csv"
  32. PARQUET = "parquet"
  33. ORC = "orc"
  34. class InputFormat(Enum):
  35. """输入格式枚举"""
  36. TEXT = "text"
  37. JSON = "json"
  38. CSV = "csv"
  39. PARQUET = "parquet"
  40. ORC = "orc"
  41. AUTO = "auto"
  42. @dataclass
  43. class WordCountResult:
  44. """词频统计结果"""
  45. total_words: int = 0
  46. unique_words: int = 0
  47. top_words: List[Tuple[str, int]] = field(default_factory=list)
  48. word_counts: Dict[str, int] = field(default_factory=dict)
  49. execution_time_ms: float = 0.0
  50. input_size_bytes: int = 0
  51. output_size_bytes: int = 0
  52. @property
  53. def input_size_formatted(self) -> str:
  54. """格式化的输入大小"""
  55. return format_file_size(self.input_size_bytes)
  56. @property
  57. def output_size_formatted(self) -> str:
  58. """格式化的输出大小"""
  59. return format_file_size(self.output_size_bytes)
  60. def to_dict(self) -> Dict[str, Any]:
  61. """转换为字典"""
  62. return {
  63. 'total_words': self.total_words,
  64. 'unique_words': self.unique_words,
  65. 'top_words': [{'word': w, 'count': c} for w, c in self.top_words],
  66. 'word_counts': self.word_counts,
  67. 'execution_time_ms': self.execution_time_ms,
  68. 'input_size_bytes': self.input_size_bytes,
  69. 'input_size_formatted': self.input_size_formatted,
  70. 'output_size_bytes': self.output_size_bytes,
  71. 'output_size_formatted': self.output_size_formatted,
  72. }
  73. def to_json(self, indent: int = 2) -> str:
  74. """转换为 JSON 字符串"""
  75. return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
  76. def save_to_file(self, file_path: str, format: OutputFormat = OutputFormat.JSON):
  77. """保存结果到文件"""
  78. if format == OutputFormat.JSON:
  79. with open(file_path, 'w', encoding='utf-8') as f:
  80. f.write(self.to_json())
  81. elif format == OutputFormat.CSV:
  82. with open(file_path, 'w', encoding='utf-8') as f:
  83. f.write("word,count\n")
  84. for word, count in sorted(self.word_counts.items()):
  85. f.write(f"{word},{count}\n")
  86. elif format == OutputFormat.TEXT:
  87. with open(file_path, 'w', encoding='utf-8') as f:
  88. for word, count in sorted(self.word_counts.items()):
  89. f.write(f"{word}\t{count}\n")
  90. class WordCountSpark:
  91. """
  92. 现代化 PySpark 词频统计类
  93. 特性:
  94. - 配置管理集成
  95. - 多种输入输出格式支持
  96. - 性能优化配置
  97. - 数据质量检查
  98. - 详细的统计信息
  99. - 同步和异步 API
  100. """
  101. def __init__(self,
  102. config: Optional[SparkConfig] = None,
  103. config_manager: Optional[ConfigurationManager] = None,
  104. app_name: Optional[str] = None,
  105. master: Optional[str] = None,
  106. logger_name: str = 'wordcount_spark'):
  107. """
  108. 初始化 WordCountSpark 实例
  109. Args:
  110. config: Spark 配置(可选)
  111. config_manager: 配置管理器(可选)
  112. app_name: Spark 应用名称(可选)
  113. master: Spark 主节点 URL(可选)
  114. logger_name: 日志器名称
  115. """
  116. self.logger = setup_logger(logger_name)
  117. # 获取配置
  118. if config_manager is None:
  119. config_manager = get_config()
  120. if config is None:
  121. config = config_manager.spark
  122. self.config = config
  123. self._spark = None
  124. self._sc = None
  125. # 覆盖配置
  126. if app_name:
  127. self.config.app_name = app_name
  128. if master:
  129. self.config.master = master
  130. @property
  131. def spark(self):
  132. """获取 SparkSession 实例(延迟初始化)"""
  133. if self._spark is None:
  134. self._init_spark()
  135. return self._spark
  136. @property
  137. def sc(self):
  138. """获取 SparkContext 实例"""
  139. if self._sc is None:
  140. self._init_spark()
  141. return self._sc
  142. def _init_spark(self):
  143. """初始化 Spark 会话"""
  144. try:
  145. from pyspark.sql import SparkSession
  146. from pyspark import SparkConf
  147. # 创建配置
  148. conf = SparkConf()
  149. conf.setAppName(self.config.app_name)
  150. if self.config.master:
  151. conf.setMaster(self.config.master)
  152. # 应用配置
  153. conf.set("spark.driver.memory", self.config.driver_memory)
  154. conf.set("spark.executor.memory", self.config.executor_memory)
  155. conf.set("spark.executor.cores", str(self.config.executor_cores))
  156. conf.set("spark.executor.instances", str(self.config.num_executors))
  157. conf.set("spark.sql.shuffle.partitions", str(self.config.shuffle_partitions))
  158. conf.set("spark.serializer", self.config.serializer)
  159. conf.set("spark.kryo.registrationRequired", str(self.config.kryo_registration_required).lower())
  160. if self.config.default_parallelism:
  161. conf.set("spark.default.parallelism", str(self.config.default_par_par))
  162. # 应用额外配置
  163. for key, value in self.config.extra_configs.items():
  164. conf.set(key, value)
  165. # 创建 SparkSession
  166. builder = SparkSession.builder.config(conf=conf)
  167. self._spark = builder.getOrCreate()
  168. self._sc = self._spark.sparkContext
  169. # 设置日志级别
  170. self._sc.setLogLevel(self.config.log_level)
  171. self.logger.info(f"Spark session initialized: {self.config.app_name}")
  172. self.logger.info(f"Spark master: {self._sc.master}")
  173. self.logger.info(f"Spark version: {self._sc.version}")
  174. except ImportError as e:
  175. self.logger.error(f"PySpark is not installed: {e}")
  176. raise
  177. except Exception as e:
  178. self.logger.error(f"Failed to initialize Spark: {e}")
  179. raise
  180. def stop(self):
  181. """停止 Spark 会话"""
  182. if self._spark:
  183. self._spark.stop()
  184. self._spark = None
  185. self._sc = None
  186. self.logger.info("Spark session stopped")
  187. def _infer_input_format(self, path: str) -> InputFormat:
  188. """推断输入格式"""
  189. path_lower = path.lower()
  190. if path_lower.endswith('.json') or path_lower.endswith('.jsonl'):
  191. return InputFormat.JSON
  192. elif path_lower.endswith('.csv'):
  193. return InputFormat.CSV
  194. elif path_lower.endswith('.parquet'):
  195. return InputFormat.PARQUET
  196. elif path_lower.endswith('.orc'):
  197. return InputFormat.ORC
  198. else:
  199. return InputFormat.TEXT
  200. def _read_input(self, path: str, input_format: InputFormat = InputFormat.AUTO,
  201. text_column: str = 'value') -> Any:
  202. """
  203. 读取输入数据
  204. Args:
  205. path: 输入路径
  206. input_format: 输入格式
  207. text_column: 文本列名(用于结构化格式)
  208. Returns:
  209. DataFrame 或 RDD
  210. """
  211. if input_format == InputFormat.AUTO:
  212. input_format = self._infer_input_format(path)
  213. self.logger.info(f"Reading input from {path} with format {input_format.value}")
  214. if input_format == InputFormat.JSON:
  215. return self.spark.read.json(path)
  216. elif input_format == InputFormat.CSV:
  217. return self.spark.read.csv(path, header=True, inferSchema=True)
  218. elif input_format == InputFormat.PARQUET:
  219. return self.spark.read.parquet(path)
  220. elif input_format == InputFormat.ORC:
  221. return self.spark.read.orc(path)
  222. else:
  223. # 文本格式
  224. return self.spark.read.text(path)
  225. def _split_line(self, line: str) -> List[str]:
  226. """
  227. 分割一行文本为单词列表
  228. Args:
  229. line: 输入文本行
  230. Returns:
  231. 单词列表
  232. """
  233. words = []
  234. # 分割文本为单词(使用空格、制表符等分隔符)
  235. raw_words = line.strip().split()
  236. for word in raw_words:
  237. # 清理单词(移除标点符号,转为小写)
  238. word = word.strip('.,!?;:()[]{}"\'').lower()
  239. if word: # 确保单词非空
  240. words.append(word)
  241. return words
  242. def count_words_from_rdd(self, text_rdd) -> Dict[str, int]:
  243. """
  244. 从 RDD 统计单词
  245. 对应 Java 版本的 WordCount 逻辑,但使用 Spark 的算子。
  246. Args:
  247. text_rdd: 包含文本的 RDD
  248. Returns:
  249. 单词计数字典
  250. """
  251. # 1. 分割每行文本为单词
  252. words_rdd = text_rdd.flatMap(self._split_line)
  253. # 2. 映射为 (单词, 1)
  254. pairs_rdd = words_rdd.map(lambda word: (word, 1))
  255. # 3. 按单词聚合计数
  256. word_counts_rdd = pairs_rdd.reduceByKey(lambda x, y: x + y)
  257. # 4. 收集结果到本地
  258. result = word_counts_rdd.collectAsMap()
  259. return dict(result)
  260. def count_words_from_dataframe(self, df, text_column: str = 'value',
  261. stop_words: Optional[List[str]] = None,
  262. min_word_length: int = 1,
  263. max_word_length: int = 100) -> Dict[str, int]:
  264. """
  265. 从 DataFrame 统计单词(使用 Spark SQL 风格)
  266. 更高级的 API,支持更多配置选项。
  267. Args:
  268. df: 包含文本的 DataFrame
  269. text_column: 包含文本的列名
  270. stop_words: 停用词列表(可选)
  271. min_word_length: 最小单词长度
  272. max_word_length: 最大单词长度
  273. Returns:
  274. 单词计数字典
  275. """
  276. from pyspark.sql.functions import (
  277. explode, split, lower, trim, regexp_replace, col, count,
  278. length, lit, array_contains
  279. )
  280. from pyspark.sql.types import ArrayType, StringType
  281. # 1. 清理文本(移除标点符号,转为小写)
  282. df_clean = df.withColumn(
  283. 'clean_text',
  284. lower(trim(regexp_replace(col(text_column), '[^a-zA-Z0-9\\s]', ' ')))
  285. )
  286. # 2. 分割为单词
  287. df_words = df_clean.withColumn(
  288. 'word',
  289. explode(split(col('clean_text'), '\\s+'))
  290. )
  291. # 3. 过滤空单词和长度限制
  292. df_filtered = df_words.filter(
  293. (col('word') != '') &
  294. (length(col('word')) >= min_word_length) &
  295. (length(col('word')) <= max_word_length)
  296. )
  297. # 4. 过滤停用词
  298. if stop_words:
  299. # 创建停用词广播变量
  300. stop_words_broadcast = self.sc.broadcast(set(stop_words))
  301. # 定义 UDF 过滤停用词
  302. def is_not_stop_word(word):
  303. return word not in stop_words_broadcast.value
  304. from pyspark.sql.functions import udf
  305. is_not_stop_word_udf = udf(is_not_stop_word, StringType())
  306. df_filtered = df_filtered.filter(
  307. ~col('word').isin(stop_words)
  308. )
  309. # 5. 按单词分组计数
  310. df_counts = df_filtered.groupBy('word').agg(count('*').alias('count'))
  311. # 6. 收集结果
  312. result = {row['word']: row['count'] for row in df_counts.collect()}
  313. return result
  314. def run(self,
  315. input_path: str,
  316. output_path: Optional[str] = None,
  317. output_format: OutputFormat = OutputFormat.TEXT,
  318. input_format: InputFormat = InputFormat.AUTO,
  319. use_dataframe: bool = True,
  320. text_column: str = 'value',
  321. stop_words: Optional[List[str]] = None,
  322. min_word_length: int = 1,
  323. save_local_result: bool = False,
  324. local_result_path: Optional[str] = None) -> WordCountResult:
  325. """
  326. 运行完整的 WordCount 作业
  327. Args:
  328. input_path: 输入路径(可以是本地文件路径或 HDFS 路径)
  329. output_path: HDFS 输出路径(可选,如果指定则保存结果)
  330. output_format: 输出格式
  331. input_format: 输入格式
  332. use_dataframe: 是否使用 DataFrame API(否则使用 RDD API)
  333. text_column: 文本列名(用于结构化格式)
  334. stop_words: 停用词列表(可选)
  335. min_word_length: 最小单词长度
  336. save_local_result: 是否保存本地结果
  337. local_result_path: 本地结果路径(可选)
  338. Returns:
  339. WordCountResult 对象
  340. """
  341. import time
  342. start_time = time.time()
  343. self.logger.info(f"Running WordCount job on: {input_path}")
  344. # 读取输入
  345. df = self._read_input(input_path, input_format, text_column)
  346. # 统计单词
  347. if use_dataframe:
  348. result = self.count_words_from_dataframe(
  349. df, text_column, stop_words, min_word_length
  350. )
  351. else:
  352. # 转换为 RDD
  353. text_rdd = df.select(text_column).rdd.map(lambda row: row[0])
  354. result = self.count_words_from_rdd(text_rdd)
  355. # 计算统计信息
  356. execution_time_ms = (time.time() - start_time) * 1000
  357. total_words = sum(result.values())
  358. unique_words = len(result)
  359. # 获取 Top 单词
  360. top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100]
  361. # 创建结果对象
  362. wc_result = WordCountResult(
  363. total_words=total_words,
  364. unique_words=unique_words,
  365. top_words=top_words,
  366. word_counts=result,
  367. execution_time_ms=execution_time_ms,
  368. )
  369. # 保存到 HDFS(如果指定)
  370. if output_path:
  371. self._save_result_to_hdfs(result, output_path, output_format)
  372. wc_result.output_path = output_path
  373. # 保存到本地(如果指定)
  374. if save_local_result and local_result_path:
  375. wc_result.save_to_file(local_result_path, OutputFormat.JSON)
  376. # 打印统计信息
  377. self._print_statistics(wc_result)
  378. return wc_result
  379. def _save_result_to_hdfs(self, result: Dict[str, int],
  380. output_path: str,
  381. output_format: OutputFormat):
  382. """
  383. 保存结果到 HDFS
  384. Args:
  385. result: 单词计数字典
  386. output_path: 输出路径
  387. output_format: 输出格式
  388. """
  389. from pyspark.sql import Row
  390. self.logger.info(f"Saving results to HDFS: {output_path} (format: {output_format.value})")
  391. # 转换为 DataFrame
  392. rows = [Row(word=word, count=count) for word, count in sorted(result.items())]
  393. df = self.spark.createDataFrame(rows)
  394. # 保存
  395. if output_format == OutputFormat.JSON:
  396. df.write.json(output_path, mode='overwrite')
  397. elif output_format == OutputFormat.CSV:
  398. df.write.csv(output_path, mode='overwrite', header=True)
  399. elif output_format == OutputFormat.PARQUET:
  400. df.write.parquet(output_path, mode='overwrite')
  401. elif output_format == OutputFormat.ORC:
  402. df.write.orc(output_path, mode='overwrite')
  403. else:
  404. # 文本格式
  405. df.selectExpr("concat_ws('\t', word, count) as value") \
  406. .write.text(output_path, mode='overwrite')
  407. self.logger.info(f"Results saved to: {output_path}")
  408. def _print_statistics(self, result: WordCountResult):
  409. """
  410. 打印统计信息
  411. Args:
  412. result: 词频统计结果
  413. """
  414. if not result.word_counts:
  415. self.logger.info("No words found")
  416. return
  417. self.logger.info("=" * 60)
  418. self.logger.info("WordCount Statistics")
  419. self.logger.info("=" * 60)
  420. self.logger.info(f"Total words: {result.total_words:,}")
  421. self.logger.info(f"Unique words: {result.unique_words:,}")
  422. self.logger.info(f"Execution time: {result.execution_time_ms:.2f} ms")
  423. self.logger.info("-" * 60)
  424. self.logger.info("Top 10 words:")
  425. for i, (word, count) in enumerate(result.top_words[:10], 1):
  426. percentage = (count / result.total_words) * 100
  427. self.logger.info(f" {i:2d}. {word:15s} {count:5d} ({percentage:5.1f}%)")
  428. self.logger.info("=" * 60)
  429. def count_words_locally(self, text: str,
  430. stop_words: Optional[List[str]] = None,
  431. min_word_length: int = 1) -> Dict[str, int]:
  432. """
  433. 本地统计单词(不使用 Spark 集群)
  434. 用于测试和小规模数据处理。
  435. Args:
  436. text: 输入文本
  437. stop_words: 停用词列表(可选)
  438. min_word_length: 最小单词长度
  439. Returns:
  440. 单词计数字典
  441. """
  442. word_counts = defaultdict(int)
  443. stop_words_set = set(stop_words) if stop_words else set()
  444. for line in text.split('\n'):
  445. words = self._split_line(line)
  446. for word in words:
  447. if (len(word) >= min_word_length and
  448. word not in stop_words_set):
  449. word_counts[word] += 1
  450. return dict(word_counts)
  451. def run_with_files(self, files: List[str],
  452. output_path: Optional[str] = None,
  453. stop_words: Optional[List[str]] = None,
  454. min_word_length: int = 1) -> WordCountResult:
  455. """
  456. 对多个文件运行词频统计(本地模式)
  457. Args:
  458. files: 文件路径列表
  459. output_path: 输出路径(可选)
  460. stop_words: 停用词列表(可选)
  461. min_word_length: 最小单词长度
  462. Returns:
  463. WordCountResult 对象
  464. """
  465. import time
  466. start_time = time.time()
  467. # 合并所有文件的内容
  468. all_text = ""
  469. total_size = 0
  470. for file_path in files:
  471. try:
  472. with open(file_path, 'r', encoding='utf-8') as f:
  473. content = f.read()
  474. all_text += content + "\n"
  475. total_size += len(content.encode('utf-8'))
  476. except Exception as e:
  477. self.logger.warning(f"Failed to read file {file_path}: {e}")
  478. # 本地统计
  479. result = self.count_words_locally(all_text, stop_words, min_word_length)
  480. # 计算统计信息
  481. execution_time_ms = (time.time() - start_time) * 1000
  482. total_words = sum(result.values())
  483. unique_words = len(result)
  484. # 获取 Top 单词
  485. top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100]
  486. # 创建结果对象
  487. wc_result = WordCountResult(
  488. total_words=total_words,
  489. unique_words=unique_words,
  490. top_words=top_words,
  491. word_counts=result,
  492. execution_time_ms=execution_time_ms,
  493. input_size_bytes=total_size,
  494. )
  495. # 保存结果(如果指定)
  496. if output_path:
  497. wc_result.save_to_file(output_path, OutputFormat.JSON)
  498. # 打印统计信息
  499. self._print_statistics(wc_result)
  500. return wc_result
  501. # 便捷方法
  502. def analyze_text(self, text: str) -> Dict[str, Any]:
  503. """
  504. 分析文本,返回详细的统计信息
  505. Args:
  506. text: 输入文本
  507. Returns:
  508. 详细的分析结果
  509. """
  510. word_counts = self.count_words_locally(text)
  511. # 计算统计信息
  512. total_words = sum(word_counts.values())
  513. unique_words = len(word_counts)
  514. # 词汇密度
  515. lexical_density = unique_words / total_words if total_words > 0 else 0
  516. # 平均词长
  517. total_chars = sum(len(word) * count for word, count in word_counts.items())
  518. avg_word_length = total_chars / total_words if total_words > 0 else 0
  519. # 词频分布
  520. sorted_counts = sorted(word_counts.values(), reverse=True)
  521. return {
  522. 'total_words': total_words,
  523. 'unique_words': unique_words,
  524. 'lexical_density': lexical_density,
  525. 'avg_word_length': avg_word_length,
  526. 'top_words': [{'word': w, 'count': c}
  527. for w, c in sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:20]],
  528. 'word_frequency_distribution': {
  529. 'once': sum(1 for c in word_counts.values() if c == 1),
  530. 'twice': sum(1 for c in word_counts.values() if c == 2),
  531. 'three_to_ten': sum(1 for c in word_counts.values() if 3 <= c <= 10),
  532. 'more_than_ten': sum(1 for c in word_counts.values() if c > 10),
  533. }
  534. }
  535. def main():
  536. """
  537. 主函数:作为独立脚本运行
  538. 使用方式:
  539. python wordcount_spark.py [options] <input_path> [output_path]
  540. 选项:
  541. --local 本地模式(不使用 Spark 集群)
  542. --format <format> 输出格式:text, json, csv, parquet, orc
  543. --stop-words <file> 停用词文件路径
  544. --min-length <n> 最小单词长度
  545. --json-result <path> 保存 JSON 结果到本地文件
  546. """
  547. import argparse
  548. parser = argparse.ArgumentParser(
  549. description='WordCount with PySpark',
  550. formatter_class=argparse.RawDescriptionHelpFormatter,
  551. epilog="""
  552. Examples:
  553. # 使用 Spark 集群
  554. python wordcount_spark.py input.txt output
  555. # 本地模式
  556. python wordcount_spark.py --local input.txt output.json
  557. # 使用 JSON 格式输出
  558. python wordcount_spark.py --format json input.txt output
  559. # 使用停用词
  560. python wordcount_spark.py --stop-words stopwords.txt input.txt
  561. """
  562. )
  563. parser.add_argument('input_path', help='Input path (local or HDFS)')
  564. parser.add_argument('output_path', nargs='?', help='Output path (optional)')
  565. parser.add_argument('--local', action='store_true',
  566. help='Run in local mode (without Spark cluster)')
  567. parser.add_argument('--format', choices=['text', 'json', 'csv', 'parquet', 'orc'],
  568. default='text', help='Output format (default: text)')
  569. parser.add_argument('--stop-words', help='Path to stop words file')
  570. parser.add_argument('--min-length', type=int, default=1,
  571. help='Minimum word length (default: 1)')
  572. parser.add_argument('--json-result', help='Save JSON result to local file')
  573. parser.add_argument('--app-name', help='Spark application name')
  574. parser.add_argument('--master', help='Spark master URL')
  575. args = parser.parse_args()
  576. # 加载停用词
  577. stop_words = None
  578. if args.stop_words:
  579. try:
  580. with open(args.stop_words, 'r', encoding='utf-8') as f:
  581. stop_words = [line.strip().lower() for line in f if line.strip()]
  582. except Exception as e:
  583. print(f"Warning: Failed to load stop words: {e}")
  584. # 创建实例
  585. wc = WordCountSpark(
  586. app_name=args.app_name,
  587. master=args.master
  588. )
  589. try:
  590. if args.local:
  591. # 本地模式
  592. result = wc.run_with_files(
  593. [args.input_path],
  594. output_path=args.json_result,
  595. stop_words=stop_words,
  596. min_word_length=args.min_length
  597. )
  598. else:
  599. # Spark 模式
  600. output_format = OutputFormat(args.format)
  601. result = wc.run(
  602. input_path=args.input_path,
  603. output_path=args.output_path,
  604. output_format=output_format,
  605. stop_words=stop_words,
  606. min_word_length=args.min_length,
  607. save_local_result=bool(args.json_result),
  608. local_result_path=args.json_result
  609. )
  610. # 打印结果摘要
  611. print("\n" + "=" * 60)
  612. print("Final Results")
  613. print("=" * 60)
  614. print(f"Total words: {result.total_words:,}")
  615. print(f"Unique words: {result.unique_words:,}")
  616. print("\nTop 20 words:")
  617. for i, (word, count) in enumerate(result.top_words[:20], 1):
  618. print(f" {i:2d}. {word:15s} {count:5d}")
  619. print("=" * 60)
  620. # 保存 JSON 结果
  621. if args.json_result and not args.local:
  622. result.save_to_file(args.json_result, OutputFormat.JSON)
  623. print(f"\nJSON result saved to: {args.json_result}")
  624. finally:
  625. wc.stop()
  626. if __name__ == '__main__':
  627. main()