Uploaded image for project: 'Spark'
  1. Spark
  2. SPARK-20307

SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer

    Details

    • Type: Improvement
    • Status: Resolved
    • Priority: Minor
    • Resolution: Fixed
    • Affects Version/s: 2.1.0
    • Fix Version/s: 2.3.0
    • Component/s: SparkR
    • Labels:
      None
    • Target Version/s:

      Description

      when training a model in SparkR with string variables (tested with spark.randomForest, but i assume is valid for all spark.xx functions that apply a StringIndexer under the hood), testing on a new dataset with factor levels that are not in the training set will throw an "Unseen label" error.
      I think this can be solved if there's a method to pass setHandleInvalid on to the StringIndexers when calling spark.randomForest.

      code snippet:

      # (i've run this in Zeppelin which already has SparkR and the context loaded)
      #library(SparkR)
      #sparkR.session(master = "local[*]") 
      data = data.frame(clicked = base::sample(c(0,1),100,replace=TRUE),
                                    someString = base::sample(c("this", "that"), 100, replace=TRUE), stringsAsFactors=FALSE)
      trainidxs = base::sample(nrow(data), nrow(data)*0.7)
      traindf = as.DataFrame(data[trainidxs,])
      testdf = as.DataFrame(rbind(data[-trainidxs,],c(0,"the other")))
      
      rf = spark.randomForest(traindf, clicked~., type="classification", 
                              maxDepth=10, 
                              maxBins=41,
                              numTrees = 100)
      
      predictions = predict(rf, testdf)
      SparkR::collect(predictions)    
      

      stack trace:

      Error in handleErrors(returnStatus, conn): org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 607.0 failed 1 times, most recent failure: Lost task 0.0 in stage 607.0 (TID 1581, localhost, executor driver): org.apache.spark.SparkException: Failed to execute user defined function($anonfun$4: (string) => double)
      at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
      at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
      at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:377)
      at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:231)
      at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:225)
      at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:826)
      at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:826)
      at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
      at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
      at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
      at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
      at org.apache.spark.scheduler.Task.run(Task.scala:99)
      at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)
      at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
      at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
      at java.lang.Thread.run(Thread.java:745)
      Caused by: org.apache.spark.SparkException: Unseen label: the other.
      at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$4.apply(StringIndexer.scala:170)
      at org.apache.spark.ml.feature.StringIndexerModel$$anonfun$4.apply(StringIndexer.scala:166)
      ... 16 more

      Driver stacktrace:
      at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
      at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1423)
      at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1422)
      at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
      at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
      at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1422)
      at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
      at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
      at scala.Option.foreach(Option.scala:257)
      at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:802)
      at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1650)
      at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
      at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
      at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
      at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
      at org.apache.spark.SparkContext.runJob(SparkContext.scala:1918)
      at org.apache.spark.SparkContext.runJob(SparkContext.scala:1931)
      at org.apache.spark.SparkContext.runJob(SparkContext.scala:1944)
      at org.apache.spark.SparkContext.runJob(SparkContext.scala:1958)
      at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:935)
      at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
      at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
      at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
      at org.apache.spark.rdd.RDD.collect(RDD.scala:934)
      at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:275)
      at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$execute$1$1.apply(Dataset.scala:2371)
      at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
      at org.apache.spark.sql.Dataset.withNewExecutionId(Dataset.scala:2765)
      at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$execute$1(Dataset.scala:2370)
      at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2375)
      at org.apache.spark.sql.Dataset$$anonfun$org$apache$spark$sql$Dataset$$collect$1.apply(Dataset.scala:2375)
      at org.apache.spark.sql.Dataset.withCallback(Dataset.scala:2778)
      at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collect(Dataset.scala:2375)
      at org.apache.spark.sql.Dataset.collect(Dataset.scala:2351)
      at org.apache.spark.sql.api.r.SQLUtils$.dfToCols(SQLUtils.scala:208)
      at org.apache.spark.sql.api.r.SQLUtils.dfToCols(SQLUtils.scala)
      at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
      at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
      at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
      at java.lang.reflect.Method.invoke(Method.java:498)
      at org.apache.spark.api.r.RBackendHandler.handleMethodCall(RBackendHandler.scala:167)
      at org.apache.spark.api.r.RBackendHandler.channelRead0(RBackendHandler.scala:108)
      at org.apache.spark.api.r.RBackendHandler.channelRead0(RBackendHandler.scala:40)
      at io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:105)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:367)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:353)
      at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:346)
      at io.netty.handler.timeout.IdleStateHandler.channelRead(IdleStateHandler.java:266)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:367)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:353)
      at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:346)
      at io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:102)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:367)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:353)
      at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:346)
      at io.netty.handler.codec.ByteToMessageDecoder.fireChannelRead(ByteToMessageDecoder.java:293)
      at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:267)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:367)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:353)
      at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:346)
      at io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1294)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:367)
      at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:353)
      at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:911)
      at io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:131)
      at io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:652)
      at io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:575)
      at io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:489)
      at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:451)
      at io.netty.util.concurrent.SingleThreadEventExecutor$2.run(SingleThreadEventExecutor.java:140)
      at io.netty.util.concurrent.DefaultThreadFactory$DefaultRunnableDecorator.run(DefaultThreadFactory.java:144)
      at java

        Attachments

          Issue Links

            Activity

              People

              • Assignee:
                wm624 Miao Wang
                Reporter:
                alrutten Anne Rutten
              • Votes:
                0 Vote for this issue
                Watchers:
                6 Start watching this issue

                Dates

                • Created:
                  Updated:
                  Resolved: