当前位置: 首页 > article >正文

将pyspark中的UDF提升6倍

本文亮点

      调用jar中的UDF,减少python与JVM的交互,简单banchmark下对于54亿条数据集进行udf计算,从3小时的执行时间缩短至16分钟。
牺牲UDF部分的开发时间,尽量提高性能。
以接近纯python的开发成本,获得逼近纯scala的性能。兼顾性能和开发效率。

前提

       当遇到sql无法直接处理的数据时(比如加密解密、thrift解析操作二进制),我们需要自定义函数(UDF)来进行处理。出于开发效率的考虑,我们一般会选择airflow,使用pyspark脚本。

优化后的代码

from datetime import datetime
from pyspark.sql import SparkSession
import sys

if __name__ == "__main__":
    # 创建 Spark 会话
    spark = SparkSession.builder.appName("xxx").enableHiveSupport().getOrCreate()

    DT = sys.argv[1]
    HOUR = sys.argv[2]
    F_DT = datetime.strptime(DT, "%Y-%m-%d").strftime("%Y%m%d")
    tmp_tbl_name = f"""temp_xxx_pre_0_{F_DT}_{HOUR}"""

    print('''注册java函数''')
    spark.udf.registerJavaFunction("get_area_complex_value", "com.xxx.utf.block.AreaComplexValue")
    spark.udf.registerJavaFunction("most_common_element", "com.xxx.utf.common.MosCommonElement")
    spark.sql('''set hive.exec.dynamic.partition.mode=nonstrict;''')
    exec_sql = '''
  insert overwrite table xxx.xxx --你的目标表
			select
			a.distinct_id device_id -- l临时
			,app_id
			,install_datetime
			,os
			,ip
			,country
			,city
			,uuid
			,a.distinct_id
			,event_name
			,event_timestamp
			,event_datetime
			,game_id
			,game_type
			,round_id
			,travel_id
			,travel_lv
			,matrix
			,position
			,concat('[', concat_ws(',', clean), ']') as clean
			,block_id
			,index_id
			,rec_strategy
			,rec_strategy_fact
			,combo_cnt
			,gain_score
			,gain_item
			,block_list
			,lag_event_timestamp
			,(event_timestamp - lag_event_timestamp) / 1000 AS time_diff_in_seconds
			,most_common_element( replace( replace(replace(rec_strategy_fact, '[', ''), ']', ''), '"', '' )) as rec_strategy_fact_most
			,lag_matrix
			,is_clear_screen
			,is_blast
			,blast_row_col_cnt
			,CASE 
                WHEN size(clean) > 0 THEN
                    CASE 
                        WHEN (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) > 0 THEN TRUE
                        WHEN (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 0
                        AND combo_cnt = '-1'
                        AND (IF(size(lead_clean_3) > 0, 1, 0) + IF(size(lead_clean_2) > 0, 1, 0) + IF(size(lead_clean_1) > 0, 1, 0)) > 0 THEN TRUE
                        ELSE FALSE
                    END
                ELSE
                    CASE 
                        WHEN (IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 0 THEN FALSE
                        WHEN (IF(size(lag_clean_2) > 0, 1, 0) + IF(size(lag_clean_1) > 0, 1, 0)) = 2 THEN TRUE
                        WHEN size(lag_clean_1) > 0
                            AND lag_combo_cnt_1 = -1
                            AND (IF(size(lead_clean_1) > 0, 1, 0) + IF(size(lead_clean_2) > 0, 1, 0)) >= 1 THEN TRUE
                        WHEN size(lag_clean_1) > 0
                            AND lag_combo_cnt_1 > -1
                            AND (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_4) > 0, 1, 0)) >= 1 THEN TRUE
                        WHEN size(lag_clean_2) > 0
                            AND lag_combo_cnt_2 = -1
                            AND size(lead_clean_1) > 0 THEN TRUE
                        WHEN size(lag_clean_2) > 0
                            AND lag_combo_cnt_2 > -1
                            AND (IF(size(lag_clean_3) > 0, 1, 0) + IF(size(lag_clean_4) > 0, 1, 0) + IF(size(lag_clean_5) > 0, 1, 0)) >= 1 THEN TRUE
                        ELSE FALSE
                    END
                END AS is_combo_status
			,common_block_cnt
			,step_score
			,block_index_id
			,cast(get_area_complex_value(lag_matrix) as int) matrix_complex_value
			,-0.1 as block_line_percent
			,-0.1 as corner_outside_percent
			,-0.1 as corner_inside_percent
			,sum((event_timestamp - lag_event_timestamp) / 1000) over(partition by game_id,game_type,distinct_id) time_accumulate_seconds --  double COMMENT '当前块距离本局游戏的开始的时长,秒级别的'
			,max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) max_round_id -- 最大轮数2024-12-11数据开始准确
			,case when round_id=max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) then true else false end  is_final_round -- boolean COMMENT '此轮是不是最后一轮' 最后一轮 2024-12-11数据开始准确
			,case when round_id=max(cast(round_id as Integer)) over(partition by game_id,game_type,distinct_id ) and block_index_id=max(block_index_id) over(partition by game_id,game_type,distinct_id,round_id )  then true 			else false end is_lethal_block -- boolean COMMENT '是不是致死块' 最后一轮里面的最后一个放块 2024-12-11数据开始准确
			,sum(step_score) over(partition by game_id,game_type,distinct_id) - step_score lag_accumulate_score -- double COMMENT '出此块前的累计分数'
			,sum(step_score) over(partition by game_id,game_type,distinct_id) accumulate_score -- double COMMENT '出此块后的累计分数'
			,cast((event_timestamp-last_click_time) / 1000 as int) as time_action_in_seconds -- 落块动作的时间
			,(event_timestamp - lag_event_timestamp) / 1000	- (event_timestamp-last_click_time) / 1000 as time_think_in_seconds --落块-思考时间
			,0 as last_click_time
			,cast(`gain_score_per_done` as Integer) as gain_score_per_done
			,cast(`is_clean_screen` as Integer) as is_clean_screen
			,cast(`weight` as float) as weight
			,cast(`put_rate` as float) as put_rate
			,0 as userwaynum
			,clean_times
			,clean_cnt
			,sum(clean_times) over(partition by game_id,game_type,distinct_id)   as accumulate_clean_times
			,sum(clean_cnt) over(partition by game_id,game_type,distinct_id)     as accumulate_clean_cnt
			,app_version
			,ram
			,disk
			,cast(get_area_complex_value(matrix) as int) cur_matrix_complex_value
			,block_down_color
			,design_position
			,1 as is_sdk_sample
			,network_type
			,session_id
			,block_shape_list
			,block_shape
			,design_postion_upleft
			,fps
			,-1 as fact_line
		    ,case when block_id in (1)  then 4  
                  when block_id in (2,3) then 6  
                  when block_id in (4 ,5 ,6 ,9 ,15 ,27 ,28 ,37 ,38 ) then 8 
                  when block_id in  (7 ,8 ,10 ,14 ,16 ,17 ,18 ,19 ,20 ,25 ,26 ,29 ,30 ,31 ,32 ,33 ,34 ,35 ,36 ,42)   then 10 
                  when block_id in (11, 12 ,13 ,21 ,22 ,23 ,24 ,39 ,40 ,41)   then 12 
             end as total_line
			,dt
			from xxx.xxx a --你的源表
			where dt='{DT}' and event_name = 'game_touchend_block_done' and (event_timestamp - lag_event_timestamp)>0;
     '''.format(DT=DT, tmp_tbl_name=tmp_tbl_name)
    print(exec_sql)
    spark.sql(exec_sql)

    # 关闭 Spark 会话
    spark.stop()

低层实现原理

如上图所示,pyspark并没有像dpark一样用python重新实现一个计算引擎,依旧是复用了scala的jvm计算底层,只是用py4j架设了一条python进程和jvm互相调用的桥梁。
driver: pyspark脚本和sparkContext的jvm使用py4j相互调用;
executor: 由于driver帮忙把spark算子封装好了,执行计划也生成了字节码,一般情况下不需要python进程参与,仅当需要运行UDF(含lambda表达式形式)时,将它委托给python进程处理(DAG图中的BatchEvalPython步骤),此时JVM和python进程使用socket通信。

上述使用简单UDF时的pyspark由于需要使用UDF,因此DAG图中有BatchEvalPython步骤:

BatchEvalPython过程

参考源码:spark/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala at master · apache/spark · GitHub
可以看到和这个名字一样直白,它就是每次取100条数据让python进程帮忙处理一下:

// 第58行:
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
    // For each row, add it to the queue.
    val inputIterator = iter.map { row =>
      if (needConversion) {
        EvaluatePython.toJava(row, schema)
      } else {
        // fast path for these types that does not need conversion in Python
        val fields = new Array[Any](row.numFields)
        var i = 0
        while (i < row.numFields) {
          val dt = dataTypes(i)
          fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
          i += 1
        }
        fields
      }
    }.grouped(100).map(x => pickle.dumps(x.toArray))

由于我们的计算任务一般耗时瓶颈在于executor端的计算而不是driver,因此应该考虑尽量减少executor端调用python代码的次数从而优化性能。

参考源码:spark/python/pyspark/java_gateway.py at master · apache/spark · GitHub

// 大概135行的地方:
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.ml.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
# TODO(davies): move into sql
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")

pyspark可以把很多常见的运算封装到JVM中,但是显然不包括我们的UDF。
所以一个很自然的思路就是把我们的UDF也封到JVM中。

将python的自定义函数改成java

https://github.com/sunlongjiang/adx/blob/master/adx_platform_common/src/main/java/com/hungrystudio/utf/block/AreaComplexValue.java

并在任务中通过--jars 引用该jar包

 "--jars", "s3://hungry-studio-data-warehouse/user/sunlj/java_udf/adx_platform_common-4.0.0.jar",

改写后运行任务发现比之前少了两个transform,没有了BatchEvalPython,也少了一个WholeStageCodeGen

优化前该任务的执行时长为3个小时,

优化后改任务的执行时长为16分钟,效果非常明显!!!

因此在pyspark中尽量使用spark算子和spark-sql,同时尽量将UDF(含lambda表达式形式)封装到一个地方减少JVM和python脚本的交互。
由于BatchEvalPython过程每次处理100行,也可以把多行聚合成一行减少交互次数。
最后还可以把UDF部分用java重写打包成jar包,其他部分则保持python脚本以获得不用编译随时修改的灵活性,以兼顾性能和开发效率。


http://www.kler.cn/a/551072.html

相关文章:

  • uniapp webview嵌入外部h5网页后的消息通知
  • 机器学习入门实战 3 - 数据可视化
  • 量化噪声介绍
  • 网络安全-攻击流程-传输层
  • 11、《Web开发性能优化:静态资源处理与缓存控制深度解析》
  • LeetCode--23. 合并 K 个升序链表【堆和分治】
  • rust学习笔记2-rust的包管理工具Cargo使用
  • 深化与细化:提示工程(Prompt Engineering)的进阶策略与实践指南2
  • 5G时代的运维变革与美信监控易的深度剖析
  • 【Windows使用VNC和Cpolar实现跨平台高安全性的远程桌面在线连接】
  • VSCode 实用快捷键
  • Query String 传递 json 对象参数、map参数
  • Linux中进程的状态2
  • C#Halcon九点标定自动标定插件
  • 11-跳跃游戏
  • android uri路径转正常本地图片路径
  • 利用爬虫精准获取淘宝商品描述:实战案例指南
  • Python in Excel高级分析:一键RFM分析
  • 美国股市主要指数介绍(Major U.S. Stock Market Indexes):三大股指(中英双语)
  • 基于Flask的艺恩影片票房分析系统的设计与实现