Spark-累加器源码分析
一、累加器使用
源码中给的例子是:org.apache.spark.examples.AccumulatorMetricsTest
此示例显示了如何针对累加器源注册累加器,创建一个简单的RDD,在Task中对累加器递增,结果与累加器的值一起输出到Driver中的stdout。为了可以看到效果,我们对累加过程做了些调整
其中我们关心的代码如下,即创建、累加和使用
val accLong = sc.longAccumulator("my-long-metric")
val accDouble = sc.doubleAccumulator("my-double-metric")
val accCollection = sc.collectionAccumulator[String]("my-collection-metric")
val num = if (args.length > 0) args(0).toInt else 1000
val accumulatorTest = sc.parallelize(1 to num).foreach(thisNum=> {
accLong.add(3)
accDouble.add(1.1)
accDouble.add(2.1)
accCollection.add("num:"+thisNum)
})
println("*** Long accumulator (my-long-metric): " + accLong.value)
println("*** Long accumulator (my-long-metric): count:" + accLong.count)
println("*** Long accumulator (my-long-metric): sum:" + accLong.sum)
println("*** Long accumulator (my-long-metric): avg:" + accLong.avg)
println("*** Double accumulator (my-double-metric): " + accDouble.value)
println("*** Double accumulator (my-double-metric): count:" + accDouble.count)
println("*** Double accumulator (my-double-metric): sum:" + accDouble.sum)
println("*** Double accumulator (my-double-metric): avg:" + accDouble.avg)
println("*** Collection accumulator (my-collection-metric): " + accCollection.value)
输出结果为:
*** Long accumulator (my-long-metric): 3000
*** Long accumulator (my-long-metric): count:1000
*** Long accumulator (my-long-metric): sum:3000
*** Long accumulator (my-long-metric): avg:3.0
*** Double accumulator (my-double-metric): 3199.9999999998868
*** Double accumulator (my-double-metric): count:2000
*** Double accumulator (my-double-metric): sum:3199.9999999998868
*** Double accumulator (my-double-metric): avg:1.5999999999999435
*** Collection accumulator (my-collection-metric): [num:1, num:2, num:3, num:4, num:5, num:6, ......,num:999, num:1000]
二、创建累加器
我们拿最常用的longAccumulator来看下:
1、SparkContext
创建并注册一个Long累加器,它从0开始,通过“add”累加输入
def longAccumulator(name: String): LongAccumulator = {
val acc = new LongAccumulator
register(acc, name)
acc
}
def register(acc: AccumulatorV2[_, _], name: String): Unit = {
//调用AccumulatorV2的register
acc.register(this, name = Option(name))
}
2、AccumulatorV2
累加器的基类,可以累加“IN”类型的输入,并产生“OUT”类型的输出
它是LongAccumulator、DoubleAccumulator、CollectionAccumulator的父类
abstract class AccumulatorV2[IN, OUT] extends Serializable {
private[spark] def register(
sc: SparkContext,
name: Option[String] = None,
countFailedValues: Boolean = false): Unit = {
if (this.metadata != null) {
throw new IllegalStateException("Cannot register an Accumulator twice.")
}
this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)
AccumulatorContext.register(this)
sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
}
//...............
//内部类
private[spark] object AccumulatorContext extends Logging {
//此全局映射保存在Driver上创建的原始累加器对象。它保留了对这些对象的弱引用,这样一旦RDD和引用它们的用户代码被清理干净,累加器就可以被垃圾回收。
private val originals = new ConcurrentHashMap[Long, jl.ref.WeakReference[AccumulatorV2[_, _]]]
//注册在Driver上创建的[[AcumulatorV2]],以便在Executor上使用。
//此处注册的所有累加器稍后都可以用作跨多个Task累加部分值的容器。这就是org.apache.spark.scheduler。DAGScheduler上来做的。
//注意:如果在此处注册了累加器,则还应将其注册到活动上下文清理器中进行清理,以避免内存泄漏。
//如果已经注册了具有相同ID的[[AcumulatorV2]],这只会覆盖它,而不会做任何事情。我们永远不会重复注册同一个累加器
def register(a: AccumulatorV2[_, _]): Unit = {
originals.putIfAbsent(a.id, new jl.ref.WeakReference[AccumulatorV2[_, _]](a))
}
//...............
}
//用于计算64位整数的求和、计数和平均值的[[AcumulatorV2累加器]]
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private var _sum = 0L
private var _count = 0L
override def add(v: jl.Long): Unit = {
_sum += v
_count += 1
}
def count: Long = _count
def sum: Long = _sum
def avg: Double = _sum.toDouble / _count
override def value: jl.Long = _sum
override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
case o: LongAccumulator =>
_sum += o.sum
_count += o.count
case _ =>
//.....抛异常....
}
//...............
}
//用于计算双精度浮点数的求和、计数和平均值的累加器
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private var _sum = 0.0
private var _count = 0L
override def add(v: jl.Double): Unit = {
_sum += v
_count += 1
}
def count: Long = _count
def sum: Double = _sum
def avg: Double = _sum / _count
override def value: jl.Double = _sum
override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
case o: DoubleAccumulator =>
_sum += o.sum
_count += o.count
case _ =>
//.....抛异常....
}
//...............
}
//用于收集元素列表的[[AcumulatorV2累加器]]
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
private var _list: java.util.List[T] = _
override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match {
case o: CollectionAccumulator[T] => this.synchronized(getOrCreate.addAll(o.value))
case _ => //.....抛异常....
}
private def getOrCreate = {
_list = Option(_list).getOrElse(new java.util.ArrayList[T]())
_list
}
override def value: java.util.List[T] = this.synchronized {
java.util.Collections.unmodifiableList(new ArrayList[T](getOrCreate))
}
//...............
}
}
三、实现累加
从AccumulatorV2中的方法我们可以知道最终是在DAGScheduler调用它的merge方法来实现的累加,下面我们详细看下在什么位置:
private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
//当任务完成后会调用handleTaskCompletion
case completion: CompletionEvent =>
dagScheduler.handleTaskCompletion(completion)
}
private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = {
val task = event.task
val stageId = task.stageId
//........省略........
//确保在任何其他处理发生之前更新任务的累加器,以便我们可以在更新任何作业或阶段之前发布任务结束事件。
event.reason match {
case Success =>
task match {
case rt: ResultTask[_, _] =>
val resultStage = stage.asInstanceOf[ResultStage]
resultStage.activeJob match {
case Some(job) =>
// 对于每个结果任务,只更新一次累加器
if (!job.finished(rt.outputId)) {
updateAccumulators(event)
}
case None => // 如果任务的作业已完成,则忽略更新
}
case _ =>
updateAccumulators(event)
}
case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
case _ =>
}
//........省略........
}
private def updateAccumulators(event: CompletionEvent): Unit = {
val task = event.task
val stage = stageIdToStage(task.stageId)
event.accumUpdates.foreach { updates =>
val id = updates.id
try {
// 在Driver上找到相应的累加器并更新
val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match {
case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]]
case None =>
throw new SparkException(s"attempted to access non-existent accumulator $id")
}
acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]])
} catch {
//......异常处理........
}
}
}
四、总结
1、通过SparkContext创建累加器(LongAccumulator、DoubleAccumulator、CollectionAccumulator)
2、在Driver端注册累加器(累加器必须先注册再使用)(其实就是向全局Map中放入了该累加器)
3、累加器从0开始计数,在每层Stage对应的Task结束时通过merge方法更新Driver端的累计器
4、当一个Job跑完后我们就可以使用累加器变量了,如果是数值型可以拿到总和、累加次数、平均值,如果时集合型可以拿到一个数据序列