Skip to content

Commit

Permalink
Improved Enum support
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Feb 17, 2024
1 parent 0d4100c commit f74bc7f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
21 changes: 19 additions & 2 deletions ext/polars/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ use polars::frame::NullStrategy;
use polars::io::avro::AvroCompression;
use polars::prelude::*;
use polars::series::ops::NullBehavior;
use polars_core::utils::arrow::array::Array;
use polars_utils::total_ord::TotalEq;
use smartstring::alias::String as SmartString;

use crate::object::OBJECT_NAME;
use crate::rb_modules::series;
use crate::{RbDataFrame, RbLazyFrame, RbPolarsErr, RbResult, RbSeries, RbTypeError, RbValueError};

pub(crate) fn slice_to_wrapped<T>(slice: &[T]) -> &[Wrap<T>] {
Expand Down Expand Up @@ -80,6 +82,13 @@ pub(crate) fn get_series(obj: Value) -> RbResult<Series> {
Ok(rbs.series.borrow().clone())
}

pub(crate) fn to_series(s: RbSeries) -> Value {
let series = series();
series
.funcall::<_, _, Value>("_from_rbseries", (s,))
.unwrap()
}

impl TryConvert for Wrap<NullValues> {
fn try_convert(ob: Value) -> RbResult<Self> {
if let Ok(s) = String::try_convert(ob) {
Expand Down Expand Up @@ -155,8 +164,16 @@ impl IntoValue for Wrap<DataType> {
}
DataType::Object(_, _) => pl.const_get::<_, Value>("Object").unwrap(),
DataType::Categorical(_, _) => pl.const_get::<_, Value>("Categorical").unwrap(),
DataType::Enum(_, _) => {
todo!()
DataType::Enum(rev_map, _) => {
// we should always have an initialized rev_map coming from rust
let categories = rev_map.as_ref().unwrap().get_categories();
let class = pl.const_get::<_, Value>("Enum").unwrap();
let s = Series::from_arrow("category", categories.to_boxed()).unwrap();
let series = to_series(s.into());
class
.funcall::<_, _, Value>("new", (series,))
.unwrap()
.into()
}
DataType::Time => pl.const_get::<_, Value>("Time").unwrap(),
DataType::Struct(fields) => {
Expand Down
2 changes: 1 addition & 1 deletion test/series_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_new_strict
def test_new_enum
dtype = Polars::Enum.new(["a", "b"])
s = Polars::Series.new([nil, "a", "b"], dtype: dtype)
assert_series [nil, "a", "b"], s #, dtype: dtype
assert_series [nil, "a", "b"], s, dtype: dtype
end

def test_new_bigdecimal
Expand Down

0 comments on commit f74bc7f

Please sign in to comment.