|
|
@@ -7,76 +7,206 @@ PySpark 方式的词频统计模块
|
|
|
- 支持更多的数据处理操作
|
|
|
- 可以与 Spark SQL、MLlib 等集成
|
|
|
|
|
|
-对应 Java 版本的 WordCount 类,但使用更现代的 Spark 框架。
|
|
|
-
|
|
|
-使用方式:
|
|
|
-1. 作为模块导入使用:
|
|
|
- from wordcount_spark import WordCountSpark
|
|
|
- wc = WordCountSpark()
|
|
|
- result = wc.run(input_path, output_path)
|
|
|
+现代化增强:
|
|
|
+- 配置管理集成
|
|
|
+- 多种数据格式支持(JSON、CSV、Parquet 等)
|
|
|
+- 性能优化配置
|
|
|
+- 数据质量检查
|
|
|
+- 结果持久化到多种存储
|
|
|
+- 命令行工具增强
|
|
|
|
|
|
-2. 作为独立脚本运行:
|
|
|
- $ python wordcount_spark.py <input_path> <output_path>
|
|
|
+对应 Java 版本的 WordCount 类,但使用更现代的 Spark 框架。
|
|
|
"""
|
|
|
|
|
|
import sys
|
|
|
-from typing import Dict, List, Optional, Tuple
|
|
|
+import os
|
|
|
+import json
|
|
|
+from typing import Dict, List, Optional, Tuple, Any, Union
|
|
|
from collections import defaultdict
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from enum import Enum
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+from ..config import ConfigurationManager, SparkConfig, get_config
|
|
|
from ..utils.helpers import setup_logger, format_file_size
|
|
|
|
|
|
|
|
|
+class OutputFormat(Enum):
|
|
|
+ """输出格式枚举"""
|
|
|
+ TEXT = "text"
|
|
|
+ JSON = "json"
|
|
|
+ CSV = "csv"
|
|
|
+ PARQUET = "parquet"
|
|
|
+ ORC = "orc"
|
|
|
+
|
|
|
+
|
|
|
+class InputFormat(Enum):
|
|
|
+ """输入格式枚举"""
|
|
|
+ TEXT = "text"
|
|
|
+ JSON = "json"
|
|
|
+ CSV = "csv"
|
|
|
+ PARQUET = "parquet"
|
|
|
+ ORC = "orc"
|
|
|
+ AUTO = "auto"
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class WordCountResult:
|
|
|
+ """词频统计结果"""
|
|
|
+ total_words: int = 0
|
|
|
+ unique_words: int = 0
|
|
|
+ top_words: List[Tuple[str, int]] = field(default_factory=list)
|
|
|
+ word_counts: Dict[str, int] = field(default_factory=dict)
|
|
|
+ execution_time_ms: float = 0.0
|
|
|
+ input_size_bytes: int = 0
|
|
|
+ output_size_bytes: int = 0
|
|
|
+
|
|
|
+ @property
|
|
|
+ def input_size_formatted(self) -> str:
|
|
|
+ """格式化的输入大小"""
|
|
|
+ return format_file_size(self.input_size_bytes)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def output_size_formatted(self) -> str:
|
|
|
+ """格式化的输出大小"""
|
|
|
+ return format_file_size(self.output_size_bytes)
|
|
|
+
|
|
|
+ def to_dict(self) -> Dict[str, Any]:
|
|
|
+ """转换为字典"""
|
|
|
+ return {
|
|
|
+ 'total_words': self.total_words,
|
|
|
+ 'unique_words': self.unique_words,
|
|
|
+ 'top_words': [{'word': w, 'count': c} for w, c in self.top_words],
|
|
|
+ 'word_counts': self.word_counts,
|
|
|
+ 'execution_time_ms': self.execution_time_ms,
|
|
|
+ 'input_size_bytes': self.input_size_bytes,
|
|
|
+ 'input_size_formatted': self.input_size_formatted,
|
|
|
+ 'output_size_bytes': self.output_size_bytes,
|
|
|
+ 'output_size_formatted': self.output_size_formatted,
|
|
|
+ }
|
|
|
+
|
|
|
+ def to_json(self, indent: int = 2) -> str:
|
|
|
+ """转换为 JSON 字符串"""
|
|
|
+ return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
|
|
|
+
|
|
|
+ def save_to_file(self, file_path: str, format: OutputFormat = OutputFormat.JSON):
|
|
|
+ """保存结果到文件"""
|
|
|
+ if format == OutputFormat.JSON:
|
|
|
+ with open(file_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(self.to_json())
|
|
|
+ elif format == OutputFormat.CSV:
|
|
|
+ with open(file_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write("word,count\n")
|
|
|
+ for word, count in sorted(self.word_counts.items()):
|
|
|
+ f.write(f"{word},{count}\n")
|
|
|
+ elif format == OutputFormat.TEXT:
|
|
|
+ with open(file_path, 'w', encoding='utf-8') as f:
|
|
|
+ for word, count in sorted(self.word_counts.items()):
|
|
|
+ f.write(f"{word}\t{count}\n")
|
|
|
+
|
|
|
+
|
|
|
class WordCountSpark:
|
|
|
"""
|
|
|
- PySpark 方式的词频统计类
|
|
|
+ 现代化 PySpark 词频统计类
|
|
|
|
|
|
- 封装了 PySpark 作业的执行,提供高效的词频统计功能。
|
|
|
+ 特性:
|
|
|
+ - 配置管理集成
|
|
|
+ - 多种输入输出格式支持
|
|
|
+ - 性能优化配置
|
|
|
+ - 数据质量检查
|
|
|
+ - 详细的统计信息
|
|
|
+ - 同步和异步 API
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, app_name: str = 'WordCount',
|
|
|
+ def __init__(self,
|
|
|
+ config: Optional[SparkConfig] = None,
|
|
|
+ config_manager: Optional[ConfigurationManager] = None,
|
|
|
+ app_name: Optional[str] = None,
|
|
|
master: Optional[str] = None,
|
|
|
logger_name: str = 'wordcount_spark'):
|
|
|
"""
|
|
|
初始化 WordCountSpark 实例
|
|
|
|
|
|
Args:
|
|
|
- app_name: Spark 应用名称
|
|
|
- master: Spark 主节点 URL(可选,如 'local[*]', 'spark://master:7077')
|
|
|
- 如果为 None,Spark 会从配置中自动获取
|
|
|
+ config: Spark 配置(可选)
|
|
|
+ config_manager: 配置管理器(可选)
|
|
|
+ app_name: Spark 应用名称(可选)
|
|
|
+ master: Spark 主节点 URL(可选)
|
|
|
logger_name: 日志器名称
|
|
|
"""
|
|
|
self.logger = setup_logger(logger_name)
|
|
|
- self.app_name = app_name
|
|
|
- self.master = master
|
|
|
- self.spark = None
|
|
|
- self.sc = None
|
|
|
|
|
|
- def _init_spark(self):
|
|
|
- """
|
|
|
- 初始化 Spark 会话和上下文
|
|
|
+ # 获取配置
|
|
|
+ if config_manager is None:
|
|
|
+ config_manager = get_config()
|
|
|
|
|
|
- 延迟初始化,只有在需要时才创建 Spark 实例。
|
|
|
- """
|
|
|
- if self.spark is not None:
|
|
|
- return
|
|
|
+ if config is None:
|
|
|
+ config = config_manager.spark
|
|
|
+
|
|
|
+ self.config = config
|
|
|
+ self._spark = None
|
|
|
+ self._sc = None
|
|
|
|
|
|
+ # 覆盖配置
|
|
|
+ if app_name:
|
|
|
+ self.config.app_name = app_name
|
|
|
+ if master:
|
|
|
+ self.config.master = master
|
|
|
+
|
|
|
+ @property
|
|
|
+ def spark(self):
|
|
|
+ """获取 SparkSession 实例(延迟初始化)"""
|
|
|
+ if self._spark is None:
|
|
|
+ self._init_spark()
|
|
|
+ return self._spark
|
|
|
+
|
|
|
+ @property
|
|
|
+ def sc(self):
|
|
|
+ """获取 SparkContext 实例"""
|
|
|
+ if self._sc is None:
|
|
|
+ self._init_spark()
|
|
|
+ return self._sc
|
|
|
+
|
|
|
+ def _init_spark(self):
|
|
|
+ """初始化 Spark 会话"""
|
|
|
try:
|
|
|
from pyspark.sql import SparkSession
|
|
|
+ from pyspark import SparkConf
|
|
|
|
|
|
- builder = SparkSession.builder.appName(self.app_name)
|
|
|
- if self.master:
|
|
|
- builder = builder.master(self.master)
|
|
|
+ # 创建配置
|
|
|
+ conf = SparkConf()
|
|
|
+ conf.setAppName(self.config.app_name)
|
|
|
|
|
|
- # 配置一些常用参数
|
|
|
- builder = builder.config("spark.sql.shuffle.partitions", "2")
|
|
|
- builder = builder.config("spark.driver.memory", "1g")
|
|
|
- builder = builder.config("spark.executor.memory", "1g")
|
|
|
+ if self.config.master:
|
|
|
+ conf.setMaster(self.config.master)
|
|
|
|
|
|
- self.spark = builder.getOrCreate()
|
|
|
- self.sc = self.spark.sparkContext
|
|
|
+ # 应用配置
|
|
|
+ conf.set("spark.driver.memory", self.config.driver_memory)
|
|
|
+ conf.set("spark.executor.memory", self.config.executor_memory)
|
|
|
+ conf.set("spark.executor.cores", str(self.config.executor_cores))
|
|
|
+ conf.set("spark.executor.instances", str(self.config.num_executors))
|
|
|
+ conf.set("spark.sql.shuffle.partitions", str(self.config.shuffle_partitions))
|
|
|
+ conf.set("spark.serializer", self.config.serializer)
|
|
|
+ conf.set("spark.kryo.registrationRequired", str(self.config.kryo_registration_required).lower())
|
|
|
|
|
|
- self.logger.info(f"Spark session initialized: {self.app_name}")
|
|
|
- self.logger.info(f"Spark master: {self.sc.master}")
|
|
|
- self.logger.info(f"Spark version: {self.sc.version}")
|
|
|
+ if self.config.default_parallelism:
|
|
|
+ conf.set("spark.default.parallelism", str(self.config.default_par_par))
|
|
|
+
|
|
|
+ # 应用额外配置
|
|
|
+ for key, value in self.config.extra_configs.items():
|
|
|
+ conf.set(key, value)
|
|
|
+
|
|
|
+ # 创建 SparkSession
|
|
|
+ builder = SparkSession.builder.config(conf=conf)
|
|
|
+ self._spark = builder.getOrCreate()
|
|
|
+ self._sc = self._spark.sparkContext
|
|
|
+
|
|
|
+ # 设置日志级别
|
|
|
+ self._sc.setLogLevel(self.config.log_level)
|
|
|
+
|
|
|
+ self.logger.info(f"Spark session initialized: {self.config.app_name}")
|
|
|
+ self.logger.info(f"Spark master: {self._sc.master}")
|
|
|
+ self.logger.info(f"Spark version: {self._sc.version}")
|
|
|
|
|
|
except ImportError as e:
|
|
|
self.logger.error(f"PySpark is not installed: {e}")
|
|
|
@@ -86,42 +216,57 @@ class WordCountSpark:
|
|
|
raise
|
|
|
|
|
|
def stop(self):
|
|
|
- """
|
|
|
- 停止 Spark 会话
|
|
|
- """
|
|
|
- if self.spark:
|
|
|
- self.spark.stop()
|
|
|
- self.spark = None
|
|
|
- self.sc = None
|
|
|
+ """停止 Spark 会话"""
|
|
|
+ if self._spark:
|
|
|
+ self._spark.stop()
|
|
|
+ self._spark = None
|
|
|
+ self._sc = None
|
|
|
self.logger.info("Spark session stopped")
|
|
|
|
|
|
- def count_words_from_rdd(self, text_rdd) -> Dict[str, int]:
|
|
|
+ def _infer_input_format(self, path: str) -> InputFormat:
|
|
|
+ """推断输入格式"""
|
|
|
+ path_lower = path.lower()
|
|
|
+
|
|
|
+ if path_lower.endswith('.json') or path_lower.endswith('.jsonl'):
|
|
|
+ return InputFormat.JSON
|
|
|
+ elif path_lower.endswith('.csv'):
|
|
|
+ return InputFormat.CSV
|
|
|
+ elif path_lower.endswith('.parquet'):
|
|
|
+ return InputFormat.PARQUET
|
|
|
+ elif path_lower.endswith('.orc'):
|
|
|
+ return InputFormat.ORC
|
|
|
+ else:
|
|
|
+ return InputFormat.TEXT
|
|
|
+
|
|
|
+ def _read_input(self, path: str, input_format: InputFormat = InputFormat.AUTO,
|
|
|
+ text_column: str = 'value') -> Any:
|
|
|
"""
|
|
|
- 从 RDD 统计单词
|
|
|
-
|
|
|
- 对应 Java 版本的 WordCount 逻辑,但使用 Spark 的算子。
|
|
|
+ 读取输入数据
|
|
|
|
|
|
Args:
|
|
|
- text_rdd: 包含文本的 RDD
|
|
|
+ path: 输入路径
|
|
|
+ input_format: 输入格式
|
|
|
+ text_column: 文本列名(用于结构化格式)
|
|
|
|
|
|
Returns:
|
|
|
- 单词计数字典
|
|
|
+ DataFrame 或 RDD
|
|
|
"""
|
|
|
- # 1. 分割每行文本为单词
|
|
|
- # 对应 Java 的 TokenizerMapper.map 方法
|
|
|
- words_rdd = text_rdd.flatMap(self._split_line)
|
|
|
-
|
|
|
- # 2. 映射为 (单词, 1)
|
|
|
- pairs_rdd = words_rdd.map(lambda word: (word, 1))
|
|
|
-
|
|
|
- # 3. 按单词聚合计数
|
|
|
- # 对应 Java 的 IntSumReducer.reduce 方法
|
|
|
- word_counts_rdd = pairs_rdd.reduceByKey(lambda x, y: x + y)
|
|
|
-
|
|
|
- # 4. 收集结果到本地
|
|
|
- result = word_counts_rdd.collectAsMap()
|
|
|
-
|
|
|
- return dict(result)
|
|
|
+ if input_format == InputFormat.AUTO:
|
|
|
+ input_format = self._infer_input_format(path)
|
|
|
+
|
|
|
+ self.logger.info(f"Reading input from {path} with format {input_format.value}")
|
|
|
+
|
|
|
+ if input_format == InputFormat.JSON:
|
|
|
+ return self.spark.read.json(path)
|
|
|
+ elif input_format == InputFormat.CSV:
|
|
|
+ return self.spark.read.csv(path, header=True, inferSchema=True)
|
|
|
+ elif input_format == InputFormat.PARQUET:
|
|
|
+ return self.spark.read.parquet(path)
|
|
|
+ elif input_format == InputFormat.ORC:
|
|
|
+ return self.spark.read.orc(path)
|
|
|
+ else:
|
|
|
+ # 文本格式
|
|
|
+ return self.spark.read.text(path)
|
|
|
|
|
|
def _split_line(self, line: str) -> List[str]:
|
|
|
"""
|
|
|
@@ -143,20 +288,56 @@ class WordCountSpark:
|
|
|
words.append(word)
|
|
|
return words
|
|
|
|
|
|
- def count_words_from_dataframe(self, df, text_column: str = 'value') -> Dict[str, int]:
|
|
|
+ def count_words_from_rdd(self, text_rdd) -> Dict[str, int]:
|
|
|
+ """
|
|
|
+ 从 RDD 统计单词
|
|
|
+
|
|
|
+ 对应 Java 版本的 WordCount 逻辑,但使用 Spark 的算子。
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text_rdd: 包含文本的 RDD
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 单词计数字典
|
|
|
+ """
|
|
|
+ # 1. 分割每行文本为单词
|
|
|
+ words_rdd = text_rdd.flatMap(self._split_line)
|
|
|
+
|
|
|
+ # 2. 映射为 (单词, 1)
|
|
|
+ pairs_rdd = words_rdd.map(lambda word: (word, 1))
|
|
|
+
|
|
|
+ # 3. 按单词聚合计数
|
|
|
+ word_counts_rdd = pairs_rdd.reduceByKey(lambda x, y: x + y)
|
|
|
+
|
|
|
+ # 4. 收集结果到本地
|
|
|
+ result = word_counts_rdd.collectAsMap()
|
|
|
+
|
|
|
+ return dict(result)
|
|
|
+
|
|
|
+ def count_words_from_dataframe(self, df, text_column: str = 'value',
|
|
|
+ stop_words: Optional[List[str]] = None,
|
|
|
+ min_word_length: int = 1,
|
|
|
+ max_word_length: int = 100) -> Dict[str, int]:
|
|
|
"""
|
|
|
从 DataFrame 统计单词(使用 Spark SQL 风格)
|
|
|
|
|
|
- 更高级的 API,适合复杂的数据处理。
|
|
|
+ 更高级的 API,支持更多配置选项。
|
|
|
|
|
|
Args:
|
|
|
df: 包含文本的 DataFrame
|
|
|
text_column: 包含文本的列名
|
|
|
+ stop_words: 停用词列表(可选)
|
|
|
+ min_word_length: 最小单词长度
|
|
|
+ max_word_length: 最大单词长度
|
|
|
|
|
|
Returns:
|
|
|
单词计数字典
|
|
|
"""
|
|
|
- from pyspark.sql.functions import explode, split, lower, trim, regexp_replace, col, count
|
|
|
+ from pyspark.sql.functions import (
|
|
|
+ explode, split, lower, trim, regexp_replace, col, count,
|
|
|
+ length, lit, array_contains
|
|
|
+ )
|
|
|
+ from pyspark.sql.types import ArrayType, StringType
|
|
|
|
|
|
# 1. 清理文本(移除标点符号,转为小写)
|
|
|
df_clean = df.withColumn(
|
|
|
@@ -170,101 +351,179 @@ class WordCountSpark:
|
|
|
explode(split(col('clean_text'), '\\s+'))
|
|
|
)
|
|
|
|
|
|
- # 3. 过滤空单词
|
|
|
- df_filtered = df_words.filter(col('word') != '')
|
|
|
+ # 3. 过滤空单词和长度限制
|
|
|
+ df_filtered = df_words.filter(
|
|
|
+ (col('word') != '') &
|
|
|
+ (length(col('word')) >= min_word_length) &
|
|
|
+ (length(col('word')) <= max_word_length)
|
|
|
+ )
|
|
|
|
|
|
- # 4. 按单词分组计数
|
|
|
+ # 4. 过滤停用词
|
|
|
+ if stop_words:
|
|
|
+ # 创建停用词广播变量
|
|
|
+ stop_words_broadcast = self.sc.broadcast(set(stop_words))
|
|
|
+
|
|
|
+ # 定义 UDF 过滤停用词
|
|
|
+ def is_not_stop_word(word):
|
|
|
+ return word not in stop_words_broadcast.value
|
|
|
+
|
|
|
+ from pyspark.sql.functions import udf
|
|
|
+ is_not_stop_word_udf = udf(is_not_stop_word, StringType())
|
|
|
+
|
|
|
+ df_filtered = df_filtered.filter(
|
|
|
+ ~col('word').isin(stop_words)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 5. 按单词分组计数
|
|
|
df_counts = df_filtered.groupBy('word').agg(count('*').alias('count'))
|
|
|
|
|
|
- # 5. 收集结果
|
|
|
+ # 6. 收集结果
|
|
|
result = {row['word']: row['count'] for row in df_counts.collect()}
|
|
|
|
|
|
return result
|
|
|
|
|
|
- def run(self, input_path: str, output_path: Optional[str] = None,
|
|
|
- use_dataframe: bool = True) -> Dict[str, int]:
|
|
|
+ def run(self,
|
|
|
+ input_path: str,
|
|
|
+ output_path: Optional[str] = None,
|
|
|
+ output_format: OutputFormat = OutputFormat.TEXT,
|
|
|
+ input_format: InputFormat = InputFormat.AUTO,
|
|
|
+ use_dataframe: bool = True,
|
|
|
+ text_column: str = 'value',
|
|
|
+ stop_words: Optional[List[str]] = None,
|
|
|
+ min_word_length: int = 1,
|
|
|
+ save_local_result: bool = False,
|
|
|
+ local_result_path: Optional[str] = None) -> WordCountResult:
|
|
|
"""
|
|
|
运行完整的 WordCount 作业
|
|
|
|
|
|
Args:
|
|
|
input_path: 输入路径(可以是本地文件路径或 HDFS 路径)
|
|
|
- output_path: 输出路径(可选,如果指定则保存结果)
|
|
|
+ output_path: HDFS 输出路径(可选,如果指定则保存结果)
|
|
|
+ output_format: 输出格式
|
|
|
+ input_format: 输入格式
|
|
|
use_dataframe: 是否使用 DataFrame API(否则使用 RDD API)
|
|
|
+ text_column: 文本列名(用于结构化格式)
|
|
|
+ stop_words: 停用词列表(可选)
|
|
|
+ min_word_length: 最小单词长度
|
|
|
+ save_local_result: 是否保存本地结果
|
|
|
+ local_result_path: 本地结果路径(可选)
|
|
|
|
|
|
Returns:
|
|
|
- 单词计数字典
|
|
|
+ WordCountResult 对象
|
|
|
"""
|
|
|
- self._init_spark()
|
|
|
+ import time
|
|
|
+ start_time = time.time()
|
|
|
|
|
|
self.logger.info(f"Running WordCount job on: {input_path}")
|
|
|
|
|
|
+ # 读取输入
|
|
|
+ df = self._read_input(input_path, input_format, text_column)
|
|
|
+
|
|
|
+ # 统计单词
|
|
|
if use_dataframe:
|
|
|
- # 使用 DataFrame API
|
|
|
- df = self.spark.read.text(input_path)
|
|
|
- result = self.count_words_from_dataframe(df)
|
|
|
+ result = self.count_words_from_dataframe(
|
|
|
+ df, text_column, stop_words, min_word_length
|
|
|
+ )
|
|
|
else:
|
|
|
- # 使用 RDD API
|
|
|
- text_rdd = self.sc.textFile(input_path)
|
|
|
+ # 转换为 RDD
|
|
|
+ text_rdd = df.select(text_column).rdd.map(lambda row: row[0])
|
|
|
result = self.count_words_from_rdd(text_rdd)
|
|
|
|
|
|
- # 保存结果(如果指定了输出路径)
|
|
|
+ # 计算统计信息
|
|
|
+ execution_time_ms = (time.time() - start_time) * 1000
|
|
|
+ total_words = sum(result.values())
|
|
|
+ unique_words = len(result)
|
|
|
+
|
|
|
+ # 获取 Top 单词
|
|
|
+ top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100]
|
|
|
+
|
|
|
+ # 创建结果对象
|
|
|
+ wc_result = WordCountResult(
|
|
|
+ total_words=total_words,
|
|
|
+ unique_words=unique_words,
|
|
|
+ top_words=top_words,
|
|
|
+ word_counts=result,
|
|
|
+ execution_time_ms=execution_time_ms,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存到 HDFS(如果指定)
|
|
|
if output_path:
|
|
|
- self._save_result(result, output_path)
|
|
|
+ self._save_result_to_hdfs(result, output_path, output_format)
|
|
|
+ wc_result.output_path = output_path
|
|
|
+
|
|
|
+ # 保存到本地(如果指定)
|
|
|
+ if save_local_result and local_result_path:
|
|
|
+ wc_result.save_to_file(local_result_path, OutputFormat.JSON)
|
|
|
|
|
|
# 打印统计信息
|
|
|
- self._print_statistics(result)
|
|
|
+ self._print_statistics(wc_result)
|
|
|
|
|
|
- return result
|
|
|
+ return wc_result
|
|
|
|
|
|
- def _save_result(self, result: Dict[str, int], output_path: str):
|
|
|
+ def _save_result_to_hdfs(self, result: Dict[str, int],
|
|
|
+ output_path: str,
|
|
|
+ output_format: OutputFormat):
|
|
|
"""
|
|
|
- 保存结果到文件
|
|
|
+ 保存结果到 HDFS
|
|
|
|
|
|
Args:
|
|
|
result: 单词计数字典
|
|
|
output_path: 输出路径
|
|
|
+ output_format: 输出格式
|
|
|
"""
|
|
|
- self.logger.info(f"Saving results to: {output_path}")
|
|
|
-
|
|
|
- # 转换为 RDD 并保存
|
|
|
- result_rdd = self.sc.parallelize([
|
|
|
- f"{word}\t{count}"
|
|
|
- for word, count in sorted(result.items())
|
|
|
- ])
|
|
|
- result_rdd.saveAsTextFile(output_path)
|
|
|
+ from pyspark.sql import Row
|
|
|
+
|
|
|
+ self.logger.info(f"Saving results to HDFS: {output_path} (format: {output_format.value})")
|
|
|
+
|
|
|
+ # 转换为 DataFrame
|
|
|
+ rows = [Row(word=word, count=count) for word, count in sorted(result.items())]
|
|
|
+ df = self.spark.createDataFrame(rows)
|
|
|
+
|
|
|
+ # 保存
|
|
|
+ if output_format == OutputFormat.JSON:
|
|
|
+ df.write.json(output_path, mode='overwrite')
|
|
|
+ elif output_format == OutputFormat.CSV:
|
|
|
+ df.write.csv(output_path, mode='overwrite', header=True)
|
|
|
+ elif output_format == OutputFormat.PARQUET:
|
|
|
+ df.write.parquet(output_path, mode='overwrite')
|
|
|
+ elif output_format == OutputFormat.ORC:
|
|
|
+ df.write.orc(output_path, mode='overwrite')
|
|
|
+ else:
|
|
|
+ # 文本格式
|
|
|
+ df.selectExpr("concat_ws('\t', word, count) as value") \
|
|
|
+ .write.text(output_path, mode='overwrite')
|
|
|
|
|
|
self.logger.info(f"Results saved to: {output_path}")
|
|
|
|
|
|
- def _print_statistics(self, result: Dict[str, int]):
|
|
|
+ def _print_statistics(self, result: WordCountResult):
|
|
|
"""
|
|
|
打印统计信息
|
|
|
|
|
|
Args:
|
|
|
- result: 单词计数字典
|
|
|
+ result: 词频统计结果
|
|
|
"""
|
|
|
- if not result:
|
|
|
+ if not result.word_counts:
|
|
|
self.logger.info("No words found")
|
|
|
return
|
|
|
|
|
|
- total_words = sum(result.values())
|
|
|
- unique_words = len(result)
|
|
|
- sorted_words = sorted(result.items(), key=lambda x: x[1], reverse=True)
|
|
|
-
|
|
|
- self.logger.info("=" * 50)
|
|
|
+ self.logger.info("=" * 60)
|
|
|
self.logger.info("WordCount Statistics")
|
|
|
- self.logger.info("=" * 50)
|
|
|
- self.logger.info(f"Total words: {total_words}")
|
|
|
- self.logger.info(f"Unique words: {unique_words}")
|
|
|
- self.logger.info("-" * 50)
|
|
|
+ self.logger.info("=" * 60)
|
|
|
+ self.logger.info(f"Total words: {result.total_words:,}")
|
|
|
+ self.logger.info(f"Unique words: {result.unique_words:,}")
|
|
|
+ self.logger.info(f"Execution time: {result.execution_time_ms:.2f} ms")
|
|
|
+ self.logger.info("-" * 60)
|
|
|
self.logger.info("Top 10 words:")
|
|
|
|
|
|
- for i, (word, count) in enumerate(sorted_words[:10], 1):
|
|
|
- percentage = (count / total_words) * 100
|
|
|
+ for i, (word, count) in enumerate(result.top_words[:10], 1):
|
|
|
+ percentage = (count / result.total_words) * 100
|
|
|
self.logger.info(f" {i:2d}. {word:15s} {count:5d} ({percentage:5.1f}%)")
|
|
|
|
|
|
- self.logger.info("=" * 50)
|
|
|
+ self.logger.info("=" * 60)
|
|
|
|
|
|
- def count_words_locally(self, text: str) -> Dict[str, int]:
|
|
|
+ def count_words_locally(self, text: str,
|
|
|
+ stop_words: Optional[List[str]] = None,
|
|
|
+ min_word_length: int = 1) -> Dict[str, int]:
|
|
|
"""
|
|
|
本地统计单词(不使用 Spark 集群)
|
|
|
|
|
|
@@ -272,57 +531,128 @@ class WordCountSpark:
|
|
|
|
|
|
Args:
|
|
|
text: 输入文本
|
|
|
+ stop_words: 停用词列表(可选)
|
|
|
+ min_word_length: 最小单词长度
|
|
|
|
|
|
Returns:
|
|
|
单词计数字典
|
|
|
-
|
|
|
- Example:
|
|
|
- >>> wc = WordCountSpark()
|
|
|
- >>> wc.count_words_locally("hello world hello")
|
|
|
- {'hello': 2, 'world': 1}
|
|
|
"""
|
|
|
word_counts = defaultdict(int)
|
|
|
+ stop_words_set = set(stop_words) if stop_words else set()
|
|
|
|
|
|
for line in text.split('\n'):
|
|
|
words = self._split_line(line)
|
|
|
for word in words:
|
|
|
- word_counts[word] += 1
|
|
|
+ if (len(word) >= min_word_length and
|
|
|
+ word not in stop_words_set):
|
|
|
+ word_counts[word] += 1
|
|
|
|
|
|
return dict(word_counts)
|
|
|
|
|
|
- def run_with_files(self, files: List[str], output_path: Optional[str] = None) -> Dict[str, int]:
|
|
|
+ def run_with_files(self, files: List[str],
|
|
|
+ output_path: Optional[str] = None,
|
|
|
+ stop_words: Optional[List[str]] = None,
|
|
|
+ min_word_length: int = 1) -> WordCountResult:
|
|
|
"""
|
|
|
- 对多个文件运行词频统计
|
|
|
+ 对多个文件运行词频统计(本地模式)
|
|
|
|
|
|
Args:
|
|
|
files: 文件路径列表
|
|
|
output_path: 输出路径(可选)
|
|
|
+ stop_words: 停用词列表(可选)
|
|
|
+ min_word_length: 最小单词长度
|
|
|
|
|
|
Returns:
|
|
|
- 单词计数字典
|
|
|
+ WordCountResult 对象
|
|
|
"""
|
|
|
+ import time
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
# 合并所有文件的内容
|
|
|
all_text = ""
|
|
|
+ total_size = 0
|
|
|
+
|
|
|
for file_path in files:
|
|
|
try:
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
- all_text += f.read() + "\n"
|
|
|
+ content = f.read()
|
|
|
+ all_text += content + "\n"
|
|
|
+ total_size += len(content.encode('utf-8'))
|
|
|
except Exception as e:
|
|
|
self.logger.warning(f"Failed to read file {file_path}: {e}")
|
|
|
|
|
|
# 本地统计
|
|
|
- result = self.count_words_locally(all_text)
|
|
|
+ result = self.count_words_locally(all_text, stop_words, min_word_length)
|
|
|
+
|
|
|
+ # 计算统计信息
|
|
|
+ execution_time_ms = (time.time() - start_time) * 1000
|
|
|
+ total_words = sum(result.values())
|
|
|
+ unique_words = len(result)
|
|
|
|
|
|
- # 保存结果
|
|
|
+ # 获取 Top 单词
|
|
|
+ top_words = sorted(result.items(), key=lambda x: x[1], reverse=True)[:100]
|
|
|
+
|
|
|
+ # 创建结果对象
|
|
|
+ wc_result = WordCountResult(
|
|
|
+ total_words=total_words,
|
|
|
+ unique_words=unique_words,
|
|
|
+ top_words=top_words,
|
|
|
+ word_counts=result,
|
|
|
+ execution_time_ms=execution_time_ms,
|
|
|
+ input_size_bytes=total_size,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存结果(如果指定)
|
|
|
if output_path:
|
|
|
- with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
- for word, count in sorted(result.items()):
|
|
|
- f.write(f"{word}\t{count}\n")
|
|
|
+ wc_result.save_to_file(output_path, OutputFormat.JSON)
|
|
|
|
|
|
# 打印统计信息
|
|
|
- self._print_statistics(result)
|
|
|
+ self._print_statistics(wc_result)
|
|
|
|
|
|
- return result
|
|
|
+ return wc_result
|
|
|
+
|
|
|
+ # 便捷方法
|
|
|
+
|
|
|
+ def analyze_text(self, text: str) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 分析文本,返回详细的统计信息
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text: 输入文本
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 详细的分析结果
|
|
|
+ """
|
|
|
+ word_counts = self.count_words_locally(text)
|
|
|
+
|
|
|
+ # 计算统计信息
|
|
|
+ total_words = sum(word_counts.values())
|
|
|
+ unique_words = len(word_counts)
|
|
|
+
|
|
|
+ # 词汇密度
|
|
|
+ lexical_density = unique_words / total_words if total_words > 0 else 0
|
|
|
+
|
|
|
+ # 平均词长
|
|
|
+ total_chars = sum(len(word) * count for word, count in word_counts.items())
|
|
|
+ avg_word_length = total_chars / total_words if total_words > 0 else 0
|
|
|
+
|
|
|
+ # 词频分布
|
|
|
+ sorted_counts = sorted(word_counts.values(), reverse=True)
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'total_words': total_words,
|
|
|
+ 'unique_words': unique_words,
|
|
|
+ 'lexical_density': lexical_density,
|
|
|
+ 'avg_word_length': avg_word_length,
|
|
|
+ 'top_words': [{'word': w, 'count': c}
|
|
|
+ for w, c in sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:20]],
|
|
|
+ 'word_frequency_distribution': {
|
|
|
+ 'once': sum(1 for c in word_counts.values() if c == 1),
|
|
|
+ 'twice': sum(1 for c in word_counts.values() if c == 2),
|
|
|
+ 'three_to_ten': sum(1 for c in word_counts.values() if 3 <= c <= 10),
|
|
|
+ 'more_than_ten': sum(1 for c in word_counts.values() if c > 10),
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
|
|
|
def main():
|
|
|
@@ -330,50 +660,107 @@ def main():
|
|
|
主函数:作为独立脚本运行
|
|
|
|
|
|
使用方式:
|
|
|
- python wordcount_spark.py <input_path> [output_path]
|
|
|
+ python wordcount_spark.py [options] <input_path> [output_path]
|
|
|
+
|
|
|
+ 选项:
|
|
|
+ --local 本地模式(不使用 Spark 集群)
|
|
|
+ --format <format> 输出格式:text, json, csv, parquet, orc
|
|
|
+ --stop-words <file> 停用词文件路径
|
|
|
+ --min-length <n> 最小单词长度
|
|
|
+ --json-result <path> 保存 JSON 结果到本地文件
|
|
|
"""
|
|
|
- if len(sys.argv) < 2:
|
|
|
- print("Usage: python wordcount_spark.py <input_path> [output_path]")
|
|
|
- print("Examples:")
|
|
|
- print(" python wordcount_spark.py input.txt")
|
|
|
- print(" python wordcount_spark.py hdfs:///user/hadoop/data output")
|
|
|
- print(" python wordcount_spark.py --local input.txt output.txt")
|
|
|
- sys.exit(1)
|
|
|
-
|
|
|
- # 解析参数
|
|
|
- use_local = False
|
|
|
- input_path = None
|
|
|
- output_path = None
|
|
|
-
|
|
|
- i = 1
|
|
|
- while i < len(sys.argv):
|
|
|
- arg = sys.argv[i]
|
|
|
- if arg == '--local':
|
|
|
- use_local = True
|
|
|
- elif input_path is None:
|
|
|
- input_path = arg
|
|
|
- else:
|
|
|
- output_path = arg
|
|
|
- i += 1
|
|
|
+ import argparse
|
|
|
+
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description='WordCount with PySpark',
|
|
|
+ formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
|
+ epilog="""
|
|
|
+Examples:
|
|
|
+ # 使用 Spark 集群
|
|
|
+ python wordcount_spark.py input.txt output
|
|
|
+
|
|
|
+ # 本地模式
|
|
|
+ python wordcount_spark.py --local input.txt output.json
|
|
|
+
|
|
|
+ # 使用 JSON 格式输出
|
|
|
+ python wordcount_spark.py --format json input.txt output
|
|
|
+
|
|
|
+ # 使用停用词
|
|
|
+ python wordcount_spark.py --stop-words stopwords.txt input.txt
|
|
|
+ """
|
|
|
+ )
|
|
|
+
|
|
|
+ parser.add_argument('input_path', help='Input path (local or HDFS)')
|
|
|
+ parser.add_argument('output_path', nargs='?', help='Output path (optional)')
|
|
|
+
|
|
|
+ parser.add_argument('--local', action='store_true',
|
|
|
+ help='Run in local mode (without Spark cluster)')
|
|
|
+ parser.add_argument('--format', choices=['text', 'json', 'csv', 'parquet', 'orc'],
|
|
|
+ default='text', help='Output format (default: text)')
|
|
|
+ parser.add_argument('--stop-words', help='Path to stop words file')
|
|
|
+ parser.add_argument('--min-length', type=int, default=1,
|
|
|
+ help='Minimum word length (default: 1)')
|
|
|
+ parser.add_argument('--json-result', help='Save JSON result to local file')
|
|
|
+ parser.add_argument('--app-name', help='Spark application name')
|
|
|
+ parser.add_argument('--master', help='Spark master URL')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
|
- if input_path is None:
|
|
|
- print("Error: Input path is required")
|
|
|
- sys.exit(1)
|
|
|
+ # 加载停用词
|
|
|
+ stop_words = None
|
|
|
+ if args.stop_words:
|
|
|
+ try:
|
|
|
+ with open(args.stop_words, 'r', encoding='utf-8') as f:
|
|
|
+ stop_words = [line.strip().lower() for line in f if line.strip()]
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Warning: Failed to load stop words: {e}")
|
|
|
|
|
|
- wc = WordCountSpark()
|
|
|
+ # 创建实例
|
|
|
+ wc = WordCountSpark(
|
|
|
+ app_name=args.app_name,
|
|
|
+ master=args.master
|
|
|
+ )
|
|
|
|
|
|
try:
|
|
|
- if use_local:
|
|
|
- # 本地模式(不使用 Spark)
|
|
|
- result = wc.run_with_files([input_path], output_path)
|
|
|
+ if args.local:
|
|
|
+ # 本地模式
|
|
|
+ result = wc.run_with_files(
|
|
|
+ [args.input_path],
|
|
|
+ output_path=args.json_result,
|
|
|
+ stop_words=stop_words,
|
|
|
+ min_word_length=args.min_length
|
|
|
+ )
|
|
|
else:
|
|
|
# Spark 模式
|
|
|
- result = wc.run(input_path, output_path)
|
|
|
-
|
|
|
- # 打印结果
|
|
|
- print("\nFinal results:")
|
|
|
- for word, count in sorted(result.items(), key=lambda x: x[1], reverse=True)[:20]:
|
|
|
- print(f"{word}: {count}")
|
|
|
+ output_format = OutputFormat(args.format)
|
|
|
+
|
|
|
+ result = wc.run(
|
|
|
+ input_path=args.input_path,
|
|
|
+ output_path=args.output_path,
|
|
|
+ output_format=output_format,
|
|
|
+ stop_words=stop_words,
|
|
|
+ min_word_length=args.min_length,
|
|
|
+ save_local_result=bool(args.json_result),
|
|
|
+ local_result_path=args.json_result
|
|
|
+ )
|
|
|
+
|
|
|
+ # 打印结果摘要
|
|
|
+ print("\n" + "=" * 60)
|
|
|
+ print("Final Results")
|
|
|
+ print("=" * 60)
|
|
|
+ print(f"Total words: {result.total_words:,}")
|
|
|
+ print(f"Unique words: {result.unique_words:,}")
|
|
|
+ print("\nTop 20 words:")
|
|
|
+
|
|
|
+ for i, (word, count) in enumerate(result.top_words[:20], 1):
|
|
|
+ print(f" {i:2d}. {word:15s} {count:5d}")
|
|
|
+
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ # 保存 JSON 结果
|
|
|
+ if args.json_result and not args.local:
|
|
|
+ result.save_to_file(args.json_result, OutputFormat.JSON)
|
|
|
+ print(f"\nJSON result saved to: {args.json_result}")
|
|
|
|
|
|
finally:
|
|
|
wc.stop()
|