卷首語
前一篇文章 hive UDAF開發入門和運行過程詳解(轉) 里面講過UDAF的開發過程,其中說到如果要深入理解UDAF的執行,可以看看求平均值的UDF的源碼
本人在看完源碼后,也還是沒能十分理解里面的內容,于是動手再自己開發一個新的函數,試圖多實踐中理解它
?
函數功能介紹
函數的功能比較蛋疼,我們都知道Hive中有幾個常用的聚合函數:sum,max,min,avg
現在要用一個函數來同時實現倆個不同的功能,對于同一個key,要求返回指定value集合中的最大值與最小值
這里面涉及到一個難點,函數接收到的數據只有一個,但是要同時產生出倆個新的數據出來,且具備一定的邏輯關系
語言描述這東西我不大懂,想了好久,還是直接上代碼得了。。。。。。。。。。。。。
?
源碼
?
package org.juefan.udaf; import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.util.StringUtils; /** * GenericUDAFMaxMin. */ @Description(name = "maxmin", value = "_FUNC_(x) - Returns the max and min value of a set of numbers" ) public class GenericUDAFMaxMin extends AbstractGenericUDAFResolver { static final Log LOG = LogFactory.getLog(GenericUDAFMaxMin. class .getName()); @Override public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException { if (parameters.length != 1 ) { throw new UDFArgumentTypeException(parameters.length - 1 , "Exactly one argument is expected." ); } if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { throw new UDFArgumentTypeException(0 , "Only primitive type arguments are accepted but " + parameters[0].getTypeName() + " is passed." ); } switch (((PrimitiveTypeInfo) parameters[0 ]).getPrimitiveCategory()) { case BYTE: case SHORT: case INT: case LONG: case FLOAT: case DOUBLE: case STRING: case TIMESTAMP: return new GenericUDAFMaxMinEvaluator(); case BOOLEAN: default : throw new UDFArgumentTypeException(0 , "Only numeric or string type arguments are accepted but " + parameters[0].getTypeName() + " is passed." ); } } /** * GenericUDAFMaxMinEvaluator. * */ public static class GenericUDAFMaxMinEvaluator extends GenericUDAFEvaluator { // For PARTIAL1 and COMPLETE PrimitiveObjectInspector inputOI; // For PARTIAL2 and FINAL StructObjectInspector soi; // 封裝好的序列化數據接口,存儲計算過程中的最大值與最小值 StructField maxField; StructField minField; // 存儲數據,利用get()可直接返回double類型值 DoubleObjectInspector maxFieldOI; DoubleObjectInspector minFieldOI; // For PARTIAL1 and PARTIAL2 // 存儲中間的結果 Object[] partialResult; // For FINAL and COMPLETE // 最終輸出的數據 Text result; @Override public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException { assert (parameters.length == 1 ); super .init(m, parameters); // 初始化數據輸入過程 if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) { inputOI = (PrimitiveObjectInspector) parameters[0 ]; } else { // 如果接收到的數據是中間數據,則轉換成相應的結構體 soi = (StructObjectInspector) parameters[0 ]; // 獲取指定字段的序列化數據 maxField = soi.getStructFieldRef("max" ); minField = soi.getStructFieldRef("min" ); // 獲取指定字段的實際數據 maxFieldOI = (DoubleObjectInspector) maxField.getFieldObjectInspector(); minFieldOI = (DoubleObjectInspector) minField.getFieldObjectInspector(); } // 初始化數據輸出過程 if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) { // 輸出的數據是一個結構體,其中包含了max和min的值 // 存儲結構化數據類型 ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector> (); foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); // 存儲結構化數據的字段名稱 ArrayList<String> fname = new ArrayList<String> (); fname.add( "max" ); fname.add( "min" ); partialResult = new Object[2 ]; partialResult[ 0] = new DoubleWritable(0 ); partialResult[ 1] = new DoubleWritable(0 ); return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); } else { // 如果執行到了最后一步,則指定相應的輸出數據類型 result = new Text("" ); return PrimitiveObjectInspectorFactory.writableStringObjectInspector; } } static class AverageAgg implements AggregationBuffer { double max; double min; }; @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { AverageAgg result = new AverageAgg(); reset(result); return result; } @Override public void reset(AggregationBuffer agg) throws HiveException { AverageAgg myagg = (AverageAgg) agg; myagg.max = Double.MIN_VALUE; myagg.min = Double.MAX_VALUE; } boolean warned = false ; @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { assert (parameters.length == 1 ); Object p = parameters[0 ]; if (p != null ) { AverageAgg myagg = (AverageAgg) agg; try { // 獲取輸入數據,并進行相應的大小判斷 double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI); if (myagg.max < v){ myagg.max = v; } if (myagg.min > v){ myagg.min = v; } } catch (NumberFormatException e) { if (! warned) { warned = true ; LOG.warn(getClass().getSimpleName() + " " + StringUtils.stringifyException(e)); LOG.warn(getClass().getSimpleName() + " ignoring similar exceptions." ); } } } } @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { // 將中間計算出的結果封裝好返回給下一步操作 AverageAgg myagg = (AverageAgg) agg; ((DoubleWritable) partialResult[ 0 ]).set(myagg.max); ((DoubleWritable) partialResult[ 1 ]).set(myagg.min); return partialResult; } @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null ) { // 此處partial接收到的是terminatePartial的輸出數據 AverageAgg myagg = (AverageAgg) agg; Object partialmax = soi.getStructFieldData(partial, maxField); Object partialmin = soi.getStructFieldData(partial, minField); if (myagg.max < maxFieldOI.get(partialmax)){ myagg.max = maxFieldOI.get(partialmax); } if (myagg.min > minFieldOI.get(partialmin)){ myagg.min = minFieldOI.get(partialmin); } } } @Override public Object terminate(AggregationBuffer agg) throws HiveException { // 將最終的結果合并成字符串后輸出 AverageAgg myagg = (AverageAgg) agg; if (myagg.max == 0 ) { return null ; } else { result.set(myagg.max + "\t" + myagg.min); return result; } } } }
?
?
?
寫完后還是覺得沒有怎么理解透整個過程,所以上面的注釋也就將就著看了,不保證一定正確的!
下午加上一些輸出跟蹤一下執行過程才行,不過代碼的邏輯是沒有問題的了,本人運行過!
更多文章、技術交流、商務合作、聯系博主
微信掃碼或搜索:z360901061

微信掃一掃加我為好友
QQ號聯系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元
