Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use XGBoost PySpark API with MLeap? #867

Open
venkatacrc opened this issue Sep 13, 2023 · 4 comments
Open

How to use XGBoost PySpark API with MLeap? #867

venkatacrc opened this issue Sep 13, 2023 · 4 comments

Comments

@venkatacrc
Copy link

venkatacrc commented Sep 13, 2023

Problem description:
We were able to serialize the XGBoost model with MLeap using the older PySpark API (dmlc/xgboost#4656) as shown below:

import mleap.pyspark
from mleap.pyspark.spark_support import SimpleSparkSerializer

trans_model = model.transform(df)
local_path = "jar:file:/tmp/pyspark.model.zip"
model.serializeToBundle(local_path, trans_model)

But We are not able to do with official PySpark API (https://xgboost.readthedocs.io/en/stable/tutorials/spark_estimator.html) support. Is there anything flags that I need to use to make this work?
We use MLeap to store the model as a Serialized bundle and use it in the Java runtime enviroment for model serving.

Steps Followed:

Download spark

wget https://dlcdn.apache.org/spark/spark-3.3.3/spark-3.3.3-bin-hadoop3.tgz
tar -xvf spark-3.3.3-bin-hadoop3.tgz

Download xgboost jars

wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j_2.12/1.7.3/xgboost4j_2.12-1.7.3.jar
wget https://repo1.maven.org/maven2/ml/dmlc/xgboost4j-spark_2.12/1.7.3/xgboost4j-spark_2.12-1.7.3.jar

Build MLeap fat jar :

https://github.com/combust/mleap/blob/master/mleap-databricks-runtime-fat/README.md
git clone --recursive https://github.com/combust/mleap.git
cd mleap
git checkout tags/v0.22.0
sbt mleap-databricks-runtime-fat/assembly
cp mleap-databricks-runtime-fat/target/scala-2.12/mleap-databricks-runtime-fat-assembly-0.22.0.jar ../spark-3.3.3-bin-hadoop3/jars

Install python requirements

pip install mleap==0.22.0
pip install xgboost==1.7.3
pip install pyarrow

Running the example

cd ../spark-3.3.3-bin-hadoop3
./bin/spark-submit example.py

error log

23/09/13 22:06:29 INFO CodeGenerator: Code generated in 25.565617 ms
+---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+
|feat1|feat2|feat3|feat4|feat5|feat7|feat8| feat9|feat10|feat11|feat12|label| features|rawPrediction|prediction|probability|
+---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+
| 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 7| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[7.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 7| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[7.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 5| 20| 1| 6| 1| 10| 3|53948| 245351| 1| 2| 0|[5.0,20.0,1.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
| 5| 20| 3| 6| 1| 10| 3|53948| 245351| 1| 2| 1|[5.0,20.0,3.0,6.0...| [-0.0,0.0]| 0.0| [0.5,0.5]|
+---------+----+-------+--------+----------+------+--------+-----+------------+----+----+-----+--------------------+-------------+----------+-----------+

Traceback (most recent call last):
File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/bin/code1.py", line 47, in
model.serializeToBundle(local_path, predictions)
File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 25, in serializeToBundle
serializer.serializeToBundle(self, path, dataset=dataset)
File "/Users/s0a018g/opt/anaconda3/lib/python3.8/site-packages/mleap/pyspark/spark_support.py", line 42, in serializeToBundle
self._java_obj.serializeToBundle(transformer._to_java(), path, dataset._jdf)
File "/Users/s0a018g/local_setup/spark-3.3.3-bin-hadoop3/python/lib/pyspark.zip/pyspark/ml/pipeline.py", line 363, in _to_java
AttributeError: 'SparkXGBClassifierModel' object has no attribute '_to_java'
23/09/13 22:06:30 INFO SparkContext: Invoking stop() from shutdown hook

example.py

Thanks @agsachin for creating the instructions that can be easily reproduced on Mac.

@venkatacrc
Copy link
Author

@WeichenXu123 Any insights on this issue? Please help us.

@austinzh
Copy link
Contributor

austinzh commented Nov 7, 2023

@venkatacrc MLeap serializeToBundle work only with java objects.
The latest XGBoost Estimator is a pure python implementation instead a Java Estimator
OLD API

class XGboostEstimator(JavaEstimator, XGBoostReadable, JavaMLWritable, ParamGettersSetters):

vs NEW API

class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):

@venkatacrc
Copy link
Author

venkatacrc commented Nov 7, 2023

@austinzh Thank you. Will there be any support for the Java like the old API in the future?

@austinzh
Copy link
Contributor

If we are going to support pure python implementation of _SparkXGBEstimator that will be a lot of groundwork.

  • We need to re-implement SimpleSparkSerializer to work in Python instead of JVM, for that we can use MLeapSerializer
  • We need to implement an Serializer that work with Python XGboost Model like sklearn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants