""" PySpark 方式的词频统计模块 使用 PySpark 实现词频统计,这是现代大数据处理的推荐方式: - 更简洁的 API - 更好的性能 - 支持更多的数据处理操作 - 可以与 Spark SQL、MLlib 等集成 现代化增强: - 配置管理集成 - 多种数据格式支持(JSON、CSV、Parquet 等) - 性能优化配置 - 数据质量检查 - 结果持久化到多种存储 - 命令行工具增强 对应 Java 版本的 WordCount 类,但使用更现代的 Spark 框架。 """ import sys import os import json from typing import Dict, List, Optional, Tuple, Any, Union from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from pathlib import Path from ..config import ConfigurationManager, SparkConfig, get_config from ..utils.helpers import setup_logger, format_file_size class OutputFormat(Enum): """输出格式枚举""" TEXT = "text" JSON = "json" CSV = "csv" PARQUET = "parquet" ORC = "orc" class InputFormat(Enum): """输入格式枚举""" TEXT = "text" JSON = "json" CSV = "csv" PARQUET = "parquet" ORC = "orc" AUTO = "auto" @dataclass class WordCountResult: """词频统计结果""" total_words: int = 0 unique_words: int = 0 top_words: List[Tuple[str, int]] = field(default_factory=list) word_counts: Dict[str, int] = field(default_factory=dict) execution_time_ms: float = 0.0 input_size_bytes: int = 0 output_size_bytes: int = 0 @property def input_size_formatted(self) -> str: """格式化的输入大小""" return format_file_size(self.input_size_bytes) @property def output_size_formatted(self) -> str: """格式化的输出大小""" return format_file_size(self.output_size_bytes) def to_dict(self) -> Dict[str, Any]: """转换为字典""" return { 'total_words': self.total_words, 'unique_words': self.unique_words, 'top_words': [{'word': w, 'count': c} for w, c in self.top_words], 'word_counts': self.word_counts, 'execution_time_ms': self.execution_time_ms, 'input_size_bytes': self.input_size_bytes, 'input_size_formatted': self.input_size_formatted, 'output_size_bytes': self.output_size_bytes, 'output_size_formatted': self.output_size_formatted, } def to_json(self, indent: int = 2) -> str: """转换为 JSON 字符串""" return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False) def save_to_file(self, file_path: str, format: OutputFormat = OutputFormat.JSON): """保存结果到文件""" if format == OutputFormat.JSON: with open(file_path, 'w', encoding='utf-8') as f: f.write(self.to_json()) elif format == OutputFormat.CSV: with open(file_path, 'w', encoding='utf-8') as f: f.write("word,count\n") for word, count in sorted(self.word_counts.items()): f.write(f"{word},{count}\n") elif format == OutputFormat.TEXT: with open(file_path, 'w', encoding='utf-8') as f: for word, count in sorted(self.word_counts.items()): f.write(f"{word}\t{count}\n") class WordCountSpark: """ 现代化 PySpark 词频统计类 特性: - 配置管理集成 - 多种输入输出格式支持 - 性能优化配置 - 数据质量检查 - 详细的统计信息 - 同步和异步 API """ def __init__(self, config: Optional[SparkConfig] = None, config_manager: Optional[ConfigurationManager] = None, app_name: Optional[str] = None, master: Optional[str] = None, logger_name: str = 'wordcount_spark'): """ 初始化 WordCountSpark 实例 Args: config: Spark 配置(可选) config_manager: 配置管理器(可选) app_name: Spark 应用名称(可选) master: Spark 主节点 URL(可选) logger_name: 日志器名称 """ self.logger = setup_logger(logger_name) # 获取配置 if config_manager is None: config_manager = get_config() if config is None: config = config_manager.spark self.config = config self._spark = None self._sc = None # 覆盖配置 if app_name: self.config.app_name = app_name if master: self.config.master = master @property def spark(self): """获取 SparkSession 实例(延迟初始化)""" if self._spark is None: self._init_spark() return self._spark @property def sc(self): """获取 SparkContext 实例""" if self._sc is None: self._init_spark() return self._sc def _init_spark(self): """初始化 Spark 会话""" try: from pyspark.sql import SparkSession from pyspark import SparkConf # 创建配置 conf = SparkConf() conf.setAppName(self.config.app_name) if self.config.master: conf.setMaster(self.config.master) # 应用配置 conf.set("spark.driver.memory", self.config.driver_memory) conf.set("spark.executor.memory", self.config.executor_memory) conf.set("spark.executor.cores", str(self.config.executor_cores)) conf.set("spark.executor.instances", str(self.config.num_executors)) conf.set("spark.sql.shuffle.partitions", str(self.config.shuffle_partitions)) conf.set("spark.serializer", self.config.serializer) conf.set("spark.kryo.registrationRequired", str(self.config.kryo_registration_required).lower()) if self.config.default_parallelism: conf.set("spark.default.parallelism", str(self.config.default_par_par)) # 应用额外配置 for key, value in self.config.extra_configs.items(): conf.set(key, value) # 创建 SparkSession builder = SparkSession.builder.config(conf=conf) self._spark = builder.getOrCreate() self._sc = self._spark.sparkContext # 设置日志级别 self._sc.setLogLevel(self.config.log_level) self.logger.info(f"Spark session initialized: {self.config.app_name}") self.logger.info(f"Spark master: {self._sc.master}") self.logger.info(f"Spark version: {self._sc.version}") except ImportError as e: self.logger.error(f"PySpark is not installed: {e}") raise except Exception as e: self.logger.error(f"Failed to initialize Spark: {e}") raise def stop(self): """停止 Spark 会话""" if self._spark: self._spark.stop() self._spark = None self._sc = None self.logger.info("Spark session stopped") def _infer_input_format(self, path: str) -> InputFormat: """推断输入格式""" path_lower = path.lower() if path_lower.endswith('.json') or path_lower.endswith('.jsonl'): return InputFormat.JSON elif path_lower.endswith('.csv'): return InputFormat.CSV elif path_lower.endswith('.parquet'): return InputFormat.PARQUET elif path_lower.endswith('.orc'): return InputFormat.ORC else: return InputFormat.TEXT def _read_input(self, path: str, input_format: InputFormat = InputFormat.AUTO, text_column: str = 'value') -> Any: """ 读取输入数据 Args: path: 输入路径 input_format: 输入格式 text_column: 文本列名(用于结构化格式) Returns: DataFrame 或 RDD """ if input_format == InputFormat.AUTO: input_format = self._infer_input_format(path) self.logger.info(f"Reading input from {path} with format {input_format.value}") if input_format == InputFormat.JSON: return self.spark.read.json(path) elif input_format == InputFormat.CSV: return self.spark.read.csv(path, header=True, inferSchema=True) elif input_format == InputFormat.PARQUET: return self.spark.read.parquet(path) elif input_format == InputFormat.ORC: return self.spark.read.orc(path) else: # 文本格式 return self.spark.read.text(path) def _split_line(self, line: str) -> List[str]: """ 分割一行文本为单词列表 Args: line: 输入文本行 Returns: 单词列表 """ words = [] # 分割文本为单词(使用空格、制表符等分隔符) raw_words = line.strip().split() for word in raw_words: # 清理单词(移除标点符号,转为小写) word = word.strip('.,!?;:()[]{}"\'').lower() if word: # 确保单词非空 words.append(word) return words def count_words_from_rdd(self, text_rdd) -> Dict[str, int]: """ 从 RDD 统计单词 对应 Java 版本的 WordCount 逻辑,但使用 Spark 的算子。 Args: text_rdd: 包含文本的 RDD Returns: 单词计数字典 """ # 1. 分割每行文本为单词 words_rdd = text_rdd.flatMap(self._split_line) # 2. 映射为 (单词, 1) pairs_rdd = words_rdd.map(lambda word: (word, 1)) # 3. 按单词聚合计数 word_counts_rdd = pairs_rdd.reduceByKey(lambda x, y: x + y) # 4. 收集结果到本地 result = word_counts_rdd.collectAsMap() return dict(result) def count_words_from_dataframe(self, df, text_column: str = 'value', stop_words: Optional[List[str]] = None, min_word_length: int = 1, max_word_length: int = 100) -> Dict[str, int]: """ 从 DataFrame 统计单词(使用 Spark SQL 风格) 更高级的 API,支持更多配置选项。 Args: df: 包含文本的 DataFrame text_column: 包含文本的列名 stop_words: 停用词列表(可选) min_word_length: 最小单词长度 max_word_length: 最大单词长度 Returns: 单词计数字典 """ from pyspark.sql.functions import ( explode, split, lower, trim, regexp_replace, col, count, length, lit, array_contains ) from pyspark.sql.types import ArrayType, StringType # 1. 清理文本(移除标点符号,转为小写) df_clean = df.withColumn( 'clean_text', lower(trim(regexp_replace(col(text_column), '[^a-zA-Z0-9\\s]', ' '))) ) # 2. 分割为单词 df_words = df_clean.withColumn( 'word', explode(split(col('clean_text'), '\\s+')) ) # 3. 过滤空单词和长度限制 df_filtered = df_words.filter( (col('word') != '') & (length(col('word')) >= min_word_length) & (length(col('word')) <= max_word_length) ) # 4. 过滤停用词 if stop_words: # 创建停用词广播变量 stop_words_broadcast = self.sc.broadcast(set(stop_words)) # 定义 UDF 过滤停用词 def is_not_stop_word(word): return word not in stop_words_broadcast.value from pyspark.sql.functions import udf is_not_stop_word_udf = udf(is_not_stop_word, StringType()) df_filtered = df_filtered.filter( ~col('word').isin(stop_words) ) # 5. 按单词分组计数 df_counts = df_filtered.groupBy('word').agg(count('*').alias('count')) # 6. 收集结果 result = {row['word']: row['count'] for row in df_counts.collect()} return result def run(self, input_path: str, output_path: Optional[str] = None, output_format: OutputFormat = OutputFormat.TEXT, input_format: InputFormat = InputFormat.AUTO, use_dataframe: bool = True, text_column: str = 'value', stop_words: Optional[List[str]] = None, min_word_length: int = 1, save_local_result: bool = False, local_result_path: Optional[str] = None) -> WordCountResult: """ 运行完整的 WordCount 作业 Args: input_path: 输入路径(可以是本地文件路径或 HDFS 路径) output_path: HDFS 输出路径(可选,如果指定则保存结果) output_format: 输出格式 input_format: 输入格式 use_dataframe: 是否使用 DataFrame API(否则使用 RDD API) text_column: 文本列名(用于结构化格式) stop_words: 停用词列表(可选) min_word_length: 最小单词长度 save_local_result: 是否保存本地结果 local_result_path: 本地结果路径(可选) Returns: WordCountResult 对象 """ import time start_time = time.time() self.logger.info(f"Running WordCount job on: {input_path}") # 读取输入 df = self._read_input(input_path, input_format, text_column) # 统计单词 if use_dataframe: result = self.count_words_from_dataframe( df, text_column, stop_words, min_word_length ) else: # 转换为 RDD text_rdd = df.select(text_column).rdd.map(lambda row: row[0]) result = self.count_words_from_rdd(text_rdd) # 计算统计信息 execution_time_ms = (time.time() - start_time) * 1000 total_words = sum(result.values()) unique_words = len(result) # 获取 Top 单词 top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100] # 创建结果对象 wc_result = WordCountResult( total_words=total_words, unique_words=unique_words, top_words=top_words, word_counts=result, execution_time_ms=execution_time_ms, ) # 保存到 HDFS(如果指定) if output_path: self._save_result_to_hdfs(result, output_path, output_format) wc_result.output_path = output_path # 保存到本地(如果指定) if save_local_result and local_result_path: wc_result.save_to_file(local_result_path, OutputFormat.JSON) # 打印统计信息 self._print_statistics(wc_result) return wc_result def _save_result_to_hdfs(self, result: Dict[str, int], output_path: str, output_format: OutputFormat): """ 保存结果到 HDFS Args: result: 单词计数字典 output_path: 输出路径 output_format: 输出格式 """ from pyspark.sql import Row self.logger.info(f"Saving results to HDFS: {output_path} (format: {output_format.value})") # 转换为 DataFrame rows = [Row(word=word, count=count) for word, count in sorted(result.items())] df = self.spark.createDataFrame(rows) # 保存 if output_format == OutputFormat.JSON: df.write.json(output_path, mode='overwrite') elif output_format == OutputFormat.CSV: df.write.csv(output_path, mode='overwrite', header=True) elif output_format == OutputFormat.PARQUET: df.write.parquet(output_path, mode='overwrite') elif output_format == OutputFormat.ORC: df.write.orc(output_path, mode='overwrite') else: # 文本格式 df.selectExpr("concat_ws('\t', word, count) as value") \ .write.text(output_path, mode='overwrite') self.logger.info(f"Results saved to: {output_path}") def _print_statistics(self, result: WordCountResult): """ 打印统计信息 Args: result: 词频统计结果 """ if not result.word_counts: self.logger.info("No words found") return self.logger.info("=" * 60) self.logger.info("WordCount Statistics") self.logger.info("=" * 60) self.logger.info(f"Total words: {result.total_words:,}") self.logger.info(f"Unique words: {result.unique_words:,}") self.logger.info(f"Execution time: {result.execution_time_ms:.2f} ms") self.logger.info("-" * 60) self.logger.info("Top 10 words:") for i, (word, count) in enumerate(result.top_words[:10], 1): percentage = (count / result.total_words) * 100 self.logger.info(f" {i:2d}. {word:15s} {count:5d} ({percentage:5.1f}%)") self.logger.info("=" * 60) def count_words_locally(self, text: str, stop_words: Optional[List[str]] = None, min_word_length: int = 1) -> Dict[str, int]: """ 本地统计单词(不使用 Spark 集群) 用于测试和小规模数据处理。 Args: text: 输入文本 stop_words: 停用词列表(可选) min_word_length: 最小单词长度 Returns: 单词计数字典 """ word_counts = defaultdict(int) stop_words_set = set(stop_words) if stop_words else set() for line in text.split('\n'): words = self._split_line(line) for word in words: if (len(word) >= min_word_length and word not in stop_words_set): word_counts[word] += 1 return dict(word_counts) def run_with_files(self, files: List[str], output_path: Optional[str] = None, stop_words: Optional[List[str]] = None, min_word_length: int = 1) -> WordCountResult: """ 对多个文件运行词频统计(本地模式) Args: files: 文件路径列表 output_path: 输出路径(可选) stop_words: 停用词列表(可选) min_word_length: 最小单词长度 Returns: WordCountResult 对象 """ import time start_time = time.time() # 合并所有文件的内容 all_text = "" total_size = 0 for file_path in files: try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() all_text += content + "\n" total_size += len(content.encode('utf-8')) except Exception as e: self.logger.warning(f"Failed to read file {file_path}: {e}") # 本地统计 result = self.count_words_locally(all_text, stop_words, min_word_length) # 计算统计信息 execution_time_ms = (time.time() - start_time) * 1000 total_words = sum(result.values()) unique_words = len(result) # 获取 Top 单词 top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100] # 创建结果对象 wc_result = WordCountResult( total_words=total_words, unique_words=unique_words, top_words=top_words, word_counts=result, execution_time_ms=execution_time_ms, input_size_bytes=total_size, ) # 保存结果(如果指定) if output_path: wc_result.save_to_file(output_path, OutputFormat.JSON) # 打印统计信息 self._print_statistics(wc_result) return wc_result # 便捷方法 def analyze_text(self, text: str) -> Dict[str, Any]: """ 分析文本,返回详细的统计信息 Args: text: 输入文本 Returns: 详细的分析结果 """ word_counts = self.count_words_locally(text) # 计算统计信息 total_words = sum(word_counts.values()) unique_words = len(word_counts) # 词汇密度 lexical_density = unique_words / total_words if total_words > 0 else 0 # 平均词长 total_chars = sum(len(word) * count for word, count in word_counts.items()) avg_word_length = total_chars / total_words if total_words > 0 else 0 # 词频分布 sorted_counts = sorted(word_counts.values(), reverse=True) return { 'total_words': total_words, 'unique_words': unique_words, 'lexical_density': lexical_density, 'avg_word_length': avg_word_length, 'top_words': [{'word': w, 'count': c} for w, c in sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:20]], 'word_frequency_distribution': { 'once': sum(1 for c in word_counts.values() if c == 1), 'twice': sum(1 for c in word_counts.values() if c == 2), 'three_to_ten': sum(1 for c in word_counts.values() if 3 <= c <= 10), 'more_than_ten': sum(1 for c in word_counts.values() if c > 10), } } def main(): """ 主函数:作为独立脚本运行 使用方式: python wordcount_spark.py [options] [output_path] 选项: --local 本地模式(不使用 Spark 集群) --format 输出格式:text, json, csv, parquet, orc --stop-words 停用词文件路径 --min-length 最小单词长度 --json-result 保存 JSON 结果到本地文件 """ import argparse parser = argparse.ArgumentParser( description='WordCount with PySpark', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # 使用 Spark 集群 python wordcount_spark.py input.txt output # 本地模式 python wordcount_spark.py --local input.txt output.json # 使用 JSON 格式输出 python wordcount_spark.py --format json input.txt output # 使用停用词 python wordcount_spark.py --stop-words stopwords.txt input.txt """ ) parser.add_argument('input_path', help='Input path (local or HDFS)') parser.add_argument('output_path', nargs='?', help='Output path (optional)') parser.add_argument('--local', action='store_true', help='Run in local mode (without Spark cluster)') parser.add_argument('--format', choices=['text', 'json', 'csv', 'parquet', 'orc'], default='text', help='Output format (default: text)') parser.add_argument('--stop-words', help='Path to stop words file') parser.add_argument('--min-length', type=int, default=1, help='Minimum word length (default: 1)') parser.add_argument('--json-result', help='Save JSON result to local file') parser.add_argument('--app-name', help='Spark application name') parser.add_argument('--master', help='Spark master URL') args = parser.parse_args() # 加载停用词 stop_words = None if args.stop_words: try: with open(args.stop_words, 'r', encoding='utf-8') as f: stop_words = [line.strip().lower() for line in f if line.strip()] except Exception as e: print(f"Warning: Failed to load stop words: {e}") # 创建实例 wc = WordCountSpark( app_name=args.app_name, master=args.master ) try: if args.local: # 本地模式 result = wc.run_with_files( [args.input_path], output_path=args.json_result, stop_words=stop_words, min_word_length=args.min_length ) else: # Spark 模式 output_format = OutputFormat(args.format) result = wc.run( input_path=args.input_path, output_path=args.output_path, output_format=output_format, stop_words=stop_words, min_word_length=args.min_length, save_local_result=bool(args.json_result), local_result_path=args.json_result ) # 打印结果摘要 print("\n" + "=" * 60) print("Final Results") print("=" * 60) print(f"Total words: {result.total_words:,}") print(f"Unique words: {result.unique_words:,}") print("\nTop 20 words:") for i, (word, count) in enumerate(result.top_words[:20], 1): print(f" {i:2d}. {word:15s} {count:5d}") print("=" * 60) # 保存 JSON 结果 if args.json_result and not args.local: result.save_to_file(args.json_result, OutputFormat.JSON) print(f"\nJSON result saved to: {args.json_result}") finally: wc.stop() if __name__ == '__main__': main()