diff --git a/cpp/include/cudf/interop.hpp b/cpp/include/cudf/interop.hpp index e210179b147..3ad99cbcc99 100644 --- a/cpp/include/cudf/interop.hpp +++ b/cpp/include/cudf/interop.hpp @@ -133,6 +133,19 @@ std::shared_ptr to_arrow(table_view input, std::vector const& metadata = {}, arrow::MemoryPool* ar_mr = arrow::default_memory_pool()); +/** + * @brief Create `arrow::Scalar` from cudf scalar `input` + * + * Converts the `cudf::scalar` to `arrow::Scalar`. + * + * @param input scalar that needs to be converted to arrow Scalar + * @param metadata Contains hierarchy of names of columns and children + * @param ar_mr arrow memory pool to allocate memory for arrow Scalar + * @return arrow Scalar generated from `input` + */ +std::shared_ptr to_arrow(cudf::scalar const& input, + column_metadata const& metadata = {}, + arrow::MemoryPool* ar_mr = arrow::default_memory_pool()); /** * @brief Create `cudf::table` from given arrow Table input * @@ -145,5 +158,17 @@ std::unique_ptr from_arrow( arrow::Table const& input, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); +/** + * @brief Create `cudf::table` from given arrow Scalar input + * + * @param input arrow:Scalar that needs to be converted to `cudf::scalar` + * @param mr Device memory resource used to allocate `cudf::scalar` + * @return cudf scalar generated from given arrow Scalar + */ + +std::unique_ptr from_arrow( + arrow::Scalar const& input, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + /** @} */ // end of group } // namespace cudf diff --git a/cpp/src/interop/from_arrow.cu b/cpp/src/interop/from_arrow.cu index 30cfee97fd8..8ff361d7e04 100644 --- a/cpp/src/interop/from_arrow.cu +++ b/cpp/src/interop/from_arrow.cu @@ -13,6 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include +#include +#include #include #include #include @@ -35,6 +39,7 @@ #include #include +#include #include namespace cudf { @@ -472,4 +477,135 @@ std::unique_ptr
from_arrow(arrow::Table const& input_table, return detail::from_arrow(input_table, cudf::get_default_stream(), mr); } +template +constexpr decltype(auto) arrow_type_dispatcher(arrow::DataType const& dtype, + Functor f, + Ts&&... args) +{ + switch (dtype.id()) { + case arrow::Type::INT8: + return f.template operator()(std::forward(args)...); + case arrow::Type::INT16: + return f.template operator()(std::forward(args)...); + case arrow::Type::INT32: + return f.template operator()(std::forward(args)...); + case arrow::Type::INT64: + return f.template operator()(std::forward(args)...); + case arrow::Type::UINT8: + return f.template operator()(std::forward(args)...); + case arrow::Type::UINT16: + return f.template operator()(std::forward(args)...); + case arrow::Type::UINT32: + return f.template operator()(std::forward(args)...); + case arrow::Type::UINT64: + return f.template operator()(std::forward(args)...); + case arrow::Type::FLOAT: + return f.template operator()(std::forward(args)...); + case arrow::Type::DOUBLE: + return f.template operator()(std::forward(args)...); + case arrow::Type::BOOL: + return f.template operator()(std::forward(args)...); + case arrow::Type::TIMESTAMP: + return f.template operator()(std::forward(args)...); + case arrow::Type::DURATION: + return f.template operator()(std::forward(args)...); + // case arrow::Type::DICTIONARY32: + // return f.template operator()( + // std::forward(args)...); + case arrow::Type::STRING: + return f.template operator()(std::forward(args)...); + case arrow::Type::LIST: + return f.template operator()(std::forward(args)...); + case arrow::Type::DECIMAL128: + return f.template operator()(std::forward(args)...); + case arrow::Type::STRUCT: + return f.template operator()(std::forward(args)...); + default: { + CUDF_FAIL("Invalid type."); + } + } +} + +struct BuilderGenerator { + template + std::shared_ptr operator()(std::shared_ptr const& type) + { + return std::make_shared::BuilderType>( + type, arrow::default_memory_pool()); + } +}; + +template <> +std::shared_ptr BuilderGenerator::operator()( + std::shared_ptr const& type) +{ + CUDF_FAIL("Not implemented"); +} + +template <> +std::shared_ptr BuilderGenerator::operator()( + std::shared_ptr const& type) +{ + CUDF_FAIL("Not implemented"); +} + +std::shared_ptr make_builder(std::shared_ptr const& type) +{ + switch (type->id()) { + case arrow::Type::STRUCT: { + std::vector> field_builders; + + for (auto i = 0; i < type->num_fields(); ++i) { + auto const vt = type->field(i)->type(); + if (vt->id() == arrow::Type::STRUCT || vt->id() == arrow::Type::LIST) { + field_builders.push_back(make_builder(vt)); + } else { + field_builders.push_back(arrow_type_dispatcher(*vt, BuilderGenerator{}, vt)); + } + } + return std::make_shared( + type, arrow::default_memory_pool(), field_builders); + } + case arrow::Type::LIST: { + return std::make_shared(arrow::default_memory_pool(), + make_builder(type->field(0)->type())); + } + default: { + return arrow_type_dispatcher(*type, BuilderGenerator{}, type); + } + } +} + +std::unique_ptr from_arrow(arrow::Scalar const& input, + rmm::mr::device_memory_resource* mr) +{ + // Get a builder for the scalar type + auto builder = make_builder(input.type); + + auto status = builder->AppendScalar(input); + if (status != arrow::Status::OK()) { + if (status.IsNotImplemented()) { + // The only known failure case here is for nulls + CUDF_FAIL("Cannot create untyped null scalars or nested types with untyped null leaf nodes", + std::invalid_argument); + } + CUDF_FAIL("Arrow ArrayBuilder::AppendScalar failed"); + } + + auto maybe_array = builder->Finish(); + if (!maybe_array.ok()) { CUDF_FAIL("Arrow ArrayBuilder::Finish failed"); } + auto array = *maybe_array; + + auto field = arrow::field("", input.type); + + auto table = arrow::Table::Make(arrow::schema({field}), {array}); + + auto cudf_table = from_arrow(*table); + + auto col = cudf_table->get_column(0); + + auto cv = col.view(); + return get_element(cv, 0); +} + } // namespace cudf diff --git a/cpp/src/interop/to_arrow.cu b/cpp/src/interop/to_arrow.cu index 958a2fcb95f..b501c1e1d7e 100644 --- a/cpp/src/interop/to_arrow.cu +++ b/cpp/src/interop/to_arrow.cu @@ -15,10 +15,14 @@ */ #include +#include #include +#include #include #include +#include #include +#include #include #include #include @@ -139,6 +143,46 @@ struct dispatch_to_arrow { } }; +template <> +std::shared_ptr dispatch_to_arrow::operator()( + column_view input, + cudf::type_id, + column_metadata const&, + arrow::MemoryPool* ar_mr, + rmm::cuda_stream_view stream) +{ + using DeviceType = int32_t; + size_type const BIT_WIDTH_RATIO = 4; // Array::Type:type::DECIMAL (128) / int32_t + + rmm::device_uvector<__int128_t> buf(input.size() * BIT_WIDTH_RATIO, stream); + + auto count = thrust::make_counting_iterator(0); + + thrust::for_each(rmm::exec_policy(cudf::get_default_stream()), + count, + count + input.size(), + [in = input.begin(), out = buf.data()] __device__(auto in_idx) { + auto const out_idx = in_idx; + auto unsigned_value = in[in_idx] < 0 ? -in[in_idx] : in[in_idx]; + auto unsigned_128bit = static_cast<__int128_t>(unsigned_value); + auto signed_128bit = in[in_idx] < 0 ? -unsigned_128bit : unsigned_128bit; + out[out_idx] = signed_128bit; + }); + + auto const buf_size_in_bytes = buf.size() * sizeof(DeviceType); + auto data_buffer = allocate_arrow_buffer(buf_size_in_bytes, ar_mr); + + CUDF_CUDA_TRY(cudaMemcpyAsync( + data_buffer->mutable_data(), buf.data(), buf_size_in_bytes, cudaMemcpyDefault, stream.value())); + + auto type = arrow::decimal(9, -input.type().scale()); + auto mask = fetch_mask_buffer(input, ar_mr, stream); + auto buffers = std::vector>{mask, std::move(data_buffer)}; + auto data = std::make_shared(type, input.size(), buffers); + + return std::make_shared(data); +} + template <> std::shared_ptr dispatch_to_arrow::operator()( column_view input, @@ -413,4 +457,18 @@ std::shared_ptr to_arrow(table_view input, return detail::to_arrow(input, metadata, cudf::get_default_stream(), ar_mr); } +std::shared_ptr to_arrow(cudf::scalar const& input, + column_metadata const& metadata, + + arrow::MemoryPool* ar_mr) +{ + auto stream = cudf::get_default_stream(); + auto column = cudf::make_column_from_scalar(input, 1); + cudf::table_view tv{{column->view()}}; + auto arrow_table = cudf::to_arrow(tv, {metadata}); + auto ac = arrow_table->column(0); + auto maybe_scalar = ac->GetScalar(0); + if (!maybe_scalar.ok()) { CUDF_FAIL("Failed to produce a scalar"); } + return maybe_scalar.ValueOrDie(); +} } // namespace cudf diff --git a/python/cudf/cudf/_lib/cpp/interop.pxd b/python/cudf/cudf/_lib/cpp/interop.pxd index e81f0d617fb..88e9d83ee98 100644 --- a/python/cudf/cudf/_lib/cpp/interop.pxd +++ b/python/cudf/cudf/_lib/cpp/interop.pxd @@ -1,12 +1,13 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.string cimport string from libcpp.vector cimport vector -from pyarrow.lib cimport CTable +from pyarrow.lib cimport CScalar, CTable from cudf._lib.types import cudf_to_np_types, np_to_cudf_types +from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib.cpp.table.table cimport table from cudf._lib.cpp.table.table_view cimport table_view @@ -24,6 +25,7 @@ cdef extern from "cudf/interop.hpp" namespace "cudf" \ ) except + cdef unique_ptr[table] from_arrow(CTable input) except + + cdef unique_ptr[scalar] from_arrow(CScalar input) except + cdef cppclass column_metadata: column_metadata() except + @@ -35,3 +37,8 @@ cdef extern from "cudf/interop.hpp" namespace "cudf" \ table_view input, vector[column_metadata] metadata, ) except + + + cdef shared_ptr[CScalar] to_arrow( + const scalar& input, + column_metadata metadata, + ) except + diff --git a/python/cudf/cudf/_lib/cpp/libcpp/functional.pxd b/python/cudf/cudf/_lib/cpp/libcpp/functional.pxd index f3e2d6d0878..c38db036119 100644 --- a/python/cudf/cudf/_lib/cpp/libcpp/functional.pxd +++ b/python/cudf/cudf/_lib/cpp/libcpp/functional.pxd @@ -1,6 +1,8 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# TODO: Can be replaced once https://github.com/cython/cython/pull/5671 is +# merged and released cdef extern from "" namespace "std" nogil: cdef cppclass reference_wrapper[T]: reference_wrapper() diff --git a/python/cudf/cudf/_lib/cpp/reduce.pxd b/python/cudf/cudf/_lib/cpp/reduce.pxd index 7952c717916..997782dec6c 100644 --- a/python/cudf/cudf/_lib/cpp/reduce.pxd +++ b/python/cudf/cudf/_lib/cpp/reduce.pxd @@ -1,14 +1,13 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr from libcpp.utility cimport pair -from cudf._lib.aggregation cimport reduce_aggregation, scan_aggregation +from cudf._lib.cpp.aggregation cimport reduce_aggregation, scan_aggregation from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib.cpp.types cimport data_type -from cudf._lib.scalar cimport DeviceScalar cdef extern from "cudf/reduction.hpp" namespace "cudf" nogil: diff --git a/python/cudf/cudf/_lib/datetime.pyx b/python/cudf/cudf/_lib/datetime.pyx index 81949dbaa20..3d96f59c4d6 100644 --- a/python/cudf/cudf/_lib/datetime.pyx +++ b/python/cudf/cudf/_lib/datetime.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from cudf.core.buffer import acquire_spill_lock @@ -10,6 +10,7 @@ from cudf._lib.column cimport Column from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.filling cimport calendrical_month_sequence +from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib.cpp.types cimport size_type from cudf._lib.scalar cimport DeviceScalar @@ -166,10 +167,11 @@ def date_range(DeviceScalar start, size_type n, offset): + offset.kwds.get("months", 0) ) + cdef const scalar* c_start = start.c_value.get() with nogil: c_result = move(calendrical_month_sequence( n, - start.c_value.get()[0], + c_start[0], months )) return Column.from_unique_ptr(move(c_result)) diff --git a/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt b/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt index 0ce42dc43ff..64adb38ace3 100644 --- a/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt +++ b/python/cudf/cudf/_lib/pylibcudf/CMakeLists.txt @@ -12,10 +12,35 @@ # the License. # ============================================================================= -set(cython_sources column.pyx copying.pyx gpumemoryview.pyx table.pyx types.pyx utils.pyx) +set(cython_sources column.pyx copying.pyx gpumemoryview.pyx interop.pyx scalar.pyx table.pyx + types.pyx utils.pyx +) set(linked_libraries cudf::cudf) rapids_cython_create_modules( CXX SOURCE_FILES "${cython_sources}" LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX pylibcudf_ ASSOCIATED_TARGETS cudf ) + +find_package(Python 3.9 REQUIRED COMPONENTS Interpreter) + +execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import pyarrow; print(pyarrow.get_include())" + OUTPUT_VARIABLE PYARROW_INCLUDE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +set(targets_using_arrow_headers pylibcudf_interop pylibcudf_scalar) +foreach(target IN LISTS targets_using_arrow_headers) + target_include_directories(${target} PRIVATE "${PYARROW_INCLUDE_DIR}") +endforeach() + +# TODO: Clean up this include when switching to scikit-build-core. See cudf/_lib/CMakeLists.txt for +# more info +find_package(NumPy REQUIRED) +set(targets_using_numpy pylibcudf_interop pylibcudf_scalar) +foreach(target IN LISTS targets_using_numpy) + target_include_directories(${target} PRIVATE "${NumPy_INCLUDE_DIRS}") + # Switch to the line below when we switch back to FindPython.cmake in CMake 3.24. + # target_include_directories(${target} PRIVATE "${Python_NumPy_INCLUDE_DIRS}") +endforeach() diff --git a/python/cudf/cudf/_lib/pylibcudf/__init__.pxd b/python/cudf/cudf/_lib/pylibcudf/__init__.pxd index ba7822b0a54..cf0007b9303 100644 --- a/python/cudf/cudf/_lib/pylibcudf/__init__.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/__init__.pxd @@ -1,9 +1,13 @@ # Copyright (c) 2023, NVIDIA CORPORATION. # TODO: Verify consistent usage of relative/absolute imports in pylibcudf. -from . cimport copying +# TODO: Cannot import interop because it introduces a build-time pyarrow header +# dependency for everything that cimports pylibcudf. See if there's a way to +# avoid that before polluting the whole package. +from . cimport copying # , interop from .column cimport Column from .gpumemoryview cimport gpumemoryview +from .scalar cimport Scalar from .table cimport Table # TODO: cimport type_id once # https://github.com/cython/cython/issues/5609 is resolved @@ -12,7 +16,9 @@ from .types cimport DataType __all__ = [ "Column", "DataType", + "Scalar", "Table", "copying", "gpumemoryview", + # "interop", ] diff --git a/python/cudf/cudf/_lib/pylibcudf/__init__.py b/python/cudf/cudf/_lib/pylibcudf/__init__.py index 3edff9a53e8..72b74a57b87 100644 --- a/python/cudf/cudf/_lib/pylibcudf/__init__.py +++ b/python/cudf/cudf/_lib/pylibcudf/__init__.py @@ -1,16 +1,19 @@ # Copyright (c) 2023, NVIDIA CORPORATION. -from . import copying +from . import copying, interop from .column import Column from .gpumemoryview import gpumemoryview +from .scalar import Scalar from .table import Table from .types import DataType, TypeId __all__ = [ "Column", "DataType", + "Scalar", "Table", "TypeId", "copying", "gpumemoryview", + "interop", ] diff --git a/python/cudf/cudf/_lib/pylibcudf/copying.pxd b/python/cudf/cudf/_lib/pylibcudf/copying.pxd index d57be650710..a2232fc5d81 100644 --- a/python/cudf/cudf/_lib/pylibcudf/copying.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/copying.pxd @@ -13,3 +13,9 @@ cpdef Table gather( Column gather_map, out_of_bounds_policy bounds_policy ) + +cpdef Table scatter ( + list source_scalars, + Column indices, + Table target, +) diff --git a/python/cudf/cudf/_lib/pylibcudf/copying.pyx b/python/cudf/cudf/_lib/pylibcudf/copying.pyx index a27b44b3107..dff3e95328d 100644 --- a/python/cudf/cudf/_lib/pylibcudf/copying.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/copying.pyx @@ -2,6 +2,7 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move +from libcpp.vector cimport vector # TODO: We want to make cpp a more full-featured package so that we can access # directly from that. It will make namespacing much cleaner in pylibcudf. What @@ -9,13 +10,16 @@ from libcpp.utility cimport move # cimport libcudf... libcudf.copying.algo(...) from cudf._lib.cpp cimport copying as cpp_copying from cudf._lib.cpp.copying cimport out_of_bounds_policy +from cudf._lib.cpp.libcpp.functional cimport reference_wrapper from cudf._lib.cpp.copying import \ out_of_bounds_policy as OutOfBoundsPolicy # no-cython-lint +from cudf._lib.cpp.scalar.scalar cimport scalar from cudf._lib.cpp.table.table cimport table from .column cimport Column +from .scalar cimport Scalar from .table cimport Table @@ -55,3 +59,52 @@ cpdef Table gather( ) ) return Table.from_libcudf(move(c_result)) + + +ctypedef const scalar constscalar + +cpdef Table scatter ( + list source_scalars, + Column indices, + Table target, +): + """Scatter source_scalars into target according to the indices. + + For details on the implementation, see cudf::scatter in libcudf. + + Parameters + ---------- + source_scalars : List[Scalar] + A list containing one scalar for each column in target. + indices : Column + The rows of the target into which the source_scalars should be written. + target : Table + The table into which data should be written. + + Returns + ------- + pylibcudf.Table + The result of the scatter + """ + cdef unique_ptr[table] c_result + # TODO: This doesn't require the constscalar ctypedef + cdef vector[reference_wrapper[const scalar]] c_scalars + c_scalars.reserve(len(source_scalars)) + cdef Scalar d_slr + for d_slr in source_scalars: + c_scalars.push_back( + # TODO: This requires the constscalar ctypedef + # Possibly the same as https://github.com/cython/cython/issues/4180 + reference_wrapper[constscalar](d_slr.get()[0]) + ) + + with nogil: + c_result = move( + cpp_copying.scatter( + c_scalars, + indices.view(), + target.view(), + ) + ) + + return Table.from_libcudf(move(c_result)) diff --git a/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pxd b/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pxd index 713697bd139..93449fd02d1 100644 --- a/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pxd @@ -1,9 +1,14 @@ # Copyright (c) 2023, NVIDIA CORPORATION. +from libcpp cimport bool + cdef class gpumemoryview: # TODO: Eventually probably want to make this opaque, but for now it's fine # to treat this object as something like a POD struct cdef readonly: Py_ssize_t ptr - object obj + object _obj + bool _released + + cpdef release(self) diff --git a/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pyx b/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pyx index fc98f087a1b..9121092ecdd 100644 --- a/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/gpumemoryview.pyx @@ -19,9 +19,21 @@ cdef class gpumemoryview: "gpumemoryview must be constructed from an object supporting " "the CUDA array interface" ) - self.obj = obj + self._obj = obj + self._released = False # TODO: Need to respect readonly self.ptr = cai["data"][0] def __cuda_array_interface__(self): return self.obj.__cuda_array_interface__ + + @property + def obj(self): + if not self._released: + return self._obj + else: + raise ValueError("operation forbidden on released gpumemoryview object") + + cpdef release(self): + self._obj = None + self._released = True diff --git a/python/cudf/cudf/_lib/pylibcudf/interop.pxd b/python/cudf/cudf/_lib/pylibcudf/interop.pxd new file mode 100644 index 00000000000..c1268b8ca1a --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/interop.pxd @@ -0,0 +1,26 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. + +from pyarrow.lib cimport Scalar as pa_Scalar, Table as pa_Table + +from cudf._lib.cpp.interop cimport column_metadata + +from .scalar cimport Scalar +from .table cimport Table + + +cdef class ColumnMetadata: + cdef public object name + cdef public object children_meta + cdef column_metadata to_c_metadata(self) + +cpdef Table from_arrow( + pa_Table pyarrow_table, +) + +cpdef Scalar from_arrow_scalar( + pa_Scalar pyarrow_scalar, +) + +cpdef pa_Table to_arrow(Table tbl, list metadata) + +cpdef pa_Scalar to_arrow_scalar(Scalar slr, ColumnMetadata metadata) diff --git a/python/cudf/cudf/_lib/pylibcudf/interop.pyx b/python/cudf/cudf/_lib/pylibcudf/interop.pyx new file mode 100644 index 00000000000..9e6b44b117d --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/interop.pyx @@ -0,0 +1,98 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. + +from cython.operator cimport dereference +from libcpp.memory cimport shared_ptr, unique_ptr +from libcpp.utility cimport move +from libcpp.vector cimport vector +from pyarrow.lib cimport ( + CScalar as pa_CScalar, + CTable as pa_CTable, + Scalar as pa_Scalar, + Table as pa_Table, + pyarrow_unwrap_scalar, + pyarrow_unwrap_table, + pyarrow_wrap_scalar, + pyarrow_wrap_table, +) + +from cudf._lib.cpp.interop cimport ( + column_metadata, + from_arrow as cpp_from_arrow, + to_arrow as cpp_to_arrow, +) +from cudf._lib.cpp.scalar.scalar cimport scalar +from cudf._lib.cpp.table.table cimport table + +from .scalar cimport Scalar +from .table cimport Table + + +cdef class ColumnMetadata: + def __init__(self, name): + self.name = name + self.children_meta = [] + + cdef column_metadata to_c_metadata(self): + """Convert to C++ column_metadata. + + Since this class is mutable and cheap, it is easier to create the C++ + object on the fly rather than have it directly backing the storage for + the Cython class. + """ + cdef column_metadata c_metadata + cdef ColumnMetadata child_meta + c_metadata.name = self.name.encode() + for child_meta in self.children_meta: + c_metadata.children_meta.push_back(child_meta.to_c_metadata()) + return c_metadata + + +cpdef Table from_arrow( + pa_Table pyarrow_table, +): + cdef shared_ptr[pa_CTable] ctable = ( + pyarrow_unwrap_table(pyarrow_table) + ) + cdef unique_ptr[table] c_result + + with nogil: + c_result = move(cpp_from_arrow(ctable.get()[0])) + + return Table.from_libcudf(move(c_result)) + + +cpdef Scalar from_arrow_scalar( + pa_Scalar pyarrow_scalar, +): + cdef shared_ptr[pa_CScalar] cscalar = ( + pyarrow_unwrap_scalar(pyarrow_scalar) + ) + cdef unique_ptr[scalar] c_result + + with nogil: + c_result = move(cpp_from_arrow(cscalar.get()[0])) + + return Scalar.from_libcudf(move(c_result)) + + +cpdef pa_Table to_arrow(Table tbl, list metadata): + cdef shared_ptr[pa_CTable] c_result + cdef vector[column_metadata] c_metadata + cdef ColumnMetadata meta + for meta in metadata: + c_metadata.push_back(meta.to_c_metadata()) + + with nogil: + c_result = move(cpp_to_arrow(tbl.view(), c_metadata)) + + return pyarrow_wrap_table(c_result) + + +cpdef pa_Scalar to_arrow_scalar(Scalar slr, ColumnMetadata metadata): + cdef shared_ptr[pa_CScalar] c_result + cdef column_metadata c_metadata = metadata.to_c_metadata() + + with nogil: + c_result = move(cpp_to_arrow(dereference(slr.c_obj.get()), c_metadata)) + + return pyarrow_wrap_scalar(c_result) diff --git a/python/cudf/cudf/_lib/pylibcudf/scalar.pxd b/python/cudf/cudf/_lib/pylibcudf/scalar.pxd new file mode 100644 index 00000000000..d20c65f0be0 --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/scalar.pxd @@ -0,0 +1,32 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. + +from libcpp cimport bool +from libcpp.memory cimport unique_ptr + +from rmm._lib.memory_resource cimport DeviceMemoryResource + +from cudf._lib.cpp.scalar.scalar cimport scalar + +from .types cimport DataType + + +cdef class Scalar: + cdef unique_ptr[scalar] c_obj + cdef DataType _data_type + + # Holds a reference to the DeviceMemoryResource used for allocation. + # Ensures the MR does not get destroyed before this DeviceBuffer. `mr` is + # needed for deallocation + cdef DeviceMemoryResource mr + + cdef const scalar* get(self) except * + + cpdef DataType type(self) + cpdef bool is_valid(self) + + @staticmethod + cdef Scalar from_libcudf(unique_ptr[scalar] libcudf_scalar, dtype=*) + + # TODO: Make sure I'm correct to avoid typing the metadata as + # ColumnMetadata, I assume that will cause circular cimport problems + cpdef to_pyarrow_scalar(self, metadata) diff --git a/python/cudf/cudf/_lib/pylibcudf/scalar.pyx b/python/cudf/cudf/_lib/pylibcudf/scalar.pyx new file mode 100644 index 00000000000..d06bb9adeb9 --- /dev/null +++ b/python/cudf/cudf/_lib/pylibcudf/scalar.pyx @@ -0,0 +1,119 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. + +cimport pyarrow.lib +from cython cimport no_gc_clear +from libcpp.memory cimport unique_ptr + +import pyarrow.lib + +from rmm._lib.memory_resource cimport get_current_device_resource + +from cudf._lib.cpp.scalar.scalar cimport fixed_point_scalar, scalar +from cudf._lib.cpp.wrappers.decimals cimport ( + decimal32, + decimal64, + decimal128, + scale_type, +) + +from .types cimport DataType, type_id + + +# The DeviceMemoryResource attribute could be released prematurely +# by the gc if the DeviceScalar is in a reference cycle. Removing +# the tp_clear function with the no_gc_clear decoration prevents that. +# See https://github.com/rapidsai/rmm/pull/931 for details. +@no_gc_clear +cdef class Scalar: + """A scalar value in device memory.""" + # Unlike for columns, libcudf does not support scalar views. All APIs that + # accept scalar values accept references to the owning object rather than a + # special view type. As a result, pylibcudf.Scalar has a simpler structure + # than pylibcudf.Column because it can be a true wrapper around a libcudf + # column + + def __cinit__(self, *args, **kwargs): + self.mr = get_current_device_resource() + + def __init__(self, pyarrow.lib.Scalar value=None): + # TODO: This case is not something we really want to + # support, but it here for now to ease the transition of + # DeviceScalar. + if value is not None: + raise ValueError("Scalar should be constructed with a factory") + + @staticmethod + def from_pyarrow_scalar(pyarrow.lib.Scalar value, DataType data_type=None): + # Allow passing a dtype, but only for the purpose of decimals for now + + # Need a local import here to avoid a circular dependency because + # from_arrow_scalar returns a Scalar. + from .interop import from_arrow_scalar + + cdef Scalar s = from_arrow_scalar(value) + if s.type().id() != type_id.DECIMAL128: + if data_type is not None: + raise ValueError( + "dtype may not be passed for non-decimal types" + ) + return s + + if data_type is None: + raise ValueError( + "Decimal scalars must be constructed with a dtype" + ) + + cdef type_id tid = data_type.id() + if tid not in (type_id.DECIMAL32, type_id.DECIMAL64, type_id.DECIMAL128): + raise ValueError( + "Decimal scalars may only be cast to decimals" + ) + + if tid == type_id.DECIMAL128: + return s + + if tid == type_id.DECIMAL32: + s.c_obj.reset( + new fixed_point_scalar[decimal32]( + ( s.c_obj.get()).value(), + scale_type(-value.type.scale), + s.c_obj.get().is_valid() + ) + ) + elif tid == type_id.DECIMAL64: + s.c_obj.reset( + new fixed_point_scalar[decimal64]( + ( s.c_obj.get()).value(), + scale_type(-value.type.scale), + s.c_obj.get().is_valid() + ) + ) + return s + + cpdef to_pyarrow_scalar(self, metadata): + from .interop import to_arrow_scalar + return to_arrow_scalar(self, metadata) + + cdef const scalar* get(self) except *: + return self.c_obj.get() + + cpdef DataType type(self): + """The type of data in the column.""" + return self._data_type + + cpdef bool is_valid(self): + """True if the scalar is valid, false if not""" + return self.get().is_valid() + + @staticmethod + cdef Scalar from_libcudf(unique_ptr[scalar] libcudf_scalar, dtype=None): + """Construct a Scalar object from a libcudf scalar. + + This method is for pylibcudf's functions to use to ingest outputs of + calling libcudf algorithms, and should generally not be needed by users + (even direct pylibcudf Cython users). + """ + cdef Scalar s = Scalar.__new__(Scalar) + s.c_obj.swap(libcudf_scalar) + s._data_type = DataType.from_libcudf(s.get().type()) + return s diff --git a/python/cudf/cudf/_lib/scalar.pxd b/python/cudf/cudf/_lib/scalar.pxd index 1deed60d67d..ae1d350edc6 100644 --- a/python/cudf/cudf/_lib/scalar.pxd +++ b/python/cudf/cudf/_lib/scalar.pxd @@ -1,20 +1,16 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from libcpp cimport bool from libcpp.memory cimport unique_ptr from rmm._lib.memory_resource cimport DeviceMemoryResource +from cudf._lib cimport pylibcudf from cudf._lib.cpp.scalar.scalar cimport scalar cdef class DeviceScalar: - cdef unique_ptr[scalar] c_value - - # Holds a reference to the DeviceMemoryResource used for allocation. - # Ensures the MR does not get destroyed before this DeviceBuffer. `mr` is - # needed for deallocation - cdef DeviceMemoryResource mr + cdef pylibcudf.Scalar c_value cdef object _dtype diff --git a/python/cudf/cudf/_lib/scalar.pyx b/python/cudf/cudf/_lib/scalar.pyx index 0407785b2d8..30d46bd3fe2 100644 --- a/python/cudf/cudf/_lib/scalar.pyx +++ b/python/cudf/cudf/_lib/scalar.pyx @@ -1,7 +1,6 @@ # Copyright (c) 2020-2023, NVIDIA CORPORATION. -cimport cython - +import copy import decimal import numpy as np @@ -22,14 +21,17 @@ from libcpp cimport bool from libcpp.memory cimport unique_ptr from libcpp.utility cimport move -from rmm._lib.memory_resource cimport get_current_device_resource - import cudf +from cudf._lib.pylibcudf.interop import ColumnMetadata + +from cudf._lib.pylibcudf.types cimport type_id + from cudf._lib.types import ( LIBCUDF_TO_SUPPORTED_NUMPY_TYPES, datetime_unit_map, duration_unit_map, ) +from cudf.api.types import is_list_dtype, is_struct_dtype from cudf.core.dtypes import ListDtype, StructDtype from cudf.core.missing import NA, NaT @@ -38,9 +40,10 @@ from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.table.table_view cimport table_view from cudf._lib.types cimport dtype_from_column_view, underlying_type_t_type_id -from cudf._lib.interop import from_arrow, to_arrow +from cudf._lib.interop import to_arrow cimport cudf._lib.cpp.types as libcudf_types +from cudf._lib cimport pylibcudf from cudf._lib.cpp.scalar.scalar cimport ( duration_scalar, fixed_point_scalar, @@ -51,12 +54,7 @@ from cudf._lib.cpp.scalar.scalar cimport ( struct_scalar, timestamp_scalar, ) -from cudf._lib.cpp.wrappers.decimals cimport ( - decimal32, - decimal64, - decimal128, - scale_type, -) +from cudf._lib.cpp.wrappers.decimals cimport decimal32, decimal64, decimal128 from cudf._lib.cpp.wrappers.durations cimport ( duration_ms, duration_ns, @@ -69,18 +67,67 @@ from cudf._lib.cpp.wrappers.timestamps cimport ( timestamp_s, timestamp_us, ) -from cudf._lib.utils cimport columns_from_table_view, table_view_from_columns +from cudf._lib.utils cimport columns_from_table_view + + +# TODO: Check if this could replace _nested_na_replace +def _replace_nested_nulls(obj): + if isinstance(obj, list): + for i, item in enumerate(obj): + # TODO: Check if this should use _is_null_host_scalar + if cudf.utils.utils.is_na_like(item): + obj[i] = None + elif isinstance(item, (dict, list)): + _replace_nested_nulls(item) + elif isinstance(obj, dict): + for k, v in obj.items(): + if cudf.utils.utils.is_na_like(v): + obj[k] = None + elif isinstance(v, (dict, list)): + _replace_nested_nulls(v) + + +def _replace_nested_none(obj): + if isinstance(obj, list): + for i, item in enumerate(obj): + if item is None: + obj[i] = NA + elif isinstance(item, (dict, list)): + _replace_nested_none(item) + elif isinstance(obj, dict): + for k, v in obj.items(): + if v is None: + obj[k] = NA + elif isinstance(v, (dict, list)): + _replace_nested_none(v) + + +def gather_metadata(dtypes): + # dtypes is a dict mapping names to column dtypes + # This interface is a bit clunky, but it matches libcudf. May want to + # consider better approaches to building up the metadata eventually. + out = [] + for name, dtype in dtypes.items(): + v = ColumnMetadata(name) + if is_struct_dtype(dtype): + v.children_meta = gather_metadata(dtype.fields) + elif is_list_dtype(dtype): + # Offsets column is unnamed and has no children + v.children_meta.append(ColumnMetadata("")) + v.children_meta.extend( + gather_metadata({"": dtype.element_type}) + ) + out.append(v) + return out -# The DeviceMemoryResource attribute could be released prematurely -# by the gc if the DeviceScalar is in a reference cycle. Removing -# the tp_clear function with the no_gc_clear decoration prevents that. -# See https://github.com/rapidsai/rmm/pull/931 for details. -@cython.no_gc_clear cdef class DeviceScalar: + # I think this should be removable, except that currently the way that + # from_unique_ptr is implemented is probably dereferencing this in an + # invalid state. See what the best way to fix that is. def __cinit__(self, *args, **kwargs): - self.mr = get_current_device_resource() + self.c_value = pylibcudf.Scalar() def __init__(self, value, dtype): """ @@ -96,63 +143,102 @@ cdef class DeviceScalar: dtype : dtype A NumPy dtype. """ - self._dtype = dtype if dtype.kind != 'U' else cudf.dtype('object') - self._set_value(value, self._dtype) - - def _set_value(self, value, dtype): - # IMPORTANT: this should only ever be called from __init__ - valid = not _is_null_host_scalar(value) + dtype = dtype if dtype.kind != 'U' else cudf.dtype('object') - if isinstance(dtype, cudf.core.dtypes.DecimalDtype): - _set_decimal_from_scalar( - self.c_value, value, dtype, valid) - elif isinstance(dtype, cudf.ListDtype): - _set_list_from_pylist( - self.c_value, value, dtype, valid) - elif isinstance(dtype, cudf.StructDtype): - _set_struct_from_pydict(self.c_value, value, dtype, valid) + if cudf.utils.utils.is_na_like(value): + value = None + else: + # TODO: For now we always deepcopy the input value to avoid + # overwriting the input values when replacing nulls. Since it's + # just host values it's not that expensive, but we could consider + # alternatives. + value = copy.deepcopy(value) + _replace_nested_nulls(value) + + if isinstance(dtype, cudf.core.dtypes._BaseDtype): + pa_type = dtype.to_arrow() elif pd.api.types.is_string_dtype(dtype): - _set_string_from_np_string(self.c_value, value, valid) - elif pd.api.types.is_numeric_dtype(dtype): - _set_numeric_from_np_scalar(self.c_value, - value, - dtype, - valid) - elif pd.api.types.is_datetime64_dtype(dtype): - _set_datetime64_from_np_scalar( - self.c_value, value, dtype, valid - ) - elif pd.api.types.is_timedelta64_dtype(dtype): - _set_timedelta64_from_np_scalar( - self.c_value, value, dtype, valid - ) + # Have to manually convert object types, which we use internally + # for strings but pyarrow only supports as unicode 'U' + pa_type = pa.string() else: - raise ValueError( - f"Cannot convert value of type " - f"{type(value).__name__} to cudf scalar" - ) + pa_type = pa.from_numpy_dtype(dtype) + + pa_scalar = pa.scalar(value, type=pa_type) + + cdef type_id tid + data_type = None + if isinstance(dtype, cudf.core.dtypes.DecimalDtype): + tid = type_id.DECIMAL128 + if isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): + tid = type_id.DECIMAL32 + elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): + tid = type_id.DECIMAL64 + data_type = pylibcudf.DataType(tid, -dtype.scale) + + self.c_value = pylibcudf.Scalar.from_pyarrow_scalar(pa_scalar, data_type) + self._dtype = dtype def _to_host_scalar(self): + metadata = gather_metadata({"": self.dtype})[0] if isinstance(self.dtype, cudf.core.dtypes.DecimalDtype): - result = _get_py_decimal_from_fixed_point(self.c_value) + ps = self.c_value.to_pyarrow_scalar(metadata) + if not ps.is_valid: + return NA + return ps.as_py() + # For nested types we should eventually account for the special cases + # that we handle for other types, e.g. numerics being casted to numpy + # types or datetime/timedelta needing to be cast to int64 to handle + # overflow. However, the old implementation didn't handle these cases + # either, so we can leave that for a follow-up PR. elif cudf.api.types.is_struct_dtype(self.dtype): - result = _get_py_dict_from_struct(self.c_value, self.dtype) + ps = self.c_value.to_pyarrow_scalar(metadata) + if not ps.is_valid: + return NA + ret = ps.as_py() + _replace_nested_none(ret) + return ret elif cudf.api.types.is_list_dtype(self.dtype): - result = _get_py_list_from_list(self.c_value, self.dtype) + ps = self.c_value.to_pyarrow_scalar(metadata) + if not ps.is_valid: + return NA + ret = ps.as_py() + _replace_nested_none(ret) + return ret elif pd.api.types.is_string_dtype(self.dtype): - result = _get_py_string_from_string(self.c_value) + ps = self.c_value.to_pyarrow_scalar(metadata) + if not ps.is_valid: + return NA + return ps.as_py() elif pd.api.types.is_numeric_dtype(self.dtype): - result = _get_np_scalar_from_numeric(self.c_value) + ps = self.c_value.to_pyarrow_scalar(metadata) + if not ps.is_valid: + return NA + return ps.type.to_pandas_dtype()(ps.as_py()) elif pd.api.types.is_datetime64_dtype(self.dtype): - result = _get_np_scalar_from_timestamp64(self.c_value) + ps = self.c_value.to_pyarrow_scalar(metadata) + if not ps.is_valid: + return NaT + time_unit, _ = np.datetime_data(self.dtype) + # Cast to int64 to avoid overflow + return np.datetime64(ps.cast('int64').as_py(), time_unit) elif pd.api.types.is_timedelta64_dtype(self.dtype): - result = _get_np_scalar_from_timedelta64(self.c_value) + ps = self.c_value.to_pyarrow_scalar(metadata) + if not ps.is_valid: + return NaT + time_unit, _ = np.datetime_data(self.dtype) + # Cast to int64 to avoid overflow + return np.timedelta64(ps.cast('int64').as_py(), time_unit) else: raise ValueError( "Could not convert cudf::scalar to a Python value" ) return result + # TODO: This is just here for testing and should be removed. + def get(self): + return self.c_value + @property def dtype(self): """ @@ -169,13 +255,13 @@ cdef class DeviceScalar: return self._to_host_scalar() cdef const scalar* get_raw_ptr(self) except *: - return self.c_value.get() + return self.c_value.c_obj.get() cpdef bool is_valid(self): """ Returns if the Scalar is valid or not(i.e., ). """ - return self.get_raw_ptr()[0].is_valid() + return self.c_value.is_valid() def __repr__(self): if cudf.utils.utils.is_na_like(self.value): @@ -194,7 +280,7 @@ cdef class DeviceScalar: cdef DeviceScalar s = DeviceScalar.__new__(DeviceScalar) cdef libcudf_types.data_type cdtype - s.c_value = move(ptr) + s.c_value = pylibcudf.Scalar.from_libcudf(move(ptr)) cdtype = s.get_raw_ptr()[0].type() if dtype is not None: @@ -236,42 +322,6 @@ cdef class DeviceScalar: return s -cdef _set_string_from_np_string(unique_ptr[scalar]& s, value, bool valid=True): - value = value if valid else "" - s.reset(new string_scalar(value.encode(), valid)) - - -cdef _set_numeric_from_np_scalar(unique_ptr[scalar]& s, - object value, - object dtype, - bool valid=True): - value = value if valid else 0 - if dtype == "int8": - s.reset(new numeric_scalar[int8_t](value, valid)) - elif dtype == "int16": - s.reset(new numeric_scalar[int16_t](value, valid)) - elif dtype == "int32": - s.reset(new numeric_scalar[int32_t](value, valid)) - elif dtype == "int64": - s.reset(new numeric_scalar[int64_t](value, valid)) - elif dtype == "uint8": - s.reset(new numeric_scalar[uint8_t](value, valid)) - elif dtype == "uint16": - s.reset(new numeric_scalar[uint16_t](value, valid)) - elif dtype == "uint32": - s.reset(new numeric_scalar[uint32_t](value, valid)) - elif dtype == "uint64": - s.reset(new numeric_scalar[uint64_t](value, valid)) - elif dtype == "float32": - s.reset(new numeric_scalar[float](value, valid)) - elif dtype == "float64": - s.reset(new numeric_scalar[double](value, valid)) - elif dtype == "bool": - s.reset(new numeric_scalar[bool](value, valid)) - else: - raise ValueError(f"dtype not supported: {dtype}") - - cdef _set_datetime64_from_np_scalar(unique_ptr[scalar]& s, object value, object dtype, @@ -324,62 +374,6 @@ cdef _set_timedelta64_from_np_scalar(unique_ptr[scalar]& s, else: raise ValueError(f"dtype not supported: {dtype}") -cdef _set_decimal_from_scalar(unique_ptr[scalar]& s, - object value, - object dtype, - bool valid=True): - value = cudf.utils.dtypes._decimal_to_int64(value) if valid else 0 - if isinstance(dtype, cudf.Decimal64Dtype): - s.reset( - new fixed_point_scalar[decimal64]( - np.int64(value), scale_type(-dtype.scale), valid - ) - ) - elif isinstance(dtype, cudf.Decimal32Dtype): - s.reset( - new fixed_point_scalar[decimal32]( - np.int32(value), scale_type(-dtype.scale), valid - ) - ) - elif isinstance(dtype, cudf.Decimal128Dtype): - s.reset( - new fixed_point_scalar[decimal128]( - value, scale_type(-dtype.scale), valid - ) - ) - else: - raise ValueError(f"dtype not supported: {dtype}") - -cdef _set_struct_from_pydict(unique_ptr[scalar]& s, - object value, - object dtype, - bool valid=True): - arrow_schema = dtype.to_arrow() - columns = [str(i) for i in range(len(arrow_schema))] - if valid: - pyarrow_table = pa.Table.from_arrays( - [ - pa.array([value[f.name]], from_pandas=True, type=f.type) - for f in arrow_schema - ], - names=columns - ) - else: - pyarrow_table = pa.Table.from_arrays( - [ - pa.array([NA], from_pandas=True, type=f.type) - for f in arrow_schema - ], - names=columns - ) - - data = from_arrow(pyarrow_table) - cdef table_view struct_view = table_view_from_columns(data) - - s.reset( - new struct_scalar(struct_view, valid) - ) - cdef _get_py_dict_from_struct(unique_ptr[scalar]& s, dtype): if not s.get()[0].is_valid(): return NA @@ -395,25 +389,6 @@ cdef _get_py_dict_from_struct(unique_ptr[scalar]& s, dtype): python_dict = table.to_pydict()["None"][0] return {k: _nested_na_replace([python_dict[k]])[0] for k in python_dict} -cdef _set_list_from_pylist(unique_ptr[scalar]& s, - object value, - object dtype, - bool valid=True): - - value = value if valid else [NA] - cdef Column col - if isinstance(dtype.element_type, ListDtype): - pa_type = dtype.element_type.to_arrow() - else: - pa_type = dtype.to_arrow().value_type - col = cudf.core.column.as_column( - pa.array(value, from_pandas=True, type=pa_type) - ) - cdef column_view col_view = col.view() - s.reset( - new list_scalar(col_view, valid) - ) - cdef _get_py_list_from_list(unique_ptr[scalar]& s, dtype): @@ -601,9 +576,9 @@ def _create_proxy_nat_scalar(dtype): if dtype.char in 'mM': nat = dtype.type('NaT').astype(dtype) if dtype.type == np.datetime64: - _set_datetime64_from_np_scalar(result.c_value, nat, dtype, True) + _set_datetime64_from_np_scalar(result.c_value.c_obj, nat, dtype, True) elif dtype.type == np.timedelta64: - _set_timedelta64_from_np_scalar(result.c_value, nat, dtype, True) + _set_timedelta64_from_np_scalar(result.c_value.c_obj, nat, dtype, True) return result else: raise TypeError('NAT only valid for datetime and timedelta') diff --git a/python/cudf/cudf/tests/test_list.py b/python/cudf/cudf/tests/test_list.py index 5dd58d8a875..ac10dd97c56 100644 --- a/python/cudf/cudf/tests/test_list.py +++ b/python/cudf/cudf/tests/test_list.py @@ -895,14 +895,14 @@ def test_memory_usage(): "data, idx", [ ( - [[{"f2": {"a": 100}, "f1": "a"}, {"f1": "sf12", "f2": None}]], + [[{"f2": {"a": 100}, "f1": "a"}, {"f1": "sf12", "f2": NA}]], 0, ), ( [ [ {"f2": {"a": 100, "c": 90, "f2": 10}, "f1": "a"}, - {"f1": "sf12", "f2": None}, + {"f1": "sf12", "f2": NA}, ] ], 0, diff --git a/python/cudf/cudf/tests/test_struct.py b/python/cudf/cudf/tests/test_struct.py index a3593e55b97..ce6dc587320 100644 --- a/python/cudf/cudf/tests/test_struct.py +++ b/python/cudf/cudf/tests/test_struct.py @@ -150,9 +150,7 @@ def test_struct_setitem(data, item): "data", [ {"a": 1, "b": "rapids", "c": [1, 2, 3, 4]}, - {"a": 1, "b": "rapids", "c": [1, 2, 3, 4], "d": cudf.NA}, {"a": "Hello"}, - {"b": [], "c": [1, 2, 3]}, ], ) def test_struct_scalar_host_construction(data): @@ -161,6 +159,39 @@ def test_struct_scalar_host_construction(data): assert list(slr.device_value.value.values()) == list(data.values()) +@pytest.mark.parametrize( + ("data", "dtype"), + [ + ( + {"a": 1, "b": "rapids", "c": [1, 2, 3, 4], "d": cudf.NA}, + cudf.StructDtype( + { + "a": np.dtype(np.int64), + "b": np.dtype(np.str_), + "c": cudf.ListDtype(np.dtype(np.int64)), + "d": np.dtype(np.int64), + } + ), + ), + ( + {"b": [], "c": [1, 2, 3]}, + cudf.StructDtype( + { + "b": cudf.ListDtype(np.dtype(np.int64)), + "c": cudf.ListDtype(np.dtype(np.int64)), + } + ), + ), + ], +) +def test_struct_scalar_host_construction_no_dtype_inference(data, dtype): + # cudf cannot infer the dtype of the scalar when it contains only nulls or + # is empty. + slr = cudf.Scalar(data, dtype=dtype) + assert slr.value == data + assert list(slr.device_value.value.values()) == list(data.values()) + + def test_struct_scalar_null(): slr = cudf.Scalar(cudf.NA, dtype=StructDtype) assert slr.device_value.value is cudf.NA diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index 1b94db75340..73ea8e2cfc4 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -463,24 +463,6 @@ def _get_nan_for_dtype(dtype): return np.float64("nan") -def _decimal_to_int64(decimal: Decimal) -> int: - """ - Scale a Decimal such that the result is the integer - that would result from removing the decimal point. - - Examples - -------- - >>> _decimal_to_int64(Decimal('1.42')) - 142 - >>> _decimal_to_int64(Decimal('0.0042')) - 42 - >>> _decimal_to_int64(Decimal('-1.004201')) - -1004201 - - """ - return int(f"{decimal:0f}".replace(".", "")) - - def get_allowed_combinations_for_operator(dtype_l, dtype_r, op): error = TypeError( f"{op} not supported between {dtype_l} and {dtype_r} scalars"