您当前的位置: 首页 >  ar

宝哥大数据

暂无认证

  • 0浏览

    0关注

    1029博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

Spark SQL 用户自定义函数UDF、用户自定义聚合函数UDAF

宝哥大数据 发布时间:2018-07-19 23:10:54 ,浏览量:0

在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:
  • UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
  • UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg
  • UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap
一、自定义UDF 拼接三个参数, 1.1继承org.apache.spark.sql.api.java.UDFxx(1-22); 1.2、实现call方法
    @Override
    public String call(Long v1, String v2, String split) throws Exception {
        return String.valueOf(v1) + split + v2;
    }
完整代码实现
package com.chb.shopanalysis.hive.UDF;

import org.apache.spark.sql.api.java.UDF3;

/**
 * 自定义UDF
 * 1 上海  split
 * 拼接成"1:上海"
 * 将两个字段拼接起来(使用指定的分隔符)
 * @author chb
 *
 */
public class ConcatLongStringUDF implements UDF3 {

    private static final long serialVersionUID = 1L;

    @Override
    public String call(Long v1, String v2, String split) throws Exception {
        return String.valueOf(v1) + split + v2;
    }

}
1.4、注册函数
        // 注册自定义函数
        sqlContext.udf().register(
                "concat_long_string",       //自定义函数的名称
                new ConcatLongStringUDF(),  //自定义UDF对象
                DataTypes.StringType);      //返回数据类型
1.5、使用函数
    /**
     * 从hive表中读取数据, 使用自定义聚合函数
     */
    private static void readProductClickInfo() {

        // 可以获取到每个area下的每个product_id的城市信息拼接起来的串

        String sql = 
                "SELECT city_id, city_name,"
                    + "area,"
                    + "product_id,"
                    + "concat_long_string(city_id,city_name,':') city_infos "  
                + "FROM click_product_basic ";



        // 使用Spark SQL执行这条SQL语句
        DataFrame df = sqlContext.sql(sql);
        //展示结果
        df.show();

    }

这里写图片描述

二、用户自定义聚合函数UDAF 2.1、继承org.apache.spark.sql.expressions.UserDefinedAggregateFunction 2.2、定义输入,缓存,输出字段类型
    // 指定输入数据的字段与类型
    private StructType inputSchema = DataTypes.createStructType(Arrays.asList(
            DataTypes.createStructField("cityInfo", DataTypes.StringType, true)));  
    // 指定缓冲数据的字段与类型
    private StructType bufferSchema = DataTypes.createStructType(Arrays.asList(
            DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true)));  
    // 指定返回类型
    private DataType dataType = DataTypes.StringType;
2.3、deterministic()决定每次相同输入,是否返回相同输出, 一般都会设置为true.
    @Override
    //每次相同的输入是否返回相同的输出
    public boolean deterministic() {
        return deterministic;
    }
2.4、初始化
    /**
     * 初始化
     * 可以认为是,你自己在内部指定一个初始的值
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, "");  
    }
2.5、更新, 这个是组类根据自己的逻辑进行拼接, 然后更新数据
    /**
     * 更新
     * 可以认为是,一个一个地将组内的字段值传递进来
     * 实现拼接的逻辑
     */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        // 缓冲中的已经拼接过的城市信息串
        String bufferCityInfo = buffer.getString(0);
        // 刚刚传递进来的某个城市信息
        String cityInfo = input.getString(0);

        // 在这里要实现去重的逻辑
        // 判断:之前没有拼接过某个城市信息,那么这里才可以接下去拼接新的城市信息
        if(!bufferCityInfo.contains(cityInfo)) {
            if("".equals(bufferCityInfo)) {
                bufferCityInfo += cityInfo;
            } else {
                // 比如1:北京
                //2:上海
                //结果 1:北京,2:上海
                //再 来一个 1:北京  就不会拼接进去。
                bufferCityInfo += "," + cityInfo;
            }

            buffer.update(0, bufferCityInfo);  
        }
    }
2.6、合并, 将所有节点的数据进行合并
    /**
     * 合并
     * update操作,可能是针对一个分组内的部分数据,在某个节点上发生的
     * 但是可能一个分组内的数据,会分布在多个节点上处理
     * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
     */
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        String bufferCityInfo1 = buffer1.getString(0);
        String bufferCityInfo2 = buffer2.getString(0);

        for(String cityInfo : bufferCityInfo2.split(",")) {
            if(!bufferCityInfo1.contains(cityInfo)) {
                if("".equals(bufferCityInfo1)) {
                    bufferCityInfo1 += cityInfo;
                } else {
                    bufferCityInfo1 += "," + cityInfo;
                }
            }
        }
        buffer1.update(0, bufferCityInfo1);  
    }
2.7、输出最终结果, 可能我们需要的输出格式,可以在该方法中,进行格式化。
        @Override
        //计算出最终结果
        public Object evaluate(Row row) {  
            return row.getString(0);  
        }
2.8、注册函数
        sqlContext.udf().register("group_concat_distinct", 
                new GroupConcatDistinctUDAF());
2.9、使用
    /**
     * 从hive表中读取数据, 使用自定义聚合函数
     */
    private static void readProductClickInfo() {
        // 按照area和product_id两个字段进行分组
        // 计算出各区域各商品的点击次数
        // 可以获取到每个area下的每个product_id的城市信息拼接起来的串

        String sql =  "SELECT  area, product_id,"
                + "count(*) click_count, "  
                + "group_concat_distinct(concat_long_string(city_id,city_name,':')) city_infos "  
                + "FROM click_product_basic "
                + "GROUP BY area,product_id "; 

        // 使用Spark SQL执行这条SQL语句
        DataFrame df = sqlContext.sql(sql);

        df.show();
        // 再次将查询出来的数据注册为一个临时表
        // 各区域各商品的点击次数(以及额外的城市列表)
        df.registerTempTable("tmp_area_product_click_count");    
    }
关注
打赏
1587549273
查看更多评论
立即登录/注册

微信扫码登录

0.0410s