将pyspark中的UDF提升6倍
本文亮点
调用jar中的UDF,减少python与JVM的交互,简单banchmark下对于54亿条数据集进行udf计算,从3小时的执行时间缩短至16分钟。
牺牲UDF部分的开发时间,尽量提高性能。
以接近纯python的开发成本,获得逼近纯scala的性能。兼顾性能和开发效率。
前提
当遇到sql无法直接处理的数据时(比如加密解密、thrift解析操作二进制),我们需要自定义函数(UDF)来进行处理。出于开发效率的考虑,我们一般会选择airflow,使用pyspark脚本。
优化后的代码
from datetime import datetime
from pyspark.sql import SparkSession
import sys
if __name__ == "__main__":
# 创建 Spark 会话
spark = SparkSession.builder.appName("xxx").enableHiveSupport().getOrCreate()
DT = sys.argv[1]
HOUR = sys.argv[2]
F_DT = datetime.strptime(DT, "%Y-%m-%d").strftime("%Y%m%d")
tmp_tbl_name = f"""temp_xxx_pre_0_{F_DT}_{HOUR}"""
print('''注册java函数''')
spark.udf.registerJavaFunction("get_area_complex_value", "com.xxx.utf.block.AreaComplexValue")
spark.udf.registerJavaFunction("most_common_element", "com.xxx.utf.common.MosCommonElement")
spark.sql('''set hive.exec.dynamic.partition.mode=nonstrict;''')
exec_sql = '''
insert overwrite table xxx.xxx --你的目标表
select
a.distinct_id device_id -- l临时
,app_id
,install_datetime
,os
,ip
,country
,city
,uuid
,a.distinct_id
,event_name
,event_timestamp
,event_datetime
,game_id
,game_type
,round_id
,travel_id
,travel_lv
,matrix
,position
,concat('[', concat_ws(',', clean), ']') as clean
,block_id
,index_id
,rec_strategy
,rec_strategy_fact
,combo_cnt
,gain_score
,gain_item
,block_list
,lag_event_timestamp
,(event_timestamp - lag_event_timestamp) / 1000 AS time_diff_in_seconds
,most_common_element( replace( replace(replace(rec_strategy_fact, '[', ''), ']', ''), '"', '' )) as rec_strategy_fact_most
,lag_matrix
,is_clear_screen
,is_blast
,blast_row_col_cnt
,CASE
WHEN size(clean) > 0 THEN
CASE
WHEN (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) > 0 THEN TRUE
WHEN (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 0
AND combo_cnt = '-1'
AND (IF(size(lead_clean_3) > 0, 1, 0) + IF(size(lead_clean_2) > 0, 1, 0) + IF(size(lead_clean_1) > 0, 1, 0)) > 0 THEN TRUE
ELSE FALSE
END
ELSE
CASE
WHEN (IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 0 THEN FALSE
WHEN (IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 2 THEN TRUE
WHEN size(lag_clean_1) > 0
AND lag_combo_cnt_1 = -1
AND (IF(size(lead_clean_1) > 0, 1, 0) + IF(size(lead_clean_2) > 0, 1, 0)) >= 1 THEN TRUE
WHEN size(lag_clean_1) > 0
AND lag_combo_cnt_1 > -1
AND (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_4) > 0, 1, 0)) >= 1 THEN TRUE
WHEN size(lag_clean_2) > 0
AND lag_combo_cnt_2 = -1
AND size(lead_clean_1) > 0 THEN TRUE
WHEN size(lag_clean_2) > 0
AND lag_combo_cnt_2 > -1
AND (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_4) > 0, 1, 0) + IF(size(lag_clean_5) > 0, 1, 0)) >= 1 THEN TRUE
ELSE FALSE
END
END AS is_combo_status
,common_block_cnt
,step_score
,block_index_id
,cast(get_area_complex_value(lag_matrix) as int) matrix_complex_value
,-0.1 as block_line_percent
,-0.1 as corner_outside_percent
,-0.1 as corner_inside_percent
,sum((event_timestamp - lag_event_timestamp) / 1000) over(partition by game_id,game_type,distinct_id) time_accumulate_seconds -- double COMMENT '当前块距离本局游戏的开始的时长,秒级别的'
,max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) max_round_id -- 最大轮数2024-12-11数据开始准确
,case when round_id=max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) then true else false end is_final_round -- boolean COMMENT '此轮是不是最后一轮' 最后一轮 2024-12-11数据开始准确
,case when round_id=max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) and block_index_id=max(block_index_id) over(partition by game_id,game_type,distinct_id,round_id ) then true else false end is_lethal_block -- boolean COMMENT '是不是致死块' 最后一轮里面的最后一个放块 2024-12-11数据开始准确
,sum(step_score) over(partition by game_id,game_type,distinct_id) - step_score lag_accumulate_score -- double COMMENT '出此块前的累计分数'
,sum(step_score) over(partition by game_id,game_type,distinct_id) accumulate_score -- double COMMENT '出此块后的累计分数'
,cast((event_timestamp-last_click_time) / 1000 as int) as time_action_in_seconds -- 落块动作的时间
,(event_timestamp - lag_event_timestamp) / 1000 - (event_timestamp-last_click_time) / 1000 as time_think_in_seconds --落块-思考时间
,0 as last_click_time
,cast(`gain_score_per_done` as Integer) as gain_score_per_done
,cast(`is_clean_screen` as Integer) as is_clean_screen
,cast(`weight` as float) as weight
,cast(`put_rate` as float) as put_rate
,0 as userwaynum
,clean_times
,clean_cnt
,sum(clean_times) over(partition by game_id,game_type,distinct_id) as accumulate_clean_times
,sum(clean_cnt) over(partition by game_id,game_type,distinct_id) as accumulate_clean_cnt
,app_version
,ram
,disk
,cast(get_area_complex_value(matrix) as int) cur_matrix_complex_value
,block_down_color
,design_position
,1 as is_sdk_sample
,network_type
,session_id
,block_shape_list
,block_shape
,design_postion_upleft
,fps
,-1 as fact_line
,case when block_id in (1) then 4
when block_id in (2,3) then 6
when block_id in (4 ,5 ,6 ,9 ,15 ,27 ,28 ,37 ,38 ) then 8
when block_id in (7 ,8 ,10 ,14 ,16 ,17 ,18 ,19 ,20 ,25 ,26 ,29 ,30 ,31 ,32 ,33 ,34 ,35 ,36 ,42) then 10
when block_id in (11, 12 ,13 ,21 ,22 ,23 ,24 ,39 ,40 ,41) then 12
end as total_line
,dt
from xxx.xxx a --你的源表
where dt='{DT}' and event_name = 'game_touchend_block_done' and (event_timestamp - lag_event_timestamp)>0;
'''.format(DT=DT, tmp_tbl_name=tmp_tbl_name)
print(exec_sql)
spark.sql(exec_sql)
# 关闭 Spark 会话
spark.stop()
低层实现原理
如上图所示,pyspark并没有像dpark一样用python重新实现一个计算引擎,依旧是复用了scala的jvm计算底层,只是用py4j架设了一条python进程和jvm互相调用的桥梁。
driver
: pyspark脚本和sparkContext的jvm使用py4j相互调用;
executor
: 由于driver帮忙把spark算子封装好了,执行计划也生成了字节码,一般情况下不需要python进程参与,仅当需要运行UDF(含lambda表达式形式)时,将它委托给python进程处理(DAG图中的BatchEvalPython
步骤),此时JVM和python进程使用socket通信。
上述使用简单UDF时的pyspark由于需要使用UDF,因此DAG图中有BatchEvalPython
步骤:
BatchEvalPython过程
参考源码:spark/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala at master · apache/spark · GitHub
可以看到和这个名字一样直白,它就是每次取100条数据让python进程帮忙处理一下:
// 第58行:
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.map { row =>
if (needConversion) {
EvaluatePython.toJava(row, schema)
} else {
// fast path for these types that does not need conversion in Python
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
fields
}
}.grouped(100).map(x => pickle.dumps(x.toArray))
由于我们的计算任务一般耗时瓶颈在于executor端的计算而不是driver,因此应该考虑尽量减少executor端调用python代码的次数从而优化性能。
参考源码:spark/python/pyspark/java_gateway.py at master · apache/spark · GitHub
// 大概135行的地方:
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.ml.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
# TODO(davies): move into sql
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
pyspark可以把很多常见的运算封装到JVM中,但是显然不包括我们的UDF。
所以一个很自然的思路就是把我们的UDF也封到JVM中。
将python的自定义函数改成java
https://github.com/sunlongjiang/adx/blob/master/adx_platform_common/src/main/java/com/hungrystudio/utf/block/AreaComplexValue.java
并在任务中通过--jars 引用该jar包
"--jars", "s3://hungry-studio-data-warehouse/user/sunlj/java_udf/adx_platform_common-4.0.0.jar",
改写后运行任务发现比之前少了两个transform,没有了BatchEvalPython
,也少了一个WholeStageCodeGen
。
优化前该任务的执行时长为3个小时,
优化后改任务的执行时长为16分钟,效果非常明显!!!
因此在pyspark中尽量使用spark算子和spark-sql,同时尽量将UDF(含lambda表达式形式)封装到一个地方减少JVM和python脚本的交互。
由于BatchEvalPython
过程每次处理100行,也可以把多行聚合成一行减少交互次数。
最后还可以把UDF部分用java重写打包成jar包,其他部分则保持python脚本以获得不用编译随时修改的灵活性,以兼顾性能和开发效率。