spark.sql
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, mean, rank, row_number, desc
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
# 初始化 SparkSession 对象
spark = SparkSession.builder \
.appName("Example PySpark Script with TempView and SQL") \
.getOrCreate()
# 定义数据结构
schema = StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("city", StringType(), True)
])
# 创建第一个 DataFrame
data1 = [
("Alice", 34, "New York"),
("Bob", 45, "Los Angeles"),
("Cathy", 29, "San Francisco"),
("David", 32, "Chicago"),
("Eve", 27, "Seattle")
]
df1 = spark.createDataFrame(data=data1, schema=schema)
# 创建第二个 DataFrame
data2 = [
("Frank", 30, "New York"),
("Grace", 38, "Los Angeles"),
("Hannah", 25, "San Francisco"),
("Ian", 42, "Chicago"),
("Jack", 28, "Seattle")
]
df2 = spark.createDataFrame(data=data2, schema=schema)
# 查看 DataFrame 结构
df1.printSchema()
df2.printSchema()
# 使用 filter 过滤年龄大于等于 30 的记录
filtered_df1 = df1.filter(col("age") >= 30)
filtered_df2 = df2.filter(col("age") >= 30)
# 使用 group by 计算每个城市的平均年龄
grouped_df1 = filtered_df1.groupBy("city").agg(
count("name").alias("count"),
mean("age").alias("avg_age")
)
grouped_df2 = filtered_df2.groupBy("city").agg(
count("name").alias("count"),
mean("age").alias("avg_age")
)
# 合并两个 DataFrame
merged_df = grouped_df1.union(grouped_df2)
# 从合并后的 DataFrame 中随机抽取 50% 的样本
sampled_df = merged_df.sample(withReplacement=False, fraction=0.5)
# 限制结果集的大小为 10 条记录
limited_df = sampled_df.limit(10)
# 使用窗口函数进行排名
window_spec = Window.partitionBy("city").orderBy(desc("avg_age"))
ranked_df = limited_df.withColumn("rank", rank().over(window_spec)).withColumn("row_number", row_number().over(window_spec))
# 将 DataFrame 注册为临时视图
ranked_df.createOrReplaceTempView("ranked_cities")
# 使用 SQL 查询
sql_query = """
SELECT city, count, avg_age, rank, row_number
FROM ranked_cities
WHERE rank <= 2
"""
# 执行 SQL 查询
sql_results = spark.sql(sql_query)
# 显示结果
sql_results.show(truncate=False)
# 关闭 SparkSession
spark.stop()
在 PySpark 中,createOrReplaceTempView
方法可以将 DataFrame 注册为临时视图(temporary view),这样就可以使用 SQL 查询来操作 DataFrame。临时视图只在当前 SparkSession 的生命周期内有效,并且在同一 SparkSession 中可以被多次替换。
我们可以在之前的示例中加入 createOrReplaceTempView
,以便使用 SQL 查询来完成一些操作。
代码解释
- 创建 DataFrame:定义数据结构,并创建两个 DataFrame。
- 使用
filter
:过滤符合条件的记录。 - 使用
group by
:按字段进行分组聚合。 - 使用
union
:将两个 DataFrame 合并。 - 使用
sample
:从 DataFrame 中随机抽取样本。 - 使用
limit
:限制结果集的大小。 - 使用窗口函数:添加窗口函数来执行复杂的分析。
- 使用
createOrReplaceTempView
:注册临时视图。 - 使用 SQL 查询:执行 SQL 查询。
在 PySpark 中,执行 SQL 查询可能会比直接使用 DataFrame API 慢一些,原因在于以下几个方面:
-
SQL 解析和优化:当使用 SQL 查询时,PySpark 需要解析 SQL 语句,将其转换成逻辑计划,然后进行优化,最终生成物理执行计划。这个过程可能需要一些时间,尤其是在复杂的查询中。
-
Shuffle 操作:如果 SQL 查询涉及 shuffle 操作(例如 group by、join 等),那么数据需要重新分区和排序,这会导致额外的计算开销和磁盘 I/O。在你的例子中,虽然没有涉及 shuffle 操作,但如果查询复杂度增加,shuffle 可能成为瓶颈。
-
数据序列化和反序列化:在执行 SQL 查询时,数据可能需要多次序列化和反序列化,这也会影响性能。
-
执行计划缓存:对于重复执行的查询,执行计划可以被缓存,从而加速后续执行。但是,对于一次性查询,这种缓存带来的好处有限。
-
数据量:如果数据量很大,即使是简单的筛选操作也可能花费一定的时间。
优化建议
为了提高 SQL 查询的性能,可以考虑以下几个优化策略:
-
减少 Shuffle:尽量减少涉及 shuffle 的操作,例如使用广播 join 而不是普通的 join。
-
缓存 DataFrame:如果你反复使用同一个 DataFrame,可以将其缓存(persist 或 cache)以减少重复计算。
-
使用 DataFrame API:尽可能使用 DataFrame API 替代 SQL 查询,因为 DataFrame API 通常更高效。
-
索引:虽然 PySpark 本身没有索引的概念,但可以通过预处理数据来减少查询时的数据扫描范围。
-
调整配置:调整 Spark 的配置参数,例如增加内存分配、调整 shuffle 的参数等。