Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-30986

Structured Streaming: mapGroupsWithState UDT serialization does not work

    XMLWordPrintableJSON

Details

    • Bug
    • Status: Resolved
    • Major
    • Resolution: Duplicate
    • 2.3.0, 2.4.0
    • None
    • Structured Streaming
    • We're using Spark 2.3.0 on Ubuntu Linux and Windows w/ Scala 2.11.8

    Description

      Hello.  
       
      I'm running Scala 2.11 w/ Spark 2.3.0.  I've encountered a problem with mapGroupsWithState, and was wondering if anyone had insight.  We use Joda time in a number of data structures, and so we've generated a custom serializer for Joda.  This works well in most dataset/dataframe structured streaming operations. However, when running mapGroupsWithState we observed that incorrect dates were being returned from a state.
       
      Simple example:
      1. Input A has a date D
      2. Input A updates state in mapGroupsWithState. Date present in state is D
      3. Input A is added again.  Input A has correct date D, but existing state now has invalid date
       
      Here is a simple repro:
       
      Joda Time UDT:
       

      private[sql] class JodaTimeUDT extends UserDefinedType[DateTime] {
       override def sqlType: DataType = LongType
       override def serialize(obj: DateTime): Long = obj.getMillis
       def deserialize(datum: Any): DateTime = datum match \{ case value: Long => new DateTime(value, DateTimeZone.UTC) }
       override def userClass: Class[DateTime] = classOf[DateTime]
       private[spark] override def asNullable: JodaTimeUDT = this
      }
      
      object JodaTimeUDTRegister {
       def register : Unit = \{ UDTRegistration.register(classOf[DateTime].getName, classOf[JodaTimeUDT].getName) }
      }
      

       
      Test Leveraging Joda UDT:
       

      case class FooWithDate(date: DateTime, s: String, i: Int)
      
      @RunWith(classOf[JUnitRunner])
      class TestJodaTimeUdt extends FlatSpec with Matchers with MockFactory with BeforeAndAfterAll {
        val application = this.getClass.getName
        var session: SparkSession = _
      
        override def beforeAll(): Unit = {
          System.setProperty("hadoop.home.dir", getClass.getResource("/").getPath)
          val sparkConf = new SparkConf()
            .set("spark.driver.allowMultipleContexts", "true")
            .set("spark.testing", "true")
            .set("spark.memory.fraction", "1")
            .set("spark.ui.enabled", "false")
            .set("spark.streaming.gracefulStopTimeout", "1000")
            .setAppName(application).setMaster("local[*]")
      
      
          session = SparkSession.builder().config(sparkConf).getOrCreate()
          session.sparkContext.setCheckpointDir("/")
          JodaTimeUDTRegister.register
        }
      
        override def afterAll(): Unit = {
          session.stop()
        }
      
        it should "work correctly for a streaming input with stateful transformation" in {
          val date = new DateTime(2020, 1, 2, 3, 4, 5, 6, DateTimeZone.UTC)
          val sqlContext = session.sqlContext
          import sqlContext.implicits._
      
          val input = List(FooWithDate(date, "Foo", 1), FooWithDate(date, "Foo", 3), FooWithDate(date, "Foo", 3))
          val streamInput: MemoryStream[FooWithDate] = new MemoryStream[FooWithDate](42, session.sqlContext)
          streamInput.addData(input)
          val ds: Dataset[FooWithDate] = streamInput.toDS()
      
          val mapGroupsWithStateFunction: (Int, Iterator[FooWithDate], GroupState[FooWithDate]) => FooWithDate = TestJodaTimeUdt.updateFooState
          val result: Dataset[FooWithDate] = ds
            .groupByKey(x => x.i)
            .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout())(mapGroupsWithStateFunction)
          val writeTo = s"random_table_name"
      
          result.writeStream.outputMode(OutputMode.Update).format("memory").queryName(writeTo).trigger(Trigger.Once()).start().awaitTermination()
          val combinedResults: Array[FooWithDate] = session.sql(sqlText = s"select * from $writeTo").as[FooWithDate].collect()
          val expected = Array(FooWithDate(date, "Foo", 1), FooWithDate(date, "FooFoo", 6))
          combinedResults should contain theSameElementsAs(expected)
        }
      }
      
      object TestJodaTimeUdt {
        def updateFooState(id: Int, inputs: Iterator[FooWithDate], state: GroupState[FooWithDate]): FooWithDate = {
          if (state.hasTimedOut) {
            state.remove()
            state.getOption.get
          } else {
            val inputsSeq: Seq[FooWithDate] = inputs.toSeq
            val startingState = state.getOption.getOrElse(inputsSeq.head)
            val toProcess = if (state.getOption.isDefined) inputsSeq else inputsSeq.tail
            val updatedFoo = toProcess.foldLeft(startingState)(concatFoo)
      
            state.update(updatedFoo)
            state.setTimeoutDuration("1 minute")
            updatedFoo
          }
        }
      
        def concatFoo(a: FooWithDate, b: FooWithDate): FooWithDate = FooWithDate(b.date, a.s + b.s, a.i + b.i)
      }
      
      

      The test output shows the invalid date:

         
      org.scalatest.exceptions.TestFailedException:
      Array(FooWithDate(2021-02-02T19:26:23.374Z,Foo,1), FooWithDate(2021-02-02T19:26:23.374Z,FooFoo,6)) did not contain the same elements as
      Array(FooWithDate(2020-01-02T03:04:05.006Z,Foo,1), FooWithDate(2020-01-02T03:04:05.006Z,FooFoo,6))

      Attachments

        Issue Links

          Activity

            People

              Unassigned Unassigned
              bryan.jeffrey@gmail.com Bryan Jeffrey
              Votes:
              0 Vote for this issue
              Watchers:
              2 Start watching this issue

              Dates

                Created:
                Updated:
                Resolved: