SparkSQL用UDAF实现Bitmap函数
创建测试表
使用phoenix在HBase中创建测试表,字段使用VARBINARY类型
CREATE TABLE IF NOT EXISTS test_binary (
date VARCHAR NOT NULL,
dist_mem VARBINARY
CONSTRAINT test_binary_pk PRIMARY KEY (date)
) SALT_BUCKETS=6;
创建完成后使用RoaringBitmap序列化数据存入数据库:
实现自定义聚合函数bitmap
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.roaringbitmap.RoaringBitmap;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
/**
* 实现自定义聚合函数Bitmap
*/
public class UdafBitMap extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField("field", DataTypes.BinaryType, true));
return DataTypes.createStructType(structFields);
}
@Override
public StructType bufferSchema() {
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField("field", DataTypes.BinaryType, true));
return DataTypes.createStructType(structFields);
}
@Override
public DataType dataType() {
return DataTypes.LongType;
}
@Override
public boolean deterministic() {
//是否强制每次执行的结果相同
return false;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
//初始化
buffer.update(0, null);
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
// 相同的executor间的数据合并
// 1. 输入为空直接返回不更新
Object in = input.get(0);
if(in == null){
return ;
}
// 2. 源为空则直接更新值为输入
byte[] inBytes = (byte[]) in;
Object out = buffer.get(0);
if(out == null){
buffer.update(0, inBytes);
return ;
}
// 3. 源和输入都不为空使用bitmap去重合并
byte[] outBytes = (byte[]) out;
byte[] result = outBytes;
RoaringBitmap outRR = new RoaringBitmap();
RoaringBitmap inRR = new RoaringBitmap();
try {
outRR.deserialize(new DataInputStream(new ByteArrayInputStream(outBytes)));
inRR.deserialize(new DataInputStream(new ByteArrayInputStream(inBytes)));
outRR.or(inRR);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
outRR.serialize(new DataOutputStream(bos));
result = bos.toByteArray();
} catch (IOException e) {
e.printStackTrace();
}
buffer.update(0, result);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
//不同excutor间的数据合并
update(buffer1, buffer2);
}
@Override
public Object evaluate(Row buffer) {
//根据Buffer计算结果
long r = 0l;
Object val = buffer.get(0);
if (val != null) {
RoaringBitmap rr = new RoaringBitmap();
try {
rr.deserialize(new DataInputStream(new ByteArrayInputStream((byte[]) val)));
r = rr.getLongCardinality();
} catch (IOException e) {
e.printStackTrace();
}
}
return r;
}
}
调用示例
/**
* 使用自定义函数解析bitmap
*
* @param sparkSession
* @return
*/
private static void udafBitmap(SparkSession sparkSession) {
try {
Properties prop = PropUtil.loadProp(DB_PHOENIX_CONF_FILE);
// JDBC连接属性
Properties connProp = new Properties();
connProp.put("driver", prop.getProperty(DB_PHOENIX_DRIVER));
connProp.put("user", prop.getProperty(DB_PHOENIX_USER));
connProp.put("password", prop.getProperty(DB_PHOENIX_PASS));
connProp.put("fetchsize", prop.getProperty(DB_PHOENIX_FETCHSIZE));
// 注册自定义聚合函数
sparkSession.udf().register("bitmap",new UdafBitMap());
sparkSession
.read()
.jdbc(prop.getProperty(DB_PHOENIX_URL), "test_binary", connProp)
// sql中必须使用global_temp.表名,否则找不到
.createOrReplaceGlobalTempView("test_binary");
//sparkSession.sql("select YEAR(TO_DATE(date)) year,bitmap(dist_mem) memNum from global_temp.test_binary group by YEAR(TO_DATE(date))").show();
sparkSession.sql("select date,bitmap(dist_mem) memNum from global_temp.test_binary group by date").show();
} catch (Exception e) {
e.printStackTrace();
}
}
结果: