Details

    • Type: Sub-task
    • Status: Closed
    • Priority: Major
    • Resolution: Implemented
    • Affects Version/s: None
    • Fix Version/s: 1.3.0
    • Component/s: Table API & SQL
    • Labels:
      None

      Issue Links

        Activity

        Hide
        githubbot ASF GitHub Bot added a comment -

        GitHub user shaoxuan-wang opened a pull request:

        https://github.com/apache/flink/pull/3809

        FLINK-5906 [table] Add support to register UDAGG in Table and SQL API

        Thanks for contributing to Apache Flink. Before you open your pull request, please take the following check list into consideration.
        If your changes take all of the items into account, feel free to open your pull request. For more information and/or questions please refer to the [How To Contribute guide](http://flink.apache.org/how-to-contribute.html).
        In addition to going through the list, please provide a meaningful description of your changes.

        • [x] General
        • The pull request references the related JIRA issue ("[FLINK-XXX] Jira title text")
        • The pull request addresses only one issue
        • Each commit in the PR has a meaningful commit message (including the JIRA id)
        • [ ] Documentation
        • Documentation has been added for new functionality
        • Old documentation affected by the pull request has been updated
        • JavaDoc for public methods has been added
        • [x] Tests & Build
        • Functionality added by the pull request is covered by tests
        • `mvn clean verify` has been executed successfully locally or a Travis build has passed

        You can merge this pull request into a Git repository by running:

        $ git pull https://github.com/shaoxuan-wang/flink F5906-submit

        Alternatively you can review and apply these changes as the patch at:

        https://github.com/apache/flink/pull/3809.patch

        To close this pull request, make a commit to your master/trunk branch
        with (at least) the following in the commit message:

        This closes #3809


        commit ae2e0ae45f4a4e41a67de6c7feda427b4981c079
        Author: shaoxuan-wang <wshaoxuan@gmail.com>
        Date: 2017-05-02T15:00:51Z

        FLINK-5906 [table] Add support to register UDAGG in Table and SQL API


        Show
        githubbot ASF GitHub Bot added a comment - GitHub user shaoxuan-wang opened a pull request: https://github.com/apache/flink/pull/3809 FLINK-5906 [table] Add support to register UDAGG in Table and SQL API Thanks for contributing to Apache Flink. Before you open your pull request, please take the following check list into consideration. If your changes take all of the items into account, feel free to open your pull request. For more information and/or questions please refer to the [How To Contribute guide] ( http://flink.apache.org/how-to-contribute.html ). In addition to going through the list, please provide a meaningful description of your changes. [x] General The pull request references the related JIRA issue (" [FLINK-XXX] Jira title text") The pull request addresses only one issue Each commit in the PR has a meaningful commit message (including the JIRA id) [ ] Documentation Documentation has been added for new functionality Old documentation affected by the pull request has been updated JavaDoc for public methods has been added [x] Tests & Build Functionality added by the pull request is covered by tests `mvn clean verify` has been executed successfully locally or a Travis build has passed You can merge this pull request into a Git repository by running: $ git pull https://github.com/shaoxuan-wang/flink F5906-submit Alternatively you can review and apply these changes as the patch at: https://github.com/apache/flink/pull/3809.patch To close this pull request, make a commit to your master/trunk branch with (at least) the following in the commit message: This closes #3809 commit ae2e0ae45f4a4e41a67de6c7feda427b4981c079 Author: shaoxuan-wang <wshaoxuan@gmail.com> Date: 2017-05-02T15:00:51Z FLINK-5906 [table] Add support to register UDAGG in Table and SQL API
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114447396

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala —
        @@ -911,7 +911,21 @@ class GroupedTable(
        */
        def select(fields: String): Table = {
        val fieldExprs = ExpressionParser.parseExpressionList(fields)

        • select(fieldExprs: _*)
          +
          + //get the correct expression for UDAGGFunctionCall
          + val input: Seq[Expression] = fieldExprs.zipWithIndex.map {
          + case (Call(name, args), idx) => {
            • End diff –

        Could also be implemented without `zipWithIndex`:
        ```
        val input: Seq[Expression] = fieldExprs.map {
        case Call(name, args) => {
        val function = table.tableEnv.getFunctionCatalog.lookupFunction(name, args)
        if (function.isInstanceOf[UDAGGFunctionCall])

        { function }

        else

        { Call(name, args) }

        }
        case x => x
        }
        ```

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114447396 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala — @@ -911,7 +911,21 @@ class GroupedTable( */ def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) select(fieldExprs: _*) + + //get the correct expression for UDAGGFunctionCall + val input: Seq [Expression] = fieldExprs.zipWithIndex.map { + case (Call(name, args), idx) => { End diff – Could also be implemented without `zipWithIndex`: ``` val input: Seq [Expression] = fieldExprs.map { case Call(name, args) => { val function = table.tableEnv.getFunctionCatalog.lookupFunction(name, args) if (function.isInstanceOf [UDAGGFunctionCall] ) { function } else { Call(name, args) } } case x => x } ```
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114448027

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala —
        @@ -274,27 +278,73 @@ class CodeGenerator(
        constantFlags: Option[Array[(Int, Boolean)]],
        outputArity: Int,
        needRetract: Boolean,

        • needMerge: Boolean)
          + needMerge: Boolean,
          + needReset: Boolean)
          : GeneratedAggregationsFunction = {

        // get unique function name
        val funcName = newName(name)
        // register UDAGGs
        val aggs = aggregates.map(a => generator.addReusableFunction(a))
        // get java types of accumulators

        • val accTypes = aggregates.map { a =>
        • a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
          + val accTypeClasses = aggregates.map { a =>
          + val accType = TypeExtractor.createTypeInfo(a, classOf[AggregateFunction[_, _]], a.getClass, 1)
            • End diff –

        why doesn't the original approach work anymore?
        IMO, it is better to avoid the TypeExtractor if possible, esp. when trying to extract generic types.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114448027 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala — @@ -274,27 +278,73 @@ class CodeGenerator( constantFlags: Option[Array [(Int, Boolean)] ], outputArity: Int, needRetract: Boolean, needMerge: Boolean) + needMerge: Boolean, + needReset: Boolean) : GeneratedAggregationsFunction = { // get unique function name val funcName = newName(name) // register UDAGGs val aggs = aggregates.map(a => generator.addReusableFunction(a)) // get java types of accumulators val accTypes = aggregates.map { a => a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName + val accTypeClasses = aggregates.map { a => + val accType = TypeExtractor.createTypeInfo(a, classOf[AggregateFunction [_, _] ], a.getClass, 1) End diff – why doesn't the original approach work anymore? IMO, it is better to avoid the TypeExtractor if possible, esp. when trying to extract generic types.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114444069

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala —
        @@ -0,0 +1,309 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.stream.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
        +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
        +import org.apache.flink.streaming.api.TimeCharacteristic
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.streaming.api.datastream.

        {DataStream => JDataStream}

        +import org.apache.flink.streaming.api.environment.

        {StreamExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.streaming.api.scala.

        {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.streaming.api.watermark.Watermark
        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract}

        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset
        +import org.apache.flink.table.api.scala.stream.utils.StreamITCase
        +import org.apache.flink.table.api.

        {SlidingWindow, TableEnvironment, Types}

        +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.types.Row
        +import org.junit.Assert._
        +import org.junit.Test
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.mutable
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataStream
        + * programs is possible.
        + */
        +class DStreamUDAGGITCase
        + extends StreamingMultipleProgramsTestBase {
        +
        + val data = List(
        + //('long, 'int, 'double, 'float, 'bigdec, 'string)
        + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"),
        + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"),
        + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"),
        + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"),
        + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"),
        + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"),
        + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"),
        + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"),
        + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"),
        + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"),
        + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"),
        + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"),
        + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello"))
        +
        + @Test
        + def testUdaggSlidingWindowGroupedAggregate(): Unit =

        { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggSessionWindowGroupedAggregate(): Unit =

        { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val countFun = new CountAggFunction + + val weightAvgWithMergeFun = new WeightedAvgWithMerge + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string) + + val windowedTable = table + .window(Session withGap 5.second on 'rowtime as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('bigdec), + 'float.sum, + weightAvgWithMergeFun('long, 'int), + weightAvgWithMergeFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("Hello,6,27.0,6481,6", "Hi,6,51.0,9313,9", "Hello,1,16.0,16000,16") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggProcTimeUnBoundedPartitionedRowOver(): Unit =

        { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + StreamITCase.clear + val stream = env.fromCollection(data).map(t => (t._4, t._1, t._2, t._3, t._6)) + val table = stream.toTable(tEnv, 'float, 'long, 'int, 'double, 'string) + + val countFun = new CountAggFunction + + val weightedAvgWithRetractFun = new WeightedAvgWithRetract + + val windowedTable = table + .window( + Over partitionBy 'string orderBy 'proctime preceding UNBOUNDED_ROW as 'w) + .select( + 'string, + 'float, + countFun('string) over 'w, + 'double.sum over 'w, + weightedAvgWithRetractFun('long, 'int) over 'w, + weightedAvgWithRetractFun('int, 'int) over 'w) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,1.0,1,1.0,1000,1", "Hello,2.0,2,3.0,1666,1", "Hello,3.0,3,6.0,2333,2", + "Hi,5.0,1,5.0,5000,5", "Hi,6.0,2,11.0,5545,5", "Hi,7.0,3,18.0,6111,6", + "Hello,8.0,4,14.0,5571,5", "Hello,9.0,5,23.0,6913,6", "Hello,4.0,6,27.0,6481,6", + "Hi,10.0,4,28.0,7500,7", "Hi,11.0,5,39.0,8487,8", "Hi,12.0,6,51.0,9313,9", + "Hello,16.0,7,43.0,10023,10") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggProcTimeUnBoundedPartitionedRowOverSQL(): Unit =

        { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + StreamITCase.clear + val stream = env.fromCollection(data).map(t => (t._4, t._1, t._2, t._3, t._6)) + val table = stream.toTable(tEnv, 'f, 'l, 'i, 'd, 's) + + tEnv.registerTable("T1", table) + tEnv.registerFunction("countFun", new CountAggFunction) + tEnv.registerFunction("wAvgWithRetract", new WeightedAvgWithRetract) + val sqlQuery = "SELECT " + + "s, " + + "f, " + + "countFun(i) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)," + + "sum(d) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)," + + "wAvgWithRetract(l,i) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)," + + "wAvgWithRetract(i,i) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)" + + "from T1" + + val results = tEnv.sql(sqlQuery).toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,1.0,1,1.0,1000,1", "Hello,2.0,2,3.0,1666,1", "Hello,3.0,3,6.0,2333,2", + "Hi,5.0,1,5.0,5000,5", "Hi,6.0,2,11.0,5545,5", "Hi,7.0,3,18.0,6111,6", + "Hello,8.0,4,14.0,5571,5", "Hello,9.0,5,23.0,6913,6", "Hello,4.0,6,27.0,6481,6", + "Hi,10.0,4,28.0,7500,7", "Hi,11.0,5,39.0,8487,8", "Hi,12.0,6,51.0,9313,9", + "Hello,16.0,7,43.0,10023,10") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggGroupedAggregateSQL(): Unit =

        { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + StreamITCase.clear + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000)) + .map(t => (t._4, t._1, t._2, t._3, t._6)) + val table = stream.toTable(tEnv, 'f, 'l, 'i, 'd, 's) + + tEnv.registerTable("T1", table) + tEnv.registerFunction("countFun", new CountAggFunction) + tEnv.registerFunction("wAvgWithRetract", new WeightedAvgWithRetract) + val sqlQuery = + "SELECT s, countFun(i), SUM(d)" + + "FROM T1 " + + "GROUP BY s, TUMBLE(rowtime(), INTERVAL '5' SECOND)" + + val results = tEnv.sql(sqlQuery).toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,4,10.0", "Hi,3,18.0", "Hello,2,17.0", "Hi,3,33.0", "Hello,1,16.0") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggJavaAPI(): Unit =

        { + // mock + val ds = mock(classOf[DataStream[Row]]) + val jDs = mock(classOf[JDataStream[Row]]) + val typeInfo = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING): _*) + when(ds.javaStream).thenReturn(jDs) + when(jDs.getType).thenReturn(typeInfo) + // Scala environment + val env = mock(classOf[ScalaExecutionEnv]) + val tableEnv = TableEnvironment.getTableEnvironment(env) + val in1 = ds.toTable(tableEnv).as('int, 'long, 'string) + + // Java environment + val javaEnv = mock(classOf[JavaExecutionEnv]) + val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv) + val in2 = javaTableEnv.fromDataStream(jDs).as("int, long, string") + + // Java API + javaTableEnv.registerFunction("myCountFun", new CountAggFunction) + javaTableEnv.registerFunction("weightAvgFun", new WeightedAvg) + var javaTable = in2.window((new SlidingWindow(4.rows, 2.rows)).as("w")) + .groupBy("w, string") + .select( + "string, " + + "myCountFun(string), " + + "int.sum, " + + "weightAvgFun(long, int), " + + "weightAvgFun(int, int)") + + // Scala API + val myCountFun = new CountAggFunction + val weightAvgFun = new WeightedAvg + var scalaTable = in1.window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, myCountFun('string), 'int.sum, weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val helper = new TableTestBase + helper.verifyTableEquals(scalaTable, javaTable) + }

        +}
        +
        +object DStreamUDAGGITCase {
        +
        + class TimestampAndWatermarkWithOffset(offset: Int)
        — End diff –

        Can be removed if we use `.assignAscendingTimestamps(_._1)`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114444069 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala — @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import java.math.BigDecimal + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.streaming.api.datastream. {DataStream => JDataStream} +import org.apache.flink.streaming.api.environment. {StreamExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.streaming.api.scala. {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset +import org.apache.flink.table.api.scala.stream.utils.StreamITCase +import org.apache.flink.table.api. {SlidingWindow, TableEnvironment, Types} +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test +import org.mockito.Mockito. {mock, when} + +import scala.collection.mutable + +/** + * We only test some aggregations until better testing of constructed DataStream + * programs is possible. + */ +class DStreamUDAGGITCase + extends StreamingMultipleProgramsTestBase { + + val data = List( + //('long, 'int, 'double, 'float, 'bigdec, 'string) + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"), + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"), + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"), + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"), + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"), + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"), + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"), + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"), + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"), + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"), + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"), + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"), + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello")) + + @Test + def testUdaggSlidingWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggSessionWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val countFun = new CountAggFunction + + val weightAvgWithMergeFun = new WeightedAvgWithMerge + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'bigdec, 'string) + + val windowedTable = table + .window(Session withGap 5.second on 'rowtime as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('bigdec), + 'float.sum, + weightAvgWithMergeFun('long, 'int), + weightAvgWithMergeFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("Hello,6,27.0,6481,6", "Hi,6,51.0,9313,9", "Hello,1,16.0,16000,16") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggProcTimeUnBoundedPartitionedRowOver(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + StreamITCase.clear + val stream = env.fromCollection(data).map(t => (t._4, t._1, t._2, t._3, t._6)) + val table = stream.toTable(tEnv, 'float, 'long, 'int, 'double, 'string) + + val countFun = new CountAggFunction + + val weightedAvgWithRetractFun = new WeightedAvgWithRetract + + val windowedTable = table + .window( + Over partitionBy 'string orderBy 'proctime preceding UNBOUNDED_ROW as 'w) + .select( + 'string, + 'float, + countFun('string) over 'w, + 'double.sum over 'w, + weightedAvgWithRetractFun('long, 'int) over 'w, + weightedAvgWithRetractFun('int, 'int) over 'w) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,1.0,1,1.0,1000,1", "Hello,2.0,2,3.0,1666,1", "Hello,3.0,3,6.0,2333,2", + "Hi,5.0,1,5.0,5000,5", "Hi,6.0,2,11.0,5545,5", "Hi,7.0,3,18.0,6111,6", + "Hello,8.0,4,14.0,5571,5", "Hello,9.0,5,23.0,6913,6", "Hello,4.0,6,27.0,6481,6", + "Hi,10.0,4,28.0,7500,7", "Hi,11.0,5,39.0,8487,8", "Hi,12.0,6,51.0,9313,9", + "Hello,16.0,7,43.0,10023,10") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggProcTimeUnBoundedPartitionedRowOverSQL(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + StreamITCase.clear + val stream = env.fromCollection(data).map(t => (t._4, t._1, t._2, t._3, t._6)) + val table = stream.toTable(tEnv, 'f, 'l, 'i, 'd, 's) + + tEnv.registerTable("T1", table) + tEnv.registerFunction("countFun", new CountAggFunction) + tEnv.registerFunction("wAvgWithRetract", new WeightedAvgWithRetract) + val sqlQuery = "SELECT " + + "s, " + + "f, " + + "countFun(i) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)," + + "sum(d) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)," + + "wAvgWithRetract(l,i) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)," + + "wAvgWithRetract(i,i) OVER (PARTITION BY s ORDER BY ProcTime() RANGE UNBOUNDED preceding)" + + "from T1" + + val results = tEnv.sql(sqlQuery).toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,1.0,1,1.0,1000,1", "Hello,2.0,2,3.0,1666,1", "Hello,3.0,3,6.0,2333,2", + "Hi,5.0,1,5.0,5000,5", "Hi,6.0,2,11.0,5545,5", "Hi,7.0,3,18.0,6111,6", + "Hello,8.0,4,14.0,5571,5", "Hello,9.0,5,23.0,6913,6", "Hello,4.0,6,27.0,6481,6", + "Hi,10.0,4,28.0,7500,7", "Hi,11.0,5,39.0,8487,8", "Hi,12.0,6,51.0,9313,9", + "Hello,16.0,7,43.0,10023,10") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggGroupedAggregateSQL(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + StreamITCase.clear + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000)) + .map(t => (t._4, t._1, t._2, t._3, t._6)) + val table = stream.toTable(tEnv, 'f, 'l, 'i, 'd, 's) + + tEnv.registerTable("T1", table) + tEnv.registerFunction("countFun", new CountAggFunction) + tEnv.registerFunction("wAvgWithRetract", new WeightedAvgWithRetract) + val sqlQuery = + "SELECT s, countFun(i), SUM(d)" + + "FROM T1 " + + "GROUP BY s, TUMBLE(rowtime(), INTERVAL '5' SECOND)" + + val results = tEnv.sql(sqlQuery).toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,4,10.0", "Hi,3,18.0", "Hello,2,17.0", "Hi,3,33.0", "Hello,1,16.0") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggJavaAPI(): Unit = { + // mock + val ds = mock(classOf[DataStream[Row]]) + val jDs = mock(classOf[JDataStream[Row]]) + val typeInfo = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING): _*) + when(ds.javaStream).thenReturn(jDs) + when(jDs.getType).thenReturn(typeInfo) + // Scala environment + val env = mock(classOf[ScalaExecutionEnv]) + val tableEnv = TableEnvironment.getTableEnvironment(env) + val in1 = ds.toTable(tableEnv).as('int, 'long, 'string) + + // Java environment + val javaEnv = mock(classOf[JavaExecutionEnv]) + val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv) + val in2 = javaTableEnv.fromDataStream(jDs).as("int, long, string") + + // Java API + javaTableEnv.registerFunction("myCountFun", new CountAggFunction) + javaTableEnv.registerFunction("weightAvgFun", new WeightedAvg) + var javaTable = in2.window((new SlidingWindow(4.rows, 2.rows)).as("w")) + .groupBy("w, string") + .select( + "string, " + + "myCountFun(string), " + + "int.sum, " + + "weightAvgFun(long, int), " + + "weightAvgFun(int, int)") + + // Scala API + val myCountFun = new CountAggFunction + val weightAvgFun = new WeightedAvg + var scalaTable = in1.window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, myCountFun('string), 'int.sum, weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val helper = new TableTestBase + helper.verifyTableEquals(scalaTable, javaTable) + } +} + +object DStreamUDAGGITCase { + + class TimestampAndWatermarkWithOffset(offset: Int) — End diff – Can be removed if we use `.assignAscendingTimestamps(_._1)`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114449859

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala —
        @@ -0,0 +1,309 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.stream.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
        +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
        +import org.apache.flink.streaming.api.TimeCharacteristic
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.streaming.api.datastream.

        {DataStream => JDataStream}

        +import org.apache.flink.streaming.api.environment.

        {StreamExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.streaming.api.scala.

        {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.streaming.api.watermark.Watermark
        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract}

        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset
        +import org.apache.flink.table.api.scala.stream.utils.StreamITCase
        +import org.apache.flink.table.api.

        {SlidingWindow, TableEnvironment, Types}

        +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.types.Row
        +import org.junit.Assert._
        +import org.junit.Test
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.mutable
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataStream
        + * programs is possible.
        + */
        +class DStreamUDAGGITCase
        + extends StreamingMultipleProgramsTestBase {
        +
        + val data = List(
        + //('long, 'int, 'double, 'float, 'bigdec, 'string)
        + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"),
        + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"),
        + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"),
        + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"),
        + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"),
        + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"),
        + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"),
        + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"),
        + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"),
        + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"),
        + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"),
        + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"),
        + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello"))
        +
        + @Test
        + def testUdaggSlidingWindowGroupedAggregate(): Unit =

        { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggSessionWindowGroupedAggregate(): Unit = {
        + val env = StreamExecutionEnvironment.getExecutionEnvironment
        + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
        + env.setParallelism(1)
        + val tEnv = TableEnvironment.getTableEnvironment(env)
        + StreamITCase.testResults = mutable.MutableList()
        +
        + val countFun = new CountAggFunction
        +
        + val weightAvgWithMergeFun = new WeightedAvgWithMerge
        +
        + val stream = env
        + .fromCollection(data)
        + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000))
        — End diff –

        use `.assignAscendingTimestamps(_._1)` instead?

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114449859 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala — @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import java.math.BigDecimal + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.streaming.api.datastream. {DataStream => JDataStream} +import org.apache.flink.streaming.api.environment. {StreamExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.streaming.api.scala. {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset +import org.apache.flink.table.api.scala.stream.utils.StreamITCase +import org.apache.flink.table.api. {SlidingWindow, TableEnvironment, Types} +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test +import org.mockito.Mockito. {mock, when} + +import scala.collection.mutable + +/** + * We only test some aggregations until better testing of constructed DataStream + * programs is possible. + */ +class DStreamUDAGGITCase + extends StreamingMultipleProgramsTestBase { + + val data = List( + //('long, 'int, 'double, 'float, 'bigdec, 'string) + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"), + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"), + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"), + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"), + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"), + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"), + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"), + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"), + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"), + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"), + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"), + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"), + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello")) + + @Test + def testUdaggSlidingWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggSessionWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val countFun = new CountAggFunction + + val weightAvgWithMergeFun = new WeightedAvgWithMerge + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000)) — End diff – use `.assignAscendingTimestamps(_._1)` instead?
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114448779

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala —
        @@ -274,27 +278,73 @@ class CodeGenerator(
        constantFlags: Option[Array[(Int, Boolean)]],
        outputArity: Int,
        needRetract: Boolean,

        • needMerge: Boolean)
          + needMerge: Boolean,
          + needReset: Boolean)
          : GeneratedAggregationsFunction = {

        // get unique function name
        val funcName = newName(name)
        // register UDAGGs
        val aggs = aggregates.map(a => generator.addReusableFunction(a))
        // get java types of accumulators

        • val accTypes = aggregates.map { a =>
        • a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
          + val accTypeClasses = aggregates.map { a => + val accType = TypeExtractor.createTypeInfo(a, classOf[AggregateFunction[_, _]], a.getClass, 1) + accType.getTypeClass }

          + val accTypes = accTypeClasses.map(_.getCanonicalName)

        • // get java types of input fields
        • val javaTypes = inputType.getFieldList
          + // get java classes of input fields
          + val javaClasses = inputType.getFieldList
          .map(f => FlinkTypeFactory.toTypeInfo(f.getType))
        • .map(t => t.getTypeClass.getCanonicalName)
          + .map(t => t.getTypeClass)
          // get parameter lists for aggregation functions
        • val parameters = aggFields.map {inFields =>
        • val fields = for (f <- inFields) yield s"($ {javaTypes(f)}

          ) input.getField($f)"
          + val parameters = aggFields.map

          Unknown macro: { inFields => + val fields = for (f <- inFields) yield + s"(${javaClasses(f).getCanonicalName}) input.getField($f)" fields.mkString(", ") }

          + val methodSignaturesList = aggFields.map

          { + inFields => for (f <- inFields) yield javaClasses(f) + }

          +
          + // check and validate the needed methods
          + aggregates.zipWithIndex.map {
          + case (a, i) => {
          + getUserDefinedMethod(a, "accumulate", Array(accTypeClasses) ++ methodSignaturesList)
          + .getOrElse(
          + throw new CodeGenException(
          + s"No matching accumulate method found for aggregate " +

            • End diff –

        add parameter classes to exception, e.g., `... for aggregation function <class name> and parameters <parameter classes>`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114448779 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala — @@ -274,27 +278,73 @@ class CodeGenerator( constantFlags: Option[Array [(Int, Boolean)] ], outputArity: Int, needRetract: Boolean, needMerge: Boolean) + needMerge: Boolean, + needReset: Boolean) : GeneratedAggregationsFunction = { // get unique function name val funcName = newName(name) // register UDAGGs val aggs = aggregates.map(a => generator.addReusableFunction(a)) // get java types of accumulators val accTypes = aggregates.map { a => a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName + val accTypeClasses = aggregates.map { a => + val accType = TypeExtractor.createTypeInfo(a, classOf[AggregateFunction[_, _]], a.getClass, 1) + accType.getTypeClass } + val accTypes = accTypeClasses.map(_.getCanonicalName) // get java types of input fields val javaTypes = inputType.getFieldList + // get java classes of input fields + val javaClasses = inputType.getFieldList .map(f => FlinkTypeFactory.toTypeInfo(f.getType)) .map(t => t.getTypeClass.getCanonicalName) + .map(t => t.getTypeClass) // get parameter lists for aggregation functions val parameters = aggFields.map {inFields => val fields = for (f <- inFields) yield s"($ {javaTypes(f)} ) input.getField($f)" + val parameters = aggFields.map Unknown macro: { inFields => + val fields = for (f <- inFields) yield + s"(${javaClasses(f).getCanonicalName}) input.getField($f)" fields.mkString(", ") } + val methodSignaturesList = aggFields.map { + inFields => for (f <- inFields) yield javaClasses(f) + } + + // check and validate the needed methods + aggregates.zipWithIndex.map { + case (a, i) => { + getUserDefinedMethod(a, "accumulate", Array(accTypeClasses ) ++ methodSignaturesList ) + .getOrElse( + throw new CodeGenException( + s"No matching accumulate method found for aggregate " + End diff – add parameter classes to exception, e.g., `... for aggregation function <class name> and parameters <parameter classes>`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114446564

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala —
        @@ -911,7 +911,21 @@ class GroupedTable(
        */
        def select(fields: String): Table = {
        val fieldExprs = ExpressionParser.parseExpressionList(fields)

        • select(fieldExprs: _*)
          +
          + //get the correct expression for UDAGGFunctionCall
          + val input: Seq[Expression] = fieldExprs.zipWithIndex.map {
          + case (Call(name, args), idx) => {
            • End diff –

        This does not catch all cases. The agg function call could be nested in one or more expressions:
        ```
        select('x, weightAvgFun('long, 'int) * 2)
        ```
        So we need to recursively search each expression for UDAGG calls.

        It would also be good to make this a utility function in `ProjectionTranslator`.

        I think this needs to be added to `Table.select()` as well (non-grouped, non-windowed aggregation).

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114446564 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala — @@ -911,7 +911,21 @@ class GroupedTable( */ def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) select(fieldExprs: _*) + + //get the correct expression for UDAGGFunctionCall + val input: Seq [Expression] = fieldExprs.zipWithIndex.map { + case (Call(name, args), idx) => { End diff – This does not catch all cases. The agg function call could be nested in one or more expressions: ``` select('x, weightAvgFun('long, 'int) * 2) ``` So we need to recursively search each expression for UDAGG calls. It would also be good to make this a utility function in `ProjectionTranslator`. I think this needs to be added to `Table.select()` as well (non-grouped, non-windowed aggregation).
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on the issue:

        https://github.com/apache/flink/pull/3809

        Thanks for the review, @fhueske , and very good point regarding to the approach to "replace UDAGGFunctionCall". I have addressed your comments. Please take a look.

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on the issue: https://github.com/apache/flink/pull/3809 Thanks for the review, @fhueske , and very good point regarding to the approach to "replace UDAGGFunctionCall". I have addressed your comments. Please take a look.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114498337

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala —
        @@ -0,0 +1,309 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.stream.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
        +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
        +import org.apache.flink.streaming.api.TimeCharacteristic
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.streaming.api.datastream.

        {DataStream => JDataStream}

        +import org.apache.flink.streaming.api.environment.

        {StreamExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.streaming.api.scala.

        {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.streaming.api.watermark.Watermark
        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract}

        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset
        +import org.apache.flink.table.api.scala.stream.utils.StreamITCase
        +import org.apache.flink.table.api.

        {SlidingWindow, TableEnvironment, Types}

        +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.types.Row
        +import org.junit.Assert._
        +import org.junit.Test
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.mutable
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataStream
        + * programs is possible.
        + */
        +class DStreamUDAGGITCase
        + extends StreamingMultipleProgramsTestBase {
        +
        + val data = List(
        + //('long, 'int, 'double, 'float, 'bigdec, 'string)
        + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"),
        + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"),
        + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"),
        + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"),
        + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"),
        + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"),
        + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"),
        + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"),
        + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"),
        + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"),
        + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"),
        + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"),
        + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello"))
        +
        + @Test
        + def testUdaggSlidingWindowGroupedAggregate(): Unit =

        { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggSessionWindowGroupedAggregate(): Unit = {
        + val env = StreamExecutionEnvironment.getExecutionEnvironment
        + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
        + env.setParallelism(1)
        + val tEnv = TableEnvironment.getTableEnvironment(env)
        + StreamITCase.testResults = mutable.MutableList()
        +
        + val countFun = new CountAggFunction
        +
        + val weightAvgWithMergeFun = new WeightedAvgWithMerge
        +
        + val stream = env
        + .fromCollection(data)
        + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000))
        — End diff –

        The reason I used "TimestampAndWatermarkWithOffset(10000)" is that I want to give this test a watermark which can let the late arrival data merges two separate windows.

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114498337 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala — @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import java.math.BigDecimal + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.streaming.api.datastream. {DataStream => JDataStream} +import org.apache.flink.streaming.api.environment. {StreamExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.streaming.api.scala. {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset +import org.apache.flink.table.api.scala.stream.utils.StreamITCase +import org.apache.flink.table.api. {SlidingWindow, TableEnvironment, Types} +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test +import org.mockito.Mockito. {mock, when} + +import scala.collection.mutable + +/** + * We only test some aggregations until better testing of constructed DataStream + * programs is possible. + */ +class DStreamUDAGGITCase + extends StreamingMultipleProgramsTestBase { + + val data = List( + //('long, 'int, 'double, 'float, 'bigdec, 'string) + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"), + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"), + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"), + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"), + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"), + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"), + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"), + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"), + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"), + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"), + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"), + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"), + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello")) + + @Test + def testUdaggSlidingWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggSessionWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val countFun = new CountAggFunction + + val weightAvgWithMergeFun = new WeightedAvgWithMerge + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000)) — End diff – The reason I used "TimestampAndWatermarkWithOffset(10000)" is that I want to give this test a watermark which can let the late arrival data merges two separate windows.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114504264

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        + override def makeCopy(args: Array[AnyRef]): this.type = {
        — End diff –

        Do we need this? `ScalarFunctionCall` has a `Seq[Expression]` parameter as well and does not override `makeCopy`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114504264 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly + override def makeCopy(args: Array [AnyRef] ): this.type = { — End diff – Do we need this? `ScalarFunctionCall` has a `Seq [Expression] ` parameter as well and does not override `makeCopy`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114502444

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        — End diff –

        `vargars` -> `varargs`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114502444 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly — End diff – `vargars` -> `varargs`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114506809

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala —
        @@ -129,6 +129,8 @@ class Table(
        */
        def select(fields: String): Table = {
        val fieldExprs = ExpressionParser.parseExpressionList(fields)
        + //get the correct expression for UDAGGFunctionCall
        + val input = fieldExprs.map(replaceUDAGGFunctionCall(_, tableEnv))
        — End diff –

        rename `input` to `withResolvedUdaggCalls`?

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114506809 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala — @@ -129,6 +129,8 @@ class Table( */ def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) + //get the correct expression for UDAGGFunctionCall + val input = fieldExprs.map(replaceUDAGGFunctionCall(_, tableEnv)) — End diff – rename `input` to `withResolvedUdaggCalls`?
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114520070

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala —
        @@ -67,52 +68,89 @@ object UserDefinedFunctionUtils {
        }

        // ----------------------------------------------------------------------------------------------

        • // Utilities for eval methods
          + // Utilities for user-defined methods
          // ----------------------------------------------------------------------------------------------

        /**

        • * Returns signatures matching the given signature of [[TypeInformation]].
          + * Returns signatures of eval methods matching the given signature of [[TypeInformation]].
        • Elements of the signature can be null (act as a wildcard).
          */
        • def getSignature(
        • function: UserDefinedFunction,
        • signature: Seq[TypeInformation[_]])
          + def getEvalMethodSignature(
          + function: UserDefinedFunction,
            • End diff –

        indent

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114520070 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala — @@ -67,52 +68,89 @@ object UserDefinedFunctionUtils { } // ---------------------------------------------------------------------------------------------- // Utilities for eval methods + // Utilities for user-defined methods // ---------------------------------------------------------------------------------------------- /** * Returns signatures matching the given signature of [ [TypeInformation] ]. + * Returns signatures of eval methods matching the given signature of [ [TypeInformation] ]. Elements of the signature can be null (act as a wildcard). */ def getSignature( function: UserDefinedFunction, signature: Seq[TypeInformation [_] ]) + def getEvalMethodSignature( + function: UserDefinedFunction, End diff – indent
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114505152

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        + override def makeCopy(args: Array[AnyRef]): this.type = {
        + if (args.length < 1)

        { + throw new TableException("Invalid constructor params") + }

        + val agg = args.head.asInstanceOf[AggregateFunction[_, _]]
        + val arg = args.last.asInstanceOf[Seq[Expression]]
        + new UDAGGFunctionCall(agg, arg).asInstanceOf[this.type]
        + }
        +
        + override def resultType: TypeInformation[_] = TypeExtractor.createTypeInfo(
        + aggregateFunction, classOf[AggregateFunction[_, _]], aggregateFunction.getClass, 0)
        +
        + override def validateInput(): ValidationResult = {
        + val signature = children.map(_.resultType)
        + // look for a signature that matches the input types
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature)
        + if (foundSignature.isEmpty) {
        + ValidationFailure(s"Given parameters do not match any signature. \n" +
        + s"Actual: $

        {signatureToString(signature)}

        \n" +
        + s"Expected: $

        {signaturesToString(aggregateFunction, "accumulate")}

        ")
        + } else

        { + ValidationSuccess + }

        + }
        +
        + override def toString(): String = s"$

        {aggregateFunction.getClass.getSimpleName}

        ($args)"
        +
        + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall =

        { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val sqlFunction = AggSqlFunction(name, aggregateFunction, resultType, typeFactory) + relBuilder.aggregateCall(sqlFunction, false, null, name, args.map(_.toRexNode): _*) + }

        +
        + override private[flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) =

        { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + AggSqlFunction("UDAGG", // tableAPI parser does not really use this + aggregateFunction, + resultType, + typeFactory) + }

        +
        + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
        + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
        + relBuilder.call(
        + AggSqlFunction("UDAGG", // tableAPI parser does not really use this
        — End diff –

        Call `getSqlAggFunction` here?

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114505152 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly + override def makeCopy(args: Array [AnyRef] ): this.type = { + if (args.length < 1) { + throw new TableException("Invalid constructor params") + } + val agg = args.head.asInstanceOf[AggregateFunction [_, _] ] + val arg = args.last.asInstanceOf[Seq [Expression] ] + new UDAGGFunctionCall(agg, arg).asInstanceOf [this.type] + } + + override def resultType: TypeInformation [_] = TypeExtractor.createTypeInfo( + aggregateFunction, classOf[AggregateFunction [_, _] ], aggregateFunction.getClass, 0) + + override def validateInput(): ValidationResult = { + val signature = children.map(_.resultType) + // look for a signature that matches the input types + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature) + if (foundSignature.isEmpty) { + ValidationFailure(s"Given parameters do not match any signature. \n" + + s"Actual: $ {signatureToString(signature)} \n" + + s"Expected: $ {signaturesToString(aggregateFunction, "accumulate")} ") + } else { + ValidationSuccess + } + } + + override def toString(): String = s"$ {aggregateFunction.getClass.getSimpleName} ($args)" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val sqlFunction = AggSqlFunction(name, aggregateFunction, resultType, typeFactory) + relBuilder.aggregateCall(sqlFunction, false, null, name, args.map(_.toRexNode): _*) + } + + override private [flink] def getSqlAggFunction()(implicit relBuilder: RelBuilder) = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] + AggSqlFunction("UDAGG", // tableAPI parser does not really use this + aggregateFunction, + resultType, + typeFactory) + } + + override private [flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf [FlinkTypeFactory] + relBuilder.call( + AggSqlFunction("UDAGG", // tableAPI parser does not really use this — End diff – Call `getSqlAggFunction` here?
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114505101

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        + override def makeCopy(args: Array[AnyRef]): this.type = {
        + if (args.length < 1)

        { + throw new TableException("Invalid constructor params") + }

        + val agg = args.head.asInstanceOf[AggregateFunction[_, _]]
        + val arg = args.last.asInstanceOf[Seq[Expression]]
        + new UDAGGFunctionCall(agg, arg).asInstanceOf[this.type]
        + }
        +
        + override def resultType: TypeInformation[_] = TypeExtractor.createTypeInfo(
        + aggregateFunction, classOf[AggregateFunction[_, _]], aggregateFunction.getClass, 0)
        +
        + override def validateInput(): ValidationResult = {
        + val signature = children.map(_.resultType)
        + // look for a signature that matches the input types
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature)
        + if (foundSignature.isEmpty) {
        + ValidationFailure(s"Given parameters do not match any signature. \n" +
        + s"Actual: $

        {signatureToString(signature)}

        \n" +
        + s"Expected: $

        {signaturesToString(aggregateFunction, "accumulate")}

        ")
        + } else

        { + ValidationSuccess + }

        + }
        +
        + override def toString(): String = s"$

        {aggregateFunction.getClass.getSimpleName}

        ($args)"
        +
        + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
        + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
        + val sqlFunction = AggSqlFunction(name, aggregateFunction, resultType, typeFactory)
        — End diff –

        Call `getSqlAggFunction` here?

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114505101 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly + override def makeCopy(args: Array [AnyRef] ): this.type = { + if (args.length < 1) { + throw new TableException("Invalid constructor params") + } + val agg = args.head.asInstanceOf[AggregateFunction [_, _] ] + val arg = args.last.asInstanceOf[Seq [Expression] ] + new UDAGGFunctionCall(agg, arg).asInstanceOf [this.type] + } + + override def resultType: TypeInformation [_] = TypeExtractor.createTypeInfo( + aggregateFunction, classOf[AggregateFunction [_, _] ], aggregateFunction.getClass, 0) + + override def validateInput(): ValidationResult = { + val signature = children.map(_.resultType) + // look for a signature that matches the input types + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature) + if (foundSignature.isEmpty) { + ValidationFailure(s"Given parameters do not match any signature. \n" + + s"Actual: $ {signatureToString(signature)} \n" + + s"Expected: $ {signaturesToString(aggregateFunction, "accumulate")} ") + } else { + ValidationSuccess + } + } + + override def toString(): String = s"$ {aggregateFunction.getClass.getSimpleName} ($args)" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf [FlinkTypeFactory] + val sqlFunction = AggSqlFunction(name, aggregateFunction, resultType, typeFactory) — End diff – Call `getSqlAggFunction` here?
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114506690

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala —
        @@ -129,6 +129,8 @@ class Table(
        */
        def select(fields: String): Table = {
        val fieldExprs = ExpressionParser.parseExpressionList(fields)
        + //get the correct expression for UDAGGFunctionCall
        + val input = fieldExprs.map(replaceUDAGGFunctionCall(_, tableEnv))
        select(fieldExprs: _*)
        — End diff –

        `fieldExprs` -> `input`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114506690 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala — @@ -129,6 +129,8 @@ class Table( */ def select(fields: String): Table = { val fieldExprs = ExpressionParser.parseExpressionList(fields) + //get the correct expression for UDAGGFunctionCall + val input = fieldExprs.map(replaceUDAGGFunctionCall(_, tableEnv)) select(fieldExprs: _*) — End diff – `fieldExprs` -> `input`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114512453

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala —
        @@ -0,0 +1,177 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.functions.utils
        +
        +import org.apache.calcite.rel.`type`.RelDataType
        +import org.apache.calcite.sql._
        +import org.apache.calcite.sql.`type`._
        +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
        +import org.apache.calcite.sql.parser.SqlParserPos
        +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
        +import org.apache.flink.api.common.typeinfo._
        +import org.apache.flink.table.api.ValidationException
        +import org.apache.flink.table.calcite.FlinkTypeFactory
        +import org.apache.flink.table.functions.AggregateFunction
        +import org.apache.flink.table.functions.utils.AggSqlFunction.

        {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}

        +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
        +
        +/**
        + * Calcite wrapper for user-defined aggregate functions.
        + *
        + * @param name function name (used by SQL parser)
        + * @param aggregateFunction aggregate function to be called
        + * @param returnType the type information of returned value
        + * @param typeFactory type factory for converting Flink's between Calcite's types
        + */
        +class AggSqlFunction(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + extends SqlUserDefinedAggFunction(
        + new SqlIdentifier(name, SqlParserPos.ZERO),
        + createReturnTypeInference(returnType, typeFactory),
        + createOperandTypeInference(aggregateFunction, typeFactory),
        + createOperandTypeChecker(aggregateFunction),
        + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
        + // will be generated when translating the calcite relnode to flink runtime execution plan
        + null
        + )

        { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +}

        +
        +object AggSqlFunction {
        +
        + def apply(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory): AggSqlFunction =

        { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + }

        +
        + private[flink] def createOperandTypeInference(
        + aggregateFunction: AggregateFunction[_, _],
        + typeFactory: FlinkTypeFactory)
        + : SqlOperandTypeInference = {
        + /**
        + * Operand type inference based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeInference {
        + override def inferOperandTypes(
        + callBinding: SqlCallBinding,
        + returnType: RelDataType,
        + operandTypes: Array[RelDataType]): Unit = {
        +
        + val operandTypeInfo = getOperandTypeInfo(callBinding)
        +
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
        + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
        +
        + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1))
        + .map(typeFactory.createTypeFromTypeInfo)
        +
        + for (i <- operandTypes.indices) {
        + if (i < inferredTypes.length - 1)

        { + operandTypes(i) = inferredTypes(i) + }

        else if (null != inferredTypes.last.getComponentType) {
        — End diff –

        Can you explain the case of the array as last parameter? Is this to circumvent the Java parameter limit? This case is not reached by any tests.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114512453 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala — @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction. {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private [flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction [_, _] , + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array [RelDataType] ): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) + .map(typeFactory.createTypeFromTypeInfo) + + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { — End diff – Can you explain the case of the array as last parameter? Is this to circumvent the Java parameter limit? This case is not reached by any tests.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114514490

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala —
        @@ -0,0 +1,177 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.functions.utils
        +
        +import org.apache.calcite.rel.`type`.RelDataType
        +import org.apache.calcite.sql._
        +import org.apache.calcite.sql.`type`._
        +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
        +import org.apache.calcite.sql.parser.SqlParserPos
        +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
        +import org.apache.flink.api.common.typeinfo._
        +import org.apache.flink.table.api.ValidationException
        +import org.apache.flink.table.calcite.FlinkTypeFactory
        +import org.apache.flink.table.functions.AggregateFunction
        +import org.apache.flink.table.functions.utils.AggSqlFunction.

        {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}

        +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
        +
        +/**
        + * Calcite wrapper for user-defined aggregate functions.
        + *
        + * @param name function name (used by SQL parser)
        + * @param aggregateFunction aggregate function to be called
        + * @param returnType the type information of returned value
        + * @param typeFactory type factory for converting Flink's between Calcite's types
        + */
        +class AggSqlFunction(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + extends SqlUserDefinedAggFunction(
        + new SqlIdentifier(name, SqlParserPos.ZERO),
        + createReturnTypeInference(returnType, typeFactory),
        + createOperandTypeInference(aggregateFunction, typeFactory),
        + createOperandTypeChecker(aggregateFunction),
        + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
        + // will be generated when translating the calcite relnode to flink runtime execution plan
        + null
        + )

        { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +}

        +
        +object AggSqlFunction {
        +
        + def apply(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory): AggSqlFunction =

        { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + }

        +
        + private[flink] def createOperandTypeInference(
        + aggregateFunction: AggregateFunction[_, _],
        + typeFactory: FlinkTypeFactory)
        + : SqlOperandTypeInference = {
        + /**
        + * Operand type inference based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeInference {
        + override def inferOperandTypes(
        + callBinding: SqlCallBinding,
        + returnType: RelDataType,
        + operandTypes: Array[RelDataType]): Unit = {
        +
        + val operandTypeInfo = getOperandTypeInfo(callBinding)
        +
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
        + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
        +
        + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1))
        + .map(typeFactory.createTypeFromTypeInfo)
        +
        + for (i <- operandTypes.indices) {
        + if (i < inferredTypes.length - 1)

        { + operandTypes(i) = inferredTypes(i) + }

        else if (null != inferredTypes.last.getComponentType)

        { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + }

        else

        { + operandTypes(i) = inferredTypes.last + }

        + }
        + }
        + }
        + }
        +
        + private[flink] def createReturnTypeInference(
        + resultType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + : SqlReturnTypeInference = {
        +
        + new SqlReturnTypeInference {
        + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType =

        { + typeFactory.createTypeFromTypeInfo(resultType) + }

        + }
        + }
        +
        + private[flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction[_, _])
        + : SqlOperandTypeChecker = {
        +
        + val signatures = getMethodSignatures(aggregateFunction, "accumulate")
        +
        + /**
        + * Operand type checker based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeChecker {
        + override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
        + s"$opName[$

        {signaturesToString(aggregateFunction, "accumulate")}

        ]"
        + }
        +
        + override def getOperandCountRange: SqlOperandCountRange = {
        + var min = 255
        + var max = -1
        + signatures.foreach(
        + sig => {
        + val inputSig = sig.drop(1)
        + //do not count accumulator as input
        + var len = inputSig.length
        + if (len > 0 && inputSig(inputSig.length - 1).isArray) {
        + max = 254 // according to JVM spec 4.3.3
        — End diff –

        What are we counting here? The number of parameters of the accumulate function right?
        So if we have `accumulate(acc: ACC, v: Int, a: Array[String])` this method would return `max = 254`, although the method has only two parameters. is that correct?

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114514490 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala — @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction. {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private [flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction [_, _] , + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array [RelDataType] ): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) + .map(typeFactory.createTypeFromTypeInfo) + + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + } else { + operandTypes(i) = inferredTypes.last + } + } + } + } + } + + private [flink] def createReturnTypeInference( + resultType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + : SqlReturnTypeInference = { + + new SqlReturnTypeInference { + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType = { + typeFactory.createTypeFromTypeInfo(resultType) + } + } + } + + private [flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction [_, _] ) + : SqlOperandTypeChecker = { + + val signatures = getMethodSignatures(aggregateFunction, "accumulate") + + /** + * Operand type checker based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeChecker { + override def getAllowedSignatures(op: SqlOperator, opName: String): String = { + s"$opName[$ {signaturesToString(aggregateFunction, "accumulate")} ]" + } + + override def getOperandCountRange: SqlOperandCountRange = { + var min = 255 + var max = -1 + signatures.foreach( + sig => { + val inputSig = sig.drop(1) + //do not count accumulator as input + var len = inputSig.length + if (len > 0 && inputSig(inputSig.length - 1).isArray) { + max = 254 // according to JVM spec 4.3.3 — End diff – What are we counting here? The number of parameters of the accumulate function right? So if we have `accumulate(acc: ACC, v: Int, a: Array [String] )` this method would return `max = 254`, although the method has only two parameters. is that correct?
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114521448

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala —
        @@ -67,52 +68,89 @@ object UserDefinedFunctionUtils {
        }

        // ----------------------------------------------------------------------------------------------

        • // Utilities for eval methods
          + // Utilities for user-defined methods
          // ----------------------------------------------------------------------------------------------

        /**

        • * Returns signatures matching the given signature of [[TypeInformation]].
          + * Returns signatures of eval methods matching the given signature of [[TypeInformation]].
        • Elements of the signature can be null (act as a wildcard).
          */
        • def getSignature(
        • function: UserDefinedFunction,
        • signature: Seq[TypeInformation[_]])
          + def getEvalMethodSignature(
          + function: UserDefinedFunction,
          + signature: Seq[TypeInformation[_]])
          : Option[Array[Class[_]]] = { - getEvalMethod(function, signature).map(_.getParameterTypes) + getUserDefinedMethod(function, "eval", typeInfoToClass(signature)).map(_.getParameterTypes) }

        /**

        • * Returns eval method matching the given signature of [[TypeInformation]].
          + * Returns signatures of accumulate methods matching the given signature of [[TypeInformation]].
          + * Elements of the signature can be null (act as a wildcard).
          */
        • def getEvalMethod(
        • function: UserDefinedFunction,
          + def getAccumulateMethodSignature(
          + function: AggregateFunction[_, _],
          signature: Seq[TypeInformation[_]])
          + : Option[Array[Class[_]]] = { + val accType = TypeExtractor.createTypeInfo( + function, classOf[AggregateFunction[_, _]], function.getClass, 1) + val input = (Array(accType) ++ signature).toSeq + getUserDefinedMethod( + function, + "accumulate", + typeInfoToClass(input)).map(_.getParameterTypes) + }

          +
          + def getParameterTypes(
          + function: UserDefinedFunction,
          + signature: Array[Class[_]]): Array[TypeInformation[_]] = {
          + signature.map { c =>
          + try

          { + TypeExtractor.getForClass(c) + }

          catch

          Unknown macro: { + case ite}

          + }
          + }
          +
          + /**
          + * Returns user defined method matching the given name and signature.
          + *
          + * @param function function instance
          + * @param methodName method name
          + * @param methodSignature an array of raw Java classes. We compare the raw Java classes not the
          + * TypeInformation. TypeInformation does not matter during runtime (e.g.
          + * within a MapFunction)
          + */
          + def getUserDefinedMethod(
          + function: UserDefinedFunction,
          + methodName: String,
          + methodSignature: Array[Class[_]])
          : Option[Method] = {

        • // We compare the raw Java classes not the TypeInformation.
        • // TypeInformation does not matter during runtime (e.g. within a MapFunction).
        • val actualSignature = typeInfoToClass(signature)
        • val evalMethods = checkAndExtractEvalMethods(function)
        • val filtered = evalMethods
        • // go over all eval methods and filter out matching methods
          + val methods = checkAndExtractMethods(function, methodName)
          +
          + val filtered = methods
          + // go over all the methods and filter out matching methods
          .filter {
          case cur if !cur.isVarArgs =>
          val signatures = cur.getParameterTypes
          // match parameters of signature to actual parameters
        • actualSignature.length == signatures.length &&
          + methodSignature.length == signatures.length &&
          signatures.zipWithIndex.forall { case (clazz, i) =>
            • End diff –

        can be simplified to `signatures.zip(methodSignature).forall

        { case (s, m) => parameterTypeEquals(m, s)}

        `

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114521448 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala — @@ -67,52 +68,89 @@ object UserDefinedFunctionUtils { } // ---------------------------------------------------------------------------------------------- // Utilities for eval methods + // Utilities for user-defined methods // ---------------------------------------------------------------------------------------------- /** * Returns signatures matching the given signature of [ [TypeInformation] ]. + * Returns signatures of eval methods matching the given signature of [ [TypeInformation] ]. Elements of the signature can be null (act as a wildcard). */ def getSignature( function: UserDefinedFunction, signature: Seq[TypeInformation [_] ]) + def getEvalMethodSignature( + function: UserDefinedFunction, + signature: Seq[TypeInformation [_] ]) : Option[Array[Class [_] ]] = { - getEvalMethod(function, signature).map(_.getParameterTypes) + getUserDefinedMethod(function, "eval", typeInfoToClass(signature)).map(_.getParameterTypes) } /** * Returns eval method matching the given signature of [ [TypeInformation] ]. + * Returns signatures of accumulate methods matching the given signature of [ [TypeInformation] ]. + * Elements of the signature can be null (act as a wildcard). */ def getEvalMethod( function: UserDefinedFunction, + def getAccumulateMethodSignature( + function: AggregateFunction [_, _] , signature: Seq[TypeInformation [_] ]) + : Option[Array[Class [_] ]] = { + val accType = TypeExtractor.createTypeInfo( + function, classOf[AggregateFunction[_, _]], function.getClass, 1) + val input = (Array(accType) ++ signature).toSeq + getUserDefinedMethod( + function, + "accumulate", + typeInfoToClass(input)).map(_.getParameterTypes) + } + + def getParameterTypes( + function: UserDefinedFunction, + signature: Array[Class [_] ]): Array[TypeInformation [_] ] = { + signature.map { c => + try { + TypeExtractor.getForClass(c) + } catch Unknown macro: { + case ite} + } + } + + /** + * Returns user defined method matching the given name and signature. + * + * @param function function instance + * @param methodName method name + * @param methodSignature an array of raw Java classes. We compare the raw Java classes not the + * TypeInformation. TypeInformation does not matter during runtime (e.g. + * within a MapFunction) + */ + def getUserDefinedMethod( + function: UserDefinedFunction, + methodName: String, + methodSignature: Array[Class [_] ]) : Option [Method] = { // We compare the raw Java classes not the TypeInformation. // TypeInformation does not matter during runtime (e.g. within a MapFunction). val actualSignature = typeInfoToClass(signature) val evalMethods = checkAndExtractEvalMethods(function) val filtered = evalMethods // go over all eval methods and filter out matching methods + val methods = checkAndExtractMethods(function, methodName) + + val filtered = methods + // go over all the methods and filter out matching methods .filter { case cur if !cur.isVarArgs => val signatures = cur.getParameterTypes // match parameters of signature to actual parameters actualSignature.length == signatures.length && + methodSignature.length == signatures.length && signatures.zipWithIndex.forall { case (clazz, i) => End diff – can be simplified to `signatures.zip(methodSignature).forall { case (s, m) => parameterTypeEquals(m, s)} `
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114514562

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala —
        @@ -0,0 +1,177 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.functions.utils
        +
        +import org.apache.calcite.rel.`type`.RelDataType
        +import org.apache.calcite.sql._
        +import org.apache.calcite.sql.`type`._
        +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
        +import org.apache.calcite.sql.parser.SqlParserPos
        +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
        +import org.apache.flink.api.common.typeinfo._
        +import org.apache.flink.table.api.ValidationException
        +import org.apache.flink.table.calcite.FlinkTypeFactory
        +import org.apache.flink.table.functions.AggregateFunction
        +import org.apache.flink.table.functions.utils.AggSqlFunction.

        {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}

        +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
        +
        +/**
        + * Calcite wrapper for user-defined aggregate functions.
        + *
        + * @param name function name (used by SQL parser)
        + * @param aggregateFunction aggregate function to be called
        + * @param returnType the type information of returned value
        + * @param typeFactory type factory for converting Flink's between Calcite's types
        + */
        +class AggSqlFunction(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + extends SqlUserDefinedAggFunction(
        + new SqlIdentifier(name, SqlParserPos.ZERO),
        + createReturnTypeInference(returnType, typeFactory),
        + createOperandTypeInference(aggregateFunction, typeFactory),
        + createOperandTypeChecker(aggregateFunction),
        + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
        + // will be generated when translating the calcite relnode to flink runtime execution plan
        + null
        + )

        { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +}

        +
        +object AggSqlFunction {
        +
        + def apply(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory): AggSqlFunction =

        { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + }

        +
        + private[flink] def createOperandTypeInference(
        + aggregateFunction: AggregateFunction[_, _],
        + typeFactory: FlinkTypeFactory)
        + : SqlOperandTypeInference = {
        + /**
        + * Operand type inference based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeInference {
        + override def inferOperandTypes(
        + callBinding: SqlCallBinding,
        + returnType: RelDataType,
        + operandTypes: Array[RelDataType]): Unit = {
        +
        + val operandTypeInfo = getOperandTypeInfo(callBinding)
        +
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
        + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
        +
        + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1))
        + .map(typeFactory.createTypeFromTypeInfo)
        +
        + for (i <- operandTypes.indices) {
        + if (i < inferredTypes.length - 1)

        { + operandTypes(i) = inferredTypes(i) + }

        else if (null != inferredTypes.last.getComponentType)

        { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + }

        else

        { + operandTypes(i) = inferredTypes.last + }

        + }
        + }
        + }
        + }
        +
        + private[flink] def createReturnTypeInference(
        + resultType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + : SqlReturnTypeInference = {
        +
        + new SqlReturnTypeInference {
        + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType =

        { + typeFactory.createTypeFromTypeInfo(resultType) + }

        + }
        + }
        +
        + private[flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction[_, _])
        + : SqlOperandTypeChecker = {
        +
        + val signatures = getMethodSignatures(aggregateFunction, "accumulate")
        +
        + /**
        + * Operand type checker based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeChecker {
        + override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
        + s"$opName[$

        {signaturesToString(aggregateFunction, "accumulate")}

        ]"
        + }
        +
        + override def getOperandCountRange: SqlOperandCountRange = {
        + var min = 255
        + var max = -1
        + signatures.foreach(
        + sig => {
        + val inputSig = sig.drop(1)
        + //do not count accumulator as input
        + var len = inputSig.length
        + if (len > 0 && inputSig(inputSig.length - 1).isArray) {
        + max = 254 // according to JVM spec 4.3.3
        — End diff –

        I think we should have a test case of the last argument is an Array case.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114514562 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala — @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction. {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private [flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction [_, _] , + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array [RelDataType] ): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) + .map(typeFactory.createTypeFromTypeInfo) + + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + } else { + operandTypes(i) = inferredTypes.last + } + } + } + } + } + + private [flink] def createReturnTypeInference( + resultType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + : SqlReturnTypeInference = { + + new SqlReturnTypeInference { + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType = { + typeFactory.createTypeFromTypeInfo(resultType) + } + } + } + + private [flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction [_, _] ) + : SqlOperandTypeChecker = { + + val signatures = getMethodSignatures(aggregateFunction, "accumulate") + + /** + * Operand type checker based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeChecker { + override def getAllowedSignatures(op: SqlOperator, opName: String): String = { + s"$opName[$ {signaturesToString(aggregateFunction, "accumulate")} ]" + } + + override def getOperandCountRange: SqlOperandCountRange = { + var min = 255 + var max = -1 + signatures.foreach( + sig => { + val inputSig = sig.drop(1) + //do not count accumulator as input + var len = inputSig.length + if (len > 0 && inputSig(inputSig.length - 1).isArray) { + max = 254 // according to JVM spec 4.3.3 — End diff – I think we should have a test case of the last argument is an Array case.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114514218

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala —
        @@ -0,0 +1,177 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.functions.utils
        +
        +import org.apache.calcite.rel.`type`.RelDataType
        +import org.apache.calcite.sql._
        +import org.apache.calcite.sql.`type`._
        +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
        +import org.apache.calcite.sql.parser.SqlParserPos
        +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
        +import org.apache.flink.api.common.typeinfo._
        +import org.apache.flink.table.api.ValidationException
        +import org.apache.flink.table.calcite.FlinkTypeFactory
        +import org.apache.flink.table.functions.AggregateFunction
        +import org.apache.flink.table.functions.utils.AggSqlFunction.

        {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}

        +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
        +
        +/**
        + * Calcite wrapper for user-defined aggregate functions.
        + *
        + * @param name function name (used by SQL parser)
        + * @param aggregateFunction aggregate function to be called
        + * @param returnType the type information of returned value
        + * @param typeFactory type factory for converting Flink's between Calcite's types
        + */
        +class AggSqlFunction(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + extends SqlUserDefinedAggFunction(
        + new SqlIdentifier(name, SqlParserPos.ZERO),
        + createReturnTypeInference(returnType, typeFactory),
        + createOperandTypeInference(aggregateFunction, typeFactory),
        + createOperandTypeChecker(aggregateFunction),
        + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
        + // will be generated when translating the calcite relnode to flink runtime execution plan
        + null
        + )

        { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +}

        +
        +object AggSqlFunction {
        +
        + def apply(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory): AggSqlFunction =

        { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + }

        +
        + private[flink] def createOperandTypeInference(
        + aggregateFunction: AggregateFunction[_, _],
        + typeFactory: FlinkTypeFactory)
        + : SqlOperandTypeInference = {
        + /**
        + * Operand type inference based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeInference {
        + override def inferOperandTypes(
        + callBinding: SqlCallBinding,
        + returnType: RelDataType,
        + operandTypes: Array[RelDataType]): Unit = {
        +
        + val operandTypeInfo = getOperandTypeInfo(callBinding)
        +
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
        + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
        +
        + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1))
        + .map(typeFactory.createTypeFromTypeInfo)
        +
        + for (i <- operandTypes.indices) {
        + if (i < inferredTypes.length - 1)

        { + operandTypes(i) = inferredTypes(i) + }

        else if (null != inferredTypes.last.getComponentType)

        { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + }

        else

        { + operandTypes(i) = inferredTypes.last + }

        + }
        + }
        + }
        + }
        +
        + private[flink] def createReturnTypeInference(
        + resultType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + : SqlReturnTypeInference = {
        +
        + new SqlReturnTypeInference {
        + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType =

        { + typeFactory.createTypeFromTypeInfo(resultType) + }

        + }
        + }
        +
        + private[flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction[_, _])
        + : SqlOperandTypeChecker = {
        +
        + val signatures = getMethodSignatures(aggregateFunction, "accumulate")
        +
        + /**
        + * Operand type checker based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeChecker {
        + override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
        + s"$opName[$

        {signaturesToString(aggregateFunction, "accumulate")}

        ]"
        + }
        +
        + override def getOperandCountRange: SqlOperandCountRange = {
        + var min = 255
        + var max = -1
        + signatures.foreach(
        + sig => {
        + val inputSig = sig.drop(1)
        + //do not count accumulator as input
        + var len = inputSig.length
        + if (len > 0 && inputSig(inputSig.length - 1).isArray) {
        — End diff –

        `inputSig.(inputSig.length - 1)` -> `inputSig.last`?

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114514218 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala — @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction. {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private [flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction [_, _] , + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array [RelDataType] ): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) + .map(typeFactory.createTypeFromTypeInfo) + + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + } else { + operandTypes(i) = inferredTypes.last + } + } + } + } + } + + private [flink] def createReturnTypeInference( + resultType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + : SqlReturnTypeInference = { + + new SqlReturnTypeInference { + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType = { + typeFactory.createTypeFromTypeInfo(resultType) + } + } + } + + private [flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction [_, _] ) + : SqlOperandTypeChecker = { + + val signatures = getMethodSignatures(aggregateFunction, "accumulate") + + /** + * Operand type checker based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeChecker { + override def getAllowedSignatures(op: SqlOperator, opName: String): String = { + s"$opName[$ {signaturesToString(aggregateFunction, "accumulate")} ]" + } + + override def getOperandCountRange: SqlOperandCountRange = { + var min = 255 + var max = -1 + signatures.foreach( + sig => { + val inputSig = sig.drop(1) + //do not count accumulator as input + var len = inputSig.length + if (len > 0 && inputSig(inputSig.length - 1).isArray) { — End diff – `inputSig.(inputSig.length - 1)` -> `inputSig.last`?
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114523374

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala —
        @@ -263,12 +327,12 @@ object UserDefinedFunctionUtils {
        /**

        • Returns the return type of the evaluation method matching the given signature.
          */
        • def getResultTypeClass(
          + def getResultTypeClassOfScalaFunction(
            • End diff –

        `Scala` -> `Scalar`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114523374 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala — @@ -263,12 +327,12 @@ object UserDefinedFunctionUtils { /** Returns the return type of the evaluation method matching the given signature. */ def getResultTypeClass( + def getResultTypeClassOfScalaFunction( End diff – `Scala` -> `Scalar`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114523351

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala —
        @@ -236,12 +300,12 @@ object UserDefinedFunctionUtils {

        • def getResultType(
          + def getResultTypeOfScalaFunction(
            • End diff –

        `Scala` -> `Scalar`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114523351 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala — @@ -236,12 +300,12 @@ object UserDefinedFunctionUtils { Internal method of [ ScalarFunction#getResultType() ] that does some pre-checking and uses [ [TypeExtractor] ] as default return type inference. */ def getResultType( + def getResultTypeOfScalaFunction( End diff – `Scala` -> `Scalar`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114505720

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala —
        @@ -281,16 +281,19 @@ case class ScalarFunctionCall(
        override def toString =
        s"$

        {scalarFunction.getClass.getCanonicalName}

        ($

        {parameters.mkString(", ")}

        )"

        • override private[flink] def resultType = getResultType(scalarFunction, foundSignature.get)
          + override private[flink] def resultType =
          + getResultTypeOfScalaFunction(
            • End diff –

        typo +r: `getResultTypeOfScalaFunction` -> `getResultTypeOfScalarFunction`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114505720 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala — @@ -281,16 +281,19 @@ case class ScalarFunctionCall( override def toString = s"$ {scalarFunction.getClass.getCanonicalName} ($ {parameters.mkString(", ")} )" override private [flink] def resultType = getResultType(scalarFunction, foundSignature.get) + override private [flink] def resultType = + getResultTypeOfScalaFunction( End diff – typo +r: `getResultTypeOfScalaFunction` -> `getResultTypeOfScalarFunction`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114533693

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/DSetUDAGGITCase.scala —
        @@ -0,0 +1,192 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.batch.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.

        {DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.api.scala.

        {DataSet, ExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMergeAndReset}

        +import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestBase
        +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.

        {TableEnvironment, Types}

        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.test.util.TestBaseUtils
        +import org.apache.flink.types.Row
        +import org.junit._
        +import org.junit.runner.RunWith
        +import org.junit.runners.Parameterized
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.JavaConverters._
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataSet
        + * programs is possible.
        + */
        +@RunWith(classOf[Parameterized])
        +class DSetUDAGGITCase(configMode: TableConfigMode)
        — End diff –

        Actually, I think we can just extend the existing aggregation tests and add a user-defined aggregation function to some. This way we can avoid many new tests and just adapt some of the existing ones.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114533693 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/DSetUDAGGITCase.scala — @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import java.math.BigDecimal + +import org.apache.flink.api.java. {DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala. {DataSet, ExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMergeAndReset} +import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api. {TableEnvironment, Types} +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.mockito.Mockito. {mock, when} + +import scala.collection.JavaConverters._ + +/** + * We only test some aggregations until better testing of constructed DataSet + * programs is possible. + */ +@RunWith(classOf [Parameterized] ) +class DSetUDAGGITCase(configMode: TableConfigMode) — End diff – Actually, I think we can just extend the existing aggregation tests and add a user-defined aggregation function to some. This way we can avoid many new tests and just adapt some of the existing ones.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114531389

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/DSetUDAGGITCase.scala —
        @@ -0,0 +1,192 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.batch.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.

        {DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.api.scala.

        {DataSet, ExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMergeAndReset}

        +import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestBase
        +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.

        {TableEnvironment, Types}

        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.test.util.TestBaseUtils
        +import org.apache.flink.types.Row
        +import org.junit._
        +import org.junit.runner.RunWith
        +import org.junit.runners.Parameterized
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.JavaConverters._
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataSet
        + * programs is possible.
        + */
        +@RunWith(classOf[Parameterized])
        +class DSetUDAGGITCase(configMode: TableConfigMode)
        — End diff –

        It would be good to implement some unit tests that check the plan and extend `TableTestBase` similar to `UserDefinedTableFunctionTest`.
        Some of the tests should validate the equivalence of the Scala expression and String-based API.
        Others should validate the generated plan (for Scala Table API and SQL) and check that the validations are correct.
        The tests should cover all `select()` methods (`Table`, `GroupedTable`, `WindowGroupedTable`, `OverWindowedTable`).

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114531389 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/DSetUDAGGITCase.scala — @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import java.math.BigDecimal + +import org.apache.flink.api.java. {DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala. {DataSet, ExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMergeAndReset} +import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api. {TableEnvironment, Types} +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.mockito.Mockito. {mock, when} + +import scala.collection.JavaConverters._ + +/** + * We only test some aggregations until better testing of constructed DataSet + * programs is possible. + */ +@RunWith(classOf [Parameterized] ) +class DSetUDAGGITCase(configMode: TableConfigMode) — End diff – It would be good to implement some unit tests that check the plan and extend `TableTestBase` similar to `UserDefinedTableFunctionTest`. Some of the tests should validate the equivalence of the Scala expression and String-based API. Others should validate the generated plan (for Scala Table API and SQL) and check that the validations are correct. The tests should cover all `select()` methods (`Table`, `GroupedTable`, `WindowGroupedTable`, `OverWindowedTable`).
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114534286

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala —
        @@ -0,0 +1,318 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.stream.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
        +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
        +import org.apache.flink.streaming.api.TimeCharacteristic
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.streaming.api.datastream.

        {DataStream => JDataStream}

        +import org.apache.flink.streaming.api.environment.

        {StreamExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.streaming.api.scala.

        {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.streaming.api.watermark.Watermark
        +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract}

        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset
        +import org.apache.flink.table.api.scala.stream.utils.StreamITCase
        +import org.apache.flink.table.api.

        {SlidingWindow, TableEnvironment, Types}

        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.types.Row
        +import org.junit.Assert._
        +import org.junit.Test
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.mutable
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataStream
        + * programs is possible.
        + */
        +class DStreamUDAGGITCase
        — End diff –

        I think we could extend the existing tests for stream aggregates (group window, over window, non-windowed) and use some UDAGGs there (could be a built-in function which is called as a user-defined agg function).
        We should also add the plan tests as described for the batch aggregations.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114534286 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala — @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import java.math.BigDecimal + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.streaming.api.datastream. {DataStream => JDataStream} +import org.apache.flink.streaming.api.environment. {StreamExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.streaming.api.scala. {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset +import org.apache.flink.table.api.scala.stream.utils.StreamITCase +import org.apache.flink.table.api. {SlidingWindow, TableEnvironment, Types} +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test +import org.mockito.Mockito. {mock, when} + +import scala.collection.mutable + +/** + * We only test some aggregations until better testing of constructed DataStream + * programs is possible. + */ +class DStreamUDAGGITCase — End diff – I think we could extend the existing tests for stream aggregates (group window, over window, non-windowed) and use some UDAGGs there (could be a built-in function which is called as a user-defined agg function). We should also add the plan tests as described for the batch aggregations.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114534426

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala —
        @@ -0,0 +1,309 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.stream.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
        +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
        +import org.apache.flink.streaming.api.TimeCharacteristic
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.streaming.api.datastream.

        {DataStream => JDataStream}

        +import org.apache.flink.streaming.api.environment.

        {StreamExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.streaming.api.scala.

        {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.streaming.api.watermark.Watermark
        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract}

        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset
        +import org.apache.flink.table.api.scala.stream.utils.StreamITCase
        +import org.apache.flink.table.api.

        {SlidingWindow, TableEnvironment, Types}

        +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.types.Row
        +import org.junit.Assert._
        +import org.junit.Test
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.mutable
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataStream
        + * programs is possible.
        + */
        +class DStreamUDAGGITCase
        + extends StreamingMultipleProgramsTestBase {
        +
        + val data = List(
        + //('long, 'int, 'double, 'float, 'bigdec, 'string)
        + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"),
        + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"),
        + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"),
        + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"),
        + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"),
        + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"),
        + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"),
        + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"),
        + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"),
        + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"),
        + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"),
        + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"),
        + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello"))
        +
        + @Test
        + def testUdaggSlidingWindowGroupedAggregate(): Unit =

        { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + }

        +
        + @Test
        + def testUdaggSessionWindowGroupedAggregate(): Unit = {
        + val env = StreamExecutionEnvironment.getExecutionEnvironment
        + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
        + env.setParallelism(1)
        + val tEnv = TableEnvironment.getTableEnvironment(env)
        + StreamITCase.testResults = mutable.MutableList()
        +
        + val countFun = new CountAggFunction
        +
        + val weightAvgWithMergeFun = new WeightedAvgWithMerge
        +
        + val stream = env
        + .fromCollection(data)
        + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000))
        — End diff –

        I see, makes sense. Would be good to add a comment about the intention of the test, IMO.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114534426 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/DStreamUDAGGITCase.scala — @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import java.math.BigDecimal + +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.streaming.api.datastream. {DataStream => JDataStream} +import org.apache.flink.streaming.api.environment. {StreamExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.streaming.api.scala. {DataStream, StreamExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMerge, WeightedAvgWithRetract} +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.table.DStreamUDAGGITCase.TimestampAndWatermarkWithOffset +import org.apache.flink.table.api.scala.stream.utils.StreamITCase +import org.apache.flink.table.api. {SlidingWindow, TableEnvironment, Types} +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test +import org.mockito.Mockito. {mock, when} + +import scala.collection.mutable + +/** + * We only test some aggregations until better testing of constructed DataStream + * programs is possible. + */ +class DStreamUDAGGITCase + extends StreamingMultipleProgramsTestBase { + + val data = List( + //('long, 'int, 'double, 'float, 'bigdec, 'string) + (1000L, 1, 1d, 1f, new BigDecimal("1"), "Hello"), + (2000L, 2, 2d, 2f, new BigDecimal("2"), "Hello"), + (3000L, 3, 3d, 3f, new BigDecimal("3"), "Hello"), + (5000L, 5, 5d, 5f, new BigDecimal("5"), "Hi"), + (6000L, 6, 6d, 6f, new BigDecimal("6"), "Hi"), + (7000L, 7, 7d, 7f, new BigDecimal("7"), "Hi"), + (8000L, 8, 8d, 8f, new BigDecimal("8"), "Hello"), + (9000L, 9, 9d, 9f, new BigDecimal("9"), "Hello"), + (4000L, 4, 4d, 4f, new BigDecimal("4"), "Hello"), + (10000L, 10, 10d, 10f, new BigDecimal("10"), "Hi"), + (11000L, 11, 11d, 11f, new BigDecimal("11"), "Hi"), + (12000L, 12, 12d, 12f, new BigDecimal("12"), "Hi"), + (16000L, 16, 16d, 16f, new BigDecimal("16"), "Hello")) + + @Test + def testUdaggSlidingWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data).map(t => (t._1, t._2, t._3, t._4, t._6)) + val table = stream.toTable(tEnv, 'long, 'int, 'double, 'float, 'string) + + val countFun = new CountAggFunction + + val weightAvgFun = new WeightedAvg + + val windowedTable = table + .window(Slide over 4.rows every 2.rows as 'w) + .groupBy('w, 'string) + .select( + 'string, + countFun('float), + 'double.sum, + weightAvgFun('long, 'int), + weightAvgFun('int, 'int)) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello,2,3.0,1666,1", "Hi,2,11.0,5545,5", "Hello,4,14.0,5571,5", + "Hello,4,24.0,7083,7", "Hi,4,28.0,7500,7", "Hi,4,40.0,10350,10") + assertEquals(expected, StreamITCase.testResults) + } + + @Test + def testUdaggSessionWindowGroupedAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val countFun = new CountAggFunction + + val weightAvgWithMergeFun = new WeightedAvgWithMerge + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampAndWatermarkWithOffset(10000)) — End diff – I see, makes sense. Would be good to add a comment about the intention of the test, IMO.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114523232

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala —
        @@ -220,14 +264,34 @@ object UserDefinedFunctionUtils {
        typeFactory: FlinkTypeFactory)
        : Seq[SqlFunction] = {
        val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType)

        • val evalMethods = checkAndExtractEvalMethods(tableFunction)
          + val evalMethods = checkAndExtractMethods(tableFunction, "eval")

        evalMethods.map

        { method => val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method) TableSqlFunction(name, tableFunction, resultType, typeFactory, function) }

        }

        + /**
        + * Create [[SqlFunction]]s for an [[AggregateFunction]]
        + *
        + * @param name function name
        + * @param aggFunction aggregate function
        + * @param resultType the type information of returned value
        + * @param typeFactory type factory
        + * @return the TableSqlFunction
        + */
        + def createAggregateSqlFunctions(
        — End diff –

        -s -> `createAggregateSqlFunction`

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114523232 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala — @@ -220,14 +264,34 @@ object UserDefinedFunctionUtils { typeFactory: FlinkTypeFactory) : Seq [SqlFunction] = { val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType) val evalMethods = checkAndExtractEvalMethods(tableFunction) + val evalMethods = checkAndExtractMethods(tableFunction, "eval") evalMethods.map { method => val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method) TableSqlFunction(name, tableFunction, resultType, typeFactory, function) } } + /** + * Create [ [SqlFunction] ]s for an [ [AggregateFunction] ] + * + * @param name function name + * @param aggFunction aggregate function + * @param resultType the type information of returned value + * @param typeFactory type factory + * @return the TableSqlFunction + */ + def createAggregateSqlFunctions( — End diff – -s -> `createAggregateSqlFunction`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114511003

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala —
        @@ -0,0 +1,177 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.functions.utils
        +
        +import org.apache.calcite.rel.`type`.RelDataType
        +import org.apache.calcite.sql._
        +import org.apache.calcite.sql.`type`._
        +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
        +import org.apache.calcite.sql.parser.SqlParserPos
        +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
        +import org.apache.flink.api.common.typeinfo._
        +import org.apache.flink.table.api.ValidationException
        +import org.apache.flink.table.calcite.FlinkTypeFactory
        +import org.apache.flink.table.functions.AggregateFunction
        +import org.apache.flink.table.functions.utils.AggSqlFunction.

        {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}

        +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
        +
        +/**
        + * Calcite wrapper for user-defined aggregate functions.
        + *
        + * @param name function name (used by SQL parser)
        + * @param aggregateFunction aggregate function to be called
        + * @param returnType the type information of returned value
        + * @param typeFactory type factory for converting Flink's between Calcite's types
        + */
        +class AggSqlFunction(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + extends SqlUserDefinedAggFunction(
        + new SqlIdentifier(name, SqlParserPos.ZERO),
        + createReturnTypeInference(returnType, typeFactory),
        + createOperandTypeInference(aggregateFunction, typeFactory),
        + createOperandTypeChecker(aggregateFunction),
        + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
        + // will be generated when translating the calcite relnode to flink runtime execution plan
        + null
        + )

        { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +}

        +
        +object AggSqlFunction {
        +
        + def apply(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory): AggSqlFunction =

        { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + }

        +
        + private[flink] def createOperandTypeInference(
        + aggregateFunction: AggregateFunction[_, _],
        + typeFactory: FlinkTypeFactory)
        + : SqlOperandTypeInference = {
        + /**
        + * Operand type inference based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeInference {
        + override def inferOperandTypes(
        + callBinding: SqlCallBinding,
        + returnType: RelDataType,
        + operandTypes: Array[RelDataType]): Unit = {
        +
        + val operandTypeInfo = getOperandTypeInfo(callBinding)
        +
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
        + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
        — End diff –

        Exception message is not complete

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114511003 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala — @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction. {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private [flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction [_, _] , + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array [RelDataType] ): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) — End diff – Exception message is not complete
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114542402

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala —
        @@ -274,27 +278,74 @@ class CodeGenerator(
        constantFlags: Option[Array[(Int, Boolean)]],
        outputArity: Int,
        needRetract: Boolean,

        • needMerge: Boolean)
          + needMerge: Boolean,
          + needReset: Boolean)
          : GeneratedAggregationsFunction = {

        // get unique function name
        val funcName = newName(name)
        // register UDAGGs
        val aggs = aggregates.map(a => generator.addReusableFunction(a))
        // get java types of accumulators

        • val accTypes = aggregates.map { a =>
        • a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
          + val accTypeClasses = aggregates.map { a => + a.getClass.getMethod("createAccumulator").getReturnType }

          + val accTypes = accTypeClasses.map(_.getCanonicalName)

        • // get java types of input fields
        • val javaTypes = inputType.getFieldList
          + // get java classes of input fields
          + val javaClasses = inputType.getFieldList
          .map(f => FlinkTypeFactory.toTypeInfo(f.getType))
        • .map(t => t.getTypeClass.getCanonicalName)
          + .map(t => t.getTypeClass)
          // get parameter lists for aggregation functions
        • val parameters = aggFields.map {inFields =>
        • val fields = for (f <- inFields) yield s"($ {javaTypes(f)}

          ) input.getField($f)"
          + val parameters = aggFields.map

          Unknown macro: { inFields => + val fields = for (f <- inFields) yield + s"(${javaClasses(f).getCanonicalName}) input.getField($f)" fields.mkString(", ") }

          + val methodSignaturesList = aggFields.map

          { + inFields => for (f <- inFields) yield javaClasses(f) + }

          +
          + // check and validate the needed methods
          + aggregates.zipWithIndex.map {
          + case (a, i) =>

          Unknown macro: { + getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i)) + .getOrElse( + throw new CodeGenException( + s"No matching accumulate method found for aggregate " + + s"'${a.getClass.getCanonicalName}'" +
          + s"with parameters '${signatureToString(methodSignaturesList(i))}'.")
          + )
          +
          + if (needRetract) {
          + getUserDefinedMethod(a, "retract", Array(accTypeClasses) ++ methodSignaturesList)
          + .getOrElse(
          + throw new CodeGenException(
          + s"No matching retract method found for aggregate " +
          + s"'${a.getClass.getCanonicalName}'" + + s"with parameters '${signatureToString(methodSignaturesList(i))}'.") + ) + }

          +
          + if (needMerge) {
          + val methods = checkAndExtractMethods(a, "merge")

            • End diff –

        This check be done as:

        ```
        // get method
        val mergeMethod =
        getUserDefinedMethod(a, "merge", Array(accTypeClasses, classOf[JIterable[Any]]))
        .getOrElse(
        throw new CodeGenException(
        s"No matching merge method found for aggregate " +
        s"$

        {a.getClass.getCanonicalName}'.")
        )
        // get class of iterable type
        val iterableTypeClass = mergeMethod.getGenericParameterTypes.apply(1)
        .asInstanceOf[ParameterizedType].getActualTypeArguments.apply(0)
        if (iterableTypeClass != accTypeClasses) {
        throw new CodeGenException(
        s"No matching merge method found for aggregate " +
        s"${a.getClass.getCanonicalName}

        '.")
        }
        ```

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114542402 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala — @@ -274,27 +278,74 @@ class CodeGenerator( constantFlags: Option[Array [(Int, Boolean)] ], outputArity: Int, needRetract: Boolean, needMerge: Boolean) + needMerge: Boolean, + needReset: Boolean) : GeneratedAggregationsFunction = { // get unique function name val funcName = newName(name) // register UDAGGs val aggs = aggregates.map(a => generator.addReusableFunction(a)) // get java types of accumulators val accTypes = aggregates.map { a => a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName + val accTypeClasses = aggregates.map { a => + a.getClass.getMethod("createAccumulator").getReturnType } + val accTypes = accTypeClasses.map(_.getCanonicalName) // get java types of input fields val javaTypes = inputType.getFieldList + // get java classes of input fields + val javaClasses = inputType.getFieldList .map(f => FlinkTypeFactory.toTypeInfo(f.getType)) .map(t => t.getTypeClass.getCanonicalName) + .map(t => t.getTypeClass) // get parameter lists for aggregation functions val parameters = aggFields.map {inFields => val fields = for (f <- inFields) yield s"($ {javaTypes(f)} ) input.getField($f)" + val parameters = aggFields.map Unknown macro: { inFields => + val fields = for (f <- inFields) yield + s"(${javaClasses(f).getCanonicalName}) input.getField($f)" fields.mkString(", ") } + val methodSignaturesList = aggFields.map { + inFields => for (f <- inFields) yield javaClasses(f) + } + + // check and validate the needed methods + aggregates.zipWithIndex.map { + case (a, i) => Unknown macro: { + getUserDefinedMethod(a, "accumulate", Array(accTypeClasses(i)) ++ methodSignaturesList(i)) + .getOrElse( + throw new CodeGenException( + s"No matching accumulate method found for aggregate " + + s"'${a.getClass.getCanonicalName}'" + + s"with parameters '${signatureToString(methodSignaturesList(i))}'.") + ) + + if (needRetract) { + getUserDefinedMethod(a, "retract", Array(accTypeClasses ) ++ methodSignaturesList ) + .getOrElse( + throw new CodeGenException( + s"No matching retract method found for aggregate " + + s"'${a.getClass.getCanonicalName}'" + + s"with parameters '${signatureToString(methodSignaturesList(i))}'.") + ) + } + + if (needMerge) { + val methods = checkAndExtractMethods(a, "merge") End diff – This check be done as: ``` // get method val mergeMethod = getUserDefinedMethod(a, "merge", Array(accTypeClasses , classOf[JIterable [Any] ])) .getOrElse( throw new CodeGenException( s"No matching merge method found for aggregate " + s"$ {a.getClass.getCanonicalName}'.") ) // get class of iterable type val iterableTypeClass = mergeMethod.getGenericParameterTypes.apply(1) .asInstanceOf [ParameterizedType] .getActualTypeArguments.apply(0) if (iterableTypeClass != accTypeClasses ) { throw new CodeGenException( s"No matching merge method found for aggregate " + s"${a.getClass.getCanonicalName} '.") } ```
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114539517

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala —
        @@ -44,10 +44,10 @@ class ScalarFunctionCallGen(
        operands: Seq[GeneratedExpression])
        : GeneratedExpression = {
        // determine function method and result class

        • val matchingMethod = getEvalMethod(scalarFunction, signature)
          + val matchingMethod = getUserDefinedMethod(scalarFunction, "eval", typeInfoToClass(signature))
          .getOrElse(throw new CodeGenException("No matching signature found."))
          val matchingSignature = matchingMethod.getParameterTypes
        • val resultClass = getResultTypeClass(scalarFunction, matchingSignature)
          + val resultClass = getResultTypeClassOfScalaFunction(scalarFunction, matchingSignature)
            • End diff –

        missing `r`

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114539517 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala — @@ -44,10 +44,10 @@ class ScalarFunctionCallGen( operands: Seq [GeneratedExpression] ) : GeneratedExpression = { // determine function method and result class val matchingMethod = getEvalMethod(scalarFunction, signature) + val matchingMethod = getUserDefinedMethod(scalarFunction, "eval", typeInfoToClass(signature)) .getOrElse(throw new CodeGenException("No matching signature found.")) val matchingSignature = matchingMethod.getParameterTypes val resultClass = getResultTypeClass(scalarFunction, matchingSignature) + val resultClass = getResultTypeClassOfScalaFunction(scalarFunction, matchingSignature) End diff – missing `r`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114541980

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        + override def makeCopy(args: Array[AnyRef]): this.type = {
        — End diff –

        Do we need this? We have no varargs...

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114541980 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly + override def makeCopy(args: Array [AnyRef] ): this.type = { — End diff – Do we need this? We have no varargs...
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114541011

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        + override def makeCopy(args: Array[AnyRef]): this.type = {
        + if (args.length < 1)

        { + throw new TableException("Invalid constructor params") + }

        + val agg = args.head.asInstanceOf[AggregateFunction[_, _]]
        + val arg = args.last.asInstanceOf[Seq[Expression]]
        + new UDAGGFunctionCall(agg, arg).asInstanceOf[this.type]
        + }
        +
        + override def resultType: TypeInformation[_] = TypeExtractor.createTypeInfo(
        — End diff –

        This is again very dangerous. Type extraction fails very often. We should provide a `getResultType` in `AggregateFunction` that the user can implement.

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114541011 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly + override def makeCopy(args: Array [AnyRef] ): this.type = { + if (args.length < 1) { + throw new TableException("Invalid constructor params") + } + val agg = args.head.asInstanceOf[AggregateFunction [_, _] ] + val arg = args.last.asInstanceOf[Seq [Expression] ] + new UDAGGFunctionCall(agg, arg).asInstanceOf [this.type] + } + + override def resultType: TypeInformation [_] = TypeExtractor.createTypeInfo( — End diff – This is again very dangerous. Type extraction fails very often. We should provide a `getResultType` in `AggregateFunction` that the user can implement.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114538809

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala —
        @@ -691,6 +692,8 @@ trait ImplicitExpressionConversions {
        implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression =
        Literal(sqlTimestamp)
        implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
        + implicit def UserDefinedAggFunctionConstructor[T: TypeInformation, ACC](udagg:
        + AggregateFunction[T, ACC]): UDAGGExpression[T, ACC] = UDAGGExpression(udagg)
        — End diff –

        Do we need this `UDAGGExpression`? Can't we convert directly into `UDAGGFunctionCall`? Imho I would call this `AggregateFunctionCall` and avoid acronyms.

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114538809 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala — @@ -691,6 +692,8 @@ trait ImplicitExpressionConversions { implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp) implicit def array2ArrayConstructor(array: Array [_] ): Expression = convertArray(array) + implicit def UserDefinedAggFunctionConstructor [T: TypeInformation, ACC] (udagg: + AggregateFunction [T, ACC] ): UDAGGExpression [T, ACC] = UDAGGExpression(udagg) — End diff – Do we need this `UDAGGExpression`? Can't we convert directly into `UDAGGFunctionCall`? Imho I would call this `AggregateFunctionCall` and avoid acronyms.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114535675

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala —
        @@ -178,4 +178,24 @@ class BatchTableEnvironment(

        registerTableFunctionInternal[T](name, tf)
        }
        +
        + /**
        + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
        + * Registered functions can be referenced in Table API and SQL queries.
        + *
        + * @param name The name under which the function is registered.
        + * @param f The AggregateFunction to register.
        + * @tparam T The type of the output value.
        + * @tparam ACC The type of aggregate accumulator.
        + */
        + def registerFunction[T, ACC](
        + name: String,
        + f: AggregateFunction[T, ACC])
        + : Unit = {
        + implicit val typeInfo: TypeInformation[T] = TypeExtractor
        — End diff –

        What happens if type extraction fails? This happens very often. This method and also `registerTableFunction` should be overloaded to also supply the return type manually.

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114535675 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala — @@ -178,4 +178,24 @@ class BatchTableEnvironment( registerTableFunctionInternal [T] (name, tf) } + + /** + * Registers an [ [AggregateFunction] ] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction [T, ACC] ( + name: String, + f: AggregateFunction [T, ACC] ) + : Unit = { + implicit val typeInfo: TypeInformation [T] = TypeExtractor — End diff – What happens if type extraction fails? This happens very often. This method and also `registerTableFunction` should be overloaded to also supply the return type manually.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114545441

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala —
        @@ -327,4 +332,56 @@ object ProjectionTranslator {
        }
        }

        + /**
        + * Find and replace UDAGG function Call to UDAGGFunctionCall
        + *
        + * @param field the expression to check
        + * @param tableEnv the TableEnvironment
        + * @return an expression with correct UDAGGFunctionCall type for UDAGG functions
        + */
        + def replaceUDAGGFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = {
        — End diff –

        Couldn't we do the replacement during `replaceAggregationsAndProperties`? Save some duplicate code? Actually, `LogicalNode#resolveExpressions` is responsible for looking up calls. Maybe this logic has to be adapted a little bit.

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114545441 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala — @@ -327,4 +332,56 @@ object ProjectionTranslator { } } + /** + * Find and replace UDAGG function Call to UDAGGFunctionCall + * + * @param field the expression to check + * @param tableEnv the TableEnvironment + * @return an expression with correct UDAGGFunctionCall type for UDAGG functions + */ + def replaceUDAGGFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = { — End diff – Couldn't we do the replacement during `replaceAggregationsAndProperties`? Save some duplicate code? Actually, `LogicalNode#resolveExpressions` is responsible for looking up calls. Maybe this logic has to be adapted a little bit.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114537431

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala —
        @@ -691,6 +692,8 @@ trait ImplicitExpressionConversions {
        implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression =
        Literal(sqlTimestamp)
        implicit def array2ArrayConstructor(array: Array[_]): Expression = convertArray(array)
        + implicit def UserDefinedAggFunctionConstructor[T: TypeInformation, ACC](udagg:
        — End diff –

        Methods start with lower case letters.

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114537431 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala — @@ -691,6 +692,8 @@ trait ImplicitExpressionConversions { implicit def sqlTimestamp2Literal(sqlTimestamp: Timestamp): Expression = Literal(sqlTimestamp) implicit def array2ArrayConstructor(array: Array [_] ): Expression = convertArray(array) + implicit def UserDefinedAggFunctionConstructor [T: TypeInformation, ACC] (udagg: — End diff – Methods start with lower case letters.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114571910

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala —
        @@ -327,4 +332,56 @@ object ProjectionTranslator {
        }
        }

        + /**
        + * Find and replace UDAGG function Call to UDAGGFunctionCall
        + *
        + * @param field the expression to check
        + * @param tableEnv the TableEnvironment
        + * @return an expression with correct UDAGGFunctionCall type for UDAGG functions
        + */
        + def replaceUDAGGFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = {
        — End diff –

        We will not have the chance to execute LogicalNode#resolveExpressions before get aggNames, projectFields, etc. I actually tried another alternative approach to conduct the replacement in extractAggregationsAndProperties and replaceAggregationsAndProperties (we have to check and handle the UDAGG call carefully in both functions), it works but I do not like that design. It makes the logic of these two methods not completely clean. Also, in over aggregate it will not call extractAggregationsAndProperties and replaceAggregationsAndProperties. So I decide to implement a separate function to handle the UDAGGFunctionCall replacement.

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114571910 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala — @@ -327,4 +332,56 @@ object ProjectionTranslator { } } + /** + * Find and replace UDAGG function Call to UDAGGFunctionCall + * + * @param field the expression to check + * @param tableEnv the TableEnvironment + * @return an expression with correct UDAGGFunctionCall type for UDAGG functions + */ + def replaceUDAGGFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = { — End diff – We will not have the chance to execute LogicalNode#resolveExpressions before get aggNames, projectFields, etc. I actually tried another alternative approach to conduct the replacement in extractAggregationsAndProperties and replaceAggregationsAndProperties (we have to check and handle the UDAGG call carefully in both functions), it works but I do not like that design. It makes the logic of these two methods not completely clean. Also, in over aggregate it will not call extractAggregationsAndProperties and replaceAggregationsAndProperties. So I decide to implement a separate function to handle the UDAGGFunctionCall replacement.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114576600

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala —
        @@ -178,4 +178,24 @@ class BatchTableEnvironment(

        registerTableFunctionInternal[T](name, tf)
        }
        +
        + /**
        + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
        + * Registered functions can be referenced in Table API and SQL queries.
        + *
        + * @param name The name under which the function is registered.
        + * @param f The AggregateFunction to register.
        + * @tparam T The type of the output value.
        + * @tparam ACC The type of aggregate accumulator.
        + */
        + def registerFunction[T, ACC](
        + name: String,
        + f: AggregateFunction[T, ACC])
        + : Unit = {
        + implicit val typeInfo: TypeInformation[T] = TypeExtractor
        — End diff –

        Thanks @twalthr . If I understand you correctly, you suggest to create a contract method `getResultType` for UDAGG, such that user can provide the result type in case the type extraction fails. Sounds good to me?
        Can you give some examples that when the type extraction will fail (for instance a Row type?) and why it may fail, such that I can add some test cases.

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114576600 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala — @@ -178,4 +178,24 @@ class BatchTableEnvironment( registerTableFunctionInternal [T] (name, tf) } + + /** + * Registers an [ [AggregateFunction] ] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction [T, ACC] ( + name: String, + f: AggregateFunction [T, ACC] ) + : Unit = { + implicit val typeInfo: TypeInformation [T] = TypeExtractor — End diff – Thanks @twalthr . If I understand you correctly, you suggest to create a contract method `getResultType` for UDAGG, such that user can provide the result type in case the type extraction fails. Sounds good to me? Can you give some examples that when the type extraction will fail (for instance a Row type?) and why it may fail, such that I can add some test cases.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114579772

        — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/DSetUDAGGITCase.scala —
        @@ -0,0 +1,192 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.api.scala.batch.table
        +
        +import java.math.BigDecimal
        +
        +import org.apache.flink.api.java.

        {DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv}

        +import org.apache.flink.api.java.typeutils.RowTypeInfo
        +import org.apache.flink.api.scala._
        +import org.apache.flink.api.scala.

        {DataSet, ExecutionEnvironment => ScalaExecutionEnv}

        +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.

        {WeightedAvg, WeightedAvgWithMergeAndReset}

        +import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestBase
        +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
        +import org.apache.flink.table.api.scala._
        +import org.apache.flink.table.api.

        {TableEnvironment, Types}

        +import org.apache.flink.table.functions.aggfunctions.CountAggFunction
        +import org.apache.flink.table.utils.TableTestBase
        +import org.apache.flink.test.util.TestBaseUtils
        +import org.apache.flink.types.Row
        +import org.junit._
        +import org.junit.runner.RunWith
        +import org.junit.runners.Parameterized
        +import org.mockito.Mockito.

        {mock, when}

        +
        +import scala.collection.JavaConverters._
        +
        +/**
        + * We only test some aggregations until better testing of constructed DataSet
        + * programs is possible.
        + */
        +@RunWith(classOf[Parameterized])
        +class DSetUDAGGITCase(configMode: TableConfigMode)
        — End diff –

        Sounds good to me. I did not have the UDAGG design across all different aggregation tests, as I feel the current agg tests are a little mess up. It always takes me a while to find the right test cases among all different test files. I put UDAGG test cases into one file which helps me to easily understand what kinds of tests have been covered. I think we need to think about how to reorganize our agg test structure. Considering the short time to freeze feature, let us keep the current structure (I will split the UDAGG into all different agg test files) and massage the tests later.

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114579772 — Diff: flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/DSetUDAGGITCase.scala — @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.table + +import java.math.BigDecimal + +import org.apache.flink.api.java. {DataSet => JDataSet, ExecutionEnvironment => JavaExecutionEnv} +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala. {DataSet, ExecutionEnvironment => ScalaExecutionEnv} +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions. {WeightedAvg, WeightedAvgWithMergeAndReset} +import org.apache.flink.table.api.scala.batch.utils.TableProgramsCollectionTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api. {TableEnvironment, Types} +import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row +import org.junit._ +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.mockito.Mockito. {mock, when} + +import scala.collection.JavaConverters._ + +/** + * We only test some aggregations until better testing of constructed DataSet + * programs is possible. + */ +@RunWith(classOf [Parameterized] ) +class DSetUDAGGITCase(configMode: TableConfigMode) — End diff – Sounds good to me. I did not have the UDAGG design across all different aggregation tests, as I feel the current agg tests are a little mess up. It always takes me a while to find the right test cases among all different test files. I put UDAGG test cases into one file which helps me to easily understand what kinds of tests have been covered. I think we need to think about how to reorganize our agg test structure. Considering the short time to freeze feature, let us keep the current structure (I will split the UDAGG into all different agg test files) and massage the tests later.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114581240

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala —
        @@ -178,4 +178,24 @@ class BatchTableEnvironment(

        registerTableFunctionInternal[T](name, tf)
        }
        +
        + /**
        + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
        + * Registered functions can be referenced in Table API and SQL queries.
        + *
        + * @param name The name under which the function is registered.
        + * @param f The AggregateFunction to register.
        + * @tparam T The type of the output value.
        + * @tparam ACC The type of aggregate accumulator.
        + */
        + def registerFunction[T, ACC](
        + name: String,
        + f: AggregateFunction[T, ACC])
        + : Unit = {
        + implicit val typeInfo: TypeInformation[T] = TypeExtractor
        — End diff –

        The idea would be to always check `getResultType()` first and only use the type extractor if the method is not implemented. So you would not need to enforce a TypeExtractor failure for the tests.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114581240 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala — @@ -178,4 +178,24 @@ class BatchTableEnvironment( registerTableFunctionInternal [T] (name, tf) } + + /** + * Registers an [ [AggregateFunction] ] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction [T, ACC] ( + name: String, + f: AggregateFunction [T, ACC] ) + : Unit = { + implicit val typeInfo: TypeInformation [T] = TypeExtractor — End diff – The idea would be to always check `getResultType()` first and only use the type extractor if the method is not implemented. So you would not need to enforce a TypeExtractor failure for the tests.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user twalthr commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114593800

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala —
        @@ -178,4 +178,24 @@ class BatchTableEnvironment(

        registerTableFunctionInternal[T](name, tf)
        }
        +
        + /**
        + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
        + * Registered functions can be referenced in Table API and SQL queries.
        + *
        + * @param name The name under which the function is registered.
        + * @param f The AggregateFunction to register.
        + * @tparam T The type of the output value.
        + * @tparam ACC The type of aggregate accumulator.
        + */
        + def registerFunction[T, ACC](
        + name: String,
        + f: AggregateFunction[T, ACC])
        + : Unit = {
        + implicit val typeInfo: TypeInformation[T] = TypeExtractor
        — End diff –

        Yes, the Row type is a good example for type extraction problems. A user usually doesn't want to use `GenericType<Row>`. Other cases need for custom serializers or types with complex generics.

        Show
        githubbot ASF GitHub Bot added a comment - Github user twalthr commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114593800 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala — @@ -178,4 +178,24 @@ class BatchTableEnvironment( registerTableFunctionInternal [T] (name, tf) } + + /** + * Registers an [ [AggregateFunction] ] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction [T, ACC] ( + name: String, + f: AggregateFunction [T, ACC] ) + : Unit = { + implicit val typeInfo: TypeInformation [T] = TypeExtractor — End diff – Yes, the Row type is a good example for type extraction problems. A user usually doesn't want to use `GenericType<Row>`. Other cases need for custom serializers or types with complex generics.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114602957

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala —
        @@ -178,4 +178,24 @@ class BatchTableEnvironment(

        registerTableFunctionInternal[T](name, tf)
        }
        +
        + /**
        + * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
        + * Registered functions can be referenced in Table API and SQL queries.
        + *
        + * @param name The name under which the function is registered.
        + * @param f The AggregateFunction to register.
        + * @tparam T The type of the output value.
        + * @tparam ACC The type of aggregate accumulator.
        + */
        + def registerFunction[T, ACC](
        + name: String,
        + f: AggregateFunction[T, ACC])
        + : Unit = {
        + implicit val typeInfo: TypeInformation[T] = TypeExtractor
        — End diff –

        Well, in the cases here, we are not really interested in the TypeInformation itself but the type class. So it would not really matter if you get Row.class from the `GenericTypeInfo` or the `RowTypeInfo`. Nonetheless, I think it is a good idea to have this optional method. It would also be consistent with the other user-defined functions.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114602957 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/java/BatchTableEnvironment.scala — @@ -178,4 +178,24 @@ class BatchTableEnvironment( registerTableFunctionInternal [T] (name, tf) } + + /** + * Registers an [ [AggregateFunction] ] under a unique name in the TableEnvironment's catalog. + * Registered functions can be referenced in Table API and SQL queries. + * + * @param name The name under which the function is registered. + * @param f The AggregateFunction to register. + * @tparam T The type of the output value. + * @tparam ACC The type of aggregate accumulator. + */ + def registerFunction [T, ACC] ( + name: String, + f: AggregateFunction [T, ACC] ) + : Unit = { + implicit val typeInfo: TypeInformation [T] = TypeExtractor — End diff – Well, in the cases here, we are not really interested in the TypeInformation itself but the type class. So it would not really matter if you get Row.class from the `GenericTypeInfo` or the `RowTypeInfo`. Nonetheless, I think it is a good idea to have this optional method. It would also be consistent with the other user-defined functions.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114651484

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        + override def makeCopy(args: Array[AnyRef]): this.type = {
        + if (args.length < 1)

        { + throw new TableException("Invalid constructor params") + }

        + val agg = args.head.asInstanceOf[AggregateFunction[_, _]]
        + val arg = args.last.asInstanceOf[Seq[Expression]]
        + new UDAGGFunctionCall(agg, arg).asInstanceOf[this.type]
        + }
        +
        + override def resultType: TypeInformation[_] = TypeExtractor.createTypeInfo(
        + aggregateFunction, classOf[AggregateFunction[_, _]], aggregateFunction.getClass, 0)
        +
        + override def validateInput(): ValidationResult = {
        + val signature = children.map(_.resultType)
        + // look for a signature that matches the input types
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature)
        + if (foundSignature.isEmpty) {
        + ValidationFailure(s"Given parameters do not match any signature. \n" +
        + s"Actual: $

        {signatureToString(signature)}

        \n" +
        + s"Expected: $

        {signaturesToString(aggregateFunction, "accumulate")}

        ")
        — End diff –

        The signature string includes the Accumulator, which should be removed.

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114651484 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly + override def makeCopy(args: Array [AnyRef] ): this.type = { + if (args.length < 1) { + throw new TableException("Invalid constructor params") + } + val agg = args.head.asInstanceOf[AggregateFunction [_, _] ] + val arg = args.last.asInstanceOf[Seq [Expression] ] + new UDAGGFunctionCall(agg, arg).asInstanceOf [this.type] + } + + override def resultType: TypeInformation [_] = TypeExtractor.createTypeInfo( + aggregateFunction, classOf[AggregateFunction [_, _] ], aggregateFunction.getClass, 0) + + override def validateInput(): ValidationResult = { + val signature = children.map(_.resultType) + // look for a signature that matches the input types + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature) + if (foundSignature.isEmpty) { + ValidationFailure(s"Given parameters do not match any signature. \n" + + s"Actual: $ {signatureToString(signature)} \n" + + s"Expected: $ {signaturesToString(aggregateFunction, "accumulate")} ") — End diff – The signature string includes the Accumulator, which should be removed.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114690845

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala —
        @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation

        { new SqlAvgAggFunction(AVG) }

        }
        +
        +case class UDAGGFunctionCall(
        + aggregateFunction: AggregateFunction[_, _],
        + args: Seq[Expression])
        + extends Aggregation {
        +
        + override private[flink] def children: Seq[Expression] = args
        +
        + // Override makeCopy method in TreeNode, to produce vargars properly
        + override def makeCopy(args: Array[AnyRef]): this.type = {
        + if (args.length < 1)

        { + throw new TableException("Invalid constructor params") + }

        + val agg = args.head.asInstanceOf[AggregateFunction[_, _]]
        + val arg = args.last.asInstanceOf[Seq[Expression]]
        + new UDAGGFunctionCall(agg, arg).asInstanceOf[this.type]
        + }
        +
        + override def resultType: TypeInformation[_] = TypeExtractor.createTypeInfo(
        + aggregateFunction, classOf[AggregateFunction[_, _]], aggregateFunction.getClass, 0)
        +
        + override def validateInput(): ValidationResult = {
        + val signature = children.map(_.resultType)
        + // look for a signature that matches the input types
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature)
        + if (foundSignature.isEmpty) {
        + ValidationFailure(s"Given parameters do not match any signature. \n" +
        + s"Actual: $

        {signatureToString(signature)}

        \n" +
        + s"Expected: $

        {signaturesToString(aggregateFunction, "accumulate")}

        ")
        + } else

        { + ValidationSuccess + }

        + }
        +
        + override def toString(): String = s"$

        {aggregateFunction.getClass.getSimpleName}

        ($args)"
        +
        + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = {
        + val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory]
        + val sqlFunction = AggSqlFunction(name, aggregateFunction, resultType, typeFactory)
        — End diff –

        I was trying to keep this name consistent with `ScalaSqlFunction` and `TableSqlFunction`

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114690845 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala — @@ -130,3 +142,63 @@ case class Avg(child: Expression) extends Aggregation { new SqlAvgAggFunction(AVG) } } + +case class UDAGGFunctionCall( + aggregateFunction: AggregateFunction [_, _] , + args: Seq [Expression] ) + extends Aggregation { + + override private [flink] def children: Seq [Expression] = args + + // Override makeCopy method in TreeNode, to produce vargars properly + override def makeCopy(args: Array [AnyRef] ): this.type = { + if (args.length < 1) { + throw new TableException("Invalid constructor params") + } + val agg = args.head.asInstanceOf[AggregateFunction [_, _] ] + val arg = args.last.asInstanceOf[Seq [Expression] ] + new UDAGGFunctionCall(agg, arg).asInstanceOf [this.type] + } + + override def resultType: TypeInformation [_] = TypeExtractor.createTypeInfo( + aggregateFunction, classOf[AggregateFunction [_, _] ], aggregateFunction.getClass, 0) + + override def validateInput(): ValidationResult = { + val signature = children.map(_.resultType) + // look for a signature that matches the input types + val foundSignature = getAccumulateMethodSignature(aggregateFunction, signature) + if (foundSignature.isEmpty) { + ValidationFailure(s"Given parameters do not match any signature. \n" + + s"Actual: $ {signatureToString(signature)} \n" + + s"Expected: $ {signaturesToString(aggregateFunction, "accumulate")} ") + } else { + ValidationSuccess + } + } + + override def toString(): String = s"$ {aggregateFunction.getClass.getSimpleName} ($args)" + + override def toAggCall(name: String)(implicit relBuilder: RelBuilder): AggCall = { + val typeFactory = relBuilder.getTypeFactory.asInstanceOf [FlinkTypeFactory] + val sqlFunction = AggSqlFunction(name, aggregateFunction, resultType, typeFactory) — End diff – I was trying to keep this name consistent with `ScalaSqlFunction` and `TableSqlFunction`
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114692276

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala —
        @@ -0,0 +1,177 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.functions.utils
        +
        +import org.apache.calcite.rel.`type`.RelDataType
        +import org.apache.calcite.sql._
        +import org.apache.calcite.sql.`type`._
        +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
        +import org.apache.calcite.sql.parser.SqlParserPos
        +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
        +import org.apache.flink.api.common.typeinfo._
        +import org.apache.flink.table.api.ValidationException
        +import org.apache.flink.table.calcite.FlinkTypeFactory
        +import org.apache.flink.table.functions.AggregateFunction
        +import org.apache.flink.table.functions.utils.AggSqlFunction.

        {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}

        +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
        +
        +/**
        + * Calcite wrapper for user-defined aggregate functions.
        + *
        + * @param name function name (used by SQL parser)
        + * @param aggregateFunction aggregate function to be called
        + * @param returnType the type information of returned value
        + * @param typeFactory type factory for converting Flink's between Calcite's types
        + */
        +class AggSqlFunction(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + extends SqlUserDefinedAggFunction(
        + new SqlIdentifier(name, SqlParserPos.ZERO),
        + createReturnTypeInference(returnType, typeFactory),
        + createOperandTypeInference(aggregateFunction, typeFactory),
        + createOperandTypeChecker(aggregateFunction),
        + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
        + // will be generated when translating the calcite relnode to flink runtime execution plan
        + null
        + )

        { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +}

        +
        +object AggSqlFunction {
        +
        + def apply(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory): AggSqlFunction =

        { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + }

        +
        + private[flink] def createOperandTypeInference(
        + aggregateFunction: AggregateFunction[_, _],
        + typeFactory: FlinkTypeFactory)
        + : SqlOperandTypeInference = {
        + /**
        + * Operand type inference based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeInference {
        + override def inferOperandTypes(
        + callBinding: SqlCallBinding,
        + returnType: RelDataType,
        + operandTypes: Array[RelDataType]): Unit = {
        +
        + val operandTypeInfo = getOperandTypeInfo(callBinding)
        +
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
        + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
        +
        + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1))
        + .map(typeFactory.createTypeFromTypeInfo)
        +
        + for (i <- operandTypes.indices) {
        + if (i < inferredTypes.length - 1)

        { + operandTypes(i) = inferredTypes(i) + }

        else if (null != inferredTypes.last.getComponentType)

        { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + }

        else

        { + operandTypes(i) = inferredTypes.last + }

        + }
        + }
        + }
        + }
        +
        + private[flink] def createReturnTypeInference(
        + resultType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + : SqlReturnTypeInference = {
        +
        + new SqlReturnTypeInference {
        + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType =

        { + typeFactory.createTypeFromTypeInfo(resultType) + }

        + }
        + }
        +
        + private[flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction[_, _])
        + : SqlOperandTypeChecker = {
        +
        + val signatures = getMethodSignatures(aggregateFunction, "accumulate")
        +
        + /**
        + * Operand type checker based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeChecker {
        + override def getAllowedSignatures(op: SqlOperator, opName: String): String = {
        + s"$opName[$

        {signaturesToString(aggregateFunction, "accumulate")}

        ]"
        + }
        +
        + override def getOperandCountRange: SqlOperandCountRange = {
        + var min = 255
        + var max = -1
        + signatures.foreach(
        + sig => {
        + val inputSig = sig.drop(1)
        + //do not count accumulator as input
        + var len = inputSig.length
        + if (len > 0 && inputSig(inputSig.length - 1).isArray) {
        + max = 254 // according to JVM spec 4.3.3
        — End diff –

        We want to make sure number of parameters of all codeGened functions within the [min, max], otherwise it will trigger the JVM failure. For accumulate and retract method, the max (for Operand which does not take into account the accumulator) here should be 253, as function pointer itself takes one, and accumulator takes one. Yes, we should consider to add test cases for huge amount of inputs and Array case.

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114692276 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala — @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction. {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private [flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction [_, _] , + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array [RelDataType] ): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) + .map(typeFactory.createTypeFromTypeInfo) + + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { + // last argument is a collection, the array type + operandTypes(i) = inferredTypes.last.getComponentType + } else { + operandTypes(i) = inferredTypes.last + } + } + } + } + } + + private [flink] def createReturnTypeInference( + resultType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + : SqlReturnTypeInference = { + + new SqlReturnTypeInference { + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType = { + typeFactory.createTypeFromTypeInfo(resultType) + } + } + } + + private [flink] def createOperandTypeChecker(aggregateFunction: AggregateFunction [_, _] ) + : SqlOperandTypeChecker = { + + val signatures = getMethodSignatures(aggregateFunction, "accumulate") + + /** + * Operand type checker based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeChecker { + override def getAllowedSignatures(op: SqlOperator, opName: String): String = { + s"$opName[$ {signaturesToString(aggregateFunction, "accumulate")} ]" + } + + override def getOperandCountRange: SqlOperandCountRange = { + var min = 255 + var max = -1 + signatures.foreach( + sig => { + val inputSig = sig.drop(1) + //do not count accumulator as input + var len = inputSig.length + if (len > 0 && inputSig(inputSig.length - 1).isArray) { + max = 254 // according to JVM spec 4.3.3 — End diff – We want to make sure number of parameters of all codeGened functions within the [min, max] , otherwise it will trigger the JVM failure. For accumulate and retract method, the max (for Operand which does not take into account the accumulator) here should be 253, as function pointer itself takes one, and accumulator takes one. Yes, we should consider to add test cases for huge amount of inputs and Array case.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user shaoxuan-wang commented on a diff in the pull request:

        https://github.com/apache/flink/pull/3809#discussion_r114692346

        — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala —
        @@ -0,0 +1,177 @@
        +/*
        + * Licensed to the Apache Software Foundation (ASF) under one
        + * or more contributor license agreements. See the NOTICE file
        + * distributed with this work for additional information
        + * regarding copyright ownership. The ASF licenses this file
        + * to you under the Apache License, Version 2.0 (the
        + * "License"); you may not use this file except in compliance
        + * with the License. You may obtain a copy of the License at
        + *
        + * http://www.apache.org/licenses/LICENSE-2.0
        + *
        + * Unless required by applicable law or agreed to in writing, software
        + * distributed under the License is distributed on an "AS IS" BASIS,
        + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
        + * See the License for the specific language governing permissions and
        + * limitations under the License.
        + */
        +
        +package org.apache.flink.table.functions.utils
        +
        +import org.apache.calcite.rel.`type`.RelDataType
        +import org.apache.calcite.sql._
        +import org.apache.calcite.sql.`type`._
        +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency
        +import org.apache.calcite.sql.parser.SqlParserPos
        +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction
        +import org.apache.flink.api.common.typeinfo._
        +import org.apache.flink.table.api.ValidationException
        +import org.apache.flink.table.calcite.FlinkTypeFactory
        +import org.apache.flink.table.functions.AggregateFunction
        +import org.apache.flink.table.functions.utils.AggSqlFunction.

        {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference}

        +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
        +
        +/**
        + * Calcite wrapper for user-defined aggregate functions.
        + *
        + * @param name function name (used by SQL parser)
        + * @param aggregateFunction aggregate function to be called
        + * @param returnType the type information of returned value
        + * @param typeFactory type factory for converting Flink's between Calcite's types
        + */
        +class AggSqlFunction(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory)
        + extends SqlUserDefinedAggFunction(
        + new SqlIdentifier(name, SqlParserPos.ZERO),
        + createReturnTypeInference(returnType, typeFactory),
        + createOperandTypeInference(aggregateFunction, typeFactory),
        + createOperandTypeChecker(aggregateFunction),
        + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function
        + // will be generated when translating the calcite relnode to flink runtime execution plan
        + null
        + )

        { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +}

        +
        +object AggSqlFunction {
        +
        + def apply(
        + name: String,
        + aggregateFunction: AggregateFunction[_, _],
        + returnType: TypeInformation[_],
        + typeFactory: FlinkTypeFactory): AggSqlFunction =

        { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + }

        +
        + private[flink] def createOperandTypeInference(
        + aggregateFunction: AggregateFunction[_, _],
        + typeFactory: FlinkTypeFactory)
        + : SqlOperandTypeInference = {
        + /**
        + * Operand type inference based on [[AggregateFunction]] given information.
        + */
        + new SqlOperandTypeInference {
        + override def inferOperandTypes(
        + callBinding: SqlCallBinding,
        + returnType: RelDataType,
        + operandTypes: Array[RelDataType]): Unit = {
        +
        + val operandTypeInfo = getOperandTypeInfo(callBinding)
        +
        + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo)
        + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred."))
        +
        + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1))
        + .map(typeFactory.createTypeFromTypeInfo)
        +
        + for (i <- operandTypes.indices) {
        + if (i < inferredTypes.length - 1)

        { + operandTypes(i) = inferredTypes(i) + }

        else if (null != inferredTypes.last.getComponentType) {
        — End diff –

        Yes, if the last parameter is a component, say Array[Int]. We want to get the type of Int, not the type of Array.

        Show
        githubbot ASF GitHub Bot added a comment - Github user shaoxuan-wang commented on a diff in the pull request: https://github.com/apache/flink/pull/3809#discussion_r114692346 — Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/AggSqlFunction.scala — @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction +import org.apache.flink.api.common.typeinfo._ +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.AggregateFunction +import org.apache.flink.table.functions.utils.AggSqlFunction. {createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ + +/** + * Calcite wrapper for user-defined aggregate functions. + * + * @param name function name (used by SQL parser) + * @param aggregateFunction aggregate function to be called + * @param returnType the type information of returned value + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class AggSqlFunction( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory) + extends SqlUserDefinedAggFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(returnType, typeFactory), + createOperandTypeInference(aggregateFunction, typeFactory), + createOperandTypeChecker(aggregateFunction), + // Do not need to provide a calcite aggregateFunction here. Flink aggregateion function + // will be generated when translating the calcite relnode to flink runtime execution plan + null + ) { + + def getFunction: AggregateFunction[_, _] = aggregateFunction +} + +object AggSqlFunction { + + def apply( + name: String, + aggregateFunction: AggregateFunction [_, _] , + returnType: TypeInformation [_] , + typeFactory: FlinkTypeFactory): AggSqlFunction = { + + new AggSqlFunction(name, aggregateFunction, returnType, typeFactory) + } + + private [flink] def createOperandTypeInference( + aggregateFunction: AggregateFunction [_, _] , + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [ [AggregateFunction] ] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array [RelDataType] ): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getAccumulateMethodSignature(aggregateFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + + val inferredTypes = getParameterTypes(aggregateFunction, foundSignature.drop(1)) + .map(typeFactory.createTypeFromTypeInfo) + + for (i <- operandTypes.indices) { + if (i < inferredTypes.length - 1) { + operandTypes(i) = inferredTypes(i) + } else if (null != inferredTypes.last.getComponentType) { — End diff – Yes, if the last parameter is a component, say Array [Int] . We want to get the type of Int, not the type of Array.
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user fhueske commented on the issue:

        https://github.com/apache/flink/pull/3809

        Merging

        Show
        githubbot ASF GitHub Bot added a comment - Github user fhueske commented on the issue: https://github.com/apache/flink/pull/3809 Merging
        Hide
        githubbot ASF GitHub Bot added a comment -

        Github user asfgit closed the pull request at:

        https://github.com/apache/flink/pull/3809

        Show
        githubbot ASF GitHub Bot added a comment - Github user asfgit closed the pull request at: https://github.com/apache/flink/pull/3809
        Hide
        fhueske Fabian Hueske added a comment -

        Implemented with 981dea41e593f3db763af3d0366bf7adbdd1d3bf

        Show
        fhueske Fabian Hueske added a comment - Implemented with 981dea41e593f3db763af3d0366bf7adbdd1d3bf

          People

          • Assignee:
            ShaoxuanWang Shaoxuan Wang
            Reporter:
            fhueske Fabian Hueske
          • Votes:
            0 Vote for this issue
            Watchers:
            3 Start watching this issue

            Dates

            • Created:
              Updated:
              Resolved:

              Development