Load ML Pipeline model with UDF & Custom Transformer on Spark local mode

Previous Topic Next Topic
 
classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

Load ML Pipeline model with UDF & Custom Transformer on Spark local mode

ihainan

Hi guys,

I have a special usage scenario where I need to run Spark on local mode and load ML Pipeline models with custom ML Transformers from the local file system during runtime. Everything works well except for the models using UDF in the transform method.

I uploaded an example project to my Github https://github.com/ihainan/UDLocalExampleProject, you can execute command "cd App && sbt run"  and see the error message:

20/07/28 05:22:23 INFO DAGScheduler: ResultStage 4 (show at Application.scala:20) failed in 0.046 s due to Job aborted due to stage failure: Task 0 in stage 4.0 failed 1 times, most recent failure: Lost task 0.0 in stage 4.0 (TID 4, localhost, executor driver): java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2301)
at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1431)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2350)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
at scala.collection.immutable.List$SerializationProxy.readObject(List.scala:479)
at sun.reflect.GeneratedMethodAccessor13.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1184)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2235)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2344)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2268)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2126)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1625)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:465)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:423)
at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:83)
at org.apache.spark.scheduler.Task.run(Task.scala:123)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)


The implementation of the custom ML transformer is pretty simple:

---
// https://github.com/ihainan/UDLocalExampleProject/blob/master/CustomTransformer/src/main/scala/me/ihainan/test/MyTransformer.scala
class MyTransformer(override val uid: String)
  extends Transformer
    with DefaultParamsWritable {

  // ...implement other methods

  override def transform(df: Dataset[_]): org.apache.spark.sql.DataFrame = {
    // create a UDF and use it
    import df.sparkSession.implicits._
    val addOneUDF = udf { in: Int => in + 1 }
    df.groupBy("C3").agg(addOneUDF(sum($"C2")).as("text_counts"))
  }
}
---

And my application looks like:

---
// https://github.com/ihainan/UDLocalExampleProject/blob/master/App/src/main/scala/me/ihainan/test/Application.scala
object Application extends App {
  val spark = SparkSession.builder().master("local[*]").appName("test").getOrCreate()

  // Load custom transformer JAR file during runtime
  val customTransformerURL = new File("./transformer.jar").toURI.toURL
  val newClassLoader = new URLClassLoader(Array(customTransformerURL), this.getClass.getClassLoader)
  Thread.currentThread().setContextClassLoader(newClassLoader)

  // Load model
  val model = PipelineModel.load("./model")

  // Test model
  val df = spark.createDataFrame(Seq(
    (1, 2, "1"),
    (3, 4, "2"),
    (5, 6, "3")
  )).toDF("C1", "C2", "C3")
  model.transform(df).show()
}
---

I have two questions about this issue:

1. The error occurs inside the ResultTask.run method where it tries to deserialize the RDDs of the task. I guess the task thread can't find the UDF anonymous class, because if I add the custom transformer JAR into classpath (move the transformer.jar into the App/lib directory) the application runs well. I'm wondering if there's a way to make my original application work?

2. I also tried to implement a transformer like the following code (replace groupBy & aggr with select) and it can work:

---
class MyTransformer(override val uid: String)
  extends Transformer
    with DefaultParamsWritable {

  // ...implement other methods

  override def transform(df: Dataset[_]): org.apache.spark.sql.DataFrame = {
    // create a UDF and use it
    import df.sparkSession.implicits._
    val addOneString = udf { in: String => in + "1" }
    df.select(col("*"), addOneString($"C3").as("text_counts")) // DON'T use groupBy here
  }
}
---

I notice that my first application generates a DAG with two stages (ResultStage & ShuffleMapStage) and the second one only has a ResultStage. But I remember that both ResultTask and ShuffleMapTask run on a thread in the executor's thread pool, but why only the second application can find the UDF class?

image.png
The Stage 4 failed.

Thanks in advance!

Best Regards,
--
Ji Gao Fu (ihainan)