当前位置: 首页 > article >正文

Spark-累加器源码分析

一、累加器使用

源码中给的例子是:org.apache.spark.examples.AccumulatorMetricsTest

此示例显示了如何针对累加器源注册累加器,创建一个简单的RDD,在Task中对累加器递增,结果与累加器的值一起输出到Driver中的stdout。为了可以看到效果,我们对累加过程做了些调整

其中我们关心的代码如下,即创建、累加和使用

val accLong = sc.longAccumulator("my-long-metric")
val accDouble = sc.doubleAccumulator("my-double-metric")
val accCollection = sc.collectionAccumulator[String]("my-collection-metric")

val num = if (args.length > 0) args(0).toInt else 1000
val accumulatorTest = sc.parallelize(1 to num).foreach(thisNum=> {
  accLong.add(3)
  accDouble.add(1.1)
  accDouble.add(2.1)
  accCollection.add("num:"+thisNum)
})

println("*** Long accumulator (my-long-metric): " + accLong.value)
println("*** Long accumulator (my-long-metric): count:" + accLong.count)
println("*** Long accumulator (my-long-metric): sum:" + accLong.sum)
println("*** Long accumulator (my-long-metric): avg:" + accLong.avg)
println("*** Double accumulator (my-double-metric): " + accDouble.value)
println("*** Double accumulator (my-double-metric): count:" + accDouble.count)
println("*** Double accumulator (my-double-metric): sum:" + accDouble.sum)
println("*** Double accumulator (my-double-metric): avg:" + accDouble.avg)
println("*** Collection accumulator (my-collection-metric): " + accCollection.value)

输出结果为:

*** Long accumulator (my-long-metric): 3000
*** Long accumulator (my-long-metric): count:1000
*** Long accumulator (my-long-metric): sum:3000
*** Long accumulator (my-long-metric): avg:3.0
*** Double accumulator (my-double-metric): 3199.9999999998868
*** Double accumulator (my-double-metric): count:2000
*** Double accumulator (my-double-metric): sum:3199.9999999998868
*** Double accumulator (my-double-metric): avg:1.5999999999999435
*** Collection accumulator (my-collection-metric): [num:1, num:2, num:3, num:4, num:5, num:6, ......,num:999, num:1000]

二、创建累加器

我们拿最常用的longAccumulator来看下:

1、SparkContext

创建并注册一个Long累加器,它从0开始,通过“add”累加输入

  def longAccumulator(name: String): LongAccumulator = {
    val acc = new LongAccumulator
    register(acc, name)
    acc
  }

  def register(acc: AccumulatorV2[_, _], name: String): Unit = {
    //调用AccumulatorV2的register
    acc.register(this, name = Option(name))
  }

2、AccumulatorV2

累加器的基类,可以累加“IN”类型的输入,并产生“OUT”类型的输出

它是LongAccumulator、DoubleAccumulator、CollectionAccumulator的父类

abstract class AccumulatorV2[IN, OUT] extends Serializable {

  private[spark] def register(
      sc: SparkContext,
      name: Option[String] = None,
      countFailedValues: Boolean = false): Unit = {
    if (this.metadata != null) {
      throw new IllegalStateException("Cannot register an Accumulator twice.")
    }
    this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)
    AccumulatorContext.register(this)
    sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
  }

    //...............

//内部类
private[spark] object AccumulatorContext extends Logging {

  //此全局映射保存在Driver上创建的原始累加器对象。它保留了对这些对象的弱引用,这样一旦RDD和引用它们的用户代码被清理干净,累加器就可以被垃圾回收。
  private val originals = new ConcurrentHashMap[Long, jl.ref.WeakReference[AccumulatorV2[_, _]]]

  //注册在Driver上创建的[[AcumulatorV2]],以便在Executor上使用。
  //此处注册的所有累加器稍后都可以用作跨多个Task累加部分值的容器。这就是org.apache.spark.scheduler。DAGScheduler上来做的。
  //注意:如果在此处注册了累加器,则还应将其注册到活动上下文清理器中进行清理,以避免内存泄漏。
  //如果已经注册了具有相同ID的[[AcumulatorV2]],这只会覆盖它,而不会做任何事情。我们永远不会重复注册同一个累加器
  def register(a: AccumulatorV2[_, _]): Unit = {
    originals.putIfAbsent(a.id, new jl.ref.WeakReference[AccumulatorV2[_, _]](a))
  }

  //...............
}

//用于计算64位整数的求和、计数和平均值的[[AcumulatorV2累加器]]
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {

  private var _sum = 0L
  private var _count = 0L

  override def add(v: jl.Long): Unit = {
    _sum += v
    _count += 1
  }

  def count: Long = _count

  def sum: Long = _sum

  def avg: Double = _sum.toDouble / _count

  override def value: jl.Long = _sum

  override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
    case o: LongAccumulator =>
      _sum += o.sum
      _count += o.count
    case _ =>
     //.....抛异常....
  }

  //...............
}

//用于计算双精度浮点数的求和、计数和平均值的累加器
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {

  private var _sum = 0.0
  private var _count = 0L

  override def add(v: jl.Double): Unit = {
    _sum += v
    _count += 1
  }

  def count: Long = _count

  def sum: Double = _sum

  def avg: Double = _sum / _count

  override def value: jl.Double = _sum

  override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
    case o: DoubleAccumulator =>
      _sum += o.sum
      _count += o.count
    case _ =>
     //.....抛异常....
  }

  //...............
}

//用于收集元素列表的[[AcumulatorV2累加器]]
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {

  private var _list: java.util.List[T] = _

  override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match {
    case o: CollectionAccumulator[T] => this.synchronized(getOrCreate.addAll(o.value))
    case _ => //.....抛异常....
  }

  private def getOrCreate = {
    _list = Option(_list).getOrElse(new java.util.ArrayList[T]())
    _list
  }

  override def value: java.util.List[T] = this.synchronized {
    java.util.Collections.unmodifiableList(new ArrayList[T](getOrCreate))
  }

  //...............
}

}

三、实现累加

从AccumulatorV2中的方法我们可以知道最终是在DAGScheduler调用它的merge方法来实现的累加,下面我们详细看下在什么位置:

  private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
    //当任务完成后会调用handleTaskCompletion
    case completion: CompletionEvent =>
      dagScheduler.handleTaskCompletion(completion)

  }


  private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = {
    val task = event.task
    val stageId = task.stageId

    //........省略........

    //确保在任何其他处理发生之前更新任务的累加器,以便我们可以在更新任何作业或阶段之前发布任务结束事件。
    event.reason match {
      case Success =>
        task match {
          case rt: ResultTask[_, _] =>
            val resultStage = stage.asInstanceOf[ResultStage]
            resultStage.activeJob match {
              case Some(job) =>
                // 对于每个结果任务,只更新一次累加器
                if (!job.finished(rt.outputId)) {
                  updateAccumulators(event)
                }
              case None => // 如果任务的作业已完成,则忽略更新
            }
          case _ =>
            updateAccumulators(event)
        }
      case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
      case _ =>
    }

    //........省略........

  }

  private def updateAccumulators(event: CompletionEvent): Unit = {
    val task = event.task
    val stage = stageIdToStage(task.stageId)

    event.accumUpdates.foreach { updates =>
      val id = updates.id
      try {
        // 在Driver上找到相应的累加器并更新
        val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match {
          case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]]
          case None =>
            throw new SparkException(s"attempted to access non-existent accumulator $id")
        }
        acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]])

      } catch {
        //......异常处理........
      }
    }
  }

四、总结

1、通过SparkContext创建累加器(LongAccumulator、DoubleAccumulator、CollectionAccumulator)

2、在Driver端注册累加器(累加器必须先注册再使用)(其实就是向全局Map中放入了该累加器)

3、累加器从0开始计数,在每层Stage对应的Task结束时通过merge方法更新Driver端的累计器

4、当一个Job跑完后我们就可以使用累加器变量了,如果是数值型可以拿到总和、累加次数、平均值,如果时集合型可以拿到一个数据序列


http://www.kler.cn/news/315924.html

相关文章:

  • JS执行机制(同步和异步)
  • 深度学习入门:探索神经网络、感知器与损失函数
  • html实现TAB选项卡切换
  • LLMs之OCR:llm_aided_ocr(基于LLM辅助的OCR项目)的简介、安装和使用方法、案例应用之详细攻略
  • Python之一些列表的练习题
  • Spring Boot入门:构建你的首个Spring Boot应用
  • Mybatis-plus进阶篇(二)
  • 【JUC并发编程系列】深入理解Java并发机制:线程局部变量的奥秘与最佳实践(五、ThreadLocal原理、对象之间的引用)
  • 数据结构 ——— 常见的时间复杂度计算例题(最终篇)
  • Linux驱动开发 ——架构体系
  • 求最大公约数
  • CSS 布局三大样式简单学习
  • 【解密 Kotlin 扩展函数】命名参数和默认值(十三)
  • 【深入Java枚举类:不仅仅是常量的容器】
  • 数据结构——串的模式匹配算法(BF算法和KMP算法)
  • 设计模式-装饰者模式
  • VMware虚拟机经常性卡死,打开运行一段时间后卡死,CPU占比增至100%
  • 电脑网络怎么弄动态ip :步骤详解与优势探讨
  • Tomcat系列漏洞复现
  • AI时代最好的编程语言应该选择谁?
  • vue h5 蓝牙连接 webBluetooth API
  • MySQL 中删除重复的数据并只保留一条
  • C#实现指南:将文件夹与exe合并为一个exe
  • vscode 环境搭建
  • 神经网络修剪实战
  • ubuntu安装docker compose
  • 解决 TortoiseGitPlink Fatal Error:深入解析
  • JS巧用.padStart()|.padEnd()方法用另一个字符串填充当前字符串
  • 9月16日笔记
  • 工作笔记:Vue 3 中使用 vue-router 进行导航与监听路由变化