diff --git a/pyzoo/test/zoo/orca/learn/jep/test_pytorch_estimator_for_spark.py b/pyzoo/test/zoo/orca/learn/jep/test_pytorch_estimator_for_spark.py index 9b01ce0afb..9ce108a982 100644 --- a/pyzoo/test/zoo/orca/learn/jep/test_pytorch_estimator_for_spark.py +++ b/pyzoo/test/zoo/orca/learn/jep/test_pytorch_estimator_for_spark.py @@ -22,6 +22,7 @@ from zoo.orca import init_orca_context, stop_orca_context from zoo.orca.data.pandas import read_csv +from zoo.orca.data import SparkXShards from zoo.orca.learn.pytorch import Estimator from zoo.orca.learn.metrics import Accuracy from zoo.orca.learn.trigger import EveryEpoch @@ -68,6 +69,10 @@ def transform(df): } return result + def transform_del_y(d): + result = {"x": d["x"]} + return result + OrcaContext.pandas_read_backend = "pandas" file_path = os.path.join(resource_path, "orca/learn/ncf.csv") data_shard = read_csv(file_path) @@ -84,6 +89,13 @@ def transform(df): est2.fit(data=data_shard, epochs=8, batch_size=2, validation_data=data_shard, validation_methods=[Accuracy()], checkpoint_trigger=EveryEpoch()) est2.evaluate(data_shard, validation_methods=[Accuracy()], batch_size=2) + pred_result = est2.predict(data_shard) + pred_c = pred_result.collect() + assert(pred_result, SparkXShards) + pred_shard = data_shard.transform_shard(transform_del_y) + pred_result2 = est2.predict(pred_shard) + pred_c_2 = pred_result2.collect() + assert (pred_c[0]["prediction"] == pred_c_2[0]["prediction"]).all() if __name__ == "__main__": diff --git a/pyzoo/zoo/orca/data/utils.py b/pyzoo/zoo/orca/data/utils.py index 81031b7a88..35ef1e0b00 100644 --- a/pyzoo/zoo/orca/data/utils.py +++ b/pyzoo/zoo/orca/data/utils.py @@ -194,8 +194,11 @@ def to_sample(data): from bigdl.util.common import Sample data = check_type_and_convert(data, allow_list=True, allow_tuple=False) features = data["x"] - labels = data["y"] length = features[0].shape[0] + if "y" in data: + labels = data["y"] + else: + labels = np.array([[-1] * length]) for i in range(length): fs = [feat[i] for feat in features] diff --git a/pyzoo/zoo/orca/learn/pytorch/estimator.py b/pyzoo/zoo/orca/learn/pytorch/estimator.py index 9130682165..582273c6a2 100644 --- a/pyzoo/zoo/orca/learn/pytorch/estimator.py +++ b/pyzoo/zoo/orca/learn/pytorch/estimator.py @@ -251,8 +251,16 @@ def fit(self, data, epochs=1, batch_size=32, validation_data=None, validation_me "callable data_creators but get " + data.__class__.__name__) return self - def predict(self, data, **kwargs): - pass + def predict(self, data, batch_size=4): + from zoo.orca.learn.utils import convert_predict_to_xshard + if isinstance(data, SparkXShards): + from zoo.orca.data.utils import to_sample + data_rdd = data.rdd.flatMap(to_sample) + else: + raise ValueError("Data should be XShards, each element needs to be {'x': a feature " + "numpy array}.") + predicted_rdd = self.model.predict(data_rdd, batch_size=batch_size) + return convert_predict_to_xshard(predicted_rdd) def evaluate(self, data, validation_methods=None, batch_size=32): from zoo.orca.data.utils import to_sample