Skip to content

Commit

Permalink
Added Testing module
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Mar 3, 2024
1 parent 8f34b07 commit ab67ea3
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ See the [upgrade guide](https://docs.pola.rs/releases/upgrade/0.20/)
- Changed `count` method to exclude null values
- Changed `dtype` and `schema` methods to always return instances of data types
- Added `Enum` type
- Added `Testing` module
- Added `sink_ipc`, `sink_csv`, and `sink_ndjson` methods to `LazyFrame`
- Added `replace` method to `Series` and `Expr`
- Added `eq`, `eq_missing`, `ne`, and `ne_missing` methods to `Series` and `Expr`
Expand Down
1 change: 1 addition & 0 deletions lib/polars.rb
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
require_relative "polars/string_name_space"
require_relative "polars/struct_expr"
require_relative "polars/struct_name_space"
require_relative "polars/testing"
require_relative "polars/utils"
require_relative "polars/version"
require_relative "polars/when"
Expand Down
2 changes: 1 addition & 1 deletion lib/polars/data_types.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def ==(other)
if other.is_a?(Class)
is_a?(other)
else
eql?(other)
other.instance_of?(self.class)
end
end

Expand Down
344 changes: 344 additions & 0 deletions lib/polars/testing.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
module Polars
module Testing
class AssertionError < StandardError; end

def assert_frame_equal(
left,
right,
check_row_order: true,
check_column_order: true,
check_dtype: true,
check_exact: false,
rtol: 1e-5,
atol: 1e-8,
categorical_as_str: false
)
lazy = _assert_correct_input_type(left, right)
objects = lazy ? "LazyFrames" : "DataFrames"

_assert_frame_schema_equal(
left,
right,
check_column_order: check_column_order,
check_dtype: check_dtype,
objects: objects,
)

if lazy
left, right = left.collect, right.collect
end

if left.height != right.height
raise_assertion_error(
objects, "number of rows does not match", left.height, right.height
)
end

if !check_row_order
left, right = _sort_dataframes(left, right)
end

left.columns.each do |c|
s_left, s_right = left.get_column(c), right.get_column(c)
begin
_assert_series_values_equal(
s_left,
s_right,
check_exact: check_exact,
rtol: rtol,
atol: atol,
categorical_as_str: categorical_as_str
)
rescue AssertionError => exc
raise_assertion_error(
objects,
"value mismatch for column #{c.inspect}",
s_left.to_a,
s_right.to_a,
cause: exc
)
end
end
end

def assert_series_equal(
left,
right,
check_dtype: true,
check_names: true,
check_exact: false,
rtol: 1e-5,
atol: 1e-8,
categorical_as_str: false
)
if !(left.is_a?(Series) && right.is_a?(Series))
raise_assertion_error(
"inputs",
"unexpected input types",
left.class.name,
right.class.name
)
end

if left.len != right.len
raise_assertion_error("Series", "length mismatch", left.len, right.len)
end

if check_names && left.name != right.name
raise_assertion_error("Series", "name mismatch", left.name, right.name)
end

if check_dtype && left.dtype != right.dtype
raise_assertion_error("Series", "dtype mismatch", left.dtype, right.dtype)
end

_assert_series_values_equal(
left,
right,
check_exact: check_exact,
rtol: rtol,
atol: atol,
categorical_as_str: categorical_as_str
)
end

private

def _assert_correct_input_type(left, right)
if left.is_a?(DataFrame) && right.is_a?(DataFrame)
return false
elsif left.is_a?(LazyFrame) && right.is_a?(DataFrame)
return true
else
raise_assertion_error(
"inputs",
"unexpected input types",
left.class.name,
right.class.name
)
end
end

def _assert_frame_schema_equal(
left,
right,
check_dtype:,
check_column_order:,
objects:
)
left_schema, right_schema = left.schema, right.schema

# Fast path for equal frames
if left_schema == right_schema
return
end

# Special error message for when column names do not match
if left_schema.keys != right_schema.keys
if (left_not_right = right_schema.keys - left_schema.keys).any?
msg = "columns #{left_not_right.inspect} in left #{objects[..-1]}, but not in right"
raise msg
else
right_not_left = right_schema.keys - left_schema.keys
msg = "columns #{right_not_left.inspect} in right #{objects[..-1]}, but not in left"
raise msg
end
end

if check_column_order
left_columns, right_columns = left_schema.keys, right_schema.keys
if left_columns != right_columns
detail = "columns are not in the same order"
raise_assertion_error(objects, detail, left_columns, right_columns)
end
end

if check_dtype
left_schema_dict, right_schema_dict = left_schema.to_h, right_schema.to_h
if check_column_order || left_schema_dict != right_schema_dict
detail = "dtypes do not match"
raise_assertion_error(objects, detail, left_schema_dict, right_schema_dict)
end
end
end

def _sort_dataframes(left, right)
by = left.columns
begin
left = left.sort(by)
right = right.sort(by)
rescue ComputeError
msg = "cannot set `check_row_order: false` on frame with unsortable columns"
raise msg
end
[left, right]
end

def _assert_series_values_equal(
left,
right,
check_exact:,
rtol:,
atol:,
categorical_as_str:
)
if categorical_as_str
if left.dtype == Categorical
left = left.cast(String)
end
if right.dtype == Categorical
right = right.cast(String)
end
end

# Determine unequal elements
begin
unequal = left.ne_missing(right)
rescue ComputeError => exc
raise_assertion_error(
"Series",
"incompatible data types",
left.dtype,
right.dtype,
cause: exc
)
end

# Check nested dtypes in separate function
if _comparing_nested_floats(left.dtype, right.dtype)
begin
_assert_series_nested_values_equal(
left: left.filter(unequal),
right: right.filter(unequal),
check_exact: check_exact,
rtol: rtol,
atol: atol,
categorical_as_str: categorical_as_str
)
rescue AssertionError => exc
raise_assertion_error(
"Series",
"nested value mismatch",
left.to_a,
right.to_a,
cause: exc
)
else
return
end
end

# If no differences found during exact checking, we're done
if !unequal.any
return
end

# Only do inexact checking for float types
if check_exact || !left.dtype.is_float || !right.dtype.is_float
raise_assertion_error(
"Series", "exact value mismatch", left.to_a, right.to_a
)
end

_assert_series_null_values_match(left, right)
_assert_series_nan_values_match(left, right)
_assert_series_values_within_tolerance(
left,
right,
unequal,
rtol: rtol,
atol: atol
)
end

def _assert_series_nested_values_equal(
left,
right,
check_exact:,
rtol:,
atol:,
categorical_as_str:
)
# compare nested lists element-wise
if _comparing_lists(left.dtype, right.dtype)
left.right.zip do |s1, s2|
if s1.nil? || s2.nil?
raise_assertion_error("Series", "nested value mismatch", s1, s2)
end

_assert_series_values_equal(
s1,
s2,
check_exact: check_exact,
rtol: rtol,
atol: atol,
categorical_as_str: categorical_as_str
)
end

# unnest structs as series and compare
else
ls, rs = left.struct.unnest, right.struct.unnest
ls.zip(rs) do |s1, s2|
_assert_series_values_equal(
s1,
s2,
check_exact: check_exact,
rtol: rtol,
atol: atol,
categorical_as_str: categorical_as_str
)
end
end
end

def _assert_series_null_values_match(left, right)
null_value_mismatch = left.is_null != right.is_null
if null_value_mismatch.any
raise_assertion_error(
"Series", "null value mismatch", left.to_a, right.to_a
)
end
end

def _assert_series_nan_values_match(left, right)
if !_comparing_floats(left.dtype, right.dtype)
return
end
nan_value_mismatch = left.is_nan != right.is_nan
if nan_value_mismatch.any
raise_assertion_error(
"Series",
"nan value mismatch",
left.to_a,
right.to_a
)
end
end

def _comparing_floats(left, right)
left.is_float && right.is_float
end

def _comparing_lists(left, right)
[List, Array].include?(left) && [List, Array].include?(right)
end

def _comparing_structs(left, right)
left == Struct && right == Struct
end

def _comparing_nested_floats(left, right)
if !(_comparing_lists(left, right) || _comparing_structs(left, right))
return false
end

(FLOAT_DTYPES & unpack_dtypes(left)) && (FLOAT_DTYPES & unpack_dtypes(right))
end

def raise_assertion_error(objects, detail, left, right, cause: nil)
msg = "#{objects} are different (#{detail})\n[left]: #{left}\n[right]: #{right}"
raise AssertionError, msg
end
end
end
4 changes: 2 additions & 2 deletions test/data_frame_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ def test_null_count

def test_count
df = Polars::DataFrame.new({"a" => [1, 2, nil]})
assert_frame Polars::DataFrame.new({"a" => [2]}), df.select(Polars.col("a").count)
assert_frame Polars::DataFrame.new({"a" => [3]}), df.select(Polars.col("a").len)
assert_frame Polars::DataFrame.new({"a" => [2]}), df.select(Polars.col("a").count), check_dtype: false
assert_frame Polars::DataFrame.new({"a" => [3]}), df.select(Polars.col("a").len), check_dtype: false
end

def test_replace
Expand Down
2 changes: 1 addition & 1 deletion test/series_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_to_dummies
def test_value_counts
s = Polars::Series.new("a", ["x", "x", "y"])
expected = Polars::DataFrame.new({"a" => ["x", "y"], "count" => [2, 1]})
assert_frame expected, s.value_counts.sort("a")
assert_frame expected, s.value_counts, check_row_order: false, check_dtype: false
end

def test_entropy
Expand Down
Loading

0 comments on commit ab67ea3

Please sign in to comment.