[Spark SQL]: Dataframe group by potential bug (Scala)

classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

[Spark SQL]: Dataframe group by potential bug (Scala)

ludwiggj
This post was updated on .
This is using Spark Scala 2.4.4. I'm getting some very strange behaviour
after reading in a dataframe from a json file, using sparkSession.read in
permissive mode. I've included the error column when reading in the data, as
I want to log details of any errors in the input json file.

My suspicion is that I've found a bug in Spark, though I'm happy to be
wrong. I can't find any reference to this issue online.

Given this schema:

val salesSchema = StructType(Seq(
      StructField("shopId", LongType, nullable = false),
      StructField("game", StringType, nullable = false),
      StructField("sales", LongType, nullable = false),
      StructField("_corrupt_record", StringType)
))

I'm reading in this file:

{"shopId": 1, "game":  "Monopoly", "sales": 60}
{"shopId": 1, "game":  "Cleudo", "sales": 25}
{"shopId": 2, "game":  "Monopoly", "sales": 40}
{"shopId": "err", "game":  "Cleudo", "sales": 75}

Note that the last line has a deliberate error on the shopId field.

I read in the data:

val inputDataDF = sparkSession.read
      .schema(salesSchema)
      .option("mode", "PERMISSIVE")
      .json(filePath)

On displaying it:

+------+-----------+-----+-------------------------------------------------------+
|shopId|game      |sales |_corrupt_record                                                |
+------+----------+------+-------------------------------------------------------+
|1       |Monopoly |60     |null                                                                  |
|1       |Cleudo     |25     |null                                                                  |
|2       |Monopoly |40     |null                                                                  |
|null    |null         |null    |{"shopId": "err", "game":  "Cleudo","sales": 75} |
+------+--------+-----+----------------------------------------------------------+

I then filter out the failures:

val validSales = inputDataDF.filter(col("_corrupt_record").isNull)

I use a group by to sum the sales per game:

val incorrectReportDF = validSales.groupBy("game")
      .agg(
        count(col("game")),
        sum(col("sales")) as "salesTotal"
      ).sort("game")

The result is incorrect:

+----------+----------------+----------+
|game       |count(game)   |salesTotal|
+----------+----------------+----------+
|Cleudo     |2                    |100        |
|Monopoly |2                    |100        |
+----------+----------------+----------+

The Cleudo sales should only be 25, but the count column shows that the
erroneous record has been counted too. Since the sales of the error record
are 75, the incorrect total is 100.

If I change the groupBy statement to collect all the records contributing
to each group, I get a different result:

 val reportDF = validSales.groupBy("game")
      .agg(
        count(col("game")),
        sum(col("sales")) as "salesTotal",
        collect_list(struct("*")).as("allRecords")
      ).sort("game")

+----------+--------------+----------+--------------------------------------------+
|game       |count(game)|salesTotal|allRecords                                           |
+----------+--------------+----------+--------------------------------------------+
|Cleudo     |1                 |25          |[[1, Cleudo, 25,]]                                |
|Monopoly |2                 |100        |[[1, Monopoly, 60,], [2, Monopoly,40,]] |
+----------+--------------+----------+--------------------------------------------+

The salesTotal is now correct. However, if I then process this dataframe
further, for example by dropping the allRecords column, or converting it to
a DataSet based on a simple case class, the salesTotals revert to the
incorrect values.

The only reliable way I've found to handle this is to process the allRecords
column via an explode, and then group the resulting records again.

In a single statement:

val allInOneReport = validSales.groupBy("game")
      .agg(
        collect_list(struct("*")).as("allRecords")
      )
      .select(explode($"allRecords"))
      .select($"col.game", $"col.sales")
      .groupBy("game")
      .agg(
        sum(col("sales")) as "salesTotal"
      )
      .sort("game")

+-----------+----------+
|game        |salesTotal|
+-----------+----------+
|Cleudo      |25          |
|Monopoly  |100        |
+-----------+----------+

I've created a gist (https://gist.github.com/ludwiggj/1fc3ac09ca698e22143e824c683e2394) with all the code and the output.

Thanks,

Graeme.

--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: user-unsubscribe@spark.apache.org