Skip to content

Commit

Permalink
Fixed error with then method - fixes #64
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Apr 19, 2024
1 parent 1d1ced0 commit 9a563b9
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 63 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Added `truncate_ragged_lines` option to `read_csv`, `read_csv_batched`, and `scan_csv` methods
- Added precompiled gem for Linux x86-64 MUSL
- Changed `drop` method to ignore missing columns
- Fixed error with `then` method

## 0.9.0 (2024-03-03)

Expand Down
64 changes: 47 additions & 17 deletions ext/polars/src/functions/whenthen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,72 @@ use polars::lazy::dsl;

use crate::RbExpr;

pub fn when(condition: &RbExpr) -> RbWhen {
RbWhen {
inner: dsl::when(condition.inner.clone()),
}
}

#[magnus::wrap(class = "Polars::RbWhen")]
#[derive(Clone)]
pub struct RbWhen {
pub inner: dsl::When,
}

impl From<dsl::When> for RbWhen {
fn from(inner: dsl::When) -> Self {
RbWhen { inner }
}
}

#[magnus::wrap(class = "Polars::RbWhenThen")]
#[magnus::wrap(class = "Polars::RbThen")]
#[derive(Clone)]
pub struct RbThen {
pub inner: dsl::Then,
}

impl From<dsl::Then> for RbThen {
fn from(inner: dsl::Then) -> Self {
RbThen { inner }
}
#[magnus::wrap(class = "Polars::RbChainedWhen")]
#[derive(Clone)]
pub struct RbChainedWhen {
pub inner: dsl::ChainedWhen,
}

#[magnus::wrap(class = "Polars::RbChainedThen")]
#[derive(Clone)]
pub struct RbChainedThen {
pub inner: dsl::ChainedThen,
}

impl RbWhen {
pub fn then(&self, expr: &RbExpr) -> RbThen {
self.inner.clone().then(expr.inner.clone()).into()
pub fn then(&self, statement: &RbExpr) -> RbThen {
RbThen {
inner: self.inner.clone().then(statement.inner.clone()),
}
}
}

impl RbThen {
pub fn overwise(&self, expr: &RbExpr) -> RbExpr {
self.inner.clone().otherwise(expr.inner.clone()).into()
pub fn when(&self, condition: &RbExpr) -> RbChainedWhen {
RbChainedWhen {
inner: self.inner.clone().when(condition.inner.clone()),
}
}

pub fn otherwise(&self, statement: &RbExpr) -> RbExpr {
self.inner.clone().otherwise(statement.inner.clone()).into()
}
}

pub fn when(predicate: &RbExpr) -> RbWhen {
dsl::when(predicate.inner.clone()).into()
impl RbChainedWhen {
pub fn then(&self, statement: &RbExpr) -> RbChainedThen {
RbChainedThen {
inner: self.inner.clone().then(statement.inner.clone()),
}
}
}

impl RbChainedThen {
pub fn when(&self, condition: &RbExpr) -> RbChainedWhen {
RbChainedWhen {
inner: self.inner.clone().when(condition.inner.clone()),
}
}

pub fn otherwise(&self, statement: &RbExpr) -> RbExpr {
self.inner.clone().otherwise(statement.inner.clone()).into()
}
}
17 changes: 13 additions & 4 deletions ext/polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use error::{RbPolarsErr, RbTypeError, RbValueError};
use expr::rb_exprs_to_exprs;
use expr::RbExpr;
use functions::string_cache::RbStringCacheHolder;
use functions::whenthen::{RbThen, RbWhen};
use functions::whenthen::{RbChainedThen, RbChainedWhen, RbThen, RbWhen};
use lazyframe::RbLazyFrame;
use lazygroupby::RbLazyGroupBy;
use magnus::{define_module, function, method, prelude::*, Error, Ruby};
Expand Down Expand Up @@ -1032,11 +1032,20 @@ fn init(ruby: &Ruby) -> RbResult<()> {
// extra
class.define_method("extend_constant", method!(RbSeries::extend_constant, 2))?;

// when then
let class = module.define_class("RbWhen", ruby.class_object())?;
class.define_method("_then", method!(RbWhen::then, 1))?;
class.define_method("then", method!(RbWhen::then, 1))?;

let class = module.define_class("RbWhenThen", ruby.class_object())?;
class.define_method("otherwise", method!(RbThen::overwise, 1))?;
let class = module.define_class("RbThen", ruby.class_object())?;
class.define_method("when", method!(RbThen::when, 1))?;
class.define_method("otherwise", method!(RbThen::otherwise, 1))?;

let class = module.define_class("RbChainedWhen", ruby.class_object())?;
class.define_method("then", method!(RbChainedWhen::then, 1))?;

let class = module.define_class("RbChainedThen", ruby.class_object())?;
class.define_method("when", method!(RbChainedThen::when, 1))?;
class.define_method("otherwise", method!(RbChainedThen::otherwise, 1))?;

// sql
let class = module.define_class("RbSQLContext", ruby.class_object())?;
Expand Down
3 changes: 1 addition & 2 deletions lib/polars.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@
require_relative "polars/testing"
require_relative "polars/utils"
require_relative "polars/version"
require_relative "polars/when"
require_relative "polars/when_then"
require_relative "polars/whenthen"

module Polars
extend Convert
Expand Down
79 changes: 74 additions & 5 deletions lib/polars/functions/whenthen.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Functions
#
# @return [When]
#
# @example
# @example Below we add a column with the value 1, where column "foo" > 2 and the value -1 where it isn't.
# df = Polars::DataFrame.new({"foo" => [1, 3, 4], "bar" => [3, 4, 0]})
# df.with_column(Polars.when(Polars.col("foo") > 2).then(Polars.lit(1)).otherwise(Polars.lit(-1)))
# # =>
Expand All @@ -18,10 +18,79 @@ module Functions
# # │ 3 ┆ 4 ┆ 1 │
# # │ 4 ┆ 0 ┆ 1 │
# # └─────┴─────┴─────────┘
def when(expr)
expr = Utils.expr_to_lit_or_expr(expr)
pw = Plr.when(expr._rbexpr)
When.new(pw)
#
# @example Or with multiple when-then operations chained:
# df.with_columns(
# Polars.when(Polars.col("foo") > 2)
# .then(1)
# .when(Polars.col("bar") > 2)
# .then(4)
# .otherwise(-1)
# .alias("val")
# )
# # =>
# # shape: (3, 3)
# # ┌─────┬─────┬─────┐
# # │ foo ┆ bar ┆ val │
# # │ --- ┆ --- ┆ --- │
# # │ i64 ┆ i64 ┆ i32 │
# # ╞═════╪═════╪═════╡
# # │ 1 ┆ 3 ┆ 4 │
# # │ 3 ┆ 4 ┆ 1 │
# # │ 4 ┆ 0 ┆ 1 │
# # └─────┴─────┴─────┘
#
# @example The `otherwise` at the end is optional. If left out, any rows where none of the `when` expressions evaluate to True, are set to `null`:
# df.with_columns(Polars.when(Polars.col("foo") > 2).then(1).alias("val"))
# # =>
# # shape: (3, 3)
# # ┌─────┬─────┬──────┐
# # │ foo ┆ bar ┆ val │
# # │ --- ┆ --- ┆ --- │
# # │ i64 ┆ i64 ┆ i32 │
# # ╞═════╪═════╪══════╡
# # │ 1 ┆ 3 ┆ null │
# # │ 3 ┆ 4 ┆ 1 │
# # │ 4 ┆ 0 ┆ 1 │
# # └─────┴─────┴──────┘
#
# @example Pass multiple predicates, each of which must be met:
# df.with_columns(
# val: Polars.when(
# Polars.col("bar") > 0,
# Polars.col("foo") % 2 != 0
# )
# .then(99)
# .otherwise(-1)
# )
# # =>
# # shape: (3, 3)
# # ┌─────┬─────┬─────┐
# # │ foo ┆ bar ┆ val │
# # │ --- ┆ --- ┆ --- │
# # │ i64 ┆ i64 ┆ i32 │
# # ╞═════╪═════╪═════╡
# # │ 1 ┆ 3 ┆ 99 │
# # │ 3 ┆ 4 ┆ 99 │
# # │ 4 ┆ 0 ┆ -1 │
# # └─────┴─────┴─────┘
#
# @example Pass conditions as keyword arguments:
# df.with_columns(val: Polars.when(foo: 4, bar: 0).then(99).otherwise(-1))
# # =>
# # shape: (3, 3)
# # ┌─────┬─────┬─────┐
# # │ foo ┆ bar ┆ val │
# # │ --- ┆ --- ┆ --- │
# # │ i64 ┆ i64 ┆ i32 │
# # ╞═════╪═════╪═════╡
# # │ 1 ┆ 3 ┆ -1 │
# # │ 3 ┆ 4 ┆ -1 │
# # │ 4 ┆ 0 ┆ 99 │
# # └─────┴─────┴─────┘
def when(*predicates, **constraints)
condition = Utils.parse_when_inputs(*predicates, **constraints)
When.new(Plr.when(condition))
end
end
end
34 changes: 34 additions & 0 deletions lib/polars/utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -384,5 +384,39 @@ def self._expand_selectors(frame, *items)
def self.is_selector(obj)
false
end

def self.parse_predicates_constraints_as_expression(*predicates, **constraints)
all_predicates = _parse_positional_inputs(predicates)

if constraints.any?
constraint_predicates = _parse_constraints(constraints)
all_predicates.concat(constraint_predicates)
end

_combine_predicates(all_predicates)
end

def self._parse_constraints(constraints)
constraints.map do |name, value|
Polars.col(name).eq(value)._rbexpr
end
end

def self._combine_predicates(predicates)
if !predicates.any?
msg = "at least one predicate or constraint must be provided"
raise TypeError, msg
end

if predicates.length == 1
return predicates[0]
end

Plr.all_horizontal(predicates)
end

def self.parse_when_inputs(*predicates, **constraints)
parse_predicates_constraints_as_expression(*predicates, **constraints)
end
end
end
16 changes: 0 additions & 16 deletions lib/polars/when.rb

This file was deleted.

19 changes: 0 additions & 19 deletions lib/polars/when_then.rb

This file was deleted.

83 changes: 83 additions & 0 deletions lib/polars/whenthen.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
module Polars
# @private
class When
attr_accessor :_when

def initialize(rbwhen)
self._when = rbwhen
end

def then(statement)
statement_rbexpr = Utils.parse_as_expression(statement)
Then.new(_when.then(statement_rbexpr))
end
end

# @private
class Then < Expr
attr_accessor :_then

def initialize(rbthen)
self._then = rbthen
end

def self._from_rbexpr(rbexpr)
Utils.wrap_expr(rbexpr)
end

def _rbexpr
_then.otherwise(Polars.lit(nil)._rbexpr)
end

def when(*predicates, **constraints)
condition_rbexpr = Utils.parse_when_inputs(*predicates, **constraints)
ChainedWhen.new(_then.when(condition_rbexpr))
end

def otherwise(statement)
statement_rbexpr = Utils.parse_as_expression(statement)
Utils.wrap_expr(_then.otherwise(statement_rbexpr))
end
end

# @private
class ChainedWhen
attr_accessor :_chained_when

def initialize(chained_when)
self._chained_when = chained_when
end

def then(statement)
statement_rbexpr = Utils.parse_as_expression(statement)
ChainedThen.new(_chained_when.then(statement_rbexpr))
end
end

# @private
class ChainedThen < Expr
attr_accessor :_chained_then

def initialize(chained_then)
self._chained_then = chained_then
end

def self._from_rbexpr(rbexpr)
Utils.wrap_expr(rbexpr)
end

def _rbexpr
_chained_then.otherwise(Polars.lit(nil)._rbexpr)
end

def when(*predicates, **constraints)
condition_rbexpr = Utils.parse_when_inputs(*predicates, **constraints)
ChainedWhen.new(_chained_then.when(condition_rbexpr))
end

def otherwise(statement)
statement_rbexpr = Utils.parse_as_expression(statement)
Utils.wrap_expr(_chained_then.otherwise(statement_rbexpr))
end
end
end

0 comments on commit 9a563b9

Please sign in to comment.