简介

UDAF(User Defined Aggregate Function),即用户自定义聚合函数,至于啥叫聚合函数,用来干嘛的,熟悉 SQL 的自不必多说,而且 UDAF 面向的是 SparkSQL,熟悉 SQL 是前提条件。

场景

在一次我对 Spark 项目优化过程中,需要将一个复杂的计算从 Driver 端提取出来,重新设计然后放入 SparkSQL 中进行计算,但是已有的聚合函数是完全无法满足需求的,我需要处理的数据包含三个列,一般的聚合函数只能满足一列,这时候就需要使用自定义聚合函数了。

(PS:至于为什么要从 Driver 端提取出来是因为历史原因,这个放到以后的 Spark 优化方案博客中说明)

UDAF 的使用

自定义 UDAF 一共分为三步:

  1. 自定义类继承 UserDefinedAggregateFunction 类,并实现对应的方法;
  2. 使用 Spark 对定义好的类进行注册,并提供一个可在 SQL 语句中调用的函数名;
  3. SQL 中使用;

自定义 UDAF 类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

/**
* @FileName: CustomMaxSpeedUdaf
* @Author: ChenTianyi
* @Date: 2021-01-28 9:23
* @Description: 自定义获取最大速度(降序第 5% 个位置处的速度)
*/
class CustomMaxSpeedUdaf extends UserDefinedAggregateFunction {

/**
* 集合函数传递进来的值
* 使用 List 搜集传递进来的值的类型,这里只有 13 中类型,分别是:
* StringType, BinaryType, BooleanType, DateType, TimestampType
* CalendarIntervalType, DoubleType, FloatType, ByteType, IntegerType, LongType, ShortType, NullType
* 注意:这里字段的顺序需要跟你传递进来的字段顺序一致
*/
override def inputSchema: StructType = StructType(
List(
StructField("speed_in", DataTypes.IntegerType)
)
)

// 用于缓存计算的中间结果
override def bufferSchema: StructType = StructType(
List(
StructField("speed_buffer", DataTypes.StringType)
)
)

// 最终计算完成返回的结果类型
override def dataType: DataType = DataTypes.StringType

// 每次相同的输入是否返回相同的输出,为保证一致性,建议使用 true
override def deterministic: Boolean = true

// 初始化缓存值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
}

// 用于更新中间缓存值 (update 相当于在每个分区中计算)
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getString(0) + "," + input.getInt(0).toString
}

// 所有分区中的中间结果聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getString(0) + "," + buffer2.getString(0)
}

// 计算最终结果
override def evaluate(buffer: Row): Any = {
val speeds = buffer.getString(0)
.split(",")
.filter(_.nonEmpty)
.map(speed => {
if (speed.isEmpty) 0 else speed.toInt
})
.toList
.sortWith((x, y) => {
x.compareTo(y) > 0
})
val index = (speeds.length * 0.05).toInt
speeds(index).toString
}
}

Spark 对其进行注册

1
2
3
// 注册自定义 spark sql 函数
// 这里注意了,要用 SparkSession 实例进行注册,不是 SparkContext
sparksession.sqlContext.udf.register("CUSTOM_MAX_SPEED", new CustomMaxSpeedUdaf)

在 SQL 中调用

1
2
3
-- 不贴实际的业务代码了,大概表示一下
SELECT CUSTOM_MAX_SPEED(speed) AS custom_max_speed
FROM TABLE

总结

开发人员可以根据自己的业务需求,将大部分计算规则通过 UDAF 的方式实现并注册,通过 SparkSQL 进行调用,而 SparkSQL 是在多个 Executor 上执行的,可以大幅度提高运行效率。就我优化的这个项目来说,在进行优化之后,一天的数据计算时间从15 分钟减少为 45 秒左右,提升了 20 多倍;而历史数据计算时间从1 天半缩减为 16 分钟左右,这个结果可想而知;合理设计 Spark 计算任务在对于离线计算中的效率提升是有很大的帮助的。