Details
Description
I have a Dataset of non-unique identifiers that I can use with Dataset::flatMap() to create multiple rows with sub-identifiers for each id. When I run the code below, the limit(2) call is placed after the call to flatMap() in the optimized logical plan. This unexpectedly yields only 2 rows, when I would expect it to yield 6.
StructType idSchema = DataTypes.createStructType(List.of(DataTypes.createStructField("id", DataTypes.LongType, false))); StructType flatMapSchema = DataTypes.createStructType(List.of( DataTypes.createStructField("id", DataTypes.LongType, false), DataTypes.createStructField("subId", DataTypes.LongType, false) ));Dataset<Row> inputDataset = context.sparkSession().createDataset( LongStream.range(0,5).mapToObj((id) -> RowFactory.create(id)).collect(Collectors.toList()), RowEncoder.apply(idSchema) ); return inputDataset .distinct() .limit(2) .flatMap((Row row) -> { Long id = row.getLong(row.fieldIndex("id")); return LongStream.range(6,8).mapToObj((subid) -> RowFactory.create(id, subid)).iterator(); }, RowEncoder.apply(flatMapSchema));
When run, the above code produces something like:
id | subID |
---|---|
0 | 6 |
0 | 7 |
But I would expect something like:
id | subID |
---|---|
1 | 6 |
1 | 7 |
1 | 8 |
0 | 6 |
0 | 7 |
0 | 8 |