卷首語
前一篇文章 hive UDAF開發入門和運行過程詳解(轉) 里面講過UDAF的開發過程,其中說到如果要深入理解UDAF的執行,可以看看求平均值的UDF的源碼
本人在看完源碼后,也還是沒能十分理解里面的內容,于是動手再自己開發一個新的函數,試圖多實踐中理解它
?
函數功能介紹
函數的功能比較蛋疼,我們都知道Hive中有幾個常用的聚合函數:sum,max,min,avg
現在要用一個函數來同時實現倆個不同的功能,對于同一個key,要求返回指定value集合中的最大值與最小值
這里面涉及到一個難點,函數接收到的數據只有一個,但是要同時產生出倆個新的數據出來,且具備一定的邏輯關系
語言描述這東西我不大懂,想了好久,還是直接上代碼得了。。。。。。。。。。。。。
?
源碼
?
package
org.juefan.udaf;
import
java.util.ArrayList;
import
org.apache.commons.logging.Log;
import
org.apache.commons.logging.LogFactory;
import
org.apache.hadoop.hive.ql.exec.Description;
import
org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import
org.apache.hadoop.hive.ql.metadata.HiveException;
import
org.apache.hadoop.hive.ql.parse.SemanticException;
import
org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import
org.apache.hadoop.hive.serde2.io.DoubleWritable;
import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.StructField;
import
org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import
org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import
org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import
org.apache.hadoop.io.LongWritable;
import
org.apache.hadoop.io.Text;
import
org.apache.hadoop.util.StringUtils;
/**
* GenericUDAFMaxMin.
*/
@Description(name
= "maxmin", value = "_FUNC_(x) - Returns the max and min value of a set of numbers"
)
public
class
GenericUDAFMaxMin
extends
AbstractGenericUDAFResolver {
static
final
Log LOG = LogFactory.getLog(GenericUDAFMaxMin.
class
.getName());
@Override
public
GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
throws
SemanticException {
if
(parameters.length != 1
) {
throw
new
UDFArgumentTypeException(parameters.length - 1
,
"Exactly one argument is expected."
);
}
if
(parameters[0].getCategory() !=
ObjectInspector.Category.PRIMITIVE) {
throw
new
UDFArgumentTypeException(0
,
"Only primitive type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed."
);
}
switch
(((PrimitiveTypeInfo) parameters[0
]).getPrimitiveCategory()) {
case
BYTE:
case
SHORT:
case
INT:
case
LONG:
case
FLOAT:
case
DOUBLE:
case
STRING:
case
TIMESTAMP:
return
new
GenericUDAFMaxMinEvaluator();
case
BOOLEAN:
default
:
throw
new
UDFArgumentTypeException(0
,
"Only numeric or string type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed."
);
}
}
/**
* GenericUDAFMaxMinEvaluator.
*
*/
public
static
class
GenericUDAFMaxMinEvaluator
extends
GenericUDAFEvaluator {
//
For PARTIAL1 and COMPLETE
PrimitiveObjectInspector inputOI;
//
For PARTIAL2 and FINAL
StructObjectInspector soi;
//
封裝好的序列化數據接口,存儲計算過程中的最大值與最小值
StructField maxField;
StructField minField;
//
存儲數據,利用get()可直接返回double類型值
DoubleObjectInspector maxFieldOI;
DoubleObjectInspector minFieldOI;
//
For PARTIAL1 and PARTIAL2
//
存儲中間的結果
Object[] partialResult;
//
For FINAL and COMPLETE
//
最終輸出的數據
Text result;
@Override
public
ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws
HiveException {
assert
(parameters.length == 1
);
super
.init(m, parameters);
//
初始化數據輸入過程
if
(m == Mode.PARTIAL1 || m ==
Mode.COMPLETE) {
inputOI
= (PrimitiveObjectInspector) parameters[0
];
}
else
{
//
如果接收到的數據是中間數據,則轉換成相應的結構體
soi = (StructObjectInspector) parameters[0
];
//
獲取指定字段的序列化數據
maxField = soi.getStructFieldRef("max"
);
minField
= soi.getStructFieldRef("min"
);
//
獲取指定字段的實際數據
maxFieldOI =
(DoubleObjectInspector) maxField.getFieldObjectInspector();
minFieldOI
=
(DoubleObjectInspector) minField.getFieldObjectInspector();
}
//
初始化數據輸出過程
if
(m == Mode.PARTIAL1 || m ==
Mode.PARTIAL2) {
//
輸出的數據是一個結構體,其中包含了max和min的值
//
存儲結構化數據類型
ArrayList<ObjectInspector> foi =
new
ArrayList<ObjectInspector>
();
foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
//
存儲結構化數據的字段名稱
ArrayList<String> fname =
new
ArrayList<String>
();
fname.add(
"max"
);
fname.add(
"min"
);
partialResult
=
new
Object[2
];
partialResult[
0] =
new
DoubleWritable(0
);
partialResult[
1] =
new
DoubleWritable(0
);
return
ObjectInspectorFactory.getStandardStructObjectInspector(fname,
foi);
}
else
{
//
如果執行到了最后一步,則指定相應的輸出數據類型
result =
new
Text(""
);
return
PrimitiveObjectInspectorFactory.writableStringObjectInspector;
}
}
static
class
AverageAgg
implements
AggregationBuffer {
double
max;
double
min;
};
@Override
public
AggregationBuffer getNewAggregationBuffer()
throws
HiveException {
AverageAgg result
=
new
AverageAgg();
reset(result);
return
result;
}
@Override
public
void
reset(AggregationBuffer agg)
throws
HiveException {
AverageAgg myagg
=
(AverageAgg) agg;
myagg.max
=
Double.MIN_VALUE;
myagg.min
=
Double.MAX_VALUE;
}
boolean
warned =
false
;
@Override
public
void
iterate(AggregationBuffer agg, Object[] parameters)
throws
HiveException {
assert
(parameters.length == 1
);
Object p
= parameters[0
];
if
(p !=
null
) {
AverageAgg myagg
=
(AverageAgg) agg;
try
{
//
獲取輸入數據,并進行相應的大小判斷
double
v =
PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
if
(myagg.max <
v){
myagg.max
=
v;
}
if
(myagg.min >
v){
myagg.min
=
v;
}
}
catch
(NumberFormatException e) {
if
(!
warned) {
warned
=
true
;
LOG.warn(getClass().getSimpleName()
+ " "
+
StringUtils.stringifyException(e));
LOG.warn(getClass().getSimpleName()
+ " ignoring similar exceptions."
);
}
}
}
}
@Override
public
Object terminatePartial(AggregationBuffer agg)
throws
HiveException {
//
將中間計算出的結果封裝好返回給下一步操作
AverageAgg myagg =
(AverageAgg) agg;
((DoubleWritable) partialResult[
0
]).set(myagg.max);
((DoubleWritable) partialResult[
1
]).set(myagg.min);
return
partialResult;
}
@Override
public
void
merge(AggregationBuffer agg, Object partial)
throws
HiveException {
if
(partial !=
null
) {
//
此處partial接收到的是terminatePartial的輸出數據
AverageAgg myagg =
(AverageAgg) agg;
Object partialmax
=
soi.getStructFieldData(partial, maxField);
Object partialmin
=
soi.getStructFieldData(partial, minField);
if
(myagg.max <
maxFieldOI.get(partialmax)){
myagg.max
=
maxFieldOI.get(partialmax);
}
if
(myagg.min >
minFieldOI.get(partialmin)){
myagg.min
=
minFieldOI.get(partialmin);
}
}
}
@Override
public
Object terminate(AggregationBuffer agg)
throws
HiveException {
//
將最終的結果合并成字符串后輸出
AverageAgg myagg =
(AverageAgg) agg;
if
(myagg.max == 0
) {
return
null
;
}
else
{
result.set(myagg.max
+ "\t" +
myagg.min);
return
result;
}
}
}
}
?
?
?
寫完后還是覺得沒有怎么理解透整個過程,所以上面的注釋也就將就著看了,不保證一定正確的!
下午加上一些輸出跟蹤一下執行過程才行,不過代碼的邏輯是沒有問題的了,本人運行過!
更多文章、技術交流、商務合作、聯系博主
微信掃碼或搜索:z360901061
微信掃一掃加我為好友
QQ號聯系: 360901061
您的支持是博主寫作最大的動力,如果您喜歡我的文章,感覺我的文章對您有幫助,請用微信掃描下面二維碼支持博主2元、5元、10元、20元等您想捐的金額吧,狠狠點擊下面給點支持吧,站長非常感激您!手機微信長按不能支付解決辦法:請將微信支付二維碼保存到相冊,切換到微信,然后點擊微信右上角掃一掃功能,選擇支付二維碼完成支付。
【本文對您有幫助就好】元

