Description
When creating a custom data source with the Data Source API V2 it seems that the computation of possible scan reuses is broken when the same data source is used but configured with different configuration options. In the case when both scans produce the same schema (which is always the case for count queries with column pruning enabled) the optimizer will reuse the scan produced by on of the data source instance for both branches of the query.
This can lead to wrong results if the configuration option somehow influences the returned data.
The behavior can be reproduced with the following example:
import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SparkSession} import scala.tools.nsc.interpreter.JList class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { class Reader(rowCount: Int) extends DataSourceReader with SupportsPushDownRequiredColumns { var requiredSchema = new StructType().add("i", "int").add("j", "int") override def pruneColumns(requiredSchema: StructType): Unit = { this.requiredSchema = requiredSchema } override def readSchema(): StructType = { requiredSchema } override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = { val res = new java.util.ArrayList[DataReaderFactory[Row]] res.add(new AdvancedDataReaderFactory(0, 5, requiredSchema)) res.add(new AdvancedDataReaderFactory(5, rowCount, requiredSchema)) res } } override def createReader(options: DataSourceOptions): DataSourceReader = new Reader(options.get("rows").orElse("10").toInt) } class AdvancedDataReaderFactory(start: Int, end: Int, requiredSchema: StructType) extends DataReaderFactory[Row] with DataReader[Row] { private var current = start - 1 override def createDataReader(): DataReader[Row] = { new AdvancedDataReaderFactory(start, end, requiredSchema) } override def close(): Unit = {} override def next(): Boolean = { current += 1 current < end } override def get(): Row = { val values = requiredSchema.map(_.name).map { case "i" => current case "j" => -current } Row.fromSeq(values) } } object DataSourceTest extends App { val spark = SparkSession.builder().master("local[*]").getOrCreate() val cls = classOf[AdvancedDataSourceV2] val with100 = spark.read.format(cls.getName).option("rows", 100).load() val with10 = spark.read.format(cls.getName).option("rows", 10).load() assert(with100.union(with10).count == 110) }