-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-1788391 Validation Error When Converting Pandas String Column to NumPy Array in Prediction Function #123
Comments
Hi @kenkoooo, could you kindly provide a repro code snippets? |
When you register a model like in the following example, a model that accepts a DataFrame with a column of string type will be registered: import snowflake.snowpark as snowpark
from snowflake.ml.model import custom_model
import pandas as pd
from snowflake.ml.model.model_signature import ModelSignature, FeatureSpec, DataType
from snowflake.ml.registry import Registry
class MyCustomModel(custom_model.CustomModel):
def __init__(self, context: custom_model.ModelContext) -> None:
super().__init__(context)
@custom_model.inference_api
def predict(self, X: pd.DataFrame) -> pd.DataFrame:
return X
def main(session: snowpark.Session):
mc = custom_model.ModelContext()
model = MyCustomModel(mc)
signature = ModelSignature(
inputs=[FeatureSpec(name="COL", dtype=DataType.STRING)],
outputs=[FeatureSpec(name="COL", dtype=DataType.STRING)],
)
reg = Registry(session=session)
reg.log_model(model, model_name="MY_COOL_MODEL", signatures={"predict": signature})
return session.create_dataframe([["OK"]]) You can use this model like this: import snowflake.snowpark as snowpark
import pandas as pd
from snowflake.ml.registry import Registry
def main(session: snowpark.Session):
reg = Registry(session=session)
mv = reg.get_model("MY_COOL_MODEL").last()
X = pd.DataFrame({"COL": ["A", "B", "C"]})
res = mv.run(X, function_name="PREDICT", strict_input_validation=False)
return session.create_dataframe(res) This works as expected since the import snowflake.snowpark as snowpark
import pandas as pd
from snowflake.ml.registry import Registry
def main(session: snowpark.Session):
reg = Registry(session=session)
mv = reg.get_model("MY_COOL_MODEL").last()
X = pd.DataFrame({"COL": ["A", "B", "C"]})
# Ensure that the column is of type string
X["COL"] = X["COL"].astype("string")
res = mv.run(X, function_name="PREDICT", strict_input_validation=False)
return session.create_dataframe(res) You will encounter an error message like the following:
|
Hi @kenkoooo Thank you for reporting this issue. We believe this is currently because pandas.StringDType has not been supported by us yet. We will add support soon and before that, you could workaround by |
@kenkoooo This bug is fixed from version |
When a Pandas DataFrame containing a string column is passed to the prediction function, it is converted to a NumPy array and then validated. During validation, the column's data type is compared with the type specified in the saved model's signature.
Even if the column's type is
string[python]
in the original Pandas DataFrame, it will be represented as 'O' (object) after conversion to a NumPy array.As a result,
np.can_cast(arr.dtype, feature_type._numpy_type, casting='no')
will returnFalse
, causing the validation to fail if the first type is 'O' and the second isnp.str_
.The text was updated successfully, but these errors were encountered: