diff --git ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDecimalToFloat.java ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDecimalToFloat.java new file mode 100644 index 0000000..4ef5422 --- /dev/null +++ ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/CastDecimalToFloat.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hive.ql.exec.vector.expressions; + +import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; + +/** + * Cast a decimal to float based on decimal to double function. + * + */ +public class CastDecimalToFloat extends FuncDecimalToDouble { + + private static final long serialVersionUID = 1L; + + public CastDecimalToFloat() { + super(); + } + + public CastDecimalToFloat(int inputCol, int outputColumnNum) { + super(inputCol, outputColumnNum); + } + + protected void func(DoubleColumnVector outV, DecimalColumnVector inV, int i) { + outV.vector[i] = inV.vector[i].floatValue(); + } +} diff --git ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToFloat.java ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToFloat.java index fd49d1f..53c59b3 100755 --- ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToFloat.java +++ ql/src/java/org/apache/hadoop/hive/ql/udf/UDFToFloat.java @@ -20,7 +20,7 @@ import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions; -import org.apache.hadoop.hive.ql.exec.vector.expressions.CastDecimalToDouble; +import org.apache.hadoop.hive.ql.exec.vector.expressions.CastDecimalToFloat; import org.apache.hadoop.hive.ql.exec.vector.expressions.CastStringToFloat; import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.CastLongToFloatViaLongToDouble; import org.apache.hadoop.hive.ql.exec.vector.expressions.CastTimestampToDouble; @@ -42,7 +42,7 @@ * */ @VectorizedExpressions({CastTimestampToDouble.class, CastLongToFloatViaLongToDouble.class, - CastDecimalToDouble.class, CastStringToFloat.class}) + CastDecimalToFloat.class, CastStringToFloat.class}) public class UDFToFloat extends UDF { private final FloatWritable floatWritable = new FloatWritable(); diff --git ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTypeCasts.java ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTypeCasts.java index 6aa6da9..0485e8e 100644 --- ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTypeCasts.java +++ ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorTypeCasts.java @@ -377,6 +377,54 @@ private VectorizedRowBatch getBatchDecimalDouble() { return b; } + + @Test + public void testCastDecimalToFloat() { + + final double eps = 0.00000000000001d; // tolerance to check float equality + + double f1 = HiveDecimal.create("1.1").floatValue(); + double f2 = HiveDecimal.create("-2.2").floatValue(); + double f3 = HiveDecimal.create("9999999999999999.00").floatValue(); + + // test basic case + VectorizedRowBatch b = getBatchDecimalDouble(); + VectorExpression expr = new CastDecimalToFloat(0, 1); + expr.evaluate(b); + DoubleColumnVector r = (DoubleColumnVector) b.cols[1]; + assertEquals(f1, r.vector[0], eps); + assertEquals(f2, r.vector[1], eps); + assertEquals(f3, r.vector[2], eps); + + // test with nulls in input + b = getBatchDecimalDouble(); + b.cols[0].noNulls = false; + b.cols[0].isNull[1] = true; + expr.evaluate(b); + r = (DoubleColumnVector) b.cols[1]; + assertFalse(r.noNulls); + assertTrue(r.isNull[1]); + assertFalse(r.isNull[0]); + assertEquals(f1, r.vector[0], eps); + + // test repeating case + b = getBatchDecimalDouble(); + b.cols[0].isRepeating = true; + expr.evaluate(b); + r = (DoubleColumnVector) b.cols[1]; + assertTrue(r.isRepeating); + assertEquals(f1, r.vector[0], eps); + + // test repeating nulls case + b = getBatchDecimalDouble(); + b.cols[0].isRepeating = true; + b.cols[0].noNulls = false; + b.cols[0].isNull[0] = true; + expr.evaluate(b); + r = (DoubleColumnVector) b.cols[1]; + assertTrue(r.isRepeating); + assertTrue(r.isNull[0]); + } @Test public void testCastDecimalToString() throws HiveException { diff --git ql/src/test/queries/clientpositive/vectorization_parquet_ppd_decimal.q ql/src/test/queries/clientpositive/vectorization_parquet_ppd_decimal.q new file mode 100644 index 0000000..006caac --- /dev/null +++ ql/src/test/queries/clientpositive/vectorization_parquet_ppd_decimal.q @@ -0,0 +1,165 @@ +SET hive.vectorized.execution.enabled=true; +SET hive.input.format=org.apache.hadoop.hive.ql.io.HiveInputFormat; +SET mapred.min.split.size=1000; +SET mapred.max.split.size=5000; +set hive.llap.cache.allow.synthetic.fileid=true; + +create table newtypestbl(c char(10), v varchar(10), d decimal(5,3), da date) stored as parquet; + +insert overwrite table newtypestbl select * from (select cast("apple" as char(10)), cast("bee" as varchar(10)), 0.22, cast("1970-02-20" as date) from src src1 union all select cast("hello" as char(10)), cast("world" as varchar(10)), 11.22, cast("1970-02-27" as date) from src src2 limit 10) uniontbl; + +-- decimal data types (EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_EQUALS, IN, BETWEEN tests) +select * from newtypestbl where d=0.22; + +set hive.optimize.index.filter=true; +select * from newtypestbl where d=0.22; + +set hive.optimize.index.filter=false; +select * from newtypestbl where d='0.22'; + +set hive.optimize.index.filter=true; +select * from newtypestbl where d='0.22'; + +set hive.optimize.index.filter=false; +select * from newtypestbl where d=cast('0.22' as float); + +set hive.optimize.index.filter=true; +select * from newtypestbl where d=cast('0.22' as float); + +set hive.optimize.index.filter=false; +select * from newtypestbl where d!=0.22; + +set hive.optimize.index.filter=true; +select * from newtypestbl where d!=0.22; + +set hive.optimize.index.filter=false; +select * from newtypestbl where d!='0.22'; + +set hive.optimize.index.filter=true; +select * from newtypestbl where d!='0.22'; + +set hive.optimize.index.filter=false; +select * from newtypestbl where d!=cast('0.22' as float); + +set hive.optimize.index.filter=true; +select * from newtypestbl where d!=cast('0.22' as float); + +set hive.optimize.index.filter=false; +select * from newtypestbl where d<11.22; + +set hive.optimize.index.filter=true; +select * from newtypestbl where d<11.22; + +set hive.optimize.index.filter=false; +select * from newtypestbl where d<'11.22'; + +set hive.optimize.index.filter=true; +select * from newtypestbl where d<'11.22'; + +set hive.optimize.index.filter=false; +select * from newtypestbl where d