| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770 |
- """
- PySpark 方式的词频统计模块
- 使用 PySpark 实现词频统计,这是现代大数据处理的推荐方式:
- - 更简洁的 API
- - 更好的性能
- - 支持更多的数据处理操作
- - 可以与 Spark SQL、MLlib 等集成
- 现代化增强:
- - 配置管理集成
- - 多种数据格式支持(JSON、CSV、Parquet 等)
- - 性能优化配置
- - 数据质量检查
- - 结果持久化到多种存储
- - 命令行工具增强
- 对应 Java 版本的 WordCount 类,但使用更现代的 Spark 框架。
- """
- import sys
- import os
- import json
- from typing import Dict, List, Optional, Tuple, Any, Union
- from collections import defaultdict
- from dataclasses import dataclass, field
- from enum import Enum
- from pathlib import Path
- from ..config import ConfigurationManager, SparkConfig, get_config
- from ..utils.helpers import setup_logger, format_file_size
- class OutputFormat(Enum):
- """输出格式枚举"""
- TEXT = "text"
- JSON = "json"
- CSV = "csv"
- PARQUET = "parquet"
- ORC = "orc"
- class InputFormat(Enum):
- """输入格式枚举"""
- TEXT = "text"
- JSON = "json"
- CSV = "csv"
- PARQUET = "parquet"
- ORC = "orc"
- AUTO = "auto"
- @dataclass
- class WordCountResult:
- """词频统计结果"""
- total_words: int = 0
- unique_words: int = 0
- top_words: List[Tuple[str, int]] = field(default_factory=list)
- word_counts: Dict[str, int] = field(default_factory=dict)
- execution_time_ms: float = 0.0
- input_size_bytes: int = 0
- output_size_bytes: int = 0
-
- @property
- def input_size_formatted(self) -> str:
- """格式化的输入大小"""
- return format_file_size(self.input_size_bytes)
-
- @property
- def output_size_formatted(self) -> str:
- """格式化的输出大小"""
- return format_file_size(self.output_size_bytes)
-
- def to_dict(self) -> Dict[str, Any]:
- """转换为字典"""
- return {
- 'total_words': self.total_words,
- 'unique_words': self.unique_words,
- 'top_words': [{'word': w, 'count': c} for w, c in self.top_words],
- 'word_counts': self.word_counts,
- 'execution_time_ms': self.execution_time_ms,
- 'input_size_bytes': self.input_size_bytes,
- 'input_size_formatted': self.input_size_formatted,
- 'output_size_bytes': self.output_size_bytes,
- 'output_size_formatted': self.output_size_formatted,
- }
-
- def to_json(self, indent: int = 2) -> str:
- """转换为 JSON 字符串"""
- return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
-
- def save_to_file(self, file_path: str, format: OutputFormat = OutputFormat.JSON):
- """保存结果到文件"""
- if format == OutputFormat.JSON:
- with open(file_path, 'w', encoding='utf-8') as f:
- f.write(self.to_json())
- elif format == OutputFormat.CSV:
- with open(file_path, 'w', encoding='utf-8') as f:
- f.write("word,count\n")
- for word, count in sorted(self.word_counts.items()):
- f.write(f"{word},{count}\n")
- elif format == OutputFormat.TEXT:
- with open(file_path, 'w', encoding='utf-8') as f:
- for word, count in sorted(self.word_counts.items()):
- f.write(f"{word}\t{count}\n")
- class WordCountSpark:
- """
- 现代化 PySpark 词频统计类
-
- 特性:
- - 配置管理集成
- - 多种输入输出格式支持
- - 性能优化配置
- - 数据质量检查
- - 详细的统计信息
- - 同步和异步 API
- """
-
- def __init__(self,
- config: Optional[SparkConfig] = None,
- config_manager: Optional[ConfigurationManager] = None,
- app_name: Optional[str] = None,
- master: Optional[str] = None,
- logger_name: str = 'wordcount_spark'):
- """
- 初始化 WordCountSpark 实例
-
- Args:
- config: Spark 配置(可选)
- config_manager: 配置管理器(可选)
- app_name: Spark 应用名称(可选)
- master: Spark 主节点 URL(可选)
- logger_name: 日志器名称
- """
- self.logger = setup_logger(logger_name)
-
- # 获取配置
- if config_manager is None:
- config_manager = get_config()
-
- if config is None:
- config = config_manager.spark
-
- self.config = config
- self._spark = None
- self._sc = None
-
- # 覆盖配置
- if app_name:
- self.config.app_name = app_name
- if master:
- self.config.master = master
-
- @property
- def spark(self):
- """获取 SparkSession 实例(延迟初始化)"""
- if self._spark is None:
- self._init_spark()
- return self._spark
-
- @property
- def sc(self):
- """获取 SparkContext 实例"""
- if self._sc is None:
- self._init_spark()
- return self._sc
-
- def _init_spark(self):
- """初始化 Spark 会话"""
- try:
- from pyspark.sql import SparkSession
- from pyspark import SparkConf
-
- # 创建配置
- conf = SparkConf()
- conf.setAppName(self.config.app_name)
-
- if self.config.master:
- conf.setMaster(self.config.master)
-
- # 应用配置
- conf.set("spark.driver.memory", self.config.driver_memory)
- conf.set("spark.executor.memory", self.config.executor_memory)
- conf.set("spark.executor.cores", str(self.config.executor_cores))
- conf.set("spark.executor.instances", str(self.config.num_executors))
- conf.set("spark.sql.shuffle.partitions", str(self.config.shuffle_partitions))
- conf.set("spark.serializer", self.config.serializer)
- conf.set("spark.kryo.registrationRequired", str(self.config.kryo_registration_required).lower())
-
- if self.config.default_parallelism:
- conf.set("spark.default.parallelism", str(self.config.default_par_par))
-
- # 应用额外配置
- for key, value in self.config.extra_configs.items():
- conf.set(key, value)
-
- # 创建 SparkSession
- builder = SparkSession.builder.config(conf=conf)
- self._spark = builder.getOrCreate()
- self._sc = self._spark.sparkContext
-
- # 设置日志级别
- self._sc.setLogLevel(self.config.log_level)
-
- self.logger.info(f"Spark session initialized: {self.config.app_name}")
- self.logger.info(f"Spark master: {self._sc.master}")
- self.logger.info(f"Spark version: {self._sc.version}")
-
- except ImportError as e:
- self.logger.error(f"PySpark is not installed: {e}")
- raise
- except Exception as e:
- self.logger.error(f"Failed to initialize Spark: {e}")
- raise
-
- def stop(self):
- """停止 Spark 会话"""
- if self._spark:
- self._spark.stop()
- self._spark = None
- self._sc = None
- self.logger.info("Spark session stopped")
-
- def _infer_input_format(self, path: str) -> InputFormat:
- """推断输入格式"""
- path_lower = path.lower()
-
- if path_lower.endswith('.json') or path_lower.endswith('.jsonl'):
- return InputFormat.JSON
- elif path_lower.endswith('.csv'):
- return InputFormat.CSV
- elif path_lower.endswith('.parquet'):
- return InputFormat.PARQUET
- elif path_lower.endswith('.orc'):
- return InputFormat.ORC
- else:
- return InputFormat.TEXT
-
- def _read_input(self, path: str, input_format: InputFormat = InputFormat.AUTO,
- text_column: str = 'value') -> Any:
- """
- 读取输入数据
-
- Args:
- path: 输入路径
- input_format: 输入格式
- text_column: 文本列名(用于结构化格式)
-
- Returns:
- DataFrame 或 RDD
- """
- if input_format == InputFormat.AUTO:
- input_format = self._infer_input_format(path)
-
- self.logger.info(f"Reading input from {path} with format {input_format.value}")
-
- if input_format == InputFormat.JSON:
- return self.spark.read.json(path)
- elif input_format == InputFormat.CSV:
- return self.spark.read.csv(path, header=True, inferSchema=True)
- elif input_format == InputFormat.PARQUET:
- return self.spark.read.parquet(path)
- elif input_format == InputFormat.ORC:
- return self.spark.read.orc(path)
- else:
- # 文本格式
- return self.spark.read.text(path)
-
- def _split_line(self, line: str) -> List[str]:
- """
- 分割一行文本为单词列表
-
- Args:
- line: 输入文本行
-
- Returns:
- 单词列表
- """
- words = []
- # 分割文本为单词(使用空格、制表符等分隔符)
- raw_words = line.strip().split()
- for word in raw_words:
- # 清理单词(移除标点符号,转为小写)
- word = word.strip('.,!?;:()[]{}"\'').lower()
- if word: # 确保单词非空
- words.append(word)
- return words
-
- def count_words_from_rdd(self, text_rdd) -> Dict[str, int]:
- """
- 从 RDD 统计单词
-
- 对应 Java 版本的 WordCount 逻辑,但使用 Spark 的算子。
-
- Args:
- text_rdd: 包含文本的 RDD
-
- Returns:
- 单词计数字典
- """
- # 1. 分割每行文本为单词
- words_rdd = text_rdd.flatMap(self._split_line)
-
- # 2. 映射为 (单词, 1)
- pairs_rdd = words_rdd.map(lambda word: (word, 1))
-
- # 3. 按单词聚合计数
- word_counts_rdd = pairs_rdd.reduceByKey(lambda x, y: x + y)
-
- # 4. 收集结果到本地
- result = word_counts_rdd.collectAsMap()
-
- return dict(result)
-
- def count_words_from_dataframe(self, df, text_column: str = 'value',
- stop_words: Optional[List[str]] = None,
- min_word_length: int = 1,
- max_word_length: int = 100) -> Dict[str, int]:
- """
- 从 DataFrame 统计单词(使用 Spark SQL 风格)
-
- 更高级的 API,支持更多配置选项。
-
- Args:
- df: 包含文本的 DataFrame
- text_column: 包含文本的列名
- stop_words: 停用词列表(可选)
- min_word_length: 最小单词长度
- max_word_length: 最大单词长度
-
- Returns:
- 单词计数字典
- """
- from pyspark.sql.functions import (
- explode, split, lower, trim, regexp_replace, col, count,
- length, lit, array_contains
- )
- from pyspark.sql.types import ArrayType, StringType
-
- # 1. 清理文本(移除标点符号,转为小写)
- df_clean = df.withColumn(
- 'clean_text',
- lower(trim(regexp_replace(col(text_column), '[^a-zA-Z0-9\\s]', ' ')))
- )
-
- # 2. 分割为单词
- df_words = df_clean.withColumn(
- 'word',
- explode(split(col('clean_text'), '\\s+'))
- )
-
- # 3. 过滤空单词和长度限制
- df_filtered = df_words.filter(
- (col('word') != '') &
- (length(col('word')) >= min_word_length) &
- (length(col('word')) <= max_word_length)
- )
-
- # 4. 过滤停用词
- if stop_words:
- # 创建停用词广播变量
- stop_words_broadcast = self.sc.broadcast(set(stop_words))
-
- # 定义 UDF 过滤停用词
- def is_not_stop_word(word):
- return word not in stop_words_broadcast.value
-
- from pyspark.sql.functions import udf
- is_not_stop_word_udf = udf(is_not_stop_word, StringType())
-
- df_filtered = df_filtered.filter(
- ~col('word').isin(stop_words)
- )
-
- # 5. 按单词分组计数
- df_counts = df_filtered.groupBy('word').agg(count('*').alias('count'))
-
- # 6. 收集结果
- result = {row['word']: row['count'] for row in df_counts.collect()}
-
- return result
-
- def run(self,
- input_path: str,
- output_path: Optional[str] = None,
- output_format: OutputFormat = OutputFormat.TEXT,
- input_format: InputFormat = InputFormat.AUTO,
- use_dataframe: bool = True,
- text_column: str = 'value',
- stop_words: Optional[List[str]] = None,
- min_word_length: int = 1,
- save_local_result: bool = False,
- local_result_path: Optional[str] = None) -> WordCountResult:
- """
- 运行完整的 WordCount 作业
-
- Args:
- input_path: 输入路径(可以是本地文件路径或 HDFS 路径)
- output_path: HDFS 输出路径(可选,如果指定则保存结果)
- output_format: 输出格式
- input_format: 输入格式
- use_dataframe: 是否使用 DataFrame API(否则使用 RDD API)
- text_column: 文本列名(用于结构化格式)
- stop_words: 停用词列表(可选)
- min_word_length: 最小单词长度
- save_local_result: 是否保存本地结果
- local_result_path: 本地结果路径(可选)
-
- Returns:
- WordCountResult 对象
- """
- import time
- start_time = time.time()
-
- self.logger.info(f"Running WordCount job on: {input_path}")
-
- # 读取输入
- df = self._read_input(input_path, input_format, text_column)
-
- # 统计单词
- if use_dataframe:
- result = self.count_words_from_dataframe(
- df, text_column, stop_words, min_word_length
- )
- else:
- # 转换为 RDD
- text_rdd = df.select(text_column).rdd.map(lambda row: row[0])
- result = self.count_words_from_rdd(text_rdd)
-
- # 计算统计信息
- execution_time_ms = (time.time() - start_time) * 1000
- total_words = sum(result.values())
- unique_words = len(result)
-
- # 获取 Top 单词
- top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100]
-
- # 创建结果对象
- wc_result = WordCountResult(
- total_words=total_words,
- unique_words=unique_words,
- top_words=top_words,
- word_counts=result,
- execution_time_ms=execution_time_ms,
- )
-
- # 保存到 HDFS(如果指定)
- if output_path:
- self._save_result_to_hdfs(result, output_path, output_format)
- wc_result.output_path = output_path
-
- # 保存到本地(如果指定)
- if save_local_result and local_result_path:
- wc_result.save_to_file(local_result_path, OutputFormat.JSON)
-
- # 打印统计信息
- self._print_statistics(wc_result)
-
- return wc_result
-
- def _save_result_to_hdfs(self, result: Dict[str, int],
- output_path: str,
- output_format: OutputFormat):
- """
- 保存结果到 HDFS
-
- Args:
- result: 单词计数字典
- output_path: 输出路径
- output_format: 输出格式
- """
- from pyspark.sql import Row
-
- self.logger.info(f"Saving results to HDFS: {output_path} (format: {output_format.value})")
-
- # 转换为 DataFrame
- rows = [Row(word=word, count=count) for word, count in sorted(result.items())]
- df = self.spark.createDataFrame(rows)
-
- # 保存
- if output_format == OutputFormat.JSON:
- df.write.json(output_path, mode='overwrite')
- elif output_format == OutputFormat.CSV:
- df.write.csv(output_path, mode='overwrite', header=True)
- elif output_format == OutputFormat.PARQUET:
- df.write.parquet(output_path, mode='overwrite')
- elif output_format == OutputFormat.ORC:
- df.write.orc(output_path, mode='overwrite')
- else:
- # 文本格式
- df.selectExpr("concat_ws('\t', word, count) as value") \
- .write.text(output_path, mode='overwrite')
-
- self.logger.info(f"Results saved to: {output_path}")
-
- def _print_statistics(self, result: WordCountResult):
- """
- 打印统计信息
-
- Args:
- result: 词频统计结果
- """
- if not result.word_counts:
- self.logger.info("No words found")
- return
-
- self.logger.info("=" * 60)
- self.logger.info("WordCount Statistics")
- self.logger.info("=" * 60)
- self.logger.info(f"Total words: {result.total_words:,}")
- self.logger.info(f"Unique words: {result.unique_words:,}")
- self.logger.info(f"Execution time: {result.execution_time_ms:.2f} ms")
- self.logger.info("-" * 60)
- self.logger.info("Top 10 words:")
-
- for i, (word, count) in enumerate(result.top_words[:10], 1):
- percentage = (count / result.total_words) * 100
- self.logger.info(f" {i:2d}. {word:15s} {count:5d} ({percentage:5.1f}%)")
-
- self.logger.info("=" * 60)
-
- def count_words_locally(self, text: str,
- stop_words: Optional[List[str]] = None,
- min_word_length: int = 1) -> Dict[str, int]:
- """
- 本地统计单词(不使用 Spark 集群)
-
- 用于测试和小规模数据处理。
-
- Args:
- text: 输入文本
- stop_words: 停用词列表(可选)
- min_word_length: 最小单词长度
-
- Returns:
- 单词计数字典
- """
- word_counts = defaultdict(int)
- stop_words_set = set(stop_words) if stop_words else set()
-
- for line in text.split('\n'):
- words = self._split_line(line)
- for word in words:
- if (len(word) >= min_word_length and
- word not in stop_words_set):
- word_counts[word] += 1
-
- return dict(word_counts)
-
- def run_with_files(self, files: List[str],
- output_path: Optional[str] = None,
- stop_words: Optional[List[str]] = None,
- min_word_length: int = 1) -> WordCountResult:
- """
- 对多个文件运行词频统计(本地模式)
-
- Args:
- files: 文件路径列表
- output_path: 输出路径(可选)
- stop_words: 停用词列表(可选)
- min_word_length: 最小单词长度
-
- Returns:
- WordCountResult 对象
- """
- import time
- start_time = time.time()
-
- # 合并所有文件的内容
- all_text = ""
- total_size = 0
-
- for file_path in files:
- try:
- with open(file_path, 'r', encoding='utf-8') as f:
- content = f.read()
- all_text += content + "\n"
- total_size += len(content.encode('utf-8'))
- except Exception as e:
- self.logger.warning(f"Failed to read file {file_path}: {e}")
-
- # 本地统计
- result = self.count_words_locally(all_text, stop_words, min_word_length)
-
- # 计算统计信息
- execution_time_ms = (time.time() - start_time) * 1000
- total_words = sum(result.values())
- unique_words = len(result)
-
- # 获取 Top 单词
- top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100]
-
- # 创建结果对象
- wc_result = WordCountResult(
- total_words=total_words,
- unique_words=unique_words,
- top_words=top_words,
- word_counts=result,
- execution_time_ms=execution_time_ms,
- input_size_bytes=total_size,
- )
-
- # 保存结果(如果指定)
- if output_path:
- wc_result.save_to_file(output_path, OutputFormat.JSON)
-
- # 打印统计信息
- self._print_statistics(wc_result)
-
- return wc_result
-
- # 便捷方法
-
- def analyze_text(self, text: str) -> Dict[str, Any]:
- """
- 分析文本,返回详细的统计信息
-
- Args:
- text: 输入文本
-
- Returns:
- 详细的分析结果
- """
- word_counts = self.count_words_locally(text)
-
- # 计算统计信息
- total_words = sum(word_counts.values())
- unique_words = len(word_counts)
-
- # 词汇密度
- lexical_density = unique_words / total_words if total_words > 0 else 0
-
- # 平均词长
- total_chars = sum(len(word) * count for word, count in word_counts.items())
- avg_word_length = total_chars / total_words if total_words > 0 else 0
-
- # 词频分布
- sorted_counts = sorted(word_counts.values(), reverse=True)
-
- return {
- 'total_words': total_words,
- 'unique_words': unique_words,
- 'lexical_density': lexical_density,
- 'avg_word_length': avg_word_length,
- 'top_words': [{'word': w, 'count': c}
- for w, c in sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:20]],
- 'word_frequency_distribution': {
- 'once': sum(1 for c in word_counts.values() if c == 1),
- 'twice': sum(1 for c in word_counts.values() if c == 2),
- 'three_to_ten': sum(1 for c in word_counts.values() if 3 <= c <= 10),
- 'more_than_ten': sum(1 for c in word_counts.values() if c > 10),
- }
- }
- def main():
- """
- 主函数:作为独立脚本运行
-
- 使用方式:
- python wordcount_spark.py [options] <input_path> [output_path]
-
- 选项:
- --local 本地模式(不使用 Spark 集群)
- --format <format> 输出格式:text, json, csv, parquet, orc
- --stop-words <file> 停用词文件路径
- --min-length <n> 最小单词长度
- --json-result <path> 保存 JSON 结果到本地文件
- """
- import argparse
-
- parser = argparse.ArgumentParser(
- description='WordCount with PySpark',
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- Examples:
- # 使用 Spark 集群
- python wordcount_spark.py input.txt output
-
- # 本地模式
- python wordcount_spark.py --local input.txt output.json
-
- # 使用 JSON 格式输出
- python wordcount_spark.py --format json input.txt output
-
- # 使用停用词
- python wordcount_spark.py --stop-words stopwords.txt input.txt
- """
- )
-
- parser.add_argument('input_path', help='Input path (local or HDFS)')
- parser.add_argument('output_path', nargs='?', help='Output path (optional)')
-
- parser.add_argument('--local', action='store_true',
- help='Run in local mode (without Spark cluster)')
- parser.add_argument('--format', choices=['text', 'json', 'csv', 'parquet', 'orc'],
- default='text', help='Output format (default: text)')
- parser.add_argument('--stop-words', help='Path to stop words file')
- parser.add_argument('--min-length', type=int, default=1,
- help='Minimum word length (default: 1)')
- parser.add_argument('--json-result', help='Save JSON result to local file')
- parser.add_argument('--app-name', help='Spark application name')
- parser.add_argument('--master', help='Spark master URL')
-
- args = parser.parse_args()
-
- # 加载停用词
- stop_words = None
- if args.stop_words:
- try:
- with open(args.stop_words, 'r', encoding='utf-8') as f:
- stop_words = [line.strip().lower() for line in f if line.strip()]
- except Exception as e:
- print(f"Warning: Failed to load stop words: {e}")
-
- # 创建实例
- wc = WordCountSpark(
- app_name=args.app_name,
- master=args.master
- )
-
- try:
- if args.local:
- # 本地模式
- result = wc.run_with_files(
- [args.input_path],
- output_path=args.json_result,
- stop_words=stop_words,
- min_word_length=args.min_length
- )
- else:
- # Spark 模式
- output_format = OutputFormat(args.format)
-
- result = wc.run(
- input_path=args.input_path,
- output_path=args.output_path,
- output_format=output_format,
- stop_words=stop_words,
- min_word_length=args.min_length,
- save_local_result=bool(args.json_result),
- local_result_path=args.json_result
- )
-
- # 打印结果摘要
- print("\n" + "=" * 60)
- print("Final Results")
- print("=" * 60)
- print(f"Total words: {result.total_words:,}")
- print(f"Unique words: {result.unique_words:,}")
- print("\nTop 20 words:")
-
- for i, (word, count) in enumerate(result.top_words[:20], 1):
- print(f" {i:2d}. {word:15s} {count:5d}")
-
- print("=" * 60)
-
- # 保存 JSON 结果
- if args.json_result and not args.local:
- result.save_to_file(args.json_result, OutputFormat.JSON)
- print(f"\nJSON result saved to: {args.json_result}")
-
- finally:
- wc.stop()
- if __name__ == '__main__':
- main()
|