Spark ml how to extract split points from trained decision tree mode

classic Classic list List threaded Threaded
6 messages Options
Reply | Threaded
Open this post in threaded view
|

Spark ml how to extract split points from trained decision tree mode

AaronLee-2
I am following  official spark 2.4.3 tutorial
<https://spark.apache.org/docs/2.4.3/ml-classification-regression.html#decision-tree-classifier>  
trained a decision tree model. How to extract split points from the trained
model?

// model
val dt = new DecisionTreeClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setMaxBins(10)

// Train model.  This also runs the indexers.
val dtm = dt.fit(trainingData)

// extract bin split points
how to do it                   <- ?



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: [hidden email]

Reply | Threaded
Open this post in threaded view
|

Re: Spark ml how to extract split points from trained decision tree mode

srowen
You should be able to look at dtm.rootNode and, treating it as an InternalNode, get the .split from it

On Thu, Jun 11, 2020 at 7:02 PM AaronLee <[hidden email]> wrote:
I am following  official spark 2.4.3 tutorial
<https://spark.apache.org/docs/2.4.3/ml-classification-regression.html#decision-tree-classifier
trained a decision tree model. How to extract split points from the trained
model?

// model
val dt = new DecisionTreeClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setMaxBins(10)

// Train model.  This also runs the indexers.
val dtm = dt.fit(trainingData)

// extract bin split points
how to do it                   <- ?



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: [hidden email]

Reply | Threaded
Open this post in threaded view
|

Re: Spark ml how to extract split points from trained decision tree mode

AaronLee-2
Thanks srowen. I also checked
https://www.programcreek.com/scala/org.apache.spark.ml.tree.InternalNode.
Splits are available via "InternalNode" ".split" attribute. But
"dtm.rootNode"  belongs to "LeafNode".

```
scala> dtm.rootNode
res9: org.apache.spark.ml.tree.Node = LeafNode(prediction = 0.0, impurity =
0.3153051824490453)

scala> dftm.rootNode.
impurity   prediction

scala> dftm.rootNode.getClass.getSimpleName
res13: String = LeafNode

scala> import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}
import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}

scala> val intnode = dftm.rootNode.asInstanceOf[InternalNode]
java.lang.ClassCastException: org.apache.spark.ml.tree.LeafNode cannot be
cast to org.apache.spark.ml.tree.InternalNode
  ... 51 elided

```



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: [hidden email]

Reply | Threaded
Open this post in threaded view
|

Re: Spark ml how to extract split points from trained decision tree mode

srowen
Hm, the root is a leaf? it's possible but that means there are no splits. If it's a toy example, could be.
This was just off the top of my head looking at the code, so could be missing something, but a non-trivial tree should start with an internalnode.

On Thu, Jun 11, 2020 at 11:01 PM AaronLee <[hidden email]> wrote:
Thanks srowen. I also checked
https://www.programcreek.com/scala/org.apache.spark.ml.tree.InternalNode.
Splits are available via "InternalNode" ".split" attribute. But
"dtm.rootNode"  belongs to "LeafNode".

```
scala> dtm.rootNode
res9: org.apache.spark.ml.tree.Node = LeafNode(prediction = 0.0, impurity =
0.3153051824490453)

scala> dftm.rootNode.
impurity   prediction

scala> dftm.rootNode.getClass.getSimpleName
res13: String = LeafNode

scala> import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}
import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node}

scala> val intnode = dftm.rootNode.asInstanceOf[InternalNode]
java.lang.ClassCastException: org.apache.spark.ml.tree.LeafNode cannot be
cast to org.apache.spark.ml.tree.InternalNode
  ... 51 elided

```



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: [hidden email]

Reply | Threaded
Open this post in threaded view
|

Re: Spark ml how to extract split points from trained decision tree mode

AaronLee-2
@srowen. You are totally right, the model was not trained correctly. But it
is weird as the dataset I used actually has 50m rows. It has binary label
with 20% positive, and 1 feature in feature vector. Do not understand why it
does not trained correctly


```
scala> df2.count
res56: Long = 48174858

scala> df2.show
+--------------------+-----+
|            features|label|
+--------------------+-----+
|              [14.0]|  1.0|
|               [2.0]|  0.0|
|               [2.0]|  0.0|
|               [1.0]|  1.0|
|[0.9700000286102295]|  1.0|
|[1.9600000381469727]|  0.0|
|[0.9900000095367432]|  0.0|
|[11.739999771118164]|  1.0|
|               [1.0]|  0.0|
|[0.9800000190734863]|  0.0|
|               [5.0]|  0.0|
| [5.940000057220459]|  1.0|
|              [11.0]|  0.0|
|               [4.0]|  0.0|
|               [1.0]|  1.0|
|[1.9700000286102295]|  0.0|
| [6.989999771118164]|  0.0|
|[0.9700000286102295]|  0.0|
|[0.9700000286102295]|  0.0|
|[0.9900000095367432]|  0.0|
+--------------------+-----+
only showing top 20 rows


scala> df2.printSchema
root
 |-- features: vector (nullable = true)
 |-- label: double (nullable = true)

scala> val dt = new
DecisionTreeClassifier().setLabelCol("label").setFeaturesCol("features").setMaxBins(10)
dt: org.apache.spark.ml.classification.DecisionTreeClassifier =
dtc_2b6b6e170840

scala>  val dtm = dt.fit(df2)
*dtm: org.apache.spark.ml.classification.DecisionTreeClassificationModel =
DecisionTreeClassificationModel (uid=dtc_2b6b6e170840) of depth 0 with 1
nodes
*

scala> val df3 = dtm.transform(df2)
df3: org.apache.spark.sql.DataFrame = [features: vector, label: double ... 3
more fields]

scala>  df3.show(100,false)
+--------------------+-----+----------------------+----------------------------------------+----------+
|features            |label|rawPrediction         |probability                            
|prediction|
+--------------------+-----+----------------------+----------------------------------------+----------+
|[14.0]              |1.0
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[2.0]               |0.0
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[2.0]               |0.0
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[1.0]               |1.0
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[0.9700000286102295]|1.0
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[1.9600000381469727]|0.0
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[0.9900000095367432]|0.0
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
....
```




--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: [hidden email]

Reply | Threaded
Open this post in threaded view
|

Re: Spark ml how to extract split points from trained decision tree mode

AaronLee-2
instead continue explore and debug, switch to sklearn decision tree in the
end ... lol



--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: [hidden email]