diff --git a/.asf.yaml b/.asf.yaml index 646bdac5f912..968c6779215a 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +# Documentation can be found here: +# https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=127405038 + notifications: commits: commits@arrow.apache.org issues: github@arrow.apache.org @@ -28,4 +31,11 @@ github: merge: false rebase: false features: - issues: true \ No newline at end of file + issues: true + protected_branches: + master: + required_status_checks: + # require branches to be up-to-date before merging + strict: true + # don't require any jobs to pass + contexts: [] \ No newline at end of file diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml new file mode 100644 index 000000000000..0157caf8c296 --- /dev/null +++ b/.github/actions/setup-builder/action.yaml @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Prepare Rust Builder +description: 'Prepare Rust Build Environment' +inputs: + rust-version: + description: 'version of rust to install (e.g. stable)' + required: true + default: 'stable' +runs: + using: "composite" + steps: + - name: Cache Cargo + uses: actions/cache@v3 + with: + # these represent dependencies downloaded by cargo + # and thus do not depend on the OS, arch nor rust version. + # + # source https://github.com/actions/cache/blob/main/examples.md#rust---cargo + path: | + /usr/local/cargo/bin/ + /usr/local/cargo/registry/index/ + /usr/local/cargo/registry/cache/ + /usr/local/cargo/git/db/ + key: cargo-cache3-${{ hashFiles('**/Cargo.toml') }} + restore-keys: cargo-cache3- + - name: Generate lockfile + shell: bash + run: cargo fetch + - name: Cache Rust dependencies + uses: actions/cache@v3 + with: + # these represent compiled steps of both dependencies and arrow + # and thus are specific for a particular OS, arch and rust version. + path: /github/home/target + key: ${{ runner.os }}-${{ runner.arch }}-target-cache3-${{ inputs.rust-version }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-${{ runner.arch }}-target-cache3-${{ inputs.rust-version }}- + - name: Install Build Dependencies + shell: bash + run: | + apt-get update + apt-get install -y protobuf-compiler + - name: Setup Rust toolchain + shell: bash + run: | + echo "Installing ${{ inputs.rust-version }}" + rustup toolchain install ${{ inputs.rust-version }} + rustup default ${{ inputs.rust-version }} + echo "CARGO_TARGET_DIR=/github/home/target" >> $GITHUB_ENV diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..9bd42dbaa0d6 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,9 @@ +version: 2 +updates: + - package-ecosystem: cargo + directory: "/" + schedule: + interval: daily + open-pull-requests-limit: 10 + target-branch: master + labels: [auto-dependencies] diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 41b1dcbe8eb9..7eed6b8e94c9 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -39,7 +39,7 @@ jobs: path: rust fetch-depth: 0 - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: 3.8 - name: Setup Archery @@ -64,17 +64,17 @@ jobs: rustup default ${{ matrix.rust }} rustup component add rustfmt clippy - name: Cache Cargo - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /home/runner/.cargo key: cargo-maturin-cache- - name: Cache Rust dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /home/runner/target # this key is not equal because maturin uses different compilation flags. key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v3 with: python-version: '3.7' - name: Upgrade pip and setuptools diff --git a/.github/workflows/miri.sh b/.github/workflows/miri.sh new file mode 100755 index 000000000000..56da5c5c5d3e --- /dev/null +++ b/.github/workflows/miri.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# +# Script +# +# Must be run with nightly rust for example +# rustup default nightly + + +# stacked borrows checking uses too much memory to run successfully in github actions +# re-enable if the CI is migrated to something more powerful (https://github.com/apache/arrow-rs/issues/1833) +# see also https://github.com/rust-lang/miri/issues/1367 +export MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-disable-stacked-borrows" +cargo miri setup +cargo clean + +echo "Starting Arrow MIRI run..." +cargo miri test -p arrow -- --skip csv --skip ipc --skip json diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index 136b0e136008..7feacc07dd73 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -name: Rust +name: MIRI on: # always trigger @@ -23,38 +23,21 @@ on: pull_request: jobs: - miri-checks: name: MIRI runs-on: ubuntu-latest - strategy: - matrix: - arch: [amd64] - rust: [nightly-2021-07-04] steps: - uses: actions/checkout@v2 with: submodules: true - - uses: actions/cache@v2 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-miri-${{ hashFiles('**/Cargo.lock') }} - name: Setup Rust toolchain run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt clippy miri + rustup toolchain install nightly --component miri + rustup override set nightly + cargo miri setup - name: Run Miri Checks env: RUST_BACKTRACE: full - RUST_LOG: 'trace' + RUST_LOG: "trace" run: | - export MIRIFLAGS="-Zmiri-disable-isolation" - cargo miri setup - cargo clean - # Currently only the arrow crate is tested with miri - # IO related tests and some unsupported tests are skipped - cargo miri test -p arrow -- --skip csv --skip ipc --skip json + bash .github/workflows/miri.sh diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 20392efa5102..9331db745659 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -30,8 +30,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - arch: [amd64] - rust: [stable] + arch: [ amd64 ] + rust: [ stable ] container: image: ${{ matrix.arch }}/rust env: @@ -40,40 +40,23 @@ jobs: RUSTFLAGS: "-C debuginfo=1" steps: - uses: actions/checkout@v2 - - name: Cache Cargo - uses: actions/cache@v2 - with: - # these represent dependencies downloaded by cargo - # and thus do not depend on the OS, arch nor rust version. - path: /github/home/.cargo - key: cargo-cache2- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - # these represent compiled steps of both dependencies and arrow - # and thus are specific for a particular OS, arch and rust version. - path: /github/home/target - key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }}- - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Build Workspace run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" cargo build # test the crate linux-test: name: Test Workspace on AMD64 Rust ${{ matrix.rust }} - needs: [linux-build-lib] + needs: [ linux-build-lib ] runs-on: ubuntu-latest strategy: matrix: - arch: [amd64] - rust: [stable] + arch: [ amd64 ] + rust: [ stable ] container: image: ${{ matrix.arch }}/rust env: @@ -86,43 +69,52 @@ jobs: - uses: actions/checkout@v2 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache2- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - path: /github/home/target - # this key equals the ones on `linux-build-lib` for re-use - key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Run tests run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" # run tests on all workspace members with default feature list cargo test - cd arrow - # re-run tests on arrow workspace with additional features - cargo test --features=prettyprint - # run test on arrow with minimal set of features - cargo test --no-default-features + - name: Re-run tests with all supported features + run: | + cargo test -p arrow --features=force_validate,prettyprint + - name: Run examples + run: | + # Test arrow examples cargo run --example builders cargo run --example dynamic_types cargo run --example read_csv cargo run --example read_csv_infer_schema - # Exit arrow directory - cd .. - (cd parquet && cargo check --no-default-features) - (cd arrow && cargo check --no-default-features) - (cd arrow-flight && cargo check --no-default-features) + - name: Test compilation of arrow library crate with different feature combinations + run: | + cargo check -p arrow + cargo check -p arrow --no-default-features + - name: Test compilation of arrow targets with different feature combinations + run: | + cargo check -p arrow --all-targets + cargo check -p arrow --no-default-features --all-targets + cargo check -p arrow --no-default-features --all-targets --features test_utils + - name: Re-run tests on arrow-flight with all features + run: | + cargo test -p arrow-flight --all-features + - name: Re-run tests on parquet crate with all features + run: | + cargo test -p parquet --all-features + - name: Test compilation of parquet library crate with different feature combinations + run: | + cargo check -p parquet + cargo check -p parquet --no-default-features + cargo check -p parquet --no-default-features --features arrow + - name: Test compilation of parquet targets with different feature combinations + run: | + cargo check -p parquet --all-targets + cargo check -p parquet --no-default-features --all-targets + cargo check -p parquet --no-default-features --features arrow --all-targets + - name: Test compilation of parquet_derive macro with different feature combinations + run: | + cargo check -p parquet_derive # test the --features "simd" of the arrow crate. This requires nightly. linux-test-simd: @@ -130,8 +122,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - arch: [amd64] - rust: [nightly] + arch: [ amd64 ] + rust: [ nightly ] container: image: ${{ matrix.arch }}/rust env: @@ -143,43 +135,25 @@ jobs: - uses: actions/checkout@v2 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache2- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - path: /github/home/target - # this key equals the ones on `linux-build-lib` for re-use - key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} - name: Run tests run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" - cd arrow - cargo test --features "simd" - - name: Check new project build with simd features + cargo test -p arrow --features "simd" + - name: Check compilation with simd features run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" - cd arrow/test/dependency/simd - cargo check + cargo check -p arrow --features simd + cargo check -p arrow --features simd --all-targets windows-and-macos: name: Test on ${{ matrix.os }} Rust ${{ matrix.rust }} runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, macos-latest] - rust: [stable] + os: [ windows-latest, macos-latest ] + rust: [ stable ] steps: - uses: actions/checkout@v2 with: @@ -190,7 +164,6 @@ jobs: run: | rustup toolchain install ${{ matrix.rust }} rustup default ${{ matrix.rust }} - rustup component add rustfmt - name: Run tests shell: bash run: | @@ -202,12 +175,12 @@ jobs: clippy: name: Clippy - needs: [linux-build-lib] + needs: [ linux-build-lib ] runs-on: ubuntu-latest strategy: matrix: - arch: [amd64] - rust: [stable] + arch: [ amd64 ] + rust: [ stable ] container: image: ${{ matrix.arch }}/rust env: @@ -218,31 +191,44 @@ jobs: - uses: actions/checkout@v2 with: submodules: true - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache2- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - path: /github/home/target - # this key equals the ones on `linux-build-lib` for re-use - key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} + - name: Setup Clippy run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt clippy + rustup component add clippy - name: Run clippy run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" - cargo clippy --features test_common --all-targets --workspace -- -D warnings -A clippy::redundant_field_names + cargo clippy --features test_common --features prettyprint --features=async --all-targets --workspace -- -D warnings + + check_benches: + name: Check Benchmarks (but don't run them) + runs-on: ubuntu-latest + strategy: + matrix: + arch: [ amd64 ] + rust: [ stable ] + container: + image: ${{ matrix.arch }}/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} + - name: Check benchmarks + run: | + cargo check --benches --workspace --features test_common,prettyprint,async,experimental lint: - name: Lint + name: Lint (cargo fmt) runs-on: ubuntu-latest container: image: amd64/rust @@ -261,24 +247,28 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - arch: [amd64] - rust: [stable] + arch: [ amd64 ] + rust: [ stable ] steps: - uses: actions/checkout@v2 with: submodules: true + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} - name: Cache Cargo - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /home/runner/.cargo # this key is not equal because the user is different than on a container (runner vs github) - key: cargo-coverage-cache- + key: cargo-coverage-cache3- - name: Cache Rust dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /home/runner/target # this key is not equal because coverage uses different compilation flags. - key: ${{ runner.os }}-${{ matrix.arch }}-target-coverage-cache-${{ matrix.rust }}- + key: ${{ runner.os }}-${{ matrix.arch }}-target-coverage-cache3-${{ matrix.rust }}- - name: Run coverage run: | export CARGO_HOME="/home/runner/.cargo" @@ -287,6 +277,8 @@ jobs: export ARROW_TEST_DATA=$(pwd)/testing/data export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data + rustup toolchain install stable + rustup default stable cargo install --version 0.18.2 cargo-tarpaulin cargo tarpaulin --all --out Xml - name: Report coverage @@ -299,8 +291,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - arch: [amd64] - rust: [nightly] + arch: [ amd64 ] + rust: [ nightly ] container: image: ${{ matrix.arch }}/rust env: @@ -314,78 +306,54 @@ jobs: with: submodules: true - name: Cache Cargo - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache2- + key: cargo-wasm32-cache3- - name: Cache Rust dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /github/home/target - key: ${{ runner.os }}-${{ matrix.arch }}-target-wasm32-cache-${{ matrix.rust }} - - name: Setup Rust toolchain + key: ${{ runner.os }}-${{ matrix.arch }}-target-wasm32-cache3-${{ matrix.rust }} + - name: Setup Rust toolchain for WASM run: | rustup toolchain install ${{ matrix.rust }} rustup override set ${{ matrix.rust }} - rustup component add rustfmt rustup target add wasm32-unknown-unknown rustup target add wasm32-wasi - name: Build arrow crate run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" cd arrow cargo build --no-default-features --features=csv,ipc,simd --target wasm32-unknown-unknown cargo build --no-default-features --features=csv,ipc,simd --target wasm32-wasi - # test builds with various feature flag combinations outside the main workspace - default-build: - name: Feature Flag Builds ${{ matrix.rust }} + # test doc links still work + docs: + name: Docs are clean on AMD64 Rust ${{ matrix.rust }} runs-on: ubuntu-latest strategy: matrix: - arch: [amd64] - rust: [stable] + arch: [ amd64 ] + rust: [ nightly ] container: image: ${{ matrix.arch }}/rust env: - # Disable debug symbol generation to speed up CI build and keep memory down - RUSTFLAGS: "-C debuginfo=0" + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + RUSTDOCFLAGS: "-Dwarnings" steps: - uses: actions/checkout@v2 - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /github/home/.cargo - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache2- - - name: Cache Rust dependencies - uses: actions/cache@v2 with: - path: /github/home/target - # this key equals the ones on `linux-build-lib` for re-use - key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }} - - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup override set ${{ matrix.rust }} - rustup component add rustfmt - - name: Arrow Build with default features - run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" - cd arrow/test/dependency/default-features - cargo check - - name: Arrow Build with default-features=false + submodules: true + - name: Install python dev run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" - cd arrow/test/dependency/no-default-features - cargo check - - name: Parquet Derive build with default-features + apt update + apt install -y libpython3.9-dev + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} + - name: Run cargo doc run: | - export CARGO_HOME="/github/home/.cargo" - export CARGO_TARGET_DIR="/github/home/target" - cd parquet_derive/test/dependency/default-features - cargo check + cargo doc --document-private-items --no-deps --workspace --all-features diff --git a/.github_changelog_generator b/.github_changelog_generator index b02125c16848..cc23a6332d60 100644 --- a/.github_changelog_generator +++ b/.github_changelog_generator @@ -18,10 +18,8 @@ # under the License. # -# point to the old changelog in apache/arrow -front-matter=For older versions, see [apache/arrow/CHANGELOG.md](https://github.com/apache/arrow/blob/master/CHANGELOG.md)\n -# some issues are just documentation -add-sections={"documentation":{"prefix":"**Documentation updates:**","labels":["documentation"]}} +# Add special sections for documentation, security and performance +add-sections={"documentation":{"prefix":"**Documentation updates:**","labels":["documentation"]},"security":{"prefix":"**Security updates:**","labels":["security"]},"performance":{"prefix":"**Performance improvements:**","labels":["performance"]}} # uncomment to not show PRs. TBD if we shown them or not. #pull-requests=false # so that the component is shown associated with the issue diff --git a/.gitignore b/.gitignore index 8c158a246328..2088dd5d2068 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,7 @@ rusty-tags.vi .vscode venv/* # created by doctests -parquet/data.parquet \ No newline at end of file +parquet/data.parquet +# release notes cache +.githubchangeloggenerator.cache +.githubchangeloggenerator.cache.log \ No newline at end of file diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md new file mode 100644 index 000000000000..518697ce09a0 --- /dev/null +++ b/CHANGELOG-old.md @@ -0,0 +1,1311 @@ + + + +## [15.0.0](https://github.com/apache/arrow-rs/tree/15.0.0) (2022-05-27) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/14.0.0...15.0.0) + +**Breaking changes:** + +- Change `ArrayDataBuilder::null_bit_buffer` to accept `Option` rather than `Buffer` [\#1739](https://github.com/apache/arrow-rs/pull/1739) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove `null_count` from `ArrayData::try_new()` [\#1721](https://github.com/apache/arrow-rs/pull/1721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Change parquet writers to use standard `std:io::Write` rather custom `ParquetWriter` trait \(\#1717\) \(\#1163\) [\#1719](https://github.com/apache/arrow-rs/pull/1719) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add explicit column mask for selection in parquet: `ProjectionMask` \(\#1701\) [\#1716](https://github.com/apache/arrow-rs/pull/1716) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add type\_ids in Union datatype [\#1703](https://github.com/apache/arrow-rs/pull/1703) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix Parquet Reader's Arrow Schema Inference [\#1682](https://github.com/apache/arrow-rs/pull/1682) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Rename the `string` kernel to `concatenate_elements` [\#1747](https://github.com/apache/arrow-rs/issues/1747) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `ArrayDataBuilder::null_bit_buffer` should accept `Option` as input type [\#1737](https://github.com/apache/arrow-rs/issues/1737) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix schema comparison for non\_canonical\_map when running flight test [\#1730](https://github.com/apache/arrow-rs/issues/1730) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support in aggregate kernel for `BinaryArray` [\#1724](https://github.com/apache/arrow-rs/issues/1724) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix incorrect null\_count in `generate_unions_case` integration test [\#1712](https://github.com/apache/arrow-rs/issues/1712) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Keep type ids in Union datatype to follow Arrow spec and integrate with other implementations [\#1690](https://github.com/apache/arrow-rs/issues/1690) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Reading Alternative List Representations to Arrow From Parquet [\#1680](https://github.com/apache/arrow-rs/issues/1680) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Speed up the offsets checking [\#1675](https://github.com/apache/arrow-rs/issues/1675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Separate Parquet -\> Arrow Schema Conversion From ArrayBuilder [\#1655](https://github.com/apache/arrow-rs/issues/1655) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `leaf_columns` argument to `ArrowReader::get_record_reader_by_columns` [\#1653](https://github.com/apache/arrow-rs/issues/1653) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Implement `string_concat` kernel [\#1540](https://github.com/apache/arrow-rs/issues/1540) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve Unit Test Coverage of ArrayReaderBuilder [\#1484](https://github.com/apache/arrow-rs/issues/1484) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Parquet write failure \(from record batches\) when data is nested two levels deep [\#1744](https://github.com/apache/arrow-rs/issues/1744) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- IPC reader may break on projection [\#1735](https://github.com/apache/arrow-rs/issues/1735) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Latest nightly fails to build with feature simd [\#1734](https://github.com/apache/arrow-rs/issues/1734) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Trying to write parquet file in parallel results in corrupt file [\#1717](https://github.com/apache/arrow-rs/issues/1717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Roundtrip failure when using DELTA\_BINARY\_PACKED [\#1708](https://github.com/apache/arrow-rs/issues/1708) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `ArrayData::try_new` cannot always return expected error. [\#1707](https://github.com/apache/arrow-rs/issues/1707) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- "out of order projection is not supported" after Fix Parquet Arrow Schema Inference [\#1701](https://github.com/apache/arrow-rs/issues/1701) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Rust is not interoperability with C++ for IPC schemas with dictionaries [\#1694](https://github.com/apache/arrow-rs/issues/1694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect Repeated Field Schema Inference [\#1681](https://github.com/apache/arrow-rs/issues/1681) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet Treats Embedded Arrow Schema as Authoritative [\#1663](https://github.com/apache/arrow-rs/issues/1663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- parquet\_to\_arrow\_schema\_by\_columns Incorrectly Handles Nested Types [\#1654](https://github.com/apache/arrow-rs/issues/1654) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Inconsistent Arrow Schema When Projecting Nested Parquet File [\#1652](https://github.com/apache/arrow-rs/issues/1652) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- StructArrayReader Cannot Handle Nested Lists [\#1651](https://github.com/apache/arrow-rs/issues/1651) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Bug \(`substring` kernel\): The null buffer is not aligned when `offset != 0` [\#1639](https://github.com/apache/arrow-rs/issues/1639) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Parquet command line tool does not install "globally" [\#1710](https://github.com/apache/arrow-rs/issues/1710) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Improve integration test document to follow Arrow C++ repo CI [\#1742](https://github.com/apache/arrow-rs/pull/1742) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +**Merged pull requests:** + +- Test for list array equality with different offsets [\#1756](https://github.com/apache/arrow-rs/pull/1756) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Rename `string_concat` to `concat_elements_utf8` [\#1754](https://github.com/apache/arrow-rs/pull/1754) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Rename the `string` kernel to `concat_elements`. [\#1752](https://github.com/apache/arrow-rs/pull/1752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Support writing nested lists to parquet [\#1746](https://github.com/apache/arrow-rs/pull/1746) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Pin nightly version to bypass packed\_simd build error [\#1743](https://github.com/apache/arrow-rs/pull/1743) ([viirya](https://github.com/viirya)) +- Fix projection in IPC reader [\#1736](https://github.com/apache/arrow-rs/pull/1736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([iyupeng](https://github.com/iyupeng)) +- `cargo install` installs not globally [\#1732](https://github.com/apache/arrow-rs/pull/1732) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kazuk](https://github.com/kazuk)) +- Fix schema comparison for non\_canonical\_map when running flight test [\#1731](https://github.com/apache/arrow-rs/pull/1731) ([viirya](https://github.com/viirya)) +- Add `min_binary` and `max_binary` aggregate kernels [\#1725](https://github.com/apache/arrow-rs/pull/1725) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix parquet benchmarks [\#1723](https://github.com/apache/arrow-rs/pull/1723) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix BitReader::get\_batch zero extension \(\#1708\) [\#1722](https://github.com/apache/arrow-rs/pull/1722) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Implementation string concat [\#1720](https://github.com/apache/arrow-rs/pull/1720) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ismail-Maj](https://github.com/Ismail-Maj)) +- Check the length of `null_bit_buffer` in `ArrayData::try_new()` [\#1714](https://github.com/apache/arrow-rs/pull/1714) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix incorrect null\_count in `generate_unions_case` integration test [\#1713](https://github.com/apache/arrow-rs/pull/1713) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix: Null buffer accounts for `offset` in `substring` kernel. [\#1704](https://github.com/apache/arrow-rs/pull/1704) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Minor: Refine `OffsetSizeTrait` to extend `num::Integer` [\#1702](https://github.com/apache/arrow-rs/pull/1702) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix StructArrayReader handling nested lists \(\#1651\) [\#1700](https://github.com/apache/arrow-rs/pull/1700) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Speed up the offsets checking [\#1684](https://github.com/apache/arrow-rs/pull/1684) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) + +## [14.0.0](https://github.com/apache/arrow-rs/tree/14.0.0) (2022-05-13) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/13.0.0...14.0.0) + +**Breaking changes:** + +- Use `bytes` in parquet rather than custom Buffer implementation \(\#1474\) [\#1683](https://github.com/apache/arrow-rs/pull/1683) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Rename `OffsetSize::fn is_large` to `const OffsetSize::IS_LARGE` [\#1664](https://github.com/apache/arrow-rs/pull/1664) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove `StringOffsetTrait` and `BinaryOffsetTrait` [\#1645](https://github.com/apache/arrow-rs/pull/1645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix `generate_nested_dictionary_case` integration test failure [\#1636](https://github.com/apache/arrow-rs/pull/1636) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) + +**Implemented enhancements:** + +- Add support for `DataType::Duration` in ffi interface [\#1688](https://github.com/apache/arrow-rs/issues/1688) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix `generate_unions_case` integration test [\#1676](https://github.com/apache/arrow-rs/issues/1676) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `DictionaryArray` support for `bit_length` kernel [\#1673](https://github.com/apache/arrow-rs/issues/1673) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `DictionaryArray` support for `length` kernel [\#1672](https://github.com/apache/arrow-rs/issues/1672) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- flight\_client\_scenarios integration test should receive schema from flight data [\#1669](https://github.com/apache/arrow-rs/issues/1669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Unpin Flatbuffer version dependency [\#1667](https://github.com/apache/arrow-rs/issues/1667) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add dictionary array support for substring function [\#1656](https://github.com/apache/arrow-rs/issues/1656) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Exclude dict\_id and dict\_is\_ordered from equality comparison of `Field` [\#1646](https://github.com/apache/arrow-rs/issues/1646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove `StringOffsetTrait` and `BinaryOffsetTrait` [\#1644](https://github.com/apache/arrow-rs/issues/1644) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add tests and examples for `UnionArray::from(data: ArrayData)` [\#1643](https://github.com/apache/arrow-rs/issues/1643) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add methods `pub fn offsets_buffer`, `pub fn types_ids_buffer`and `pub fn data_buffer` for `ArrayDataBuilder` [\#1640](https://github.com/apache/arrow-rs/issues/1640) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix `generate_nested_dictionary_case` integration test failure for Rust cases [\#1635](https://github.com/apache/arrow-rs/issues/1635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Expose `ArrowWriter` row group flush in public API [\#1626](https://github.com/apache/arrow-rs/issues/1626) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `substring` support for `FixedSizeBinaryArray` [\#1618](https://github.com/apache/arrow-rs/issues/1618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add PrettyPrint for `UnionArray`s [\#1594](https://github.com/apache/arrow-rs/issues/1594) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add SIMD support for the `length` kernel [\#1489](https://github.com/apache/arrow-rs/issues/1489) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support dictionary arrays in length and bit\_length [\#1674](https://github.com/apache/arrow-rs/pull/1674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add dictionary array support for substring function [\#1665](https://github.com/apache/arrow-rs/pull/1665) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sunchao](https://github.com/sunchao)) +- Add `DecimalType` support in `new_null_array ` [\#1659](https://github.com/apache/arrow-rs/pull/1659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) + +**Fixed bugs:** + +- Docs.rs build is broken [\#1695](https://github.com/apache/arrow-rs/issues/1695) +- Interoperability with C++ for IPC schemas with dictionaries [\#1694](https://github.com/apache/arrow-rs/issues/1694) +- `UnionArray::is_null` incorrect [\#1625](https://github.com/apache/arrow-rs/issues/1625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Published Parquet documentation missing `arrow::async_reader` [\#1617](https://github.com/apache/arrow-rs/issues/1617) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Files written with Julia's Arrow.jl in IPC format cannot be read by arrow-rs [\#1335](https://github.com/apache/arrow-rs/issues/1335) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Correct arrow-flight readme version [\#1641](https://github.com/apache/arrow-rs/pull/1641) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Make `OffsetSizeTrait::IS_LARGE` as a const value [\#1658](https://github.com/apache/arrow-rs/issues/1658) +- Question: Why are there 3 types of `OffsetSizeTrait`s? [\#1638](https://github.com/apache/arrow-rs/issues/1638) +- Written Parquet file way bigger than input files [\#1627](https://github.com/apache/arrow-rs/issues/1627) +- Ensure there is a single zero in the offsets buffer for an empty ListArray. [\#1620](https://github.com/apache/arrow-rs/issues/1620) +- Filtering `UnionArray` Changes DataType [\#1595](https://github.com/apache/arrow-rs/issues/1595) + +**Merged pull requests:** + +- Fix docs.rs build [\#1696](https://github.com/apache/arrow-rs/pull/1696) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- support duration in ffi [\#1689](https://github.com/apache/arrow-rs/pull/1689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ryan-jacobs1](https://github.com/ryan-jacobs1)) +- fix bench command line options [\#1685](https://github.com/apache/arrow-rs/pull/1685) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kazuk](https://github.com/kazuk)) +- Enable branch protection [\#1679](https://github.com/apache/arrow-rs/pull/1679) ([tustvold](https://github.com/tustvold)) +- Fix logical merge conflict in \#1588 [\#1678](https://github.com/apache/arrow-rs/pull/1678) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix generate\_unions\_case for Rust case [\#1677](https://github.com/apache/arrow-rs/pull/1677) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Receive schema from flight data [\#1670](https://github.com/apache/arrow-rs/pull/1670) ([viirya](https://github.com/viirya)) +- unpin flatbuffers dependency version [\#1668](https://github.com/apache/arrow-rs/pull/1668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Cheappie](https://github.com/Cheappie)) +- Remove parquet dictionary converters \(\#1661\) [\#1662](https://github.com/apache/arrow-rs/pull/1662) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Minor: simplify the function `GenericListArray::get_type` [\#1650](https://github.com/apache/arrow-rs/pull/1650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Pretty Print `UnionArray`s [\#1648](https://github.com/apache/arrow-rs/pull/1648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tfeda](https://github.com/tfeda)) +- Exclude `dict_id` and `dict_is_ordered` from equality comparison of `Field` [\#1647](https://github.com/apache/arrow-rs/pull/1647) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- expose row-group flush in public api [\#1634](https://github.com/apache/arrow-rs/pull/1634) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Cheappie](https://github.com/Cheappie)) +- Add `substring` support for `FixedSizeBinaryArray` [\#1633](https://github.com/apache/arrow-rs/pull/1633) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix UnionArray is\_null [\#1632](https://github.com/apache/arrow-rs/pull/1632) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Do not assume dictionaries exists in footer [\#1631](https://github.com/apache/arrow-rs/pull/1631) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pcjentsch](https://github.com/pcjentsch)) +- Add support for nested list arrays from parquet to arrow arrays \(\#993\) [\#1588](https://github.com/apache/arrow-rs/pull/1588) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add `async` into doc features [\#1349](https://github.com/apache/arrow-rs/pull/1349) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([HaoYang670](https://github.com/HaoYang670)) + + +## [13.0.0](https://github.com/apache/arrow-rs/tree/13.0.0) (2022-04-29) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/12.0.0...13.0.0) + +**Breaking changes:** + +- Update `parquet::basic::LogicalType` to be more idomatic [\#1612](https://github.com/apache/arrow-rs/pull/1612) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tfeda](https://github.com/tfeda)) +- Fix Null Mask Handling in `ArrayData`, `UnionArray`, and `MapArray` [\#1589](https://github.com/apache/arrow-rs/pull/1589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Replace `&Option` with `Option<&T>` in several `arrow` and `parquet` APIs [\#1571](https://github.com/apache/arrow-rs/pull/1571) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tfeda](https://github.com/tfeda)) + +**Implemented enhancements:** + +- Read/write nested dictionary under fixed size list in ipc stream reader/write [\#1609](https://github.com/apache/arrow-rs/issues/1609) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for `BinaryArray` in `substring` kernel [\#1593](https://github.com/apache/arrow-rs/issues/1593) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Read/write nested dictionary under large list in ipc stream reader/write [\#1584](https://github.com/apache/arrow-rs/issues/1584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Read/write nested dictionary under map in ipc stream reader/write [\#1582](https://github.com/apache/arrow-rs/issues/1582) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `Clone` for JSON `DecoderOptions` [\#1580](https://github.com/apache/arrow-rs/issues/1580) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add utf-8 validation checking to `substring` kernel [\#1575](https://github.com/apache/arrow-rs/issues/1575) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting to/from `DataType::Null` in `cast` kernel [\#1572](https://github.com/apache/arrow-rs/pull/1572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([WinkerDu](https://github.com/WinkerDu)) + +**Fixed bugs:** + +- Parquet schema should allow scale == precision for decimal type [\#1606](https://github.com/apache/arrow-rs/issues/1606) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- ListArray::from\(ArrayData\) dereferences invalid pointer when offsets are empty [\#1601](https://github.com/apache/arrow-rs/issues/1601) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ArrayData Equality Incorrect Null Mask Offset Handling [\#1599](https://github.com/apache/arrow-rs/issues/1599) +- Filtering UnionArray Incorrect Handles Runs [\#1598](https://github.com/apache/arrow-rs/issues/1598) +- \[Safety\] Filtering Dense UnionArray Produces Invalid Offsets [\#1596](https://github.com/apache/arrow-rs/issues/1596) +- \[Safety\] UnionBuilder Doesn't Check Types [\#1591](https://github.com/apache/arrow-rs/issues/1591) +- Union Layout Should Not Support Separate Validity Mask [\#1590](https://github.com/apache/arrow-rs/issues/1590) +- Incorrect nullable flag when reading maps \( test\_read\_maps fails when `force_validate` is active\) [\#1587](https://github.com/apache/arrow-rs/issues/1587) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Output of `ipc::reader::tests::projection_should_work` fails validation [\#1548](https://github.com/apache/arrow-rs/issues/1548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect min/max statistics for decimals with byte-array notation [\#1532](https://github.com/apache/arrow-rs/issues/1532) + +**Documentation updates:** + +- Minor: Clarify docs on `UnionBuilder::append_null` [\#1628](https://github.com/apache/arrow-rs/pull/1628) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Dense UnionArray Offsets Are i32 not i8 [\#1597](https://github.com/apache/arrow-rs/issues/1597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace `&Option` with `Option<&T>` in some APIs [\#1556](https://github.com/apache/arrow-rs/issues/1556) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve ergonomics of `parquet::basic::LogicalType` [\#1554](https://github.com/apache/arrow-rs/issues/1554) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Mark the current `substring` function as `unsafe` and rename it. [\#1541](https://github.com/apache/arrow-rs/issues/1541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Requirements for Async Parquet API [\#1473](https://github.com/apache/arrow-rs/issues/1473) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Nit: use the standard function `div_ceil` [\#1629](https://github.com/apache/arrow-rs/pull/1629) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Update flatbuffers requirement from =2.1.1 to =2.1.2 [\#1622](https://github.com/apache/arrow-rs/pull/1622) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix decimals min max statistics [\#1621](https://github.com/apache/arrow-rs/pull/1621) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([atefsawaed](https://github.com/atefsawaed)) +- Add example readme [\#1615](https://github.com/apache/arrow-rs/pull/1615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve docs and examples links on main readme [\#1614](https://github.com/apache/arrow-rs/pull/1614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Read/Write nested dictionaries under FixedSizeList in IPC [\#1610](https://github.com/apache/arrow-rs/pull/1610) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add `substring` support for binary [\#1608](https://github.com/apache/arrow-rs/pull/1608) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Parquet: schema validation should allow scale == precision for decimal type [\#1607](https://github.com/apache/arrow-rs/pull/1607) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sunchao](https://github.com/sunchao)) +- Don't access and validate offset buffer in ListArray::from\(ArrayData\) [\#1602](https://github.com/apache/arrow-rs/pull/1602) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Fix map nullable flag in `ParquetTypeConverter` [\#1592](https://github.com/apache/arrow-rs/pull/1592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Read/write nested dictionary under large list in ipc stream reader/writer [\#1585](https://github.com/apache/arrow-rs/pull/1585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Read/write nested dictionary under map in ipc stream reader/writer [\#1583](https://github.com/apache/arrow-rs/pull/1583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Derive `Clone` and `PartialEq` for json `DecoderOptions` [\#1581](https://github.com/apache/arrow-rs/pull/1581) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add utf-8 validation checking for `substring` [\#1577](https://github.com/apache/arrow-rs/pull/1577) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Use `Option` rather than `Option<&T>` for copy types in substring kernel [\#1576](https://github.com/apache/arrow-rs/pull/1576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use littleendian arrow files for `projection_should_work` [\#1573](https://github.com/apache/arrow-rs/pull/1573) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + + +## [12.0.0](https://github.com/apache/arrow-rs/tree/12.0.0) (2022-04-15) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/11.1.0...12.0.0) + +**Breaking changes:** + +- Add `ArrowReaderOptions` to `ParquetFileArrowReader`, add option to skip decoding arrow metadata from parquet \(\#1459\) [\#1558](https://github.com/apache/arrow-rs/pull/1558) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support `RecordBatch` with zero columns but non zero row count, add field to `RecordBatchOptions` \(\#1536\) [\#1552](https://github.com/apache/arrow-rs/pull/1552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Consolidate JSON Reader options and `DecoderOptions` [\#1539](https://github.com/apache/arrow-rs/pull/1539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Update `prost`, `prost-derive` and `prost-types` to 0.10, `tonic`, and `tonic-build` to `0.7` [\#1510](https://github.com/apache/arrow-rs/pull/1510) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Add Json `DecoderOptions` and support custom `format_string` for each field [\#1451](https://github.com/apache/arrow-rs/pull/1451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sum12](https://github.com/sum12)) + +**Implemented enhancements:** + +- Read/write nested dictionary in ipc stream reader/writer [\#1565](https://github.com/apache/arrow-rs/issues/1565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `FixedSizeBinary` in the Arrow C data interface [\#1553](https://github.com/apache/arrow-rs/issues/1553) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support Empty Column Projection in `ParquetRecordBatchReader` [\#1537](https://github.com/apache/arrow-rs/issues/1537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support `RecordBatch` with zero columns but non zero row count [\#1536](https://github.com/apache/arrow-rs/issues/1536) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for `Date32`/`Date64`\<--\> `String`/`LargeString` in `cast` kernel [\#1535](https://github.com/apache/arrow-rs/issues/1535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support creating arrays from externally owned memory like `Vec` or `String` [\#1516](https://github.com/apache/arrow-rs/issues/1516) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up the `substring` kernel [\#1511](https://github.com/apache/arrow-rs/issues/1511) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Handle Parquet Files With Inconsistent Timestamp Units [\#1459](https://github.com/apache/arrow-rs/issues/1459) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Error Infering Schema for LogicalType::UNKNOWN [\#1557](https://github.com/apache/arrow-rs/issues/1557) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Read dictionary from nested struct in ipc stream reader panics [\#1549](https://github.com/apache/arrow-rs/issues/1549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `filter` produces invalid sparse `UnionArray`s [\#1547](https://github.com/apache/arrow-rs/issues/1547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Documentation for `GenericListBuilder` is not exposed. [\#1518](https://github.com/apache/arrow-rs/issues/1518) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- cannot read parquet file [\#1515](https://github.com/apache/arrow-rs/issues/1515) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- The `substring` kernel panics when chars \> U+0x007F [\#1478](https://github.com/apache/arrow-rs/issues/1478) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Hang due to infinite loop when reading some parquet files with RLE encoding and bit packing [\#1458](https://github.com/apache/arrow-rs/issues/1458) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Improve JSON reader documentation [\#1559](https://github.com/apache/arrow-rs/pull/1559) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve doc string for `substring` kernel [\#1529](https://github.com/apache/arrow-rs/pull/1529) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Expose documentation of `GenericListBuilder` [\#1525](https://github.com/apache/arrow-rs/pull/1525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comath](https://github.com/comath)) +- Add a diagram to `take` kernel documentation [\#1524](https://github.com/apache/arrow-rs/pull/1524) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Interesting benchmark results of `min_max_helper` [\#1400](https://github.com/apache/arrow-rs/issues/1400) + +**Merged pull requests:** + +- Fix incorrect `into_buffers` for UnionArray [\#1567](https://github.com/apache/arrow-rs/pull/1567) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Read/write nested dictionary in ipc stream reader/writer [\#1566](https://github.com/apache/arrow-rs/pull/1566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support FixedSizeBinary and FixedSizeList for the C data interface [\#1564](https://github.com/apache/arrow-rs/pull/1564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sunchao](https://github.com/sunchao)) +- Split out ListArrayReader into separate module \(\#1483\) [\#1563](https://github.com/apache/arrow-rs/pull/1563) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Split out `MapArray` into separate module \(\#1483\) [\#1562](https://github.com/apache/arrow-rs/pull/1562) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support empty projection in `ParquetRecordBatchReader` [\#1560](https://github.com/apache/arrow-rs/pull/1560) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- fix infinite loop in not fully packed bit-packed runs [\#1555](https://github.com/apache/arrow-rs/pull/1555) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add test for creating FixedSizeBinaryArray::try\_from\_sparse\_iter failed when given all Nones [\#1551](https://github.com/apache/arrow-rs/pull/1551) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix reading dictionaries from nested structs in ipc `StreamReader` [\#1550](https://github.com/apache/arrow-rs/pull/1550) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dispanser](https://github.com/dispanser)) +- Add support for Date32/64 \<--\> String/LargeString in `cast` kernel [\#1534](https://github.com/apache/arrow-rs/pull/1534) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- fix clippy errors in 1.60 [\#1527](https://github.com/apache/arrow-rs/pull/1527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Mark `remove-old-releases.sh` executable [\#1522](https://github.com/apache/arrow-rs/pull/1522) ([alamb](https://github.com/alamb)) +- Delete duplicate code in the `sort` kernel [\#1519](https://github.com/apache/arrow-rs/pull/1519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix reading nested lists from parquet files [\#1517](https://github.com/apache/arrow-rs/pull/1517) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Speed up the `substring` kernel by about 2x [\#1512](https://github.com/apache/arrow-rs/pull/1512) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add `new_from_strings` to create `MapArrays` [\#1507](https://github.com/apache/arrow-rs/pull/1507) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Decouple buffer deallocation from ffi and allow creating buffers from rust vec [\#1494](https://github.com/apache/arrow-rs/pull/1494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) + +## [11.1.0](https://github.com/apache/arrow-rs/tree/11.1.0) (2022-03-31) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/11.0.0...11.1.0) + +**Implemented enhancements:** + +- Implement `size_hint` and `ExactSizedIterator` for DecimalArray [\#1505](https://github.com/apache/arrow-rs/issues/1505) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support calculate length by chars for `StringArray` [\#1493](https://github.com/apache/arrow-rs/issues/1493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `length` kernel support for `ListArray` [\#1470](https://github.com/apache/arrow-rs/issues/1470) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- The length kernel should work with `BinaryArray`s [\#1464](https://github.com/apache/arrow-rs/issues/1464) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FFI for Arrow C Stream Interface [\#1348](https://github.com/apache/arrow-rs/issues/1348) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of `DictionaryArray::try_new()` [\#1313](https://github.com/apache/arrow-rs/issues/1313) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- MIRI error in math\_checked\_divide\_op/try\_from\_trusted\_len\_iter [\#1496](https://github.com/apache/arrow-rs/issues/1496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet Writer Incorrect Definition Levels for Nested NullArray [\#1480](https://github.com/apache/arrow-rs/issues/1480) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- FFI: ArrowArray::try\_from\_raw shouldn't clone [\#1425](https://github.com/apache/arrow-rs/issues/1425) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet reader fails to read null list. [\#1399](https://github.com/apache/arrow-rs/issues/1399) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- A small mistake in the doc of `BinaryArray` and `LargeBinaryArray` [\#1455](https://github.com/apache/arrow-rs/issues/1455) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- A small mistake in the doc of `GenericBinaryArray::take_iter_unchecked` [\#1454](https://github.com/apache/arrow-rs/issues/1454) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add links in the doc of `BinaryOffsetSizeTrait` [\#1453](https://github.com/apache/arrow-rs/issues/1453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- The doc of `FixedSizeBinaryArray` is confusing. [\#1452](https://github.com/apache/arrow-rs/issues/1452) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Clarify docs that SlicesIterator ignores null values [\#1504](https://github.com/apache/arrow-rs/pull/1504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Update the doc of `BinaryArray` and `LargeBinaryArray` [\#1471](https://github.com/apache/arrow-rs/pull/1471) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) + +**Closed issues:** + +- `packed_simd` v.s. `portable_simd`, which should be used? [\#1492](https://github.com/apache/arrow-rs/issues/1492) +- Cleanup: Use Arrow take kernel Within parquet ListArrayReader [\#1482](https://github.com/apache/arrow-rs/issues/1482) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Implement `size_hint` and `ExactSizedIterator` for `DecimalArray` [\#1506](https://github.com/apache/arrow-rs/pull/1506) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add `StringArray::num_chars` for calculating number of characters [\#1503](https://github.com/apache/arrow-rs/pull/1503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Workaround nightly miri error in `try_from_trusted_len_iter` [\#1497](https://github.com/apache/arrow-rs/pull/1497) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- update doc of array\_binary and array\_string [\#1491](https://github.com/apache/arrow-rs/pull/1491) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Use Arrow take kernel within ListArrayReader [\#1490](https://github.com/apache/arrow-rs/pull/1490) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add `length` kernel support for List Array [\#1488](https://github.com/apache/arrow-rs/pull/1488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Support sort for `Decimal` data type [\#1487](https://github.com/apache/arrow-rs/pull/1487) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- Fix reading/writing nested null arrays \(\#1480\) \(\#1036\) \(\#1399\) [\#1481](https://github.com/apache/arrow-rs/pull/1481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Implement ArrayEqual for UnionArray [\#1469](https://github.com/apache/arrow-rs/pull/1469) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support the `length` kernel on Binary Array [\#1465](https://github.com/apache/arrow-rs/pull/1465) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove Clone and copy source structs internally [\#1449](https://github.com/apache/arrow-rs/pull/1449) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix Parquet reader for null lists [\#1448](https://github.com/apache/arrow-rs/pull/1448) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Improve performance of DictionaryArray::try\_new\(\)  [\#1435](https://github.com/apache/arrow-rs/pull/1435) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Add FFI for Arrow C Stream Interface [\#1384](https://github.com/apache/arrow-rs/pull/1384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +## [11.0.0](https://github.com/apache/arrow-rs/tree/11.0.0) (2022-03-17) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/10.0.0...11.0.0) + +**Breaking changes:** + +- Replace `filter_row_groups` with `ReadOptions` in parquet SerializedFileReader [\#1389](https://github.com/apache/arrow-rs/pull/1389) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([yjshen](https://github.com/yjshen)) +- Implement projection for arrow `IPC Reader` file / streams [\#1339](https://github.com/apache/arrow-rs/pull/1339) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Dandandan](https://github.com/Dandandan)) + +**Implemented enhancements:** + +- Fix generate\_interval\_case integration test failure [\#1445](https://github.com/apache/arrow-rs/issues/1445) +- Make the doc examples of `ListArray` and `LargeListArray` more readable [\#1433](https://github.com/apache/arrow-rs/issues/1433) +- Redundant `if` and `abs` in `shift()` [\#1427](https://github.com/apache/arrow-rs/issues/1427) +- Improve substring kernel performance [\#1422](https://github.com/apache/arrow-rs/issues/1422) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add missing value\_unchecked\(\) of `FixedSizeBinaryArray` [\#1419](https://github.com/apache/arrow-rs/issues/1419) +- Remove duplicate bound check in function `shift` [\#1408](https://github.com/apache/arrow-rs/issues/1408) +- Support dictionary array in C data interface [\#1397](https://github.com/apache/arrow-rs/issues/1397) +- filter kernel should work with `UnionArray`s [\#1394](https://github.com/apache/arrow-rs/issues/1394) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- filter kernel should work with `FixedSizeListArrays`s [\#1393](https://github.com/apache/arrow-rs/issues/1393) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add doc examples for creating FixedSizeListArray [\#1392](https://github.com/apache/arrow-rs/issues/1392) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update `rust-version` to 1.59 [\#1377](https://github.com/apache/arrow-rs/issues/1377) +- Arrow IPC projection support [\#1338](https://github.com/apache/arrow-rs/issues/1338) +- Implement basic FlightSQL Server [\#1386](https://github.com/apache/arrow-rs/pull/1386) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([wangfenjin](https://github.com/wangfenjin)) + +**Fixed bugs:** + +- DictionaryArray::try\_new ignores validity bitmap of the keys [\#1429](https://github.com/apache/arrow-rs/issues/1429) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- The doc of `GenericListArray` is confusing [\#1424](https://github.com/apache/arrow-rs/issues/1424) +- DeltaBitPackDecoder Incorrectly Handles Non-Zero MiniBlock Bit Width Padding [\#1417](https://github.com/apache/arrow-rs/issues/1417) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- DeltaBitPackEncoder Pads Miniblock BitWidths With Arbitrary Values [\#1416](https://github.com/apache/arrow-rs/issues/1416) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Possible unaligned write with MutableBuffer::push [\#1410](https://github.com/apache/arrow-rs/issues/1410) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Integration Test is failing on master branch [\#1398](https://github.com/apache/arrow-rs/issues/1398) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Rewrite doc of `GenericListArray` [\#1450](https://github.com/apache/arrow-rs/pull/1450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix integration doc about build.ninja location [\#1438](https://github.com/apache/arrow-rs/pull/1438) ([viirya](https://github.com/viirya)) + +**Merged pull requests:** + +- Rewrite doc example of `ListArray` and `LargeListArray` [\#1447](https://github.com/apache/arrow-rs/pull/1447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix generate\_interval\_case in integration test [\#1446](https://github.com/apache/arrow-rs/pull/1446) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix generate\_decimal128\_case in integration test [\#1440](https://github.com/apache/arrow-rs/pull/1440) ([viirya](https://github.com/viirya)) +- `filter` kernel should work with FixedSizeListArrays [\#1434](https://github.com/apache/arrow-rs/pull/1434) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support nullable keys in DictionaryArray::try\_new [\#1430](https://github.com/apache/arrow-rs/pull/1430) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- remove redundant if/clamp\_min/abs [\#1428](https://github.com/apache/arrow-rs/pull/1428) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Add doc example for creating `FixedSizeListArray` [\#1426](https://github.com/apache/arrow-rs/pull/1426) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Directly write to MutableBuffer in substring [\#1423](https://github.com/apache/arrow-rs/pull/1423) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix possibly unaligned writes in MutableBuffer [\#1421](https://github.com/apache/arrow-rs/pull/1421) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Add value\_unchecked\(\) and unit test [\#1420](https://github.com/apache/arrow-rs/pull/1420) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Fix DeltaBitPack MiniBlock Bit Width Padding [\#1418](https://github.com/apache/arrow-rs/pull/1418) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update zstd requirement from 0.10 to 0.11 [\#1415](https://github.com/apache/arrow-rs/pull/1415) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Set `default-features = false` for `zstd` in the parquet crate to support `wasm32-unknown-unknown` [\#1414](https://github.com/apache/arrow-rs/pull/1414) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kylebarron](https://github.com/kylebarron)) +- Add support for `UnionArray` in`filter` kernel [\#1412](https://github.com/apache/arrow-rs/pull/1412) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove duplicate bound check in the function `shift` [\#1409](https://github.com/apache/arrow-rs/pull/1409) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add dictionary support for C data interface [\#1407](https://github.com/apache/arrow-rs/pull/1407) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sunchao](https://github.com/sunchao)) +- Fix a small spelling mistake in docs. [\#1406](https://github.com/apache/arrow-rs/pull/1406) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add unit test to check `FixedSizeBinaryArray` input all none [\#1405](https://github.com/apache/arrow-rs/pull/1405) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jackwener](https://github.com/jackwener)) +- Move csv Parser trait and its implementations to utils module [\#1385](https://github.com/apache/arrow-rs/pull/1385) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sum12](https://github.com/sum12)) + +## [10.0.0](https://github.com/apache/arrow-rs/tree/10.0.0) (2022-03-04) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/9.1.0...10.0.0) + +**Breaking changes:** + +- Remove existing has\_ methods for optional fields in `ColumnChunkMetaData` [\#1346](https://github.com/apache/arrow-rs/pull/1346) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([shanisolomon](https://github.com/shanisolomon)) +- Remove redundant `has_` methods in `ColumnChunkMetaData` [\#1345](https://github.com/apache/arrow-rs/pull/1345) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([shanisolomon](https://github.com/shanisolomon)) + +**Implemented enhancements:** + +- Add extract month and day in temporal.rs [\#1387](https://github.com/apache/arrow-rs/issues/1387) +- Add clone to `IpcWriteOptions` [\#1381](https://github.com/apache/arrow-rs/issues/1381) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `MapArray` in `filter` kernel [\#1378](https://github.com/apache/arrow-rs/issues/1378) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `week` temporal kernel [\#1375](https://github.com/apache/arrow-rs/issues/1375) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of `compare_dict_op` [\#1371](https://github.com/apache/arrow-rs/issues/1371) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for LargeUtf8 in json writer [\#1357](https://github.com/apache/arrow-rs/issues/1357) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make `arrow::array::builder::MapBuilder` public [\#1354](https://github.com/apache/arrow-rs/issues/1354) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Refactor `StructArray::from` [\#1351](https://github.com/apache/arrow-rs/issues/1351) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Refactor `RecordBatch::validate_new_batch` [\#1350](https://github.com/apache/arrow-rs/issues/1350) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove redundant has\_ methods for optional column metadata fields [\#1344](https://github.com/apache/arrow-rs/issues/1344) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `write` method to JsonWriter [\#1340](https://github.com/apache/arrow-rs/issues/1340) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Refactor the code of `Bitmap::new` [\#1337](https://github.com/apache/arrow-rs/issues/1337) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use DictionaryArray's iterator in `compare_dict_op` [\#1329](https://github.com/apache/arrow-rs/issues/1329) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `as_decimal_array(arr: &dyn Array) -> &DecimalArray` [\#1312](https://github.com/apache/arrow-rs/issues/1312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- More ergonomic / idiomatic primitive array creation from iterators [\#1298](https://github.com/apache/arrow-rs/issues/1298) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement DictionaryArray support in `eq_dyn`, `neq_dyn`, `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#1201](https://github.com/apache/arrow-rs/issues/1201) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- `cargo clippy` fails on the `master` branch [\#1362](https://github.com/apache/arrow-rs/issues/1362) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `ArrowArray::try_from_raw` should not assume the pointers are from Arc [\#1333](https://github.com/apache/arrow-rs/issues/1333) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix CSV Writer::new to accept delimiter and make WriterBuilder::build use it [\#1328](https://github.com/apache/arrow-rs/issues/1328) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make bounds configurable via builder when reading CSV [\#1327](https://github.com/apache/arrow-rs/issues/1327) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `with_datetime_format()` to CSV WriterBuilder [\#1272](https://github.com/apache/arrow-rs/issues/1272) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Performance improvements:** + +- Improve performance of `min` and `max` aggregation kernels without nulls [\#1373](https://github.com/apache/arrow-rs/issues/1373) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Consider removing redundant has\_XXX metadata functions in `ColumnChunkMetadata` [\#1332](https://github.com/apache/arrow-rs/issues/1332) + +**Merged pull requests:** + +- Support extract `day` and `month` in temporal.rs [\#1388](https://github.com/apache/arrow-rs/pull/1388) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Add write method to Json Writer [\#1383](https://github.com/apache/arrow-rs/pull/1383) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([matthewmturner](https://github.com/matthewmturner)) +- Derive `Clone` for `IpcWriteOptions` [\#1382](https://github.com/apache/arrow-rs/pull/1382) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([matthewmturner](https://github.com/matthewmturner)) +- feat: support maps in MutableArrayData [\#1379](https://github.com/apache/arrow-rs/pull/1379) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([helgikrs](https://github.com/helgikrs)) +- Support extract `week` in temporal.rs [\#1376](https://github.com/apache/arrow-rs/pull/1376) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Speed up the function `min_max_string` [\#1374](https://github.com/apache/arrow-rs/pull/1374) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Improve performance if dictionary kernels, add benchmark and add `take_iter_unchecked` [\#1372](https://github.com/apache/arrow-rs/pull/1372) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update pyo3 requirement from 0.15 to 0.16 [\#1369](https://github.com/apache/arrow-rs/pull/1369) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update contributing guide [\#1368](https://github.com/apache/arrow-rs/pull/1368) ([HaoYang670](https://github.com/HaoYang670)) +- Allow primitive array creation from iterators of PrimitiveTypes \(as well as `Option`\) [\#1367](https://github.com/apache/arrow-rs/pull/1367) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Update flatbuffers requirement from =2.1.0 to =2.1.1 [\#1364](https://github.com/apache/arrow-rs/pull/1364) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix clippy lints [\#1363](https://github.com/apache/arrow-rs/pull/1363) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Refactor `RecordBatch::validate_new_batch` [\#1361](https://github.com/apache/arrow-rs/pull/1361) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Refactor `StructArray::from` [\#1360](https://github.com/apache/arrow-rs/pull/1360) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Update flatbuffers requirement from =2.0.0 to =2.1.0 [\#1359](https://github.com/apache/arrow-rs/pull/1359) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: add LargeUtf8 support in json writer [\#1358](https://github.com/apache/arrow-rs/pull/1358) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tiphaineruy](https://github.com/tiphaineruy)) +- Add `as_decimal_array` function [\#1356](https://github.com/apache/arrow-rs/pull/1356) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Publicly export arrow::array::MapBuilder [\#1355](https://github.com/apache/arrow-rs/pull/1355) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tjwilson90](https://github.com/tjwilson90)) +- Add with\_datetime\_format to csv WriterBuilder [\#1347](https://github.com/apache/arrow-rs/pull/1347) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- Refactor `Bitmap::new` [\#1343](https://github.com/apache/arrow-rs/pull/1343) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove delimiter from csv Writer [\#1342](https://github.com/apache/arrow-rs/pull/1342) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- Make bounds configurable in csv ReaderBuilder [\#1341](https://github.com/apache/arrow-rs/pull/1341) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- `ArrowArray::try_from_raw` should not assume the pointers are from Arc [\#1334](https://github.com/apache/arrow-rs/pull/1334) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Use DictionaryArray's iterator in `compare_dict_op` [\#1330](https://github.com/apache/arrow-rs/pull/1330) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Implement DictionaryArray support in neq\_dyn, lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn [\#1326](https://github.com/apache/arrow-rs/pull/1326) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Arrow Rust + Conbench Integration [\#1289](https://github.com/apache/arrow-rs/pull/1289) ([dianaclarke](https://github.com/dianaclarke)) + +## [9.1.0](https://github.com/apache/arrow-rs/tree/9.1.0) (2022-02-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/9.0.2...9.1.0) + +**Implemented enhancements:** + +- Exposing page encoding stats [\#1321](https://github.com/apache/arrow-rs/issues/1321) +- Improve filter performance by special casing high and low selectivity predicates [\#1288](https://github.com/apache/arrow-rs/issues/1288) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up `DeltaBitPackDecoder` [\#1281](https://github.com/apache/arrow-rs/issues/1281) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Fix all clippy lints in arrow crate [\#1255](https://github.com/apache/arrow-rs/issues/1255) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Expose page encoding `ColumnChunkMetadata` [\#1322](https://github.com/apache/arrow-rs/pull/1322) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([shanisolomon](https://github.com/shanisolomon)) +- Expose column index and offset index in `ColumnChunkMetadata` [\#1318](https://github.com/apache/arrow-rs/pull/1318) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([shanisolomon](https://github.com/shanisolomon)) +- Expose bloom filter offset in `ColumnChunkMetadata` [\#1309](https://github.com/apache/arrow-rs/pull/1309) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([shanisolomon](https://github.com/shanisolomon)) +- Add `DictionaryArray::try_new()` to create dictionaries from pre existing arrays [\#1300](https://github.com/apache/arrow-rs/pull/1300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add `DictionaryArray::keys_iter`, and `take_iter` for other array types [\#1296](https://github.com/apache/arrow-rs/pull/1296) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make `rle` decoder public under `experimental` feature [\#1271](https://github.com/apache/arrow-rs/pull/1271) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- Add `DictionaryArray` support in `eq_dyn` kernel [\#1263](https://github.com/apache/arrow-rs/pull/1263) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +**Fixed bugs:** + +- `len` is not a parameter of `MutableArrayData::extend` [\#1316](https://github.com/apache/arrow-rs/issues/1316) +- module `data_type` is private in Rust Parquet 8.0.0 [\#1302](https://github.com/apache/arrow-rs/issues/1302) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Test failure: bit\_chunk\_iterator [\#1294](https://github.com/apache/arrow-rs/issues/1294) +- csv\_writer benchmark fails with "no such file or directory" [\#1292](https://github.com/apache/arrow-rs/issues/1292) + +**Documentation updates:** + +- Fix warnings in `cargo doc` [\#1268](https://github.com/apache/arrow-rs/pull/1268) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- Vectorize DeltaBitPackDecoder, up to 5x faster decoding [\#1284](https://github.com/apache/arrow-rs/pull/1284) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Skip zero-ing primitive nulls [\#1280](https://github.com/apache/arrow-rs/pull/1280) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add specialized filter kernels in `compute` module \(up to 10x faster\) [\#1248](https://github.com/apache/arrow-rs/pull/1248) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Closed issues:** + +- Expose column and offset index metadata offset [\#1317](https://github.com/apache/arrow-rs/issues/1317) +- Expose bloom filter metadata offset [\#1308](https://github.com/apache/arrow-rs/issues/1308) +- Improve ergonomics to construct `DictionaryArrays` from `Key` and `Value` arrays [\#1299](https://github.com/apache/arrow-rs/issues/1299) +- Make it easier to iterate over `DictionaryArray` [\#1295](https://github.com/apache/arrow-rs/issues/1295) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- (WON'T FIX) Don't Interwine Bit and Byte Aligned Operations in `BitReader` [\#1282](https://github.com/apache/arrow-rs/issues/1282) +- how to create arrow::array from streamReader [\#1278](https://github.com/apache/arrow-rs/issues/1278) +- Remove scientific notation when converting floats to strings. [\#983](https://github.com/apache/arrow-rs/issues/983) + +**Merged pull requests:** + +- Update the document of function `MutableArrayData::extend` [\#1336](https://github.com/apache/arrow-rs/pull/1336) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix clippy lint `dead_code` [\#1324](https://github.com/apache/arrow-rs/pull/1324) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- fix test bug and ensure that bloom filter metadata is serialized in `to_thrift` [\#1320](https://github.com/apache/arrow-rs/pull/1320) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([shanisolomon](https://github.com/shanisolomon)) +- Enable more clippy lints in arrow [\#1315](https://github.com/apache/arrow-rs/pull/1315) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- Fix clippy lint `clippy::type_complexity` [\#1310](https://github.com/apache/arrow-rs/pull/1310) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- Fix clippy lint `clippy::float_equality_without_abs` [\#1305](https://github.com/apache/arrow-rs/pull/1305) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- Fix clippy `clippy::vec_init_then_push` lint [\#1303](https://github.com/apache/arrow-rs/pull/1303) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gsserge](https://github.com/gsserge)) +- Fix failing csv\_writer bench [\#1293](https://github.com/apache/arrow-rs/pull/1293) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- Changes for 9.0.2 [\#1291](https://github.com/apache/arrow-rs/pull/1291) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Fix bitmask creation also for simd comparisons with scalar [\#1290](https://github.com/apache/arrow-rs/pull/1290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Fix simd comparison kernels [\#1286](https://github.com/apache/arrow-rs/pull/1286) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Restrict Decoder to compatible types \(\#1276\) [\#1277](https://github.com/apache/arrow-rs/pull/1277) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix some clippy lints in parquet crate, rename `LevelEncoder` variants to conform to Rust standards [\#1273](https://github.com/apache/arrow-rs/pull/1273) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([HaoYang670](https://github.com/HaoYang670)) +- Use new DecimalArray creation API in arrow crate [\#1249](https://github.com/apache/arrow-rs/pull/1249) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve `DecimalArray` API ergonomics: add `iter()`, `FromIterator`, `with_precision_and_scale` [\#1223](https://github.com/apache/arrow-rs/pull/1223) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + + +## [9.0.2](https://github.com/apache/arrow-rs/tree/9.0.2) (2022-02-09) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/8.0.0...9.0.2) + +**Breaking changes:** + +- Add `Send` + `Sync` to `DataType`, `RowGroupReader`, `FileReader`, `ChunkReader`. [\#1264](https://github.com/apache/arrow-rs/issues/1264) +- Rename the function `Bitmap::len` to `Bitmap::bit_len` to clarify its meaning [\#1242](https://github.com/apache/arrow-rs/pull/1242) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove unused / broken `memory-check` feature [\#1222](https://github.com/apache/arrow-rs/pull/1222) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Potentially buffer multiple `RecordBatches` before writing a parquet row group in `ArrowWriter` [\#1214](https://github.com/apache/arrow-rs/pull/1214) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Add `async` arrow parquet reader [\#1154](https://github.com/apache/arrow-rs/pull/1154) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Rename `Bitmap::len` to `Bitmap::bit_len` [\#1233](https://github.com/apache/arrow-rs/issues/1233) +- Extend CSV schema inference to allow scientific notation for floating point types [\#1215](https://github.com/apache/arrow-rs/issues/1215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Write Multiple RecordBatch to Parquet Row Group [\#1211](https://github.com/apache/arrow-rs/issues/1211) +- Add doc examples for `eq_dyn` etc. [\#1202](https://github.com/apache/arrow-rs/issues/1202) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add comparison kernels for `BinaryArray` [\#1108](https://github.com/apache/arrow-rs/issues/1108) +- `impl ArrowNativeType for i128` [\#1098](https://github.com/apache/arrow-rs/issues/1098) +- Remove `Copy` trait bound from dyn scalar kernels [\#1243](https://github.com/apache/arrow-rs/pull/1243) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([matthewmturner](https://github.com/matthewmturner)) +- Add `into_inner` for IPC `FileWriter` [\#1236](https://github.com/apache/arrow-rs/pull/1236) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) +- \[Minor\]Re-export `array::builder::make_builder` to make it available for downstream [\#1235](https://github.com/apache/arrow-rs/pull/1235) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yjshen](https://github.com/yjshen)) + +**Fixed bugs:** + +- Parquet v8.0.0 panics when reading all null column to NullArray [\#1245](https://github.com/apache/arrow-rs/issues/1245) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Get `Unknown configuration option rust-version` when running the rust format command [\#1240](https://github.com/apache/arrow-rs/issues/1240) +- `Bitmap` Length Validation is Incorrect [\#1231](https://github.com/apache/arrow-rs/issues/1231) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Writing sliced `ListArray` or `MapArray` ignore offsets [\#1226](https://github.com/apache/arrow-rs/issues/1226) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove broken `memory-tracking` crate feature [\#1171](https://github.com/apache/arrow-rs/issues/1171) +- Revert making `parquet::data_type` and `parquet::arrow::schema` experimental [\#1244](https://github.com/apache/arrow-rs/pull/1244) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Documentation updates:** + +- Update parquet crate documentation and examples [\#1253](https://github.com/apache/arrow-rs/pull/1253) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Refresh parquet readme / contributing guide [\#1252](https://github.com/apache/arrow-rs/pull/1252) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add docs examples for dynamically compare functions [\#1250](https://github.com/apache/arrow-rs/pull/1250) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add Rust Docs examples for UnionArray [\#1241](https://github.com/apache/arrow-rs/pull/1241) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Improve documentation for Bitmap [\#1237](https://github.com/apache/arrow-rs/pull/1237) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- Improve performance for arithmetic kernels with `simd` feature enabled \(except for division/modulo\) [\#1221](https://github.com/apache/arrow-rs/pull/1221) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Do not concatenate identical dictionaries [\#1219](https://github.com/apache/arrow-rs/pull/1219) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Preserve dictionary encoding when decoding parquet into Arrow arrays, 60x perf improvement \(\#171\) [\#1180](https://github.com/apache/arrow-rs/pull/1180) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Closed issues:** + +- `UnalignedBitChunkIterator` to that iterates through already aligned `u64` blocks [\#1227](https://github.com/apache/arrow-rs/issues/1227) +- Remove unused `ArrowArrayReader` in parquet [\#1197](https://github.com/apache/arrow-rs/issues/1197) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Upgrade clap to 3.0.0 [\#1261](https://github.com/apache/arrow-rs/pull/1261) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Jimexist](https://github.com/Jimexist)) +- Update chrono-tz requirement from 0.4 to 0.6 [\#1259](https://github.com/apache/arrow-rs/pull/1259) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Update zstd requirement from 0.9 to 0.10 [\#1257](https://github.com/apache/arrow-rs/pull/1257) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix NullArrayReader \(\#1245\) [\#1246](https://github.com/apache/arrow-rs/pull/1246) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- dyn compare for binary array [\#1238](https://github.com/apache/arrow-rs/pull/1238) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove arrow array reader \(\#1197\) [\#1234](https://github.com/apache/arrow-rs/pull/1234) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix null bitmap length validation \(\#1231\) [\#1232](https://github.com/apache/arrow-rs/pull/1232) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Faster bitmask iteration [\#1228](https://github.com/apache/arrow-rs/pull/1228) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add non utf8 values into the test cases of BinaryArray comparison [\#1220](https://github.com/apache/arrow-rs/pull/1220) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Update DECIMAL\_RE to allow scientific notation in auto inferred schemas [\#1216](https://github.com/apache/arrow-rs/pull/1216) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pjmore](https://github.com/pjmore)) +- Fix simd comparison kernels [\#1286](https://github.com/apache/arrow-rs/pull/1286) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Fix bitmask creation also for simd comparisons with scalar [\#1290](https://github.com/apache/arrow-rs/pull/1290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) + +## [8.0.0](https://github.com/apache/arrow-rs/tree/8.0.0) (2022-01-20) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/7.0.0...8.0.0) + +**Breaking changes:** + +- Return error from JSON writer rather than panic [\#1205](https://github.com/apache/arrow-rs/pull/1205) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Remove `ArrowSignedNumericType ` to Simplify and reduce code duplication in arithmetic kernels [\#1161](https://github.com/apache/arrow-rs/pull/1161) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Restrict RecordReader and friends to scalar types \(\#1132\) [\#1155](https://github.com/apache/arrow-rs/pull/1155) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Move more parquet functionality behind experimental feature flag \(\#1032\) [\#1134](https://github.com/apache/arrow-rs/pull/1134) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Parquet reader should be able to read structs within list [\#1186](https://github.com/apache/arrow-rs/issues/1186) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Disable serde\_json `arbitrary_precision` feature flag [\#1174](https://github.com/apache/arrow-rs/issues/1174) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Simplify and reduce code duplication in arithmetic.rs [\#1160](https://github.com/apache/arrow-rs/issues/1160) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Return `Err` from JSON writer rather than `panic!` for unsupported types [\#1157](https://github.com/apache/arrow-rs/issues/1157) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `scalar` mathematics kernels for `Array` and scalar value [\#1153](https://github.com/apache/arrow-rs/issues/1153) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `DecimalArray` in sort kernel [\#1137](https://github.com/apache/arrow-rs/issues/1137) +- Parquet Fuzz Tests [\#1053](https://github.com/apache/arrow-rs/issues/1053) +- BooleanBufferBuilder Append Packed [\#1038](https://github.com/apache/arrow-rs/issues/1038) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet Performance Optimization: StructArrayReader Redundant Level & Bitmap Computation [\#1034](https://github.com/apache/arrow-rs/issues/1034) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Reduce Public Parquet API [\#1032](https://github.com/apache/arrow-rs/issues/1032) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `from_iter_values` for binary array [\#1188](https://github.com/apache/arrow-rs/pull/1188) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- Add support for `MapArray` in json writer [\#1149](https://github.com/apache/arrow-rs/pull/1149) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([helgikrs](https://github.com/helgikrs)) + +**Fixed bugs:** + +- Empty string arrays with no nulls are not equal [\#1208](https://github.com/apache/arrow-rs/issues/1208) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Pretty print a `RecordBatch` containing `Float16` triggers a panic [\#1193](https://github.com/apache/arrow-rs/issues/1193) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Writing structs nested in lists produces an incorrect output [\#1184](https://github.com/apache/arrow-rs/issues/1184) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Undefined behavior for `GenericStringArray::from_iter_values` if reported iterator upper bound is incorrect [\#1144](https://github.com/apache/arrow-rs/issues/1144) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Interval comparisons with `simd` feature asserts [\#1136](https://github.com/apache/arrow-rs/issues/1136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- RecordReader Permits Illegal Types [\#1132](https://github.com/apache/arrow-rs/issues/1132) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Security fixes:** + +- Fix undefined behavor in GenericStringArray::from\_iter\_values [\#1145](https://github.com/apache/arrow-rs/pull/1145) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- parquet: Optimized ByteArrayReader, Add UTF-8 Validation \(\#1040\) [\#1082](https://github.com/apache/arrow-rs/pull/1082) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Documentation updates:** + +- Update parquet crate readme [\#1192](https://github.com/apache/arrow-rs/pull/1192) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Document safety justification of some uses of `from_trusted_len_iter` [\#1148](https://github.com/apache/arrow-rs/pull/1148) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- Improve parquet reading performance for columns with nulls by preserving bitmask when possible \(\#1037\) [\#1054](https://github.com/apache/arrow-rs/pull/1054) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve parquet performance: Skip levels computation for required struct arrays in parquet [\#1035](https://github.com/apache/arrow-rs/pull/1035) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Closed issues:** + +- Generify ColumnReaderImpl and RecordReader [\#1040](https://github.com/apache/arrow-rs/issues/1040) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet Preserve BitMask [\#1037](https://github.com/apache/arrow-rs/issues/1037) + +**Merged pull requests:** + +- fix a bug in variable sized equality [\#1209](https://github.com/apache/arrow-rs/pull/1209) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([helgikrs](https://github.com/helgikrs)) +- Pin WASM / packed SIMD tests to nightly-2022-01-17 [\#1204](https://github.com/apache/arrow-rs/pull/1204) ([alamb](https://github.com/alamb)) +- feat: add support for casting Duration/Interval to Int64Array [\#1196](https://github.com/apache/arrow-rs/pull/1196) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([e-dard](https://github.com/e-dard)) +- Add comparison support for fully qualified BinaryArray [\#1195](https://github.com/apache/arrow-rs/pull/1195) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix in display of `Float16Array` [\#1194](https://github.com/apache/arrow-rs/pull/1194) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([helgikrs](https://github.com/helgikrs)) +- update nightly version for miri [\#1189](https://github.com/apache/arrow-rs/pull/1189) ([Jimexist](https://github.com/Jimexist)) +- feat\(parquet\): support for reading structs nested within lists [\#1187](https://github.com/apache/arrow-rs/pull/1187) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([helgikrs](https://github.com/helgikrs)) +- fix: Fix a bug in how definition levels are calculated for nested structs in a list [\#1185](https://github.com/apache/arrow-rs/pull/1185) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([helgikrs](https://github.com/helgikrs)) +- Truncate bitmask on BooleanBufferBuilder::resize: [\#1183](https://github.com/apache/arrow-rs/pull/1183) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add ticket reference for false positive in clippy [\#1181](https://github.com/apache/arrow-rs/pull/1181) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix record formatting in 1.58 [\#1178](https://github.com/apache/arrow-rs/pull/1178) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Serialize i128 as JSON string [\#1175](https://github.com/apache/arrow-rs/pull/1175) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support DecimalType in `sort` and `take` kernels [\#1172](https://github.com/apache/arrow-rs/pull/1172) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Fix new clippy lints introduced in Rust 1.58 [\#1170](https://github.com/apache/arrow-rs/pull/1170) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix compilation error with simd feature [\#1169](https://github.com/apache/arrow-rs/pull/1169) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Fix bug while writing parquet with empty lists of structs [\#1166](https://github.com/apache/arrow-rs/pull/1166) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([helgikrs](https://github.com/helgikrs)) +- Use tempfile for parquet tests [\#1165](https://github.com/apache/arrow-rs/pull/1165) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove left over dev/README.md file from arrow/arrow-rs split [\#1162](https://github.com/apache/arrow-rs/pull/1162) ([alamb](https://github.com/alamb)) +- Add multiply\_scalar kernel [\#1159](https://github.com/apache/arrow-rs/pull/1159) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fuzz test different parquet encodings [\#1156](https://github.com/apache/arrow-rs/pull/1156) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add subtract\_scalar kernel [\#1152](https://github.com/apache/arrow-rs/pull/1152) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add add\_scalar kernel [\#1151](https://github.com/apache/arrow-rs/pull/1151) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Move simd right out of for\_each loop [\#1150](https://github.com/apache/arrow-rs/pull/1150) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Internal Remove `GenericStringArray::from_vec` and `GenericStringArray::from_opt_vec` [\#1147](https://github.com/apache/arrow-rs/pull/1147) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Implement SIMD comparison operations for types with less than 4 lanes \(i128\) [\#1146](https://github.com/apache/arrow-rs/pull/1146) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Extends parquet fuzz tests to also tests nulls, dictionaries and row groups with multiple pages \(\#1053\) [\#1110](https://github.com/apache/arrow-rs/pull/1110) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Generify ColumnReaderImpl and RecordReader \(\#1040\) [\#1041](https://github.com/apache/arrow-rs/pull/1041) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- BooleanBufferBuilder::append\_packed \(\#1038\) [\#1039](https://github.com/apache/arrow-rs/pull/1039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +## [7.0.0](https://github.com/apache/arrow-rs/tree/7.0.0) (2022-1-07) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/6.5.0...7.0.0) + +### Arrow + +**Breaking changes:** +- `pretty_format_batches` now returns `Result` rather than `String`: [#975](https://github.com/apache/arrow-rs/pull/975) +- `MutableBuffer::typed_data_mut` is marked `unsafe`: [#1029](https://github.com/apache/arrow-rs/pull/1029) +- UnionArray updated match latest Arrow spec, added `UnionMode`, `UnionArray::new()` marked `unsafe`: [#885](https://github.com/apache/arrow-rs/pull/885) + +**New Features:** +- Support for `Float16Array` types [#888](https://github.com/apache/arrow-rs/pull/888) +- IPC support for `UnionArray` [#654](https://github.com/apache/arrow-rs/issues/654) +- Dynamic comparison kernels for scalars (e.g. `eq_dyn_scalar`), including `DictionaryArray`: [#1113](https://github.com/apache/arrow-rs/issues/1113) + +**Enhancements:** +- Added `Schema::with_metadata` and `Field::with_metadata` [#1092](https://github.com/apache/arrow-rs/pull/1092) +- Support for custom datetime format for inference and parsing csv files [#1112](https://github.com/apache/arrow-rs/pull/1112) +- Implement `Array` for `ArrayRef` for easier use [#1129](https://github.com/apache/arrow-rs/pull/1129) +- Pretty printing display support for `FixedSizeBinaryArray` [#1097](https://github.com/apache/arrow-rs/pull/1097) +- Dependency Upgrades: `pyo3`, `parquet-format`, `prost`, `tonic` +- Avoid allocating vector of indices in `lexicographical_partition_ranges`[#998](https://github.com/apache/arrow-rs/pull/998) + +### Parquet + +**Fixed bugs:** +- (parquet) Fix reading of dictionary encoded pages with null values: [#1130](https://github.com/apache/arrow-rs/pull/1130) + + +# Changelog + +## [6.5.0](https://github.com/apache/arrow-rs/tree/6.5.0) (2021-12-23) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/6.4.0...6.5.0) + +* [092fc64bbb019244887ebd0d9c9a2d3e3a9aebc0](https://github.com/apache/arrow-rs/commit/092fc64bbb019244887ebd0d9c9a2d3e3a9aebc0) support cast decimal to decimal ([#1084](https://github.com/apache/arrow-rs/pull/1084)) ([#1093](https://github.com/apache/arrow-rs/pull/1093)) +* [01459762ed18b504e00e7b2818fce91f19188b1e](https://github.com/apache/arrow-rs/commit/01459762ed18b504e00e7b2818fce91f19188b1e) Fix like regex escaping ([#1085](https://github.com/apache/arrow-rs/pull/1085)) ([#1090](https://github.com/apache/arrow-rs/pull/1090)) +* [7c748bfccbc2eac0c1138378736b70dcb7e26a5b](https://github.com/apache/arrow-rs/commit/7c748bfccbc2eac0c1138378736b70dcb7e26a5b) support cast decimal to signed numeric ([#1073](https://github.com/apache/arrow-rs/pull/1073)) ([#1089](https://github.com/apache/arrow-rs/pull/1089)) +* [bd3600b6483c253ae57a38928a636d39a6b7cb02](https://github.com/apache/arrow-rs/commit/bd3600b6483c253ae57a38928a636d39a6b7cb02) parquet: Use constant for RLE decoder buffer size ([#1070](https://github.com/apache/arrow-rs/pull/1070)) ([#1088](https://github.com/apache/arrow-rs/pull/1088)) +* [2b5c53ecd92468fd95328637a15de7f35b6fcf28](https://github.com/apache/arrow-rs/commit/2b5c53ecd92468fd95328637a15de7f35b6fcf28) Box RleDecoder index buffer ([#1061](https://github.com/apache/arrow-rs/pull/1061)) ([#1062](https://github.com/apache/arrow-rs/pull/1062)) ([#1081](https://github.com/apache/arrow-rs/pull/1081)) +* [78721bc1a467177679ad6196b994759cf4d73377](https://github.com/apache/arrow-rs/commit/78721bc1a467177679ad6196b994759cf4d73377) BooleanBufferBuilder correct buffer length ([#1051](https://github.com/apache/arrow-rs/pull/1051)) ([#1052](https://github.com/apache/arrow-rs/pull/1052)) ([#1080](https://github.com/apache/arrow-rs/pull/1080)) +* [3a5e3541d3a4db61a828011ed95c8539adf1d57c](https://github.com/apache/arrow-rs/commit/3a5e3541d3a4db61a828011ed95c8539adf1d57c) support cast signed numeric to decimal ([#1044](https://github.com/apache/arrow-rs/pull/1044)) ([#1079](https://github.com/apache/arrow-rs/pull/1079)) +* [000bdb3053098255d43288aa3e8665e8b1892a6c](https://github.com/apache/arrow-rs/commit/000bdb3053098255d43288aa3e8665e8b1892a6c) fix(compute): LIKE escape parenthesis ([#1042](https://github.com/apache/arrow-rs/pull/1042)) ([#1078](https://github.com/apache/arrow-rs/pull/1078)) +* [e0abdb9e62772a2f853974e68e744246e7f47569](https://github.com/apache/arrow-rs/commit/e0abdb9e62772a2f853974e68e744246e7f47569) Add Schema::project and RecordBatch::project functions ([#1033](https://github.com/apache/arrow-rs/pull/1033)) ([#1077](https://github.com/apache/arrow-rs/pull/1077)) +* [31911a4d6328d889d98796b896412b3997f73e13](https://github.com/apache/arrow-rs/commit/31911a4d6328d889d98796b896412b3997f73e13) Remove outdated safety example from doc ([#1050](https://github.com/apache/arrow-rs/pull/1050)) ([#1058](https://github.com/apache/arrow-rs/pull/1058)) +* [71ac8620993a65a7f1f57278c3495556625356b3](https://github.com/apache/arrow-rs/commit/71ac8620993a65a7f1f57278c3495556625356b3) Use existing array type in `take` kernel ([#1046](https://github.com/apache/arrow-rs/pull/1046)) ([#1057](https://github.com/apache/arrow-rs/pull/1057)) +* [1c5902376b7f7d56cb5249db4f98a6a370ead919](https://github.com/apache/arrow-rs/commit/1c5902376b7f7d56cb5249db4f98a6a370ead919) Extract method to drive PageIterator -> RecordReader ([#1031](https://github.com/apache/arrow-rs/pull/1031)) ([#1056](https://github.com/apache/arrow-rs/pull/1056)) +* [7ca39361f8733b86bc0cef5ed5d74093e2c6b14d](https://github.com/apache/arrow-rs/commit/7ca39361f8733b86bc0cef5ed5d74093e2c6b14d) Clarify governance of arrow crate ([#1030](https://github.com/apache/arrow-rs/pull/1030)) ([#1055](https://github.com/apache/arrow-rs/pull/1055)) + + +## [6.4.0](https://github.com/apache/arrow-rs/tree/6.4.0) (2021-12-10) + + +[Full Changelog](https://github.com/apache/arrow-rs/compare/6.3.0...6.4.0) + + +* [049f48559f578243935b6e512d06c4c2df360bf1](https://github.com/apache/arrow-rs/commit/049f48559f578243935b6e512d06c4c2df360bf1) Force new cargo and target caching to fix CI ([#1023](https://github.com/apache/arrow-rs/pull/1023)) ([#1024](https://github.com/apache/arrow-rs/pull/1024)) +* [ef37da3b60f71a52d5ad67e9ca810dca38b29f00](https://github.com/apache/arrow-rs/commit/ef37da3b60f71a52d5ad67e9ca810dca38b29f00) Fix a broken link and some missing styling in the main arrow crate docs ([#1013](https://github.com/apache/arrow-rs/pull/1013)) ([#1019](https://github.com/apache/arrow-rs/pull/1019)) +* [f2c746a9b968714cfe05d35fcee8658371acd899](https://github.com/apache/arrow-rs/commit/f2c746a9b968714cfe05d35fcee8658371acd899) Remove out of date comment ([#1008](https://github.com/apache/arrow-rs/pull/1008)) ([#1018](https://github.com/apache/arrow-rs/pull/1018)) +* [557fc11e3b2a09a680c0cfbf38d27b13101b63fe](https://github.com/apache/arrow-rs/commit/557fc11e3b2a09a680c0cfbf38d27b13101b63fe) Remove unneeded `rc` feature of serde ([#990](https://github.com/apache/arrow-rs/pull/990)) ([#1016](https://github.com/apache/arrow-rs/pull/1016)) +* [b28385e096b1cf8f5fb2773d49b160f93d94fbac](https://github.com/apache/arrow-rs/commit/b28385e096b1cf8f5fb2773d49b160f93d94fbac) Docstrings for Timestamp*Array. ([#988](https://github.com/apache/arrow-rs/pull/988)) ([#1015](https://github.com/apache/arrow-rs/pull/1015)) +* [a92672e40217670d2566a85d70b0b59fffac594c](https://github.com/apache/arrow-rs/commit/a92672e40217670d2566a85d70b0b59fffac594c) Add full data validation for ArrayData::try_new() ([#1007](https://github.com/apache/arrow-rs/pull/1007)) +* [6c8b2936d7b07e1e2f5d1d48eea425a385382dfb](https://github.com/apache/arrow-rs/commit/6c8b2936d7b07e1e2f5d1d48eea425a385382dfb) Add boolean comparison to scalar kernels for less then, greater than ([#977](https://github.com/apache/arrow-rs/pull/977)) ([#1005](https://github.com/apache/arrow-rs/pull/1005)) +* [14d140aeca608a23a8a6b2c251c8f53ffd377e61](https://github.com/apache/arrow-rs/commit/14d140aeca608a23a8a6b2c251c8f53ffd377e61) Fix some typos in code and comments ([#985](https://github.com/apache/arrow-rs/pull/985)) ([#1006](https://github.com/apache/arrow-rs/pull/1006)) +* [b4507f562fb0eddfb79840871cd2733dc0e337cd](https://github.com/apache/arrow-rs/commit/b4507f562fb0eddfb79840871cd2733dc0e337cd) Fix warnings introduced by Rust/Clippy 1.57.0 ([#1004](https://github.com/apache/arrow-rs/pull/1004)) + + +## [6.3.0](https://github.com/apache/arrow-rs/tree/6.3.0) (2021-11-26) + + +[Full Changelog](https://github.com/apache/arrow-rs/compare/6.2.0...6.3.0) + + +**Changes:** +* [7e51df015ce851a5de444ca08b57b38e7ee959a3](https://github.com/apache/arrow-rs/commit/7e51df015ce851a5de444ca08b57b38e7ee959a3) add more error test case and change the code style ([#952](https://github.com/apache/arrow-rs/pull/952)) ([#976](https://github.com/apache/arrow-rs/pull/976)) +* [6c570cfe98d6a7a4ec74b139b733c5c72ed10015](https://github.com/apache/arrow-rs/commit/6c570cfe98d6a7a4ec74b139b733c5c72ed10015) Support read decimal data from csv reader if user provide the schema with decimal data type ([#941](https://github.com/apache/arrow-rs/pull/941)) ([#974](https://github.com/apache/arrow-rs/pull/974)) +* [4fa0d4d7f7d9ca0a3da2a6dfe3eae6dc2d51a79a](https://github.com/apache/arrow-rs/commit/4fa0d4d7f7d9ca0a3da2a6dfe3eae6dc2d51a79a) Adding Pretty Print Support For Fixed Size List ([#958](https://github.com/apache/arrow-rs/pull/958)) ([#968](https://github.com/apache/arrow-rs/pull/968)) +* [9d453a3128013c03e8ed854ded76b15cc6f28be4](https://github.com/apache/arrow-rs/commit/9d453a3128013c03e8ed854ded76b15cc6f28be4) Fix bug in temporal utilities due to DST being ignored. ([#955](https://github.com/apache/arrow-rs/pull/955)) ([#967](https://github.com/apache/arrow-rs/pull/967)) +* [1b9fd9e3fb2653236513bb7dda5aa2fa14d1d831](https://github.com/apache/arrow-rs/commit/1b9fd9e3fb2653236513bb7dda5aa2fa14d1d831) Inferring 2. as Float64 for issue [#929](https://github.com/apache/arrow-rs/pull/929) ([#950](https://github.com/apache/arrow-rs/pull/950)) ([#966](https://github.com/apache/arrow-rs/pull/966)) +* [e6c5e1c877bd94b3d6e545567f901d9962257cf8](https://github.com/apache/arrow-rs/commit/e6c5e1c877bd94b3d6e545567f901d9962257cf8) Fix CI for latest nightly ([#970](https://github.com/apache/arrow-rs/pull/970)) ([#973](https://github.com/apache/arrow-rs/pull/973)) +* [c96e8de457442806e18944f0b26dd06ba4cb1aee](https://github.com/apache/arrow-rs/commit/c96e8de457442806e18944f0b26dd06ba4cb1aee) Fix primitive sort when input contains more nulls than the given sort limit ([#954](https://github.com/apache/arrow-rs/pull/954)) ([#965](https://github.com/apache/arrow-rs/pull/965)) +* [094037d418381584178db1d886cad3b5024b414a](https://github.com/apache/arrow-rs/commit/094037d418381584178db1d886cad3b5024b414a) Update comfy-table to 5.0 ([#957](https://github.com/apache/arrow-rs/pull/957)) ([#964](https://github.com/apache/arrow-rs/pull/964)) +* [9f635021eee6786c5377c891218c5f88ebce07c3](https://github.com/apache/arrow-rs/commit/9f635021eee6786c5377c891218c5f88ebce07c3) Fix csv writing of timestamps to show timezone. ([#849](https://github.com/apache/arrow-rs/pull/849)) ([#963](https://github.com/apache/arrow-rs/pull/963)) +* [f7deba4c3a050a52608462ee8a827bb8f6364140](https://github.com/apache/arrow-rs/commit/f7deba4c3a050a52608462ee8a827bb8f6364140) Adding ability to parse float from number with leading decimal ([#831](https://github.com/apache/arrow-rs/pull/831)) ([#962](https://github.com/apache/arrow-rs/pull/962)) +* [59f96e842d05b63882f7ba285c66a9739761cf84](https://github.com/apache/arrow-rs/commit/59f96e842d05b63882f7ba285c66a9739761cf84) add ilike comparitor ([#874](https://github.com/apache/arrow-rs/pull/874)) ([#961](https://github.com/apache/arrow-rs/pull/961)) +* [54023c8a5543c9f9fa4955afa01189029f3e96f5](https://github.com/apache/arrow-rs/commit/54023c8a5543c9f9fa4955afa01189029f3e96f5) Remove unpassable cargo publish check from verify-release-candidate.sh ([#882](https://github.com/apache/arrow-rs/pull/882)) ([#949](https://github.com/apache/arrow-rs/pull/949)) + + + +## [6.2.0](https://github.com/apache/arrow-rs/tree/6.2.0) (2021-11-12) + + +[Full Changelog](https://github.com/apache/arrow-rs/compare/6.1.0...6.2.0) + +**Features / Fixes:** + + +* [4037933e43cad9e4de027039ce14caa65f78300a](https://github.com/apache/arrow-rs/commit/4037933e43cad9e4de027039ce14caa65f78300a) Fix validation for offsets of StructArrays ([#942](https://github.com/apache/arrow-rs/pull/942)) ([#946](https://github.com/apache/arrow-rs/pull/946)) +* [1af9ca5d363d870550026a7b1abcb749befbb371](https://github.com/apache/arrow-rs/commit/1af9ca5d363d870550026a7b1abcb749befbb371) implement take kernel for null arrays ([#939](https://github.com/apache/arrow-rs/pull/939)) ([#944](https://github.com/apache/arrow-rs/pull/944)) +* [320de1c20aefbf204f6888e2ad3663863afeba9f](https://github.com/apache/arrow-rs/commit/320de1c20aefbf204f6888e2ad3663863afeba9f) add checker for appending i128 to decimal builder ([#928](https://github.com/apache/arrow-rs/pull/928)) ([#943](https://github.com/apache/arrow-rs/pull/943)) +* [dff14113884ad4246a8cafb9be579ebdb4e1481f](https://github.com/apache/arrow-rs/commit/dff14113884ad4246a8cafb9be579ebdb4e1481f) Validate arguments to ArrayData::new and null bit buffer and buffers ([#810](https://github.com/apache/arrow-rs/pull/810)) ([#936](https://github.com/apache/arrow-rs/pull/936)) +* [c3eae1ec56303b97c9e15263063a6a13122ef194](https://github.com/apache/arrow-rs/commit/c3eae1ec56303b97c9e15263063a6a13122ef194) fix some warning about unused variables in panic tests ([#894](https://github.com/apache/arrow-rs/pull/894)) ([#933](https://github.com/apache/arrow-rs/pull/933)) +* [e80bb018450f13a30811ffd244c42917d8bf8a62](https://github.com/apache/arrow-rs/commit/e80bb018450f13a30811ffd244c42917d8bf8a62) fix some clippy warnings ([#896](https://github.com/apache/arrow-rs/pull/896)) ([#930](https://github.com/apache/arrow-rs/pull/930)) +* [bde89463b627be3f60b5569d038ca36c434da71d](https://github.com/apache/arrow-rs/commit/bde89463b627be3f60b5569d038ca36c434da71d) feat(ipc): add support for deserializing messages with nested dictionary fields ([#923](https://github.com/apache/arrow-rs/pull/923)) ([#931](https://github.com/apache/arrow-rs/pull/931)) +* [792544b5fb7b84224ef9745ecb9f330663c14fb4](https://github.com/apache/arrow-rs/commit/792544b5fb7b84224ef9745ecb9f330663c14fb4) refactor regexp_is_match_utf8_scalar to try to mitigate miri failures ([#895](https://github.com/apache/arrow-rs/pull/895)) ([#932](https://github.com/apache/arrow-rs/pull/932)) +* [3f0e252811cbb6e3f7c774959787dcfec985d03e](https://github.com/apache/arrow-rs/commit/3f0e252811cbb6e3f7c774959787dcfec985d03e) Automatically retry failed MIRI runs to work around intermittent failures ([#934](https://github.com/apache/arrow-rs/pull/934)) +* [c9a9515c46d560ced00e23ff57cb10a1c97573cb](https://github.com/apache/arrow-rs/commit/c9a9515c46d560ced00e23ff57cb10a1c97573cb) Update mod.rs ([#909](https://github.com/apache/arrow-rs/pull/909)) ([#919](https://github.com/apache/arrow-rs/pull/919)) +* [64ed79ece67141b92dc45b8a1d43cb9d909aa6a9](https://github.com/apache/arrow-rs/commit/64ed79ece67141b92dc45b8a1d43cb9d909aa6a9) Mark boolean kernels public ([#913](https://github.com/apache/arrow-rs/pull/913)) ([#920](https://github.com/apache/arrow-rs/pull/920)) +* [8b95fe0bbf03588c5cc00f67365c5b0dac4d7a34](https://github.com/apache/arrow-rs/commit/8b95fe0bbf03588c5cc00f67365c5b0dac4d7a34) doc example mistype ([#904](https://github.com/apache/arrow-rs/pull/904)) ([#918](https://github.com/apache/arrow-rs/pull/918)) +* [34c5eab4862cab16fdfd5f5ed6c68dce6298dfa4](https://github.com/apache/arrow-rs/commit/34c5eab4862cab16fdfd5f5ed6c68dce6298dfa4) allow null array to be cast to all other types ([#884](https://github.com/apache/arrow-rs/pull/884)) ([#917](https://github.com/apache/arrow-rs/pull/917)) +* [3c69752e55ed0c58f5a8faed918a22b45cd93766](https://github.com/apache/arrow-rs/commit/3c69752e55ed0c58f5a8faed918a22b45cd93766) Fix instances of UB that cause tests to not pass under miri ([#878](https://github.com/apache/arrow-rs/pull/878)) ([#916](https://github.com/apache/arrow-rs/pull/916)) +* [85402148c3af03d0855e81f855715ea98a7491c5](https://github.com/apache/arrow-rs/commit/85402148c3af03d0855e81f855715ea98a7491c5) feat(ipc): Support writing dictionaries nested in structs and unions ([#870](https://github.com/apache/arrow-rs/pull/870)) ([#915](https://github.com/apache/arrow-rs/pull/915)) +* [03d95e626cb0e654775fefa77786674ea41be4a2](https://github.com/apache/arrow-rs/commit/03d95e626cb0e654775fefa77786674ea41be4a2) Fix references to changelog ([#905](https://github.com/apache/arrow-rs/pull/905)) + + +## [6.1.0](https://github.com/apache/arrow-rs/tree/6.1.0) (2021-10-29) + + +[Full Changelog](https://github.com/apache/arrow-rs/compare/6.0.0...6.1.0) + +**Features / Fixes:** + +* [b42649b0088fe7762c713a41a23c1abdf8d0496d](https://github.com/apache/arrow-rs/commit/b42649b0088fe7762c713a41a23c1abdf8d0496d) implement eq_dyn and neq_dyn ([#858](https://github.com/apache/arrow-rs/pull/858)) ([#867](https://github.com/apache/arrow-rs/pull/867)) +* [01743f3f10a377c1ca857cd554acbf84155766d8](https://github.com/apache/arrow-rs/commit/01743f3f10a377c1ca857cd554acbf84155766d8) fix: fix a bug in offset calculation for unions ([#863](https://github.com/apache/arrow-rs/pull/863)) ([#871](https://github.com/apache/arrow-rs/pull/871)) +* [8bfff793a23f0e71008c7a9eea7a54d6b913ecff](https://github.com/apache/arrow-rs/commit/8bfff793a23f0e71008c7a9eea7a54d6b913ecff) add lt_bool, lt_eq_bool, gt_bool, gt_eq_bool ([#860](https://github.com/apache/arrow-rs/pull/860)) ([#868](https://github.com/apache/arrow-rs/pull/868)) +* [8845e91d4ab584c822e9ee903db7069551b124af](https://github.com/apache/arrow-rs/commit/8845e91d4ab584c822e9ee903db7069551b124af) fix(ipc): Support serializing structs containing dictionaries ([#848](https://github.com/apache/arrow-rs/pull/848)) ([#865](https://github.com/apache/arrow-rs/pull/865)) +* [620282a0d9fdd2a8ed7e8313d17ba3dec64c80e5](https://github.com/apache/arrow-rs/commit/620282a0d9fdd2a8ed7e8313d17ba3dec64c80e5) Implement boolean equality kernels ([#844](https://github.com/apache/arrow-rs/pull/844)) ([#857](https://github.com/apache/arrow-rs/pull/857)) +* [94cddcacf785be982e69689291ce034ef00220b4](https://github.com/apache/arrow-rs/commit/94cddcacf785be982e69689291ce034ef00220b4) Cherry pick fix parquet_derive with default features (and fix cargo publish) ([#856](https://github.com/apache/arrow-rs/pull/856)) +* [733fd583ddb3dbe6b4d58a809c444ee16ac0eae8](https://github.com/apache/arrow-rs/commit/733fd583ddb3dbe6b4d58a809c444ee16ac0eae8) Use kernel utility for parsing timestamps in csv reader. ([#832](https://github.com/apache/arrow-rs/pull/832)) ([#853](https://github.com/apache/arrow-rs/pull/853)) +* [2cc64937a153f632796915d2d9869d5c2a501d28](https://github.com/apache/arrow-rs/commit/2cc64937a153f632796915d2d9869d5c2a501d28) [Minor] Fix clippy errors with new rust version (1.56) and float formatting with nightly ([#845](https://github.com/apache/arrow-rs/pull/845)) ([#850](https://github.com/apache/arrow-rs/pull/850)) + +**Other:** +* [bfac9e5a027e3bd78b7a1ec90c75a3e385bd66bb](https://github.com/apache/arrow-rs/commit/bfac9e5a027e3bd78b7a1ec90c75a3e385bd66bb) Test out new tarpaulin version ([#852](https://github.com/apache/arrow-rs/pull/852)) ([#866](https://github.com/apache/arrow-rs/pull/866)) +* [809350ced392cfc78d8a1a46228d4ffc25dea9ff](https://github.com/apache/arrow-rs/commit/809350ced392cfc78d8a1a46228d4ffc25dea9ff) Update README.md ([#834](https://github.com/apache/arrow-rs/pull/834)) ([#854](https://github.com/apache/arrow-rs/pull/854)) +* [70582f40dd21f5c710c4946266d0563a92b92337](https://github.com/apache/arrow-rs/commit/70582f40dd21f5c710c4946266d0563a92b92337) [MINOR] Delete temp file from docs ([#836](https://github.com/apache/arrow-rs/pull/836)) ([#855](https://github.com/apache/arrow-rs/pull/855)) +* [a721e00014015a7e598946b6efb9b1da8080ec85](https://github.com/apache/arrow-rs/commit/a721e00014015a7e598946b6efb9b1da8080ec85) Force fresh cargo cache key in CI ([#839](https://github.com/apache/arrow-rs/pull/839)) ([#851](https://github.com/apache/arrow-rs/pull/851)) + + +## [6.0.0](https://github.com/apache/arrow-rs/tree/6.0.0) (2021-10-13) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/5.5.0...6.0.0) + +**Breaking changes:** + +- Replace `ArrayData::new()` with `ArrayData::try_new()` and `unsafe ArrayData::new_unchecked` [\#822](https://github.com/apache/arrow-rs/pull/822) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Update Bitmap::len to return bits rather than bytes [\#749](https://github.com/apache/arrow-rs/pull/749) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([matthewmturner](https://github.com/matthewmturner)) +- use sort\_unstable\_by in primitive sorting [\#552](https://github.com/apache/arrow-rs/pull/552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- New MapArray support [\#491](https://github.com/apache/arrow-rs/pull/491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nevi-me](https://github.com/nevi-me)) + +**Implemented enhancements:** + +- Improve parquet binary writer speed by reducing allocations [\#819](https://github.com/apache/arrow-rs/issues/819) +- Expose buffer operations [\#808](https://github.com/apache/arrow-rs/issues/808) +- Add doc examples of writing parquet files using `ArrowWriter` [\#788](https://github.com/apache/arrow-rs/issues/788) + +**Fixed bugs:** + +- JSON reader can create null struct children on empty lists [\#825](https://github.com/apache/arrow-rs/issues/825) +- Incorrect null count for cast kernel for list arrays [\#815](https://github.com/apache/arrow-rs/issues/815) +- `minute` and `second` temporal kernels do not respect timezone [\#500](https://github.com/apache/arrow-rs/issues/500) +- Fix data corruption in json decoder f64-to-i64 cast [\#652](https://github.com/apache/arrow-rs/pull/652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xianwill](https://github.com/xianwill)) + +**Documentation updates:** + +- Doctest for PrimitiveArray using from\_iter\_values. [\#694](https://github.com/apache/arrow-rs/pull/694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([novemberkilo](https://github.com/novemberkilo)) +- Doctests for BinaryArray and LargeBinaryArray. [\#625](https://github.com/apache/arrow-rs/pull/625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([novemberkilo](https://github.com/novemberkilo)) +- Add links in docstrings [\#605](https://github.com/apache/arrow-rs/pull/605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + + +## [5.5.0](https://github.com/apache/arrow-rs/tree/5.5.0) (2021-09-24) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/5.4.0...5.5.0) + +**Implemented enhancements:** + +- parquet should depend on a small set of arrow features [\#800](https://github.com/apache/arrow-rs/issues/800) +- Support equality on RecordBatch [\#735](https://github.com/apache/arrow-rs/issues/735) + +**Fixed bugs:** + +- Converting from string to timestamp uses microseconds instead of milliseconds [\#780](https://github.com/apache/arrow-rs/issues/780) +- Document has no link to `RowColumIter` [\#762](https://github.com/apache/arrow-rs/issues/762) +- length on slices with null doesn't work [\#744](https://github.com/apache/arrow-rs/issues/744) + +## [5.4.0](https://github.com/apache/arrow-rs/tree/5.4.0) (2021-09-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/5.3.0...5.4.0) + +**Implemented enhancements:** + +- Upgrade lexical-core to 0.8 [\#747](https://github.com/apache/arrow-rs/issues/747) +- `append_nulls` and `append_trusted_len_iter` for PrimitiveBuilder [\#725](https://github.com/apache/arrow-rs/issues/725) +- Optimize MutableArrayData::extend for null buffers [\#397](https://github.com/apache/arrow-rs/issues/397) + +**Fixed bugs:** + +- Arithmetic with scalars doesn't work on slices [\#742](https://github.com/apache/arrow-rs/issues/742) +- Comparisons with scalar don't work on slices [\#740](https://github.com/apache/arrow-rs/issues/740) +- `unary` kernel doesn't respect offset [\#738](https://github.com/apache/arrow-rs/issues/738) +- `new_null_array` creates invalid struct arrays [\#734](https://github.com/apache/arrow-rs/issues/734) +- --no-default-features is broken for parquet [\#733](https://github.com/apache/arrow-rs/issues/733) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `Bitmap::len` returns the number of bytes, not bits. [\#730](https://github.com/apache/arrow-rs/issues/730) +- Decimal logical type is formatted incorrectly by print\_schema [\#713](https://github.com/apache/arrow-rs/issues/713) +- parquet\_derive does not support chrono time values [\#711](https://github.com/apache/arrow-rs/issues/711) +- Numeric overflow when formatting Decimal type [\#710](https://github.com/apache/arrow-rs/issues/710) +- The integration tests are not running [\#690](https://github.com/apache/arrow-rs/issues/690) + +**Closed issues:** + +- Question: Is there no way to create a DictionaryArray with a pre-arranged mapping? [\#729](https://github.com/apache/arrow-rs/issues/729) + +## [5.3.0](https://github.com/apache/arrow-rs/tree/5.3.0) (2021-08-26) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/5.2.0...5.3.0) + +**Implemented enhancements:** + +- Add optimized filter kernel for regular expression matching [\#697](https://github.com/apache/arrow-rs/issues/697) +- Can't cast from timestamp array to string array [\#587](https://github.com/apache/arrow-rs/issues/587) + +**Fixed bugs:** + +- 'Encoding DELTA\_BYTE\_ARRAY is not supported' with parquet arrow readers [\#708](https://github.com/apache/arrow-rs/issues/708) +- Support reading json string into binary data type. [\#701](https://github.com/apache/arrow-rs/issues/701) + +**Closed issues:** + +- Resolve Issues with `prettytable-rs` dependency [\#69](https://github.com/apache/arrow-rs/issues/69) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +## [5.2.0](https://github.com/apache/arrow-rs/tree/5.2.0) (2021-08-12) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/5.1.0...5.2.0) + +**Implemented enhancements:** + +- Make rand an optional dependency [\#671](https://github.com/apache/arrow-rs/issues/671) +- Remove undefined behavior in `value` method of boolean and primitive arrays [\#645](https://github.com/apache/arrow-rs/issues/645) +- Avoid materialization of indices in filter\_record\_batch for single arrays [\#636](https://github.com/apache/arrow-rs/issues/636) +- Add a note about arrow crate security / safety [\#627](https://github.com/apache/arrow-rs/issues/627) +- Allow the creation of String arrays from an interator of &Option\<&str\> [\#598](https://github.com/apache/arrow-rs/issues/598) +- Support arrow map datatype [\#395](https://github.com/apache/arrow-rs/issues/395) + +**Fixed bugs:** + +- Parquet fixed length byte array columns write byte array statistics [\#660](https://github.com/apache/arrow-rs/issues/660) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet boolean columns write Int32 statistics [\#659](https://github.com/apache/arrow-rs/issues/659) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Writing Parquet with a boolean column fails [\#657](https://github.com/apache/arrow-rs/issues/657) +- JSON decoder data corruption for large i64/u64 [\#653](https://github.com/apache/arrow-rs/issues/653) +- Incorrect min/max statistics for strings in parquet files [\#641](https://github.com/apache/arrow-rs/issues/641) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Closed issues:** + +- Release candidate verifying script seems work on macOS [\#640](https://github.com/apache/arrow-rs/issues/640) +- Update CONTRIBUTING [\#342](https://github.com/apache/arrow-rs/issues/342) + +## [5.1.0](https://github.com/apache/arrow-rs/tree/5.1.0) (2021-07-29) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/5.0.0...5.1.0) + +**Implemented enhancements:** + +- Make FFI\_ArrowArray empty\(\) public [\#602](https://github.com/apache/arrow-rs/issues/602) +- exponential sort can be used to speed up lexico partition kernel [\#586](https://github.com/apache/arrow-rs/issues/586) +- Implement sort\(\) for binary array [\#568](https://github.com/apache/arrow-rs/issues/568) +- primitive sorting can be improved and more consistent with and without `limit` if sorted unstably [\#553](https://github.com/apache/arrow-rs/issues/553) + +**Fixed bugs:** + +- Confusing memory usage with CSV reader [\#623](https://github.com/apache/arrow-rs/issues/623) +- FFI implementation deviates from specification for array release [\#595](https://github.com/apache/arrow-rs/issues/595) +- Parquet file content is different if `~/.cargo` is in a git checkout [\#589](https://github.com/apache/arrow-rs/issues/589) +- Ensure output of MIRI is checked for success [\#581](https://github.com/apache/arrow-rs/issues/581) +- MIRI failure in `array::ffi::tests::test_struct` and other ffi tests [\#580](https://github.com/apache/arrow-rs/issues/580) +- ListArray equality check may return wrong result [\#570](https://github.com/apache/arrow-rs/issues/570) +- cargo audit failed [\#561](https://github.com/apache/arrow-rs/issues/561) +- ArrayData::slice\(\) does not work for nested types such as StructArray [\#554](https://github.com/apache/arrow-rs/issues/554) + +**Documentation updates:** + +- More examples of how to construct Arrays [\#301](https://github.com/apache/arrow-rs/issues/301) + +**Closed issues:** + +- Implement StringBuilder::append\_option [\#263](https://github.com/apache/arrow-rs/issues/263) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +## [5.0.0](https://github.com/apache/arrow-rs/tree/5.0.0) (2021-07-14) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/4.4.0...5.0.0) + +**Breaking changes:** + +- Remove lifetime from DynComparator [\#543](https://github.com/apache/arrow-rs/issues/543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Simplify interactions with arrow flight APIs [\#376](https://github.com/apache/arrow-rs/issues/376) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- refactor: remove lifetime from DynComparator [\#542](https://github.com/apache/arrow-rs/pull/542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([e-dard](https://github.com/e-dard)) +- use iterator for partition kernel instead of generating vec [\#438](https://github.com/apache/arrow-rs/pull/438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- Remove DictionaryArray::keys\_array method [\#419](https://github.com/apache/arrow-rs/pull/419) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- simplify interactions with arrow flight APIs [\#377](https://github.com/apache/arrow-rs/pull/377) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([garyanaplan](https://github.com/garyanaplan)) +- return reference from DictionaryArray::values\(\) \(\#313\) [\#314](https://github.com/apache/arrow-rs/pull/314) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Allow creation of StringArrays from Vec\ [\#519](https://github.com/apache/arrow-rs/issues/519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement RecordBatch::concat [\#461](https://github.com/apache/arrow-rs/issues/461) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement RecordBatch::slice\(\) to slice RecordBatches [\#460](https://github.com/apache/arrow-rs/issues/460) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add a RecordBatch::split to split large batches into a set of smaller batches [\#343](https://github.com/apache/arrow-rs/issues/343) +- generate parquet schema from rust struct [\#539](https://github.com/apache/arrow-rs/pull/539) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) +- Implement `RecordBatch::concat` [\#537](https://github.com/apache/arrow-rs/pull/537) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([silathdiir](https://github.com/silathdiir)) +- Implement function slice for RecordBatch [\#490](https://github.com/apache/arrow-rs/pull/490) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([b41sh](https://github.com/b41sh)) +- add lexicographically partition points and ranges [\#424](https://github.com/apache/arrow-rs/pull/424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- allow to read non-standard CSV [\#326](https://github.com/apache/arrow-rs/pull/326) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kazuk](https://github.com/kazuk)) +- parquet: Speed up `BitReader`/`DeltaBitPackDecoder` [\#325](https://github.com/apache/arrow-rs/pull/325) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kornholi](https://github.com/kornholi)) +- ARROW-12343: \[Rust\] Support auto-vectorization for min/max [\#9](https://github.com/apache/arrow-rs/pull/9) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- ARROW-12411: \[Rust\] Create RecordBatches from Iterators [\#7](https://github.com/apache/arrow-rs/pull/7) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Fixed bugs:** + +- Error building on master - error: cyclic package dependency: package `ahash v0.7.4` depends on itself. Cycle [\#544](https://github.com/apache/arrow-rs/issues/544) +- IPC reader panics with out of bounds error [\#541](https://github.com/apache/arrow-rs/issues/541) +- Take kernel doesn't handle nulls and structs correctly [\#530](https://github.com/apache/arrow-rs/issues/530) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- master fails to compile with `default-features=false` [\#529](https://github.com/apache/arrow-rs/issues/529) +- README developer instructions out of date [\#523](https://github.com/apache/arrow-rs/issues/523) +- Update rustc and packed\_simd in CI before 5.0 release [\#517](https://github.com/apache/arrow-rs/issues/517) +- Incorrect memory usage calculation for dictionary arrays [\#503](https://github.com/apache/arrow-rs/issues/503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- sliced null buffers lead to incorrect result in take kernel \(and probably on other places\) [\#502](https://github.com/apache/arrow-rs/issues/502) +- Cast of utf8 types and list container types don't respect offset [\#334](https://github.com/apache/arrow-rs/issues/334) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- fix take kernel null handling on structs [\#531](https://github.com/apache/arrow-rs/pull/531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([bjchambers](https://github.com/bjchambers)) +- Correct array memory usage calculation for dictionary arrays [\#505](https://github.com/apache/arrow-rs/pull/505) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- parquet: improve BOOLEAN writing logic and report error on encoding fail [\#443](https://github.com/apache/arrow-rs/pull/443) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([garyanaplan](https://github.com/garyanaplan)) +- Fix bug with null buffer offset in boolean not kernel [\#418](https://github.com/apache/arrow-rs/pull/418) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- respect offset in utf8 and list casts [\#335](https://github.com/apache/arrow-rs/pull/335) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) +- Fix comparison of dictionaries with different values arrays \(\#332\) [\#333](https://github.com/apache/arrow-rs/pull/333) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- ensure null-counts are written for all-null columns [\#307](https://github.com/apache/arrow-rs/pull/307) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([crepererum](https://github.com/crepererum)) +- fix invalid null handling in filter [\#296](https://github.com/apache/arrow-rs/pull/296) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) +- fix NaN handling in parquet statistics [\#256](https://github.com/apache/arrow-rs/pull/256) ([crepererum](https://github.com/crepererum)) + +**Documentation updates:** + +- Improve arrow's crate's readme on crates.io [\#463](https://github.com/apache/arrow-rs/issues/463) +- Clean up README.md in advance of the 5.0 release [\#536](https://github.com/apache/arrow-rs/pull/536) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- fix readme instructions to reflect new structure [\#524](https://github.com/apache/arrow-rs/pull/524) ([marcvanheerden](https://github.com/marcvanheerden)) +- Improve docs for NullArray, new\_null\_array and new\_empty\_array [\#240](https://github.com/apache/arrow-rs/pull/240) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- Fix default arrow build [\#533](https://github.com/apache/arrow-rs/pull/533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add tests for building applications using arrow with different feature flags [\#532](https://github.com/apache/arrow-rs/pull/532) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Remove unused futures dependency from arrow-flight [\#528](https://github.com/apache/arrow-rs/pull/528) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- CI: update rust nightly and packed\_simd [\#525](https://github.com/apache/arrow-rs/pull/525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) +- Support `StringArray` creation from String Vec [\#522](https://github.com/apache/arrow-rs/pull/522) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([silathdiir](https://github.com/silathdiir)) +- Fix parquet benchmark schema [\#513](https://github.com/apache/arrow-rs/pull/513) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) +- Fix parquet definition levels [\#511](https://github.com/apache/arrow-rs/pull/511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) +- Fix for primitive and boolean take kernel for nullable indices with an offset [\#509](https://github.com/apache/arrow-rs/pull/509) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Bump flatbuffers [\#499](https://github.com/apache/arrow-rs/pull/499) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([PsiACE](https://github.com/PsiACE)) +- implement second/minute helpers for temporal [\#493](https://github.com/apache/arrow-rs/pull/493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ovr](https://github.com/ovr)) +- special case concatenating single element array shortcut [\#492](https://github.com/apache/arrow-rs/pull/492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) +- update docs to reflect recent changes \(joins and window functions\) [\#489](https://github.com/apache/arrow-rs/pull/489) ([Jimexist](https://github.com/Jimexist)) +- Update rand, proc-macro and zstd dependencies [\#488](https://github.com/apache/arrow-rs/pull/488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Doctest for GenericListArray. [\#474](https://github.com/apache/arrow-rs/pull/474) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([novemberkilo](https://github.com/novemberkilo)) +- remove stale comment on `ArrayData` equality and update unit tests [\#472](https://github.com/apache/arrow-rs/pull/472) ([Jimexist](https://github.com/Jimexist)) +- remove unused patch file [\#471](https://github.com/apache/arrow-rs/pull/471) ([Jimexist](https://github.com/Jimexist)) +- fix clippy warnings for rust 1.53 [\#470](https://github.com/apache/arrow-rs/pull/470) ([Jimexist](https://github.com/Jimexist)) +- Fix PR labeler [\#468](https://github.com/apache/arrow-rs/pull/468) ([Dandandan](https://github.com/Dandandan)) +- Tweak dev backporting docs [\#466](https://github.com/apache/arrow-rs/pull/466) ([alamb](https://github.com/alamb)) +- Unvendor Archery [\#459](https://github.com/apache/arrow-rs/pull/459) ([kszucs](https://github.com/kszucs)) +- Add sort boolean benchmark [\#457](https://github.com/apache/arrow-rs/pull/457) ([alamb](https://github.com/alamb)) +- Add C data interface for decimal128 and timestamp [\#453](https://github.com/apache/arrow-rs/pull/453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alippai](https://github.com/alippai)) +- Implement the Iterator trait for the json Reader. [\#451](https://github.com/apache/arrow-rs/pull/451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([LaurentMazare](https://github.com/LaurentMazare)) +- Update release docs + release email template [\#450](https://github.com/apache/arrow-rs/pull/450) ([alamb](https://github.com/alamb)) +- remove clippy unnecessary wraps suppresions in cast kernel [\#449](https://github.com/apache/arrow-rs/pull/449) ([Jimexist](https://github.com/Jimexist)) +- Use partition for bool sort [\#448](https://github.com/apache/arrow-rs/pull/448) ([Jimexist](https://github.com/Jimexist)) +- remove unnecessary wraps in sort [\#445](https://github.com/apache/arrow-rs/pull/445) ([Jimexist](https://github.com/Jimexist)) +- Python FFI bridge for Schema, Field and DataType [\#439](https://github.com/apache/arrow-rs/pull/439) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszucs](https://github.com/kszucs)) +- Update release Readme.md [\#436](https://github.com/apache/arrow-rs/pull/436) ([alamb](https://github.com/alamb)) +- Derive Eq and PartialEq for SortOptions [\#425](https://github.com/apache/arrow-rs/pull/425) ([tustvold](https://github.com/tustvold)) +- refactor lexico sort for future code reuse [\#423](https://github.com/apache/arrow-rs/pull/423) ([Jimexist](https://github.com/Jimexist)) +- Reenable MIRI check on PRs [\#421](https://github.com/apache/arrow-rs/pull/421) ([alamb](https://github.com/alamb)) +- Sort by float lists [\#420](https://github.com/apache/arrow-rs/pull/420) ([medwards](https://github.com/medwards)) +- Fix out of bounds read in bit chunk iterator [\#416](https://github.com/apache/arrow-rs/pull/416) ([jhorstmann](https://github.com/jhorstmann)) +- Doctests for DecimalArray. [\#414](https://github.com/apache/arrow-rs/pull/414) ([novemberkilo](https://github.com/novemberkilo)) +- Add Decimal to CsvWriter and improve debug display [\#406](https://github.com/apache/arrow-rs/pull/406) ([alippai](https://github.com/alippai)) +- MINOR: update install instruction [\#400](https://github.com/apache/arrow-rs/pull/400) ([alippai](https://github.com/alippai)) +- use prettier to auto format md files [\#398](https://github.com/apache/arrow-rs/pull/398) ([Jimexist](https://github.com/Jimexist)) +- window::shift to work for all array types [\#388](https://github.com/apache/arrow-rs/pull/388) ([Jimexist](https://github.com/Jimexist)) +- add more tests for window::shift and handle boundary cases [\#386](https://github.com/apache/arrow-rs/pull/386) ([Jimexist](https://github.com/Jimexist)) +- Implement faster arrow array reader [\#384](https://github.com/apache/arrow-rs/pull/384) ([yordan-pavlov](https://github.com/yordan-pavlov)) +- Add set\_bit to BooleanBufferBuilder to allow mutating bit in index [\#383](https://github.com/apache/arrow-rs/pull/383) ([boazberman](https://github.com/boazberman)) +- make sure that only concat preallocates buffers [\#382](https://github.com/apache/arrow-rs/pull/382) ([ritchie46](https://github.com/ritchie46)) +- Respect max rowgroup size in Arrow writer [\#381](https://github.com/apache/arrow-rs/pull/381) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) +- Fix typo in release script, update release location [\#380](https://github.com/apache/arrow-rs/pull/380) ([alamb](https://github.com/alamb)) +- Doctests for FixedSizeBinaryArray [\#378](https://github.com/apache/arrow-rs/pull/378) ([novemberkilo](https://github.com/novemberkilo)) +- Simplify shift kernel using new\_null\_array [\#370](https://github.com/apache/arrow-rs/pull/370) ([Dandandan](https://github.com/Dandandan)) +- allow `SliceableCursor` to be constructed from an `Arc` directly [\#369](https://github.com/apache/arrow-rs/pull/369) ([crepererum](https://github.com/crepererum)) +- Add doctest for ArrayBuilder [\#367](https://github.com/apache/arrow-rs/pull/367) ([alippai](https://github.com/alippai)) +- Fix version in readme [\#365](https://github.com/apache/arrow-rs/pull/365) ([domoritz](https://github.com/domoritz)) +- Remove superfluous space [\#363](https://github.com/apache/arrow-rs/pull/363) ([domoritz](https://github.com/domoritz)) +- Add crate badges [\#362](https://github.com/apache/arrow-rs/pull/362) ([domoritz](https://github.com/domoritz)) +- Disable MIRI check until it runs cleanly on CI [\#360](https://github.com/apache/arrow-rs/pull/360) ([alamb](https://github.com/alamb)) +- Only register Flight.proto with cargo if it exists [\#351](https://github.com/apache/arrow-rs/pull/351) ([tustvold](https://github.com/tustvold)) +- Reduce memory usage of concat \(large\)utf8 [\#348](https://github.com/apache/arrow-rs/pull/348) ([ritchie46](https://github.com/ritchie46)) +- Fix filter UB and add fast path [\#341](https://github.com/apache/arrow-rs/pull/341) ([ritchie46](https://github.com/ritchie46)) +- Automatic cherry-pick script [\#339](https://github.com/apache/arrow-rs/pull/339) ([alamb](https://github.com/alamb)) +- Doctests for BooleanArray. [\#338](https://github.com/apache/arrow-rs/pull/338) ([novemberkilo](https://github.com/novemberkilo)) +- feature gate ipc reader/writer [\#336](https://github.com/apache/arrow-rs/pull/336) ([ritchie46](https://github.com/ritchie46)) +- Add ported Rust release verification script [\#331](https://github.com/apache/arrow-rs/pull/331) ([wesm](https://github.com/wesm)) +- Doctests for StringArray and LargeStringArray. [\#330](https://github.com/apache/arrow-rs/pull/330) ([novemberkilo](https://github.com/novemberkilo)) +- inline PrimitiveArray::value [\#329](https://github.com/apache/arrow-rs/pull/329) ([ritchie46](https://github.com/ritchie46)) +- Enable wasm32 as a target architecture for the SIMD feature [\#324](https://github.com/apache/arrow-rs/pull/324) ([roee88](https://github.com/roee88)) +- Fix undefined behavior in FFI and enable MIRI checks on CI [\#323](https://github.com/apache/arrow-rs/pull/323) ([roee88](https://github.com/roee88)) +- Mutablebuffer::shrink\_to\_fit [\#318](https://github.com/apache/arrow-rs/pull/318) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) +- Add \(simd\) modulus op [\#317](https://github.com/apache/arrow-rs/pull/317) ([gangliao](https://github.com/gangliao)) +- feature gate csv functionality [\#312](https://github.com/apache/arrow-rs/pull/312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) +- \[Minor\] Version upgrades [\#304](https://github.com/apache/arrow-rs/pull/304) ([Dandandan](https://github.com/Dandandan)) +- Remove old release scripts [\#293](https://github.com/apache/arrow-rs/pull/293) ([alamb](https://github.com/alamb)) +- Add Send to the ArrayBuilder trait [\#291](https://github.com/apache/arrow-rs/pull/291) ([Max-Meldrum](https://github.com/Max-Meldrum)) +- Added changelog generator script and configuration. [\#289](https://github.com/apache/arrow-rs/pull/289) ([jorgecarleitao](https://github.com/jorgecarleitao)) +- manually bump development version [\#288](https://github.com/apache/arrow-rs/pull/288) ([nevi-me](https://github.com/nevi-me)) +- Fix FFI and add support for Struct type [\#287](https://github.com/apache/arrow-rs/pull/287) ([roee88](https://github.com/roee88)) +- Fix subtraction underflow when sorting string arrays with many nulls [\#285](https://github.com/apache/arrow-rs/pull/285) ([medwards](https://github.com/medwards)) +- Speed up bound checking in `take` [\#281](https://github.com/apache/arrow-rs/pull/281) ([Dandandan](https://github.com/Dandandan)) +- Update PR template by commenting out instructions [\#278](https://github.com/apache/arrow-rs/pull/278) ([nevi-me](https://github.com/nevi-me)) +- Added Decimal support to pretty-print display utility \(\#230\) [\#273](https://github.com/apache/arrow-rs/pull/273) ([mgill25](https://github.com/mgill25)) +- Fix null struct and list roundtrip [\#270](https://github.com/apache/arrow-rs/pull/270) ([nevi-me](https://github.com/nevi-me)) +- 1.52 clippy fixes [\#267](https://github.com/apache/arrow-rs/pull/267) ([nevi-me](https://github.com/nevi-me)) +- Fix typo in csv/reader.rs [\#265](https://github.com/apache/arrow-rs/pull/265) ([domoritz](https://github.com/domoritz)) +- Fix empty Schema::metadata deserialization error [\#260](https://github.com/apache/arrow-rs/pull/260) ([hulunbier](https://github.com/hulunbier)) +- update datafusion and ballista doc links [\#259](https://github.com/apache/arrow-rs/pull/259) ([Jimexist](https://github.com/Jimexist)) +- support full u32 and u64 roundtrip through parquet [\#258](https://github.com/apache/arrow-rs/pull/258) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([crepererum](https://github.com/crepererum)) +- \[MINOR\] Added env to run rust in integration. [\#253](https://github.com/apache/arrow-rs/pull/253) ([jorgecarleitao](https://github.com/jorgecarleitao)) +- \[Minor\] Made integration tests always run. [\#248](https://github.com/apache/arrow-rs/pull/248) ([jorgecarleitao](https://github.com/jorgecarleitao)) +- fix parquet max\_definition for non-null structs [\#246](https://github.com/apache/arrow-rs/pull/246) ([nevi-me](https://github.com/nevi-me)) +- Disabled rebase needed until demonstrate working. [\#243](https://github.com/apache/arrow-rs/pull/243) ([jorgecarleitao](https://github.com/jorgecarleitao)) +- pin flatbuffers to 0.8.4 [\#239](https://github.com/apache/arrow-rs/pull/239) ([ritchie46](https://github.com/ritchie46)) +- sort\_primitive result is capped to the min of limit or values.len [\#236](https://github.com/apache/arrow-rs/pull/236) ([medwards](https://github.com/medwards)) +- Read list field correctly [\#234](https://github.com/apache/arrow-rs/pull/234) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) +- Fix code examples for RecordBatch::try\_from\_iter [\#231](https://github.com/apache/arrow-rs/pull/231) ([alamb](https://github.com/alamb)) +- Support string dictionaries in csv reader \(\#228\) [\#229](https://github.com/apache/arrow-rs/pull/229) ([tustvold](https://github.com/tustvold)) +- support LargeUtf8 in sort kernel [\#26](https://github.com/apache/arrow-rs/pull/26) ([ritchie46](https://github.com/ritchie46)) +- Removed unused files [\#22](https://github.com/apache/arrow-rs/pull/22) ([jorgecarleitao](https://github.com/jorgecarleitao)) +- ARROW-12504: Buffer::from\_slice\_ref set correct capacity [\#18](https://github.com/apache/arrow-rs/pull/18) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add GitHub templates [\#17](https://github.com/apache/arrow-rs/pull/17) ([andygrove](https://github.com/andygrove)) +- ARROW-12493: Add support for writing dictionary arrays to CSV and JSON [\#16](https://github.com/apache/arrow-rs/pull/16) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- ARROW-12426: \[Rust\] Fix concatentation of arrow dictionaries [\#15](https://github.com/apache/arrow-rs/pull/15) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update repository and homepage urls [\#14](https://github.com/apache/arrow-rs/pull/14) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Dandandan](https://github.com/Dandandan)) +- Added rebase-needed bot [\#13](https://github.com/apache/arrow-rs/pull/13) ([jorgecarleitao](https://github.com/jorgecarleitao)) +- Added Integration tests against arrow [\#10](https://github.com/apache/arrow-rs/pull/10) ([jorgecarleitao](https://github.com/jorgecarleitao)) + +## [4.4.0](https://github.com/apache/arrow-rs/tree/4.4.0) (2021-06-24) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/4.3.0...4.4.0) + +**Breaking changes:** + +- migrate partition kernel to use Iterator trait [\#437](https://github.com/apache/arrow-rs/issues/437) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove DictionaryArray::keys\_array [\#391](https://github.com/apache/arrow-rs/issues/391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Implemented enhancements:** + +- sort kernel boolean sort can be O\(n\) [\#447](https://github.com/apache/arrow-rs/issues/447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- C data interface for decimal128, timestamp, date32 and date64 [\#413](https://github.com/apache/arrow-rs/issues/413) +- Add Decimal to CsvWriter [\#405](https://github.com/apache/arrow-rs/issues/405) +- Use iterators to increase performance of creating Arrow arrays [\#200](https://github.com/apache/arrow-rs/issues/200) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Release Audit Tool \(RAT\) is not being triggered [\#481](https://github.com/apache/arrow-rs/issues/481) +- Security Vulnerabilities: flatbuffers: `read_scalar` and `read_scalar_at` allow transmuting values without `unsafe` blocks [\#476](https://github.com/apache/arrow-rs/issues/476) +- Clippy broken after upgrade to rust 1.53 [\#467](https://github.com/apache/arrow-rs/issues/467) +- Pull Request Labeler is not working [\#462](https://github.com/apache/arrow-rs/issues/462) +- Arrow 4.3 release: error\[E0658\]: use of unstable library feature 'partition\_point': new API [\#456](https://github.com/apache/arrow-rs/issues/456) +- parquet reading hangs when row\_group contains more than 2048 rows of data [\#349](https://github.com/apache/arrow-rs/issues/349) +- Fail to build arrow [\#247](https://github.com/apache/arrow-rs/issues/247) +- JSON reader does not implement iterator [\#193](https://github.com/apache/arrow-rs/issues/193) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Security fixes:** + +- Ensure a successful MIRI Run on CI [\#227](https://github.com/apache/arrow-rs/issues/227) + +**Closed issues:** + +- sort kernel has a lot of unnecessary wrapping [\#446](https://github.com/apache/arrow-rs/issues/446) +- \[Parquet\] Plain encoded boolean column chunks limited to 2048 values [\#48](https://github.com/apache/arrow-rs/issues/48) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +## [4.3.0](https://github.com/apache/arrow-rs/tree/4.3.0) (2021-06-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/4.2.0...4.3.0) + +**Implemented enhancements:** + +- Add partitioning kernel for sorted arrays [\#428](https://github.com/apache/arrow-rs/issues/428) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement sort by float lists [\#427](https://github.com/apache/arrow-rs/issues/427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Derive Eq and PartialEq for SortOptions [\#426](https://github.com/apache/arrow-rs/issues/426) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- use prettier and github action to normalize markdown document syntax [\#399](https://github.com/apache/arrow-rs/issues/399) +- window::shift can work for more than just primitive array type [\#392](https://github.com/apache/arrow-rs/issues/392) +- Doctest for ArrayBuilder [\#366](https://github.com/apache/arrow-rs/issues/366) + +**Fixed bugs:** + +- Boolean `not` kernel does not take offset of null buffer into account [\#417](https://github.com/apache/arrow-rs/issues/417) +- my contribution not marged in 4.2 release [\#394](https://github.com/apache/arrow-rs/issues/394) +- window::shift shall properly handle boundary cases [\#387](https://github.com/apache/arrow-rs/issues/387) +- Parquet `WriterProperties.max_row_group_size` not wired up [\#257](https://github.com/apache/arrow-rs/issues/257) +- Out of bound reads in chunk iterator [\#198](https://github.com/apache/arrow-rs/issues/198) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +## [4.2.0](https://github.com/apache/arrow-rs/tree/4.2.0) (2021-05-29) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/4.1.0...4.2.0) + +**Breaking changes:** + +- DictionaryArray::values\(\) clones the underlying ArrayRef [\#313](https://github.com/apache/arrow-rs/issues/313) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Implemented enhancements:** + +- Simplify shift kernel using null array [\#371](https://github.com/apache/arrow-rs/issues/371) +- Provide `Arc`-based constructor for `parquet::util::cursor::SliceableCursor` [\#368](https://github.com/apache/arrow-rs/issues/368) +- Add badges to crates [\#361](https://github.com/apache/arrow-rs/issues/361) +- Consider inlining PrimitiveArray::value [\#328](https://github.com/apache/arrow-rs/issues/328) +- Implement automated release verification script [\#327](https://github.com/apache/arrow-rs/issues/327) +- Add wasm32 to the list of target architectures of the simd feature [\#316](https://github.com/apache/arrow-rs/issues/316) +- add with\_escape for csv::ReaderBuilder [\#315](https://github.com/apache/arrow-rs/issues/315) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- IPC feature gate [\#310](https://github.com/apache/arrow-rs/issues/310) +- csv feature gate [\#309](https://github.com/apache/arrow-rs/issues/309) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `shrink_to` / `shrink_to_fit` to `MutableBuffer` [\#297](https://github.com/apache/arrow-rs/issues/297) + +**Fixed bugs:** + +- Incorrect crate setup instructions [\#364](https://github.com/apache/arrow-rs/issues/364) +- Arrow-flight only register rerun-if-changed if file exists [\#350](https://github.com/apache/arrow-rs/issues/350) +- Dictionary Comparison Uses Wrong Values Array [\#332](https://github.com/apache/arrow-rs/issues/332) +- Undefined behavior in FFI implementation [\#322](https://github.com/apache/arrow-rs/issues/322) +- All-null column get wrong parquet null-counts [\#306](https://github.com/apache/arrow-rs/issues/306) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Filter has inconsistent null handling [\#295](https://github.com/apache/arrow-rs/issues/295) + +## [4.1.0](https://github.com/apache/arrow-rs/tree/4.1.0) (2021-05-17) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/4.0.0...4.1.0) + +**Implemented enhancements:** + +- Add Send to ArrayBuilder [\#290](https://github.com/apache/arrow-rs/issues/290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of bound checking option [\#280](https://github.com/apache/arrow-rs/issues/280) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- extend compute kernel arity to include nullary functions [\#276](https://github.com/apache/arrow-rs/issues/276) +- Implement FFI / CDataInterface for Struct Arrays [\#251](https://github.com/apache/arrow-rs/issues/251) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add support for pretty-printing Decimal numbers [\#230](https://github.com/apache/arrow-rs/issues/230) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- CSV Reader String Dictionary Support [\#228](https://github.com/apache/arrow-rs/issues/228) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Builder interface for adding Arrays to record batches [\#210](https://github.com/apache/arrow-rs/issues/210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support auto-vectorization for min/max [\#209](https://github.com/apache/arrow-rs/issues/209) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support LargeUtf8 in sort kernel [\#25](https://github.com/apache/arrow-rs/issues/25) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- no method named `select_nth_unstable_by` found for mutable reference `&mut [T]` [\#283](https://github.com/apache/arrow-rs/issues/283) +- Rust 1.52 Clippy error [\#266](https://github.com/apache/arrow-rs/issues/266) +- NaNs can break parquet statistics [\#255](https://github.com/apache/arrow-rs/issues/255) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- u64::MAX does not roundtrip through parquet [\#254](https://github.com/apache/arrow-rs/issues/254) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Integration tests failing to compile \(flatbuffer\) [\#249](https://github.com/apache/arrow-rs/issues/249) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix compatibility quirks between arrow and parquet structs [\#245](https://github.com/apache/arrow-rs/issues/245) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Unable to write non-null Arrow structs to Parquet [\#244](https://github.com/apache/arrow-rs/issues/244) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- schema: missing field `metadata` when deserialize [\#241](https://github.com/apache/arrow-rs/issues/241) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Arrow does not compile due to flatbuffers upgrade [\#238](https://github.com/apache/arrow-rs/issues/238) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Sort with limit panics for the limit includes some but not all nulls, for large arrays [\#235](https://github.com/apache/arrow-rs/issues/235) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-rs contains a copy of the "format" directory [\#233](https://github.com/apache/arrow-rs/issues/233) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix SEGFAULT/ SIGILL in child-data ffi [\#206](https://github.com/apache/arrow-rs/issues/206) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Read list field correctly in \\> [\#167](https://github.com/apache/arrow-rs/issues/167) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- FFI listarray lead to undefined behavior. [\#20](https://github.com/apache/arrow-rs/issues/20) + +**Security fixes:** + +- Fix MIRI build on CI [\#226](https://github.com/apache/arrow-rs/issues/226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Get MIRI running again [\#224](https://github.com/apache/arrow-rs/issues/224) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Comment out the instructions in the PR template [\#277](https://github.com/apache/arrow-rs/issues/277) +- Update links to datafusion and ballista in README.md [\#19](https://github.com/apache/arrow-rs/issues/19) +- Update "repository" in Cargo.toml [\#12](https://github.com/apache/arrow-rs/issues/12) + +**Closed issues:** + +- Arrow Aligned Vec [\#268](https://github.com/apache/arrow-rs/issues/268) +- \[Rust\]: Tracking issue for AVX-512 [\#220](https://github.com/apache/arrow-rs/issues/220) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Umbrella issue for clippy integration [\#217](https://github.com/apache/arrow-rs/issues/217) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support sort [\#215](https://github.com/apache/arrow-rs/issues/215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support stable Rust [\#214](https://github.com/apache/arrow-rs/issues/214) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove Rust and point integration tests to arrow-rs repo [\#211](https://github.com/apache/arrow-rs/issues/211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ArrayData buffers are inconsistent accross implementations [\#207](https://github.com/apache/arrow-rs/issues/207) +- 3.0.1 patch release [\#204](https://github.com/apache/arrow-rs/issues/204) +- Document patch release process [\#202](https://github.com/apache/arrow-rs/issues/202) +- Simplify Offset [\#186](https://github.com/apache/arrow-rs/issues/186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Typed Bytes [\#185](https://github.com/apache/arrow-rs/issues/185) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[CI\]docker-compose setup should enable caching [\#175](https://github.com/apache/arrow-rs/issues/175) +- Improve take primitive performance [\#174](https://github.com/apache/arrow-rs/issues/174) +- \[CI\] Try out buildkite [\#165](https://github.com/apache/arrow-rs/issues/165) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update assignees in JIRA where missing [\#160](https://github.com/apache/arrow-rs/issues/160) +- \[Rust\]: From\ implementations should validate data type [\#103](https://github.com/apache/arrow-rs/issues/103) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] Verify that projection push down does not remove aliases columns [\#99](https://github.com/apache/arrow-rs/issues/99) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Rust\]\[DataFusion\] Implement modulus expression [\#98](https://github.com/apache/arrow-rs/issues/98) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] Add constant folding to expressions during logically planning [\#96](https://github.com/apache/arrow-rs/issues/96) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] DataFrame.collect should return RecordBatchReader [\#95](https://github.com/apache/arrow-rs/issues/95) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Rust\]\[DataFusion\] Add FORMAT to explain plan and an easy to visualize format [\#94](https://github.com/apache/arrow-rs/issues/94) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] Implement metrics framework [\#90](https://github.com/apache/arrow-rs/issues/90) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] Implement micro benchmarks for each operator [\#89](https://github.com/apache/arrow-rs/issues/89) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] Implement pretty print for physical query plan [\#88](https://github.com/apache/arrow-rs/issues/88) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Archery\] Support rust clippy in the lint command [\#83](https://github.com/apache/arrow-rs/issues/83) +- \[rust\]\[datafusion\] optimize count\(\*\) queries on parquet sources [\#75](https://github.com/apache/arrow-rs/issues/75) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Rust\]\[DataFusion\] Improve like/nlike performance [\#71](https://github.com/apache/arrow-rs/issues/71) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] Implement optimizer rule to remove redundant projections [\#56](https://github.com/apache/arrow-rs/issues/56) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[DataFusion\] Parquet data source does not support complex types [\#39](https://github.com/apache/arrow-rs/issues/39) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Merge utils from Parquet and Arrow [\#32](https://github.com/apache/arrow-rs/issues/32) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add benchmarks for Parquet [\#30](https://github.com/apache/arrow-rs/issues/30) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Mark methods that do not perform bounds checking as unsafe [\#28](https://github.com/apache/arrow-rs/issues/28) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Test issue [\#24](https://github.com/apache/arrow-rs/issues/24) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- This is a test issue [\#11](https://github.com/apache/arrow-rs/issues/11) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3acf7706a3e8..549d4da1a6b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,485 +17,83 @@ under the License. --> -For older versions, see [apache/arrow/CHANGELOG.md](https://github.com/apache/arrow/blob/master/CHANGELOG.md) - # Changelog -## [6.0.0](https://github.com/apache/arrow-rs/tree/6.0.0) (2021-10-13) +## [16.0.0](https://github.com/apache/arrow-rs/tree/16.0.0) (2022-06-10) -[Full Changelog](https://github.com/apache/arrow-rs/compare/5.5.0...6.0.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/15.0.0...16.0.0) **Breaking changes:** -- Replace `ArrayData::new()` with `ArrayData::try_new()` and `unsafe ArrayData::new_unchecked` [\#822](https://github.com/apache/arrow-rs/pull/822) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Update Bitmap::len to return bits rather than bytes [\#749](https://github.com/apache/arrow-rs/pull/749) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([matthewmturner](https://github.com/matthewmturner)) -- use sort\_unstable\_by in primitive sorting [\#552](https://github.com/apache/arrow-rs/pull/552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) -- New MapArray support [\#491](https://github.com/apache/arrow-rs/pull/491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nevi-me](https://github.com/nevi-me)) - -**Implemented enhancements:** - -- Improve parquet binary writer speed by reducing allocations [\#819](https://github.com/apache/arrow-rs/issues/819) -- Expose buffer operations [\#808](https://github.com/apache/arrow-rs/issues/808) -- Add doc examples of writing parquet files using `ArrowWriter` [\#788](https://github.com/apache/arrow-rs/issues/788) - -**Fixed bugs:** - -- JSON reader can create null struct children on empty lists [\#825](https://github.com/apache/arrow-rs/issues/825) -- Incorrect null count for cast kernel for list arrays [\#815](https://github.com/apache/arrow-rs/issues/815) -- `minute` and `second` temporal kernels do not respect timezone [\#500](https://github.com/apache/arrow-rs/issues/500) -- Fix data corruption in json decoder f64-to-i64 cast [\#652](https://github.com/apache/arrow-rs/pull/652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xianwill](https://github.com/xianwill)) - -**Documentation updates:** - -- Doctest for PrimitiveArray using from\_iter\_values. [\#694](https://github.com/apache/arrow-rs/pull/694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([novemberkilo](https://github.com/novemberkilo)) -- Doctests for BinaryArray and LargeBinaryArray. [\#625](https://github.com/apache/arrow-rs/pull/625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([novemberkilo](https://github.com/novemberkilo)) -- Add links in docstrings [\#605](https://github.com/apache/arrow-rs/pull/605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) - - -## [5.5.0](https://github.com/apache/arrow-rs/tree/5.5.0) (2021-09-24) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/5.4.0...5.5.0) - -**Implemented enhancements:** - -- parquet should depend on a small set of arrow features [\#800](https://github.com/apache/arrow-rs/issues/800) -- Support equality on RecordBatch [\#735](https://github.com/apache/arrow-rs/issues/735) - -**Fixed bugs:** - -- Converting from string to timestamp uses microseconds instead of milliseconds [\#780](https://github.com/apache/arrow-rs/issues/780) -- Document has no link to `RowColumIter` [\#762](https://github.com/apache/arrow-rs/issues/762) -- length on slices with null doesn't work [\#744](https://github.com/apache/arrow-rs/issues/744) - -## [5.4.0](https://github.com/apache/arrow-rs/tree/5.4.0) (2021-09-10) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/5.3.0...5.4.0) - -**Implemented enhancements:** - -- Upgrade lexical-core to 0.8 [\#747](https://github.com/apache/arrow-rs/issues/747) -- `append_nulls` and `append_trusted_len_iter` for PrimitiveBuilder [\#725](https://github.com/apache/arrow-rs/issues/725) -- Optimize MutableArrayData::extend for null buffers [\#397](https://github.com/apache/arrow-rs/issues/397) - -**Fixed bugs:** - -- Arithmetic with scalars doesn't work on slices [\#742](https://github.com/apache/arrow-rs/issues/742) -- Comparisons with scalar don't work on slices [\#740](https://github.com/apache/arrow-rs/issues/740) -- `unary` kernel doesn't respect offset [\#738](https://github.com/apache/arrow-rs/issues/738) -- `new_null_array` creates invalid struct arrays [\#734](https://github.com/apache/arrow-rs/issues/734) -- --no-default-features is broken for parquet [\#733](https://github.com/apache/arrow-rs/issues/733) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `Bitmap::len` returns the number of bytes, not bits. [\#730](https://github.com/apache/arrow-rs/issues/730) -- Decimal logical type is formatted incorrectly by print\_schema [\#713](https://github.com/apache/arrow-rs/issues/713) -- parquet\_derive does not support chrono time values [\#711](https://github.com/apache/arrow-rs/issues/711) -- Numeric overflow when formatting Decimal type [\#710](https://github.com/apache/arrow-rs/issues/710) -- The integration tests are not running [\#690](https://github.com/apache/arrow-rs/issues/690) - -**Closed issues:** - -- Question: Is there no way to create a DictionaryArray with a pre-arranged mapping? [\#729](https://github.com/apache/arrow-rs/issues/729) - -## [5.3.0](https://github.com/apache/arrow-rs/tree/5.3.0) (2021-08-26) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/5.2.0...5.3.0) - -**Implemented enhancements:** - -- Add optimized filter kernel for regular expression matching [\#697](https://github.com/apache/arrow-rs/issues/697) -- Can't cast from timestamp array to string array [\#587](https://github.com/apache/arrow-rs/issues/587) - -**Fixed bugs:** - -- 'Encoding DELTA\_BYTE\_ARRAY is not supported' with parquet arrow readers [\#708](https://github.com/apache/arrow-rs/issues/708) -- Support reading json string into binary data type. [\#701](https://github.com/apache/arrow-rs/issues/701) - -**Closed issues:** - -- Resolve Issues with `prettytable-rs` dependency [\#69](https://github.com/apache/arrow-rs/issues/69) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -## [5.2.0](https://github.com/apache/arrow-rs/tree/5.2.0) (2021-08-12) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/5.1.0...5.2.0) - -**Implemented enhancements:** - -- Make rand an optional dependency [\#671](https://github.com/apache/arrow-rs/issues/671) -- Remove undefined behavior in `value` method of boolean and primitive arrays [\#645](https://github.com/apache/arrow-rs/issues/645) -- Avoid materialization of indices in filter\_record\_batch for single arrays [\#636](https://github.com/apache/arrow-rs/issues/636) -- Add a note about arrow crate security / safety [\#627](https://github.com/apache/arrow-rs/issues/627) -- Allow the creation of String arrays from an interator of &Option\<&str\> [\#598](https://github.com/apache/arrow-rs/issues/598) -- Support arrow map datatype [\#395](https://github.com/apache/arrow-rs/issues/395) - -**Fixed bugs:** - -- Parquet fixed length byte array columns write byte array statistics [\#660](https://github.com/apache/arrow-rs/issues/660) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Parquet boolean columns write Int32 statistics [\#659](https://github.com/apache/arrow-rs/issues/659) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Writing Parquet with a boolean column fails [\#657](https://github.com/apache/arrow-rs/issues/657) -- JSON decoder data corruption for large i64/u64 [\#653](https://github.com/apache/arrow-rs/issues/653) -- Incorrect min/max statistics for strings in parquet files [\#641](https://github.com/apache/arrow-rs/issues/641) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] - -**Closed issues:** - -- Release candidate verifying script seems work on macOS [\#640](https://github.com/apache/arrow-rs/issues/640) -- Update CONTRIBUTING [\#342](https://github.com/apache/arrow-rs/issues/342) - -## [5.1.0](https://github.com/apache/arrow-rs/tree/5.1.0) (2021-07-29) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/5.0.0...5.1.0) +- Seal `ArrowNativeType` and `OffsetSizeTrait` for safety \(\#1028\) [\#1819](https://github.com/apache/arrow-rs/pull/1819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Improve API for `csv::infer_file_schema` by removing redundant ref [\#1776](https://github.com/apache/arrow-rs/pull/1776) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) **Implemented enhancements:** -- Make FFI\_ArrowArray empty\(\) public [\#602](https://github.com/apache/arrow-rs/issues/602) -- exponential sort can be used to speed up lexico partition kernel [\#586](https://github.com/apache/arrow-rs/issues/586) -- Implement sort\(\) for binary array [\#568](https://github.com/apache/arrow-rs/issues/568) -- primitive sorting can be improved and more consistent with and without `limit` if sorted unstably [\#553](https://github.com/apache/arrow-rs/issues/553) +- List equality method should work on empty offset `ListArray` [\#1817](https://github.com/apache/arrow-rs/issues/1817) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Command line tool for convert CSV to Parquet [\#1797](https://github.com/apache/arrow-rs/issues/1797) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- IPC writer should write validity buffer for `UnionArray` in V4 IPC message [\#1793](https://github.com/apache/arrow-rs/issues/1793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add function for row alignment with page mask [\#1790](https://github.com/apache/arrow-rs/issues/1790) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Rust IPC Read should be able to read V4 UnionType Array [\#1788](https://github.com/apache/arrow-rs/issues/1788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `combine_option_bitmap` should accept arbitrary number of input arrays. [\#1780](https://github.com/apache/arrow-rs/issues/1780) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `substring_by_char` kernels for slicing on character boundaries [\#1768](https://github.com/apache/arrow-rs/issues/1768) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support reading `PageIndex` from column metadata [\#1761](https://github.com/apache/arrow-rs/issues/1761) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support casting from `DataType::Utf8` to `DataType::Boolean` [\#1740](https://github.com/apache/arrow-rs/issues/1740) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make current position available in `FileWriter`. [\#1691](https://github.com/apache/arrow-rs/issues/1691) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support writing parquet to `stdout` [\#1687](https://github.com/apache/arrow-rs/issues/1687) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Fixed bugs:** -- Confusing memory usage with CSV reader [\#623](https://github.com/apache/arrow-rs/issues/623) -- FFI implementation deviates from specification for array release [\#595](https://github.com/apache/arrow-rs/issues/595) -- Parquet file content is different if `~/.cargo` is in a git checkout [\#589](https://github.com/apache/arrow-rs/issues/589) -- Ensure output of MIRI is checked for success [\#581](https://github.com/apache/arrow-rs/issues/581) -- MIRI failure in `array::ffi::tests::test_struct` and other ffi tests [\#580](https://github.com/apache/arrow-rs/issues/580) -- ListArray equality check may return wrong result [\#570](https://github.com/apache/arrow-rs/issues/570) -- cargo audit failed [\#561](https://github.com/apache/arrow-rs/issues/561) -- ArrayData::slice\(\) does not work for nested types such as StructArray [\#554](https://github.com/apache/arrow-rs/issues/554) +- Incorrect Offset Validation for Sliced List Array Children [\#1814](https://github.com/apache/arrow-rs/issues/1814) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet Snappy Codec overwrites Existing Data in Decompression Buffer [\#1806](https://github.com/apache/arrow-rs/issues/1806) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `flight_data_to_arrow_batch` does not support `RecordBatch`es with no columns [\#1783](https://github.com/apache/arrow-rs/issues/1783) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- parquet does not compile with `features=["zstd"]` [\#1630](https://github.com/apache/arrow-rs/issues/1630) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Documentation updates:** -- More examples of how to construct Arrays [\#301](https://github.com/apache/arrow-rs/issues/301) +- Update arrow module docs [\#1840](https://github.com/apache/arrow-rs/pull/1840) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update safety disclaimer [\#1837](https://github.com/apache/arrow-rs/pull/1837) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Update ballista readme link [\#1765](https://github.com/apache/arrow-rs/pull/1765) ([tustvold](https://github.com/tustvold)) +- Move changelog archive to `CHANGELOG-old.md` [\#1759](https://github.com/apache/arrow-rs/pull/1759) ([alamb](https://github.com/alamb)) **Closed issues:** -- Implement StringBuilder::append\_option [\#263](https://github.com/apache/arrow-rs/issues/263) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -## [5.0.0](https://github.com/apache/arrow-rs/tree/5.0.0) (2021-07-14) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/4.4.0...5.0.0) - -**Breaking changes:** - -- Remove lifetime from DynComparator [\#543](https://github.com/apache/arrow-rs/issues/543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Simplify interactions with arrow flight APIs [\#376](https://github.com/apache/arrow-rs/issues/376) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- refactor: remove lifetime from DynComparator [\#542](https://github.com/apache/arrow-rs/pull/542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([e-dard](https://github.com/e-dard)) -- use iterator for partition kernel instead of generating vec [\#438](https://github.com/apache/arrow-rs/pull/438) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) -- Remove DictionaryArray::keys\_array method [\#419](https://github.com/apache/arrow-rs/pull/419) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- simplify interactions with arrow flight APIs [\#377](https://github.com/apache/arrow-rs/pull/377) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([garyanaplan](https://github.com/garyanaplan)) -- return reference from DictionaryArray::values\(\) \(\#313\) [\#314](https://github.com/apache/arrow-rs/pull/314) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) - -**Implemented enhancements:** - -- Allow creation of StringArrays from Vec\ [\#519](https://github.com/apache/arrow-rs/issues/519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Implement RecordBatch::concat [\#461](https://github.com/apache/arrow-rs/issues/461) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Implement RecordBatch::slice\(\) to slice RecordBatches [\#460](https://github.com/apache/arrow-rs/issues/460) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add a RecordBatch::split to split large batches into a set of smaller batches [\#343](https://github.com/apache/arrow-rs/issues/343) -- generate parquet schema from rust struct [\#539](https://github.com/apache/arrow-rs/pull/539) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) -- Implement `RecordBatch::concat` [\#537](https://github.com/apache/arrow-rs/pull/537) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([silathdiir](https://github.com/silathdiir)) -- Implement function slice for RecordBatch [\#490](https://github.com/apache/arrow-rs/pull/490) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([b41sh](https://github.com/b41sh)) -- add lexicographically partition points and ranges [\#424](https://github.com/apache/arrow-rs/pull/424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) -- allow to read non-standard CSV [\#326](https://github.com/apache/arrow-rs/pull/326) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kazuk](https://github.com/kazuk)) -- parquet: Speed up `BitReader`/`DeltaBitPackDecoder` [\#325](https://github.com/apache/arrow-rs/pull/325) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kornholi](https://github.com/kornholi)) -- ARROW-12343: \[Rust\] Support auto-vectorization for min/max [\#9](https://github.com/apache/arrow-rs/pull/9) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- ARROW-12411: \[Rust\] Create RecordBatches from Iterators [\#7](https://github.com/apache/arrow-rs/pull/7) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) - -**Fixed bugs:** - -- Error building on master - error: cyclic package dependency: package `ahash v0.7.4` depends on itself. Cycle [\#544](https://github.com/apache/arrow-rs/issues/544) -- IPC reader panics with out of bounds error [\#541](https://github.com/apache/arrow-rs/issues/541) -- Take kernel doesn't handle nulls and structs correctly [\#530](https://github.com/apache/arrow-rs/issues/530) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- master fails to compile with `default-features=false` [\#529](https://github.com/apache/arrow-rs/issues/529) -- README developer instructions out of date [\#523](https://github.com/apache/arrow-rs/issues/523) -- Update rustc and packed\_simd in CI before 5.0 release [\#517](https://github.com/apache/arrow-rs/issues/517) -- Incorrect memory usage calculation for dictionary arrays [\#503](https://github.com/apache/arrow-rs/issues/503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- sliced null buffers lead to incorrect result in take kernel \(and probably on other places\) [\#502](https://github.com/apache/arrow-rs/issues/502) -- Cast of utf8 types and list container types don't respect offset [\#334](https://github.com/apache/arrow-rs/issues/334) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- fix take kernel null handling on structs [\#531](https://github.com/apache/arrow-rs/pull/531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([bjchambers](https://github.com/bjchambers)) -- Correct array memory usage calculation for dictionary arrays [\#505](https://github.com/apache/arrow-rs/pull/505) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- parquet: improve BOOLEAN writing logic and report error on encoding fail [\#443](https://github.com/apache/arrow-rs/pull/443) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([garyanaplan](https://github.com/garyanaplan)) -- Fix bug with null buffer offset in boolean not kernel [\#418](https://github.com/apache/arrow-rs/pull/418) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- respect offset in utf8 and list casts [\#335](https://github.com/apache/arrow-rs/pull/335) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) -- Fix comparison of dictionaries with different values arrays \(\#332\) [\#333](https://github.com/apache/arrow-rs/pull/333) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- ensure null-counts are written for all-null columns [\#307](https://github.com/apache/arrow-rs/pull/307) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([crepererum](https://github.com/crepererum)) -- fix invalid null handling in filter [\#296](https://github.com/apache/arrow-rs/pull/296) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) -- fix NaN handling in parquet statistics [\#256](https://github.com/apache/arrow-rs/pull/256) ([crepererum](https://github.com/crepererum)) - -**Documentation updates:** - -- Improve arrow's crate's readme on crates.io [\#463](https://github.com/apache/arrow-rs/issues/463) -- Clean up README.md in advance of the 5.0 release [\#536](https://github.com/apache/arrow-rs/pull/536) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- fix readme instructions to reflect new structure [\#524](https://github.com/apache/arrow-rs/pull/524) ([marcvanheerden](https://github.com/marcvanheerden)) -- Improve docs for NullArray, new\_null\_array and new\_empty\_array [\#240](https://github.com/apache/arrow-rs/pull/240) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- `DataType::Decimal` Non-Compliant? [\#1779](https://github.com/apache/arrow-rs/issues/1779) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Further simplify the offset validation [\#1770](https://github.com/apache/arrow-rs/issues/1770) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Best way to convert arrow to Rust native type [\#1760](https://github.com/apache/arrow-rs/issues/1760) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Why `Parquet` is a part of `Arrow`? [\#1715](https://github.com/apache/arrow-rs/issues/1715) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Merged pull requests:** -- Fix default arrow build [\#533](https://github.com/apache/arrow-rs/pull/533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Add tests for building applications using arrow with different feature flags [\#532](https://github.com/apache/arrow-rs/pull/532) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Remove unused futures dependency from arrow-flight [\#528](https://github.com/apache/arrow-rs/pull/528) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) -- CI: update rust nightly and packed\_simd [\#525](https://github.com/apache/arrow-rs/pull/525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) -- Support `StringArray` creation from String Vec [\#522](https://github.com/apache/arrow-rs/pull/522) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([silathdiir](https://github.com/silathdiir)) -- Fix parquet benchmark schema [\#513](https://github.com/apache/arrow-rs/pull/513) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) -- Fix parquet definition levels [\#511](https://github.com/apache/arrow-rs/pull/511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) -- Fix for primitive and boolean take kernel for nullable indices with an offset [\#509](https://github.com/apache/arrow-rs/pull/509) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- Bump flatbuffers [\#499](https://github.com/apache/arrow-rs/pull/499) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([PsiACE](https://github.com/PsiACE)) -- implement second/minute helpers for temporal [\#493](https://github.com/apache/arrow-rs/pull/493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ovr](https://github.com/ovr)) -- special case concatenating single element array shortcut [\#492](https://github.com/apache/arrow-rs/pull/492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jimexist](https://github.com/Jimexist)) -- update docs to reflect recent changes \(joins and window functions\) [\#489](https://github.com/apache/arrow-rs/pull/489) ([Jimexist](https://github.com/Jimexist)) -- Update rand, proc-macro and zstd dependencies [\#488](https://github.com/apache/arrow-rs/pull/488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Doctest for GenericListArray. [\#474](https://github.com/apache/arrow-rs/pull/474) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([novemberkilo](https://github.com/novemberkilo)) -- remove stale comment on `ArrayData` equality and update unit tests [\#472](https://github.com/apache/arrow-rs/pull/472) ([Jimexist](https://github.com/Jimexist)) -- remove unused patch file [\#471](https://github.com/apache/arrow-rs/pull/471) ([Jimexist](https://github.com/Jimexist)) -- fix clippy warnings for rust 1.53 [\#470](https://github.com/apache/arrow-rs/pull/470) ([Jimexist](https://github.com/Jimexist)) -- Fix PR labeler [\#468](https://github.com/apache/arrow-rs/pull/468) ([Dandandan](https://github.com/Dandandan)) -- Tweak dev backporting docs [\#466](https://github.com/apache/arrow-rs/pull/466) ([alamb](https://github.com/alamb)) -- Unvendor Archery [\#459](https://github.com/apache/arrow-rs/pull/459) ([kszucs](https://github.com/kszucs)) -- Add sort boolean benchmark [\#457](https://github.com/apache/arrow-rs/pull/457) ([alamb](https://github.com/alamb)) -- Add C data interface for decimal128 and timestamp [\#453](https://github.com/apache/arrow-rs/pull/453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alippai](https://github.com/alippai)) -- Implement the Iterator trait for the json Reader. [\#451](https://github.com/apache/arrow-rs/pull/451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([LaurentMazare](https://github.com/LaurentMazare)) -- Update release docs + release email template [\#450](https://github.com/apache/arrow-rs/pull/450) ([alamb](https://github.com/alamb)) -- remove clippy unnecessary wraps suppresions in cast kernel [\#449](https://github.com/apache/arrow-rs/pull/449) ([Jimexist](https://github.com/Jimexist)) -- Use partition for bool sort [\#448](https://github.com/apache/arrow-rs/pull/448) ([Jimexist](https://github.com/Jimexist)) -- remove unnecessary wraps in sort [\#445](https://github.com/apache/arrow-rs/pull/445) ([Jimexist](https://github.com/Jimexist)) -- Python FFI bridge for Schema, Field and DataType [\#439](https://github.com/apache/arrow-rs/pull/439) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kszucs](https://github.com/kszucs)) -- Update release Readme.md [\#436](https://github.com/apache/arrow-rs/pull/436) ([alamb](https://github.com/alamb)) -- Derive Eq and PartialEq for SortOptions [\#425](https://github.com/apache/arrow-rs/pull/425) ([tustvold](https://github.com/tustvold)) -- refactor lexico sort for future code reuse [\#423](https://github.com/apache/arrow-rs/pull/423) ([Jimexist](https://github.com/Jimexist)) -- Reenable MIRI check on PRs [\#421](https://github.com/apache/arrow-rs/pull/421) ([alamb](https://github.com/alamb)) -- Sort by float lists [\#420](https://github.com/apache/arrow-rs/pull/420) ([medwards](https://github.com/medwards)) -- Fix out of bounds read in bit chunk iterator [\#416](https://github.com/apache/arrow-rs/pull/416) ([jhorstmann](https://github.com/jhorstmann)) -- Doctests for DecimalArray. [\#414](https://github.com/apache/arrow-rs/pull/414) ([novemberkilo](https://github.com/novemberkilo)) -- Add Decimal to CsvWriter and improve debug display [\#406](https://github.com/apache/arrow-rs/pull/406) ([alippai](https://github.com/alippai)) -- MINOR: update install instruction [\#400](https://github.com/apache/arrow-rs/pull/400) ([alippai](https://github.com/alippai)) -- use prettier to auto format md files [\#398](https://github.com/apache/arrow-rs/pull/398) ([Jimexist](https://github.com/Jimexist)) -- window::shift to work for all array types [\#388](https://github.com/apache/arrow-rs/pull/388) ([Jimexist](https://github.com/Jimexist)) -- add more tests for window::shift and handle boundary cases [\#386](https://github.com/apache/arrow-rs/pull/386) ([Jimexist](https://github.com/Jimexist)) -- Implement faster arrow array reader [\#384](https://github.com/apache/arrow-rs/pull/384) ([yordan-pavlov](https://github.com/yordan-pavlov)) -- Add set\_bit to BooleanBufferBuilder to allow mutating bit in index [\#383](https://github.com/apache/arrow-rs/pull/383) ([boazberman](https://github.com/boazberman)) -- make sure that only concat preallocates buffers [\#382](https://github.com/apache/arrow-rs/pull/382) ([ritchie46](https://github.com/ritchie46)) -- Respect max rowgroup size in Arrow writer [\#381](https://github.com/apache/arrow-rs/pull/381) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) -- Fix typo in release script, update release location [\#380](https://github.com/apache/arrow-rs/pull/380) ([alamb](https://github.com/alamb)) -- Doctests for FixedSizeBinaryArray [\#378](https://github.com/apache/arrow-rs/pull/378) ([novemberkilo](https://github.com/novemberkilo)) -- Simplify shift kernel using new\_null\_array [\#370](https://github.com/apache/arrow-rs/pull/370) ([Dandandan](https://github.com/Dandandan)) -- allow `SliceableCursor` to be constructed from an `Arc` directly [\#369](https://github.com/apache/arrow-rs/pull/369) ([crepererum](https://github.com/crepererum)) -- Add doctest for ArrayBuilder [\#367](https://github.com/apache/arrow-rs/pull/367) ([alippai](https://github.com/alippai)) -- Fix version in readme [\#365](https://github.com/apache/arrow-rs/pull/365) ([domoritz](https://github.com/domoritz)) -- Remove superfluous space [\#363](https://github.com/apache/arrow-rs/pull/363) ([domoritz](https://github.com/domoritz)) -- Add crate badges [\#362](https://github.com/apache/arrow-rs/pull/362) ([domoritz](https://github.com/domoritz)) -- Disable MIRI check until it runs cleanly on CI [\#360](https://github.com/apache/arrow-rs/pull/360) ([alamb](https://github.com/alamb)) -- Only register Flight.proto with cargo if it exists [\#351](https://github.com/apache/arrow-rs/pull/351) ([tustvold](https://github.com/tustvold)) -- Reduce memory usage of concat \(large\)utf8 [\#348](https://github.com/apache/arrow-rs/pull/348) ([ritchie46](https://github.com/ritchie46)) -- Fix filter UB and add fast path [\#341](https://github.com/apache/arrow-rs/pull/341) ([ritchie46](https://github.com/ritchie46)) -- Automatic cherry-pick script [\#339](https://github.com/apache/arrow-rs/pull/339) ([alamb](https://github.com/alamb)) -- Doctests for BooleanArray. [\#338](https://github.com/apache/arrow-rs/pull/338) ([novemberkilo](https://github.com/novemberkilo)) -- feature gate ipc reader/writer [\#336](https://github.com/apache/arrow-rs/pull/336) ([ritchie46](https://github.com/ritchie46)) -- Add ported Rust release verification script [\#331](https://github.com/apache/arrow-rs/pull/331) ([wesm](https://github.com/wesm)) -- Doctests for StringArray and LargeStringArray. [\#330](https://github.com/apache/arrow-rs/pull/330) ([novemberkilo](https://github.com/novemberkilo)) -- inline PrimitiveArray::value [\#329](https://github.com/apache/arrow-rs/pull/329) ([ritchie46](https://github.com/ritchie46)) -- Enable wasm32 as a target architecture for the SIMD feature [\#324](https://github.com/apache/arrow-rs/pull/324) ([roee88](https://github.com/roee88)) -- Fix undefined behavior in FFI and enable MIRI checks on CI [\#323](https://github.com/apache/arrow-rs/pull/323) ([roee88](https://github.com/roee88)) -- Mutablebuffer::shrink\_to\_fit [\#318](https://github.com/apache/arrow-rs/pull/318) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) -- Add \(simd\) modulus op [\#317](https://github.com/apache/arrow-rs/pull/317) ([gangliao](https://github.com/gangliao)) -- feature gate csv functionality [\#312](https://github.com/apache/arrow-rs/pull/312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ritchie46](https://github.com/ritchie46)) -- \[Minor\] Version upgrades [\#304](https://github.com/apache/arrow-rs/pull/304) ([Dandandan](https://github.com/Dandandan)) -- Remove old release scripts [\#293](https://github.com/apache/arrow-rs/pull/293) ([alamb](https://github.com/alamb)) -- Add Send to the ArrayBuilder trait [\#291](https://github.com/apache/arrow-rs/pull/291) ([Max-Meldrum](https://github.com/Max-Meldrum)) -- Added changelog generator script and configuration. [\#289](https://github.com/apache/arrow-rs/pull/289) ([jorgecarleitao](https://github.com/jorgecarleitao)) -- manually bump development version [\#288](https://github.com/apache/arrow-rs/pull/288) ([nevi-me](https://github.com/nevi-me)) -- Fix FFI and add support for Struct type [\#287](https://github.com/apache/arrow-rs/pull/287) ([roee88](https://github.com/roee88)) -- Fix subtraction underflow when sorting string arrays with many nulls [\#285](https://github.com/apache/arrow-rs/pull/285) ([medwards](https://github.com/medwards)) -- Speed up bound checking in `take` [\#281](https://github.com/apache/arrow-rs/pull/281) ([Dandandan](https://github.com/Dandandan)) -- Update PR template by commenting out instructions [\#278](https://github.com/apache/arrow-rs/pull/278) ([nevi-me](https://github.com/nevi-me)) -- Added Decimal support to pretty-print display utility \(\#230\) [\#273](https://github.com/apache/arrow-rs/pull/273) ([mgill25](https://github.com/mgill25)) -- Fix null struct and list roundtrip [\#270](https://github.com/apache/arrow-rs/pull/270) ([nevi-me](https://github.com/nevi-me)) -- 1.52 clippy fixes [\#267](https://github.com/apache/arrow-rs/pull/267) ([nevi-me](https://github.com/nevi-me)) -- Fix typo in csv/reader.rs [\#265](https://github.com/apache/arrow-rs/pull/265) ([domoritz](https://github.com/domoritz)) -- Fix empty Schema::metadata deserialization error [\#260](https://github.com/apache/arrow-rs/pull/260) ([hulunbier](https://github.com/hulunbier)) -- update datafusion and ballista doc links [\#259](https://github.com/apache/arrow-rs/pull/259) ([Jimexist](https://github.com/Jimexist)) -- support full u32 and u64 roundtrip through parquet [\#258](https://github.com/apache/arrow-rs/pull/258) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([crepererum](https://github.com/crepererum)) -- \[MINOR\] Added env to run rust in integration. [\#253](https://github.com/apache/arrow-rs/pull/253) ([jorgecarleitao](https://github.com/jorgecarleitao)) -- \[Minor\] Made integration tests always run. [\#248](https://github.com/apache/arrow-rs/pull/248) ([jorgecarleitao](https://github.com/jorgecarleitao)) -- fix parquet max\_definition for non-null structs [\#246](https://github.com/apache/arrow-rs/pull/246) ([nevi-me](https://github.com/nevi-me)) -- Disabled rebase needed until demonstrate working. [\#243](https://github.com/apache/arrow-rs/pull/243) ([jorgecarleitao](https://github.com/jorgecarleitao)) -- pin flatbuffers to 0.8.4 [\#239](https://github.com/apache/arrow-rs/pull/239) ([ritchie46](https://github.com/ritchie46)) -- sort\_primitive result is capped to the min of limit or values.len [\#236](https://github.com/apache/arrow-rs/pull/236) ([medwards](https://github.com/medwards)) -- Read list field correctly [\#234](https://github.com/apache/arrow-rs/pull/234) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nevi-me](https://github.com/nevi-me)) -- Fix code examples for RecordBatch::try\_from\_iter [\#231](https://github.com/apache/arrow-rs/pull/231) ([alamb](https://github.com/alamb)) -- Support string dictionaries in csv reader \(\#228\) [\#229](https://github.com/apache/arrow-rs/pull/229) ([tustvold](https://github.com/tustvold)) -- support LargeUtf8 in sort kernel [\#26](https://github.com/apache/arrow-rs/pull/26) ([ritchie46](https://github.com/ritchie46)) -- Removed unused files [\#22](https://github.com/apache/arrow-rs/pull/22) ([jorgecarleitao](https://github.com/jorgecarleitao)) -- ARROW-12504: Buffer::from\_slice\_ref set correct capacity [\#18](https://github.com/apache/arrow-rs/pull/18) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Add GitHub templates [\#17](https://github.com/apache/arrow-rs/pull/17) ([andygrove](https://github.com/andygrove)) -- ARROW-12493: Add support for writing dictionary arrays to CSV and JSON [\#16](https://github.com/apache/arrow-rs/pull/16) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- ARROW-12426: \[Rust\] Fix concatentation of arrow dictionaries [\#15](https://github.com/apache/arrow-rs/pull/15) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Update repository and homepage urls [\#14](https://github.com/apache/arrow-rs/pull/14) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Dandandan](https://github.com/Dandandan)) -- Added rebase-needed bot [\#13](https://github.com/apache/arrow-rs/pull/13) ([jorgecarleitao](https://github.com/jorgecarleitao)) -- Added Integration tests against arrow [\#10](https://github.com/apache/arrow-rs/pull/10) ([jorgecarleitao](https://github.com/jorgecarleitao)) - -## [4.4.0](https://github.com/apache/arrow-rs/tree/4.4.0) (2021-06-24) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/4.3.0...4.4.0) - -**Breaking changes:** - -- migrate partition kernel to use Iterator trait [\#437](https://github.com/apache/arrow-rs/issues/437) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Remove DictionaryArray::keys\_array [\#391](https://github.com/apache/arrow-rs/issues/391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -**Implemented enhancements:** - -- sort kernel boolean sort can be O\(n\) [\#447](https://github.com/apache/arrow-rs/issues/447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- C data interface for decimal128, timestamp, date32 and date64 [\#413](https://github.com/apache/arrow-rs/issues/413) -- Add Decimal to CsvWriter [\#405](https://github.com/apache/arrow-rs/issues/405) -- Use iterators to increase performance of creating Arrow arrays [\#200](https://github.com/apache/arrow-rs/issues/200) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] - -**Fixed bugs:** - -- Release Audit Tool \(RAT\) is not being triggered [\#481](https://github.com/apache/arrow-rs/issues/481) -- Security Vulnerabilities: flatbuffers: `read_scalar` and `read_scalar_at` allow transmuting values without `unsafe` blocks [\#476](https://github.com/apache/arrow-rs/issues/476) -- Clippy broken after upgrade to rust 1.53 [\#467](https://github.com/apache/arrow-rs/issues/467) -- Pull Request Labeler is not working [\#462](https://github.com/apache/arrow-rs/issues/462) -- Arrow 4.3 release: error\[E0658\]: use of unstable library feature 'partition\_point': new API [\#456](https://github.com/apache/arrow-rs/issues/456) -- parquet reading hangs when row\_group contains more than 2048 rows of data [\#349](https://github.com/apache/arrow-rs/issues/349) -- Fail to build arrow [\#247](https://github.com/apache/arrow-rs/issues/247) -- JSON reader does not implement iterator [\#193](https://github.com/apache/arrow-rs/issues/193) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -**Security fixes:** - -- Ensure a successful MIRI Run on CI [\#227](https://github.com/apache/arrow-rs/issues/227) - -**Closed issues:** - -- sort kernel has a lot of unnecessary wrapping [\#446](https://github.com/apache/arrow-rs/issues/446) -- \[Parquet\] Plain encoded boolean column chunks limited to 2048 values [\#48](https://github.com/apache/arrow-rs/issues/48) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] - -## [4.3.0](https://github.com/apache/arrow-rs/tree/4.3.0) (2021-06-10) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/4.2.0...4.3.0) - -**Implemented enhancements:** - -- Add partitioning kernel for sorted arrays [\#428](https://github.com/apache/arrow-rs/issues/428) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Implement sort by float lists [\#427](https://github.com/apache/arrow-rs/issues/427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Derive Eq and PartialEq for SortOptions [\#426](https://github.com/apache/arrow-rs/issues/426) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- use prettier and github action to normalize markdown document syntax [\#399](https://github.com/apache/arrow-rs/issues/399) -- window::shift can work for more than just primitive array type [\#392](https://github.com/apache/arrow-rs/issues/392) -- Doctest for ArrayBuilder [\#366](https://github.com/apache/arrow-rs/issues/366) - -**Fixed bugs:** - -- Boolean `not` kernel does not take offset of null buffer into account [\#417](https://github.com/apache/arrow-rs/issues/417) -- my contribution not marged in 4.2 release [\#394](https://github.com/apache/arrow-rs/issues/394) -- window::shift shall properly handle boundary cases [\#387](https://github.com/apache/arrow-rs/issues/387) -- Parquet `WriterProperties.max_row_group_size` not wired up [\#257](https://github.com/apache/arrow-rs/issues/257) -- Out of bound reads in chunk iterator [\#198](https://github.com/apache/arrow-rs/issues/198) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -## [4.2.0](https://github.com/apache/arrow-rs/tree/4.2.0) (2021-05-29) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/4.1.0...4.2.0) - -**Breaking changes:** - -- DictionaryArray::values\(\) clones the underlying ArrayRef [\#313](https://github.com/apache/arrow-rs/issues/313) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -**Implemented enhancements:** - -- Simplify shift kernel using null array [\#371](https://github.com/apache/arrow-rs/issues/371) -- Provide `Arc`-based constructor for `parquet::util::cursor::SliceableCursor` [\#368](https://github.com/apache/arrow-rs/issues/368) -- Add badges to crates [\#361](https://github.com/apache/arrow-rs/issues/361) -- Consider inlining PrimitiveArray::value [\#328](https://github.com/apache/arrow-rs/issues/328) -- Implement automated release verification script [\#327](https://github.com/apache/arrow-rs/issues/327) -- Add wasm32 to the list of target architectures of the simd feature [\#316](https://github.com/apache/arrow-rs/issues/316) -- add with\_escape for csv::ReaderBuilder [\#315](https://github.com/apache/arrow-rs/issues/315) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- IPC feature gate [\#310](https://github.com/apache/arrow-rs/issues/310) -- csv feature gate [\#309](https://github.com/apache/arrow-rs/issues/309) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add `shrink_to` / `shrink_to_fit` to `MutableBuffer` [\#297](https://github.com/apache/arrow-rs/issues/297) - -**Fixed bugs:** - -- Incorrect crate setup instructions [\#364](https://github.com/apache/arrow-rs/issues/364) -- Arrow-flight only register rerun-if-changed if file exists [\#350](https://github.com/apache/arrow-rs/issues/350) -- Dictionary Comparison Uses Wrong Values Array [\#332](https://github.com/apache/arrow-rs/issues/332) -- Undefined behavior in FFI implementation [\#322](https://github.com/apache/arrow-rs/issues/322) -- All-null column get wrong parquet null-counts [\#306](https://github.com/apache/arrow-rs/issues/306) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Filter has inconsistent null handling [\#295](https://github.com/apache/arrow-rs/issues/295) - -## [4.1.0](https://github.com/apache/arrow-rs/tree/4.1.0) (2021-05-17) - -[Full Changelog](https://github.com/apache/arrow-rs/compare/4.0.0...4.1.0) - -**Implemented enhancements:** - -- Add Send to ArrayBuilder [\#290](https://github.com/apache/arrow-rs/issues/290) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Improve performance of bound checking option [\#280](https://github.com/apache/arrow-rs/issues/280) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- extend compute kernel arity to include nullary functions [\#276](https://github.com/apache/arrow-rs/issues/276) -- Implement FFI / CDataInterface for Struct Arrays [\#251](https://github.com/apache/arrow-rs/issues/251) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add support for pretty-printing Decimal numbers [\#230](https://github.com/apache/arrow-rs/issues/230) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- CSV Reader String Dictionary Support [\#228](https://github.com/apache/arrow-rs/issues/228) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add Builder interface for adding Arrays to record batches [\#210](https://github.com/apache/arrow-rs/issues/210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support auto-vectorization for min/max [\#209](https://github.com/apache/arrow-rs/issues/209) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support LargeUtf8 in sort kernel [\#25](https://github.com/apache/arrow-rs/issues/25) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -**Fixed bugs:** - -- no method named `select_nth_unstable_by` found for mutable reference `&mut [T]` [\#283](https://github.com/apache/arrow-rs/issues/283) -- Rust 1.52 Clippy error [\#266](https://github.com/apache/arrow-rs/issues/266) -- NaNs can break parquet statistics [\#255](https://github.com/apache/arrow-rs/issues/255) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- u64::MAX does not roundtrip through parquet [\#254](https://github.com/apache/arrow-rs/issues/254) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Integration tests failing to compile \(flatbuffer\) [\#249](https://github.com/apache/arrow-rs/issues/249) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Fix compatibility quirks between arrow and parquet structs [\#245](https://github.com/apache/arrow-rs/issues/245) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Unable to write non-null Arrow structs to Parquet [\#244](https://github.com/apache/arrow-rs/issues/244) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- schema: missing field `metadata` when deserialize [\#241](https://github.com/apache/arrow-rs/issues/241) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Arrow does not compile due to flatbuffers upgrade [\#238](https://github.com/apache/arrow-rs/issues/238) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Sort with limit panics for the limit includes some but not all nulls, for large arrays [\#235](https://github.com/apache/arrow-rs/issues/235) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- arrow-rs contains a copy of the "format" directory [\#233](https://github.com/apache/arrow-rs/issues/233) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Fix SEGFAULT/ SIGILL in child-data ffi [\#206](https://github.com/apache/arrow-rs/issues/206) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Read list field correctly in \\> [\#167](https://github.com/apache/arrow-rs/issues/167) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- FFI listarray lead to undefined behavior. [\#20](https://github.com/apache/arrow-rs/issues/20) - -**Security fixes:** - -- Fix MIRI build on CI [\#226](https://github.com/apache/arrow-rs/issues/226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Get MIRI running again [\#224](https://github.com/apache/arrow-rs/issues/224) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] - -**Documentation updates:** - -- Comment out the instructions in the PR template [\#277](https://github.com/apache/arrow-rs/issues/277) -- Update links to datafusion and ballista in README.md [\#19](https://github.com/apache/arrow-rs/issues/19) -- Update "repository" in Cargo.toml [\#12](https://github.com/apache/arrow-rs/issues/12) - -**Closed issues:** - -- Arrow Aligned Vec [\#268](https://github.com/apache/arrow-rs/issues/268) -- \[Rust\]: Tracking issue for AVX-512 [\#220](https://github.com/apache/arrow-rs/issues/220) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Umbrella issue for clippy integration [\#217](https://github.com/apache/arrow-rs/issues/217) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support sort [\#215](https://github.com/apache/arrow-rs/issues/215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support stable Rust [\#214](https://github.com/apache/arrow-rs/issues/214) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Remove Rust and point integration tests to arrow-rs repo [\#211](https://github.com/apache/arrow-rs/issues/211) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- ArrayData buffers are inconsistent accross implementations [\#207](https://github.com/apache/arrow-rs/issues/207) -- 3.0.1 patch release [\#204](https://github.com/apache/arrow-rs/issues/204) -- Document patch release process [\#202](https://github.com/apache/arrow-rs/issues/202) -- Simplify Offset [\#186](https://github.com/apache/arrow-rs/issues/186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Typed Bytes [\#185](https://github.com/apache/arrow-rs/issues/185) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[CI\]docker-compose setup should enable caching [\#175](https://github.com/apache/arrow-rs/issues/175) -- Improve take primitive performance [\#174](https://github.com/apache/arrow-rs/issues/174) -- \[CI\] Try out buildkite [\#165](https://github.com/apache/arrow-rs/issues/165) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Update assignees in JIRA where missing [\#160](https://github.com/apache/arrow-rs/issues/160) -- \[Rust\]: From\ implementations should validate data type [\#103](https://github.com/apache/arrow-rs/issues/103) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] Verify that projection push down does not remove aliases columns [\#99](https://github.com/apache/arrow-rs/issues/99) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Rust\]\[DataFusion\] Implement modulus expression [\#98](https://github.com/apache/arrow-rs/issues/98) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] Add constant folding to expressions during logically planning [\#96](https://github.com/apache/arrow-rs/issues/96) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] DataFrame.collect should return RecordBatchReader [\#95](https://github.com/apache/arrow-rs/issues/95) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Rust\]\[DataFusion\] Add FORMAT to explain plan and an easy to visualize format [\#94](https://github.com/apache/arrow-rs/issues/94) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] Implement metrics framework [\#90](https://github.com/apache/arrow-rs/issues/90) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] Implement micro benchmarks for each operator [\#89](https://github.com/apache/arrow-rs/issues/89) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] Implement pretty print for physical query plan [\#88](https://github.com/apache/arrow-rs/issues/88) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Archery\] Support rust clippy in the lint command [\#83](https://github.com/apache/arrow-rs/issues/83) -- \[rust\]\[datafusion\] optimize count\(\*\) queries on parquet sources [\#75](https://github.com/apache/arrow-rs/issues/75) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Rust\]\[DataFusion\] Improve like/nlike performance [\#71](https://github.com/apache/arrow-rs/issues/71) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] Implement optimizer rule to remove redundant projections [\#56](https://github.com/apache/arrow-rs/issues/56) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[DataFusion\] Parquet data source does not support complex types [\#39](https://github.com/apache/arrow-rs/issues/39) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Merge utils from Parquet and Arrow [\#32](https://github.com/apache/arrow-rs/issues/32) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Add benchmarks for Parquet [\#30](https://github.com/apache/arrow-rs/issues/30) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Mark methods that do not perform bounds checking as unsafe [\#28](https://github.com/apache/arrow-rs/issues/28) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Test issue [\#24](https://github.com/apache/arrow-rs/issues/24) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- This is a test issue [\#11](https://github.com/apache/arrow-rs/issues/11) +- Make equals\_datatype method public, enabling other modules [\#1838](https://github.com/apache/arrow-rs/pull/1838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nl5887](https://github.com/nl5887)) +- \[Minor\] Clarify `PageIterator` Documentation [\#1831](https://github.com/apache/arrow-rs/pull/1831) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Update MIRI pin [\#1828](https://github.com/apache/arrow-rs/pull/1828) ([tustvold](https://github.com/tustvold)) +- Change to use `resolver v2`, test more feature flag combinations in CI, fix errors \(\#1630\) [\#1822](https://github.com/apache/arrow-rs/pull/1822) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add ScalarBuffer abstraction \(\#1811\) [\#1820](https://github.com/apache/arrow-rs/pull/1820) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix list equal for empty offset list array [\#1818](https://github.com/apache/arrow-rs/pull/1818) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix Decimal and List ArrayData Validation \(\#1813\) \(\#1814\) [\#1816](https://github.com/apache/arrow-rs/pull/1816) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Don't overwrite existing data on snappy decompress \(\#1806\) [\#1807](https://github.com/apache/arrow-rs/pull/1807) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Rename `arrow/benches/string_kernels.rs` to `arrow/benches/substring_kernels.rs` [\#1805](https://github.com/apache/arrow-rs/pull/1805) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add public API for decoding parquet footer [\#1804](https://github.com/apache/arrow-rs/pull/1804) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add AsyncFileReader trait [\#1803](https://github.com/apache/arrow-rs/pull/1803) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- add parquet-fromcsv \(\#1\) [\#1798](https://github.com/apache/arrow-rs/pull/1798) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kazuk](https://github.com/kazuk)) +- Use IPC row count info in IPC reader [\#1796](https://github.com/apache/arrow-rs/pull/1796) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix typos in the Memory and Buffers section of the docs home [\#1795](https://github.com/apache/arrow-rs/pull/1795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([datapythonista](https://github.com/datapythonista)) +- Write validity buffer for UnionArray in V4 IPC message [\#1794](https://github.com/apache/arrow-rs/pull/1794) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- feat:Add function for row alignment with page mask [\#1791](https://github.com/apache/arrow-rs/pull/1791) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Read and skip validity buffer of UnionType Array for V4 ipc message [\#1789](https://github.com/apache/arrow-rs/pull/1789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Add `Substring_by_char` [\#1784](https://github.com/apache/arrow-rs/pull/1784) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add `ParquetFileArrowReader::try_new` [\#1782](https://github.com/apache/arrow-rs/pull/1782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Arbitrary size combine option bitmap [\#1781](https://github.com/apache/arrow-rs/pull/1781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ismail-Maj](https://github.com/Ismail-Maj)) +- Implement `ChunkReader` for `Bytes`, deprecate `SliceableCursor` [\#1775](https://github.com/apache/arrow-rs/pull/1775) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Access metadata of flushed row groups on write \(\#1691\) [\#1774](https://github.com/apache/arrow-rs/pull/1774) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Simplify ParquetFileArrowReader Metadata API [\#1773](https://github.com/apache/arrow-rs/pull/1773) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- MINOR: Unpin nightly version as packed\_simd releases new version [\#1771](https://github.com/apache/arrow-rs/pull/1771) ([viirya](https://github.com/viirya)) +- Update comfy-table requirement from 5.0 to 6.0 [\#1769](https://github.com/apache/arrow-rs/pull/1769) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Optionally disable `validate_decimal_precision` check in `DecimalBuilder.append_value` for interop test [\#1767](https://github.com/apache/arrow-rs/pull/1767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Minor: Clean up the code of MutableArrayData [\#1763](https://github.com/apache/arrow-rs/pull/1763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Support reading PageIndex from parquet metadata, prepare for skipping pages at reading [\#1762](https://github.com/apache/arrow-rs/pull/1762) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Support casting `Utf8` to `Boolean` [\#1738](https://github.com/apache/arrow-rs/pull/1738) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([MazterQyou](https://github.com/MazterQyou)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9fe6c48a243d..4e4c53e5e2bd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -17,8 +17,30 @@ under the License. --> +## Introduction + +We welcome and encourage contributions of all kinds, such as: + +1. Tickets with issue reports of feature requests +2. Documentation improvements +3. Code (PR or PR Review) + +In addition to submitting new PRs, we have a healthy tradition of community members helping review each other's PRs. Doing so is a great way to help the community as well as get more familiar with Rust and the relevant codebases. + ## Developer's guide to Arrow Rust +### Setting Up Your Build Environment + +Install the Rust tool chain: + +https://www.rust-lang.org/tools/install + +Also, make sure your Rust tool chain is up-to-date, because we always use the latest stable version of Rust to test this project. + +```bash +rustup update stable +``` + ### How to compile This is a standard cargo project with workspaces. To build it, you need to have `rust` and `cargo`: diff --git a/Cargo.toml b/Cargo.toml index de7d36f34814..2837f028e8c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,15 @@ members = [ "arrow-flight", "integration-testing", ] +# Enable the version 2 feature resolver, which avoids unifying features for targets that are not being built +# +# Critically this prevents dev-dependencies from enabling features even when not building a target that +# uses dev-dependencies, e.g. the library crate. This in turn ensures that we can catch invalid feature +# flag combinations that would otherwise only surface in dependent crates +# +# Reference - https://doc.rust-lang.org/nightly/cargo/reference/features.html#feature-resolver-version-2 +# +resolver = "2" # this package is excluded because it requires different compilation flags, thereby significantly changing # how it is compiled within the workspace, causing the whole workspace to be compiled from scratch diff --git a/README.md b/README.md index 54dcbe74ff07..08385fb6c15d 100644 --- a/README.md +++ b/README.md @@ -58,10 +58,13 @@ and bug fixes and this plays a critical role in the release process. For design discussions we generally collaborate on Google documents and file a GitHub issue linking to the document. +There is more information in the [contributing] guide. + [rust]: https://www.rust-lang.org/ [arrow-readme]: arrow/README.md +[contributing]: CONTRIBUTING.md [parquet-readme]: parquet/README.md [flight-readme]: arrow-flight/README.md [datafusion-readme]: https://github.com/apache/arrow-datafusion/blob/master/README.md -[ballista-readme]: https://github.com/apache/arrow-datafusion/blob/master/ballista/README.md +[ballista-readme]: https://github.com/apache/arrow-ballista/blob/master/README.md [issues]: https://github.com/apache/arrow-rs/issues diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 4b50f4e5b9e9..c5522766e4bd 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -18,31 +18,37 @@ [package] name = "arrow-flight" description = "Apache Arrow Flight" -version = "7.0.0-SNAPSHOT" -edition = "2018" +version = "16.0.0" +edition = "2021" +rust-version = "1.57" authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" license = "Apache-2.0" [dependencies] -arrow = { path = "../arrow", version = "7.0.0-SNAPSHOT" } -base64 = "0.13" -tonic = "0.5" -bytes = "1" -prost = "0.8" -prost-derive = "0.8" -tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } +arrow = { path = "../arrow", version = "16.0.0", default-features = false, features = ["ipc"] } +base64 = { version = "0.13", default-features = false } +tonic = { version = "0.7", default-features = false, features = ["transport", "codegen", "prost"] } +bytes = { version = "1", default-features = false } +prost = { version = "0.10", default-features = false } +prost-types = { version = "0.10.0", default-features = false, optional = true } +prost-derive = { version = "0.10", default-features = false } +tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } +futures = { version = "0.3", default-features = false, features = ["alloc"]} + +[features] +default = [] +flight-sql-experimental = ["prost-types"] [dev-dependencies] -futures = { version = "0.3", default-features = false, features = ["alloc"]} [build-dependencies] -tonic-build = "0.5" +tonic-build = { version = "0.7", default-features = false, features = ["transport", "prost"] } # Pin specific version of the tonic-build dependencies to avoid auto-generated # (and checked in) arrow.flight.protocol.rs from changing -proc-macro2 = "=1.0.27" +proc-macro2 = { version = ">1.0.30", default-features = false } -#[lib] -#name = "flight" -#path = "src/lib.rs" +[[example]] +name = "flight_sql_server" +required-features = ["flight-sql-experimental"] diff --git a/arrow-flight/README.md b/arrow-flight/README.md index b9bc466e205e..a951699f40aa 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -27,7 +27,7 @@ Add this to your Cargo.toml: ```toml [dependencies] -arrow-flight = "5.0" +arrow-flight = "16.0.0" ``` Apache Arrow Flight is a gRPC based protocol for exchanging Arrow data between processes. See the blog post [Introducing Apache Arrow Flight: A Framework for Fast Data Transport](https://arrow.apache.org/blog/2019/10/13/introducing-arrow-flight/) for more information. diff --git a/arrow-flight/build.rs b/arrow-flight/build.rs index 1cbfceb9262c..2054f8e7118e 100644 --- a/arrow-flight/build.rs +++ b/arrow-flight/build.rs @@ -33,7 +33,14 @@ fn main() -> Result<(), Box> { // avoid rerunning build if the file has not changed println!("cargo:rerun-if-changed=../format/Flight.proto"); - tonic_build::compile_protos("../format/Flight.proto")?; + let proto_dir = Path::new("../format"); + let proto_path = Path::new("../format/Flight.proto"); + + tonic_build::configure() + // protoc in unbuntu builder needs this option + .protoc_arg("--experimental_allow_proto3_optional") + .compile(&[proto_path], &[proto_dir])?; + // read file contents to string let mut file = OpenOptions::new() .read(true) @@ -49,6 +56,38 @@ fn main() -> Result<(), Box> { file.write_all(buffer.as_bytes())?; } + // override the build location, in order to check in the changes to proto files + env::set_var("OUT_DIR", "src/sql"); + // The current working directory can vary depending on how the project is being + // built or released so we build an absolute path to the proto file + let path = Path::new("../format/FlightSql.proto"); + if path.exists() { + // avoid rerunning build if the file has not changed + println!("cargo:rerun-if-changed=../format/FlightSql.proto"); + + let proto_dir = Path::new("../format"); + let proto_path = Path::new("../format/FlightSql.proto"); + + tonic_build::configure() + // protoc in unbuntu builder needs this option + .protoc_arg("--experimental_allow_proto3_optional") + .compile(&[proto_path], &[proto_dir])?; + + // read file contents to string + let mut file = OpenOptions::new() + .read(true) + .open("src/sql/arrow.flight.protocol.sql.rs")?; + let mut buffer = String::new(); + file.read_to_string(&mut buffer)?; + // append warning that file was auto-generate + let mut file = OpenOptions::new() + .write(true) + .truncate(true) + .open("src/sql/arrow.flight.protocol.sql.rs")?; + file.write_all("// This file was automatically generated through the build.rs script, and should not be edited.\n\n".as_bytes())?; + file.write_all(buffer.as_bytes())?; + } + // As the proto file is checked in, the build should not fail if the file is not found Ok(()) } diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs new file mode 100644 index 000000000000..8b4fe477b868 --- /dev/null +++ b/arrow-flight/examples/flight_sql_server.rs @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; +use arrow_flight::FlightData; +use tonic::transport::Server; +use tonic::{Response, Status, Streaming}; + +use arrow_flight::{ + flight_service_server::FlightService, + flight_service_server::FlightServiceServer, + sql::{ + server::FlightSqlService, ActionClosePreparedStatementRequest, + ActionCreatePreparedStatementRequest, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, + CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, + CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery, + CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementUpdate, + TicketStatementQuery, + }, + FlightDescriptor, FlightInfo, +}; + +#[derive(Clone)] +pub struct FlightSqlServiceImpl {} + +#[tonic::async_trait] +impl FlightSqlService for FlightSqlServiceImpl { + type FlightService = FlightSqlServiceImpl; + // get_flight_info + async fn get_flight_info_statement( + &self, + _query: CommandStatementQuery, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_prepared_statement( + &self, + _query: CommandPreparedStatementQuery, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_catalogs( + &self, + _query: CommandGetCatalogs, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_schemas( + &self, + _query: CommandGetDbSchemas, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_tables( + &self, + _query: CommandGetTables, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_table_types( + &self, + _query: CommandGetTableTypes, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_sql_info( + &self, + _query: CommandGetSqlInfo, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_primary_keys( + &self, + _query: CommandGetPrimaryKeys, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_exported_keys( + &self, + _query: CommandGetExportedKeys, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_imported_keys( + &self, + _query: CommandGetImportedKeys, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn get_flight_info_cross_reference( + &self, + _query: CommandGetCrossReference, + _request: FlightDescriptor, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + // do_get + async fn do_get_statement( + &self, + _ticket: TicketStatementQuery, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get_prepared_statement( + &self, + _query: CommandPreparedStatementQuery, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_catalogs( + &self, + _query: CommandGetCatalogs, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_schemas( + &self, + _query: CommandGetDbSchemas, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_tables( + &self, + _query: CommandGetTables, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_table_types( + &self, + _query: CommandGetTableTypes, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_sql_info( + &self, + _query: CommandGetSqlInfo, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_primary_keys( + &self, + _query: CommandGetPrimaryKeys, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_exported_keys( + &self, + _query: CommandGetExportedKeys, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_imported_keys( + &self, + _query: CommandGetImportedKeys, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_get_cross_reference( + &self, + _query: CommandGetCrossReference, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + // do_put + async fn do_put_statement_update( + &self, + _ticket: CommandStatementUpdate, + ) -> Result { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_put_prepared_statement_query( + &self, + _query: CommandPreparedStatementQuery, + _request: Streaming, + ) -> Result::DoPutStream>, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_put_prepared_statement_update( + &self, + _query: CommandPreparedStatementUpdate, + _request: Streaming, + ) -> Result { + Err(Status::unimplemented("Not yet implemented")) + } + // do_action + async fn do_action_create_prepared_statement( + &self, + _query: ActionCreatePreparedStatementRequest, + ) -> Result { + Err(Status::unimplemented("Not yet implemented")) + } + async fn do_action_close_prepared_statement( + &self, + _query: ActionClosePreparedStatementRequest, + ) { + unimplemented!("Not yet implemented") + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} +} + +/// This example shows how to run a FlightSql server +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "0.0.0.0:50051".parse()?; + + let svc = FlightServiceServer::new(FlightSqlServiceImpl {}); + + println!("Listening on {:?}", addr); + + Server::builder().add_service(svc).serve(addr).await?; + + Ok(()) +} diff --git a/arrow-flight/src/arrow.flight.protocol.rs b/arrow-flight/src/arrow.flight.protocol.rs index b1a79ee72031..c76469b39ce7 100644 --- a/arrow-flight/src/arrow.flight.protocol.rs +++ b/arrow-flight/src/arrow.flight.protocol.rs @@ -6,43 +6,44 @@ pub struct HandshakeRequest { /// /// A defined protocol version - #[prost(uint64, tag = "1")] + #[prost(uint64, tag="1")] pub protocol_version: u64, /// /// Arbitrary auth/handshake info. - #[prost(bytes = "vec", tag = "2")] + #[prost(bytes="vec", tag="2")] pub payload: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct HandshakeResponse { /// /// A defined protocol version - #[prost(uint64, tag = "1")] + #[prost(uint64, tag="1")] pub protocol_version: u64, /// /// Arbitrary auth/handshake info. - #[prost(bytes = "vec", tag = "2")] + #[prost(bytes="vec", tag="2")] pub payload: ::prost::alloc::vec::Vec, } /// /// A message for doing simple auth. #[derive(Clone, PartialEq, ::prost::Message)] pub struct BasicAuth { - #[prost(string, tag = "2")] + #[prost(string, tag="2")] pub username: ::prost::alloc::string::String, - #[prost(string, tag = "3")] + #[prost(string, tag="3")] pub password: ::prost::alloc::string::String, } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct Empty {} +pub struct Empty { +} /// /// Describes an available action, including both the name used for execution /// along with a short description of the purpose of the action. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionType { - #[prost(string, tag = "1")] + #[prost(string, tag="1")] pub r#type: ::prost::alloc::string::String, - #[prost(string, tag = "2")] + #[prost(string, tag="2")] pub description: ::prost::alloc::string::String, } /// @@ -50,23 +51,23 @@ pub struct ActionType { /// of available Arrow Flight streams. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Criteria { - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes="vec", tag="1")] pub expression: ::prost::alloc::vec::Vec, } /// /// An opaque action specific for the service. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Action { - #[prost(string, tag = "1")] + #[prost(string, tag="1")] pub r#type: ::prost::alloc::string::String, - #[prost(bytes = "vec", tag = "2")] + #[prost(bytes="vec", tag="2")] pub body: ::prost::alloc::vec::Vec, } /// /// An opaque result returned after executing an action. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Result { - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes="vec", tag="1")] pub body: ::prost::alloc::vec::Vec, } /// @@ -74,7 +75,7 @@ pub struct Result { #[derive(Clone, PartialEq, ::prost::Message)] pub struct SchemaResult { /// schema of the dataset as described in Schema.fbs::Schema. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes="vec", tag="1")] pub schema: ::prost::alloc::vec::Vec, } /// @@ -82,26 +83,24 @@ pub struct SchemaResult { /// a flight or be used to expose a set of previously defined flights. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightDescriptor { - #[prost(enumeration = "flight_descriptor::DescriptorType", tag = "1")] + #[prost(enumeration="flight_descriptor::DescriptorType", tag="1")] pub r#type: i32, /// /// Opaque value used to express a command. Should only be defined when /// type = CMD. - #[prost(bytes = "vec", tag = "2")] + #[prost(bytes="vec", tag="2")] pub cmd: ::prost::alloc::vec::Vec, /// /// List of strings identifying a particular dataset. Should only be defined /// when type = PATH. - #[prost(string, repeated, tag = "3")] + #[prost(string, repeated, tag="3")] pub path: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } /// Nested message and enum types in `FlightDescriptor`. pub mod flight_descriptor { /// /// Describes what type of descriptor is defined. - #[derive( - Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration, - )] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum DescriptorType { /// Protobuf pattern, not used. @@ -122,21 +121,21 @@ pub mod flight_descriptor { #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightInfo { /// schema of the dataset as described in Schema.fbs::Schema. - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes="vec", tag="1")] pub schema: ::prost::alloc::vec::Vec, /// /// The descriptor associated with this info. - #[prost(message, optional, tag = "2")] + #[prost(message, optional, tag="2")] pub flight_descriptor: ::core::option::Option, /// /// A list of endpoints associated with the flight. To consume the whole /// flight, all endpoints must be consumed. - #[prost(message, repeated, tag = "3")] + #[prost(message, repeated, tag="3")] pub endpoint: ::prost::alloc::vec::Vec, /// Set these to -1 if unknown. - #[prost(int64, tag = "4")] + #[prost(int64, tag="4")] pub total_records: i64, - #[prost(int64, tag = "5")] + #[prost(int64, tag="5")] pub total_bytes: i64, } /// @@ -145,13 +144,13 @@ pub struct FlightInfo { pub struct FlightEndpoint { /// /// Token used to retrieve this stream. - #[prost(message, optional, tag = "1")] + #[prost(message, optional, tag="1")] pub ticket: ::core::option::Option, /// /// A list of URIs where this ticket can be redeemed. If the list is /// empty, the expectation is that the ticket can only be redeemed on the /// current service where the ticket was generated. - #[prost(message, repeated, tag = "2")] + #[prost(message, repeated, tag="2")] pub location: ::prost::alloc::vec::Vec, } /// @@ -159,7 +158,7 @@ pub struct FlightEndpoint { /// stream given a ticket. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Location { - #[prost(string, tag = "1")] + #[prost(string, tag="1")] pub uri: ::prost::alloc::string::String, } /// @@ -167,7 +166,7 @@ pub struct Location { /// portion of a stream. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Ticket { - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes="vec", tag="1")] pub ticket: ::prost::alloc::vec::Vec, } /// @@ -177,46 +176,46 @@ pub struct FlightData { /// /// The descriptor of the data. This is only relevant when a client is /// starting a new DoPut stream. - #[prost(message, optional, tag = "1")] + #[prost(message, optional, tag="1")] pub flight_descriptor: ::core::option::Option, /// /// Header for message data as described in Message.fbs::Message. - #[prost(bytes = "vec", tag = "2")] + #[prost(bytes="vec", tag="2")] pub data_header: ::prost::alloc::vec::Vec, /// /// Application-defined metadata. - #[prost(bytes = "vec", tag = "3")] + #[prost(bytes="vec", tag="3")] pub app_metadata: ::prost::alloc::vec::Vec, /// /// The actual batch of Arrow data. Preferably handled with minimal-copies /// coming last in the definition to help with sidecar patterns (it is /// expected that some implementations will fetch this field off the wire /// with specialized code to avoid extra memory copies). - #[prost(bytes = "vec", tag = "1000")] + #[prost(bytes="vec", tag="1000")] pub data_body: ::prost::alloc::vec::Vec, } ///* /// The response message associated with the submission of a DoPut. #[derive(Clone, PartialEq, ::prost::Message)] pub struct PutResult { - #[prost(bytes = "vec", tag = "1")] + #[prost(bytes="vec", tag="1")] pub app_metadata: ::prost::alloc::vec::Vec, } -#[doc = r" Generated client implementations."] +/// Generated client implementations. pub mod flight_service_client { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] use tonic::codegen::*; - #[doc = ""] - #[doc = " A flight service is an endpoint for retrieving or storing Arrow data. A"] - #[doc = " flight service can expose one or more predefined endpoints that can be"] - #[doc = " accessed using the Arrow Flight Protocol. Additionally, a flight service"] - #[doc = " can expose a set of actions that are available."] + /// + /// A flight service is an endpoint for retrieving or storing Arrow data. A + /// flight service can expose one or more predefined endpoints that can be + /// accessed using the Arrow Flight Protocol. Additionally, a flight service + /// can expose a set of actions that are available. #[derive(Debug, Clone)] pub struct FlightServiceClient { inner: tonic::client::Grpc, } impl FlightServiceClient { - #[doc = r" Attempt to create a new client by connecting to a given endpoint."] + /// Attempt to create a new client by connecting to a given endpoint. pub async fn connect(dst: D) -> Result where D: std::convert::TryInto, @@ -229,8 +228,8 @@ pub mod flight_service_client { impl FlightServiceClient where T: tonic::client::GrpcService, - T::ResponseBody: Body + Send + Sync + 'static, T::Error: Into, + T::ResponseBody: Body + Send + 'static, ::Error: Into + Send, { pub fn new(inner: T) -> Self { @@ -243,35 +242,39 @@ pub mod flight_service_client { ) -> FlightServiceClient> where F: tonic::service::Interceptor, + T::ResponseBody: Default, T: tonic::codegen::Service< http::Request, Response = http::Response< >::ResponseBody, >, >, - >>::Error: - Into + Send + Sync, + , + >>::Error: Into + Send + Sync, { FlightServiceClient::new(InterceptedService::new(inner, interceptor)) } - #[doc = r" Compress requests with `gzip`."] - #[doc = r""] - #[doc = r" This requires the server to support it otherwise it might respond with an"] - #[doc = r" error."] + /// Compress requests with `gzip`. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] pub fn send_gzip(mut self) -> Self { self.inner = self.inner.send_gzip(); self } - #[doc = r" Enable decompressing responses with `gzip`."] + /// Enable decompressing responses with `gzip`. + #[must_use] pub fn accept_gzip(mut self) -> Self { self.inner = self.inner.accept_gzip(); self } - #[doc = ""] - #[doc = " Handshake between client and server. Depending on the server, the"] - #[doc = " handshake may be required to determine the token that should be used for"] - #[doc = " future operations. Both request and response are streams to allow multiple"] - #[doc = " round-trips depending on auth mechanism."] + /// + /// Handshake between client and server. Depending on the server, the + /// handshake may be required to determine the token that should be used for + /// future operations. Both request and response are streams to allow multiple + /// round-trips depending on auth mechanism. pub async fn handshake( &mut self, request: impl tonic::IntoStreamingRequest, @@ -279,27 +282,28 @@ pub mod flight_service_client { tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/Handshake", ); - self.inner - .streaming(request.into_streaming_request(), path, codec) - .await + self.inner.streaming(request.into_streaming_request(), path, codec).await } - #[doc = ""] - #[doc = " Get a list of available streams given a particular criteria. Most flight"] - #[doc = " services will expose one or more streams that are readily available for"] - #[doc = " retrieval. This api allows listing the streams available for"] - #[doc = " consumption. A user can also provide a criteria. The criteria can limit"] - #[doc = " the subset of streams that can be listed via this interface. Each flight"] - #[doc = " service allows its own definition of how to consume criteria."] + /// + /// Get a list of available streams given a particular criteria. Most flight + /// services will expose one or more streams that are readily available for + /// retrieval. This api allows listing the streams available for + /// consumption. A user can also provide a criteria. The criteria can limit + /// the subset of streams that can be listed via this interface. Each flight + /// service allows its own definition of how to consume criteria. pub async fn list_flights( &mut self, request: impl tonic::IntoRequest, @@ -307,73 +311,80 @@ pub mod flight_service_client { tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListFlights", ); - self.inner - .server_streaming(request.into_request(), path, codec) - .await + self.inner.server_streaming(request.into_request(), path, codec).await } - #[doc = ""] - #[doc = " For a given FlightDescriptor, get information about how the flight can be"] - #[doc = " consumed. This is a useful interface if the consumer of the interface"] - #[doc = " already can identify the specific flight to consume. This interface can"] - #[doc = " also allow a consumer to generate a flight stream through a specified"] - #[doc = " descriptor. For example, a flight descriptor might be something that"] - #[doc = " includes a SQL statement or a Pickled Python operation that will be"] - #[doc = " executed. In those cases, the descriptor will not be previously available"] - #[doc = " within the list of available streams provided by ListFlights but will be"] - #[doc = " available for consumption for the duration defined by the specific flight"] - #[doc = " service."] + /// + /// For a given FlightDescriptor, get information about how the flight can be + /// consumed. This is a useful interface if the consumer of the interface + /// already can identify the specific flight to consume. This interface can + /// also allow a consumer to generate a flight stream through a specified + /// descriptor. For example, a flight descriptor might be something that + /// includes a SQL statement or a Pickled Python operation that will be + /// executed. In those cases, the descriptor will not be previously available + /// within the list of available streams provided by ListFlights but will be + /// available for consumption for the duration defined by the specific flight + /// service. pub async fn get_flight_info( &mut self, request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetFlightInfo", ); self.inner.unary(request.into_request(), path, codec).await } - #[doc = ""] - #[doc = " For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema"] - #[doc = " This is used when a consumer needs the Schema of flight stream. Similar to"] - #[doc = " GetFlightInfo this interface may generate a new flight that was not previously"] - #[doc = " available in ListFlights."] + /// + /// For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema + /// This is used when a consumer needs the Schema of flight stream. Similar to + /// GetFlightInfo this interface may generate a new flight that was not previously + /// available in ListFlights. pub async fn get_schema( &mut self, request: impl tonic::IntoRequest, ) -> Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetSchema", ); self.inner.unary(request.into_request(), path, codec).await } - #[doc = ""] - #[doc = " Retrieve a single stream associated with a particular descriptor"] - #[doc = " associated with the referenced ticket. A Flight can be composed of one or"] - #[doc = " more streams where each stream can be retrieved using a separate opaque"] - #[doc = " ticket that the flight service uses for managing a collection of streams."] + /// + /// Retrieve a single stream associated with a particular descriptor + /// associated with the referenced ticket. A Flight can be composed of one or + /// more streams where each stream can be retrieved using a separate opaque + /// ticket that the flight service uses for managing a collection of streams. pub async fn do_get( &mut self, request: impl tonic::IntoRequest, @@ -381,27 +392,28 @@ pub mod flight_service_client { tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoGet", ); - self.inner - .server_streaming(request.into_request(), path, codec) - .await + self.inner.server_streaming(request.into_request(), path, codec).await } - #[doc = ""] - #[doc = " Push a stream to the flight service associated with a particular"] - #[doc = " flight stream. This allows a client of a flight service to upload a stream"] - #[doc = " of data. Depending on the particular flight service, a client consumer"] - #[doc = " could be allowed to upload a single stream per descriptor or an unlimited"] - #[doc = " number. In the latter, the service might implement a 'seal' action that"] - #[doc = " can be applied to a descriptor once all streams are uploaded."] + /// + /// Push a stream to the flight service associated with a particular + /// flight stream. This allows a client of a flight service to upload a stream + /// of data. Depending on the particular flight service, a client consumer + /// could be allowed to upload a single stream per descriptor or an unlimited + /// number. In the latter, the service might implement a 'seal' action that + /// can be applied to a descriptor once all streams are uploaded. pub async fn do_put( &mut self, request: impl tonic::IntoStreamingRequest, @@ -409,26 +421,27 @@ pub mod flight_service_client { tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoPut", ); - self.inner - .streaming(request.into_streaming_request(), path, codec) - .await + self.inner.streaming(request.into_streaming_request(), path, codec).await } - #[doc = ""] - #[doc = " Open a bidirectional data channel for a given descriptor. This"] - #[doc = " allows clients to send and receive arbitrary Arrow data and"] - #[doc = " application-specific metadata in a single logical stream. In"] - #[doc = " contrast to DoGet/DoPut, this is more suited for clients"] - #[doc = " offloading computation (rather than storage) to a Flight service."] + /// + /// Open a bidirectional data channel for a given descriptor. This + /// allows clients to send and receive arbitrary Arrow data and + /// application-specific metadata in a single logical stream. In + /// contrast to DoGet/DoPut, this is more suited for clients + /// offloading computation (rather than storage) to a Flight service. pub async fn do_exchange( &mut self, request: impl tonic::IntoStreamingRequest, @@ -436,50 +449,54 @@ pub mod flight_service_client { tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoExchange", ); - self.inner - .streaming(request.into_streaming_request(), path, codec) - .await + self.inner.streaming(request.into_streaming_request(), path, codec).await } - #[doc = ""] - #[doc = " Flight services can support an arbitrary number of simple actions in"] - #[doc = " addition to the possible ListFlights, GetFlightInfo, DoGet, DoPut"] - #[doc = " operations that are potentially available. DoAction allows a flight client"] - #[doc = " to do a specific action against a flight service. An action includes"] - #[doc = " opaque request and response objects that are specific to the type action"] - #[doc = " being undertaken."] + /// + /// Flight services can support an arbitrary number of simple actions in + /// addition to the possible ListFlights, GetFlightInfo, DoGet, DoPut + /// operations that are potentially available. DoAction allows a flight client + /// to do a specific action against a flight service. An action includes + /// opaque request and response objects that are specific to the type action + /// being undertaken. pub async fn do_action( &mut self, request: impl tonic::IntoRequest, - ) -> Result>, tonic::Status> - { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + ) -> Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoAction", ); - self.inner - .server_streaming(request.into_request(), path, codec) - .await + self.inner.server_streaming(request.into_request(), path, codec).await } - #[doc = ""] - #[doc = " A flight service exposes all of the available action types that it has"] - #[doc = " along with descriptions. This allows different flight consumers to"] - #[doc = " understand the capabilities of the flight service."] + /// + /// A flight service exposes all of the available action types that it has + /// along with descriptions. This allows different flight consumers to + /// understand the capabilities of the flight service. pub async fn list_actions( &mut self, request: impl tonic::IntoRequest, @@ -487,163 +504,171 @@ pub mod flight_service_client { tonic::Response>, tonic::Status, > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; let codec = tonic::codec::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListActions", ); - self.inner - .server_streaming(request.into_request(), path, codec) - .await + self.inner.server_streaming(request.into_request(), path, codec).await } } } -#[doc = r" Generated server implementations."] +/// Generated server implementations. pub mod flight_service_server { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] use tonic::codegen::*; - #[doc = "Generated trait containing gRPC methods that should be implemented for use with FlightServiceServer."] + ///Generated trait containing gRPC methods that should be implemented for use with FlightServiceServer. #[async_trait] pub trait FlightService: Send + Sync + 'static { - #[doc = "Server streaming response type for the Handshake method."] - type HandshakeStream: futures_core::Stream> + ///Server streaming response type for the Handshake method. + type HandshakeStream: futures_core::Stream< + Item = Result, + > + Send - + Sync + 'static; - #[doc = ""] - #[doc = " Handshake between client and server. Depending on the server, the"] - #[doc = " handshake may be required to determine the token that should be used for"] - #[doc = " future operations. Both request and response are streams to allow multiple"] - #[doc = " round-trips depending on auth mechanism."] + /// + /// Handshake between client and server. Depending on the server, the + /// handshake may be required to determine the token that should be used for + /// future operations. Both request and response are streams to allow multiple + /// round-trips depending on auth mechanism. async fn handshake( &self, request: tonic::Request>, ) -> Result, tonic::Status>; - #[doc = "Server streaming response type for the ListFlights method."] - type ListFlightsStream: futures_core::Stream> + ///Server streaming response type for the ListFlights method. + type ListFlightsStream: futures_core::Stream< + Item = Result, + > + Send - + Sync + 'static; - #[doc = ""] - #[doc = " Get a list of available streams given a particular criteria. Most flight"] - #[doc = " services will expose one or more streams that are readily available for"] - #[doc = " retrieval. This api allows listing the streams available for"] - #[doc = " consumption. A user can also provide a criteria. The criteria can limit"] - #[doc = " the subset of streams that can be listed via this interface. Each flight"] - #[doc = " service allows its own definition of how to consume criteria."] + /// + /// Get a list of available streams given a particular criteria. Most flight + /// services will expose one or more streams that are readily available for + /// retrieval. This api allows listing the streams available for + /// consumption. A user can also provide a criteria. The criteria can limit + /// the subset of streams that can be listed via this interface. Each flight + /// service allows its own definition of how to consume criteria. async fn list_flights( &self, request: tonic::Request, ) -> Result, tonic::Status>; - #[doc = ""] - #[doc = " For a given FlightDescriptor, get information about how the flight can be"] - #[doc = " consumed. This is a useful interface if the consumer of the interface"] - #[doc = " already can identify the specific flight to consume. This interface can"] - #[doc = " also allow a consumer to generate a flight stream through a specified"] - #[doc = " descriptor. For example, a flight descriptor might be something that"] - #[doc = " includes a SQL statement or a Pickled Python operation that will be"] - #[doc = " executed. In those cases, the descriptor will not be previously available"] - #[doc = " within the list of available streams provided by ListFlights but will be"] - #[doc = " available for consumption for the duration defined by the specific flight"] - #[doc = " service."] + /// + /// For a given FlightDescriptor, get information about how the flight can be + /// consumed. This is a useful interface if the consumer of the interface + /// already can identify the specific flight to consume. This interface can + /// also allow a consumer to generate a flight stream through a specified + /// descriptor. For example, a flight descriptor might be something that + /// includes a SQL statement or a Pickled Python operation that will be + /// executed. In those cases, the descriptor will not be previously available + /// within the list of available streams provided by ListFlights but will be + /// available for consumption for the duration defined by the specific flight + /// service. async fn get_flight_info( &self, request: tonic::Request, ) -> Result, tonic::Status>; - #[doc = ""] - #[doc = " For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema"] - #[doc = " This is used when a consumer needs the Schema of flight stream. Similar to"] - #[doc = " GetFlightInfo this interface may generate a new flight that was not previously"] - #[doc = " available in ListFlights."] + /// + /// For a given FlightDescriptor, get the Schema as described in Schema.fbs::Schema + /// This is used when a consumer needs the Schema of flight stream. Similar to + /// GetFlightInfo this interface may generate a new flight that was not previously + /// available in ListFlights. async fn get_schema( &self, request: tonic::Request, ) -> Result, tonic::Status>; - #[doc = "Server streaming response type for the DoGet method."] - type DoGetStream: futures_core::Stream> + ///Server streaming response type for the DoGet method. + type DoGetStream: futures_core::Stream< + Item = Result, + > + Send - + Sync + 'static; - #[doc = ""] - #[doc = " Retrieve a single stream associated with a particular descriptor"] - #[doc = " associated with the referenced ticket. A Flight can be composed of one or"] - #[doc = " more streams where each stream can be retrieved using a separate opaque"] - #[doc = " ticket that the flight service uses for managing a collection of streams."] + /// + /// Retrieve a single stream associated with a particular descriptor + /// associated with the referenced ticket. A Flight can be composed of one or + /// more streams where each stream can be retrieved using a separate opaque + /// ticket that the flight service uses for managing a collection of streams. async fn do_get( &self, request: tonic::Request, ) -> Result, tonic::Status>; - #[doc = "Server streaming response type for the DoPut method."] - type DoPutStream: futures_core::Stream> + ///Server streaming response type for the DoPut method. + type DoPutStream: futures_core::Stream< + Item = Result, + > + Send - + Sync + 'static; - #[doc = ""] - #[doc = " Push a stream to the flight service associated with a particular"] - #[doc = " flight stream. This allows a client of a flight service to upload a stream"] - #[doc = " of data. Depending on the particular flight service, a client consumer"] - #[doc = " could be allowed to upload a single stream per descriptor or an unlimited"] - #[doc = " number. In the latter, the service might implement a 'seal' action that"] - #[doc = " can be applied to a descriptor once all streams are uploaded."] + /// + /// Push a stream to the flight service associated with a particular + /// flight stream. This allows a client of a flight service to upload a stream + /// of data. Depending on the particular flight service, a client consumer + /// could be allowed to upload a single stream per descriptor or an unlimited + /// number. In the latter, the service might implement a 'seal' action that + /// can be applied to a descriptor once all streams are uploaded. async fn do_put( &self, request: tonic::Request>, ) -> Result, tonic::Status>; - #[doc = "Server streaming response type for the DoExchange method."] - type DoExchangeStream: futures_core::Stream> + ///Server streaming response type for the DoExchange method. + type DoExchangeStream: futures_core::Stream< + Item = Result, + > + Send - + Sync + 'static; - #[doc = ""] - #[doc = " Open a bidirectional data channel for a given descriptor. This"] - #[doc = " allows clients to send and receive arbitrary Arrow data and"] - #[doc = " application-specific metadata in a single logical stream. In"] - #[doc = " contrast to DoGet/DoPut, this is more suited for clients"] - #[doc = " offloading computation (rather than storage) to a Flight service."] + /// + /// Open a bidirectional data channel for a given descriptor. This + /// allows clients to send and receive arbitrary Arrow data and + /// application-specific metadata in a single logical stream. In + /// contrast to DoGet/DoPut, this is more suited for clients + /// offloading computation (rather than storage) to a Flight service. async fn do_exchange( &self, request: tonic::Request>, ) -> Result, tonic::Status>; - #[doc = "Server streaming response type for the DoAction method."] - type DoActionStream: futures_core::Stream> + ///Server streaming response type for the DoAction method. + type DoActionStream: futures_core::Stream< + Item = Result, + > + Send - + Sync + 'static; - #[doc = ""] - #[doc = " Flight services can support an arbitrary number of simple actions in"] - #[doc = " addition to the possible ListFlights, GetFlightInfo, DoGet, DoPut"] - #[doc = " operations that are potentially available. DoAction allows a flight client"] - #[doc = " to do a specific action against a flight service. An action includes"] - #[doc = " opaque request and response objects that are specific to the type action"] - #[doc = " being undertaken."] + /// + /// Flight services can support an arbitrary number of simple actions in + /// addition to the possible ListFlights, GetFlightInfo, DoGet, DoPut + /// operations that are potentially available. DoAction allows a flight client + /// to do a specific action against a flight service. An action includes + /// opaque request and response objects that are specific to the type action + /// being undertaken. async fn do_action( &self, request: tonic::Request, ) -> Result, tonic::Status>; - #[doc = "Server streaming response type for the ListActions method."] - type ListActionsStream: futures_core::Stream> + ///Server streaming response type for the ListActions method. + type ListActionsStream: futures_core::Stream< + Item = Result, + > + Send - + Sync + 'static; - #[doc = ""] - #[doc = " A flight service exposes all of the available action types that it has"] - #[doc = " along with descriptions. This allows different flight consumers to"] - #[doc = " understand the capabilities of the flight service."] + /// + /// A flight service exposes all of the available action types that it has + /// along with descriptions. This allows different flight consumers to + /// understand the capabilities of the flight service. async fn list_actions( &self, request: tonic::Request, ) -> Result, tonic::Status>; } - #[doc = ""] - #[doc = " A flight service is an endpoint for retrieving or storing Arrow data. A"] - #[doc = " flight service can expose one or more predefined endpoints that can be"] - #[doc = " accessed using the Arrow Flight Protocol. Additionally, a flight service"] - #[doc = " can expose a set of actions that are available."] + /// + /// A flight service is an endpoint for retrieving or storing Arrow data. A + /// flight service can expose one or more predefined endpoints that can be + /// accessed using the Arrow Flight Protocol. Additionally, a flight service + /// can expose a set of actions that are available. #[derive(Debug)] pub struct FlightServiceServer { inner: _Inner, @@ -653,7 +678,9 @@ pub mod flight_service_server { struct _Inner(Arc); impl FlightServiceServer { pub fn new(inner: T) -> Self { - let inner = Arc::new(inner); + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { let inner = _Inner(inner); Self { inner, @@ -674,13 +701,16 @@ pub mod flight_service_server { impl tonic::codegen::Service> for FlightServiceServer where T: FlightService, - B: Body + Send + Sync + 'static, + B: Body + Send + 'static, B::Error: Into + Send + 'static, { type Response = http::Response; - type Error = Never; + type Error = std::convert::Infallible; type Future = BoxFuture; - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: http::Request) -> Self::Future { @@ -689,10 +719,10 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/Handshake" => { #[allow(non_camel_case_types)] struct HandshakeSvc(pub Arc); - impl - tonic::server::StreamingService - for HandshakeSvc - { + impl< + T: FlightService, + > tonic::server::StreamingService + for HandshakeSvc { type Response = super::HandshakeResponse; type ResponseStream = T::HandshakeStream; type Future = BoxFuture< @@ -730,10 +760,10 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/ListFlights" => { #[allow(non_camel_case_types)] struct ListFlightsSvc(pub Arc); - impl - tonic::server::ServerStreamingService - for ListFlightsSvc - { + impl< + T: FlightService, + > tonic::server::ServerStreamingService + for ListFlightsSvc { type Response = super::FlightInfo; type ResponseStream = T::ListFlightsStream; type Future = BoxFuture< @@ -745,7 +775,9 @@ pub mod flight_service_server { request: tonic::Request, ) -> Self::Future { let inner = self.0.clone(); - let fut = async move { (*inner).list_flights(request).await }; + let fut = async move { + (*inner).list_flights(request).await + }; Box::pin(fut) } } @@ -769,20 +801,23 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/GetFlightInfo" => { #[allow(non_camel_case_types)] struct GetFlightInfoSvc(pub Arc); - impl - tonic::server::UnaryService - for GetFlightInfoSvc - { + impl< + T: FlightService, + > tonic::server::UnaryService + for GetFlightInfoSvc { type Response = super::FlightInfo; - type Future = - BoxFuture, tonic::Status>; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; fn call( &mut self, request: tonic::Request, ) -> Self::Future { let inner = self.0.clone(); - let fut = - async move { (*inner).get_flight_info(request).await }; + let fut = async move { + (*inner).get_flight_info(request).await + }; Box::pin(fut) } } @@ -806,13 +841,15 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/GetSchema" => { #[allow(non_camel_case_types)] struct GetSchemaSvc(pub Arc); - impl - tonic::server::UnaryService - for GetSchemaSvc - { + impl< + T: FlightService, + > tonic::server::UnaryService + for GetSchemaSvc { type Response = super::SchemaResult; - type Future = - BoxFuture, tonic::Status>; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; fn call( &mut self, request: tonic::Request, @@ -842,10 +879,10 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/DoGet" => { #[allow(non_camel_case_types)] struct DoGetSvc(pub Arc); - impl - tonic::server::ServerStreamingService - for DoGetSvc - { + impl< + T: FlightService, + > tonic::server::ServerStreamingService + for DoGetSvc { type Response = super::FlightData; type ResponseStream = T::DoGetStream; type Future = BoxFuture< @@ -881,10 +918,10 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/DoPut" => { #[allow(non_camel_case_types)] struct DoPutSvc(pub Arc); - impl - tonic::server::StreamingService - for DoPutSvc - { + impl< + T: FlightService, + > tonic::server::StreamingService + for DoPutSvc { type Response = super::PutResult; type ResponseStream = T::DoPutStream; type Future = BoxFuture< @@ -920,10 +957,10 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/DoExchange" => { #[allow(non_camel_case_types)] struct DoExchangeSvc(pub Arc); - impl - tonic::server::StreamingService - for DoExchangeSvc - { + impl< + T: FlightService, + > tonic::server::StreamingService + for DoExchangeSvc { type Response = super::FlightData; type ResponseStream = T::DoExchangeStream; type Future = BoxFuture< @@ -959,10 +996,10 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/DoAction" => { #[allow(non_camel_case_types)] struct DoActionSvc(pub Arc); - impl - tonic::server::ServerStreamingService - for DoActionSvc - { + impl< + T: FlightService, + > tonic::server::ServerStreamingService + for DoActionSvc { type Response = super::Result; type ResponseStream = T::DoActionStream; type Future = BoxFuture< @@ -998,10 +1035,10 @@ pub mod flight_service_server { "/arrow.flight.protocol.FlightService/ListActions" => { #[allow(non_camel_case_types)] struct ListActionsSvc(pub Arc); - impl - tonic::server::ServerStreamingService - for ListActionsSvc - { + impl< + T: FlightService, + > tonic::server::ServerStreamingService + for ListActionsSvc { type Response = super::ActionType; type ResponseStream = T::ListActionsStream; type Future = BoxFuture< @@ -1013,7 +1050,9 @@ pub mod flight_service_server { request: tonic::Request, ) -> Self::Future { let inner = self.0.clone(); - let fut = async move { (*inner).list_actions(request).await }; + let fut = async move { + (*inner).list_actions(request).await + }; Box::pin(fut) } } @@ -1034,14 +1073,18 @@ pub mod flight_service_server { }; Box::pin(fut) } - _ => Box::pin(async move { - Ok(http::Response::builder() - .status(200) - .header("grpc-status", "12") - .header("content-type", "application/grpc") - .body(empty_body()) - .unwrap()) - }), + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } } } } diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index a431cfc08845..5cfbd3f60657 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -67,6 +67,9 @@ pub use gen::Ticket; pub mod utils; +#[cfg(feature = "flight-sql-experimental")] +pub mod sql; + use flight_descriptor::DescriptorType; /// SchemaAsIpc represents a pairing of a `Schema` with IpcWriteOptions diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs new file mode 100644 index 000000000000..ea378a0a2577 --- /dev/null +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -0,0 +1,1139 @@ +// This file was automatically generated through the build.rs script, and should not be edited. + +/// +/// Represents a metadata request. Used in the command member of FlightDescriptor +/// for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// info_name: uint32 not null, +/// value: dense_union< +/// string_value: utf8, +/// bool_value: bool, +/// bigint_value: int64, +/// int32_bitmask: int32, +/// string_list: list +/// int32_to_int32_list_map: map> +/// > +/// where there is one row per requested piece of metadata information. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetSqlInfo { + /// + /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide + /// Flight SQL clients with basic, SQL syntax and SQL functions related information. + /// More information types can be added in future releases. + /// E.g. more SQL syntax support types, scalar functions support, type conversion support etc. + /// + /// Note that the set of metadata may expand. + /// + /// Initially, Flight SQL will support the following information types: + /// - Server Information - Range [0-500) + /// - Syntax Information - Range [500-1000) + /// Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). + /// Custom options should start at 10,000. + /// + /// If omitted, then all metadata will be retrieved. + /// Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must + /// at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. + /// If additional metadata is included, the metadata IDs should start from 10,000. + #[prost(uint32, repeated, tag="1")] + pub info: ::prost::alloc::vec::Vec, +} +/// +/// Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. +/// The definition of a catalog depends on vendor/implementation. It is usually the database itself +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8 not null +/// > +/// The returned data should be ordered by catalog_name. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetCatalogs { +} +/// +/// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. +/// The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8, +/// db_schema_name: utf8 not null +/// > +/// The returned data should be ordered by catalog_name, then db_schema_name. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetDbSchemas { + /// + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag="1")] + pub catalog: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[prost(string, optional, tag="2")] + pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, +} +/// +/// Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8, +/// db_schema_name: utf8, +/// table_name: utf8 not null, +/// table_type: utf8 not null, +/// \[optional\] table_schema: bytes not null (schema of the table as described in Schema.fbs::Schema, +/// it is serialized as an IPC message.) +/// > +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetTables { + /// + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag="1")] + pub catalog: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[prost(string, optional, tag="2")] + pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Specifies a filter pattern for tables to search for. + /// When no table_name_filter_pattern is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + #[prost(string, optional, tag="3")] + pub table_name_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Specifies a filter of table types which must match. + /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. + #[prost(string, repeated, tag="4")] + pub table_types: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// Specifies if the Arrow schema should be returned for found tables. + #[prost(bool, tag="5")] + pub include_schema: bool, +} +/// +/// Represents a request to retrieve the list of table types on a Flight SQL enabled backend. +/// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. +/// TABLE, VIEW, and SYSTEM TABLE are commonly supported. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// table_type: utf8 not null +/// > +/// The returned data should be ordered by table_type. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetTableTypes { +} +/// +/// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8, +/// db_schema_name: utf8, +/// table_name: utf8 not null, +/// column_name: utf8 not null, +/// key_name: utf8, +/// key_sequence: int not null +/// > +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetPrimaryKeys { + /// + /// Specifies the catalog to search for the table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag="1")] + pub catalog: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Specifies the schema to search for the table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag="2")] + pub db_schema: ::core::option::Option<::prost::alloc::string::String>, + /// Specifies the table to get the primary keys for. + #[prost(string, tag="3")] + pub table: ::prost::alloc::string::String, +} +/// +/// Represents a request to retrieve a description of the foreign key columns that reference the given table's +/// primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// pk_catalog_name: utf8, +/// pk_db_schema_name: utf8, +/// pk_table_name: utf8 not null, +/// pk_column_name: utf8 not null, +/// fk_catalog_name: utf8, +/// fk_db_schema_name: utf8, +/// fk_table_name: utf8 not null, +/// fk_column_name: utf8 not null, +/// key_sequence: int not null, +/// fk_key_name: utf8, +/// pk_key_name: utf8, +/// update_rule: uint1 not null, +/// delete_rule: uint1 not null +/// > +/// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetExportedKeys { + /// + /// Specifies the catalog to search for the foreign key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag="1")] + pub catalog: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Specifies the schema to search for the foreign key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag="2")] + pub db_schema: ::core::option::Option<::prost::alloc::string::String>, + /// Specifies the foreign key table to get the foreign keys for. + #[prost(string, tag="3")] + pub table: ::prost::alloc::string::String, +} +/// +/// Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// pk_catalog_name: utf8, +/// pk_db_schema_name: utf8, +/// pk_table_name: utf8 not null, +/// pk_column_name: utf8 not null, +/// fk_catalog_name: utf8, +/// fk_db_schema_name: utf8, +/// fk_table_name: utf8 not null, +/// fk_column_name: utf8 not null, +/// key_sequence: int not null, +/// fk_key_name: utf8, +/// pk_key_name: utf8, +/// update_rule: uint1 not null, +/// delete_rule: uint1 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// - 0 = CASCADE +/// - 1 = RESTRICT +/// - 2 = SET NULL +/// - 3 = NO ACTION +/// - 4 = SET DEFAULT +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetImportedKeys { + /// + /// Specifies the catalog to search for the primary key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag="1")] + pub catalog: ::core::option::Option<::prost::alloc::string::String>, + /// + /// Specifies the schema to search for the primary key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag="2")] + pub db_schema: ::core::option::Option<::prost::alloc::string::String>, + /// Specifies the primary key table to get the foreign keys for. + #[prost(string, tag="3")] + pub table: ::prost::alloc::string::String, +} +/// +/// Represents a request to retrieve a description of the foreign key columns in the given foreign key table that +/// reference the primary key or the columns representing a unique constraint of the parent table (could be the same +/// or a different table) on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. +/// +/// The returned Arrow schema will be: +/// < +/// pk_catalog_name: utf8, +/// pk_db_schema_name: utf8, +/// pk_table_name: utf8 not null, +/// pk_column_name: utf8 not null, +/// fk_catalog_name: utf8, +/// fk_db_schema_name: utf8, +/// fk_table_name: utf8 not null, +/// fk_column_name: utf8 not null, +/// key_sequence: int not null, +/// fk_key_name: utf8, +/// pk_key_name: utf8, +/// update_rule: uint1 not null, +/// delete_rule: uint1 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// - 0 = CASCADE +/// - 1 = RESTRICT +/// - 2 = SET NULL +/// - 3 = NO ACTION +/// - 4 = SET DEFAULT +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandGetCrossReference { + ///* + /// The catalog name where the parent table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag="1")] + pub pk_catalog: ::core::option::Option<::prost::alloc::string::String>, + ///* + /// The Schema name where the parent table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag="2")] + pub pk_db_schema: ::core::option::Option<::prost::alloc::string::String>, + ///* + /// The parent table name. It cannot be null. + #[prost(string, tag="3")] + pub pk_table: ::prost::alloc::string::String, + ///* + /// The catalog name where the foreign table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. + #[prost(string, optional, tag="4")] + pub fk_catalog: ::core::option::Option<::prost::alloc::string::String>, + ///* + /// The schema name where the foreign table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. + #[prost(string, optional, tag="5")] + pub fk_db_schema: ::core::option::Option<::prost::alloc::string::String>, + ///* + /// The foreign table name. It cannot be null. + #[prost(string, tag="6")] + pub fk_table: ::prost::alloc::string::String, +} +// SQL Execution Action Messages + +/// +/// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionCreatePreparedStatementRequest { + /// The valid SQL string to create a prepared statement for. + #[prost(string, tag="1")] + pub query: ::prost::alloc::string::String, +} +/// +/// Wrap the result of a "GetPreparedStatement" action. +/// +/// The resultant PreparedStatement can be closed either: +/// - Manually, through the "ClosePreparedStatement" action; +/// - Automatically, by a server timeout. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionCreatePreparedStatementResult { + /// Opaque handle for the prepared statement on the server. + #[prost(bytes="vec", tag="1")] + pub prepared_statement_handle: ::prost::alloc::vec::Vec, + /// If a result set generating query was provided, dataset_schema contains the + /// schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. + #[prost(bytes="vec", tag="2")] + pub dataset_schema: ::prost::alloc::vec::Vec, + /// If the query provided contained parameters, parameter_schema contains the + /// schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. + #[prost(bytes="vec", tag="3")] + pub parameter_schema: ::prost::alloc::vec::Vec, +} +/// +/// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. +/// Closes server resources associated with the prepared statement handle. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ActionClosePreparedStatementRequest { + /// Opaque handle for the prepared statement on the server. + #[prost(bytes="vec", tag="1")] + pub prepared_statement_handle: ::prost::alloc::vec::Vec, +} +// SQL Execution Messages. + +/// +/// Represents a SQL query. Used in the command member of FlightDescriptor +/// for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the query. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandStatementQuery { + /// The SQL syntax. + #[prost(string, tag="1")] + pub query: ::prost::alloc::string::String, +} +///* +/// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. +/// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TicketStatementQuery { + /// Unique identifier for the instance of the statement to execute. + #[prost(bytes="vec", tag="1")] + pub statement_handle: ::prost::alloc::vec::Vec, +} +/// +/// Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for +/// the following RPC calls: +/// - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. +/// - GetFlightInfo: execute the prepared statement instance. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandPreparedStatementQuery { + /// Opaque handle for the prepared statement on the server. + #[prost(bytes="vec", tag="1")] + pub prepared_statement_handle: ::prost::alloc::vec::Vec, +} +/// +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the the RPC call DoPut to cause the server to execute the included SQL update. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandStatementUpdate { + /// The SQL syntax. + #[prost(string, tag="1")] + pub query: ::prost::alloc::string::String, +} +/// +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the the RPC call DoPut to cause the server to execute the included +/// prepared statement handle as an update. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CommandPreparedStatementUpdate { + /// Opaque handle for the prepared statement on the server. + #[prost(bytes="vec", tag="1")] + pub prepared_statement_handle: ::prost::alloc::vec::Vec, +} +/// +/// Returned from the RPC call DoPut when a CommandStatementUpdate +/// CommandPreparedStatementUpdate was in the request, containing +/// results from the update. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DoPutUpdateResult { + /// The number of records updated. A return value of -1 represents + /// an unknown updated record count. + #[prost(int64, tag="1")] + pub record_count: i64, +} +/// Options for CommandGetSqlInfo. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlInfo { + // Server Information [0-500): Provides basic information about the Flight SQL Server. + + /// Retrieves a UTF-8 string with the name of the Flight SQL Server. + FlightSqlServerName = 0, + /// Retrieves a UTF-8 string with the native version of the Flight SQL Server. + FlightSqlServerVersion = 1, + /// Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. + FlightSqlServerArrowVersion = 2, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server is read only. + /// + /// Returns: + /// - false: if read-write + /// - true: if read only + FlightSqlServerReadOnly = 3, + // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. + + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + /// + /// Returns: + /// - false: if it doesn't support CREATE and DROP of catalogs. + /// - true: if it supports CREATE and DROP of catalogs. + SqlDdlCatalog = 500, + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. + /// + /// Returns: + /// - false: if it doesn't support CREATE and DROP of schemas. + /// - true: if it supports CREATE and DROP of schemas. + SqlDdlSchema = 501, + /// + /// Indicates whether the Flight SQL Server supports CREATE and DROP of tables. + /// + /// Returns: + /// - false: if it doesn't support CREATE and DROP of tables. + /// - true: if it supports CREATE and DROP of tables. + SqlDdlTable = 502, + /// + /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of catalog, table, schema and table names. + /// + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + SqlIdentifierCase = 503, + /// Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. + SqlIdentifierQuoteChar = 504, + /// + /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of quoted identifiers. + /// + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + SqlQuotedIdentifierCase = 505, + /// + /// Retrieves a boolean value indicating whether all tables are selectable. + /// + /// Returns: + /// - false: if not all tables are selectable or if none are; + /// - true: if all tables are selectable. + SqlAllTablesAreSelectable = 506, + /// + /// Retrieves the null ordering. + /// + /// Returns a uint32 ordinal for the null ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlNullOrdering`. + SqlNullOrdering = 507, + /// Retrieves a UTF-8 string list with values of the supported keywords. + SqlKeywords = 508, + /// Retrieves a UTF-8 string list with values of the supported numeric functions. + SqlNumericFunctions = 509, + /// Retrieves a UTF-8 string list with values of the supported string functions. + SqlStringFunctions = 510, + /// Retrieves a UTF-8 string list with values of the supported system functions. + SqlSystemFunctions = 511, + /// Retrieves a UTF-8 string list with values of the supported datetime functions. + SqlDatetimeFunctions = 512, + /// + /// Retrieves the UTF-8 string that can be used to escape wildcard characters. + /// This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern + /// (and therefore use one of the wildcard characters). + /// The '_' character represents any single character; the '%' character represents any sequence of zero or more + /// characters. + SqlSearchStringEscape = 513, + /// + /// Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names + /// (those beyond a-z, A-Z, 0-9 and _). + SqlExtraNameCharacters = 514, + /// + /// Retrieves a boolean value indicating whether column aliasing is supported. + /// If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns + /// as required. + /// + /// Returns: + /// - false: if column aliasing is unsupported; + /// - true: if column aliasing is supported. + SqlSupportsColumnAliasing = 515, + /// + /// Retrieves a boolean value indicating whether concatenations between null and non-null values being + /// null are supported. + /// + /// - Returns: + /// - false: if concatenations between null and non-null values being null are unsupported; + /// - true: if concatenations between null and non-null values being null are supported. + SqlNullPlusNullIsNull = 516, + /// + /// Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, + /// indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on + /// SqlSupportsConvert enum. + /// The returned map will be: map> + SqlSupportsConvert = 517, + /// + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. + /// + /// Returns: + /// - false: if table correlation names are unsupported; + /// - true: if table correlation names are supported. + SqlSupportsTableCorrelationNames = 518, + /// + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. + /// + /// Returns: + /// - false: if different table correlation names are unsupported; + /// - true: if different table correlation names are supported + SqlSupportsDifferentTableCorrelationNames = 519, + /// + /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. + /// + /// Returns: + /// - false: if expressions in ORDER BY are unsupported; + /// - true: if expressions in ORDER BY are supported; + SqlSupportsExpressionsInOrderBy = 520, + /// + /// Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY + /// clause is supported. + /// + /// Returns: + /// - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; + /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. + SqlSupportsOrderByUnrelated = 521, + /// + /// Retrieves the supported GROUP BY commands; + /// + /// Returns an int32 bitmask value representing the supported commands. + /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// + /// For instance: + /// - return 0 (\b0) => [] (GROUP BY is unsupported); + /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; + /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; + /// - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. + /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. + SqlSupportedGroupBy = 522, + /// + /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. + /// + /// Returns: + /// - false: if specifying a LIKE escape clause is unsupported; + /// - true: if specifying a LIKE escape clause is supported. + SqlSupportsLikeEscapeClause = 523, + /// + /// Retrieves a boolean value indicating whether columns may be defined as non-nullable. + /// + /// Returns: + /// - false: if columns cannot be defined as non-nullable; + /// - true: if columns may be defined as non-nullable. + SqlSupportsNonNullableColumns = 524, + /// + /// Retrieves the supported SQL grammar level as per the ODBC specification. + /// + /// Returns an int32 bitmask value representing the supported SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. + /// + /// For instance: + /// - return 0 (\b0) => [] (SQL grammar is unsupported); + /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; + /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; + /// - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; + /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; + /// - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + /// - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + /// - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. + /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. + SqlSupportedGrammar = 525, + /// + /// Retrieves the supported ANSI92 SQL grammar level. + /// + /// Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// + /// For instance: + /// - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); + /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; + /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; + /// - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; + /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; + /// - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; + /// - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; + /// - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. + /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. + SqlAnsi92SupportedLevel = 526, + /// + /// Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. + /// + /// Returns: + /// - false: if the SQL Integrity Enhancement Facility is supported; + /// - true: if the SQL Integrity Enhancement Facility is supported. + SqlSupportsIntegrityEnhancementFacility = 527, + /// + /// Retrieves the support level for SQL OUTER JOINs. + /// + /// Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. + SqlOuterJoinsSupportLevel = 528, + /// Retrieves a UTF-8 string with the preferred term for "schema". + SqlSchemaTerm = 529, + /// Retrieves a UTF-8 string with the preferred term for "procedure". + SqlProcedureTerm = 530, + /// Retrieves a UTF-8 string with the preferred term for "catalog". + SqlCatalogTerm = 531, + /// + /// Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. + /// + /// - false: if a catalog does not appear at the start of a fully qualified table name; + /// - true: if a catalog appears at the start of a fully qualified table name. + SqlCatalogAtStart = 532, + /// + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL schema. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported actions for SQL schema); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + SqlSchemasSupportedActions = 533, + /// + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL catalog. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported actions for SQL catalog); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + SqlCatalogsSupportedActions = 534, + /// + /// Retrieves the supported SQL positioned commands. + /// + /// Returns an int32 bitmask value representing the supported SQL positioned commands. + /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; + /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; + /// - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. + SqlSupportedPositionedCommands = 535, + /// + /// Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. + /// + /// Returns: + /// - false: if SELECT FOR UPDATE statements are unsupported; + /// - true: if SELECT FOR UPDATE statements are supported. + SqlSelectForUpdateSupported = 536, + /// + /// Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax + /// are supported. + /// + /// Returns: + /// - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; + /// - true: if stored procedure calls that use the stored procedure escape syntax are supported. + SqlStoredProceduresSupported = 537, + /// + /// Retrieves the supported SQL subqueries. + /// + /// Returns an int32 bitmask value representing the supported SQL subqueries. + /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL subqueries); + /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; + /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; + /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; + /// - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; + /// - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; + /// - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; + /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - ... + /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. + SqlSupportedSubqueries = 538, + /// + /// Retrieves a boolean value indicating whether correlated subqueries are supported. + /// + /// Returns: + /// - false: if correlated subqueries are unsupported; + /// - true: if correlated subqueries are supported. + SqlCorrelatedSubqueriesSupported = 539, + /// + /// Retrieves the supported SQL UNIONs. + /// + /// Returns an int32 bitmask value representing the supported SQL UNIONs. + /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_UNION\]; + /// - return 2 (\b10) => \[SQL_UNION_ALL\]; + /// - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. + SqlSupportedUnions = 540, + /// Retrieves a uint32 value representing the maximum number of hex characters allowed in an inline binary literal. + SqlMaxBinaryLiteralLength = 541, + /// Retrieves a uint32 value representing the maximum number of characters allowed for a character literal. + SqlMaxCharLiteralLength = 542, + /// Retrieves a uint32 value representing the maximum number of characters allowed for a column name. + SqlMaxColumnNameLength = 543, + /// Retrieves a uint32 value representing the the maximum number of columns allowed in a GROUP BY clause. + SqlMaxColumnsInGroupBy = 544, + /// Retrieves a uint32 value representing the maximum number of columns allowed in an index. + SqlMaxColumnsInIndex = 545, + /// Retrieves a uint32 value representing the maximum number of columns allowed in an ORDER BY clause. + SqlMaxColumnsInOrderBy = 546, + /// Retrieves a uint32 value representing the maximum number of columns allowed in a SELECT list. + SqlMaxColumnsInSelect = 547, + /// Retrieves a uint32 value representing the maximum number of columns allowed in a table. + SqlMaxColumnsInTable = 548, + /// Retrieves a uint32 value representing the maximum number of concurrent connections possible. + SqlMaxConnections = 549, + /// Retrieves a uint32 value the maximum number of characters allowed in a cursor name. + SqlMaxCursorNameLength = 550, + /// + /// Retrieves a uint32 value representing the maximum number of bytes allowed for an index, + /// including all of the parts of the index. + SqlMaxIndexLength = 551, + /// Retrieves a uint32 value representing the maximum number of characters allowed in a schema name. + SqlDbSchemaNameLength = 552, + /// Retrieves a uint32 value representing the maximum number of characters allowed in a procedure name. + SqlMaxProcedureNameLength = 553, + /// Retrieves a uint32 value representing the maximum number of characters allowed in a catalog name. + SqlMaxCatalogNameLength = 554, + /// Retrieves a uint32 value representing the maximum number of bytes allowed in a single row. + SqlMaxRowSize = 555, + /// + /// Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL + /// data types LONGVARCHAR and LONGVARBINARY. + /// + /// Returns: + /// - false: if return value for the JDBC method getMaxRowSize does + /// not include the SQL data types LONGVARCHAR and LONGVARBINARY; + /// - true: if return value for the JDBC method getMaxRowSize includes + /// the SQL data types LONGVARCHAR and LONGVARBINARY. + SqlMaxRowSizeIncludesBlobs = 556, + /// + /// Retrieves a uint32 value representing the maximum number of characters allowed for an SQL statement; + /// a result of 0 (zero) means that there is no limit or the limit is not known. + SqlMaxStatementLength = 557, + /// Retrieves a uint32 value representing the maximum number of active statements that can be open at the same time. + SqlMaxStatements = 558, + /// Retrieves a uint32 value representing the maximum number of characters allowed in a table name. + SqlMaxTableNameLength = 559, + /// Retrieves a uint32 value representing the maximum number of tables allowed in a SELECT statement. + SqlMaxTablesInSelect = 560, + /// Retrieves a uint32 value representing the maximum number of characters allowed in a user name. + SqlMaxUsernameLength = 561, + /// + /// Retrieves this database's default transaction isolation level as described in + /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + /// + /// Returns a uint32 ordinal for the SQL transaction isolation level. + SqlDefaultTransactionIsolation = 562, + /// + /// Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a + /// noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + /// + /// Returns: + /// - false: if transactions are unsupported; + /// - true: if transactions are supported. + SqlTransactionsSupported = 563, + /// + /// Retrieves the supported transactions isolation levels. + /// + /// Returns an int32 bitmask value representing the supported transactions isolation levels. + /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL transactions isolation levels); + /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; + /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; + /// - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; + /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; + /// - ... + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + SqlSupportedTransactionsIsolationLevels = 564, + /// + /// Retrieves a boolean value indicating whether a data definition statement within a transaction forces + /// the transaction to commit. + /// + /// Returns: + /// - false: if a data definition statement within a transaction does not force the transaction to commit; + /// - true: if a data definition statement within a transaction forces the transaction to commit. + SqlDataDefinitionCausesTransactionCommit = 565, + /// + /// Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. + /// + /// Returns: + /// - false: if a data definition statement within a transaction is taken into account; + /// - true: a data definition statement within a transaction is ignored. + SqlDataDefinitionsInTransactionsIgnored = 566, + /// + /// Retrieves an int32 bitmask value representing the supported result set types. + /// The returned bitmask should be parsed in order to retrieve the supported result set types. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported result set types); + /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; + /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; + /// - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + /// - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + /// - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; + /// - ... + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. + SqlSupportedResultSetTypes = 567, + /// + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + SqlSupportedConcurrenciesForResultSetUnspecified = 568, + /// + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + SqlSupportedConcurrenciesForResultSetForwardOnly = 569, + /// + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + SqlSupportedConcurrenciesForResultSetScrollSensitive = 570, + /// + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + SqlSupportedConcurrenciesForResultSetScrollInsensitive = 571, + /// + /// Retrieves a boolean value indicating whether this database supports batch updates. + /// + /// - false: if this database does not support batch updates; + /// - true: if this database supports batch updates. + SqlBatchUpdatesSupported = 572, + /// + /// Retrieves a boolean value indicating whether this database supports savepoints. + /// + /// Returns: + /// - false: if this database does not support savepoints; + /// - true: if this database supports savepoints. + SqlSavepointsSupported = 573, + /// + /// Retrieves a boolean value indicating whether named parameters are supported in callable statements. + /// + /// Returns: + /// - false: if named parameters in callable statements are unsupported; + /// - true: if named parameters in callable statements are supported. + SqlNamedParametersSupported = 574, + /// + /// Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. + /// + /// Returns: + /// - false: if updates made to a LOB are made directly to the LOB; + /// - true: if updates made to a LOB are made on a copy. + SqlLocatorsUpdateCopy = 575, + /// + /// Retrieves a boolean value indicating whether invoking user-defined or vendor functions + /// using the stored procedure escape syntax is supported. + /// + /// Returns: + /// - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; + /// - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. + SqlStoredFunctionsUsingCallSyntaxSupported = 576, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedCaseSensitivity { + SqlCaseSensitivityUnknown = 0, + SqlCaseSensitivityCaseInsensitive = 1, + SqlCaseSensitivityUppercase = 2, + SqlCaseSensitivityLowercase = 3, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlNullOrdering { + SqlNullsSortedHigh = 0, + SqlNullsSortedLow = 1, + SqlNullsSortedAtStart = 2, + SqlNullsSortedAtEnd = 3, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SupportedSqlGrammar { + SqlMinimumGrammar = 0, + SqlCoreGrammar = 1, + SqlExtendedGrammar = 2, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SupportedAnsi92SqlGrammarLevel { + Ansi92EntrySql = 0, + Ansi92IntermediateSql = 1, + Ansi92FullSql = 2, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlOuterJoinsSupportLevel { + SqlJoinsUnsupported = 0, + SqlLimitedOuterJoins = 1, + SqlFullOuterJoins = 2, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedGroupBy { + SqlGroupByUnrelated = 0, + SqlGroupByBeyondSelect = 1, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedElementActions { + SqlElementInProcedureCalls = 0, + SqlElementInIndexDefinitions = 1, + SqlElementInPrivilegeDefinitions = 2, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedPositionedCommands { + SqlPositionedDelete = 0, + SqlPositionedUpdate = 1, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedSubqueries { + SqlSubqueriesInComparisons = 0, + SqlSubqueriesInExists = 1, + SqlSubqueriesInIns = 2, + SqlSubqueriesInQuantifieds = 3, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedUnions { + SqlUnion = 0, + SqlUnionAll = 1, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlTransactionIsolationLevel { + SqlTransactionNone = 0, + SqlTransactionReadUncommitted = 1, + SqlTransactionReadCommitted = 2, + SqlTransactionRepeatableRead = 3, + SqlTransactionSerializable = 4, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedTransactions { + SqlTransactionUnspecified = 0, + SqlDataDefinitionTransactions = 1, + SqlDataManipulationTransactions = 2, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedResultSetType { + SqlResultSetTypeUnspecified = 0, + SqlResultSetTypeForwardOnly = 1, + SqlResultSetTypeScrollInsensitive = 2, + SqlResultSetTypeScrollSensitive = 3, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportedResultSetConcurrency { + SqlResultSetConcurrencyUnspecified = 0, + SqlResultSetConcurrencyReadOnly = 1, + SqlResultSetConcurrencyUpdatable = 2, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum SqlSupportsConvert { + SqlConvertBigint = 0, + SqlConvertBinary = 1, + SqlConvertBit = 2, + SqlConvertChar = 3, + SqlConvertDate = 4, + SqlConvertDecimal = 5, + SqlConvertFloat = 6, + SqlConvertInteger = 7, + SqlConvertIntervalDayTime = 8, + SqlConvertIntervalYearMonth = 9, + SqlConvertLongvarbinary = 10, + SqlConvertLongvarchar = 11, + SqlConvertNumeric = 12, + SqlConvertReal = 13, + SqlConvertSmallint = 14, + SqlConvertTime = 15, + SqlConvertTimestamp = 16, + SqlConvertTinyint = 17, + SqlConvertVarbinary = 18, + SqlConvertVarchar = 19, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum UpdateDeleteRules { + Cascade = 0, + Restrict = 1, + SetNull = 2, + NoAction = 3, + SetDefault = 4, +} diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs new file mode 100644 index 000000000000..cd198a1401d1 --- /dev/null +++ b/arrow-flight/src/sql/mod.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::error::{ArrowError, Result as ArrowResult}; +use prost::Message; + +mod gen { + #![allow(clippy::all)] + include!("arrow.flight.protocol.sql.rs"); +} + +pub use gen::ActionClosePreparedStatementRequest; +pub use gen::ActionCreatePreparedStatementRequest; +pub use gen::ActionCreatePreparedStatementResult; +pub use gen::CommandGetCatalogs; +pub use gen::CommandGetCrossReference; +pub use gen::CommandGetDbSchemas; +pub use gen::CommandGetExportedKeys; +pub use gen::CommandGetImportedKeys; +pub use gen::CommandGetPrimaryKeys; +pub use gen::CommandGetSqlInfo; +pub use gen::CommandGetTableTypes; +pub use gen::CommandGetTables; +pub use gen::CommandPreparedStatementQuery; +pub use gen::CommandPreparedStatementUpdate; +pub use gen::CommandStatementQuery; +pub use gen::CommandStatementUpdate; +pub use gen::DoPutUpdateResult; +pub use gen::SqlInfo; +pub use gen::SqlNullOrdering; +pub use gen::SqlOuterJoinsSupportLevel; +pub use gen::SqlSupportedCaseSensitivity; +pub use gen::SqlSupportedElementActions; +pub use gen::SqlSupportedGroupBy; +pub use gen::SqlSupportedPositionedCommands; +pub use gen::SqlSupportedResultSetConcurrency; +pub use gen::SqlSupportedResultSetType; +pub use gen::SqlSupportedSubqueries; +pub use gen::SqlSupportedTransactions; +pub use gen::SqlSupportedUnions; +pub use gen::SqlSupportsConvert; +pub use gen::SqlTransactionIsolationLevel; +pub use gen::SupportedSqlGrammar; +pub use gen::TicketStatementQuery; +pub use gen::UpdateDeleteRules; + +pub mod server; + +/// ProstMessageExt are useful utility methods for prost::Message types +pub trait ProstMessageExt: prost::Message + Default { + /// type_url for this Message + fn type_url() -> &'static str; + + /// Convert this Message to prost_types::Any + fn as_any(&self) -> prost_types::Any; +} + +macro_rules! prost_message_ext { + ($($name:ty,)*) => { + $( + impl ProstMessageExt for $name { + fn type_url() -> &'static str { + concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name)) + } + + fn as_any(&self) -> prost_types::Any { + prost_types::Any { + type_url: <$name>::type_url().to_string(), + value: self.encode_to_vec(), + } + } + } + )* + }; +} + +// Implement ProstMessageExt for all structs defined in FlightSql.proto +prost_message_ext!( + ActionClosePreparedStatementRequest, + ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, + CommandGetCatalogs, + CommandGetCrossReference, + CommandGetDbSchemas, + CommandGetExportedKeys, + CommandGetImportedKeys, + CommandGetPrimaryKeys, + CommandGetSqlInfo, + CommandGetTableTypes, + CommandGetTables, + CommandPreparedStatementQuery, + CommandPreparedStatementUpdate, + CommandStatementQuery, + CommandStatementUpdate, + DoPutUpdateResult, + TicketStatementQuery, +); + +/// ProstAnyExt are useful utility methods for prost_types::Any +/// The API design is inspired by [rust-protobuf](https://github.com/stepancheg/rust-protobuf/blob/master/protobuf/src/well_known_types_util/any.rs) +pub trait ProstAnyExt { + /// Check if `Any` contains a message of given type. + fn is(&self) -> bool; + + /// Extract a message from this `Any`. + /// + /// # Returns + /// + /// * `Ok(None)` when message type mismatch + /// * `Err` when parse failed + fn unpack(&self) -> ArrowResult>; + + /// Pack any message into `prost_types::Any` value. + fn pack(message: &M) -> ArrowResult; +} + +impl ProstAnyExt for prost_types::Any { + fn is(&self) -> bool { + M::type_url() == self.type_url + } + + fn unpack(&self) -> ArrowResult> { + if !self.is::() { + return Ok(None); + } + let m = prost::Message::decode(&*self.value).map_err(|err| { + ArrowError::ParseError(format!("Unable to decode Any value: {}", err)) + })?; + Ok(Some(m)) + } + + fn pack(message: &M) -> ArrowResult { + Ok(message.as_any()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_type_url() { + assert_eq!( + TicketStatementQuery::type_url(), + "type.googleapis.com/arrow.flight.protocol.sql.TicketStatementQuery" + ); + assert_eq!( + CommandStatementQuery::type_url(), + "type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery" + ); + } + + #[test] + fn test_prost_any_pack_unpack() -> ArrowResult<()> { + let query = CommandStatementQuery { + query: "select 1".to_string(), + }; + let any = prost_types::Any::pack(&query)?; + assert!(any.is::()); + let unpack_query: CommandStatementQuery = any.unpack()?.unwrap(); + assert_eq!(query, unpack_query); + Ok(()) + } +} diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs new file mode 100644 index 000000000000..87e282b103b7 --- /dev/null +++ b/arrow-flight/src/sql/server.rs @@ -0,0 +1,658 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::pin::Pin; + +use futures::Stream; +use prost::Message; +use tonic::{Request, Response, Status, Streaming}; + +use super::{ + super::{ + flight_service_server::FlightService, Action, ActionType, Criteria, Empty, + FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, + PutResult, SchemaResult, Ticket, + }, + ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference, + CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, + CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, + CommandStatementUpdate, DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo, + TicketStatementQuery, +}; + +static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement"; +static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement"; + +/// Implements FlightSqlService to handle the flight sql protocol +#[tonic::async_trait] +pub trait FlightSqlService: + std::marker::Sync + std::marker::Send + std::marker::Sized + 'static +{ + /// When impl FlightSqlService, you can always set FlightService to Self + type FlightService: FlightService; + + /// Get a FlightInfo for executing a SQL query. + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo for executing an already created prepared statement. + async fn get_flight_info_prepared_statement( + &self, + query: CommandPreparedStatementQuery, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo for listing catalogs. + async fn get_flight_info_catalogs( + &self, + query: CommandGetCatalogs, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo for listing schemas. + async fn get_flight_info_schemas( + &self, + query: CommandGetDbSchemas, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo for listing tables. + async fn get_flight_info_tables( + &self, + query: CommandGetTables, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo to extract information about the table types. + async fn get_flight_info_table_types( + &self, + query: CommandGetTableTypes, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo for retrieving other information (See SqlInfo). + async fn get_flight_info_sql_info( + &self, + query: CommandGetSqlInfo, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo to extract information about primary and foreign keys. + async fn get_flight_info_primary_keys( + &self, + query: CommandGetPrimaryKeys, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo to extract information about exported keys. + async fn get_flight_info_exported_keys( + &self, + query: CommandGetExportedKeys, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo to extract information about imported keys. + async fn get_flight_info_imported_keys( + &self, + query: CommandGetImportedKeys, + request: FlightDescriptor, + ) -> Result, Status>; + + /// Get a FlightInfo to extract information about cross reference. + async fn get_flight_info_cross_reference( + &self, + query: CommandGetCrossReference, + request: FlightDescriptor, + ) -> Result, Status>; + + // do_get + + /// Get a FlightDataStream containing the query results. + async fn do_get_statement( + &self, + ticket: TicketStatementQuery, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the prepared statement query results. + async fn do_get_prepared_statement( + &self, + query: CommandPreparedStatementQuery, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the list of catalogs. + async fn do_get_catalogs( + &self, + query: CommandGetCatalogs, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the list of schemas. + async fn do_get_schemas( + &self, + query: CommandGetDbSchemas, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the list of tables. + async fn do_get_tables( + &self, + query: CommandGetTables, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the data related to the table types. + async fn do_get_table_types( + &self, + query: CommandGetTableTypes, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the list of SqlInfo results. + async fn do_get_sql_info( + &self, + query: CommandGetSqlInfo, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the data related to the primary and foreign keys. + async fn do_get_primary_keys( + &self, + query: CommandGetPrimaryKeys, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the data related to the exported keys. + async fn do_get_exported_keys( + &self, + query: CommandGetExportedKeys, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the data related to the imported keys. + async fn do_get_imported_keys( + &self, + query: CommandGetImportedKeys, + ) -> Result::DoGetStream>, Status>; + + /// Get a FlightDataStream containing the data related to the cross reference. + async fn do_get_cross_reference( + &self, + query: CommandGetCrossReference, + ) -> Result::DoGetStream>, Status>; + + // do_put + + /// Execute an update SQL statement. + async fn do_put_statement_update( + &self, + ticket: CommandStatementUpdate, + ) -> Result; + + /// Bind parameters to given prepared statement. + async fn do_put_prepared_statement_query( + &self, + query: CommandPreparedStatementQuery, + request: Streaming, + ) -> Result::DoPutStream>, Status>; + + /// Execute an update SQL prepared statement. + async fn do_put_prepared_statement_update( + &self, + query: CommandPreparedStatementUpdate, + request: Streaming, + ) -> Result; + + // do_action + + /// Create a prepared statement from given SQL statement. + async fn do_action_create_prepared_statement( + &self, + query: ActionCreatePreparedStatementRequest, + ) -> Result; + + /// Close a prepared statement. + async fn do_action_close_prepared_statement( + &self, + query: ActionClosePreparedStatementRequest, + ); + + /// Register a new SqlInfo result, making it available when calling GetSqlInfo. + async fn register_sql_info(&self, id: i32, result: &SqlInfo); +} + +/// Implements the lower level interface to handle FlightSQL +#[tonic::async_trait] +impl FlightService for T +where + T: FlightSqlService + std::marker::Send, +{ + type HandshakeStream = + Pin> + Send + 'static>>; + type ListFlightsStream = + Pin> + Send + 'static>>; + type DoGetStream = + Pin> + Send + 'static>>; + type DoPutStream = + Pin> + Send + 'static>>; + type DoActionStream = Pin< + Box> + Send + 'static>, + >; + type ListActionsStream = + Pin> + Send + 'static>>; + type DoExchangeStream = + Pin> + Send + 'static>>; + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let any: prost_types::Any = + prost::Message::decode(&*request.cmd).map_err(decode_error_to_status)?; + + if any.is::() { + return self + .get_flight_info_statement( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_prepared_statement( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_catalogs( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_schemas( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_tables( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_table_types( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_sql_info( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_primary_keys( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_exported_keys( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_imported_keys( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + return self + .get_flight_info_cross_reference( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + + Err(Status::unimplemented(format!( + "get_flight_info: The defined request is invalid: {:?}", + String::from_utf8(any.encode_to_vec()).unwrap() + ))) + } + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let any: prost_types::Any = + prost::Message::decode(&*request.ticket).map_err(decode_error_to_status)?; + + if any.is::() { + return self + .do_get_statement( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_prepared_statement( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_catalogs( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_schemas( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_tables( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_table_types( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_sql_info( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_primary_keys( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_exported_keys( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_imported_keys( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + if any.is::() { + return self + .do_get_cross_reference( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await; + } + + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {:?}", + String::from_utf8(request.ticket).unwrap() + ))) + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + let mut request = request.into_inner(); + let cmd = request.message().await?.unwrap(); + let any: prost_types::Any = + prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) + .map_err(decode_error_to_status)?; + if any.is::() { + let record_count = self + .do_put_statement_update( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + ) + .await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { + app_metadata: result.as_any().encode_to_vec(), + })]); + return Ok(Response::new(Box::pin(output))); + } + if any.is::() { + return self + .do_put_prepared_statement_query( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await; + } + if any.is::() { + let record_count = self + .do_put_prepared_statement_update( + any.unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"), + request, + ) + .await?; + let result = DoPutUpdateResult { record_count }; + let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { + app_metadata: result.as_any().encode_to_vec(), + })]); + return Ok(Response::new(Box::pin(output))); + } + + Err(Status::invalid_argument(format!( + "do_put: The defined request is invalid: {:?}", + String::from_utf8(any.encode_to_vec()).unwrap() + ))) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + let create_prepared_statement_action_type = ActionType { + r#type: CREATE_PREPARED_STATEMENT.to_string(), + description: "Creates a reusable prepared statement resource on the server.\n + Request Message: ActionCreatePreparedStatementRequest\n + Response Message: ActionCreatePreparedStatementResult" + .into(), + }; + let close_prepared_statement_action_type = ActionType { + r#type: CLOSE_PREPARED_STATEMENT.to_string(), + description: "Closes a reusable prepared statement resource on the server.\n + Request Message: ActionClosePreparedStatementRequest\n + Response Message: N/A" + .into(), + }; + let actions: Vec> = vec![ + Ok(create_prepared_statement_action_type), + Ok(close_prepared_statement_action_type), + ]; + let output = futures::stream::iter(actions); + Ok(Response::new(Box::pin(output) as Self::ListActionsStream)) + } + + async fn do_action( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + + if request.r#type == CREATE_PREPARED_STATEMENT { + let any: prost_types::Any = + prost::Message::decode(&*request.body).map_err(decode_error_to_status)?; + + let cmd: ActionCreatePreparedStatementRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument( + "Unable to unpack ActionCreatePreparedStatementRequest.", + ) + })?; + let stmt = self.do_action_create_prepared_statement(cmd).await?; + let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + body: stmt.as_any().encode_to_vec(), + })]); + return Ok(Response::new(Box::pin(output))); + } + if request.r#type == CLOSE_PREPARED_STATEMENT { + let any: prost_types::Any = + prost::Message::decode(&*request.body).map_err(decode_error_to_status)?; + + let cmd: ActionClosePreparedStatementRequest = any + .unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| { + Status::invalid_argument( + "Unable to unpack ActionClosePreparedStatementRequest.", + ) + })?; + self.do_action_close_prepared_statement(cmd).await; + return Ok(Response::new(Box::pin(futures::stream::empty()))); + } + + Err(Status::invalid_argument(format!( + "do_action: The defined request is invalid: {:?}", + request.r#type + ))) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +fn decode_error_to_status(err: prost::DecodeError) -> tonic::Status { + tonic::Status::invalid_argument(format!("{:?}", err)) +} + +fn arrow_error_to_status(err: arrow::error::ArrowError) -> tonic::Status { + tonic::Status::internal(format!("{:?}", err)) +} diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 9232acc2aefb..dda3fc7fe3db 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -18,6 +18,7 @@ //! Utilities to assist with reading and writing Arrow data as Flight messages use crate::{FlightData, IpcMessage, SchemaAsIpc, SchemaResult}; +use std::collections::HashMap; use arrow::array::ArrayRef; use arrow::datatypes::{Schema, SchemaRef}; @@ -49,7 +50,7 @@ pub fn flight_data_from_arrow_batch( pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, - dictionaries_by_field: &[Option], + dictionaries_by_id: &HashMap, ) -> Result { // check that the data_header is a record batch message let message = arrow::ipc::root_as_message(&data.data_header[..]).map_err(|err| { @@ -68,7 +69,9 @@ pub fn flight_data_to_arrow_batch( &data.data_body, batch, schema, - dictionaries_by_field, + dictionaries_by_id, + None, + &message.version(), ) })? } diff --git a/arrow-pyarrow-integration-testing/Cargo.toml b/arrow-pyarrow-integration-testing/Cargo.toml index c0c20f6d2889..58ba726091c8 100644 --- a/arrow-pyarrow-integration-testing/Cargo.toml +++ b/arrow-pyarrow-integration-testing/Cargo.toml @@ -18,21 +18,22 @@ [package] name = "arrow-pyarrow-integration-testing" description = "" -version = "7.0.0-SNAPSHOT" +version = "16.0.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] license = "Apache-2.0" keywords = [ "arrow" ] -edition = "2018" +edition = "2021" +rust-version = "1.57" [lib] name = "arrow_pyarrow_integration_testing" crate-type = ["cdylib"] [dependencies] -arrow = { path = "../arrow", version = "7.0.0-SNAPSHOT", features = ["pyarrow"] } -pyo3 = { version = "0.14", features = ["extension-module"] } +arrow = { path = "../arrow", version = "16.0.0", features = ["pyarrow"] } +pyo3 = { version = "0.16", features = ["extension-module"] } [package.metadata.maturin] requires-dist = ["pyarrow>=1"] diff --git a/arrow-pyarrow-integration-testing/README.md b/arrow-pyarrow-integration-testing/README.md index 7e78aa9ec70e..e63953ad7900 100644 --- a/arrow-pyarrow-integration-testing/README.md +++ b/arrow-pyarrow-integration-testing/README.md @@ -17,7 +17,7 @@ under the License. --> -# Arrow c integration +# Arrow + PyArrow integration testing This is a Rust crate that tests compatibility between Rust's Arrow implementation and PyArrow. @@ -45,7 +45,7 @@ we can use pyarrow's interface to move pointers from and to Rust. ```bash # prepare development environment (used to build wheel / install in development) python -m venv venv -venv/bin/pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 +venv/bin/pip install maturin toml pytest pytz pyarrow>=5.0 ``` Whenever rust code changes (your changes or via git pull): @@ -53,5 +53,5 @@ Whenever rust code changes (your changes or via git pull): ```bash source venv/bin/activate maturin develop -python -m unittest discover tests +pytest -v . ``` diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 082a72e9e1ff..086b21834657 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -27,6 +27,7 @@ use arrow::array::{ArrayData, ArrayRef, Int64Array}; use arrow::compute::kernels; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; +use arrow::ffi_stream::ArrowArrayStreamReader; use arrow::pyarrow::PyArrowConvert; use arrow::record_batch::RecordBatch; @@ -70,7 +71,7 @@ fn substring(array: ArrayData, start: i64) -> PyResult { let array = ArrayRef::from(array); // substring - let array = kernels::substring::substring(array.as_ref(), start, &None)?; + let array = kernels::substring::substring(array.as_ref(), start, None)?; Ok(array.data().to_owned()) } @@ -111,6 +112,13 @@ fn round_trip_record_batch(obj: RecordBatch) -> PyResult { Ok(obj) } +#[pyfunction] +fn round_trip_record_batch_reader( + obj: ArrowArrayStreamReader, +) -> PyResult { + Ok(obj) +} + #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(double))?; @@ -122,5 +130,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> m.add_wrapped(wrap_pyfunction!(round_trip_schema))?; m.add_wrapped(wrap_pyfunction!(round_trip_array))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?; + m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?; Ok(()) } diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index bacd1188ce4f..a17ba6d06135 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -55,15 +55,21 @@ def assert_pyarrow_leak(): pa.timestamp("us"), pa.timestamp("us", tz="UTC"), pa.timestamp("us", tz="Europe/Paris"), + pa.duration("s"), + pa.duration("ms"), + pa.duration("us"), + pa.duration("ns"), pa.float16(), pa.float32(), pa.float64(), pa.decimal128(19, 4), pa.string(), pa.binary(), + pa.binary(10), pa.large_string(), pa.large_binary(), pa.list_(pa.int32()), + pa.list_(pa.int32(), 2), pa.large_list(pa.uint16()), pa.struct( [ @@ -79,13 +85,11 @@ def assert_pyarrow_leak(): pa.field("c", pa.string()), ] ), + pa.dictionary(pa.int8(), pa.string()), ] _unsupported_pyarrow_types = [ pa.decimal256(76, 38), - pa.duration("s"), - pa.binary(10), - pa.list_(pa.int32(), 2), pa.map_(pa.string(), pa.int32()), pa.union( [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], @@ -122,14 +126,6 @@ def test_type_roundtrip_raises(pyarrow_type): with pytest.raises(pa.ArrowException): rust.round_trip_type(pyarrow_type) - -def test_dictionary_type_roundtrip(): - # the dictionary type conversion is incomplete - pyarrow_type = pa.dictionary(pa.int32(), pa.string()) - ty = rust.round_trip_type(pyarrow_type) - assert ty == pa.int32() - - @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) def test_field_roundtrip(pyarrow_type): pyarrow_field = pa.field("test", pyarrow_type, nullable=True) @@ -197,6 +193,29 @@ def test_time32_python(): del b del expected +def test_binary_array(): + """ + Python -> Rust -> Python + """ + a = pa.array(["a", None, "bb", "ccc"], pa.binary()) + b = rust.round_trip_array(a) + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type + del a + del b + +def test_fixed_len_binary_array(): + """ + Python -> Rust -> Python + """ + a = pa.array(["aaa", None, "bbb", "ccc"], pa.binary(3)) + b = rust.round_trip_array(a) + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type + del a + del b def test_list_array(): """ @@ -210,6 +229,17 @@ def test_list_array(): del a del b +def test_fixed_len_list_array(): + """ + Python -> Rust -> Python + """ + a = pa.array([[1, 2], None, [3, 4], [5, 6]], pa.list_(pa.int64(), 2)) + b = rust.round_trip_array(a) + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type + del a + del b def test_timestamp_python(): """ @@ -263,3 +293,29 @@ def test_decimal_python(): assert a == b del a del b + +def test_dictionary_python(): + """ + Python -> Rust -> Python + """ + a = pa.array(["a", None, "b", None, "a"], type=pa.dictionary(pa.int8(), pa.string())) + b = rust.round_trip_array(a) + assert a == b + del a + del b + +def test_record_batch_reader(): + """ + Python -> Rust -> Python + """ + schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'}) + batches = [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + a = pa.RecordBatchReader.from_batches(schema, batches) + b = rust.round_trip_record_batch_reader(a) + + assert b.schema == schema + got_batches = list(b) + assert got_batches == batches diff --git a/arrow/CONTRIBUTING.md b/arrow/CONTRIBUTING.md index 843e1faf05e7..bbf309d4d225 100644 --- a/arrow/CONTRIBUTING.md +++ b/arrow/CONTRIBUTING.md @@ -26,19 +26,6 @@ Rust [README.md](../README.md). Please refer to [lib.rs](src/lib.rs) for an introduction to this specific crate and its current functionality. -### How to check memory allocations - -This crate heavily uses `unsafe` due to how memory is allocated in cache lines. -We have a small tool to verify that this crate does not leak memory (beyond what the compiler already does) - -Run it with - -```bash -cargo test --features memory-check --lib -- --test-threads 1 -``` - -This runs all unit-tests on a single thread and counts all allocations and de-allocations. - ## IPC The expected flatc version is 1.12.0+, built from [flatbuffers](https://github.com/google/flatbuffers) @@ -99,7 +86,15 @@ The arrow format declares a IPC protocol, which this crate supports. IPC is equi #### SIMD -The API provided by the `packed_simd` library is currently `unsafe`. However, SIMD offers a significant performance improvement over non-SIMD operations. +The API provided by the [packed_simd_2](https://docs.rs/packed_simd_2/latest/packed_simd_2/) crate is currently `unsafe`. However, +SIMD offers a significant performance improvement over non-SIMD operations. A related crate in development is +[portable-simd](https://rust-lang.github.io/portable-simd/core_simd/) which has a nice +[beginners guide](https://github.com/rust-lang/portable-simd/blob/master/beginners-guide.md). These crates provide the ability +for code on x86 and ARM architectures to use some of the available parallel register operations. As an example if two arrays +of numbers are added, [1,2,3,4] + [5,6,7,8], rather than using four instructions to add each of the elements of the arrays, +one instruction can be used to all all four elements at the same time, which leads to improved time to solution. SIMD instructions +are typically most effective when data is aligned to allow a single load instruction to bring multiple consecutive data elements +to the registers, before use of a SIMD instruction. #### Performance diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 88623211dab7..b59a69753856 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "arrow" -version = "7.0.0-SNAPSHOT" +version = "16.0.0" description = "Rust implementation of Apache Arrow" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" @@ -29,36 +29,38 @@ include = [ "src/**/*.rs", "Cargo.toml", ] -edition = "2018" +edition = "2021" +rust-version = "1.57" [lib] name = "arrow" path = "src/lib.rs" +bench = false [dependencies] -serde = { version = "1.0", features = ["rc"] } -serde_derive = "1.0" -serde_json = { version = "1.0", features = ["preserve_order"] } -indexmap = "1.6" -rand = { version = "0.8", optional = true } -num = "0.4" -csv_crate = { version = "1.1", optional = true, package="csv" } -regex = "1.3" -lazy_static = "1.4" -packed_simd = { version = "0.3", optional = true, package = "packed_simd_2" } -chrono = "0.4" -chrono-tz = {version = "0.4", optional = true} -flatbuffers = { version = "=2.0.0", optional = true } -hex = "0.4" -comfy-table = { version = "4.0", optional = true, default-features = false } -pyo3 = { version = "0.14", optional = true } -lexical-core = "^0.8" -multiversion = "0.6.1" -bitflags = "1.2.1" +serde = { version = "1.0", default-features = false } +serde_derive = { version = "1.0", default-features = false } +serde_json = { version = "1.0", default-features = false, features = ["preserve_order"] } +indexmap = { version = "1.6", default-features = false, features = ["std"] } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } +num = { version = "0.4", default-features = false, features = ["std"] } +half = { version = "1.8", default-features = false } +csv_crate = { version = "1.1", default-features = false, optional = true, package="csv" } +regex = { version = "1.5.6", default-features = false, features = ["std", "unicode"] } +lazy_static = { version = "1.4", default-features = false } +packed_simd = { version = "0.3", default-features = false, optional = true, package = "packed_simd_2" } +chrono = { version = "0.4", default-features = false, features = ["clock"] } +chrono-tz = {version = "0.6", default-features = false, optional = true} +flatbuffers = { version = "2.1.2", default-features = false, features = ["thiserror"], optional = true } +hex = { version = "0.4", default-features = false, features = ["std"] } +comfy-table = { version = "6.0", optional = true, default-features = false } +pyo3 = { version = "0.16", default-features = false, optional = true } +lexical-core = { version = "^0.8", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } +multiversion = { version = "0.6.1", default-features = false } +bitflags = { version = "1.2.1", default-features = false } [features] default = ["csv", "ipc", "test_utils"] -avx512 = [] csv = ["csv_crate"] ipc = ["flatbuffers"] simd = ["packed_simd"] @@ -68,23 +70,24 @@ prettyprint = ["comfy-table"] # an optional dependency for supporting compile to wasm32-unknown-unknown # target without assuming an environment containing JavaScript. test_utils = ["rand"] -# this is only intended to be used in single-threaded programs: it verifies that -# all allocated memory is being released (no memory leaks). -# See README for details -memory-check = [] pyarrow = ["pyo3"] +# force_validate runs full data validation for all arrays that are created +# this is not enabled by default as it is too computationally expensive +# but is run as part of our CI checks +force_validate = [] [dev-dependencies] -rand = "0.8" -criterion = "0.3" -flate2 = "1" -tempfile = "3" +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } +criterion = { version = "0.3", default-features = false } +flate2 = { version = "1", default-features = false, features = ["rust_backend"] } +tempfile = { version = "3", default-features = false } [build-dependencies] [[bench]] name = "aggregate_kernels" harness = false +required-features = ["test_utils"] [[bench]] name = "array_from_vec" @@ -93,6 +96,7 @@ harness = false [[bench]] name = "builder" harness = false +required-features = ["test_utils"] [[bench]] name = "buffer_bit_ops" @@ -101,26 +105,36 @@ harness = false [[bench]] name = "boolean_kernels" harness = false +required-features = ["test_utils"] + +[[bench]] +name = "boolean_append_packed" +harness = false [[bench]] name = "arithmetic_kernels" harness = false +required-features = ["test_utils"] [[bench]] name = "cast_kernels" harness = false +required-features = ["test_utils"] [[bench]] name = "comparison_kernels" harness = false +required-features = ["test_utils"] [[bench]] name = "filter_kernels" harness = false +required-features = ["test_utils"] [[bench]] name = "take_kernels" harness = false +required-features = ["test_utils"] [[bench]] name = "length_kernel" @@ -133,10 +147,12 @@ harness = false [[bench]] name = "sort_kernel" harness = false +required-features = ["test_utils"] [[bench]] name = "partition_kernels" harness = false +required-features = ["test_utils"] [[bench]] name = "csv_writer" @@ -149,6 +165,7 @@ harness = false [[bench]] name = "equal" harness = false +required-features = ["test_utils"] [[bench]] name = "array_slice" @@ -157,11 +174,23 @@ harness = false [[bench]] name = "concatenate_kernel" harness = false +required-features = ["test_utils"] [[bench]] name = "mutable_array" harness = false +required-features = ["test_utils"] [[bench]] name = "buffer_create" harness = false +required-features = ["test_utils"] + +[[bench]] +name = "substring_kernels" +harness = false +required-features = ["test_utils"] + +[[bench]] +name = "array_data_validate" +harness = false diff --git a/arrow/README.md b/arrow/README.md index c73ccd16a353..28240e77dff3 100644 --- a/arrow/README.md +++ b/arrow/README.md @@ -19,17 +19,20 @@ # Apache Arrow Official Native Rust Implementation -[![Crates.io](https://img.shields.io/crates/v/arrow.svg)](https://crates.io/crates/arrow) +[![crates.io](https://img.shields.io/crates/v/arrow.svg)](https://crates.io/crates/arrow) +[![docs.rs](https://img.shields.io/docsrs/arrow.svg)](https://docs.rs/arrow/latest/arrow/) -This crate contains the official Native Rust implementation of [Apache Arrow][arrow] in memory format. Please see the API documents for additional details. +This crate contains the official Native Rust implementation of [Apache Arrow][arrow] in memory format, governed by the Apache Software Foundation. Additional details can be found on [crates.io](https://crates.io/crates/arrow), [docs.rs](https://docs.rs/arrow/latest/arrow/) and [examples](https://github.com/apache/arrow-rs/tree/master/arrow/examples). ## Rust Version Compatibility -This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. +This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions. ## Versioning / Releases -Unlike many other crates in the Rust ecosystem which spend extended time in "pre 1.0.0" state, releasing versions 0.x, the arrow-rs crate follows the versioning scheme of the overall [Apache Arrow][arrow] project in an effort to signal which language implementations have been integration tested with each other. +The arrow crate follows the [SemVer standard](https://doc.rust-lang.org/cargo/reference/semver.html) defined by Cargo and works well within the Rust crate ecosystem. + +However, for historical reasons, this crate uses versions with major numbers greater than `0.x` (e.g. `16.0.0`), unlike many other crates in the Rust ecosystem which spend extended time releasing versions `0.x` to signal planned ongoing API changes. Minor arrow releases contain only compatible changes, while major releases may contain breaking API changes. ## Features @@ -40,27 +43,37 @@ The arrow crate provides the following features which may be enabled: - `prettyprint` - support for formatting record batches as textual columns - `js` - support for building arrow for WebAssembly / JavaScript - `simd` - (_Requires Nightly Rust_) alternate optimized - implementations of some [compute](https://github.com/apache/arrow/tree/master/rust/arrow/src/compute) - kernels using explicit SIMD processor intrinsics. + implementations of some [compute](https://github.com/apache/arrow-rs/tree/master/arrow/src/compute/kernels) + kernels using explicit SIMD instructions available through [packed_simd_2](https://docs.rs/packed_simd_2/latest/packed_simd_2/). - `chrono-tz` - support of parsing timezone using [chrono-tz](https://docs.rs/chrono-tz/0.6.0/chrono_tz/) ## Safety -TLDR: You should avoid using the `alloc` and `buffer` and `bitmap` modules if at all possible. These modules contain `unsafe` code and are easy to misuse. +Arrow seeks to uphold the Rust Soundness Pledge as articulated eloquently [here](https://raphlinus.github.io/rust/2020/01/18/soundness-pledge.html). Specifically: -As with all open source code, you should carefully evaluate the suitability of `arrow` for your project, taking into consideration your needs and risk tolerance prior to use. +> The intent of this crate is to be free of soundness bugs. The developers will do their best to avoid them, and welcome help in analyzing and fixing them -_Background_: There are various parts of the `arrow` crate which use `unsafe` and `transmute` code internally. We are actively working as a community to minimize undefined behavior and remove `unsafe` usage to align more with Rust's core principles of safety (e.g. the arrow2 project). +Where soundness in turn is defined as: -As `arrow` exists today, it is fairly easy to misuse the APIs, leading to undefined behavior, and it is especially easy to misuse code in modules named above. For an example, as described in [the arrow2 crate](https://github.com/jorgecarleitao/arrow2#why), the following code compiles, does not panic, but results in undefined behavior: +> Code is unable to trigger undefined behaviour using safe APIs -```rust -let buffer = Buffer::from_slice_ref(&[0i32, 2i32]) -let data = ArrayData::new(DataType::Int64, 10, 0, None, 0, vec![buffer], vec![]); -let array = Float64Array::from(Arc::new(data)); +One way to ensure this would be to not use `unsafe`, however, as described in the opening chapter of the [Rustonomicon](https://doc.rust-lang.org/nomicon/meet-safe-and-unsafe.html) this is not a requirement, and flexibility in this regard is actually one of Rust's great strengths. -println!("{:?}", array.value(1)); -``` +In particular there are a number of scenarios where `unsafe` is largely unavoidable: + +* Invariants that cannot be statically verified by the compiler and unlock non-trivial performance wins, e.g. values in a StringArray are UTF-8, [TrustedLen](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html) iterators, etc... +* FFI +* SIMD + +Additionally, this crate exposes a number of `unsafe` APIs, allowing downstream crates to explicitly opt-out of potentially expensive invariant checking where appropriate. + +We have a number of strategies to help reduce this risk: + +* Provide strongly-typed `Array` and `ArrayBuilder` APIs to safely and efficiently interact with arrays +* Extensive validation logic to safely construct `ArrayData` from untrusted sources +* All commits are verified using [MIRI](https://github.com/rust-lang/miri) to detect undefined behaviour +* We provide a `force_validate` feature that enables additional validation checks for use in test/debug builds +* There is ongoing work to reduce and better document the use of unsafe, and we welcome contributions in this space ## Building for WASM @@ -87,3 +100,17 @@ cargo run --example read_csv ``` [arrow]: https://arrow.apache.org/ + + +## Performance + +Most of the compute kernels benefit a lot from being optimized for a specific CPU target. +This is especially so on x86-64 since without specifying a target the compiler can only assume support for SSE2 vector instructions. +One of the following values as `-Ctarget-cpu=value` in `RUSTFLAGS` can therefore improve performance significantly: + + - `native`: Target the exact features of the cpu that the build is running on. + This should give the best performance when building and running locally, but should be used carefully for example when building in a CI pipeline or when shipping pre-compiled software. + - `x86-64-v3`: Includes AVX2 support and is close to the intel `haswell` architecture released in 2013 and should be supported by any recent Intel or Amd cpu. + - `x86-64-v4`: Includes AVX512 support available on intel `skylake` server and `icelake`/`tigerlake`/`rocketlake` laptop and desktop processors. + +These flags should be used in addition to the `simd` feature, since they will also affect the code generated by the simd library. \ No newline at end of file diff --git a/arrow/benches/arithmetic_kernels.rs b/arrow/benches/arithmetic_kernels.rs index bbe412366988..4be4a26933aa 100644 --- a/arrow/benches/arithmetic_kernels.rs +++ b/arrow/benches/arithmetic_kernels.rs @@ -24,7 +24,6 @@ use std::sync::Arc; extern crate arrow; -use arrow::compute::kernels::limit::*; use arrow::util::bench_util::*; use arrow::{array::*, datatypes::Float32Type}; use arrow::{compute::kernels::arithmetic::*, util::test_util::seedable_rng}; @@ -59,44 +58,69 @@ fn bench_divide(arr_a: &ArrayRef, arr_b: &ArrayRef) { criterion::black_box(divide(arr_a, arr_b).unwrap()); } +fn bench_divide_unchecked(arr_a: &ArrayRef, arr_b: &ArrayRef) { + let arr_a = arr_a.as_any().downcast_ref::().unwrap(); + let arr_b = arr_b.as_any().downcast_ref::().unwrap(); + criterion::black_box(divide_unchecked(arr_a, arr_b).unwrap()); +} + fn bench_divide_scalar(array: &ArrayRef, divisor: f32) { let array = array.as_any().downcast_ref::().unwrap(); criterion::black_box(divide_scalar(array, divisor).unwrap()); } -fn bench_limit(arr_a: &ArrayRef, max: usize) { - criterion::black_box(limit(arr_a, max)); +fn bench_modulo(arr_a: &ArrayRef, arr_b: &ArrayRef) { + let arr_a = arr_a.as_any().downcast_ref::().unwrap(); + let arr_b = arr_b.as_any().downcast_ref::().unwrap(); + criterion::black_box(modulus(arr_a, arr_b).unwrap()); +} + +fn bench_modulo_scalar(array: &ArrayRef, divisor: f32) { + let array = array.as_any().downcast_ref::().unwrap(); + criterion::black_box(modulus_scalar(array, divisor).unwrap()); } fn add_benchmark(c: &mut Criterion) { - let arr_a = create_array(512, false); - let arr_b = create_array(512, false); + const BATCH_SIZE: usize = 64 * 1024; + let arr_a = create_array(BATCH_SIZE, false); + let arr_b = create_array(BATCH_SIZE, false); let scalar = seedable_rng().gen(); - c.bench_function("add 512", |b| b.iter(|| bench_add(&arr_a, &arr_b))); - c.bench_function("subtract 512", |b| { - b.iter(|| bench_subtract(&arr_a, &arr_b)) + c.bench_function("add", |b| b.iter(|| bench_add(&arr_a, &arr_b))); + c.bench_function("subtract", |b| b.iter(|| bench_subtract(&arr_a, &arr_b))); + c.bench_function("multiply", |b| b.iter(|| bench_multiply(&arr_a, &arr_b))); + c.bench_function("divide", |b| b.iter(|| bench_divide(&arr_a, &arr_b))); + c.bench_function("divide_unchecked", |b| { + b.iter(|| bench_divide_unchecked(&arr_a, &arr_b)) }); - c.bench_function("multiply 512", |b| { - b.iter(|| bench_multiply(&arr_a, &arr_b)) - }); - c.bench_function("divide 512", |b| b.iter(|| bench_divide(&arr_a, &arr_b))); - c.bench_function("divide_scalar 512", |b| { + c.bench_function("divide_scalar", |b| { b.iter(|| bench_divide_scalar(&arr_a, scalar)) }); - c.bench_function("limit 512, 512", |b| b.iter(|| bench_limit(&arr_a, 512))); + c.bench_function("modulo", |b| b.iter(|| bench_modulo(&arr_a, &arr_b))); + c.bench_function("modulo_scalar", |b| { + b.iter(|| bench_modulo_scalar(&arr_a, scalar)) + }); - let arr_a_nulls = create_array(512, false); - let arr_b_nulls = create_array(512, false); - c.bench_function("add_nulls_512", |b| { + let arr_a_nulls = create_array(BATCH_SIZE, true); + let arr_b_nulls = create_array(BATCH_SIZE, true); + c.bench_function("add_nulls", |b| { b.iter(|| bench_add(&arr_a_nulls, &arr_b_nulls)) }); - c.bench_function("divide_nulls_512", |b| { + c.bench_function("divide_nulls", |b| { b.iter(|| bench_divide(&arr_a_nulls, &arr_b_nulls)) }); - c.bench_function("divide_scalar_nulls_512", |b| { + c.bench_function("divide_nulls_unchecked", |b| { + b.iter(|| bench_divide_unchecked(&arr_a_nulls, &arr_b_nulls)) + }); + c.bench_function("divide_scalar_nulls", |b| { b.iter(|| bench_divide_scalar(&arr_a_nulls, scalar)) }); + c.bench_function("modulo_nulls", |b| { + b.iter(|| bench_modulo(&arr_a_nulls, &arr_b_nulls)) + }); + c.bench_function("modulo_scalar_nulls", |b| { + b.iter(|| bench_modulo_scalar(&arr_a_nulls, scalar)) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/array_data_validate.rs b/arrow/benches/array_data_validate.rs new file mode 100644 index 000000000000..c46252bececd --- /dev/null +++ b/arrow/benches/array_data_validate.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +extern crate arrow; + +use arrow::{array::*, buffer::Buffer, datatypes::DataType}; + +fn create_binary_array_data(length: i32) -> ArrayData { + let value_buffer = Buffer::from_iter(0_i32..length); + let offsets_buffer = Buffer::from_iter(0_i32..length + 1); + ArrayData::try_new( + DataType::Binary, + length as usize, + None, + 0, + vec![offsets_buffer, value_buffer], + vec![], + ) + .unwrap() +} + +fn array_slice_benchmark(c: &mut Criterion) { + c.bench_function("validate_binary_array_data 20000", |b| { + b.iter(|| create_binary_array_data(20000)) + }); +} + +criterion_group!(benches, array_slice_benchmark); +criterion_main!(benches); diff --git a/arrow/benches/boolean_append_packed.rs b/arrow/benches/boolean_append_packed.rs new file mode 100644 index 000000000000..62bcbcc352fd --- /dev/null +++ b/arrow/benches/boolean_append_packed.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::BooleanBufferBuilder; +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::{thread_rng, Rng}; + +fn rand_bytes(len: usize) -> Vec { + let mut rng = thread_rng(); + let mut buf = vec![0_u8; len]; + rng.fill(buf.as_mut_slice()); + buf +} + +fn boolean_append_packed(c: &mut Criterion) { + let mut rng = thread_rng(); + let source = rand_bytes(1024); + let ranges: Vec<_> = (0..100) + .into_iter() + .map(|_| { + let start: usize = rng.gen_range(0..1024 * 8); + let end: usize = rng.gen_range(start..1024 * 8); + start..end + }) + .collect(); + + let total_bits: usize = ranges.iter().map(|x| x.end - x.start).sum(); + + c.bench_function("boolean_append_packed", |b| { + b.iter(|| { + let mut buffer = BooleanBufferBuilder::new(total_bits); + for range in &ranges { + buffer.append_packed_range(range.clone(), &source); + } + assert_eq!(buffer.len(), total_bits); + }) + }); +} + +criterion_group!(benches, boolean_append_packed); +criterion_main!(benches); diff --git a/arrow/benches/buffer_bit_ops.rs b/arrow/benches/buffer_bit_ops.rs index 063f39c92729..6c6bb0463b28 100644 --- a/arrow/benches/buffer_bit_ops.rs +++ b/arrow/benches/buffer_bit_ops.rs @@ -17,11 +17,14 @@ #[macro_use] extern crate criterion; -use criterion::Criterion; + +use criterion::{Criterion, Throughput}; extern crate arrow; -use arrow::buffer::{Buffer, MutableBuffer}; +use arrow::buffer::{ + buffer_bin_and, buffer_bin_or, buffer_unary_not, Buffer, MutableBuffer, +}; /// Helper function to create arrays fn create_buffer(size: usize) -> Buffer { @@ -42,17 +45,59 @@ fn bench_buffer_or(left: &Buffer, right: &Buffer) { criterion::black_box((left | right).unwrap()); } +fn bench_buffer_not(buffer: &Buffer) { + criterion::black_box(!buffer); +} + +fn bench_buffer_and_with_offsets( + left: &Buffer, + left_offset: usize, + right: &Buffer, + right_offset: usize, + len: usize, +) { + criterion::black_box(buffer_bin_and(left, left_offset, right, right_offset, len)); +} + +fn bench_buffer_or_with_offsets( + left: &Buffer, + left_offset: usize, + right: &Buffer, + right_offset: usize, + len: usize, +) { + criterion::black_box(buffer_bin_or(left, left_offset, right, right_offset, len)); +} + +fn bench_buffer_not_with_offsets(buffer: &Buffer, offset: usize, len: usize) { + criterion::black_box(buffer_unary_not(buffer, offset, len)); +} + fn bit_ops_benchmark(c: &mut Criterion) { let left = create_buffer(512 * 10); let right = create_buffer(512 * 10); - c.bench_function("buffer_bit_ops and", |b| { - b.iter(|| bench_buffer_and(&left, &right)) - }); + c.benchmark_group("buffer_binary_ops") + .throughput(Throughput::Bytes(3 * left.len() as u64)) + .bench_function("and", |b| b.iter(|| bench_buffer_and(&left, &right))) + .bench_function("or", |b| b.iter(|| bench_buffer_or(&left, &right))) + .bench_function("and_with_offset", |b| { + b.iter(|| { + bench_buffer_and_with_offsets(&left, 1, &right, 2, left.len() * 8 - 5) + }) + }) + .bench_function("or_with_offset", |b| { + b.iter(|| { + bench_buffer_or_with_offsets(&left, 1, &right, 2, left.len() * 8 - 5) + }) + }); - c.bench_function("buffer_bit_ops or", |b| { - b.iter(|| bench_buffer_or(&left, &right)) - }); + c.benchmark_group("buffer_unary_ops") + .throughput(Throughput::Bytes(2 * left.len() as u64)) + .bench_function("not", |b| b.iter(|| bench_buffer_not(&left))) + .bench_function("not_with_offset", |b| { + b.iter(|| bench_buffer_not_with_offsets(&left, 1, left.len() * 8 - 5)) + }); } criterion_group!(benches, bit_ops_benchmark); diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs index bfee9b977a37..21d83e07eec3 100644 --- a/arrow/benches/comparison_kernels.rs +++ b/arrow/benches/comparison_kernels.rs @@ -22,9 +22,9 @@ use criterion::Criterion; extern crate arrow; use arrow::compute::*; -use arrow::datatypes::ArrowNumericType; +use arrow::datatypes::{ArrowNumericType, IntervalMonthDayNanoType}; use arrow::util::bench_util::*; -use arrow::{array::*, datatypes::Float32Type}; +use arrow::{array::*, datatypes::Float32Type, datatypes::Int32Type}; fn bench_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) where @@ -119,6 +119,16 @@ fn bench_nlike_utf8_scalar(arr_a: &StringArray, value_b: &str) { .unwrap(); } +fn bench_ilike_utf8_scalar(arr_a: &StringArray, value_b: &str) { + ilike_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)) + .unwrap(); +} + +fn bench_nilike_utf8_scalar(arr_a: &StringArray, value_b: &str) { + nilike_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)) + .unwrap(); +} + fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) { regexp_is_match_utf8_scalar( criterion::black_box(arr_a), @@ -128,11 +138,28 @@ fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) { .unwrap(); } +fn bench_dict_eq(arr_a: &DictionaryArray, arr_b: &DictionaryArray) +where + T: ArrowNumericType, +{ + cmp_dict_utf8::( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a == b, + ) + .unwrap(); +} + fn add_benchmark(c: &mut Criterion) { let size = 65536; let arr_a = create_primitive_array_with_seed::(size, 0.0, 42); let arr_b = create_primitive_array_with_seed::(size, 0.0, 43); + let arr_month_day_nano_a = + create_primitive_array_with_seed::(size, 0.0, 43); + let arr_month_day_nano_b = + create_primitive_array_with_seed::(size, 0.0, 43); + let arr_string = create_string_array::(size, 0.0); c.bench_function("eq Float32", |b| b.iter(|| bench_eq(&arr_a, &arr_b))); @@ -165,6 +192,13 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_gt_eq_scalar(&arr_a, 1.0)) }); + c.bench_function("eq MonthDayNano", |b| { + b.iter(|| bench_eq(&arr_month_day_nano_a, &arr_month_day_nano_b)) + }); + c.bench_function("eq scalar MonthDayNano", |b| { + b.iter(|| bench_eq_scalar(&arr_month_day_nano_a, 123)) + }); + c.bench_function("like_utf8 scalar equals", |b| { b.iter(|| bench_like_utf8_scalar(&arr_string, "xxxx")) }); @@ -205,6 +239,46 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| bench_nlike_utf8_scalar(&arr_string, "%xx_xx%xxx")) }); + c.bench_function("ilike_utf8 scalar equals", |b| { + b.iter(|| bench_ilike_utf8_scalar(&arr_string, "xxXX")) + }); + + c.bench_function("ilike_utf8 scalar contains", |b| { + b.iter(|| bench_ilike_utf8_scalar(&arr_string, "%xxXX%")) + }); + + c.bench_function("ilike_utf8 scalar ends with", |b| { + b.iter(|| bench_ilike_utf8_scalar(&arr_string, "xXXx%")) + }); + + c.bench_function("ilike_utf8 scalar starts with", |b| { + b.iter(|| bench_ilike_utf8_scalar(&arr_string, "%XXXx")) + }); + + c.bench_function("ilike_utf8 scalar complex", |b| { + b.iter(|| bench_ilike_utf8_scalar(&arr_string, "%xx_xX%xXX")) + }); + + c.bench_function("nilike_utf8 scalar equals", |b| { + b.iter(|| bench_nilike_utf8_scalar(&arr_string, "xxXX")) + }); + + c.bench_function("nilike_utf8 scalar contains", |b| { + b.iter(|| bench_nilike_utf8_scalar(&arr_string, "%xxXX%")) + }); + + c.bench_function("nilike_utf8 scalar ends with", |b| { + b.iter(|| bench_nilike_utf8_scalar(&arr_string, "xXXx%")) + }); + + c.bench_function("nilike_utf8 scalar starts with", |b| { + b.iter(|| bench_nilike_utf8_scalar(&arr_string, "%XXXx")) + }); + + c.bench_function("nilike_utf8 scalar complex", |b| { + b.iter(|| bench_nilike_utf8_scalar(&arr_string, "%xx_xX%xXX")) + }); + c.bench_function("egexp_matches_utf8 scalar starts with", |b| { b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "^xx")) }); @@ -212,6 +286,13 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("egexp_matches_utf8 scalar ends with", |b| { b.iter(|| bench_regexp_is_match_utf8_scalar(&arr_string, "xx$")) }); + + let dict_arr_a = create_string_dict_array::(size, 0.0); + let dict_arr_b = create_string_dict_array::(size, 0.0); + + c.bench_function("dict eq string", |b| { + b.iter(|| bench_dict_eq(&dict_arr_a, &dict_arr_b)) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/csv_writer.rs b/arrow/benches/csv_writer.rs index 62c5da980312..3ecf514ad6db 100644 --- a/arrow/benches/csv_writer.rs +++ b/arrow/benches/csv_writer.rs @@ -25,6 +25,7 @@ use arrow::array::*; use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; +use std::env; use std::fs::File; use std::sync::Arc; @@ -56,7 +57,8 @@ fn criterion_benchmark(c: &mut Criterion) { vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], ) .unwrap(); - let file = File::create("target/bench_write_csv.csv").unwrap(); + let path = env::temp_dir().join("bench_write_csv.csv"); + let file = File::create(path).unwrap(); let mut writer = csv::Writer::new(file); let batches = vec![&b, &b, &b, &b, &b, &b, &b, &b, &b, &b, &b]; diff --git a/arrow/benches/filter_kernels.rs b/arrow/benches/filter_kernels.rs index d5ff09c040b8..be6d9027a8db 100644 --- a/arrow/benches/filter_kernels.rs +++ b/arrow/benches/filter_kernels.rs @@ -18,13 +18,13 @@ extern crate arrow; use std::sync::Arc; -use arrow::compute::{filter_record_batch, Filter}; +use arrow::compute::{filter_record_batch, FilterBuilder, FilterPredicate}; use arrow::record_batch::RecordBatch; use arrow::util::bench_util::*; use arrow::array::*; -use arrow::compute::{build_filter, filter}; -use arrow::datatypes::{Field, Float32Type, Schema, UInt8Type}; +use arrow::compute::filter; +use arrow::datatypes::{Field, Float32Type, Int32Type, Schema, UInt8Type}; use criterion::{criterion_group, criterion_main, Criterion}; @@ -32,8 +32,8 @@ fn bench_filter(data_array: &dyn Array, filter_array: &BooleanArray) { criterion::black_box(filter(data_array, filter_array).unwrap()); } -fn bench_built_filter<'a>(filter: &Filter<'a>, data: &impl Array) { - criterion::black_box(filter(data.data())); +fn bench_built_filter(filter: &FilterPredicate, array: &dyn Array) { + criterion::black_box(filter.filter(array).unwrap()); } fn add_benchmark(c: &mut Criterion) { @@ -42,68 +42,145 @@ fn add_benchmark(c: &mut Criterion) { let dense_filter_array = create_boolean_array(size, 0.0, 1.0 - 1.0 / 1024.0); let sparse_filter_array = create_boolean_array(size, 0.0, 1.0 / 1024.0); - let filter = build_filter(&filter_array).unwrap(); - let dense_filter = build_filter(&dense_filter_array).unwrap(); - let sparse_filter = build_filter(&sparse_filter_array).unwrap(); + let filter = FilterBuilder::new(&filter_array).optimize().build(); + let dense_filter = FilterBuilder::new(&dense_filter_array).optimize().build(); + let sparse_filter = FilterBuilder::new(&sparse_filter_array).optimize().build(); let data_array = create_primitive_array::(size, 0.0); - c.bench_function("filter u8", |b| { + c.bench_function("filter optimize (kept 1/2)", |b| { + b.iter(|| FilterBuilder::new(&filter_array).optimize().build()) + }); + + c.bench_function("filter optimize high selectivity (kept 1023/1024)", |b| { + b.iter(|| FilterBuilder::new(&dense_filter_array).optimize().build()) + }); + + c.bench_function("filter optimize low selectivity (kept 1/1024)", |b| { + b.iter(|| FilterBuilder::new(&sparse_filter_array).optimize().build()) + }); + + c.bench_function("filter u8 (kept 1/2)", |b| { b.iter(|| bench_filter(&data_array, &filter_array)) }); - c.bench_function("filter u8 high selectivity", |b| { + c.bench_function("filter u8 high selectivity (kept 1023/1024)", |b| { b.iter(|| bench_filter(&data_array, &dense_filter_array)) }); - c.bench_function("filter u8 low selectivity", |b| { + c.bench_function("filter u8 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_filter(&data_array, &sparse_filter_array)) }); - c.bench_function("filter context u8", |b| { + c.bench_function("filter context u8 (kept 1/2)", |b| { b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context u8 high selectivity", |b| { + c.bench_function("filter context u8 high selectivity (kept 1023/1024)", |b| { b.iter(|| bench_built_filter(&dense_filter, &data_array)) }); - c.bench_function("filter context u8 low selectivity", |b| { + c.bench_function("filter context u8 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); - let data_array = create_primitive_array::(size, 0.5); - c.bench_function("filter context u8 w NULLs", |b| { - b.iter(|| bench_built_filter(&filter, &data_array)) + let data_array = create_primitive_array::(size, 0.0); + c.bench_function("filter i32 (kept 1/2)", |b| { + b.iter(|| bench_filter(&data_array, &filter_array)) }); - c.bench_function("filter context u8 w NULLs high selectivity", |b| { - b.iter(|| bench_built_filter(&dense_filter, &data_array)) + c.bench_function("filter i32 high selectivity (kept 1023/1024)", |b| { + b.iter(|| bench_filter(&data_array, &dense_filter_array)) + }); + c.bench_function("filter i32 low selectivity (kept 1/1024)", |b| { + b.iter(|| bench_filter(&data_array, &sparse_filter_array)) + }); + + c.bench_function("filter context i32 (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context u8 w NULLs low selectivity", |b| { + c.bench_function( + "filter context i32 high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function("filter context i32 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); + let data_array = create_primitive_array::(size, 0.5); + c.bench_function("filter context i32 w NULLs (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context i32 w NULLs high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context i32 w NULLs low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + + let data_array = create_primitive_array::(size, 0.5); + c.bench_function("filter context u8 w NULLs (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context u8 w NULLs high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context u8 w NULLs low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + let data_array = create_primitive_array::(size, 0.5); - c.bench_function("filter f32", |b| { + c.bench_function("filter f32 (kept 1/2)", |b| { b.iter(|| bench_filter(&data_array, &filter_array)) }); - c.bench_function("filter context f32", |b| { + c.bench_function("filter context f32 (kept 1/2)", |b| { b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context f32 high selectivity", |b| { - b.iter(|| bench_built_filter(&dense_filter, &data_array)) - }); - c.bench_function("filter context f32 low selectivity", |b| { + c.bench_function( + "filter context f32 high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function("filter context f32 low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); let data_array = create_string_array::(size, 0.5); - c.bench_function("filter context string", |b| { + c.bench_function("filter context string (kept 1/2)", |b| { b.iter(|| bench_built_filter(&filter, &data_array)) }); - c.bench_function("filter context string high selectivity", |b| { - b.iter(|| bench_built_filter(&dense_filter, &data_array)) - }); - c.bench_function("filter context string low selectivity", |b| { + c.bench_function( + "filter context string high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function("filter context string low selectivity (kept 1/1024)", |b| { b.iter(|| bench_built_filter(&sparse_filter, &data_array)) }); + let data_array = create_string_dict_array::(size, 0.0); + c.bench_function("filter context string dictionary (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context string dictionary high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context string dictionary low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + + let data_array = create_string_dict_array::(size, 0.5); + c.bench_function("filter context string dictionary w NULLs (kept 1/2)", |b| { + b.iter(|| bench_built_filter(&filter, &data_array)) + }); + c.bench_function( + "filter context string dictionary w NULLs high selectivity (kept 1023/1024)", + |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)), + ); + c.bench_function( + "filter context string dictionary w NULLs low selectivity (kept 1/1024)", + |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)), + ); + let data_array = create_primitive_array::(size, 0.0); let field = Field::new("c1", data_array.data_type().clone(), true); diff --git a/arrow/benches/substring_kernels.rs b/arrow/benches/substring_kernels.rs new file mode 100644 index 000000000000..6bbfc9c09839 --- /dev/null +++ b/arrow/benches/substring_kernels.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; + +extern crate arrow; + +use arrow::array::*; +use arrow::compute::kernels::substring::*; +use arrow::util::bench_util::*; + +fn bench_substring(arr: &dyn Array, start: i64, length: Option) { + substring(criterion::black_box(arr), start, length).unwrap(); +} + +fn bench_substring_by_char( + arr: &GenericStringArray, + start: i64, + length: Option, +) { + substring_by_char(criterion::black_box(arr), start, length).unwrap(); +} + +fn add_benchmark(c: &mut Criterion) { + let size = 65536; + let val_len = 1000; + + let arr_string = create_string_array_with_len::(size, 0.0, val_len); + let arr_fsb = create_fsb_array(size, 0.0, val_len); + + c.bench_function("substring utf8 (start = 0, length = None)", |b| { + b.iter(|| bench_substring(&arr_string, 0, None)) + }); + + c.bench_function("substring utf8 (start = 1, length = str_len - 1)", |b| { + b.iter(|| bench_substring(&arr_string, 1, Some((val_len - 1) as u64))) + }); + + c.bench_function("substring utf8 by char", |b| { + b.iter(|| bench_substring_by_char(&arr_string, 1, Some((val_len - 1) as u64))) + }); + + c.bench_function("substring fixed size binary array", |b| { + b.iter(|| bench_substring(&arr_fsb, 1, Some((val_len - 1) as u64))) + }); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/parquet_derive/test/dependency/README.md b/arrow/examples/README.md similarity index 65% rename from parquet_derive/test/dependency/README.md rename to arrow/examples/README.md index b618b4636e7c..41ffd823357d 100644 --- a/parquet_derive/test/dependency/README.md +++ b/arrow/examples/README.md @@ -17,5 +17,11 @@ under the License. --> -This directory contains projects that use arrow as a dependency with -various combinations of feature flags. +# Examples + +- [`builders.rs`](builders.rs): Using the Builder API +- `collect` (TODO): Using the `FromIter` API +- [`dynamic_types.rs`](dynamic_types.rs): +- [`read_csv.rs`](read_csv.rs): Reading CSV files with explict schema, pretty printing Arrays +- [`read_csv_infer_schema.rs`](read_csv_infer_schema.rs): Reading CSV files, pretty printing Arrays +- [`tensor_builder.rs`](tensor_builder.rs): Using tensor builder diff --git a/arrow/examples/builders.rs b/arrow/examples/builders.rs index 0dc1d76f34f4..d35cb5ab744d 100644 --- a/arrow/examples/builders.rs +++ b/arrow/examples/builders.rs @@ -81,7 +81,7 @@ fn main() { .len(3) .add_buffer(Buffer::from(offsets.to_byte_slice())) .add_buffer(Buffer::from(&values[..])) - .null_bit_buffer(Buffer::from([0b00000101])) + .null_bit_buffer(Some(Buffer::from([0b00000101]))) .build() .unwrap(); let binary_array = StringArray::from(array_data); diff --git a/arrow/examples/read_csv.rs b/arrow/examples/read_csv.rs index 506b89887657..5ccf0c58a797 100644 --- a/arrow/examples/read_csv.rs +++ b/arrow/examples/read_csv.rs @@ -35,10 +35,11 @@ fn main() { Field::new("lng", DataType::Float64, false), ]); - let file = File::open("test/data/uk_cities.csv").unwrap(); + let path = format!("{}/test/data/uk_cities.csv", env!("CARGO_MANIFEST_DIR")); + let file = File::open(path).unwrap(); let mut csv = - csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None); + csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None, None); let _batch = csv.next().unwrap().unwrap(); #[cfg(feature = "prettyprint")] { diff --git a/arrow/examples/read_csv_infer_schema.rs b/arrow/examples/read_csv_infer_schema.rs index 11f8cfb7f7d2..e9f5ff650706 100644 --- a/arrow/examples/read_csv_infer_schema.rs +++ b/arrow/examples/read_csv_infer_schema.rs @@ -26,7 +26,11 @@ use std::fs::File; fn main() { #[cfg(feature = "csv")] { - let file = File::open("test/data/uk_cities_with_headers.csv").unwrap(); + let path = format!( + "{}/test/data/uk_cities_with_headers.csv", + env!("CARGO_MANIFEST_DIR") + ); + let file = File::open(path).unwrap(); let builder = csv::ReaderBuilder::new() .has_header(true) .infer_schema(Some(100)); diff --git a/arrow/src/alloc/mod.rs b/arrow/src/alloc/mod.rs index a225d32dd82d..418bc95fd2e8 100644 --- a/arrow/src/alloc/mod.rs +++ b/arrow/src/alloc/mod.rs @@ -18,12 +18,12 @@ //! Defines memory-related functions, such as allocate/deallocate/reallocate memory //! regions, cache and allocation alignments. +use std::alloc::{handle_alloc_error, Layout}; +use std::fmt::{Debug, Formatter}; use std::mem::size_of; +use std::panic::RefUnwindSafe; use std::ptr::NonNull; -use std::{ - alloc::{handle_alloc_error, Layout}, - sync::atomic::AtomicIsize, -}; +use std::sync::Arc; mod alignment; mod types; @@ -31,9 +31,6 @@ mod types; pub use alignment::ALIGNMENT; pub use types::NativeType; -// If this number is not zero after all objects have been `drop`, there is a memory leak -pub static mut ALLOCATIONS: AtomicIsize = AtomicIsize::new(0); - #[inline] unsafe fn null_pointer() -> NonNull { NonNull::new_unchecked(ALIGNMENT as *mut T) @@ -48,7 +45,6 @@ pub fn allocate_aligned(size: usize) -> NonNull { null_pointer() } else { let size = size * size_of::(); - ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); let raw_ptr = std::alloc::alloc(layout) as *mut T; @@ -66,7 +62,6 @@ pub fn allocate_aligned_zeroed(size: usize) -> NonNull { null_pointer() } else { let size = size * size_of::(); - ALLOCATIONS.fetch_add(size as isize, std::sync::atomic::Ordering::SeqCst); let layout = Layout::from_size_align_unchecked(size, ALIGNMENT); let raw_ptr = std::alloc::alloc_zeroed(layout) as *mut T; @@ -86,7 +81,6 @@ pub fn allocate_aligned_zeroed(size: usize) -> NonNull { pub unsafe fn free_aligned(ptr: NonNull, size: usize) { if ptr != null_pointer() { let size = size * size_of::(); - ALLOCATIONS.fetch_sub(size as isize, std::sync::atomic::Ordering::SeqCst); std::alloc::dealloc( ptr.as_ptr() as *mut u8, Layout::from_size_align_unchecked(size, ALIGNMENT), @@ -121,10 +115,6 @@ pub unsafe fn reallocate( return null_pointer(); } - ALLOCATIONS.fetch_add( - new_size as isize - old_size as isize, - std::sync::atomic::Ordering::SeqCst, - ); let raw_ptr = std::alloc::realloc( ptr.as_ptr() as *mut u8, Layout::from_size_align_unchecked(old_size, ALIGNMENT), @@ -134,3 +124,32 @@ pub unsafe fn reallocate( handle_alloc_error(Layout::from_size_align_unchecked(new_size, ALIGNMENT)) }) } + +/// The owner of an allocation. +/// The trait implementation is responsible for dropping the allocations once no more references exist. +pub trait Allocation: RefUnwindSafe {} + +impl Allocation for T {} + +/// Mode of deallocating memory regions +pub(crate) enum Deallocation { + /// An allocation of the given capacity that needs to be deallocated using arrows's cache aligned allocator. + /// See [allocate_aligned] and [free_aligned]. + Arrow(usize), + /// An allocation from an external source like the FFI interface or a Rust Vec. + /// Deallocation will happen + Custom(Arc), +} + +impl Debug for Deallocation { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + Deallocation::Arrow(capacity) => { + write!(f, "Deallocation::Arrow {{ capacity: {} }}", capacity) + } + Deallocation::Custom(_) => { + write!(f, "Deallocation::Custom {{ capacity: unknown }}") + } + } + } +} diff --git a/arrow/src/alloc/types.rs b/arrow/src/alloc/types.rs index 92a6107f3d54..026e1241f46b 100644 --- a/arrow/src/alloc/types.rs +++ b/arrow/src/alloc/types.rs @@ -16,6 +16,7 @@ // under the License. use crate::datatypes::DataType; +use half::f16; /// A type that Rust's custom allocator knows how to allocate and deallocate. /// This is implemented for all Arrow's physical types whose in-memory representation @@ -67,5 +68,6 @@ create_native!( i64, DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _) ); +create_native!(f16, DataType::Float16); create_native!(f32, DataType::Float32); create_native!(f64, DataType::Float64); diff --git a/arrow/src/arch/avx512.rs b/arrow/src/arch/avx512.rs deleted file mode 100644 index 264532f3594c..000000000000 --- a/arrow/src/arch/avx512.rs +++ /dev/null @@ -1,73 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -pub(crate) const AVX512_U8X64_LANES: usize = 64; - -#[target_feature(enable = "avx512f")] -pub(crate) unsafe fn avx512_bin_and(left: &[u8], right: &[u8], res: &mut [u8]) { - use core::arch::x86_64::{__m512i, _mm512_and_si512, _mm512_loadu_epi64}; - - let l: __m512i = _mm512_loadu_epi64(left.as_ptr() as *const _); - let r: __m512i = _mm512_loadu_epi64(right.as_ptr() as *const _); - let f = _mm512_and_si512(l, r); - let s = &f as *const __m512i as *const u8; - let d = res.get_unchecked_mut(0) as *mut _ as *mut u8; - std::ptr::copy_nonoverlapping(s, d, std::mem::size_of::<__m512i>()); -} - -#[target_feature(enable = "avx512f")] -pub(crate) unsafe fn avx512_bin_or(left: &[u8], right: &[u8], res: &mut [u8]) { - use core::arch::x86_64::{__m512i, _mm512_loadu_epi64, _mm512_or_si512}; - - let l: __m512i = _mm512_loadu_epi64(left.as_ptr() as *const _); - let r: __m512i = _mm512_loadu_epi64(right.as_ptr() as *const _); - let f = _mm512_or_si512(l, r); - let s = &f as *const __m512i as *const u8; - let d = res.get_unchecked_mut(0) as *mut _ as *mut u8; - std::ptr::copy_nonoverlapping(s, d, std::mem::size_of::<__m512i>()); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_bitwise_and_avx512() { - let buf1 = [0b00110011u8; 64]; - let buf2 = [0b11110000u8; 64]; - let mut buf3 = [0b00000000; 64]; - unsafe { - avx512_bin_and(&buf1, &buf2, &mut buf3); - }; - for i in buf3.iter() { - assert_eq!(&0b00110000u8, i); - } - } - - #[test] - fn test_bitwise_or_avx512() { - let buf1 = [0b00010011u8; 64]; - let buf2 = [0b11100000u8; 64]; - let mut buf3 = [0b00000000; 64]; - unsafe { - avx512_bin_or(&buf1, &buf2, &mut buf3); - }; - for i in buf3.iter() { - assert_eq!(&0b11110011u8, i); - } - } -} diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index fcf4647666e8..c566ff99f12e 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -227,6 +227,69 @@ pub trait Array: fmt::Debug + Send + Sync + JsonEqual { /// A reference-counted reference to a generic `Array`. pub type ArrayRef = Arc; +/// Ergonomics: Allow use of an ArrayRef as an `&dyn Array` +impl Array for ArrayRef { + fn as_any(&self) -> &dyn Any { + self.as_ref().as_any() + } + + fn data(&self) -> &ArrayData { + self.as_ref().data() + } + + fn data_ref(&self) -> &ArrayData { + self.as_ref().data_ref() + } + + fn data_type(&self) -> &DataType { + self.as_ref().data_type() + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + self.as_ref().slice(offset, length) + } + + fn len(&self) -> usize { + self.as_ref().len() + } + + fn is_empty(&self) -> bool { + self.as_ref().is_empty() + } + + fn offset(&self) -> usize { + self.as_ref().offset() + } + + fn is_null(&self, index: usize) -> bool { + self.as_ref().is_null(index) + } + + fn is_valid(&self, index: usize) -> bool { + self.as_ref().is_valid(index) + } + + fn null_count(&self) -> usize { + self.as_ref().null_count() + } + + fn get_buffer_memory_size(&self) -> usize { + self.as_ref().get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.as_ref().get_array_memory_size() + } + + fn to_raw( + &self, + ) -> Result<(*const ffi::FFI_ArrowArray, *const ffi::FFI_ArrowSchema)> { + let data = self.data().clone(); + let array = ffi::ArrowArray::try_from(data)?; + Ok(ffi::ArrowArray::into_raw(array)) + } +} + /// Constructs an array using the input `data`. /// Returns a reference-counted `Array` instance. pub fn make_array(data: ArrayData) -> ArrayRef { @@ -240,7 +303,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef, DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef, DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef, - DataType::Float16 => panic!("Float16 datatype not supported"), + DataType::Float16 => Arc::new(Float16Array::from(data)) as ArrayRef, DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef, DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef, DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef, @@ -275,6 +338,9 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::Interval(IntervalUnit::DayTime) => { Arc::new(IntervalDayTimeArray::from(data)) as ArrayRef } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Arc::new(IntervalMonthDayNanoArray::from(data)) as ArrayRef + } DataType::Duration(TimeUnit::Second) => { Arc::new(DurationSecondArray::from(data)) as ArrayRef } @@ -298,7 +364,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::LargeList(_) => Arc::new(LargeListArray::from(data)) as ArrayRef, DataType::Struct(_) => Arc::new(StructArray::from(data)) as ArrayRef, DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef, - DataType::Union(_) => Arc::new(UnionArray::from(data)) as ArrayRef, + DataType::Union(_, _, _) => Arc::new(UnionArray::from(data)) as ArrayRef, DataType::FixedSizeList(_, _) => { Arc::new(FixedSizeListArray::from(data)) as ArrayRef } @@ -393,7 +459,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { DataType::UInt8 => new_null_sized_array::(data_type, length), DataType::Int16 => new_null_sized_array::(data_type, length), DataType::UInt16 => new_null_sized_array::(data_type, length), - DataType::Float16 => unreachable!(), + DataType::Float16 => new_null_sized_array::(data_type, length), DataType::Int32 => new_null_sized_array::(data_type, length), DataType::UInt32 => new_null_sized_array::(data_type, length), DataType::Float32 => new_null_sized_array::(data_type, length), @@ -415,6 +481,9 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { IntervalUnit::DayTime => { new_null_sized_array::(data_type, length) } + IntervalUnit::MonthDayNano => { + new_null_sized_array::(data_type, length) + } }, DataType::FixedSizeBinary(value_len) => make_array(unsafe { ArrayData::new_unchecked( @@ -466,7 +535,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { DataType::Map(field, _keys_sorted) => { new_null_list_array::(data_type, field.data_type(), length) } - DataType::Union(_) => { + DataType::Union(_, _, _) => { unimplemented!("Creating null Union array not yet supported") } DataType::Dictionary(key, value) => { @@ -485,9 +554,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { ) }) } - DataType::Decimal(_, _) => { - unimplemented!("Creating null Decimal array not yet supported") - } + DataType::Decimal(_, _) => new_null_sized_decimal(data_type, length), } } @@ -551,6 +618,24 @@ fn new_null_sized_array( }) } +#[inline] +fn new_null_sized_decimal(data_type: &DataType, length: usize) -> ArrayRef { + make_array(unsafe { + ArrayData::new_unchecked( + data_type.clone(), + length, + Some(length), + Some(MutableBuffer::new_null(length).into()), + 0, + vec![Buffer::from(vec![ + 0u8; + length * std::mem::size_of::() + ])], + vec![], + ) + }) +} + /// Creates a new array from two FFI pointers. Used to import arrays from the C Data Interface /// # Safety /// Assumes that these pointers represent valid C Data Interfaces, both in memory @@ -563,6 +648,30 @@ pub unsafe fn make_array_from_raw( let data = ArrayData::try_from(array)?; Ok(make_array(data)) } + +/// Exports an array to raw pointers of the C Data Interface provided by the consumer. +/// # Safety +/// Assumes that these pointers represent valid C Data Interfaces, both in memory +/// representation and lifetime via the `release` mechanism. +/// +/// This function copies the content of two FFI structs [ffi::FFI_ArrowArray] and +/// [ffi::FFI_ArrowSchema] in the array to the location pointed by the raw pointers. +/// Usually the raw pointers are provided by the array data consumer. +pub unsafe fn export_array_into_raw( + src: ArrayRef, + out_array: *mut ffi::FFI_ArrowArray, + out_schema: *mut ffi::FFI_ArrowSchema, +) -> Result<()> { + let data = src.data(); + let array = ffi::FFI_ArrowArray::new(data); + let schema = ffi::FFI_ArrowSchema::try_from(data.data_type())?; + + std::ptr::write_unaligned(out_array, array); + std::ptr::write_unaligned(out_schema, schema); + + Ok(()) +} + // Helper function for printing potentially long arrays. pub(super) fn print_long_array( array: &A, @@ -764,11 +873,13 @@ mod tests { #[test] fn test_memory_size_primitive_nullable() { - let arr: PrimitiveArray = (0..128).map(Some).collect(); + let arr: PrimitiveArray = (0..128) + .map(|i| if i % 20 == 0 { Some(i) } else { None }) + .collect(); let empty_with_bitmap = PrimitiveArray::::from( ArrayData::builder(arr.data_type().clone()) .add_buffer(MutableBuffer::new(0).into()) - .null_bit_buffer(MutableBuffer::new_null(0).into()) + .null_bit_buffer(Some(MutableBuffer::new_null(0).into())) .build() .unwrap(), ); @@ -837,4 +948,22 @@ mod tests { expected_size ); } + + /// Test function that takes an &dyn Array + fn compute_my_thing(arr: &dyn Array) -> bool { + !arr.is_empty() + } + + #[test] + fn test_array_ref_as_array() { + let arr: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + + // works well! + assert!(compute_my_thing(&arr)); + + // Should also work when wrapped as an ArrayRef + let arr: ArrayRef = Arc::new(arr); + assert!(compute_my_thing(&arr)); + assert!(compute_my_thing(arr.as_ref())); + } } diff --git a/arrow/src/array/array_binary.rs b/arrow/src/array/array_binary.rs index 89a3efd2caf2..481ea92d66c3 100644 --- a/arrow/src/array/array_binary.rs +++ b/arrow/src/array/array_binary.rs @@ -15,42 +15,47 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Borrow; use std::convert::{From, TryInto}; use std::fmt; use std::{any::Any, iter::FromIterator}; +use super::BooleanBufferBuilder; use super::{ array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, FixedSizeListArray, GenericBinaryIter, GenericListArray, OffsetSizeTrait, }; +pub use crate::array::DecimalIter; use crate::buffer::Buffer; -use crate::error::ArrowError; +use crate::datatypes::{ + validate_decimal_precision, DECIMAL_DEFAULT_SCALE, DECIMAL_MAX_PRECISION, + DECIMAL_MAX_SCALE, +}; +use crate::error::{ArrowError, Result}; use crate::util::bit_util; +use crate::util::decimal::Decimal128; use crate::{buffer::MutableBuffer, datatypes::DataType}; -/// Like OffsetSizeTrait, but specialized for Binary -// This allow us to expose a constant datatype for the GenericBinaryArray -pub trait BinaryOffsetSizeTrait: OffsetSizeTrait { - const DATA_TYPE: DataType; -} - -impl BinaryOffsetSizeTrait for i32 { - const DATA_TYPE: DataType = DataType::Binary; -} - -impl BinaryOffsetSizeTrait for i64 { - const DATA_TYPE: DataType = DataType::LargeBinary; -} - /// See [`BinaryArray`] and [`LargeBinaryArray`] for storing /// binary data. -pub struct GenericBinaryArray { +pub struct GenericBinaryArray { data: ArrayData, value_offsets: RawPtrBox, value_data: RawPtrBox, } -impl GenericBinaryArray { +impl GenericBinaryArray { + /// Get the data type of the array. + // Declare this function as `pub const fn` after + // https://github.com/rust-lang/rust/issues/93706 is merged. + pub fn get_data_type() -> DataType { + if OffsetSize::IS_LARGE { + DataType::LargeBinary + } else { + DataType::Binary + } + } + /// Returns the length for value at index `i`. #[inline] pub fn value_length(&self, i: usize) -> OffsetSize { @@ -124,22 +129,10 @@ impl GenericBinaryArray { } /// Creates a [GenericBinaryArray] from a vector of byte slices + /// + /// See also [`Self::from_iter_values`] pub fn from_vec(v: Vec<&[u8]>) -> Self { - let mut offsets = Vec::with_capacity(v.len() + 1); - let mut values = Vec::new(); - let mut length_so_far: OffsetSize = OffsetSize::zero(); - offsets.push(length_so_far); - for s in &v { - length_so_far += OffsetSize::from_usize(s.len()).unwrap(); - offsets.push(length_so_far); - values.extend_from_slice(s); - } - let array_data = ArrayData::builder(OffsetSize::DATA_TYPE) - .len(v.len()) - .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)); - let array_data = unsafe { array_data.build_unchecked() }; - GenericBinaryArray::::from(array_data) + Self::from_iter_values(v) } /// Creates a [GenericBinaryArray] from a vector of Optional (null) byte slices @@ -160,29 +153,82 @@ impl GenericBinaryArray { "BinaryArray can only be created from List arrays, mismatched data types." ); - let mut builder = ArrayData::builder(OffsetSize::DATA_TYPE) + let builder = ArrayData::builder(Self::get_data_type()) .len(v.len()) .add_buffer(v.data_ref().buffers()[0].clone()) - .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()); - if let Some(bitmap) = v.data_ref().null_bitmap() { - builder = builder.null_bit_buffer(bitmap.bits.clone()) - } + .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()) + .null_bit_buffer(v.data_ref().null_buffer().cloned()); let data = unsafe { builder.build_unchecked() }; Self::from(data) } + + /// Creates a [`GenericBinaryArray`] based on an iterator of values without nulls + pub fn from_iter_values(iter: I) -> Self + where + Ptr: AsRef<[u8]>, + I: IntoIterator, + { + let iter = iter.into_iter(); + let (_, data_len) = iter.size_hint(); + let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. + + let mut offsets = + MutableBuffer::new((data_len + 1) * std::mem::size_of::()); + let mut values = MutableBuffer::new(0); + + let mut length_so_far = OffsetSize::zero(); + offsets.push(length_so_far); + + for s in iter { + let s = s.as_ref(); + length_so_far += OffsetSize::from_usize(s.len()).unwrap(); + offsets.push(length_so_far); + values.extend_from_slice(s); + } + + // iterator size hint may not be correct so compute the actual number of offsets + assert!(!offsets.is_empty()); // wrote at least one + let actual_len = (offsets.len() / std::mem::size_of::()) - 1; + + let array_data = ArrayData::builder(Self::get_data_type()) + .len(actual_len) + .add_buffer(offsets.into()) + .add_buffer(values.into()); + let array_data = unsafe { array_data.build_unchecked() }; + Self::from(array_data) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the indexes in the iterator are less than the `array.len()` + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } -impl<'a, T: BinaryOffsetSizeTrait> GenericBinaryArray { +impl<'a, T: OffsetSizeTrait> GenericBinaryArray { /// constructs a new iterator pub fn iter(&'a self) -> GenericBinaryIter<'a, T> { GenericBinaryIter::<'a, T>::new(self) } } -impl fmt::Debug for GenericBinaryArray { +impl fmt::Debug for GenericBinaryArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = if OffsetSize::is_large() { "Large" } else { "" }; + let prefix = if OffsetSize::IS_LARGE { "Large" } else { "" }; write!(f, "{}BinaryArray\n[\n", prefix)?; print_long_array(self, f, |array, index, f| { @@ -192,7 +238,7 @@ impl fmt::Debug for GenericBinaryArray Array for GenericBinaryArray { +impl Array for GenericBinaryArray { fn as_any(&self) -> &dyn Any { self } @@ -202,13 +248,11 @@ impl Array for GenericBinaryArray } } -impl From - for GenericBinaryArray -{ +impl From for GenericBinaryArray { fn from(data: ArrayData) -> Self { assert_eq!( data.data_type(), - &::DATA_TYPE, + &Self::get_data_type(), "[Large]BinaryArray expects Datatype::[Large]Binary" ); assert_eq!( @@ -226,7 +270,7 @@ impl From } } -impl FromIterator> +impl FromIterator> for GenericBinaryArray where Ptr: AsRef<[u8]>, @@ -259,19 +303,20 @@ where // calculate actual data_len, which may be different from the iterator's upper bound let data_len = offsets.len() - 1; - let array_data = ArrayData::builder(OffsetSize::DATA_TYPE) + let array_data = ArrayData::builder(Self::get_data_type()) .len(data_len) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_buffer(Buffer::from_slice_ref(&values)) - .null_bit_buffer(null_buf.into()); + .null_bit_buffer(Some(null_buf.into())); let array_data = unsafe { array_data.build_unchecked() }; Self::from(array_data) } } -/// An array where each element is a byte whose maximum length is represented by a i32. +/// An array where each element contains 0 or more bytes. +/// The byte length of each element is represented by an i32. /// -/// Examples +/// # Examples /// /// Create a BinaryArray from a vector of byte slices. /// @@ -308,8 +353,10 @@ where /// pub type BinaryArray = GenericBinaryArray; -/// An array where each element is a byte whose maximum length is represented by a i64. -/// Examples +/// An array where each element contains 0 or more bytes. +/// The byte length of each element is represented by an i64. +/// +/// # Examples /// /// Create a LargeBinaryArray from a vector of byte slices. /// @@ -346,7 +393,7 @@ pub type BinaryArray = GenericBinaryArray; /// pub type LargeBinaryArray = GenericBinaryArray; -impl<'a, T: BinaryOffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { +impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { type Item = Option<&'a [u8]>; type IntoIter = GenericBinaryIter<'a, T>; @@ -355,29 +402,27 @@ impl<'a, T: BinaryOffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { } } -impl From>> +impl From>> for GenericBinaryArray { fn from(v: Vec>) -> Self { - GenericBinaryArray::::from_opt_vec(v) + Self::from_opt_vec(v) } } -impl From> - for GenericBinaryArray -{ +impl From> for GenericBinaryArray { fn from(v: Vec<&[u8]>) -> Self { - GenericBinaryArray::::from_vec(v) + Self::from_iter_values(v) } } -impl From> for GenericBinaryArray { +impl From> for GenericBinaryArray { fn from(v: GenericListArray) -> Self { - GenericBinaryArray::::from_list(v) + Self::from_list(v) } } -/// A type of `FixedSizeListArray` whose elements are binaries. +/// An array where each element is a fixed-size sequence of bytes. /// /// # Examples /// @@ -424,6 +469,18 @@ impl FixedSizeBinaryArray { } } + /// Returns the element at index `i` as a byte slice. + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + let offset = i.checked_add(self.data.offset()).unwrap(); + let pos = self.value_offset_at(offset); + std::slice::from_raw_parts( + self.value_data.as_ptr().offset(pos as isize), + (self.value_offset_at(offset + 1) - pos) as usize, + ) + } + /// Returns the offset for the element at index `i`. /// /// Note this doesn't do any bound checking, for performance reason. @@ -467,7 +524,7 @@ impl FixedSizeBinaryArray { /// # Errors /// /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_sparse_iter(mut iter: T) -> Result + pub fn try_from_sparse_iter(mut iter: T) -> Result where T: Iterator>, U: AsRef<[u8]>, @@ -478,7 +535,7 @@ impl FixedSizeBinaryArray { let mut null_buf = MutableBuffer::from_len_zeroed(0); let mut buffer = MutableBuffer::from_len_zeroed(0); let mut prepend = 0; - iter.try_for_each(|item| -> Result<(), ArrowError> { + iter.try_for_each(|item| -> Result<()> { // extend null bitmask by one byte per each 8 items if byte == 0 { null_buf.push(0u8); @@ -551,7 +608,7 @@ impl FixedSizeBinaryArray { /// # Errors /// /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_iter(mut iter: T) -> Result + pub fn try_from_iter(mut iter: T) -> Result where T: Iterator, U: AsRef<[u8]>, @@ -559,7 +616,7 @@ impl FixedSizeBinaryArray { let mut len = 0; let mut size = None; let mut buffer = MutableBuffer::from_len_zeroed(0); - iter.try_for_each(|item| -> Result<(), ArrowError> { + iter.try_for_each(|item| -> Result<()> { let slice = item.as_ref(); if let Some(size) = size { if size != slice.len() { @@ -634,18 +691,28 @@ impl From for FixedSizeBinaryArray { "FixedSizeBinaryArray can only be created from FixedSizeList arrays, mismatched data types." ); - let mut builder = ArrayData::builder(DataType::FixedSizeBinary(v.value_length())) + let builder = ArrayData::builder(DataType::FixedSizeBinary(v.value_length())) .len(v.len()) - .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()); - if let Some(bitmap) = v.data_ref().null_bitmap() { - builder = builder.null_bit_buffer(bitmap.bits.clone()) - } + .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()) + .null_bit_buffer(v.data_ref().null_buffer().cloned()); let data = unsafe { builder.build_unchecked() }; Self::from(data) } } +impl From>> for FixedSizeBinaryArray { + fn from(v: Vec>) -> Self { + Self::try_from_sparse_iter(v.into_iter()).unwrap() + } +} + +impl From> for FixedSizeBinaryArray { + fn from(v: Vec<&[u8]>) -> Self { + Self::try_from_iter(v.into_iter()).unwrap() + } +} + impl fmt::Debug for FixedSizeBinaryArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "FixedSizeBinaryArray<{}>\n[\n", self.value_length())?; @@ -666,22 +733,31 @@ impl Array for FixedSizeBinaryArray { } } -/// A type of `DecimalArray` whose elements are binaries. +/// `DecimalArray` stores fixed width decimal numbers, +/// with a fixed precision and scale. /// /// # Examples /// /// ``` -/// use arrow::array::{Array, DecimalArray, DecimalBuilder}; +/// use arrow::array::{Array, DecimalArray}; /// use arrow::datatypes::DataType; -/// let mut builder = DecimalBuilder::new(30, 23, 6); /// -/// builder.append_value(8_887_000_000).unwrap(); -/// builder.append_null().unwrap(); -/// builder.append_value(-8_887_000_000).unwrap(); -/// let decimal_array: DecimalArray = builder.finish(); +/// // Create a DecimalArray with the default precision and scale +/// let decimal_array: DecimalArray = vec![ +/// Some(8_887_000_000), +/// None, +/// Some(-8_887_000_000), +/// ] +/// .into_iter().collect(); +/// +/// // set precision and scale so values are interpreted +/// // as `8887.000000`, `Null`, and `-8887.000000` +/// let decimal_array = decimal_array +/// .with_precision_and_scale(23, 6) +/// .unwrap(); /// /// assert_eq!(&DataType::Decimal(23, 6), decimal_array.data_type()); -/// assert_eq!(8_887_000_000, decimal_array.value(0)); +/// assert_eq!(8_887_000_000_i128, decimal_array.value(0).as_i128()); /// assert_eq!("8887.000000", decimal_array.value_as_string(0)); /// assert_eq!(3, decimal_array.len()); /// assert_eq!(1, decimal_array.null_count()); @@ -700,8 +776,8 @@ pub struct DecimalArray { } impl DecimalArray { - /// Returns the element at index `i` as i128. - pub fn value(&self, i: usize) -> i128 { + /// Returns the element at index `i`. + pub fn value(&self, i: usize) -> Decimal128 { assert!(i < self.data.len(), "DecimalArray out of bounds access"); let offset = i.checked_add(self.data.offset()).unwrap(); let raw_val = unsafe { @@ -712,10 +788,11 @@ impl DecimalArray { ) }; let as_array = raw_val.try_into(); - match as_array { + let integer = match as_array { Ok(v) if raw_val.len() == 16 => i128::from_le_bytes(v), _ => panic!("DecimalArray elements are not 128bit integers."), - } + }; + Decimal128::new_from_i128(self.precision, self.scale, integer) } /// Returns the offset for the element at index `i`. @@ -746,23 +823,7 @@ impl DecimalArray { #[inline] pub fn value_as_string(&self, row: usize) -> String { - let value = self.value(row); - let value_str = value.to_string(); - - if self.scale == 0 { - value_str - } else { - let (sign, rest) = value_str.split_at(if value >= 0 { 0 } else { 1 }); - - if rest.len() > self.scale { - // Decimal separator is in the middle of the string - let (whole, decimal) = value_str.split_at(value_str.len() - self.scale); - format!("{}.{}", whole, decimal) - } else { - // String has to be padded - format!("{}0.{:0>width$}", sign, rest, width = self.scale) - } - } + self.value(row).as_string() } pub fn from_fixed_size_list_array( @@ -782,23 +843,101 @@ impl DecimalArray { "DecimalArray can only be created from FixedSizeList arrays, mismatched data types." ); - let mut builder = ArrayData::builder(DataType::Decimal(precision, scale)) + let builder = ArrayData::builder(DataType::Decimal(precision, scale)) .len(v.len()) - .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()); - if let Some(bitmap) = v.data_ref().null_bitmap() { - builder = builder.null_bit_buffer(bitmap.bits.clone()) - } + .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()) + .null_bit_buffer(v.data_ref().null_buffer().cloned()); let array_data = unsafe { builder.build_unchecked() }; Self::from(array_data) } + + /// Creates a [DecimalArray] with default precision and scale, + /// based on an iterator of `i128` values without nulls + pub fn from_iter_values>(iter: I) -> Self { + let val_buf: Buffer = iter.into_iter().collect(); + let data = unsafe { + ArrayData::new_unchecked( + Self::default_type(), + val_buf.len() / std::mem::size_of::(), + None, + None, + 0, + vec![val_buf], + vec![], + ) + }; + DecimalArray::from(data) + } + + /// Return the precision (total digits) that can be stored by this array pub fn precision(&self) -> usize { self.precision } + /// Return the scale (digits after the decimal) that can be stored by this array pub fn scale(&self) -> usize { self.scale } + + /// Returns a DecimalArray with the same data as self, with the + /// specified precision. + /// + /// Returns an Error if: + /// 1. `precision` is larger than [`DECIMAL_MAX_PRECISION`] + /// 2. `scale` is larger than [`DECIMAL_MAX_SCALE`]; + /// 3. `scale` is > `precision` + pub fn with_precision_and_scale( + mut self, + precision: usize, + scale: usize, + ) -> Result { + if precision > DECIMAL_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, DECIMAL_MAX_PRECISION + ))); + } + if scale > DECIMAL_MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, DECIMAL_MAX_SCALE + ))); + } + if scale > precision { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than precision {}", + scale, precision + ))); + } + + // Ensure that all values are within the requested + // precision. For performance, only check if the precision is + // decreased + if precision < self.precision { + for v in self.iter().flatten() { + validate_decimal_precision(v, precision)?; + } + } + + assert_eq!( + self.data.data_type(), + &DataType::Decimal(self.precision, self.scale) + ); + + // safety: self.data is valid DataType::Decimal as checked above + let new_data_type = DataType::Decimal(precision, scale); + self.precision = precision; + self.scale = scale; + self.data = self.data.with_data_type(new_data_type); + Ok(self) + } + + /// The default precision and scale used when not specified. + pub fn default_type() -> DataType { + // Keep maximum precision + DataType::Decimal(DECIMAL_MAX_PRECISION, DECIMAL_DEFAULT_SCALE) + } } impl From for DecimalArray { @@ -824,6 +963,64 @@ impl From for DecimalArray { } } +impl From for ArrayData { + fn from(array: DecimalArray) -> Self { + array.data + } +} + +impl<'a> IntoIterator for &'a DecimalArray { + type Item = Option; + type IntoIter = DecimalIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + DecimalIter::<'a>::new(self) + } +} + +impl<'a> DecimalArray { + /// constructs a new iterator + pub fn iter(&'a self) -> DecimalIter<'a> { + DecimalIter::new(self) + } +} + +impl>> FromIterator for DecimalArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, upper) = iter.size_hint(); + let size_hint = upper.unwrap_or(lower); + + let mut null_buf = BooleanBufferBuilder::new(size_hint); + + let buffer: Buffer = iter + .map(|item| { + if let Some(a) = item.borrow() { + null_buf.append(true); + *a + } else { + null_buf.append(false); + // arbitrary value for NULL + 0 + } + }) + .collect(); + + let data = unsafe { + ArrayData::new_unchecked( + Self::default_type(), + null_buf.len(), + None, + Some(null_buf.into()), + 0, + vec![buffer], + vec![], + ) + }; + DecimalArray::from(data) + } +} + impl fmt::Debug for DecimalArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DecimalArray<{}, {}>\n[\n", self.precision, self.scale)?; @@ -848,9 +1045,12 @@ impl Array for DecimalArray { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{ array::{DecimalBuilder, LargeListArray, ListArray}, - datatypes::Field, + datatypes::{Field, Schema}, + record_batch::RecordBatch, }; use super::*; @@ -891,10 +1091,18 @@ mod tests { assert!(binary_array.is_valid(i)); assert!(!binary_array.is_null(i)); } + } + + #[test] + fn test_binary_array_with_offsets() { + let values: [u8; 12] = [ + b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', + ]; + let offsets: [i32; 4] = [0, 5, 5, 12]; // Test binary array with offset let array_data = ArrayData::builder(DataType::Binary) - .len(4) + .len(2) .offset(1) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_buffer(Buffer::from_slice_ref(&values)) @@ -947,10 +1155,18 @@ mod tests { assert!(binary_array.is_valid(i)); assert!(!binary_array.is_null(i)); } + } + + #[test] + fn test_large_binary_array_with_offsets() { + let values: [u8; 12] = [ + b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', + ]; + let offsets: [i64; 4] = [0, 5, 5, 12]; // Test binary array with offset let array_data = ArrayData::builder(DataType::LargeBinary) - .len(4) + .len(2) .offset(1) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_buffer(Buffer::from_slice_ref(&values)) @@ -1064,7 +1280,7 @@ mod tests { } } - fn test_generic_binary_array_from_opt_vec() { + fn test_generic_binary_array_from_opt_vec() { let values: Vec> = vec![Some(b"one"), Some(b"two"), None, Some(b""), Some(b"three")]; let array = GenericBinaryArray::::from_opt_vec(values); @@ -1138,7 +1354,7 @@ mod tests { .build() .unwrap(); let list_array = ListArray::from(array_data); - BinaryArray::from(list_array); + drop(BinaryArray::from(list_array)); } #[test] @@ -1196,28 +1412,30 @@ mod tests { #[test] #[should_panic( - expected = "FixedSizeBinaryArray can only be created from list array of u8 values \ - (i.e. FixedSizeList>)." + expected = "FixedSizeBinaryArray can only be created from FixedSizeList arrays" )] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_fixed_size_binary_array_from_incorrect_list_array() { let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; let values_data = ArrayData::builder(DataType::UInt32) .len(12) .add_buffer(Buffer::from_slice_ref(&values)) - .add_child_data(ArrayData::builder(DataType::Boolean).build().unwrap()) .build() .unwrap(); - let array_data = ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Binary, false)), - 4, - )) - .len(3) - .add_child_data(values_data) - .build() - .unwrap(); + let array_data = unsafe { + ArrayData::builder(DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Binary, false)), + 4, + )) + .len(3) + .add_child_data(values_data) + .build_unchecked() + }; let list_array = FixedSizeListArray::from(array_data); - FixedSizeBinaryArray::from(list_array); + drop(FixedSizeBinaryArray::from(list_array)); } #[test] @@ -1260,24 +1478,130 @@ mod tests { 192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 36, 75, 238, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]; - let array_data = ArrayData::builder(DataType::Decimal(23, 6)) + let array_data = ArrayData::builder(DataType::Decimal(38, 6)) .len(2) .add_buffer(Buffer::from(&values[..])) .build() .unwrap(); let decimal_array = DecimalArray::from(array_data); - assert_eq!(8_887_000_000, decimal_array.value(0)); - assert_eq!(-8_887_000_000, decimal_array.value(1)); + assert_eq!(8_887_000_000_i128, decimal_array.value(0).into()); + assert_eq!(-8_887_000_000_i128, decimal_array.value(1).into()); assert_eq!(16, decimal_array.value_length()); } #[test] - fn test_decimal_array_value_as_string() { - let mut decimal_builder = DecimalBuilder::new(7, 5, 3); - for value in [123450, -123450, 100, -100, 10, -10, 0] { - decimal_builder.append_value(value).unwrap(); + #[cfg(not(feature = "force_validate"))] + fn test_decimal_append_error_value() { + let mut decimal_builder = DecimalBuilder::new(10, 5, 3); + let mut result = decimal_builder.append_value(123456); + let mut error = result.unwrap_err(); + assert_eq!( + "Invalid argument error: 123456 is too large to store in a Decimal of precision 5. Max is 99999", + error.to_string() + ); + + unsafe { + decimal_builder.disable_value_validation(); } + result = decimal_builder.append_value(123456); + assert!(result.is_ok()); + decimal_builder.append_value(12345).unwrap(); let arr = decimal_builder.finish(); + assert_eq!("12.345", arr.value_as_string(1)); + + decimal_builder = DecimalBuilder::new(10, 2, 1); + result = decimal_builder.append_value(100); + error = result.unwrap_err(); + assert_eq!( + "Invalid argument error: 100 is too large to store in a Decimal of precision 2. Max is 99", + error.to_string() + ); + + unsafe { + decimal_builder.disable_value_validation(); + } + result = decimal_builder.append_value(100); + assert!(result.is_ok()); + decimal_builder.append_value(99).unwrap(); + result = decimal_builder.append_value(-100); + assert!(result.is_ok()); + decimal_builder.append_value(-99).unwrap(); + let arr = decimal_builder.finish(); + assert_eq!("9.9", arr.value_as_string(1)); + assert_eq!("-9.9", arr.value_as_string(3)); + } + + #[test] + fn test_decimal_from_iter_values() { + let array = DecimalArray::from_iter_values(vec![-100, 0, 101].into_iter()); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal(38, 10)); + assert_eq!(-100_i128, array.value(0).into()); + assert!(!array.is_null(0)); + assert_eq!(0_i128, array.value(1).into()); + assert!(!array.is_null(1)); + assert_eq!(101_i128, array.value(2).into()); + assert!(!array.is_null(2)); + } + + #[test] + fn test_decimal_from_iter() { + let array: DecimalArray = vec![Some(-100), None, Some(101)].into_iter().collect(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal(38, 10)); + assert_eq!(-100_i128, array.value(0).into()); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(101_i128, array.value(2).into()); + assert!(!array.is_null(2)); + } + + #[test] + fn test_decimal_iter() { + let data = vec![Some(-100), None, Some(101)]; + let array: DecimalArray = data.clone().into_iter().collect(); + + let collected: Vec<_> = array.iter().collect(); + assert_eq!(data, collected); + } + + #[test] + fn test_decimal_into_iter() { + let data = vec![Some(-100), None, Some(101)]; + let array: DecimalArray = data.clone().into_iter().collect(); + + let collected: Vec<_> = array.into_iter().collect(); + assert_eq!(data, collected); + } + + #[test] + fn test_decimal_iter_sized() { + let data = vec![Some(-100), None, Some(101)]; + let array: DecimalArray = data.into_iter().collect(); + let mut iter = array.into_iter(); + + // is exact sized + assert_eq!(array.len(), 3); + + // size_hint is reported correctly + assert_eq!(iter.size_hint(), (3, Some(3))); + iter.next().unwrap(); + assert_eq!(iter.size_hint(), (2, Some(2))); + iter.next().unwrap(); + iter.next().unwrap(); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert!(iter.next().is_none()); + assert_eq!(iter.size_hint(), (0, Some(0))); + } + + #[test] + fn test_decimal_array_value_as_string() { + let arr = [123450, -123450, 100, -100, 10, -10, 0] + .into_iter() + .map(Some) + .collect::() + .with_precision_and_scale(6, 3) + .unwrap(); assert_eq!("123.450", arr.value_as_string(0)); assert_eq!("-123.450", arr.value_as_string(1)); @@ -1288,16 +1612,65 @@ mod tests { assert_eq!("0.000", arr.value_as_string(6)); } + #[test] + fn test_decimal_array_with_precision_and_scale() { + let arr = DecimalArray::from_iter_values([12345, 456, 7890, -123223423432432]) + .with_precision_and_scale(20, 2) + .unwrap(); + + assert_eq!(arr.data_type(), &DataType::Decimal(20, 2)); + assert_eq!(arr.precision(), 20); + assert_eq!(arr.scale(), 2); + + let actual: Vec<_> = (0..arr.len()).map(|i| arr.value_as_string(i)).collect(); + let expected = vec!["123.45", "4.56", "78.90", "-1232234234324.32"]; + + assert_eq!(actual, expected); + } + + #[test] + #[should_panic( + expected = "-123223423432432 is too small to store in a Decimal of precision 5. Min is -99999" + )] + fn test_decimal_array_with_precision_and_scale_out_of_range() { + DecimalArray::from_iter_values([12345, 456, 7890, -123223423432432]) + // precision is too small to hold value + .with_precision_and_scale(5, 2) + .unwrap(); + } + + #[test] + #[should_panic(expected = "precision 40 is greater than max 38")] + fn test_decimal_array_with_precision_and_scale_invalid_precision() { + DecimalArray::from_iter_values([12345, 456]) + .with_precision_and_scale(40, 2) + .unwrap(); + } + + #[test] + #[should_panic(expected = "scale 40 is greater than max 38")] + fn test_decimal_array_with_precision_and_scale_invalid_scale() { + DecimalArray::from_iter_values([12345, 456]) + .with_precision_and_scale(20, 40) + .unwrap(); + } + + #[test] + #[should_panic(expected = "scale 10 is greater than precision 4")] + fn test_decimal_array_with_precision_and_scale_invalid_precision_and_scale() { + DecimalArray::from_iter_values([12345, 456]) + .with_precision_and_scale(4, 10) + .unwrap(); + } + #[test] fn test_decimal_array_fmt_debug() { - let values: Vec = vec![8887000000, -8887000000]; - let mut decimal_builder = DecimalBuilder::new(3, 23, 6); + let arr = [Some(8887000000), Some(-8887000000), None] + .iter() + .collect::() + .with_precision_and_scale(23, 6) + .unwrap(); - values.iter().for_each(|&value| { - decimal_builder.append_value(value).unwrap(); - }); - decimal_builder.append_null().unwrap(); - let arr = decimal_builder.finish(); assert_eq!( "DecimalArray<23, 6>\n[\n 8887.000000,\n -8887.000000,\n null,\n]", format!("{:?}", arr) @@ -1313,6 +1686,16 @@ mod tests { assert_eq!(3, arr.len()) } + #[test] + fn test_all_none_fixed_size_binary_array_from_sparse_iter() { + let none_option: Option<[u8; 32]> = None; + let input_arg = vec![none_option, none_option, none_option]; + let arr = + FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + assert_eq!(0, arr.value_length()); + assert_eq!(3, arr.len()) + } + #[test] fn test_fixed_size_binary_array_from_sparse_iter() { let input_arg = vec![ @@ -1327,4 +1710,111 @@ mod tests { assert_eq!(2, arr.value_length()); assert_eq!(5, arr.len()) } + + #[test] + fn test_fixed_size_binary_array_from_vec() { + let values = vec!["one".as_bytes(), b"two", b"six", b"ten"]; + let array = FixedSizeBinaryArray::from(values); + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + assert_eq!(array.value(0), b"one"); + assert_eq!(array.value(1), b"two"); + assert_eq!(array.value(2), b"six"); + assert_eq!(array.value(3), b"ten"); + assert!(!array.is_null(0)); + assert!(!array.is_null(1)); + assert!(!array.is_null(2)); + assert!(!array.is_null(3)); + } + + #[test] + #[should_panic(expected = "Nested array size mismatch: one is 3, and the other is 5")] + fn test_fixed_size_binary_array_from_vec_incorrect_length() { + let values = vec!["one".as_bytes(), b"two", b"three", b"four"]; + let _ = FixedSizeBinaryArray::from(values); + } + + #[test] + fn test_fixed_size_binary_array_from_opt_vec() { + let values = vec![ + Some("one".as_bytes()), + Some(b"two"), + None, + Some(b"six"), + Some(b"ten"), + ]; + let array = FixedSizeBinaryArray::from(values); + assert_eq!(array.len(), 5); + assert_eq!(array.value(0), b"one"); + assert_eq!(array.value(1), b"two"); + assert_eq!(array.value(3), b"six"); + assert_eq!(array.value(4), b"ten"); + assert!(!array.is_null(0)); + assert!(!array.is_null(1)); + assert!(array.is_null(2)); + assert!(!array.is_null(3)); + assert!(!array.is_null(4)); + } + + #[test] + #[should_panic(expected = "Nested array size mismatch: one is 3, and the other is 5")] + fn test_fixed_size_binary_array_from_opt_vec_incorrect_length() { + let values = vec![ + Some("one".as_bytes()), + Some(b"two"), + None, + Some(b"three"), + Some(b"four"), + ]; + let _ = FixedSizeBinaryArray::from(values); + } + + #[test] + fn test_binary_array_all_null() { + let data = vec![None]; + let array = BinaryArray::from(data); + array + .data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + fn test_large_binary_array_all_null() { + let data = vec![None]; + let array = LargeBinaryArray::from(data); + array + .data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + fn fixed_size_binary_array_all_null() { + let data = vec![None] as Vec>; + let array = FixedSizeBinaryArray::try_from_sparse_iter(data.into_iter()).unwrap(); + array + .data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + // Test for https://github.com/apache/arrow-rs/issues/1390 + #[should_panic( + expected = "column types must match schema types, expected FixedSizeBinary(2) but found FixedSizeBinary(0) at column index 0" + )] + fn fixed_size_binary_array_all_null_in_batch_with_schema() { + let schema = + Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), false)]); + + let none_option: Option<[u8; 2]> = None; + let item = FixedSizeBinaryArray::try_from_sparse_iter( + vec![none_option, none_option, none_option].into_iter(), + ) + .unwrap(); + + // Should not panic + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(item)]).unwrap(); + } } diff --git a/arrow/src/array/array_boolean.rs b/arrow/src/array/array_boolean.rs index 07f3da6c4147..f4e9ce28b733 100644 --- a/arrow/src/array/array_boolean.rs +++ b/arrow/src/array/array_boolean.rs @@ -122,6 +122,25 @@ impl BooleanArray { // `i < self.len() unsafe { self.value_unchecked(i) } } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } impl Array for BooleanArray { @@ -331,11 +350,15 @@ mod tests { #[test] #[should_panic(expected = "BooleanArray data should contain a single buffer only \ (values buffer)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_boolean_array_invalid_buffer_len() { - let data = ArrayData::builder(DataType::Boolean) - .len(5) - .build() - .unwrap(); - BooleanArray::from(data); + let data = unsafe { + ArrayData::builder(DataType::Boolean) + .len(5) + .build_unchecked() + }; + drop(BooleanArray::from(data)); } } diff --git a/arrow/src/array/array_dictionary.rs b/arrow/src/array/array_dictionary.rs index c684c253aa7b..0fbd5a34eb60 100644 --- a/arrow/src/array/array_dictionary.rs +++ b/arrow/src/array/array_dictionary.rs @@ -24,8 +24,10 @@ use super::{ make_array, Array, ArrayData, ArrayRef, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, StringDictionaryBuilder, }; -use crate::datatypes::ArrowNativeType; -use crate::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType, DataType}; +use crate::datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, +}; +use crate::error::Result; /// A dictionary array where each element is a single value indexed by an integer key. /// This is mostly used to represent strings or a limited set of primitive types as integers, @@ -50,15 +52,31 @@ use crate::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType, DataType}; /// let array : DictionaryArray = test.into_iter().collect(); /// assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2])); /// ``` +/// +/// Example from existing arrays: +/// +/// ``` +/// use arrow::array::{DictionaryArray, Int8Array, StringArray}; +/// use arrow::datatypes::Int8Type; +/// // You can form your own DictionaryArray by providing the +/// // values (dictionary) and keys (indexes into the dictionary): +/// let values = StringArray::from_iter_values(["a", "b", "c"]); +/// let keys = Int8Array::from_iter_values([0, 0, 1, 2]); +/// let array = DictionaryArray::::try_new(&keys, &values).unwrap(); +/// let expected: DictionaryArray:: = vec!["a", "a", "b", "c"] +/// .into_iter() +/// .collect(); +/// assert_eq!(&array, &expected); +/// ``` pub struct DictionaryArray { /// Data of this dictionary. Note that this is _not_ compatible with the C Data interface, /// as, in the current implementation, `values` below are the first child of this struct. data: ArrayData, - /// The keys of this dictionary. These are constructed from the buffer and null bitmap - /// of `data`. - /// Also, note that these do not correspond to the true values of this array. Rather, they map - /// to the real values. + /// The keys of this dictionary. These are constructed from the + /// buffer and null bitmap of `data`. Also, note that these do + /// not correspond to the true values of this array. Rather, they + /// map to the real values. keys: PrimitiveArray, /// Array of dictionary values (can by any DataType). @@ -69,6 +87,42 @@ pub struct DictionaryArray { } impl<'a, K: ArrowPrimitiveType> DictionaryArray { + /// Attempt to create a new DictionaryArray with a specified keys + /// (indexes into the dictionary) and values (dictionary) + /// array. Returns an error if there are any keys that are outside + /// of the dictionary array. + pub fn try_new(keys: &PrimitiveArray, values: &dyn Array) -> Result { + let dict_data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + // Note: This use the ArrayDataBuilder::build_unchecked and afterwards + // call the new function which only validates that the keys are in bounds. + let mut data = ArrayData::builder(dict_data_type) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .add_child_data(values.data().clone()); + + match keys.data().null_buffer() { + Some(buffer) if keys.data().null_count() > 0 => { + data = data + .null_bit_buffer(Some(buffer.clone())) + .null_count(keys.data().null_count()); + } + _ => data = data.null_count(0), + } + + // Safety: `validate` ensures key type is correct, and + // `validate_values` ensures all offsets are within range + let array = unsafe { data.build_unchecked() }; + + array.validate()?; + array.validate_values()?; + + Ok(array.into()) + } + /// Return an array view of the keys of this dictionary as a PrimitiveArray. pub fn keys(&self) -> &PrimitiveArray { &self.keys @@ -81,8 +135,7 @@ impl<'a, K: ArrowPrimitiveType> DictionaryArray { (0..rd_buf.len()) .position(|i| rd_buf.value(i) == value) - .map(K::Native::from_usize) - .flatten() + .and_then(K::Native::from_usize) } /// Returns a reference to the dictionary values array @@ -105,10 +158,17 @@ impl<'a, K: ArrowPrimitiveType> DictionaryArray { self.keys.is_empty() } - // Currently exists for compatibility purposes with Arrow IPC. + /// Currently exists for compatibility purposes with Arrow IPC. pub fn is_ordered(&self) -> bool { self.is_ordered } + + /// Return an iterator over the keys (indexes into the dictionary) + pub fn keys_iter(&self) -> impl Iterator> + '_ { + self.keys + .iter() + .map(|key| key.map(|k| k.to_usize().expect("Dictionary index not usize"))) + } } /// Constructs a `DictionaryArray` from an array data reference. @@ -256,14 +316,16 @@ impl fmt::Debug for DictionaryArray { mod tests { use super::*; - use crate::{ - array::Int16Array, - datatypes::{Int32Type, Int8Type, UInt32Type, UInt8Type}, - }; + use crate::array::{Float32Array, Int8Array}; + use crate::datatypes::{Float32Type, Int16Type}; use crate::{ array::Int16DictionaryArray, array::PrimitiveDictionaryBuilder, datatypes::DataType, }; + use crate::{ + array::{Int16Array, Int32Array}, + datatypes::{Int32Type, Int8Type, UInt32Type, UInt8Type}, + }; use crate::{buffer::Buffer, datatypes::ToByteSlice}; #[test] @@ -412,4 +474,120 @@ mod tests { assert_eq!(1, keys.value(2)); assert_eq!(0, keys.value(5)); } + + #[test] + fn test_dictionary_all_nulls() { + let test = vec![None, None, None]; + let array: DictionaryArray = test.into_iter().collect(); + array + .data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + fn test_dictionary_iter() { + // Construct a value array + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int16Array::from_iter_values([2_i16, 3, 4]); + + // Construct a dictionary array from the above two + let dict_array = DictionaryArray::::try_new(&keys, &values).unwrap(); + + let mut key_iter = dict_array.keys_iter(); + assert_eq!(2, key_iter.next().unwrap().unwrap()); + assert_eq!(3, key_iter.next().unwrap().unwrap()); + assert_eq!(4, key_iter.next().unwrap().unwrap()); + assert!(key_iter.next().is_none()); + + let mut iter = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(dict_array.keys_iter()); + + assert_eq!(12, iter.next().unwrap().unwrap()); + assert_eq!(13, iter.next().unwrap().unwrap()); + assert_eq!(14, iter.next().unwrap().unwrap()); + assert!(iter.next().is_none()); + } + + #[test] + fn test_dictionary_iter_with_null() { + let test = vec![Some("a"), None, Some("b"), None, None, Some("a")]; + let array: DictionaryArray = test.into_iter().collect(); + + let mut iter = array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(array.keys_iter()); + + assert_eq!("a", iter.next().unwrap().unwrap()); + assert!(iter.next().unwrap().is_none()); + assert_eq!("b", iter.next().unwrap().unwrap()); + assert!(iter.next().unwrap().is_none()); + assert!(iter.next().unwrap().is_none()); + assert_eq!("a", iter.next().unwrap().unwrap()); + assert!(iter.next().is_none()); + } + + #[test] + fn test_try_new() { + let values: StringArray = [Some("foo"), Some("bar"), Some("baz")] + .into_iter() + .collect(); + let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect(); + + let array = DictionaryArray::::try_new(&keys, &values).unwrap(); + assert_eq!(array.keys().data_type(), &DataType::Int32); + assert_eq!(array.values().data_type(), &DataType::Utf8); + + assert_eq!(array.data().null_count(), 1); + + assert!(array.keys().is_valid(0)); + assert!(array.keys().is_valid(1)); + assert!(array.keys().is_null(2)); + assert!(array.keys().is_valid(3)); + + assert_eq!(array.keys().value(0), 0); + assert_eq!(array.keys().value(1), 2); + assert_eq!(array.keys().value(3), 1); + + assert_eq!( + "DictionaryArray {keys: PrimitiveArray\n[\n 0,\n 2,\n null,\n 1,\n] values: StringArray\n[\n \"foo\",\n \"bar\",\n \"baz\",\n]}\n", + format!("{:?}", array) + ); + } + + #[test] + #[should_panic( + expected = "Value at position 1 out of bounds: 3 (should be in [0, 1])" + )] + fn test_try_new_index_too_large() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + // dictionary only has 2 values, so offset 3 is out of bounds + let keys: Int32Array = [Some(0), Some(3)].into_iter().collect(); + DictionaryArray::::try_new(&keys, &values).unwrap(); + } + + #[test] + #[should_panic( + expected = "Value at position 0 out of bounds: -100 (should be in [0, 1])" + )] + fn test_try_new_index_too_small() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + let keys: Int32Array = [Some(-100)].into_iter().collect(); + DictionaryArray::::try_new(&keys, &values).unwrap(); + } + + #[test] + #[should_panic(expected = "Dictionary key type must be integer, but was Float32")] + fn test_try_wrong_dictionary_key_type() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect(); + DictionaryArray::::try_new(&keys, &values).unwrap(); + } } diff --git a/arrow/src/array/array_list.rs b/arrow/src/array/array_list.rs index fbba8fcf412d..709e4e7ba7d0 100644 --- a/arrow/src/array/array_list.rs +++ b/arrow/src/array/array_list.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::fmt; -use num::Num; +use num::Integer; use super::{ array::print_long_array, make_array, raw_pointer::RawPtrBox, Array, ArrayData, @@ -31,25 +31,22 @@ use crate::{ }; /// trait declaring an offset size, relevant for i32 vs i64 array types. -pub trait OffsetSizeTrait: ArrowNativeType + Num + Ord + std::ops::AddAssign { - fn is_large() -> bool; +pub trait OffsetSizeTrait: ArrowNativeType + std::ops::AddAssign + Integer { + const IS_LARGE: bool; } impl OffsetSizeTrait for i32 { - #[inline] - fn is_large() -> bool { - false - } + const IS_LARGE: bool = false; } impl OffsetSizeTrait for i64 { - #[inline] - fn is_large() -> bool { - true - } + const IS_LARGE: bool = true; } -/// Generic struct for a primitive Array +/// Generic struct for a variable-size list array. +/// +/// Columnar format in Apache Arrow: +/// /// /// For non generic lists, you may wish to consider using [`ListArray`] or [`LargeListArray`]` pub struct GenericListArray { @@ -115,16 +112,11 @@ impl GenericListArray { #[inline] fn get_type(data_type: &DataType) -> Option<&DataType> { - if OffsetSize::is_large() { - if let DataType::LargeList(child) = data_type { + match (OffsetSize::IS_LARGE, data_type) { + (true, DataType::LargeList(child)) | (false, DataType::List(child)) => { Some(child.data_type()) - } else { - None } - } else if let DataType::List(child) = data_type { - Some(child.data_type()) - } else { - None + _ => None, } } @@ -177,7 +169,7 @@ impl GenericListArray { .collect(); let field = Box::new(Field::new("item", T::DATA_TYPE, true)); - let data_type = if OffsetSize::is_large() { + let data_type = if OffsetSize::IS_LARGE { DataType::LargeList(field) } else { DataType::List(field) @@ -186,7 +178,7 @@ impl GenericListArray { .len(null_buf.len()) .add_buffer(offsets.into()) .add_child_data(values.data().clone()) - .null_bit_buffer(null_buf.into()); + .null_bit_buffer(Some(null_buf.into())); let array_data = unsafe { array_data.build_unchecked() }; Self::from(array_data) @@ -236,15 +228,7 @@ impl GenericListArray { let values = make_array(values); let value_offsets = data.buffers()[0].as_ptr(); - let value_offsets = unsafe { RawPtrBox::::new(value_offsets) }; - unsafe { - if !(*value_offsets.as_ptr().offset(0)).is_zero() { - return Err(ArrowError::InvalidArgumentError(String::from( - "offsets do not start at zero", - ))); - } - } Ok(Self { data, values, @@ -265,7 +249,7 @@ impl Array for GenericListArray fmt::Debug for GenericListArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = if OffsetSize::is_large() { "Large" } else { "" }; + let prefix = if OffsetSize::IS_LARGE { "Large" } else { "" }; write!(f, "{}ListArray\n[\n", prefix)?; print_long_array(self, f, |array, index, f| { @@ -281,30 +265,25 @@ impl fmt::Debug for GenericListArray { /// # Example /// /// ``` -/// # use arrow::array::{Array, ListArray, Int32Array}; -/// # use arrow::datatypes::{DataType, Int32Type}; -/// let data = vec![ -/// Some(vec![Some(0), Some(1), Some(2)]), -/// None, -/// Some(vec![Some(3), None, Some(5), Some(19)]), -/// Some(vec![Some(6), Some(7)]), -/// ]; -/// let list_array = ListArray::from_iter_primitive::(data); -/// assert_eq!(DataType::Int32, list_array.value_type()); -/// assert_eq!(4, list_array.len()); -/// assert_eq!(1, list_array.null_count()); -/// assert_eq!(3, list_array.value_length(0)); -/// assert_eq!(0, list_array.value_length(1)); -/// assert_eq!(4, list_array.value_length(2)); -/// assert_eq!( -/// 19, -/// list_array -/// .value(2) -/// .as_any() -/// .downcast_ref::() -/// .unwrap() -/// .value(3) -/// ) +/// # use arrow::array::{Array, ListArray, Int32Array}; +/// # use arrow::datatypes::{DataType, Int32Type}; +/// let data = vec![ +/// Some(vec![]), +/// None, +/// Some(vec![Some(3), None, Some(5), Some(19)]), +/// Some(vec![Some(6), Some(7)]), +/// ]; +/// let list_array = ListArray::from_iter_primitive::(data); +/// +/// assert_eq!(false, list_array.is_valid(1)); +/// +/// let list0 = list_array.value(0); +/// let list2 = list_array.value(2); +/// let list3 = list_array.value(3); +/// +/// assert_eq!(&[] as &[i32], list0.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!(false, list2.as_any().downcast_ref::().unwrap().is_valid(1)); +/// assert_eq!(&[6, 7], list3.as_any().downcast_ref::().unwrap().values()); /// ``` pub type ListArray = GenericListArray; @@ -313,37 +292,64 @@ pub type ListArray = GenericListArray; /// # Example /// /// ``` -/// # use arrow::array::{Array, LargeListArray, Int64Array}; -/// # use arrow::datatypes::{DataType, Int64Type}; -/// let data = vec![ -/// Some(vec![Some(0), Some(1), Some(2)]), -/// None, -/// Some(vec![Some(3), None, Some(5), Some(19)]), -/// Some(vec![Some(6), Some(7)]), -/// ]; -/// let list_array = LargeListArray::from_iter_primitive::(data); -/// assert_eq!(DataType::Int64, list_array.value_type()); -/// assert_eq!(4, list_array.len()); -/// assert_eq!(1, list_array.null_count()); -/// assert_eq!(3, list_array.value_length(0)); -/// assert_eq!(0, list_array.value_length(1)); -/// assert_eq!(4, list_array.value_length(2)); -/// assert_eq!( -/// 19, -/// list_array -/// .value(2) -/// .as_any() -/// .downcast_ref::() -/// .unwrap() -/// .value(3) -/// ) +/// # use arrow::array::{Array, LargeListArray, Int32Array}; +/// # use arrow::datatypes::{DataType, Int32Type}; +/// let data = vec![ +/// Some(vec![]), +/// None, +/// Some(vec![Some(3), None, Some(5), Some(19)]), +/// Some(vec![Some(6), Some(7)]), +/// ]; +/// let list_array = LargeListArray::from_iter_primitive::(data); +/// +/// assert_eq!(false, list_array.is_valid(1)); +/// +/// let list0 = list_array.value(0); +/// let list2 = list_array.value(2); +/// let list3 = list_array.value(3); +/// +/// assert_eq!(&[] as &[i32], list0.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!(false, list2.as_any().downcast_ref::().unwrap().is_valid(1)); +/// assert_eq!(&[6, 7], list3.as_any().downcast_ref::().unwrap().values()); /// ``` pub type LargeListArray = GenericListArray; /// A list array where each element is a fixed-size sequence of values with the same /// type whose maximum length is represented by a i32. /// -/// For non generic lists, you may wish to consider using [`FixedSizeBinaryArray`] +/// # Example +/// +/// ``` +/// # use arrow::array::{Array, ArrayData, FixedSizeListArray, Int32Array}; +/// # use arrow::datatypes::{DataType, Field}; +/// # use arrow::buffer::Buffer; +/// // Construct a value array +/// let value_data = ArrayData::builder(DataType::Int32) +/// .len(9) +/// .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) +/// .build() +/// .unwrap(); +/// let list_data_type = DataType::FixedSizeList( +/// Box::new(Field::new("item", DataType::Int32, false)), +/// 3, +/// ); +/// let list_data = ArrayData::builder(list_data_type.clone()) +/// .len(3) +/// .add_child_data(value_data.clone()) +/// .build() +/// .unwrap(); +/// let list_array = FixedSizeListArray::from(list_data); +/// let list0 = list_array.value(0); +/// let list1 = list_array.value(1); +/// let list2 = list_array.value(2); +/// +/// assert_eq!( &[0, 1, 2], list0.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[3, 4, 5], list1.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[6, 7, 8], list2.as_any().downcast_ref::().unwrap().values()); +/// ``` +/// +/// For non generic lists, you may wish to consider using +/// [crate::array::FixedSizeBinaryArray] pub struct FixedSizeListArray { data: ArrayData, values: ArrayRef, @@ -499,6 +505,32 @@ mod tests { assert_eq!(list_array, another) } + #[test] + fn test_empty_list_array() { + // Construct an empty value array + let value_data = ArrayData::builder(DataType::Int32) + .len(0) + .add_buffer(Buffer::from([])) + .build() + .unwrap(); + + // Construct an empty offset buffer + let value_offsets = Buffer::from([]); + + // Construct a list array from the above two + let list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_data = ArrayData::builder(list_data_type) + .len(0) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build() + .unwrap(); + + let list_array = ListArray::from(list_data); + assert_eq!(list_array.len(), 0) + } + #[test] fn test_list_array() { // Construct a value array @@ -552,9 +584,10 @@ mod tests { assert!(!list_array.is_null(i)); } - // Now test with a non-zero offset + // Now test with a non-zero offset (skip first element) + // [[3, 4, 5], [6, 7]] let list_data = ArrayData::builder(list_data_type) - .len(3) + .len(2) .offset(1) .add_buffer(value_offsets) .add_child_data(value_data.clone()) @@ -565,7 +598,7 @@ mod tests { let values = list_array.values(); assert_eq!(&value_data, values.data()); assert_eq!(DataType::Int32, list_array.value_type()); - assert_eq!(3, list_array.len()); + assert_eq!(2, list_array.len()); assert_eq!(0, list_array.null_count()); assert_eq!(6, list_array.value_offsets()[1]); assert_eq!(2, list_array.value_length(1)); @@ -642,8 +675,9 @@ mod tests { } // Now test with a non-zero offset + // [[3, 4, 5], [6, 7]] let list_data = ArrayData::builder(list_data_type) - .len(3) + .len(2) .offset(1) .add_buffer(value_offsets) .add_child_data(value_data.clone()) @@ -654,7 +688,7 @@ mod tests { let values = list_array.values(); assert_eq!(&value_data, values.data()); assert_eq!(DataType::Int32, list_array.value_type()); - assert_eq!(3, list_array.len()); + assert_eq!(2, list_array.len()); assert_eq!(0, list_array.null_count()); assert_eq!(6, list_array.value_offsets()[1]); assert_eq!(2, list_array.value_length(1)); @@ -750,6 +784,9 @@ mod tests { #[should_panic( expected = "FixedSizeListArray child array length should be a multiple of 3" )] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_fixed_size_list_array_unequal_children() { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) @@ -763,12 +800,13 @@ mod tests { Box::new(Field::new("item", DataType::Int32, false)), 3, ); - let list_data = ArrayData::builder(list_data_type) - .len(3) - .add_child_data(value_data) - .build() - .unwrap(); - FixedSizeListArray::from(list_data); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build_unchecked() + }; + drop(FixedSizeListArray::from(list_data)); } #[test] @@ -798,7 +836,7 @@ mod tests { .len(9) .add_buffer(value_offsets) .add_child_data(value_data.clone()) - .null_bit_buffer(Buffer::from(null_bits)) + .null_bit_buffer(Some(Buffer::from(null_bits))) .build() .unwrap(); let list_array = ListArray::from(list_data); @@ -862,7 +900,7 @@ mod tests { .len(9) .add_buffer(value_offsets) .add_child_data(value_data.clone()) - .null_bit_buffer(Buffer::from(null_bits)) + .null_bit_buffer(Some(Buffer::from(null_bits))) .build() .unwrap(); let list_array = LargeListArray::from(list_data); @@ -929,7 +967,7 @@ mod tests { .len(9) .add_buffer(value_offsets) .add_child_data(value_data) - .null_bit_buffer(Buffer::from(null_bits)) + .null_bit_buffer(Some(Buffer::from(null_bits))) .build() .unwrap(); let list_array = LargeListArray::from(list_data); @@ -963,7 +1001,7 @@ mod tests { let list_data = ArrayData::builder(list_data_type) .len(5) .add_child_data(value_data.clone()) - .null_bit_buffer(Buffer::from(null_bits)) + .null_bit_buffer(Some(Buffer::from(null_bits))) .build() .unwrap(); let list_array = FixedSizeListArray::from(list_data); @@ -1025,7 +1063,7 @@ mod tests { let list_data = ArrayData::builder(list_data_type) .len(5) .add_child_data(value_data) - .null_bit_buffer(Buffer::from(null_bits)) + .null_bit_buffer(Some(Buffer::from(null_bits))) .build() .unwrap(); let list_array = FixedSizeListArray::from(list_data); @@ -1037,41 +1075,49 @@ mod tests { #[should_panic( expected = "ListArray data should contain a single buffer only (value offsets)" )] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_list_array_invalid_buffer_len() { - let value_data = ArrayData::builder(DataType::Int32) - .len(8) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) - .build() - .unwrap(); + let value_data = unsafe { + ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .build_unchecked() + }; let list_data_type = DataType::List(Box::new(Field::new("item", DataType::Int32, false))); - let list_data = ArrayData::builder(list_data_type) - .len(3) - .add_child_data(value_data) - .build() - .unwrap(); - ListArray::from(list_data); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build_unchecked() + }; + drop(ListArray::from(list_data)); } #[test] #[should_panic( expected = "ListArray should contain a single child array (values array)" )] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_list_array_invalid_child_array_len() { let value_offsets = Buffer::from_slice_ref(&[0, 2, 5, 7]); let list_data_type = DataType::List(Box::new(Field::new("item", DataType::Int32, false))); - let list_data = ArrayData::builder(list_data_type) - .len(3) - .add_buffer(value_offsets) - .build() - .unwrap(); - ListArray::from(list_data); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .build_unchecked() + }; + drop(ListArray::from(list_data)); } #[test] - #[should_panic(expected = "offsets do not start at zero")] - fn test_list_array_invalid_value_offset_start() { + fn test_list_array_offsets_need_not_start_at_zero() { let value_data = ArrayData::builder(DataType::Int32) .len(8) .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) @@ -1088,7 +1134,11 @@ mod tests { .add_child_data(value_data) .build() .unwrap(); - ListArray::from(list_data); + + let list_array = ListArray::from(list_data); + assert_eq!(list_array.value_length(0), 0); + assert_eq!(list_array.value_length(1), 3); + assert_eq!(list_array.value_length(2), 2); } #[test] @@ -1101,30 +1151,35 @@ mod tests { .add_buffer(buf2) .build() .unwrap(); - Int32Array::from(array_data); + drop(Int32Array::from(array_data)); } #[test] #[should_panic(expected = "memory is not aligned")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_list_array_alignment() { let ptr = alloc::allocate_aligned::(8); let buf = unsafe { Buffer::from_raw_parts(ptr, 8, 8) }; let buf2 = buf.slice(1); let values: [i32; 8] = [0; 8]; - let value_data = ArrayData::builder(DataType::Int32) - .add_buffer(Buffer::from_slice_ref(&values)) - .build() - .unwrap(); + let value_data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(Buffer::from_slice_ref(&values)) + .build_unchecked() + }; let list_data_type = DataType::List(Box::new(Field::new("item", DataType::Int32, false))); - let list_data = ArrayData::builder(list_data_type) - .add_buffer(buf2) - .add_child_data(value_data) - .build() - .unwrap(); - ListArray::from(list_data); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .add_buffer(buf2) + .add_child_data(value_data) + .build_unchecked() + }; + drop(ListArray::from(list_data)); } #[test] diff --git a/arrow/src/array/array_map.rs b/arrow/src/array/array_map.rs index bd888ff83e9b..081362021aa0 100644 --- a/arrow/src/array/array_map.rs +++ b/arrow/src/array/array_map.rs @@ -15,22 +15,25 @@ // specific language governing permissions and limitations // under the License. +use crate::array::{StringArray, StructArray}; +use crate::buffer::Buffer; use std::any::Any; use std::fmt; use std::mem; +use std::sync::Arc; use super::make_array; use super::{ array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, ArrayRef, }; -use crate::datatypes::{ArrowNativeType, DataType}; +use crate::datatypes::{ArrowNativeType, DataType, Field, ToByteSlice}; use crate::error::ArrowError; /// A nested array type where each record is a key-value map. /// Keys should always be non-null, but values can be null. /// -/// [MapArray] is physically a [ListArray] that has a [StructArray] -/// with 2 child fields. +/// [MapArray] is physically a [crate::array::ListArray] that has a +/// [crate::array::StructArray] with 2 child fields. pub struct MapArray { data: ArrayData, values: ArrayRef, @@ -152,6 +155,44 @@ impl MapArray { value_offsets, }) } + + /// Creates map array from provided keys, values and entry_offsets. + pub fn new_from_strings<'a>( + keys: impl Iterator, + values: &dyn Array, + entry_offsets: &[u32], + ) -> Result { + let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice()); + let keys_data = StringArray::from_iter_values(keys); + + let keys_field = Field::new("keys", DataType::Utf8, false); + let values_field = Field::new( + "values", + values.data_type().clone(), + values.null_count() > 0, + ); + + let entry_struct = StructArray::from(vec![ + (keys_field, Arc::new(keys_data) as ArrayRef), + (values_field, make_array(values.data().clone())), + ]); + + let map_data_type = DataType::Map( + Box::new(Field::new( + "entries", + entry_struct.data_type().clone(), + true, + )), + false, + ); + let map_data = ArrayData::builder(map_data_type) + .len(entry_offsets.len() - 1) + .add_buffer(entry_offsets_buffer) + .add_child_data(entry_struct.data().clone()) + .build()?; + + Ok(MapArray::from(map_data)) + } } impl Array for MapArray { @@ -255,7 +296,7 @@ mod tests { .add_buffer(Buffer::from( &[0u32, 10, 20, 0, 40, 0, 60, 70].to_byte_slice(), )) - .null_bit_buffer(Buffer::from(&[0b11010110])) + .null_bit_buffer(Some(Buffer::from(&[0b11010110]))) .build() .unwrap(); @@ -320,7 +361,7 @@ mod tests { // Now test with a non-zero offset let map_data = ArrayData::builder(map_array.data_type().clone()) - .len(3) + .len(2) .offset(1) .add_buffer(map_array.data().buffers()[0].clone()) .add_child_data(map_array.data().child_data()[0].clone()) @@ -331,7 +372,7 @@ mod tests { let values = map_array.values(); assert_eq!(&value_data, values.data()); assert_eq!(DataType::UInt32, map_array.value_type()); - assert_eq!(3, map_array.len()); + assert_eq!(2, map_array.len()); assert_eq!(0, map_array.null_count()); assert_eq!(6, map_array.value_offsets()[1]); assert_eq!(2, map_array.value_length(1)); @@ -428,4 +469,54 @@ mod tests { map_array.value(map_array.len()); } + + #[test] + fn test_new_from_strings() { + let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; + let values_data = UInt32Array::from(vec![0u32, 10, 20, 30, 40, 50, 60, 70]); + + // Construct a buffer for value offsets, for the nested array: + // [[a, b, c], [d, e, f], [g, h]] + let entry_offsets = [0, 3, 6, 8]; + + let map_array = MapArray::new_from_strings( + keys.clone().into_iter(), + &values_data, + &entry_offsets, + ) + .unwrap(); + + let values = map_array.values(); + assert_eq!( + &values_data, + values.as_any().downcast_ref::().unwrap() + ); + assert_eq!(DataType::UInt32, map_array.value_type()); + assert_eq!(3, map_array.len()); + assert_eq!(0, map_array.null_count()); + assert_eq!(6, map_array.value_offsets()[2]); + assert_eq!(2, map_array.value_length(2)); + + let key_array = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let value_array = Arc::new(UInt32Array::from(vec![0u32, 10, 20])) as ArrayRef; + let keys_field = Field::new("keys", DataType::Utf8, false); + let values_field = Field::new("values", DataType::UInt32, false); + let struct_array = + StructArray::from(vec![(keys_field, key_array), (values_field, value_array)]); + assert_eq!( + struct_array, + StructArray::from(map_array.value(0).data().clone()) + ); + assert_eq!( + &struct_array, + unsafe { map_array.value_unchecked(0) } + .as_any() + .downcast_ref::() + .unwrap() + ); + for i in 0..3 { + assert!(map_array.is_valid(i)); + assert!(!map_array.is_null(i)); + } + } } diff --git a/arrow/src/array/array_primitive.rs b/arrow/src/array/array_primitive.rs index a93e703946d1..6f496562f896 100644 --- a/arrow/src/array/array_primitive.rs +++ b/arrow/src/array/array_primitive.rs @@ -16,7 +16,6 @@ // under the License. use std::any::Any; -use std::borrow::Borrow; use std::convert::From; use std::fmt; use std::iter::{FromIterator, IntoIterator}; @@ -34,14 +33,7 @@ use crate::{ util::trusted_len_unzip, }; -/// Number of seconds in a day -const SECONDS_IN_DAY: i64 = 86_400; -/// Number of milliseconds in a second -const MILLISECONDS: i64 = 1_000; -/// Number of microseconds in a second -const MICROSECONDS: i64 = 1_000_000; -/// Number of nanoseconds in a second -const NANOSECONDS: i64 = 1_000_000_000; +use half::f16; /// Array whose elements are of primitive types. /// @@ -140,7 +132,7 @@ impl PrimitiveArray { /// Creates a PrimitiveArray based on a constant value with `count` elements pub fn from_value(value: T::Native, count: usize) -> Self { - // # Safety: length is known + // # Safety: iterator (0..count) correctly reports its length let val_buf = unsafe { Buffer::from_trusted_len_iter((0..count).map(|_| value)) }; let data = unsafe { ArrayData::new_unchecked( @@ -155,6 +147,25 @@ impl PrimitiveArray { }; PrimitiveArray::from(data) } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the offsets in the iterator are less than the array len() + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } impl Array for PrimitiveArray { @@ -330,36 +341,91 @@ impl<'a, T: ArrowPrimitiveType> PrimitiveArray { } } -impl::Native>>> - FromIterator for PrimitiveArray +/// This struct is used as an adapter when creating `PrimitiveArray` from an iterator. +/// `FromIterator` for `PrimitiveArray` takes an iterator where the elements can be `into` +/// this struct. So once implementing `From` or `Into` trait for a type, an iterator of +/// the type can be collected to `PrimitiveArray`. +#[derive(Debug)] +pub struct NativeAdapter { + pub native: Option, +} + +macro_rules! def_from_for_primitive { + ( $ty:ident, $tt:tt) => { + impl From<$tt> for NativeAdapter<$ty> { + fn from(value: $tt) -> Self { + NativeAdapter { + native: Some(value), + } + } + } + }; +} + +def_from_for_primitive!(Int8Type, i8); +def_from_for_primitive!(Int16Type, i16); +def_from_for_primitive!(Int32Type, i32); +def_from_for_primitive!(Int64Type, i64); +def_from_for_primitive!(UInt8Type, u8); +def_from_for_primitive!(UInt16Type, u16); +def_from_for_primitive!(UInt32Type, u32); +def_from_for_primitive!(UInt64Type, u64); +def_from_for_primitive!(Float16Type, f16); +def_from_for_primitive!(Float32Type, f32); +def_from_for_primitive!(Float64Type, f64); + +impl From::Native>> + for NativeAdapter +{ + fn from(value: Option<::Native>) -> Self { + NativeAdapter { native: value } + } +} + +impl From<&Option<::Native>> + for NativeAdapter +{ + fn from(value: &Option<::Native>) -> Self { + NativeAdapter { native: *value } + } +} + +impl<'a, T: ArrowPrimitiveType, Ptr: Into>> FromIterator + for PrimitiveArray { fn from_iter>(iter: I) -> Self { let iter = iter.into_iter(); let (lower, _) = iter.size_hint(); - let mut null_buf = BooleanBufferBuilder::new(lower); + let mut null_builder = BooleanBufferBuilder::new(lower); let buffer: Buffer = iter .map(|item| { - if let Some(a) = item.borrow() { - null_buf.append(true); - *a + if let Some(a) = item.into().native { + null_builder.append(true); + a } else { - null_buf.append(false); + null_builder.append(false); // this ensures that null items on the buffer are not arbitrary. - // This is important because falible operations can use null values (e.g. a vectorized "add") + // This is important because fallible operations can use null values (e.g. a vectorized "add") // which may panic (e.g. overflow if the number on the slots happen to be very large). T::Native::default() } }) .collect(); + let len = null_builder.len(); + let null_buf: Buffer = null_builder.into(); + let valid_count = null_buf.count_set_bits(); + let null_count = len - valid_count; + let opt_null_buf = (null_count != 0).then(|| null_buf); + let data = unsafe { ArrayData::new_unchecked( T::DATA_TYPE, - null_buf.len(), - None, - Some(null_buf.into()), + len, + Some(null_count), + opt_null_buf, 0, vec![buffer], vec![], @@ -444,6 +510,7 @@ def_numeric_from_vec!(Time64MicrosecondType); def_numeric_from_vec!(Time64NanosecondType); def_numeric_from_vec!(IntervalYearMonthType); def_numeric_from_vec!(IntervalDayTimeType); +def_numeric_from_vec!(IntervalMonthDayNanoType); def_numeric_from_vec!(DurationSecondType); def_numeric_from_vec!(DurationMillisecondType); def_numeric_from_vec!(DurationMicrosecondType); @@ -489,7 +556,7 @@ impl PrimitiveArray { ArrayData::builder(DataType::Timestamp(T::get_time_unit(), timezone)) .len(data_len) .add_buffer(val_buf.into()) - .null_bit_buffer(null_buf.into()); + .null_bit_buffer(Some(null_buf.into())); let array_data = unsafe { array_data.build_unchecked() }; PrimitiveArray::from(array_data) } @@ -519,6 +586,7 @@ mod tests { use std::thread; use crate::buffer::Buffer; + use crate::compute::eq_dyn; use crate::datatypes::DataType; #[test] @@ -649,6 +717,23 @@ mod tests { assert!(arr.is_null(1)); assert_eq!(-5, arr.value(2)); assert_eq!(-5, arr.values()[2]); + + // a month_day_nano interval contains months, days and nanoseconds, + // but we do not yet have accessors for the values. + // TODO: implement month, day, and nanos access method for month_day_nano. + let arr = IntervalMonthDayNanoArray::from(vec![ + Some(100000000000000000000), + None, + Some(-500000000000000000000), + ]); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(1, arr.null_count()); + assert_eq!(100000000000000000000, arr.value(0)); + assert_eq!(100000000000000000000, arr.values()[0]); + assert!(arr.is_null(1)); + assert_eq!(-500000000000000000000, arr.value(2)); + assert_eq!(-500000000000000000000, arr.values()[2]); } #[test] @@ -894,7 +979,7 @@ mod tests { #[test] fn test_primitive_array_builder() { // Test building a primitive array with ArrayData builder and offset - let buf = Buffer::from_slice_ref(&[0, 1, 2, 3, 4]); + let buf = Buffer::from_slice_ref(&[0i32, 1, 2, 3, 4, 5, 6]); let buf2 = buf.clone(); let data = ArrayData::builder(DataType::Int32) .len(5) @@ -946,12 +1031,33 @@ mod tests { assert_eq!(primitive_array.len(), 10); } + #[test] + fn test_primitive_array_from_non_null_iter() { + let iter = (0..10_i32).map(Some); + let primitive_array = PrimitiveArray::::from_iter(iter); + assert_eq!(primitive_array.len(), 10); + assert_eq!(primitive_array.null_count(), 0); + assert_eq!(primitive_array.data().null_buffer(), None); + assert_eq!(primitive_array.values(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + } + #[test] #[should_panic(expected = "PrimitiveArray data should contain a single buffer only \ (values buffer)")] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] fn test_primitive_array_invalid_buffer_len() { - let data = ArrayData::builder(DataType::Int32).len(5).build().unwrap(); - Int32Array::from(data); + let buffer = Buffer::from_slice_ref(&[0i32, 1, 2, 3, 4]); + let data = unsafe { + ArrayData::builder(DataType::Int32) + .add_buffer(buffer.clone()) + .add_buffer(buffer) + .len(5) + .build_unchecked() + }; + + drop(Int32Array::from(data)); } #[test] @@ -962,4 +1068,16 @@ mod tests { assert!(ret.is_ok()); assert_eq!(8, ret.ok().unwrap()); } + + #[test] + fn test_primitive_array_creation() { + let array1: Int8Array = [10_i8, 11, 12, 13, 14].into_iter().collect(); + let array2: Int8Array = [10_i8, 11, 12, 13, 14].into_iter().map(Some).collect(); + + let result = eq_dyn(&array1, &array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, true, true, true, true]) + ); + } } diff --git a/arrow/src/array/array_string.rs b/arrow/src/array/array_string.rs index c07f34a6b726..9e09350f7e68 100644 --- a/arrow/src/array/array_string.rs +++ b/arrow/src/array/array_string.rs @@ -27,31 +27,28 @@ use crate::buffer::Buffer; use crate::util::bit_util; use crate::{buffer::MutableBuffer, datatypes::DataType}; -/// Like OffsetSizeTrait, but specialized for Strings -// This allow us to expose a constant datatype for the GenericStringArray -pub trait StringOffsetSizeTrait: OffsetSizeTrait { - const DATA_TYPE: DataType; -} - -impl StringOffsetSizeTrait for i32 { - const DATA_TYPE: DataType = DataType::Utf8; -} - -impl StringOffsetSizeTrait for i64 { - const DATA_TYPE: DataType = DataType::LargeUtf8; -} - /// Generic struct for \[Large\]StringArray /// /// See [`StringArray`] and [`LargeStringArray`] for storing /// specific string data. -pub struct GenericStringArray { +pub struct GenericStringArray { data: ArrayData, value_offsets: RawPtrBox, value_data: RawPtrBox, } -impl GenericStringArray { +impl GenericStringArray { + /// Get the data type of the array. + // Declare this function as `pub const fn` after + // https://github.com/rust-lang/rust/issues/93706 is merged. + pub fn get_data_type() -> DataType { + if OffsetSize::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + } + } + /// Returns the length for the element at index `i`. #[inline] pub fn value_length(&self, i: usize) -> OffsetSize { @@ -78,6 +75,15 @@ impl GenericStringArray { self.data.buffers()[1].clone() } + /// Returns the number of `Unicode Scalar Value` in the string at index `i`. + /// # Performance + /// This function has `O(n)` time complexity where `n` is the string length. + /// If you can make sure that all chars in the string are in the range `U+0x0000` ~ `U+0x007F`, + /// please use the function [`value_length`](#method.value_length) which has O(1) time complexity. + pub fn num_chars(&self, i: usize) -> usize { + self.value(i).chars().count() + } + /// Returns the element at index /// # Safety /// caller is responsible for ensuring that index is within the array bounds @@ -125,50 +131,21 @@ impl GenericStringArray { "StringArray can only be created from List arrays, mismatched data types." ); - let mut builder = ArrayData::builder(OffsetSize::DATA_TYPE) + let builder = ArrayData::builder(Self::get_data_type()) .len(v.len()) .add_buffer(v.data().buffers()[0].clone()) - .add_buffer(v.data().child_data()[0].buffers()[0].clone()); - if let Some(bitmap) = v.data().null_bitmap() { - builder = builder.null_bit_buffer(bitmap.bits.clone()) - } + .add_buffer(v.data().child_data()[0].buffers()[0].clone()) + .null_bit_buffer(v.data().null_buffer().cloned()); let array_data = unsafe { builder.build_unchecked() }; Self::from(array_data) } - pub(crate) fn from_vec(v: Vec) -> Self - where - Ptr: AsRef, - { - let mut offsets = - MutableBuffer::new((v.len() + 1) * std::mem::size_of::()); - let mut values = MutableBuffer::new(0); - - let mut length_so_far = OffsetSize::zero(); - offsets.push(length_so_far); - - for s in &v { - length_so_far += OffsetSize::from_usize(s.as_ref().len()).unwrap(); - offsets.push(length_so_far); - values.extend_from_slice(s.as_ref().as_bytes()); - } - let array_data = ArrayData::builder(OffsetSize::DATA_TYPE) - .len(v.len()) - .add_buffer(offsets.into()) - .add_buffer(values.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Self::from(array_data) - } - - pub(crate) fn from_opt_vec(v: Vec>) -> Self { - v.into_iter().collect() - } - - /// Creates a `GenericStringArray` based on an iterator of values without nulls - pub fn from_iter_values>(iter: I) -> Self + /// Creates a [`GenericStringArray`] based on an iterator of values without nulls + pub fn from_iter_values(iter: I) -> Self where Ptr: AsRef, + I: IntoIterator, { let iter = iter.into_iter(); let (_, data_len) = iter.size_hint(); @@ -187,16 +164,40 @@ impl GenericStringArray { offsets.push(length_so_far); values.extend_from_slice(s.as_bytes()); } - let array_data = ArrayData::builder(OffsetSize::DATA_TYPE) - .len(data_len) + + // iterator size hint may not be correct so compute the actual number of offsets + assert!(!offsets.is_empty()); // wrote at least one + let actual_len = (offsets.len() / std::mem::size_of::()) - 1; + + let array_data = ArrayData::builder(Self::get_data_type()) + .len(actual_len) .add_buffer(offsets.into()) .add_buffer(values.into()); let array_data = unsafe { array_data.build_unchecked() }; Self::from(array_data) } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + pub fn take_iter<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value(index))) + } + + /// Returns an iterator that returns the values of `array.value(i)` for an iterator with each element `i` + /// # Safety + /// + /// caller must ensure that the indexes in the iterator are less than the `array.len()` + pub unsafe fn take_iter_unchecked<'a>( + &'a self, + indexes: impl Iterator> + 'a, + ) -> impl Iterator> + 'a { + indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + } } -impl<'a, Ptr, OffsetSize: StringOffsetSizeTrait> FromIterator<&'a Option> +impl<'a, Ptr, OffsetSize: OffsetSizeTrait> FromIterator<&'a Option> for GenericStringArray where Ptr: AsRef + 'a, @@ -210,12 +211,12 @@ where } } -impl<'a, Ptr, OffsetSize: StringOffsetSizeTrait> FromIterator> +impl<'a, Ptr, OffsetSize: OffsetSizeTrait> FromIterator> for GenericStringArray where Ptr: AsRef, { - /// Creates a [`GenericStringArray`] based on an iterator of `Option`s + /// Creates a [`GenericStringArray`] based on an iterator of [`Option`]s fn from_iter>>(iter: I) -> Self { let iter = iter.into_iter(); let (_, data_len) = iter.size_hint(); @@ -245,17 +246,17 @@ where // calculate actual data_len, which may be different from the iterator's upper bound let data_len = (offsets.len() / offset_size) - 1; - let array_data = ArrayData::builder(OffsetSize::DATA_TYPE) + let array_data = ArrayData::builder(Self::get_data_type()) .len(data_len) .add_buffer(offsets.into()) .add_buffer(values.into()) - .null_bit_buffer(null_buf.into()); + .null_bit_buffer(Some(null_buf.into())); let array_data = unsafe { array_data.build_unchecked() }; Self::from(array_data) } } -impl<'a, T: StringOffsetSizeTrait> IntoIterator for &'a GenericStringArray { +impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericStringArray { type Item = Option<&'a str>; type IntoIter = GenericStringIter<'a, T>; @@ -264,16 +265,16 @@ impl<'a, T: StringOffsetSizeTrait> IntoIterator for &'a GenericStringArray { } } -impl<'a, T: StringOffsetSizeTrait> GenericStringArray { +impl<'a, T: OffsetSizeTrait> GenericStringArray { /// constructs a new iterator pub fn iter(&'a self) -> GenericStringIter<'a, T> { GenericStringIter::<'a, T>::new(self) } } -impl fmt::Debug for GenericStringArray { +impl fmt::Debug for GenericStringArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = if OffsetSize::is_large() { "Large" } else { "" }; + let prefix = if OffsetSize::IS_LARGE { "Large" } else { "" }; write!(f, "{}StringArray\n[\n", prefix)?; print_long_array(self, f, |array, index, f| { @@ -283,7 +284,7 @@ impl fmt::Debug for GenericStringArray Array for GenericStringArray { +impl Array for GenericStringArray { fn as_any(&self) -> &dyn Any { self } @@ -293,13 +294,11 @@ impl Array for GenericStringArray } } -impl From - for GenericStringArray -{ +impl From for GenericStringArray { fn from(data: ArrayData) -> Self { assert_eq!( data.data_type(), - &::DATA_TYPE, + &Self::get_data_type(), "[Large]StringArray expects Datatype::[Large]Utf8" ); assert_eq!( @@ -317,27 +316,23 @@ impl From } } -impl From>> +impl From>> for GenericStringArray { fn from(v: Vec>) -> Self { - GenericStringArray::::from_opt_vec(v) + v.into_iter().collect() } } -impl From> - for GenericStringArray -{ +impl From> for GenericStringArray { fn from(v: Vec<&str>) -> Self { - GenericStringArray::::from_vec(v) + Self::from_iter_values(v) } } -impl From> - for GenericStringArray -{ +impl From> for GenericStringArray { fn from(v: Vec) -> Self { - GenericStringArray::::from_vec(v) + Self::from_iter_values(v) } } @@ -365,7 +360,7 @@ pub type StringArray = GenericStringArray; /// ``` pub type LargeStringArray = GenericStringArray; -impl From> for GenericStringArray { +impl From> for GenericStringArray { fn from(v: GenericListArray) -> Self { GenericStringArray::::from_list(v) } @@ -380,9 +375,9 @@ mod tests { #[test] fn test_string_array_from_u8_slice() { - let values: Vec<&str> = vec!["hello", "", "parquet"]; + let values: Vec<&str> = vec!["hello", "", "A£ऀ𖼚𝌆৩ƐZ"]; - // Array data: ["hello", "", "parquet"] + // Array data: ["hello", "", "A£ऀ𖼚𝌆৩ƐZ"] let string_array = StringArray::from(values); assert_eq!(3, string_array.len()); @@ -391,10 +386,12 @@ mod tests { assert_eq!("hello", unsafe { string_array.value_unchecked(0) }); assert_eq!("", string_array.value(1)); assert_eq!("", unsafe { string_array.value_unchecked(1) }); - assert_eq!("parquet", string_array.value(2)); - assert_eq!("parquet", unsafe { string_array.value_unchecked(2) }); - assert_eq!(5, string_array.value_offsets()[2]); - assert_eq!(7, string_array.value_length(2)); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", string_array.value(2)); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", unsafe { + string_array.value_unchecked(2) + }); + assert_eq!(20, string_array.value_length(2)); // 1 + 2 + 3 + 4 + 4 + 3 + 2 + 1 + assert_eq!(8, string_array.num_chars(2)); for i in 0..3 { assert!(string_array.is_valid(i)); assert!(!string_array.is_null(i)); @@ -405,14 +402,14 @@ mod tests { #[should_panic(expected = "[Large]StringArray expects Datatype::[Large]Utf8")] fn test_string_array_from_int() { let array = LargeStringArray::from(vec!["a", "b"]); - StringArray::from(array.data().clone()); + drop(StringArray::from(array.data().clone())); } #[test] fn test_large_string_array_from_u8_slice() { - let values: Vec<&str> = vec!["hello", "", "parquet"]; + let values: Vec<&str> = vec!["hello", "", "A£ऀ𖼚𝌆৩ƐZ"]; - // Array data: ["hello", "", "parquet"] + // Array data: ["hello", "", "A£ऀ𖼚𝌆৩ƐZ"] let string_array = LargeStringArray::from(values); assert_eq!(3, string_array.len()); @@ -421,10 +418,13 @@ mod tests { assert_eq!("hello", unsafe { string_array.value_unchecked(0) }); assert_eq!("", string_array.value(1)); assert_eq!("", unsafe { string_array.value_unchecked(1) }); - assert_eq!("parquet", string_array.value(2)); - assert_eq!("parquet", unsafe { string_array.value_unchecked(2) }); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", string_array.value(2)); + assert_eq!("A£ऀ𖼚𝌆৩ƐZ", unsafe { + string_array.value_unchecked(2) + }); assert_eq!(5, string_array.value_offsets()[2]); - assert_eq!(7, string_array.value_length(2)); + assert_eq!(20, string_array.value_length(2)); // 1 + 2 + 3 + 4 + 4 + 3 + 2 + 1 + assert_eq!(8, string_array.num_chars(2)); for i in 0..3 { assert!(string_array.is_valid(i)); assert!(!string_array.is_null(i)); @@ -554,13 +554,88 @@ mod tests { } #[test] - fn test_string_array_from_string_vec() { - let data = vec!["Foo".to_owned(), "Bar".to_owned(), "Baz".to_owned()]; + fn test_string_array_all_null() { + let data = vec![None]; let array = StringArray::from(data); + array + .data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + fn test_large_string_array_all_null() { + let data = vec![None]; + let array = LargeStringArray::from(data); + array + .data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[cfg(feature = "test_utils")] + #[test] + fn bad_size_collect_string() { + use crate::util::test_util::BadIterator; + let data = vec![Some("foo"), None, Some("bar")]; + let expected: StringArray = data.clone().into_iter().collect(); + + // Iterator reports too many items + let arr: StringArray = BadIterator::new(3, 10, data.clone()).collect(); + assert_eq!(expected, arr); + + // Iterator reports too few items + let arr: StringArray = BadIterator::new(3, 1, data.clone()).collect(); + assert_eq!(expected, arr); + } + + #[cfg(feature = "test_utils")] + #[test] + fn bad_size_collect_large_string() { + use crate::util::test_util::BadIterator; + let data = vec![Some("foo"), None, Some("bar")]; + let expected: LargeStringArray = data.clone().into_iter().collect(); + + // Iterator reports too many items + let arr: LargeStringArray = BadIterator::new(3, 10, data.clone()).collect(); + assert_eq!(expected, arr); - assert_eq!(array.len(), 3); - assert_eq!(array.value(0), "Foo"); - assert_eq!(array.value(1), "Bar"); - assert_eq!(array.value(2), "Baz"); + // Iterator reports too few items + let arr: LargeStringArray = BadIterator::new(3, 1, data.clone()).collect(); + assert_eq!(expected, arr); + } + + #[cfg(feature = "test_utils")] + #[test] + fn bad_size_iter_values_string() { + use crate::util::test_util::BadIterator; + let data = vec!["foo", "bar", "baz"]; + let expected: StringArray = data.clone().into_iter().map(Some).collect(); + + // Iterator reports too many items + let arr = StringArray::from_iter_values(BadIterator::new(3, 10, data.clone())); + assert_eq!(expected, arr); + + // Iterator reports too few items + let arr = StringArray::from_iter_values(BadIterator::new(3, 1, data.clone())); + assert_eq!(expected, arr); + } + + #[cfg(feature = "test_utils")] + #[test] + fn bad_size_iter_values_large_string() { + use crate::util::test_util::BadIterator; + let data = vec!["foo", "bar", "baz"]; + let expected: LargeStringArray = data.clone().into_iter().map(Some).collect(); + + // Iterator reports too many items + let arr = + LargeStringArray::from_iter_values(BadIterator::new(3, 10, data.clone())); + assert_eq!(expected, arr); + + // Iterator reports too few items + let arr = + LargeStringArray::from_iter_values(BadIterator::new(3, 1, data.clone())); + assert_eq!(expected, arr); } } diff --git a/arrow/src/array/array_struct.rs b/arrow/src/array/array_struct.rs index a1cab7f50c70..91c77c72b17d 100644 --- a/arrow/src/array/array_struct.rs +++ b/arrow/src/array/array_struct.rs @@ -108,10 +108,12 @@ impl StructArray { impl From for StructArray { fn from(data: ArrayData) -> Self { - let mut boxed_fields = vec![]; - for cd in data.child_data() { - boxed_fields.push(make_array(cd.clone())); - } + let boxed_fields = data + .child_data() + .iter() + .map(|cd| make_array(cd.clone())) + .collect(); + Self { data, boxed_fields } } } @@ -174,12 +176,10 @@ impl TryFrom> for StructArray { } let len = len.unwrap(); - let mut builder = ArrayData::builder(DataType::Struct(fields)) + let builder = ArrayData::builder(DataType::Struct(fields)) .len(len) + .null_bit_buffer(null) .child_data(child_data); - if let Some(null_buffer) = null { - builder = builder.null_bit_buffer(null_buffer); - } let array_data = unsafe { builder.build_unchecked() }; @@ -268,7 +268,7 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray { } let array_data = ArrayData::builder(DataType::Struct(field_types)) - .null_bit_buffer(pair.1) + .null_bit_buffer(Some(pair.1)) .child_data(field_values.into_iter().map(|a| a.data().clone()).collect()) .len(length); let array_data = unsafe { array_data.build_unchecked() }; @@ -358,13 +358,13 @@ mod tests { assert_eq!(1, struct_data.null_count()); assert_eq!( // 00001011 - &Some(Bitmap::from(Buffer::from(&[11_u8]))), + Some(&Bitmap::from(Buffer::from(&[11_u8]))), struct_data.null_bitmap() ); let expected_string_data = ArrayData::builder(DataType::Utf8) .len(4) - .null_bit_buffer(Buffer::from(&[9_u8])) + .null_bit_buffer(Some(Buffer::from(&[9_u8]))) .add_buffer(Buffer::from(&[0, 3, 3, 3, 7].to_byte_slice())) .add_buffer(Buffer::from(b"joemark")) .build() @@ -372,7 +372,7 @@ mod tests { let expected_int_data = ArrayData::builder(DataType::Int32) .len(4) - .null_bit_buffer(Buffer::from(&[11_u8])) + .null_bit_buffer(Some(Buffer::from(&[11_u8]))) .add_buffer(Buffer::from(&[1, 2, 0, 4].to_byte_slice())) .build() .unwrap(); @@ -408,7 +408,7 @@ mod tests { expected = "the field data types must match the array data in a StructArray" )] fn test_struct_array_from_mismatched_types() { - StructArray::from(vec![ + drop(StructArray::from(vec![ ( Field::new("b", DataType::Int16, false), Arc::new(BooleanArray::from(vec![false, false, true, true])) @@ -418,7 +418,7 @@ mod tests { Field::new("c", DataType::Utf8, false), Arc::new(Int32Array::from(vec![42, 28, 19, 31])), ), - ]); + ])); } #[test] @@ -426,24 +426,25 @@ mod tests { let boolean_data = ArrayData::builder(DataType::Boolean) .len(5) .add_buffer(Buffer::from([0b00010000])) - .null_bit_buffer(Buffer::from([0b00010001])) + .null_bit_buffer(Some(Buffer::from([0b00010001]))) .build() .unwrap(); let int_data = ArrayData::builder(DataType::Int32) .len(5) .add_buffer(Buffer::from([0, 28, 42, 0, 0].to_byte_slice())) - .null_bit_buffer(Buffer::from([0b00000110])) + .null_bit_buffer(Some(Buffer::from([0b00000110]))) .build() .unwrap(); - let mut field_types = vec![]; - field_types.push(Field::new("a", DataType::Boolean, false)); - field_types.push(Field::new("b", DataType::Int32, false)); + let field_types = vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int32, false), + ]; let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) .len(5) .add_child_data(boolean_data.clone()) .add_child_data(int_data.clone()) - .null_bit_buffer(Buffer::from([0b00010111])) + .null_bit_buffer(Some(Buffer::from([0b00010111]))) .build() .unwrap(); let struct_array = StructArray::from(struct_array_data); @@ -515,7 +516,7 @@ mod tests { expected = "all child arrays of a StructArray must have the same length" )] fn test_invalid_struct_child_array_lengths() { - StructArray::from(vec![ + drop(StructArray::from(vec![ ( Field::new("b", DataType::Float32, false), Arc::new(Float32Array::from(vec![1.1])) as Arc, @@ -524,6 +525,6 @@ mod tests { Field::new("c", DataType::Float64, false), Arc::new(Float64Array::from(vec![2.2, 3.3])), ), - ]); + ])); } } diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs index ba563ec796b4..4ff0a31c6529 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow/src/array/array_union.rs @@ -17,25 +17,95 @@ /// Contains the `UnionArray` type. /// -use crate::array::{data::count_nulls, make_array, Array, ArrayData, ArrayRef}; +use crate::array::{make_array, Array, ArrayData, ArrayRef}; use crate::buffer::Buffer; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use core::fmt; use std::any::Any; -use std::mem::size_of; /// An Array that can represent slots of varying types. /// -/// Each slot in a `UnionArray` can have a value chosen from a number of types. Each of the -/// possible types are named like the fields of a [`StructArray`](crate::array::StructArray). -/// A `UnionArray` can have two possible memory layouts, "dense" or "sparse". For more information -/// on please see the [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout). +/// Each slot in a [UnionArray] can have a value chosen from a number +/// of types. Each of the possible types are named like the fields of +/// a [`StructArray`](crate::array::StructArray). A `UnionArray` can +/// have two possible memory layouts, "dense" or "sparse". For more +/// information on please see the +/// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout). /// -/// [`UnionBuilder`]can be used to create `UnionArray`'s of primitive types. `UnionArray`'s of nested -/// types are also supported but not via `UnionBuilder`, see the tests for examples. +/// [UnionBuilder](crate::array::UnionBuilder) can be used to +/// create [UnionArray]'s of primitive types. `UnionArray`'s of nested +/// types are also supported but not via `UnionBuilder`, see the tests +/// for examples. /// +/// # Examples +/// ## Create a dense UnionArray `[1, 3.2, 34]` +/// ``` +/// use arrow::buffer::Buffer; +/// use arrow::datatypes::*; +/// use std::sync::Arc; +/// use arrow::array::{Array, Int32Array, Float64Array, UnionArray}; +/// +/// let int_array = Int32Array::from(vec![1, 34]); +/// let float_array = Float64Array::from(vec![3.2]); +/// let type_id_buffer = Buffer::from_slice_ref(&[0_i8, 1, 0]); +/// let value_offsets_buffer = Buffer::from_slice_ref(&[0_i32, 0, 1]); +/// +/// let children: Vec<(Field, Arc)> = vec![ +/// (Field::new("A", DataType::Int32, false), Arc::new(int_array)), +/// (Field::new("B", DataType::Float64, false), Arc::new(float_array)), +/// ]; +/// +/// let array = UnionArray::try_new( +/// &vec![0, 1], +/// type_id_buffer, +/// Some(value_offsets_buffer), +/// children, +/// ).unwrap(); +/// +/// let value = array.value(0).as_any().downcast_ref::().unwrap().value(0); +/// assert_eq!(1, value); +/// +/// let value = array.value(1).as_any().downcast_ref::().unwrap().value(0); +/// assert!(3.2 - value < f64::EPSILON); +/// +/// let value = array.value(2).as_any().downcast_ref::().unwrap().value(0); +/// assert_eq!(34, value); +/// ``` +/// +/// ## Create a sparse UnionArray `[1, 3.2, 34]` +/// ``` +/// use arrow::buffer::Buffer; +/// use arrow::datatypes::*; +/// use std::sync::Arc; +/// use arrow::array::{Array, Int32Array, Float64Array, UnionArray}; +/// +/// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]); +/// let float_array = Float64Array::from(vec![None, Some(3.2), None]); +/// let type_id_buffer = Buffer::from_slice_ref(&[0_i8, 1, 0]); +/// +/// let children: Vec<(Field, Arc)> = vec![ +/// (Field::new("A", DataType::Int32, false), Arc::new(int_array)), +/// (Field::new("B", DataType::Float64, false), Arc::new(float_array)), +/// ]; +/// +/// let array = UnionArray::try_new( +/// &vec![0, 1], +/// type_id_buffer, +/// None, +/// children, +/// ).unwrap(); +/// +/// let value = array.value(0).as_any().downcast_ref::().unwrap().value(0); +/// assert_eq!(1, value); +/// +/// let value = array.value(1).as_any().downcast_ref::().unwrap().value(0); +/// assert!(3.2 - value < f64::EPSILON); +/// +/// let value = array.value(2).as_any().downcast_ref::().unwrap().value(0); +/// assert_eq!(34, value); +/// ``` pub struct UnionArray { data: ArrayData, boxed_fields: Vec, @@ -49,7 +119,7 @@ impl UnionArray { /// caller and assumes that each of the components are correct and consistent with each other. /// See `try_new` for an alternative that validates the data provided. /// - /// # Data Consistency + /// # Safety /// /// The `type_ids` `Buffer` should contain `i8` values. These values should be greater than /// zero and must be less than the number of children provided in `child_arrays`. These values @@ -57,8 +127,8 @@ impl UnionArray { /// /// The `value_offsets` `Buffer` is only provided in the case of a dense union, sparse unions /// should use `None`. If provided the `value_offsets` `Buffer` should contain `i32` values. - /// These values should be greater than zero and must be less than the length of the overall - /// array. + /// The values in this array should be greater than zero and must be less than the length of the + /// overall array. /// /// In both cases above we use signed integer types to maintain compatibility with other /// Arrow implementations. @@ -66,40 +136,47 @@ impl UnionArray { /// In both of the cases above we are accepting `Buffer`'s which are assumed to be representing /// `i8` and `i32` values respectively. `Buffer` objects are untyped and no attempt is made /// to ensure that the data provided is valid. - pub fn new( + pub unsafe fn new_unchecked( + field_type_ids: &[i8], type_ids: Buffer, value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, - bitmap_data: Option, ) -> Self { let (field_types, field_values): (Vec<_>, Vec<_>) = child_arrays.into_iter().unzip(); let len = type_ids.len(); - let mut builder = ArrayData::builder(DataType::Union(field_types)) - .add_buffer(type_ids) - .child_data(field_values.into_iter().map(|a| a.data().clone()).collect()) - .len(len); - if let Some(bitmap) = bitmap_data { - builder = builder.null_bit_buffer(bitmap) - } - let data = unsafe { - match value_offsets { - Some(b) => builder.add_buffer(b).build_unchecked(), - None => builder.build_unchecked(), - } + + let mode = if value_offsets.is_some() { + UnionMode::Dense + } else { + UnionMode::Sparse + }; + + let builder = ArrayData::builder(DataType::Union( + field_types, + Vec::from(field_type_ids), + mode, + )) + .add_buffer(type_ids) + .child_data(field_values.into_iter().map(|a| a.data().clone()).collect()) + .len(len); + + let data = match value_offsets { + Some(b) => builder.add_buffer(b).build_unchecked(), + None => builder.build_unchecked(), }; Self::from(data) } - /// Attempts to create a new `UnionArray` and validates the inputs provided. + + /// Attempts to create a new `UnionArray`, validating the inputs provided. pub fn try_new( + field_type_ids: &[i8], type_ids: Buffer, value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, - bitmap: Option, ) -> Result { if let Some(b) = &value_offsets { - let nulls = count_nulls(bitmap.as_ref(), 0, type_ids.len()); - if ((type_ids.len() - nulls) * 4) != b.len() { + if ((type_ids.len()) * 4) != b.len() { return Err(ArrowError::InvalidArgumentError( "Type Ids and Offsets represent a different number of array slots." .to_string(), @@ -108,7 +185,7 @@ impl UnionArray { } // Check the type_ids - let type_id_slice: &[i8] = unsafe { type_ids.typed_data() }; + let type_id_slice: &[i8] = type_ids.typed_data(); let invalid_type_ids = type_id_slice .iter() .filter(|i| *i < &0) @@ -124,7 +201,7 @@ impl UnionArray { // Check the value offsets if provided if let Some(offset_buffer) = &value_offsets { let max_len = type_ids.len() as i32; - let offsets_slice: &[i32] = unsafe { offset_buffer.typed_data() }; + let offsets_slice: &[i32] = offset_buffer.typed_data(); let invalid_offsets = offsets_slice .iter() .filter(|i| *i < &0 || *i > &max_len) @@ -138,7 +215,14 @@ impl UnionArray { } } - Ok(Self::new(type_ids, value_offsets, child_arrays, bitmap)) + // Unsafe Justification: arguments were validated above (and + // re-revalidated as part of data().validate() below) + let new_self = unsafe { + Self::new_unchecked(field_type_ids, type_ids, value_offsets, child_arrays) + }; + new_self.data().validate()?; + + Ok(new_self) } /// Accesses the child array for `type_id`. @@ -171,13 +255,7 @@ impl UnionArray { pub fn value_offset(&self, index: usize) -> i32 { assert!(index - self.offset() < self.len()); if self.is_dense() { - // In format v4 unions had their own validity bitmap and offsets are compressed by omitting null values - // Starting with v5 unions don't have a validity bitmap and it's possible to directly index into the offsets buffer - let valid_slots = match self.data.null_buffer() { - Some(b) => b.count_set_bits_offset(0, index), - None => index, - }; - self.data().buffers()[1].as_slice()[valid_slots * size_of::()] as i32 + self.data().buffers()[1].typed_data::()[index] } else { index as i32 } @@ -198,7 +276,7 @@ impl UnionArray { /// Returns the names of the types in the union. pub fn type_names(&self) -> Vec<&str> { match self.data.data_type() { - DataType::Union(fields) => fields + DataType::Union(fields, _, _) => fields .iter() .map(|f| f.name().as_str()) .collect::>(), @@ -208,7 +286,10 @@ impl UnionArray { /// Returns whether the `UnionArray` is dense (or sparse if `false`). fn is_dense(&self) -> bool { - self.data().buffers().len() == 2 + match self.data.data_type() { + DataType::Union(_, _, mode) => mode == &UnionMode::Dense, + _ => unreachable!("Union array's data type is not a union!"), + } } } @@ -230,6 +311,24 @@ impl Array for UnionArray { fn data(&self) -> &ArrayData { &self.data } + + /// Union types always return non null as there is no validity buffer. + /// To check validity correctly you must check the underlying vector. + fn is_null(&self, _index: usize) -> bool { + false + } + + /// Union types always return non null as there is no validity buffer. + /// To check validity correctly you must check the underlying vector. + fn is_valid(&self, _index: usize) -> bool { + true + } + + /// Union types always return 0 null count as there is no validity buffer. + /// To get null count correctly you must check the underlying vector. + fn null_count(&self) -> usize { + 0 + } } impl fmt::Debug for UnionArray { @@ -334,6 +433,49 @@ mod tests { } } + #[test] + #[cfg_attr(miri, ignore)] + fn test_dense_i32_large() { + let mut builder = UnionBuilder::new_dense(1024); + + let expected_type_ids = vec![0_i8; 1024]; + let expected_value_offsets: Vec<_> = (0..1024).collect(); + let expected_array_values: Vec<_> = (1..=1024).collect(); + + expected_array_values + .iter() + .for_each(|v| builder.append::("a", *v).unwrap()); + + let union = builder.build().unwrap(); + + // Check type ids + assert_eq!( + union.data().buffers()[0], + Buffer::from_slice_ref(&expected_type_ids) + ); + for (i, id) in expected_type_ids.iter().enumerate() { + assert_eq!(id, &union.type_id(i)); + } + + // Check offsets + assert_eq!( + union.data().buffers()[1], + Buffer::from_slice_ref(&expected_value_offsets) + ); + for (i, id) in expected_value_offsets.iter().enumerate() { + assert_eq!(&union.value_offset(i), id); + } + + for (i, expected_value) in expected_array_values.iter().enumerate() { + assert!(!union.is_null(i)); + let slot = union.value(i); + let slot = slot.as_any().downcast_ref::().unwrap(); + assert_eq!(slot.len(), 1); + let value = slot.value(0); + assert_eq!(expected_value, &value); + } + } + #[test] fn test_dense_mixed() { let mut builder = UnionBuilder::new_dense(7); @@ -390,7 +532,7 @@ mod tests { builder.append::("a", 1).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 10).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); let union = builder.build().unwrap(); @@ -400,29 +542,29 @@ mod tests { match i { 0 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(1_i32, value); } 1 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(3_i64, value); } 2 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(10_i32, value); } - 3 => assert!(union.is_null(i)), + 3 => assert!(slot.is_null(0)), 4 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(6_i32, value); @@ -438,7 +580,7 @@ mod tests { builder.append::("a", 1).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 10).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); let union = builder.build().unwrap(); @@ -451,15 +593,15 @@ mod tests { match i { 0 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(10_i32, value); } - 1 => assert!(new_union.is_null(i)), + 1 => assert!(slot.is_null(0)), 2 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(6_i32, value); @@ -481,21 +623,22 @@ mod tests { let type_id_buffer = Buffer::from_slice_ref(&type_ids); let value_offsets_buffer = Buffer::from_slice_ref(&value_offsets); - let mut children: Vec<(Field, Arc)> = Vec::new(); - children.push(( - Field::new("A", DataType::Utf8, false), - Arc::new(string_array), - )); - children.push((Field::new("B", DataType::Int32, false), Arc::new(int_array))); - children.push(( - Field::new("C", DataType::Float64, false), - Arc::new(float_array), - )); + let children: Vec<(Field, Arc)> = vec![ + ( + Field::new("A", DataType::Utf8, false), + Arc::new(string_array), + ), + (Field::new("B", DataType::Int32, false), Arc::new(int_array)), + ( + Field::new("C", DataType::Float64, false), + Arc::new(float_array), + ), + ]; let array = UnionArray::try_new( + &[0, 1, 2], type_id_buffer, Some(value_offsets_buffer), children, - None, ) .unwrap(); @@ -543,7 +686,7 @@ mod tests { .downcast_ref::() .unwrap() .value(0); - assert!(10.0 - value < f64::EPSILON); + assert_eq!(10.0, value); let slot = array.value(4); let value = slot @@ -648,7 +791,7 @@ mod tests { let slot = slot.as_any().downcast_ref::().unwrap(); assert_eq!(slot.len(), 1); let value = slot.value(0); - assert!(value - 3_f64 < f64::EPSILON); + assert_eq!(value, 3_f64); } 2 => { let slot = slot.as_any().downcast_ref::().unwrap(); @@ -660,7 +803,7 @@ mod tests { let slot = slot.as_any().downcast_ref::().unwrap(); assert_eq!(slot.len(), 1); let value = slot.value(0); - assert!(5_f64 - value < f64::EPSILON); + assert_eq!(5_f64, value); } 4 => { let slot = slot.as_any().downcast_ref::().unwrap(); @@ -677,7 +820,7 @@ mod tests { fn test_sparse_mixed_with_nulls() { let mut builder = UnionBuilder::new_sparse(5); builder.append::("a", 1).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); builder.append::("a", 4).unwrap(); let union = builder.build().unwrap(); @@ -701,22 +844,22 @@ mod tests { match i { 0 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(1_i32, value); } - 1 => assert!(union.is_null(i)), + 1 => assert!(slot.is_null(0)), 2 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); - assert!(value - 3_f64 < f64::EPSILON); + assert_eq!(value, 3_f64); } 3 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(4_i32, value); @@ -730,9 +873,9 @@ mod tests { fn test_sparse_mixed_with_nulls_and_offset() { let mut builder = UnionBuilder::new_sparse(5); builder.append::("a", 1).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("c").unwrap(); builder.append::("a", 4).unwrap(); let union = builder.build().unwrap(); @@ -743,18 +886,18 @@ mod tests { for i in 0..new_union.len() { let slot = new_union.value(i); match i { - 0 => assert!(new_union.is_null(i)), + 0 => assert!(slot.is_null(0)), 1 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!new_union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); - assert!(value - 3_f64 < f64::EPSILON); + assert_eq!(value, 3_f64); } - 2 => assert!(new_union.is_null(i)), + 2 => assert!(slot.is_null(0)), 3 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!new_union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(4_i32, value); @@ -763,4 +906,44 @@ mod tests { } } } + + fn test_union_validity(union_array: &UnionArray) { + assert_eq!(union_array.null_count(), 0); + + for i in 0..union_array.len() { + assert!(!union_array.is_null(i)); + assert!(union_array.is_valid(i)); + } + } + + #[test] + fn test_union_array_validaty() { + let mut builder = UnionBuilder::new_sparse(5); + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append_null::("c").unwrap(); + builder.append::("a", 4).unwrap(); + let union = builder.build().unwrap(); + + test_union_validity(&union); + + let mut builder = UnionBuilder::new_dense(5); + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append_null::("c").unwrap(); + builder.append::("a", 4).unwrap(); + let union = builder.build().unwrap(); + + test_union_validity(&union); + } + + #[test] + fn test_type_check() { + let mut builder = UnionBuilder::new_sparse(2); + builder.append::("a", 1.0).unwrap(); + let err = builder.append::("a", 1).unwrap_err().to_string(); + assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err); + } } diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index 60f76d95485f..3802b1d2d406 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -25,6 +25,7 @@ use std::collections::HashMap; use std::fmt; use std::marker::PhantomData; use std::mem; +use std::ops::Range; use std::sync::Arc; use crate::array::*; @@ -33,29 +34,6 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -/// Converts a `MutableBuffer` to a `BufferBuilder`. -/// -/// `slots` is the number of array slots currently represented in the `MutableBuffer`. -pub(crate) fn mutable_buffer_to_builder( - mutable_buffer: MutableBuffer, - slots: usize, -) -> BufferBuilder { - BufferBuilder:: { - buffer: mutable_buffer, - len: slots, - _marker: PhantomData, - } -} - -/// Converts a `BufferBuilder` into its underlying `MutableBuffer`. -/// -/// `From` is not implemented because associated type bounds are unstable. -pub(crate) fn builder_to_mutable_buffer( - builder: BufferBuilder, -) -> MutableBuffer { - builder.buffer -} - /// Builder for creating a [`Buffer`](crate::buffer::Buffer) object. /// /// A [`Buffer`](crate::buffer::Buffer) is the underlying data @@ -75,7 +53,7 @@ pub(crate) fn builder_to_mutable_buffer( /// builder.append(45); /// let buffer = builder.finish(); /// -/// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 43, 44, 45]); +/// assert_eq!(buffer.typed_data::(), &[42, 43, 44, 45]); /// # Ok(()) /// # } /// ``` @@ -179,8 +157,7 @@ impl BufferBuilder { /// ``` #[inline] pub fn advance(&mut self, i: usize) { - let new_buffer_len = (self.len + i) * mem::size_of::(); - self.buffer.resize(new_buffer_len, 0); + self.buffer.extend_zeros(i * mem::size_of::()); self.len += i; } @@ -243,6 +220,24 @@ impl BufferBuilder { self.len += n; } + /// Appends `n`, zero-initialized values + /// + /// # Example: + /// + /// ``` + /// use arrow::array::UInt32BufferBuilder; + /// + /// let mut builder = UInt32BufferBuilder::new(10); + /// builder.append_n_zeroed(3); + /// + /// assert_eq!(builder.len(), 3); + /// assert_eq!(builder.as_slice(), &[0, 0, 0]) + #[inline] + pub fn append_n_zeroed(&mut self, n: usize) { + self.buffer.extend_zeros(n * mem::size_of::()); + self.len += n; + } + /// Appends a slice of type `T`, growing the internal buffer as needed. /// /// # Example: @@ -261,6 +256,78 @@ impl BufferBuilder { self.len += slice.len(); } + /// View the contents of this buffer as a slice + /// + /// ``` + /// use arrow::array::Float64BufferBuilder; + /// + /// let mut builder = Float64BufferBuilder::new(10); + /// builder.append(1.3); + /// builder.append_n(2, 2.3); + /// + /// assert_eq!(builder.as_slice(), &[1.3, 2.3, 2.3]); + /// ``` + #[inline] + pub fn as_slice(&self) -> &[T] { + // SAFETY + // + // - MutableBuffer is aligned and initialized for len elements of T + // - MutableBuffer corresponds to a single allocation + // - MutableBuffer does not support modification whilst active immutable borrows + unsafe { std::slice::from_raw_parts(self.buffer.as_ptr() as _, self.len) } + } + + /// View the contents of this buffer as a mutable slice + /// + /// # Example: + /// + /// ``` + /// use arrow::array::Float32BufferBuilder; + /// + /// let mut builder = Float32BufferBuilder::new(10); + /// + /// builder.append_slice(&[1., 2., 3.4]); + /// assert_eq!(builder.as_slice(), &[1., 2., 3.4]); + /// + /// builder.as_slice_mut()[1] = 4.2; + /// assert_eq!(builder.as_slice(), &[1., 4.2, 3.4]); + /// ``` + #[inline] + pub fn as_slice_mut(&mut self) -> &mut [T] { + // SAFETY + // + // - MutableBuffer is aligned and initialized for len elements of T + // - MutableBuffer corresponds to a single allocation + // - MutableBuffer does not support modification whilst active immutable borrows + unsafe { std::slice::from_raw_parts_mut(self.buffer.as_mut_ptr() as _, self.len) } + } + + /// Shorten this BufferBuilder to `len` items + /// + /// If `len` is greater than the builder's current length, this has no effect + /// + /// # Example: + /// + /// ``` + /// use arrow::array::UInt16BufferBuilder; + /// + /// let mut builder = UInt16BufferBuilder::new(10); + /// + /// builder.append_slice(&[42, 44, 46]); + /// assert_eq!(builder.as_slice(), &[42, 44, 46]); + /// + /// builder.truncate(2); + /// assert_eq!(builder.as_slice(), &[42, 44]); + /// + /// builder.append(12); + /// assert_eq!(builder.as_slice(), &[42, 44, 12]); + /// ``` + #[inline] + pub fn truncate(&mut self, len: usize) { + self.buffer.truncate(len * mem::size_of::()); + self.len = len; + } + /// # Safety /// This requires the iterator be a trusted length. This could instead require /// the iterator implement `TrustedLen` once that is stabilized. @@ -290,7 +357,7 @@ impl BufferBuilder { /// /// let buffer = builder.finish(); /// - /// assert_eq!(unsafe { buffer.typed_data::() }, &[42, 44, 46]); + /// assert_eq!(buffer.typed_data::(), &[42, 44, 46]); /// ``` #[inline] pub fn finish(&mut self) -> Buffer { @@ -310,7 +377,7 @@ impl BooleanBufferBuilder { #[inline] pub fn new(capacity: usize) -> Self { let byte_capacity = bit_util::ceil(capacity, 8); - let buffer = MutableBuffer::from_len_zeroed(byte_capacity); + let buffer = MutableBuffer::new(byte_capacity); Self { buffer, len: 0 } } @@ -366,6 +433,15 @@ impl BooleanBufferBuilder { } } + /// Resizes the buffer, either truncating its contents (with no change in capacity), or + /// growing it (potentially reallocating it) and writing `false` in the newly available bits. + #[inline] + pub fn resize(&mut self, len: usize) { + let len_bytes = bit_util::ceil(len, 8); + self.buffer.resize(len_bytes, 0); + self.len = len; + } + #[inline] pub fn append(&mut self, v: bool) { self.advance(1); @@ -398,6 +474,31 @@ impl BooleanBufferBuilder { } } + /// Append `range` bits from `to_set` + /// + /// `to_set` is a slice of bits packed LSB-first into `[u8]` + /// + /// # Panics + /// + /// Panics if `to_set` does not contain `ceil(range.end / 8)` bytes + pub fn append_packed_range(&mut self, range: Range, to_set: &[u8]) { + let offset_write = self.len; + let len = range.end - range.start; + self.advance(len); + crate::util::bit_mask::set_bits( + self.buffer.as_slice_mut(), + to_set, + offset_write, + range.start, + len, + ); + } + + /// Returns the packed bits + pub fn as_slice(&self) -> &[u8] { + self.buffer.as_slice() + } + #[inline] pub fn finish(&mut self) -> Buffer { let buf = std::mem::replace(&mut self.buffer, MutableBuffer::new(0)); @@ -595,12 +696,11 @@ impl BooleanBuilder { let len = self.len(); let null_bit_buffer = self.bitmap_builder.finish(); let null_count = len - null_bit_buffer.count_set_bits(); - let mut builder = ArrayData::builder(DataType::Boolean) + let builder = ArrayData::builder(DataType::Boolean) .len(len) - .add_buffer(self.values_builder.finish()); - if null_count > 0 { - builder = builder.null_bit_buffer(null_bit_buffer); - } + .add_buffer(self.values_builder.finish()) + .null_bit_buffer((null_count > 0).then(|| null_bit_buffer)); + let array_data = unsafe { builder.build_unchecked() }; BooleanArray::from(array_data) } @@ -794,12 +894,15 @@ impl PrimitiveBuilder { .as_ref() .map(|b| b.count_set_bits()) .unwrap_or(len); - let mut builder = ArrayData::builder(T::DATA_TYPE) + let builder = ArrayData::builder(T::DATA_TYPE) .len(len) - .add_buffer(self.values_builder.finish()); - if null_count > 0 { - builder = builder.null_bit_buffer(null_bit_buffer.unwrap()); - } + .add_buffer(self.values_builder.finish()) + .null_bit_buffer(if null_count > 0 { + null_bit_buffer + } else { + None + }); + let array_data = unsafe { builder.build_unchecked() }; PrimitiveArray::::from(array_data) } @@ -821,7 +924,7 @@ impl PrimitiveBuilder { .len(len) .add_buffer(self.values_builder.finish()); if null_count > 0 { - builder = builder.null_bit_buffer(null_bit_buffer.unwrap()); + builder = builder.null_bit_buffer(null_bit_buffer); } builder = builder.add_child_data(values.data().clone()); let array_data = unsafe { builder.build_unchecked() }; @@ -948,7 +1051,7 @@ where values_data.data_type().clone(), true, // TODO: find a consistent way of getting this )); - let data_type = if OffsetSize::is_large() { + let data_type = if OffsetSize::IS_LARGE { DataType::LargeList(field) } else { DataType::List(field) @@ -957,7 +1060,7 @@ where .len(len) .add_buffer(offset_buffer) .add_child_data(values_data.clone()) - .null_bit_buffer(null_bit_buffer); + .null_bit_buffer(Some(null_bit_buffer)); let array_data = unsafe { array_data.build_unchecked() }; @@ -1088,7 +1191,7 @@ where )) .len(len) .add_child_data(values_data.clone()) - .null_bit_buffer(null_bit_buffer); + .null_bit_buffer(Some(null_bit_buffer)); let array_data = unsafe { array_data.build_unchecked() }; @@ -1128,11 +1231,13 @@ pub struct DecimalBuilder { builder: FixedSizeListBuilder, precision: usize, scale: usize, + + /// Should i128 values be validated for compatibility with scale and precision? + /// defaults to true + value_validation: bool, } -impl ArrayBuilder - for GenericBinaryBuilder -{ +impl ArrayBuilder for GenericBinaryBuilder { /// Returns the builder as a non-mutable `Any` reference. fn as_any(&self) -> &dyn Any { self @@ -1164,9 +1269,7 @@ impl ArrayBuilder } } -impl ArrayBuilder - for GenericStringBuilder -{ +impl ArrayBuilder for GenericStringBuilder { /// Returns the builder as a non-mutable `Any` reference. fn as_any(&self) -> &dyn Any { self @@ -1263,7 +1366,7 @@ impl ArrayBuilder for DecimalBuilder { } } -impl GenericBinaryBuilder { +impl GenericBinaryBuilder { /// Creates a new `GenericBinaryBuilder`, `capacity` is the number of bytes in the values /// array pub fn new(capacity: usize) -> Self { @@ -1312,7 +1415,7 @@ impl GenericBinaryBuilder { } } -impl GenericStringBuilder { +impl GenericStringBuilder { /// Creates a new `StringBuilder`, /// `capacity` is the number of bytes of string data to pre-allocate space for in this builder pub fn new(capacity: usize) -> Self { @@ -1422,15 +1525,32 @@ impl DecimalBuilder { builder: FixedSizeListBuilder::new(values_builder, byte_width), precision, scale, + value_validation: true, } } + /// Disable validation + /// + /// # Safety + /// + /// After disabling validation, caller must ensure that appended values are compatible + /// for the specified precision and scale. + pub unsafe fn disable_value_validation(&mut self) { + self.value_validation = false; + } + /// Appends a byte slice into the builder. /// /// Automatically calls the `append` method to delimit the slice appended in as a /// distinct array element. #[inline] - pub fn append_value(&mut self, value: i128) -> Result<()> { + pub fn append_value(&mut self, value: impl Into) -> Result<()> { + let value = if self.value_validation { + validate_decimal_precision(value.into(), self.precision)? + } else { + value.into() + }; + let value_as_bytes = Self::from_i128_to_fixed_size_bytes( value, self.builder.value_length() as usize, @@ -1446,7 +1566,7 @@ impl DecimalBuilder { self.builder.append(true) } - fn from_i128_to_fixed_size_bytes(v: i128, size: usize) -> Result> { + pub(crate) fn from_i128_to_fixed_size_bytes(v: i128, size: usize) -> Result> { if size > 16 { return Err(ArrowError::InvalidArgumentError( "DecimalBuilder only supports values up to 16 bytes.".to_string(), @@ -1597,6 +1717,9 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { Box::new(IntervalDayTimeBuilder::new(capacity)) } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Box::new(IntervalMonthDayNanoBuilder::new(capacity)) + } DataType::Duration(TimeUnit::Second) => { Box::new(DurationSecondBuilder::new(capacity)) } @@ -1675,7 +1798,7 @@ impl StructBuilder { .len(self.len) .child_data(child_data); if null_count > 0 { - builder = builder.null_bit_buffer(null_bit_buffer); + builder = builder.null_bit_buffer(Some(null_bit_buffer)); } self.len = 0; @@ -1712,6 +1835,7 @@ impl Default for MapFieldNames { } } +#[allow(dead_code)] impl MapBuilder { pub fn new( field_names: Option, @@ -1809,7 +1933,7 @@ impl MapBuilder { .len(len) .add_buffer(offset_buffer) .add_child_data(struct_array.data().clone()) - .null_bit_buffer(null_bit_buffer); + .null_bit_buffer(Some(null_bit_buffer)); let array_data = unsafe { array_data.build_unchecked() }; @@ -1851,106 +1975,65 @@ struct FieldData { /// The Arrow data type represented in the `values_buffer`, which is untyped data_type: DataType, /// A buffer containing the values for this field in raw bytes - values_buffer: Option, + values_buffer: Box, /// The number of array slots represented by the buffer slots: usize, - /// A builder for the bitmap if required (for Sparse Unions) - bitmap_builder: Option, + /// A builder for the null bitmap + bitmap_builder: BooleanBufferBuilder, +} + +/// A type-erased [`BufferBuilder`] used by [`FieldData`] +trait FieldDataValues: std::fmt::Debug { + fn as_mut_any(&mut self) -> &mut dyn Any; + + fn append_null(&mut self); + + fn finish(&mut self) -> Buffer; +} + +impl FieldDataValues for BufferBuilder { + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn append_null(&mut self) { + self.advance(1) + } + + fn finish(&mut self) -> Buffer { + self.finish() + } } impl FieldData { /// Creates a new `FieldData`. - fn new( - type_id: i8, - data_type: DataType, - bitmap_builder: Option, - ) -> Self { + fn new(type_id: i8, data_type: DataType) -> Self { Self { type_id, data_type, - values_buffer: Some(MutableBuffer::new(1)), slots: 0, - bitmap_builder, + values_buffer: Box::new(BufferBuilder::::new(1)), + bitmap_builder: BooleanBufferBuilder::new(1), } } /// Appends a single value to this `FieldData`'s `values_buffer`. - #[allow(clippy::unnecessary_wraps)] - fn append_to_values_buffer( - &mut self, - v: T::Native, - ) -> Result<()> { - let values_buffer = self - .values_buffer - .take() - .expect("Values buffer was never created"); - let mut builder: BufferBuilder = - mutable_buffer_to_builder(values_buffer, self.slots); - builder.append(v); - let mutable_buffer = builder_to_mutable_buffer(builder); - self.values_buffer = Some(mutable_buffer); + fn append_value(&mut self, v: T::Native) { + self.values_buffer + .as_mut_any() + .downcast_mut::>() + .expect("Tried to append unexpected type") + .append(v); + self.bitmap_builder.append(true); self.slots += 1; - if let Some(b) = &mut self.bitmap_builder { - b.append(true) - }; - Ok(()) } /// Appends a null to this `FieldData`. - #[allow(clippy::unnecessary_wraps)] - fn append_null(&mut self) -> Result<()> { - if let Some(b) = &mut self.bitmap_builder { - let values_buffer = self - .values_buffer - .take() - .expect("Values buffer was never created"); - let mut builder: BufferBuilder = - mutable_buffer_to_builder(values_buffer, self.slots); - builder.advance(1); - let mutable_buffer = builder_to_mutable_buffer(builder); - self.values_buffer = Some(mutable_buffer); - self.slots += 1; - b.append(false); - }; - Ok(()) - } - - /// Appends a null to this `FieldData` when the type is not known at compile time. - /// - /// As the main `append` method of `UnionBuilder` is generic, we need a way to append null - /// slots to the fields that are not being appended to in the case of sparse unions. This - /// method solves this problem by appending dynamically based on `DataType`. - /// - /// Note, this method does **not** update the length of the `UnionArray` (this is done by the - /// main append operation) and assumes that it is called from a method that is generic over `T` - /// where `T` satisfies the bound `ArrowPrimitiveType`. - fn append_null_dynamic(&mut self) -> Result<()> { - match self.data_type { - DataType::Null => unimplemented!(), - DataType::Int8 => self.append_null::()?, - DataType::Int16 => self.append_null::()?, - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - self.append_null::()? - } - DataType::Int64 - | DataType::Timestamp(_, _) - | DataType::Date64 - | DataType::Time64(_) - | DataType::Interval(IntervalUnit::DayTime) - | DataType::Duration(_) => self.append_null::()?, - DataType::UInt8 => self.append_null::()?, - DataType::UInt16 => self.append_null::()?, - DataType::UInt32 => self.append_null::()?, - DataType::UInt64 => self.append_null::()?, - DataType::Float32 => self.append_null::()?, - DataType::Float64 => self.append_null::()?, - _ => unreachable!("All cases of types that satisfy the trait bounds over T are covered above."), - }; - Ok(()) + fn append_null(&mut self) { + self.values_buffer.append_null(); + self.bitmap_builder.append(false); + self.slots += 1; } } @@ -2006,8 +2089,6 @@ pub struct UnionBuilder { type_id_builder: Int8BufferBuilder, /// Builder to keep track of offsets (`None` for sparse unions) value_offset_builder: Option, - /// Optional builder for null slots - bitmap_builder: Option, } impl UnionBuilder { @@ -2018,7 +2099,6 @@ impl UnionBuilder { fields: HashMap::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: Some(Int32BufferBuilder::new(capacity)), - bitmap_builder: None, } } @@ -2029,35 +2109,19 @@ impl UnionBuilder { fields: HashMap::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: None, - bitmap_builder: None, } } - /// Appends a null to this builder. + /// Appends a null to this builder, encoding the null in the array + /// of the `type_name` child / field. + /// + /// Since `UnionArray` encodes nulls as an entry in its children + /// (it doesn't have a validity bitmap itself), and where the null + /// is part of the final array, appending a NULL requires + /// specifying which field (child) to use. #[inline] - pub fn append_null(&mut self) -> Result<()> { - if self.bitmap_builder.is_none() { - let mut builder = BooleanBufferBuilder::new(self.len + 1); - for _ in 0..self.len { - builder.append(true); - } - self.bitmap_builder = Some(builder) - } - self.bitmap_builder - .as_mut() - .expect("Cannot be None") - .append(false); - - self.type_id_builder.append(i8::default()); - - // Handle sparse union - if self.value_offset_builder.is_none() { - for (_, fd) in self.fields.iter_mut() { - fd.append_null_dynamic()?; - } - } - self.len += 1; - Ok(()) + pub fn append_null(&mut self, type_name: &str) -> Result<()> { + self.append_option::(type_name, None) } /// Appends a value to this builder. @@ -2066,21 +2130,31 @@ impl UnionBuilder { &mut self, type_name: &str, v: T::Native, + ) -> Result<()> { + self.append_option::(type_name, Some(v)) + } + + fn append_option( + &mut self, + type_name: &str, + v: Option, ) -> Result<()> { let type_name = type_name.to_string(); let mut field_data = match self.fields.remove(&type_name) { - Some(data) => data, + Some(data) => { + if data.data_type != T::DATA_TYPE { + return Err(ArrowError::InvalidArgumentError(format!("Attempt to write col \"{}\" with type {} doesn't match existing type {}", type_name, T::DATA_TYPE, data.data_type))); + } + data + } None => match self.value_offset_builder { - Some(_) => FieldData::new(self.fields.len() as i8, T::DATA_TYPE, None), + Some(_) => FieldData::new::(self.fields.len() as i8, T::DATA_TYPE), None => { - let mut fd = FieldData::new( - self.fields.len() as i8, - T::DATA_TYPE, - Some(BooleanBufferBuilder::new(1)), - ); + let mut fd = + FieldData::new::(self.fields.len() as i8, T::DATA_TYPE); for _ in 0..self.len { - fd.append_null::()?; + fd.append_null(); } fd } @@ -2095,20 +2169,19 @@ impl UnionBuilder { } // Sparse Union None => { - for (name, fd) in self.fields.iter_mut() { - if name != &type_name { - fd.append_null_dynamic()?; - } + for (_, fd) in self.fields.iter_mut() { + // Append to all bar the FieldData currently being appended to + fd.append_null(); } } } - field_data.append_to_values_buffer::(v)?; - self.fields.insert(type_name, field_data); - // Update the bitmap builder if it exists - if let Some(b) = &mut self.bitmap_builder { - b.append(true); + match v { + Some(v) => field_data.append_value::(v), + None => field_data.append_null(), } + + self.fields.insert(type_name, field_data); self.len += 1; Ok(()) } @@ -2123,27 +2196,19 @@ impl UnionBuilder { FieldData { type_id, data_type, - values_buffer, + mut values_buffer, slots, - bitmap_builder, + mut bitmap_builder, }, ) in self.fields.into_iter() { - let buffer = values_buffer - .expect("The `values_buffer` should only ever be None inside the `append` method.") - .into(); + let buffer = values_buffer.finish(); let arr_data_builder = ArrayDataBuilder::new(data_type.clone()) .add_buffer(buffer) - .len(slots); - // .build(); - let arr_data_ref = unsafe { - match bitmap_builder { - Some(mut bb) => arr_data_builder - .null_bit_buffer(bb.finish()) - .build_unchecked(), - None => arr_data_builder.build_unchecked(), - } - }; + .len(slots) + .null_bit_buffer(Some(bitmap_builder.finish())); + + let arr_data_ref = unsafe { arr_data_builder.build_unchecked() }; let array_ref = make_array(arr_data_ref); children.push((type_id, (Field::new(&name, data_type, false), array_ref))) } @@ -2153,9 +2218,10 @@ impl UnionBuilder { .expect("This will never be None as type ids are always i8 values.") }); let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect(); - let bitmap = self.bitmap_builder.map(|mut b| b.finish()); - UnionArray::try_new(type_id_buffer, value_offsets_buffer, children, bitmap) + let type_ids: Vec = (0_i8..children.len() as i8).collect(); + + UnionArray::try_new(&type_ids, type_id_buffer, value_offsets_buffer, children) } } @@ -2493,6 +2559,7 @@ mod tests { use crate::array::Array; use crate::bitmap::Bitmap; + use crate::util::decimal::Decimal128; #[test] fn test_builder_i32_empty() { @@ -2624,7 +2691,8 @@ mod tests { let buffer = b.finish(); assert_eq!(1, buffer.len()); - let mut b = BooleanBufferBuilder::new(4); + // Overallocate capacity + let mut b = BooleanBufferBuilder::new(8); b.append_slice(&[false, true, false, true]); assert_eq!(4, b.len()); assert_eq!(512, b.capacity()); @@ -2745,6 +2813,42 @@ mod tests { assert!(buffer.get_bit(11)); } + #[test] + fn test_bool_buffer_fuzz() { + use rand::prelude::*; + + let mut buffer = BooleanBufferBuilder::new(12); + let mut all_bools = vec![]; + let mut rng = rand::thread_rng(); + + let src_len = 32; + let (src, compacted_src) = { + let src: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0)) + .take(src_len) + .collect(); + + let mut compacted_src = BooleanBufferBuilder::new(src_len); + compacted_src.append_slice(&src); + (src, compacted_src.finish()) + }; + + for _ in 0..100 { + let a = rng.next_u32() as usize % src_len; + let b = rng.next_u32() as usize % src_len; + + let start = a.min(b); + let end = a.max(b); + + buffer.append_packed_range(start..end, compacted_src.as_slice()); + all_bools.extend_from_slice(&src[start..end]); + } + + let mut compacted = BooleanBufferBuilder::new(all_bools.len()); + compacted.append_slice(&all_bools); + + assert_eq!(buffer.finish(), compacted.finish()) + } + #[test] fn test_boolean_array_builder_append_slice() { let arr1 = @@ -2771,6 +2875,29 @@ mod tests { assert_eq!(arr1, arr2); } + #[test] + fn test_boolean_array_builder_resize() { + let mut builder = BooleanBufferBuilder::new(20); + builder.append_n(4, true); + builder.append_n(7, false); + builder.append_n(2, true); + builder.resize(20); + + assert_eq!(builder.len, 20); + assert_eq!( + builder.buffer.as_slice(), + &[0b00001111, 0b00011000, 0b00000000] + ); + + builder.resize(5); + assert_eq!(builder.len, 5); + assert_eq!(builder.buffer.as_slice(), &[0b00001111]); + + builder.append_n(4, true); + assert_eq!(builder.len, 9); + assert_eq!(builder.buffer.as_slice(), &[0b11101111, 0b00000001]); + } + #[test] fn test_boolean_builder_increases_buffer_len() { // 00000010 01001000 @@ -3343,14 +3470,34 @@ mod tests { #[test] fn test_decimal_builder() { - let mut builder = DecimalBuilder::new(30, 23, 6); + let mut builder = DecimalBuilder::new(30, 38, 6); + + builder.append_value(8_887_000_000_i128).unwrap(); + builder.append_null().unwrap(); + builder.append_value(-8_887_000_000_i128).unwrap(); + let decimal_array: DecimalArray = builder.finish(); + + assert_eq!(&DataType::Decimal(38, 6), decimal_array.data_type()); + assert_eq!(3, decimal_array.len()); + assert_eq!(1, decimal_array.null_count()); + assert_eq!(32, decimal_array.value_offset(2)); + assert_eq!(16, decimal_array.value_length()); + } - builder.append_value(8_887_000_000).unwrap(); + #[test] + fn test_decimal_builder_with_decimal128() { + let mut builder = DecimalBuilder::new(30, 38, 6); + + builder + .append_value(Decimal128::new_from_i128(30, 38, 8_887_000_000_i128)) + .unwrap(); builder.append_null().unwrap(); - builder.append_value(-8_887_000_000).unwrap(); + builder + .append_value(Decimal128::new_from_i128(30, 38, -8_887_000_000_i128)) + .unwrap(); let decimal_array: DecimalArray = builder.finish(); - assert_eq!(&DataType::Decimal(23, 6), decimal_array.data_type()); + assert_eq!(&DataType::Decimal(38, 6), decimal_array.data_type()); assert_eq!(3, decimal_array.len()); assert_eq!(1, decimal_array.null_count()); assert_eq!(32, decimal_array.value_offset(2)); @@ -3453,13 +3600,13 @@ mod tests { assert_eq!(4, struct_data.len()); assert_eq!(1, struct_data.null_count()); assert_eq!( - &Some(Bitmap::from(Buffer::from(&[11_u8]))), + Some(&Bitmap::from(Buffer::from(&[11_u8]))), struct_data.null_bitmap() ); let expected_string_data = ArrayData::builder(DataType::Utf8) .len(4) - .null_bit_buffer(Buffer::from(&[9_u8])) + .null_bit_buffer(Some(Buffer::from(&[9_u8]))) .add_buffer(Buffer::from_slice_ref(&[0, 3, 3, 3, 7])) .add_buffer(Buffer::from_slice_ref(b"joemark")) .build() @@ -3467,7 +3614,7 @@ mod tests { let expected_int_data = ArrayData::builder(DataType::Int32) .len(4) - .null_bit_buffer(Buffer::from_slice_ref(&[11_u8])) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[11_u8]))) .add_buffer(Buffer::from_slice_ref(&[1, 2, 0, 4])) .build() .unwrap(); @@ -3567,13 +3714,13 @@ mod tests { assert_eq!(3, map_data.len()); assert_eq!(1, map_data.null_count()); assert_eq!( - &Some(Bitmap::from(Buffer::from(&[5_u8]))), + Some(&Bitmap::from(Buffer::from(&[5_u8]))), map_data.null_bitmap() ); let expected_string_data = ArrayData::builder(DataType::Utf8) .len(4) - .null_bit_buffer(Buffer::from(&[9_u8])) + .null_bit_buffer(Some(Buffer::from(&[9_u8]))) .add_buffer(Buffer::from_slice_ref(&[0, 3, 3, 3, 7])) .add_buffer(Buffer::from_slice_ref(b"joemark")) .build() @@ -3581,7 +3728,7 @@ mod tests { let expected_int_data = ArrayData::builder(DataType::Int32) .len(4) - .null_bit_buffer(Buffer::from_slice_ref(&[11_u8])) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[11_u8]))) .add_buffer(Buffer::from_slice_ref(&[1, 2, 0, 4])) .build() .unwrap(); @@ -3596,12 +3743,14 @@ mod tests { #[test] fn test_struct_array_builder_from_schema() { - let mut fields = Vec::new(); - fields.push(Field::new("f1", DataType::Float32, false)); - fields.push(Field::new("f2", DataType::Utf8, false)); - let mut sub_fields = Vec::new(); - sub_fields.push(Field::new("g1", DataType::Int32, false)); - sub_fields.push(Field::new("g2", DataType::Boolean, false)); + let mut fields = vec![ + Field::new("f1", DataType::Float32, false), + Field::new("f2", DataType::Utf8, false), + ]; + let sub_fields = vec![ + Field::new("g1", DataType::Int32, false), + Field::new("g2", DataType::Boolean, false), + ]; let struct_type = DataType::Struct(sub_fields); fields.push(Field::new("f3", struct_type, false)); @@ -3617,8 +3766,7 @@ mod tests { expected = "Data type List(Field { name: \"item\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) is not currently supported" )] fn test_struct_array_builder_from_schema_unsupported_type() { - let mut fields = Vec::new(); - fields.push(Field::new("f1", DataType::Int16, false)); + let mut fields = vec![Field::new("f1", DataType::Int16, false)]; let list_type = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); fields.push(Field::new("f2", list_type, false)); diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs index dfc15608555f..d0b77a0d27b5 100644 --- a/arrow/src/array/cast.rs +++ b/arrow/src/array/cast.rs @@ -21,7 +21,7 @@ use crate::array::*; use crate::datatypes::*; /// Force downcast ArrayRef to PrimitiveArray -pub fn as_primitive_array(arr: &ArrayRef) -> &PrimitiveArray +pub fn as_primitive_array(arr: &dyn Array) -> &PrimitiveArray where T: ArrowPrimitiveType, { @@ -31,7 +31,7 @@ where } /// Force downcast ArrayRef to DictionaryArray -pub fn as_dictionary_array(arr: &ArrayRef) -> &DictionaryArray +pub fn as_dictionary_array(arr: &dyn Array) -> &DictionaryArray where T: ArrowDictionaryKeyType, { @@ -41,7 +41,9 @@ where } #[doc = "Force downcast ArrayRef to GenericListArray"] -pub fn as_generic_list_array(arr: &ArrayRef) -> &GenericListArray { +pub fn as_generic_list_array( + arr: &dyn Array, +) -> &GenericListArray { arr.as_any() .downcast_ref::>() .expect("Unable to downcast to list array") @@ -49,20 +51,20 @@ pub fn as_generic_list_array(arr: &ArrayRef) -> &GenericList #[doc = "Force downcast ArrayRef to ListArray"] #[inline] -pub fn as_list_array(arr: &ArrayRef) -> &ListArray { +pub fn as_list_array(arr: &dyn Array) -> &ListArray { as_generic_list_array::(arr) } #[doc = "Force downcast ArrayRef to LargeListArray"] #[inline] -pub fn as_large_list_array(arr: &ArrayRef) -> &LargeListArray { +pub fn as_large_list_array(arr: &dyn Array) -> &LargeListArray { as_generic_list_array::(arr) } #[doc = "Force downcast ArrayRef to GenericBinaryArray"] #[inline] -pub fn as_generic_binary_array( - arr: &ArrayRef, +pub fn as_generic_binary_array( + arr: &dyn Array, ) -> &GenericBinaryArray { arr.as_any() .downcast_ref::>() @@ -73,7 +75,7 @@ macro_rules! array_downcast_fn { ($name: ident, $arrty: ty, $arrty_str:expr) => { #[doc = "Force downcast ArrayRef to "] #[doc = $arrty_str] - pub fn $name(arr: &ArrayRef) -> &$arrty { + pub fn $name(arr: &dyn Array) -> &$arrty { arr.as_any().downcast_ref::<$arrty>().expect(concat!( "Unable to downcast to typed array through ", stringify!($name) @@ -92,3 +94,45 @@ array_downcast_fn!(as_largestring_array, LargeStringArray); array_downcast_fn!(as_boolean_array, BooleanArray); array_downcast_fn!(as_null_array, NullArray); array_downcast_fn!(as_struct_array, StructArray); +array_downcast_fn!(as_union_array, UnionArray); +array_downcast_fn!(as_map_array, MapArray); +array_downcast_fn!(as_decimal_array, DecimalArray); + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + #[test] + fn test_as_decimal_array_ref() { + let array: DecimalArray = vec![Some(123), None, Some(1111)] + .into_iter() + .collect::() + .with_precision_and_scale(10, 2) + .unwrap(); + assert!(!as_decimal_array(&array).is_empty()); + let result_decimal = as_decimal_array(&array); + assert_eq!(result_decimal, &array); + } + + #[test] + fn test_as_primitive_array_ref() { + let array: Int32Array = vec![1, 2, 3].into_iter().map(Some).collect(); + assert!(!as_primitive_array::(&array).is_empty()); + + // should also work when wrapped in an Arc + let array: ArrayRef = Arc::new(array); + assert!(!as_primitive_array::(&array).is_empty()); + } + + #[test] + fn test_as_string_array_ref() { + let array: StringArray = vec!["foo", "bar"].into_iter().map(Some).collect(); + assert!(!as_string_array(&array).is_empty()); + + // should also work when wrapped in an Arc + let array: ArrayRef = Arc::new(array); + assert!(!as_string_array(&array).is_empty()) + } +} diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index dbc54342b034..3e7e66496162 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -18,16 +18,18 @@ //! Contains `ArrayData`, a generic representation of Arrow array data which encapsulates //! common attributes and operations for Arrow array. -use std::mem; -use std::sync::Arc; - -use crate::datatypes::{DataType, IntervalUnit}; -use crate::error::Result; +use crate::datatypes::{validate_decimal_precision, DataType, IntervalUnit, UnionMode}; +use crate::error::{ArrowError, Result}; use crate::{bitmap::Bitmap, datatypes::ArrowNativeType}; use crate::{ buffer::{Buffer, MutableBuffer}, util::bit_util, }; +use half::f16; +use std::convert::TryInto; +use std::mem; +use std::ops::Range; +use std::sync::Arc; use super::equal::equal; @@ -88,6 +90,10 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], + DataType::Float16 => [ + MutableBuffer::new(capacity * mem::size_of::()), + empty_buffer, + ], DataType::Float32 => [ MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, @@ -115,6 +121,10 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], + DataType::Interval(IntervalUnit::MonthDayNano) => [ + MutableBuffer::new(capacity * mem::size_of::()), + empty_buffer, + ], DataType::Utf8 | DataType::Binary => { let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); // safety: `unsafe` code assumes that this buffer is initialized with one element @@ -177,7 +187,6 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff ], _ => unreachable!(), }, - DataType::Float16 => unreachable!(), DataType::FixedSizeList(_, _) | DataType::Struct(_) => { [empty_buffer, MutableBuffer::new(0)] } @@ -185,7 +194,16 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], - DataType::Union(_) => unimplemented!(), + DataType::Union(_, _, mode) => { + let type_ids = MutableBuffer::new(capacity * mem::size_of::()); + match mode { + UnionMode::Sparse => [type_ids, empty_buffer], + UnionMode::Dense => { + let offsets = MutableBuffer::new(capacity * mem::size_of::()); + [type_ids, offsets] + } + } + } } } @@ -197,11 +215,18 @@ pub(crate) fn into_buffers( buffer2: MutableBuffer, ) -> Vec { match data_type { - DataType::Null | DataType::Struct(_) => vec![], + DataType::Null | DataType::Struct(_) | DataType::FixedSizeList(_, _) => vec![], DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => vec![buffer1.into(), buffer2.into()], + DataType::Union(_, _, mode) => { + match mode { + // Based on Union's DataTypeLayout + UnionMode::Sparse => vec![buffer1.into()], + UnionMode::Dense => vec![buffer1.into(), buffer2.into()], + } + } _ => vec![buffer1.into()], } } @@ -253,6 +278,7 @@ impl ArrayData { /// Note: This is a low level API and most users of the arrow /// crate should create arrays using the methods in the `array` /// module. + #[allow(clippy::let_and_return)] pub unsafe fn new_unchecked( data_type: DataType, len: usize, @@ -267,7 +293,7 @@ impl ArrayData { Some(null_count) => null_count, }; let null_bitmap = null_bit_buffer.map(Bitmap::from); - Self { + let new_self = Self { data_type, len, null_count, @@ -275,33 +301,47 @@ impl ArrayData { buffers, child_data, null_bitmap, - } + }; + + // Provide a force_validate mode + #[cfg(feature = "force_validate")] + new_self.validate_full().unwrap(); + new_self } /// Create a new ArrayData, validating that the provided buffers /// form a valid Arrow array of the specified data type. /// - /// If `null_count` is not specified, the number of nulls in - /// null_bit_buffer is calculated - /// /// Note: This is a low level API and most users of the arrow /// crate should create arrays using the methods in the `array` /// module. pub fn try_new( data_type: DataType, len: usize, - null_count: Option, null_bit_buffer: Option, offset: usize, buffers: Vec, child_data: Vec, ) -> Result { - // Safetly justification: `validate` is (will be) called below + // we must check the length of `null_bit_buffer` first + // because we use this buffer to calculate `null_count` + // in `Self::new_unchecked`. + if let Some(null_bit_buffer) = null_bit_buffer.as_ref() { + let needed_len = bit_util::ceil(len + offset, 8); + if null_bit_buffer.len() < needed_len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_bit_buffer size too small. got {} needed {}", + null_bit_buffer.len(), + needed_len + ))); + } + } + // Safety justification: `validate_full` is called below let new_self = unsafe { Self::new_unchecked( data_type, len, - null_count, + None, null_bit_buffer, offset, buffers, @@ -309,22 +349,11 @@ impl ArrayData { ) }; - new_self.validate()?; + // As the data is not trusted, do a full validation of its contents + new_self.validate_full()?; Ok(new_self) } - /// Validates that buffers in this ArrayData are sufficiently - /// sized, to store `len` + `offset` total elements of - /// `data_type`. - /// - /// This check is "cheap" in the sense that it does not validate the - /// contents of the buffers (e.g. that string offsets for UTF8 arrays - /// are within the length of the buffer). - pub fn validate(&self) -> Result<()> { - // will be filled in a subsequent PR - Ok(()) - } - /// Returns a builder to construct a `ArrayData` instance. #[inline] pub const fn builder(data_type: DataType) -> ArrayDataBuilder { @@ -337,6 +366,27 @@ impl ArrayData { &self.data_type } + /// Updates the [DataType] of this ArrayData/ + /// + /// panic's if the new DataType is not compatible with the + /// existing type. + /// + /// Note: currently only changing a [DataType::Decimal]s precision + /// and scale are supported + #[inline] + pub(crate) fn with_data_type(mut self, new_data_type: DataType) -> Self { + assert!( + matches!(self.data_type, DataType::Decimal(_, _)), + "only DecimalType is supported for existing type" + ); + assert!( + matches!(new_data_type, DataType::Decimal(_, _)), + "only DecimalType is supported for new datatype" + ); + self.data_type = new_data_type; + self + } + /// Returns a slice of buffers for this array data pub fn buffers(&self) -> &[Buffer] { &self.buffers[..] @@ -357,8 +407,8 @@ impl ArrayData { /// Returns a reference to the null bitmap of this array data #[inline] - pub const fn null_bitmap(&self) -> &Option { - &self.null_bitmap + pub const fn null_bitmap(&self) -> Option<&Bitmap> { + self.null_bitmap.as_ref() } /// Returns a reference to the null buffer of this array data. @@ -459,7 +509,7 @@ impl ArrayData { .iter() .map(|data| data.slice(offset, length)) .collect(), - null_bitmap: self.null_bitmap().clone(), + null_bitmap: self.null_bitmap().cloned(), }; new_data @@ -508,6 +558,7 @@ impl ArrayData { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::Date32 @@ -539,11 +590,13 @@ impl ArrayData { DataType::Map(field, _) => { vec![Self::new_empty(field.data_type())] } - DataType::Union(_) => unimplemented!(), + DataType::Union(fields, _, _) => fields + .iter() + .map(|field| Self::new_empty(field.data_type())) + .collect(), DataType::Dictionary(_, data_type) => { vec![Self::new_empty(data_type)] } - DataType::Float16 => unreachable!(), }; // Data was constructed correctly above @@ -559,6 +612,769 @@ impl ArrayData { ) } } + + /// "cheap" validation of an `ArrayData`. Ensures buffers are + /// sufficiently sized to store `len` + `offset` total elements of + /// `data_type` and performs other inexpensive consistency checks. + /// + /// This check is "cheap" in the sense that it does not validate the + /// contents of the buffers (e.g. that all offsets for UTF8 arrays + /// are within the bounds of the values buffer). + /// + /// See [ArrayData::validate_full] to validate fully the offset content + /// and the validitiy of utf8 data + pub fn validate(&self) -> Result<()> { + // Need at least this mich space in each buffer + let len_plus_offset = self.len + self.offset; + + // Check that the data layout conforms to the spec + let layout = layout(&self.data_type); + + if !layout.can_contain_null_mask && self.null_bitmap.is_some() { + return Err(ArrowError::InvalidArgumentError(format!( + "Arrays of type {:?} cannot contain a null bitmask", + self.data_type, + ))); + } + + if self.buffers.len() != layout.buffers.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected {} buffers in array of type {:?}, got {}", + layout.buffers.len(), + self.data_type, + self.buffers.len(), + ))); + } + + for (i, (buffer, spec)) in + self.buffers.iter().zip(layout.buffers.iter()).enumerate() + { + match spec { + BufferSpec::FixedWidth { byte_width } => { + let min_buffer_size = len_plus_offset + .checked_mul(*byte_width) + .expect("integer overflow computing min buffer size"); + + if buffer.len() < min_buffer_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Need at least {} bytes in buffers[{}] in array of type {:?}, but got {}", + min_buffer_size, i, self.data_type, buffer.len() + ))); + } + } + BufferSpec::VariableWidth => { + // not cheap to validate (need to look at the + // data). Partially checked in validate_offsets + // called below. Can check with `validate_full` + } + BufferSpec::BitMap => { + let min_buffer_size = bit_util::ceil(len_plus_offset, 8); + if buffer.len() < min_buffer_size { + return Err(ArrowError::InvalidArgumentError(format!( + "Need at least {} bytes for bitmap in buffers[{}] in array of type {:?}, but got {}", + min_buffer_size, i, self.data_type, buffer.len() + ))); + } + } + BufferSpec::AlwaysNull => { + // Nothing to validate + } + } + } + + if self.null_count > self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count {} for an array exceeds length of {} elements", + self.null_count, self.len + ))); + } + + // check null bit buffer size + if let Some(null_bit_map) = self.null_bitmap.as_ref() { + let null_bit_buffer = null_bit_map.buffer_ref(); + let needed_len = bit_util::ceil(len_plus_offset, 8); + if null_bit_buffer.len() < needed_len { + return Err(ArrowError::InvalidArgumentError(format!( + "null_bit_buffer size too small. got {} needed {}", + null_bit_buffer.len(), + needed_len + ))); + } + } else if self.null_count > 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Array of type {} has {} nulls but no null bitmap", + self.data_type, self.null_count + ))); + } + + self.validate_child_data()?; + + // Additional Type specific checks + match &self.data_type { + DataType::Utf8 | DataType::Binary => { + self.validate_offsets::(self.buffers[1].len())?; + } + DataType::LargeUtf8 | DataType::LargeBinary => { + self.validate_offsets::(self.buffers[1].len())?; + } + DataType::Dictionary(key_type, _value_type) => { + // At the moment, constructing a DictionaryArray will also check this + if !DataType::is_dictionary_key_type(key_type) { + return Err(ArrowError::InvalidArgumentError(format!( + "Dictionary key type must be integer, but was {}", + key_type + ))); + } + } + _ => {} + }; + + Ok(()) + } + + /// Returns a reference to the data in `buffer` as a typed slice + /// (typically `&[i32]` or `&[i64]`) after validating. The + /// returned slice is guaranteed to have at least `self.len + 1` + /// entries. + /// + /// For an empty array, the `buffer` can also be empty. + fn typed_offsets(&self) -> Result<&[T]> { + // An empty list-like array can have 0 offsets + if self.len == 0 && self.buffers[0].is_empty() { + return Ok(&[]); + } + + self.typed_buffer(0, self.len + 1) + } + + /// Returns a reference to the data in `buffers[idx]` as a typed slice after validating + fn typed_buffer( + &self, + idx: usize, + len: usize, + ) -> Result<&[T]> { + let buffer = &self.buffers[idx]; + + let required_len = (len + self.offset) * std::mem::size_of::(); + + if buffer.len() < required_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Buffer {} of {} isn't large enough. Expected {} bytes got {}", + idx, + self.data_type, + required_len, + buffer.len() + ))); + } + + Ok(&buffer.typed_data::()[self.offset..self.offset + len]) + } + + /// Does a cheap sanity check that the `self.len` values in `buffer` are valid + /// offsets (of type T) into some other buffer of `values_length` bytes long + fn validate_offsets( + &self, + values_length: usize, + ) -> Result<()> { + // Justification: buffer size was validated above + let offsets = self.typed_offsets::()?; + if offsets.is_empty() { + return Ok(()); + } + + let first_offset = offsets[0].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[0] ({}) to usize for {}", + offsets[0], self.data_type + )) + })?; + + let last_offset = offsets[self.len].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[{}] ({}) to usize for {}", + self.len, offsets[self.len], self.data_type + )) + })?; + + if first_offset > values_length { + return Err(ArrowError::InvalidArgumentError(format!( + "First offset {} of {} is larger than values length {}", + first_offset, self.data_type, values_length, + ))); + } + + if last_offset > values_length { + return Err(ArrowError::InvalidArgumentError(format!( + "Last offset {} of {} is larger than values length {}", + last_offset, self.data_type, values_length, + ))); + } + + if first_offset > last_offset { + return Err(ArrowError::InvalidArgumentError(format!( + "First offset {} in {} is smaller than last offset {}", + first_offset, self.data_type, last_offset, + ))); + } + + Ok(()) + } + + /// Validates the layout of `child_data` ArrayData structures + fn validate_child_data(&self) -> Result<()> { + match &self.data_type { + DataType::List(field) | DataType::Map(field, _) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets::(values_data.len)?; + Ok(()) + } + DataType::LargeList(field) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + self.validate_offsets::(values_data.len)?; + Ok(()) + } + DataType::FixedSizeList(field, list_size) => { + let values_data = self.get_single_valid_child_data(field.data_type())?; + + let list_size: usize = (*list_size).try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "{} has a negative list_size {}", + self.data_type, list_size + )) + })?; + + let expected_values_len = self.len + .checked_mul(list_size) + .expect("integer overflow computing expected number of expected values in FixedListSize"); + + if values_data.len < expected_values_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Values length {} is less than the length ({}) multiplied by the value size ({}) for {}", + values_data.len, list_size, list_size, self.data_type + ))); + } + + Ok(()) + } + DataType::Struct(fields) => { + self.validate_num_child_data(fields.len())?; + for (i, field) in fields.iter().enumerate() { + let field_data = self.get_valid_child_data(i, field.data_type())?; + + // Ensure child field has sufficient size + if field_data.len < self.len { + return Err(ArrowError::InvalidArgumentError(format!( + "{} child array #{} for field {} has length smaller than expected for struct array ({} < {})", + self.data_type, i, field.name(), field_data.len, self.len + ))); + } + } + Ok(()) + } + DataType::Union(fields, _, mode) => { + self.validate_num_child_data(fields.len())?; + + for (i, field) in fields.iter().enumerate() { + let field_data = self.get_valid_child_data(i, field.data_type())?; + + if mode == &UnionMode::Sparse + && field_data.len < (self.len + self.offset) + { + return Err(ArrowError::InvalidArgumentError(format!( + "Sparse union child array #{} has length smaller than expected for union array ({} < {})", + i, field_data.len, self.len + self.offset + ))); + } + } + Ok(()) + } + DataType::Dictionary(_key_type, value_type) => { + self.get_single_valid_child_data(value_type)?; + Ok(()) + } + _ => { + // other types do not have child data + if !self.child_data.is_empty() { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected no child arrays for type {} but got {}", + self.data_type, + self.child_data.len() + ))); + } + Ok(()) + } + } + } + + /// Ensures that this array data has a single child_data with the + /// expected type, and calls `validate()` on it. Returns a + /// reference to that child_data + fn get_single_valid_child_data( + &self, + expected_type: &DataType, + ) -> Result<&ArrayData> { + self.validate_num_child_data(1)?; + self.get_valid_child_data(0, expected_type) + } + + /// Returns `Err` if self.child_data does not have exactly `expected_len` elements + fn validate_num_child_data(&self, expected_len: usize) -> Result<()> { + if self.child_data().len() != expected_len { + Err(ArrowError::InvalidArgumentError(format!( + "Value data for {} should contain {} child data array(s), had {}", + self.data_type(), + expected_len, + self.child_data.len() + ))) + } else { + Ok(()) + } + } + + /// Ensures that `child_data[i]` has the expected type, calls + /// `validate()` on it, and returns a reference to that child_data + fn get_valid_child_data( + &self, + i: usize, + expected_type: &DataType, + ) -> Result<&ArrayData> { + let values_data = self.child_data + .get(i) + .ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "{} did not have enough child arrays. Expected at least {} but had only {}", + self.data_type, i+1, self.child_data.len() + )) + })?; + + if expected_type != &values_data.data_type { + return Err(ArrowError::InvalidArgumentError(format!( + "Child type mismatch for {}. Expected {} but child data had {}", + self.data_type, expected_type, values_data.data_type + ))); + } + + values_data.validate()?; + Ok(values_data) + } + + /// "expensive" validation that ensures: + /// + /// 1. Null count is correct + /// 2. All offsets are valid + /// 3. All String data is valid UTF-8 + /// 4. All dictionary offsets are valid + /// + /// Does not (yet) check + /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) + /// Note calls `validate()` internally + pub fn validate_full(&self) -> Result<()> { + // Check all buffer sizes prior to looking at them more deeply in this function + self.validate()?; + + let null_bitmap_buffer = self + .null_bitmap + .as_ref() + .map(|null_bitmap| null_bitmap.buffer_ref()); + + let actual_null_count = count_nulls(null_bitmap_buffer, self.offset, self.len); + if actual_null_count != self.null_count { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count value ({}) doesn't match actual number of nulls in array ({})", + self.null_count, actual_null_count + ))); + } + + self.validate_values()?; + + // validate all children recursively + self.child_data + .iter() + .enumerate() + .try_for_each(|(i, child_data)| { + child_data.validate_full().map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "{} child #{} invalid: {}", + self.data_type, i, e + )) + }) + })?; + + Ok(()) + } + + pub fn validate_values(&self) -> Result<()> { + match &self.data_type { + DataType::Decimal(p, _) => { + let values_buffer: &[i128] = self.typed_buffer(0, self.len)?; + for value in values_buffer { + validate_decimal_precision(*value, *p)?; + } + Ok(()) + } + DataType::Utf8 => self.validate_utf8::(), + DataType::LargeUtf8 => self.validate_utf8::(), + DataType::Binary => self.validate_offsets_full::(self.buffers[1].len()), + DataType::LargeBinary => { + self.validate_offsets_full::(self.buffers[1].len()) + } + DataType::List(_) | DataType::Map(_, _) => { + let child = &self.child_data[0]; + self.validate_offsets_full::(child.len) + } + DataType::LargeList(_) => { + let child = &self.child_data[0]; + self.validate_offsets_full::(child.len) + } + DataType::Union(_, _, _) => { + // Validate Union Array as part of implementing new Union semantics + // See comments in `ArrayData::validate()` + // https://github.com/apache/arrow-rs/issues/85 + // + // TODO file follow on ticket for full union validation + Ok(()) + } + DataType::Dictionary(key_type, _value_type) => { + let dictionary_length: i64 = self.child_data[0].len.try_into().unwrap(); + let max_value = dictionary_length - 1; + match key_type.as_ref() { + DataType::UInt8 => self.check_bounds::(max_value), + DataType::UInt16 => self.check_bounds::(max_value), + DataType::UInt32 => self.check_bounds::(max_value), + DataType::UInt64 => self.check_bounds::(max_value), + DataType::Int8 => self.check_bounds::(max_value), + DataType::Int16 => self.check_bounds::(max_value), + DataType::Int32 => self.check_bounds::(max_value), + DataType::Int64 => self.check_bounds::(max_value), + _ => unreachable!(), + } + } + _ => { + // No extra validation check required for other types + Ok(()) + } + } + } + + /// Calls the `validate(item_index, range)` function for each of + /// the ranges specified in the arrow offsets buffer of type + /// `T`. Also validates that each offset is smaller than + /// `offset_limit` + /// + /// For an empty array, the offsets buffer can either be empty + /// or contain a single `0`. + /// + /// For example, the offsets buffer contained `[1, 2, 4]`, this + /// function would call `validate([1,2])`, and `validate([2,4])` + fn validate_each_offset(&self, offset_limit: usize, validate: V) -> Result<()> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + V: Fn(usize, Range) -> Result<()>, + { + self.typed_offsets::()? + .iter() + .enumerate() + .map(|(i, x)| { + // check if the offset can be converted to usize + let r = x.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: Could not convert offset {} to usize at position {}", + x, i))} + ); + // check if the offset exceeds the limit + match r { + Ok(n) if n <= offset_limit => Ok((i, n)), + Ok(_) => Err(ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: offset at position {} out of bounds: {} > {}", + i, x, offset_limit)) + ), + Err(e) => Err(e), + } + }) + .scan(0_usize, |start, end| { + // check offsets are monotonically increasing + match end { + Ok((i, end)) if *start <= end => { + let range = Some(Ok((i, *start..end))); + *start = end; + range + } + Ok((i, end)) => Some(Err(ArrowError::InvalidArgumentError(format!( + "Offset invariant failure: non-monotonic offset at slot {}: {} > {}", + i - 1, start, end)) + )), + Err(err) => Some(Err(err)), + } + }) + .skip(1) // the first element is meaningless + .try_for_each(|res: Result<(usize, Range)>| { + let (item_index, range) = res?; + validate(item_index-1, range) + }) + } + + /// Ensures that all strings formed by the offsets in `buffers[0]` + /// into `buffers[1]` are valid utf8 sequences + fn validate_utf8(&self) -> Result<()> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let values_buffer = &self.buffers[1].as_slice(); + + self.validate_each_offset::(values_buffer.len(), |string_index, range| { + std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Invalid UTF8 sequence at string index {} ({:?}): {}", + string_index, range, e + )) + })?; + Ok(()) + }) + } + + /// Ensures that all offsets in `buffers[0]` into `buffers[1]` are + /// between `0` and `offset_limit` + fn validate_offsets_full(&self, offset_limit: usize) -> Result<()> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + self.validate_each_offset::(offset_limit, |_string_index, _range| { + // No validation applied to each value, but the iteration + // itself applies bounds checking to each range + Ok(()) + }) + } + + /// Validates that each value in self.buffers (typed as T) + /// is within the range [0, max_value], inclusive + fn check_bounds(&self, max_value: i64) -> Result<()> + where + T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + { + let required_len = self.len + self.offset; + let buffer = &self.buffers[0]; + + // This should have been checked as part of `validate()` prior + // to calling `validate_full()` but double check to be sure + assert!(buffer.len() / std::mem::size_of::() >= required_len); + + // Justification: buffer size was validated above + let indexes: &[T] = + &buffer.typed_data::()[self.offset..self.offset + self.len]; + + indexes.iter().enumerate().try_for_each(|(i, &dict_index)| { + // Do not check the value is null (value can be arbitrary) + if self.is_null(i) { + return Ok(()); + } + let dict_index: i64 = dict_index.try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Value at position {} out of bounds: {} (can not convert to i64)", + i, dict_index + )) + })?; + + if dict_index < 0 || dict_index > max_value { + return Err(ArrowError::InvalidArgumentError(format!( + "Value at position {} out of bounds: {} (should be in [0, {}])", + i, dict_index, max_value + ))); + } + Ok(()) + }) + } + + /// Returns true if this `ArrayData` is equal to `other`, using pointer comparisons + /// to determine buffer equality. This is cheaper than `PartialEq::eq` but may + /// return false when the arrays are logically equal + pub fn ptr_eq(&self, other: &Self) -> bool { + if self.offset != other.offset + || self.len != other.len + || self.null_count != other.null_count + || self.data_type != other.data_type + || self.buffers.len() != other.buffers.len() + || self.child_data.len() != other.child_data.len() + { + return false; + } + + match (&self.null_bitmap, &other.null_bitmap) { + (Some(a), Some(b)) if a.bits.as_ptr() != b.bits.as_ptr() => return false, + (Some(_), None) | (None, Some(_)) => return false, + _ => {} + }; + + if !self + .buffers + .iter() + .zip(other.buffers.iter()) + .all(|(a, b)| a.as_ptr() == b.as_ptr()) + { + return false; + } + + self.child_data + .iter() + .zip(other.child_data.iter()) + .all(|(a, b)| a.ptr_eq(b)) + } +} + +/// Return the expected [`DataTypeLayout`] Arrays of this data +/// type are expected to have +fn layout(data_type: &DataType) -> DataTypeLayout { + // based on C/C++ implementation in + // https://github.com/apache/arrow/blob/661c7d749150905a63dd3b52e0a04dac39030d95/cpp/src/arrow/type.h (and .cc) + use std::mem::size_of; + match data_type { + DataType::Null => DataTypeLayout { + buffers: vec![], + can_contain_null_mask: false, + }, + DataType::Boolean => DataTypeLayout { + buffers: vec![BufferSpec::BitMap], + can_contain_null_mask: true, + }, + DataType::Int8 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Int16 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Int32 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Int64 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::UInt8 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::UInt16 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::UInt32 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::UInt64 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Float16 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Float32 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Float64 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Timestamp(_, _) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Date32 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Date64 => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Time32(_) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Time64(_) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Interval(IntervalUnit::YearMonth) => { + DataTypeLayout::new_fixed_width(size_of::()) + } + DataType::Interval(IntervalUnit::DayTime) => { + DataTypeLayout::new_fixed_width(size_of::()) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + DataTypeLayout::new_fixed_width(size_of::()) + } + DataType::Duration(_) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Binary => DataTypeLayout::new_binary(size_of::()), + DataType::FixedSizeBinary(bytes_per_value) => { + let bytes_per_value: usize = (*bytes_per_value) + .try_into() + .expect("negative size for fixed size binary"); + DataTypeLayout::new_fixed_width(bytes_per_value) + } + DataType::LargeBinary => DataTypeLayout::new_binary(size_of::()), + DataType::Utf8 => DataTypeLayout::new_binary(size_of::()), + DataType::LargeUtf8 => DataTypeLayout::new_binary(size_of::()), + DataType::List(_) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::FixedSizeList(_, _) => DataTypeLayout::new_empty(), // all in child data + DataType::LargeList(_) => DataTypeLayout::new_fixed_width(size_of::()), + DataType::Struct(_) => DataTypeLayout::new_empty(), // all in child data, + DataType::Union(_, _, mode) => { + let type_ids = BufferSpec::FixedWidth { + byte_width: size_of::(), + }; + + DataTypeLayout { + buffers: match mode { + UnionMode::Sparse => { + vec![type_ids] + } + UnionMode::Dense => { + vec![ + type_ids, + BufferSpec::FixedWidth { + byte_width: size_of::(), + }, + ] + } + }, + can_contain_null_mask: false, + } + } + DataType::Dictionary(key_type, _value_type) => layout(key_type), + DataType::Decimal(_, _) => { + // Decimals are always some fixed width; The rust implementation + // always uses 16 bytes / size of i128 + DataTypeLayout::new_fixed_width(size_of::()) + } + DataType::Map(_, _) => { + // same as ListType + DataTypeLayout::new_fixed_width(size_of::()) + } + } +} + +/// Layout specification for a data type +#[derive(Debug, PartialEq)] +// Note: Follows structure from C++: https://github.com/apache/arrow/blob/master/cpp/src/arrow/type.h#L91 +struct DataTypeLayout { + /// A vector of buffer layout specifications, one for each expected buffer + pub buffers: Vec, + + /// Can contain a null bitmask + pub can_contain_null_mask: bool, +} + +impl DataTypeLayout { + /// Describes a basic numeric array where each element has a fixed width + pub fn new_fixed_width(byte_width: usize) -> Self { + Self { + buffers: vec![BufferSpec::FixedWidth { byte_width }], + can_contain_null_mask: true, + } + } + + /// Describes arrays which have no data of their own + /// (e.g. FixedSizeList). Note such arrays may still have a Null + /// Bitmap + pub fn new_empty() -> Self { + Self { + buffers: vec![], + can_contain_null_mask: true, + } + } + + /// Describes a basic numeric array where each element has a fixed + /// with offset buffer of `offset_byte_width` bytes, followed by a + /// variable width data buffer + pub fn new_binary(offset_byte_width: usize) -> Self { + Self { + buffers: vec![ + // offsets + BufferSpec::FixedWidth { + byte_width: offset_byte_width, + }, + // values + BufferSpec::VariableWidth, + ], + can_contain_null_mask: true, + } + } +} + +/// Layout specification for a single data type buffer +#[derive(Debug, PartialEq)] +enum BufferSpec { + /// each element has a fixed width + FixedWidth { byte_width: usize }, + /// Variable width, such as string data for utf8 data + VariableWidth, + /// Buffer holds a bitmap. + /// + /// Note: Unlike the C++ implementation, the null/validity buffer + /// is handled specially rather than as another of the buffers in + /// the spec, so this variant is only used for the Boolean type. + BitMap, + /// Buffer is always null. Unused currently in Rust implementation, + /// (used in C++ for Union type) + #[allow(dead_code)] + AlwaysNull, } impl PartialEq for ArrayData { @@ -605,8 +1421,8 @@ impl ArrayDataBuilder { self } - pub fn null_bit_buffer(mut self, buf: Buffer) -> Self { - self.null_bit_buffer = Some(buf); + pub fn null_bit_buffer(mut self, buf: Option) -> Self { + self.null_bit_buffer = buf; self } @@ -659,7 +1475,6 @@ impl ArrayDataBuilder { ArrayData::try_new( self.data_type, self.len, - self.null_count, self.null_bit_buffer, self.offset, self.buffers, @@ -671,52 +1486,64 @@ impl ArrayDataBuilder { #[cfg(test)] mod tests { use super::*; + use std::ptr::NonNull; + use crate::array::{ + make_array, Array, BooleanBuilder, DecimalBuilder, FixedSizeListBuilder, + Int32Array, Int32Builder, Int64Array, StringArray, StructBuilder, UInt64Array, + UInt8Builder, + }; use crate::buffer::Buffer; + use crate::datatypes::Field; use crate::util::bit_util; #[test] - fn test_new() { - let arr_data = - ArrayData::try_new(DataType::Boolean, 10, Some(1), None, 2, vec![], vec![]) - .unwrap(); - assert_eq!(10, arr_data.len()); - assert_eq!(1, arr_data.null_count()); - assert_eq!(2, arr_data.offset()); - assert_eq!(0, arr_data.buffers().len()); - assert_eq!(0, arr_data.child_data().len()); + fn test_builder() { + // Buffer needs to be at least 25 long + let v = (0..25).collect::>(); + let b1 = Buffer::from_slice_ref(&v); + let arr_data = ArrayData::builder(DataType::Int32) + .len(20) + .offset(5) + .add_buffer(b1) + .null_bit_buffer(Some(Buffer::from(vec![ + 0b01011111, 0b10110101, 0b01100011, 0b00011110, + ]))) + .build() + .unwrap(); + + assert_eq!(20, arr_data.len()); + assert_eq!(10, arr_data.null_count()); + assert_eq!(5, arr_data.offset()); + assert_eq!(1, arr_data.buffers().len()); + assert_eq!( + Buffer::from_slice_ref(&v).as_slice(), + arr_data.buffers()[0].as_slice() + ); } #[test] - fn test_builder() { + fn test_builder_with_child_data() { let child_arr_data = ArrayData::try_new( DataType::Int32, 5, - Some(0), None, 0, vec![Buffer::from_slice_ref(&[1i32, 2, 3, 4, 5])], vec![], ) .unwrap(); - let v = vec![0, 1, 2, 3]; - let b1 = Buffer::from(&v[..]); - let arr_data = ArrayData::builder(DataType::Int32) - .len(20) - .offset(5) - .add_buffer(b1) - .null_bit_buffer(Buffer::from(vec![ - 0b01011111, 0b10110101, 0b01100011, 0b00011110, - ])) + + let data_type = DataType::Struct(vec![Field::new("x", DataType::Int32, true)]); + + let arr_data = ArrayData::builder(data_type) + .len(5) + .offset(0) .add_child_data(child_arr_data.clone()) .build() .unwrap(); - assert_eq!(20, arr_data.len()); - assert_eq!(10, arr_data.null_count()); - assert_eq!(5, arr_data.offset()); - assert_eq!(1, arr_data.buffers().len()); - assert_eq!(&[0, 1, 2, 3], arr_data.buffers()[0].as_slice()); + assert_eq!(5, arr_data.len()); assert_eq!(1, arr_data.child_data().len()); assert_eq!(child_arr_data, arr_data.child_data()[0]); } @@ -729,7 +1556,8 @@ mod tests { bit_util::set_bit(&mut bit_v, 10); let arr_data = ArrayData::builder(DataType::Int32) .len(16) - .null_bit_buffer(Buffer::from(bit_v)) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) .build() .unwrap(); assert_eq!(13, arr_data.null_count()); @@ -742,7 +1570,8 @@ mod tests { let arr_data = ArrayData::builder(DataType::Int32) .len(12) .offset(2) - .null_bit_buffer(Buffer::from(bit_v)) + .add_buffer(make_i32_buffer(14)) // requires at least 14 bytes of space, + .null_bit_buffer(Some(Buffer::from(bit_v))) .build() .unwrap(); assert_eq!(10, arr_data.null_count()); @@ -756,7 +1585,8 @@ mod tests { bit_util::set_bit(&mut bit_v, 10); let arr_data = ArrayData::builder(DataType::Int32) .len(16) - .null_bit_buffer(Buffer::from(bit_v)) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) .build() .unwrap(); assert!(arr_data.null_buffer().is_some()); @@ -771,7 +1601,8 @@ mod tests { bit_util::set_bit(&mut bit_v, 10); let data = ArrayData::builder(DataType::Int32) .len(16) - .null_bit_buffer(Buffer::from(bit_v)) + .add_buffer(make_i32_buffer(16)) + .null_bit_buffer(Some(Buffer::from(bit_v))) .build() .unwrap(); let new_data = data.slice(1, 15); @@ -788,9 +1619,55 @@ mod tests { #[test] fn test_equality() { - let int_data = ArrayData::builder(DataType::Int32).build().unwrap(); - let float_data = ArrayData::builder(DataType::Float32).build().unwrap(); + let int_data = ArrayData::builder(DataType::Int32) + .len(1) + .add_buffer(make_i32_buffer(1)) + .build() + .unwrap(); + + let float_data = ArrayData::builder(DataType::Float32) + .len(1) + .add_buffer(make_f32_buffer(1)) + .build() + .unwrap(); assert_ne!(int_data, float_data); + assert!(!int_data.ptr_eq(&float_data)); + assert!(int_data.ptr_eq(&int_data)); + + let int_data_clone = int_data.clone(); + assert_eq!(int_data, int_data_clone); + assert!(int_data.ptr_eq(&int_data_clone)); + assert!(int_data_clone.ptr_eq(&int_data)); + + let int_data_slice = int_data_clone.slice(1, 0); + assert!(int_data_slice.ptr_eq(&int_data_slice)); + assert!(!int_data.ptr_eq(&int_data_slice)); + assert!(!int_data_slice.ptr_eq(&int_data)); + + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref(&[0_i32, 2_i32, 2_i32, 5_i32]); + let string_data = ArrayData::try_new( + DataType::Utf8, + 3, + Some(Buffer::from_iter(vec![true, false, true])), + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + + assert_ne!(float_data, string_data); + assert!(!float_data.ptr_eq(&string_data)); + + assert!(string_data.ptr_eq(&string_data)); + + let string_data_cloned = string_data.clone(); + assert!(string_data_cloned.ptr_eq(&string_data)); + assert!(string_data.ptr_eq(&string_data_cloned)); + + let string_data_slice = string_data.slice(1, 2); + assert!(string_data_slice.ptr_eq(&string_data_slice)); + assert!(!string_data_slice.ptr_eq(&string_data)) } #[test] @@ -802,4 +1679,1098 @@ mod tests { let count = count_nulls(null_buffer.as_ref(), 4, 8); assert_eq!(count, 3); } + + #[test] + #[should_panic( + expected = "Need at least 80 bytes in buffers[0] in array of type Int64, but got 8" + )] + fn test_buffer_too_small() { + let buffer = Buffer::from_slice_ref(&[0i32, 2i32]); + // should fail as the declared size (10*8 = 80) is larger than the underlying bfufer (8) + ArrayData::try_new(DataType::Int64, 10, None, 0, vec![buffer], vec![]).unwrap(); + } + + #[test] + #[should_panic( + expected = "Need at least 16 bytes in buffers[0] in array of type Int64, but got 8" + )] + fn test_buffer_too_small_offset() { + let buffer = Buffer::from_slice_ref(&[0i32, 2i32]); + // should fail -- size is ok, but also has offset + ArrayData::try_new(DataType::Int64, 1, None, 1, vec![buffer], vec![]).unwrap(); + } + + #[test] + #[should_panic(expected = "Expected 1 buffers in array of type Int64, got 2")] + fn test_bad_number_of_buffers() { + let buffer1 = Buffer::from_slice_ref(&[0i32, 2i32]); + let buffer2 = Buffer::from_slice_ref(&[0i32, 2i32]); + ArrayData::try_new(DataType::Int64, 1, None, 0, vec![buffer1, buffer2], vec![]) + .unwrap(); + } + + #[test] + #[should_panic(expected = "integer overflow computing min buffer size")] + fn test_fixed_width_overflow() { + let buffer = Buffer::from_slice_ref(&[0i32, 2i32]); + ArrayData::try_new(DataType::Int64, usize::MAX, None, 0, vec![buffer], vec![]) + .unwrap(); + } + + #[test] + #[should_panic(expected = "null_bit_buffer size too small. got 1 needed 2")] + fn test_bitmap_too_small() { + let buffer = make_i32_buffer(9); + let null_bit_buffer = Buffer::from(vec![0b11111111]); + + ArrayData::try_new( + DataType::Int32, + 9, + Some(null_bit_buffer), + 0, + vec![buffer], + vec![], + ) + .unwrap(); + } + + // Test creating a dictionary with a non integer type + #[test] + #[should_panic(expected = "Dictionary key type must be integer, but was Utf8")] + fn test_non_int_dictionary() { + let i32_buffer = Buffer::from_slice_ref(&[0i32, 2i32]); + let data_type = + DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + let child_data = ArrayData::try_new( + DataType::Int32, + 1, + None, + 0, + vec![i32_buffer.clone()], + vec![], + ) + .unwrap(); + ArrayData::try_new( + data_type, + 1, + None, + 0, + vec![i32_buffer.clone(), i32_buffer], + vec![child_data], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Expected LargeUtf8 but child data had Utf8")] + fn test_mismatched_dictionary_types() { + // test w/ dictionary created with a child array data that has type different than declared + let string_array: StringArray = + vec![Some("foo"), Some("bar")].into_iter().collect(); + let i32_buffer = Buffer::from_slice_ref(&[0i32, 1i32]); + // Dict says LargeUtf8 but array is Utf8 + let data_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::LargeUtf8), + ); + let child_data = string_array.data().clone(); + ArrayData::try_new(data_type, 1, None, 0, vec![i32_buffer], vec![child_data]) + .unwrap(); + } + + #[test] + fn test_empty_utf8_array_with_empty_offsets_buffer() { + let data_buffer = Buffer::from(&[]); + let offsets_buffer = Buffer::from(&[]); + ArrayData::try_new( + DataType::Utf8, + 0, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + fn test_empty_utf8_array_with_single_zero_offset() { + let data_buffer = Buffer::from(&[]); + let offsets_buffer = Buffer::from_slice_ref(&[0i32]); + ArrayData::try_new( + DataType::Utf8, + 0, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "First offset 1 of Utf8 is larger than values length 0")] + fn test_empty_utf8_array_with_invalid_offset() { + let data_buffer = Buffer::from(&[]); + let offsets_buffer = Buffer::from_slice_ref(&[1i32]); + ArrayData::try_new( + DataType::Utf8, + 0, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + fn test_empty_utf8_array_with_non_zero_offset() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref(&[0i32, 2, 6, 0]); + ArrayData::try_new( + DataType::Utf8, + 0, + None, + 3, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Buffer 0 of LargeUtf8 isn't large enough. Expected 8 bytes got 4" + )] + fn test_empty_large_utf8_array_with_wrong_type_offsets() { + let data_buffer = Buffer::from(&[]); + let offsets_buffer = Buffer::from_slice_ref(&[0i32]); + ArrayData::try_new( + DataType::LargeUtf8, + 0, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Buffer 0 of Utf8 isn't large enough. Expected 12 bytes got 8" + )] + fn test_validate_offsets_i32() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref(&[0i32, 2i32]); + ArrayData::try_new( + DataType::Utf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Buffer 0 of LargeUtf8 isn't large enough. Expected 24 bytes got 16" + )] + fn test_validate_offsets_i64() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref(&[0i64, 2i64]); + ArrayData::try_new( + DataType::LargeUtf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Error converting offset[0] (-2) to usize for Utf8")] + fn test_validate_offsets_negative_first_i32() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref(&[-2i32, 1i32, 3i32]); + ArrayData::try_new( + DataType::Utf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Error converting offset[2] (-3) to usize for Utf8")] + fn test_validate_offsets_negative_last_i32() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + let offsets_buffer = Buffer::from_slice_ref(&[0i32, 2i32, -3i32]); + ArrayData::try_new( + DataType::Utf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "First offset 4 in Utf8 is smaller than last offset 3")] + fn test_validate_offsets_range_too_small() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + // start offset is larger than end + let offsets_buffer = Buffer::from_slice_ref(&[4i32, 2i32, 3i32]); + ArrayData::try_new( + DataType::Utf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Last offset 10 of Utf8 is larger than values length 6")] + fn test_validate_offsets_range_too_large() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + // 10 is off the end of the buffer + let offsets_buffer = Buffer::from_slice_ref(&[0i32, 2i32, 10i32]); + ArrayData::try_new( + DataType::Utf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "First offset 10 of Utf8 is larger than values length 6")] + fn test_validate_offsets_first_too_large() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + // 10 is off the end of the buffer + let offsets_buffer = Buffer::from_slice_ref(&[10i32, 2i32, 10i32]); + ArrayData::try_new( + DataType::Utf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + fn test_validate_offsets_first_too_large_skipped() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + // 10 is off the end of the buffer, but offset starts at 1 so it is skipped + let offsets_buffer = Buffer::from_slice_ref(&[10i32, 2i32, 3i32, 4i32]); + let data = ArrayData::try_new( + DataType::Utf8, + 2, + None, + 1, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + let array: StringArray = data.into(); + let expected: StringArray = vec![Some("c"), Some("d")].into_iter().collect(); + assert_eq!(array, expected); + } + + #[test] + #[should_panic(expected = "Last offset 8 of Utf8 is larger than values length 6")] + fn test_validate_offsets_last_too_large() { + let data_buffer = Buffer::from_slice_ref(&"abcdef".as_bytes()); + // 10 is off the end of the buffer + let offsets_buffer = Buffer::from_slice_ref(&[5i32, 7i32, 8i32]); + ArrayData::try_new( + DataType::Utf8, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Values length 4 is less than the length (2) multiplied by the value size (2) for FixedSizeList" + )] + fn test_validate_fixed_size_list() { + // child has 4 elements, + let child_array = vec![Some(1), Some(2), Some(3), None] + .into_iter() + .collect::(); + + // but claim we have 3 elements for a fixed size of 2 + // 10 is off the end of the buffer + let field = Field::new("field", DataType::Int32, true); + ArrayData::try_new( + DataType::FixedSizeList(Box::new(field), 2), + 3, + None, + 0, + vec![], + vec![child_array.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Child type mismatch for Struct")] + fn test_validate_struct_child_type() { + let field1 = vec![Some(1), Some(2), Some(3), None] + .into_iter() + .collect::(); + + // validate the the type of struct fields matches child fields + ArrayData::try_new( + DataType::Struct(vec![Field::new("field1", DataType::Int64, true)]), + 3, + None, + 0, + vec![], + vec![field1.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "child array #0 for field field1 has length smaller than expected for struct array (4 < 6)" + )] + fn test_validate_struct_child_length() { + // field length only has 4 items, but array claims to have 6 + let field1 = vec![Some(1), Some(2), Some(3), None] + .into_iter() + .collect::(); + + ArrayData::try_new( + DataType::Struct(vec![Field::new("field1", DataType::Int32, true)]), + 6, + None, + 0, + vec![], + vec![field1.data().clone()], + ) + .unwrap(); + } + + /// Test that the array of type `data_type` that has invalid utf8 data errors + fn check_utf8_validation(data_type: DataType) { + // 0x80 is a utf8 continuation sequence and is not a valid utf8 sequence itself + let data_buffer = Buffer::from_slice_ref(&[b'a', b'a', 0x80, 0x00]); + let offsets: Vec = [0, 2, 3] + .iter() + .map(|&v| T::from_usize(v).unwrap()) + .collect(); + + let offsets_buffer = Buffer::from_slice_ref(&offsets); + ArrayData::try_new( + data_type, + 2, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Invalid UTF8 sequence at string index 1 (2..3)")] + fn test_validate_utf8_content() { + check_utf8_validation::(DataType::Utf8); + } + + #[test] + #[should_panic(expected = "Invalid UTF8 sequence at string index 1 (2..3)")] + fn test_validate_large_utf8_content() { + check_utf8_validation::(DataType::LargeUtf8); + } + + /// Test that the array of type `data_type` that has invalid indexes (out of bounds) + fn check_index_out_of_bounds_validation(data_type: DataType) { + let data_buffer = Buffer::from_slice_ref(&[b'a', b'b', b'c', b'd']); + // First two offsets are fine, then 5 is out of bounds + let offsets: Vec = [0, 1, 2, 5, 2] + .iter() + .map(|&v| T::from_usize(v).unwrap()) + .collect(); + + let offsets_buffer = Buffer::from_slice_ref(&offsets); + ArrayData::try_new( + data_type, + 4, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: offset at position 3 out of bounds: 5 > 4" + )] + fn test_validate_utf8_out_of_bounds() { + check_index_out_of_bounds_validation::(DataType::Utf8); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: offset at position 3 out of bounds: 5 > 4" + )] + fn test_validate_large_utf8_out_of_bounds() { + check_index_out_of_bounds_validation::(DataType::LargeUtf8); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: offset at position 3 out of bounds: 5 > 4" + )] + fn test_validate_binary_out_of_bounds() { + check_index_out_of_bounds_validation::(DataType::Binary); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: offset at position 3 out of bounds: 5 > 4" + )] + fn test_validate_large_binary_out_of_bounds() { + check_index_out_of_bounds_validation::(DataType::LargeBinary); + } + + // validate that indexes don't go bacwards check indexes that go backwards + fn check_index_backwards_validation(data_type: DataType) { + let data_buffer = Buffer::from_slice_ref(&[b'a', b'b', b'c', b'd']); + // First three offsets are fine, then 1 goes backwards + let offsets: Vec = [0, 1, 2, 2, 1] + .iter() + .map(|&v| T::from_usize(v).unwrap()) + .collect(); + + let offsets_buffer = Buffer::from_slice_ref(&offsets); + ArrayData::try_new( + data_type, + 4, + None, + 0, + vec![offsets_buffer, data_buffer], + vec![], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: non-monotonic offset at slot 3: 2 > 1" + )] + fn test_validate_utf8_index_backwards() { + check_index_backwards_validation::(DataType::Utf8); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: non-monotonic offset at slot 3: 2 > 1" + )] + fn test_validate_large_utf8_index_backwards() { + check_index_backwards_validation::(DataType::LargeUtf8); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: non-monotonic offset at slot 3: 2 > 1" + )] + fn test_validate_binary_index_backwards() { + check_index_backwards_validation::(DataType::Binary); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: non-monotonic offset at slot 3: 2 > 1" + )] + fn test_validate_large_binary_index_backwards() { + check_index_backwards_validation::(DataType::LargeBinary); + } + + #[test] + #[should_panic( + expected = "Value at position 1 out of bounds: 3 (should be in [0, 1])" + )] + fn test_validate_dictionary_index_too_large() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + + // 3 is not a valid index into the values (only 0 and 1) + let keys: Int32Array = [Some(1), Some(3)].into_iter().collect(); + + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + ArrayData::try_new( + data_type, + 2, + None, + 0, + vec![keys.data().buffers[0].clone()], + vec![values.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Value at position 1 out of bounds: -1 (should be in [0, 1]" + )] + fn test_validate_dictionary_index_negative() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + + // -1 is not a valid index at all! + let keys: Int32Array = [Some(1), Some(-1)].into_iter().collect(); + + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + ArrayData::try_new( + data_type, + 2, + None, + 0, + vec![keys.data().buffers[0].clone()], + vec![values.data().clone()], + ) + .unwrap(); + } + + #[test] + fn test_validate_dictionary_index_negative_but_not_referenced() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + + // -1 is not a valid index at all, but the array is length 1 + // so the -1 should not be looked at + let keys: Int32Array = [Some(1), Some(-1)].into_iter().collect(); + + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + // Expect this not to panic + ArrayData::try_new( + data_type, + 1, + None, + 0, + vec![keys.data().buffers[0].clone()], + vec![values.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Value at position 0 out of bounds: 18446744073709551615 (can not convert to i64)" + )] + fn test_validate_dictionary_index_giant_negative() { + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + + // -1 is not a valid index at all! + let keys: UInt64Array = [Some(u64::MAX), Some(1)].into_iter().collect(); + + let data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + ArrayData::try_new( + data_type, + 2, + None, + 0, + vec![keys.data().buffers[0].clone()], + vec![values.data().clone()], + ) + .unwrap(); + } + + /// Test that the list of type `data_type` generates correct offset out of bounds errors + fn check_list_offsets(data_type: DataType) { + let values: Int32Array = + [Some(1), Some(2), Some(3), Some(4)].into_iter().collect(); + + // 5 is an invalid offset into a list of only three values + let offsets: Vec = [0, 2, 5, 4] + .iter() + .map(|&v| T::from_usize(v).unwrap()) + .collect(); + let offsets_buffer = Buffer::from_slice_ref(&offsets); + + ArrayData::try_new( + data_type, + 3, + None, + 0, + vec![offsets_buffer], + vec![values.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: offset at position 2 out of bounds: 5 > 4" + )] + fn test_validate_list_offsets() { + let field_type = Field::new("f", DataType::Int32, true); + check_list_offsets::(DataType::List(Box::new(field_type))); + } + + #[test] + #[should_panic( + expected = "Offset invariant failure: offset at position 2 out of bounds: 5 > 4" + )] + fn test_validate_large_list_offsets() { + let field_type = Field::new("f", DataType::Int32, true); + check_list_offsets::(DataType::LargeList(Box::new(field_type))); + } + + /// Test that the list of type `data_type` generates correct errors for negative offsets + #[test] + #[should_panic( + expected = "Offset invariant failure: Could not convert offset -1 to usize at position 2" + )] + fn test_validate_list_negative_offsets() { + let values: Int32Array = + [Some(1), Some(2), Some(3), Some(4)].into_iter().collect(); + let field_type = Field::new("f", values.data_type().clone(), true); + let data_type = DataType::List(Box::new(field_type)); + + // -1 is an invalid offset any way you look at it + let offsets: Vec = vec![0, 2, -1, 4]; + let offsets_buffer = Buffer::from_slice_ref(&offsets); + + ArrayData::try_new( + data_type, + 3, + None, + 0, + vec![offsets_buffer], + vec![values.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Value at position 1 out of bounds: -1 (should be in [0, 1])" + )] + /// test that children are validated recursively (aka bugs in child data of struct also are flagged) + fn test_validate_recursive() { + // Form invalid dictionary array + let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect(); + // -1 is not a valid index + let keys: Int32Array = [Some(1), Some(-1), Some(1)].into_iter().collect(); + + let dict_data_type = DataType::Dictionary( + Box::new(keys.data_type().clone()), + Box::new(values.data_type().clone()), + ); + + // purposely create an invalid child data + let dict_data = unsafe { + ArrayData::new_unchecked( + dict_data_type, + 2, + None, + None, + 0, + vec![keys.data().buffers[0].clone()], + vec![values.data().clone()], + ) + }; + + // Now, try and create a struct with this invalid child data (and expect an error) + let data_type = + DataType::Struct(vec![Field::new("d", dict_data.data_type().clone(), true)]); + + ArrayData::try_new(data_type, 1, None, 0, vec![], vec![dict_data]).unwrap(); + } + + /// returns a buffer initialized with some constant value for tests + fn make_i32_buffer(n: usize) -> Buffer { + Buffer::from_slice_ref(&vec![42i32; n]) + } + + /// returns a buffer initialized with some constant value for tests + fn make_f32_buffer(n: usize) -> Buffer { + Buffer::from_slice_ref(&vec![42f32; n]) + } + + #[test] + #[should_panic(expected = "Expected Int64 but child data had Int32")] + fn test_validate_union_different_types() { + let field1 = vec![Some(1), Some(2)].into_iter().collect::(); + + let field2 = vec![Some(1), Some(2)].into_iter().collect::(); + + let type_ids = Buffer::from_slice_ref(&[0i8, 1i8]); + + ArrayData::try_new( + DataType::Union( + vec![ + Field::new("field1", DataType::Int32, true), + Field::new("field2", DataType::Int64, true), // data is int32 + ], + vec![0, 1], + UnionMode::Sparse, + ), + 2, + None, + 0, + vec![type_ids], + vec![field1.data().clone(), field2.data().clone()], + ) + .unwrap(); + } + + // sparse with wrong sized children + #[test] + #[should_panic( + expected = "Sparse union child array #1 has length smaller than expected for union array (1 < 2)" + )] + fn test_validate_union_sparse_different_child_len() { + let field1 = vec![Some(1), Some(2)].into_iter().collect::(); + + // field 2 only has 1 item but array should have 2 + let field2 = vec![Some(1)].into_iter().collect::(); + + let type_ids = Buffer::from_slice_ref(&[0i8, 1i8]); + + ArrayData::try_new( + DataType::Union( + vec![ + Field::new("field1", DataType::Int32, true), + Field::new("field2", DataType::Int64, true), + ], + vec![0, 1], + UnionMode::Sparse, + ), + 2, + None, + 0, + vec![type_ids], + vec![field1.data().clone(), field2.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Expected 2 buffers in array of type Union")] + fn test_validate_union_dense_without_offsets() { + let field1 = vec![Some(1), Some(2)].into_iter().collect::(); + + let field2 = vec![Some(1)].into_iter().collect::(); + + let type_ids = Buffer::from_slice_ref(&[0i8, 1i8]); + + ArrayData::try_new( + DataType::Union( + vec![ + Field::new("field1", DataType::Int32, true), + Field::new("field2", DataType::Int64, true), + ], + vec![0, 1], + UnionMode::Dense, + ), + 2, + None, + 0, + vec![type_ids], // need offsets buffer here too + vec![field1.data().clone(), field2.data().clone()], + ) + .unwrap(); + } + + #[test] + #[should_panic( + expected = "Need at least 8 bytes in buffers[1] in array of type Union" + )] + fn test_validate_union_dense_with_bad_len() { + let field1 = vec![Some(1), Some(2)].into_iter().collect::(); + + let field2 = vec![Some(1)].into_iter().collect::(); + + let type_ids = Buffer::from_slice_ref(&[0i8, 1i8]); + let offsets = Buffer::from_slice_ref(&[0i32]); // should have 2 offsets, but only have 1 + + ArrayData::try_new( + DataType::Union( + vec![ + Field::new("field1", DataType::Int32, true), + Field::new("field2", DataType::Int64, true), + ], + vec![0, 1], + UnionMode::Dense, + ), + 2, + None, + 0, + vec![type_ids, offsets], + vec![field1.data().clone(), field2.data().clone()], + ) + .unwrap(); + } + + #[test] + fn test_try_new_sliced_struct() { + let mut builder = StructBuilder::new( + vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Boolean, true), + ], + vec![ + Box::new(Int32Builder::new(5)), + Box::new(BooleanBuilder::new(5)), + ], + ); + + // struct[0] = { a: 10, b: true } + builder + .field_builder::(0) + .unwrap() + .append_option(Some(10)) + .unwrap(); + builder + .field_builder::(1) + .unwrap() + .append_option(Some(true)) + .unwrap(); + builder.append(true).unwrap(); + + // struct[1] = null + builder + .field_builder::(0) + .unwrap() + .append_option(None) + .unwrap(); + builder + .field_builder::(1) + .unwrap() + .append_option(None) + .unwrap(); + builder.append(false).unwrap(); + + // struct[2] = { a: null, b: false } + builder + .field_builder::(0) + .unwrap() + .append_option(None) + .unwrap(); + builder + .field_builder::(1) + .unwrap() + .append_option(Some(false)) + .unwrap(); + builder.append(true).unwrap(); + + // struct[3] = { a: 21, b: null } + builder + .field_builder::(0) + .unwrap() + .append_option(Some(21)) + .unwrap(); + builder + .field_builder::(1) + .unwrap() + .append_option(None) + .unwrap(); + builder.append(true).unwrap(); + + // struct[4] = { a: 18, b: false } + builder + .field_builder::(0) + .unwrap() + .append_option(Some(18)) + .unwrap(); + builder + .field_builder::(1) + .unwrap() + .append_option(Some(false)) + .unwrap(); + builder.append(true).unwrap(); + + let struct_array = builder.finish(); + let struct_array_slice = struct_array.slice(1, 3); + let struct_array_data = struct_array_slice.data(); + + let cloned_data = ArrayData::try_new( + struct_array_slice.data_type().clone(), + struct_array_slice.len(), + struct_array_data.null_buffer().cloned(), + struct_array_slice.offset(), + struct_array_data.buffers().to_vec(), + struct_array_data.child_data().to_vec(), + ) + .unwrap(); + let cloned = crate::array::make_array(cloned_data); + + assert_eq!(&struct_array_slice, &cloned); + } + + #[test] + fn test_into_buffers() { + let data_types = vec![ + DataType::Union(vec![], vec![], UnionMode::Dense), + DataType::Union(vec![], vec![], UnionMode::Sparse), + ]; + + for data_type in data_types { + let buffers = new_buffers(&data_type, 0); + let [buffer1, buffer2] = buffers; + let buffers = into_buffers(&data_type, buffer1, buffer2); + + let layout = layout(&data_type); + assert_eq!(buffers.len(), layout.buffers.len()); + } + } + + #[test] + fn test_string_data_from_foreign() { + let mut strings = "foobarfoobar".to_owned(); + let mut offsets = vec![0_i32, 0, 3, 6, 12]; + let mut bitmap = vec![0b1110_u8]; + + let strings_buffer = unsafe { + Buffer::from_custom_allocation( + NonNull::new_unchecked(strings.as_mut_ptr()), + strings.len(), + Arc::new(strings), + ) + }; + let offsets_buffer = unsafe { + Buffer::from_custom_allocation( + NonNull::new_unchecked(offsets.as_mut_ptr() as *mut u8), + offsets.len() * std::mem::size_of::(), + Arc::new(offsets), + ) + }; + let null_buffer = unsafe { + Buffer::from_custom_allocation( + NonNull::new_unchecked(bitmap.as_mut_ptr()), + bitmap.len(), + Arc::new(bitmap), + ) + }; + + let data = ArrayData::try_new( + DataType::Utf8, + 4, + Some(null_buffer), + 0, + vec![offsets_buffer, strings_buffer], + vec![], + ) + .unwrap(); + + let array = make_array(data); + let array = array.as_any().downcast_ref::().unwrap(); + + let expected = + StringArray::from(vec![None, Some("foo"), Some("bar"), Some("foobar")]); + + assert_eq!(array, &expected); + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_decimal_full_validation() { + let values_builder = UInt8Builder::new(10); + let byte_width = 16; + let mut fixed_size_builder = + FixedSizeListBuilder::new(values_builder, byte_width); + let value_as_bytes = DecimalBuilder::from_i128_to_fixed_size_bytes( + 123456, + fixed_size_builder.value_length() as usize, + ) + .unwrap(); + fixed_size_builder + .values() + .append_slice(value_as_bytes.as_slice()) + .unwrap(); + fixed_size_builder.append(true).unwrap(); + let fixed_size_array = fixed_size_builder.finish(); + + // Build ArrayData for Decimal + let builder = ArrayData::builder(DataType::Decimal(5, 3)) + .len(fixed_size_array.len()) + .add_buffer(fixed_size_array.data_ref().child_data()[0].buffers()[0].clone()); + let array_data = unsafe { builder.build_unchecked() }; + let validation_result = array_data.validate_full(); + let error = validation_result.unwrap_err(); + assert_eq!( + "Invalid argument error: 123456 is too large to store in a Decimal of precision 5. Max is 99999", + error.to_string() + ); + } + + #[test] + fn test_decimal_validation() { + let mut builder = DecimalBuilder::new(4, 10, 4); + builder.append_value(10000).unwrap(); + builder.append_value(20000).unwrap(); + let array = builder.finish(); + + array.data().validate_full().unwrap(); + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_sliced_array_child() { + let values = Int32Array::from_iter_values([1, 2, 3]); + let values_sliced = values.slice(1, 2); + let offsets = Buffer::from_iter([1_i32, 3_i32]); + + let list_field = Field::new("element", DataType::Int32, false); + let data_type = DataType::List(Box::new(list_field)); + + let data = unsafe { + ArrayData::new_unchecked( + data_type, + 1, + None, + None, + 0, + vec![offsets], + vec![values_sliced.data().clone()], + ) + }; + + let err = data.validate_values().unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Offset invariant failure: offset at position 1 out of bounds: 3 > 2"); + } } diff --git a/arrow/src/array/equal/boolean.rs b/arrow/src/array/equal/boolean.rs index 35c9786e49f9..de34d7fab189 100644 --- a/arrow/src/array/equal/boolean.rs +++ b/arrow/src/array/equal/boolean.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::util::bit_util::get_bit; use super::utils::{equal_bits, equal_len}; @@ -24,8 +23,6 @@ use super::utils::{equal_bits, equal_len}; pub(super) fn boolean_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, mut lhs_start: usize, mut rhs_start: usize, mut len: usize, @@ -33,8 +30,8 @@ pub(super) fn boolean_equal( let lhs_values = lhs.buffers()[0].as_slice(); let rhs_values = rhs.buffers()[0].as_slice(); - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { // Optimize performance for starting offset at u8 boundary. @@ -73,8 +70,8 @@ pub(super) fn boolean_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); let lhs_start = lhs.offset() + lhs_start; let rhs_start = rhs.offset() + rhs_start; diff --git a/arrow/src/array/equal/decimal.rs b/arrow/src/array/equal/decimal.rs index 1ee6ec9b5436..e9879f3f281e 100644 --- a/arrow/src/array/equal/decimal.rs +++ b/arrow/src/array/equal/decimal.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::utils::equal_len; pub(super) fn decimal_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -39,8 +36,8 @@ pub(super) fn decimal_equal( let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..]; let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { equal_len( @@ -52,8 +49,8 @@ pub(super) fn decimal_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; diff --git a/arrow/src/array/equal/dictionary.rs b/arrow/src/array/equal/dictionary.rs index 22add2494d2b..4c9bcf798760 100644 --- a/arrow/src/array/equal/dictionary.rs +++ b/arrow/src/array/equal/dictionary.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::ArrowNativeType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::equal_range; pub(super) fn dictionary_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -37,8 +34,8 @@ pub(super) fn dictionary_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { (0..len).all(|i| { @@ -48,8 +45,6 @@ pub(super) fn dictionary_equal( equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), lhs_keys[lhs_pos].to_usize().unwrap(), rhs_keys[rhs_pos].to_usize().unwrap(), 1, @@ -57,8 +52,8 @@ pub(super) fn dictionary_equal( }) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; @@ -71,8 +66,6 @@ pub(super) fn dictionary_equal( && equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), lhs_keys[lhs_pos].to_usize().unwrap(), rhs_keys[rhs_pos].to_usize().unwrap(), 1, diff --git a/arrow/src/array/equal/fixed_binary.rs b/arrow/src/array/equal/fixed_binary.rs index 5f8f93232d53..aea0e08a9ebf 100644 --- a/arrow/src/array/equal/fixed_binary.rs +++ b/arrow/src/array/equal/fixed_binary.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::utils::equal_len; pub(super) fn fixed_binary_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -39,8 +36,8 @@ pub(super) fn fixed_binary_equal( let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..]; let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { equal_len( @@ -52,8 +49,8 @@ pub(super) fn fixed_binary_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; diff --git a/arrow/src/array/equal/fixed_list.rs b/arrow/src/array/equal/fixed_list.rs index e708a06efcdb..82a347c86574 100644 --- a/arrow/src/array/equal/fixed_list.rs +++ b/arrow/src/array/equal/fixed_list.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::equal_range; pub(super) fn fixed_list_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -39,23 +36,21 @@ pub(super) fn fixed_list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), - size * lhs_start, - size * rhs_start, + (lhs_start + lhs.offset()) * size, + (rhs_start + rhs.offset()) * size, size * len, ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; @@ -69,10 +64,8 @@ pub(super) fn fixed_list_equal( && equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), - lhs_pos * size, - rhs_pos * size, + (lhs_pos + lhs.offset()) * size, + (rhs_pos + rhs.offset()) * size, size, // 1 * size since we are comparing a single entry ) }) diff --git a/arrow/src/array/equal/list.rs b/arrow/src/array/equal/list.rs index 20e6400d9520..0feefa7aa11a 100644 --- a/arrow/src/array/equal/list.rs +++ b/arrow/src/array/equal/list.rs @@ -18,11 +18,10 @@ use crate::{ array::ArrayData, array::{data::count_nulls, OffsetSizeTrait}, - buffer::Buffer, util::bit_util::get_bit, }; -use super::{equal_range, utils::child_logical_null_buffer}; +use super::equal_range; fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { // invariant from `base_equal` @@ -46,41 +45,9 @@ fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { }) } -#[allow(clippy::too_many_arguments)] -#[inline] -fn offset_value_equal( - lhs_values: &ArrayData, - rhs_values: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, - lhs_offsets: &[T], - rhs_offsets: &[T], - lhs_pos: usize, - rhs_pos: usize, - len: usize, -) -> bool { - let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); - let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); - let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; - let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; - - lhs_len == rhs_len - && equal_range( - lhs_values, - rhs_values, - lhs_nulls, - rhs_nulls, - lhs_start, - rhs_start, - lhs_len.to_usize().unwrap(), - ) -} - pub(super) fn list_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -95,7 +62,7 @@ pub(super) fn list_equal( // no child values. This causes panics when trying to count set bits. // // We caught this by chance from an accidental test-case, but due to the nature of this - // crash only occuring on list equality checks, we are adding a check here, instead of + // crash only occurring on list equality checks, we are adding a check here, instead of // on the buffer/bitmap utilities, as a length check would incur a penalty for almost all // other use-cases. // @@ -106,10 +73,16 @@ pub(super) fn list_equal( // however, one is more likely to slice into a list array and get a region that has 0 // child values. // The test that triggered this behaviour had [4, 4] as a slice of 1 value slot. - let lhs_child_length = lhs_offsets.get(len).unwrap().to_usize().unwrap() - - lhs_offsets.first().unwrap().to_usize().unwrap(); - let rhs_child_length = rhs_offsets.get(len).unwrap().to_usize().unwrap() - - rhs_offsets.first().unwrap().to_usize().unwrap(); + // For the edge case that zero length list arrays are always equal. + if len == 0 { + return true; + } + + let lhs_child_length = lhs_offsets[lhs_start + len].to_usize().unwrap() + - lhs_offsets[lhs_start].to_usize().unwrap(); + + let rhs_child_length = rhs_offsets[rhs_start + len].to_usize().unwrap() + - rhs_offsets[rhs_start].to_usize().unwrap(); if lhs_child_length == 0 && lhs_child_length == rhs_child_length { return true; @@ -118,35 +91,33 @@ pub(super) fn list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - // compute the child logical bitmap - let child_lhs_nulls = - child_logical_null_buffer(lhs, lhs_nulls, lhs.child_data().get(0).unwrap()); - let child_rhs_nulls = - child_logical_null_buffer(rhs, rhs_nulls, rhs.child_data().get(0).unwrap()); + if lhs_null_count != rhs_null_count { + return false; + } if lhs_null_count == 0 && rhs_null_count == 0 { - lengths_equal( - &lhs_offsets[lhs_start..lhs_start + len], - &rhs_offsets[rhs_start..rhs_start + len], - ) && equal_range( - lhs_values, - rhs_values, - child_lhs_nulls.as_ref(), - child_rhs_nulls.as_ref(), - lhs_offsets[lhs_start].to_usize().unwrap(), - rhs_offsets[rhs_start].to_usize().unwrap(), - (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) - .to_usize() - .unwrap(), - ) + lhs_child_length == rhs_child_length + && lengths_equal( + &lhs_offsets[lhs_start..lhs_start + len], + &rhs_offsets[rhs_start..rhs_start + len], + ) + && equal_range( + lhs_values, + rhs_values, + lhs_offsets[lhs_start].to_usize().unwrap(), + rhs_offsets[rhs_start].to_usize().unwrap(), + lhs_child_length, + ) } else { // get a ref of the parent null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().unwrap().as_slice(); + // with nulls, we need to compare item by item whenever it is not null + // TODO: Could potentially compare runs of not NULL values (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; @@ -154,19 +125,76 @@ pub(super) fn list_equal( let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + if lhs_is_null != rhs_is_null { + return false; + } + + let lhs_offset_start = lhs_offsets[lhs_pos].to_usize().unwrap(); + let lhs_offset_end = lhs_offsets[lhs_pos + 1].to_usize().unwrap(); + let rhs_offset_start = rhs_offsets[rhs_pos].to_usize().unwrap(); + let rhs_offset_end = rhs_offsets[rhs_pos + 1].to_usize().unwrap(); + + let lhs_len = lhs_offset_end - lhs_offset_start; + let rhs_len = rhs_offset_end - rhs_offset_start; + lhs_is_null - || (lhs_is_null == rhs_is_null) - && offset_value_equal::( + || (lhs_len == rhs_len + && equal_range( lhs_values, rhs_values, - child_lhs_nulls.as_ref(), - child_rhs_nulls.as_ref(), - lhs_offsets, - rhs_offsets, - lhs_pos, - rhs_pos, - 1, - ) + lhs_offset_start, + rhs_offset_start, + lhs_len, + )) }) } } + +#[cfg(test)] +mod tests { + use crate::{ + array::{Array, Int64Builder, ListArray, ListBuilder}, + datatypes::Int32Type, + }; + + #[test] + fn list_array_non_zero_nulls() { + // Tests handling of list arrays with non-empty null ranges + let mut builder = ListBuilder::new(Int64Builder::new(10)); + builder.values().append_value(1).unwrap(); + builder.values().append_value(2).unwrap(); + builder.values().append_value(3).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + let array1 = builder.finish(); + + let mut builder = ListBuilder::new(Int64Builder::new(10)); + builder.values().append_value(1).unwrap(); + builder.values().append_value(2).unwrap(); + builder.values().append_value(3).unwrap(); + builder.append(true).unwrap(); + builder.values().append_null().unwrap(); + builder.values().append_null().unwrap(); + builder.append(false).unwrap(); + let array2 = builder.finish(); + + assert_eq!(array1, array2); + } + + #[test] + fn test_list_different_offsets() { + let a = ListArray::from_iter_primitive::([ + Some([Some(0), Some(0)]), + Some([Some(1), Some(2)]), + Some([None, None]), + ]); + let b = ListArray::from_iter_primitive::([ + Some([Some(1), Some(2)]), + Some([None, None]), + Some([None, None]), + ]); + let a_slice = a.slice(1, 2); + let b_slice = b.slice(0, 2); + assert_eq!(&a_slice, &b_slice); + } +} diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 15d41a0d67d6..c3b0bbc95c2b 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -20,16 +20,12 @@ //! depend on dynamic casting of `Array`. use super::{ - Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray, - FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, - GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, - StringOffsetSizeTrait, StructArray, -}; - -use crate::{ - buffer::Buffer, - datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}, + Array, ArrayData, BooleanArray, DecimalArray, DictionaryArray, FixedSizeBinaryArray, + FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericStringArray, + MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StructArray, }; +use crate::datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}; +use half::f16; mod boolean; mod decimal; @@ -40,6 +36,7 @@ mod list; mod null; mod primitive; mod structure; +mod union; mod utils; mod variable_size; @@ -55,6 +52,7 @@ use list::list_equal; use null::null_equal; use primitive::primitive_equal; use structure::struct_equal; +use union::union_equal; use variable_size::variable_sized_equal; impl PartialEq for dyn Array { @@ -81,19 +79,25 @@ impl PartialEq for PrimitiveArray { } } +impl PartialEq for DictionaryArray { + fn eq(&self, other: &Self) -> bool { + equal(self.data(), other.data()) + } +} + impl PartialEq for BooleanArray { fn eq(&self, other: &BooleanArray) -> bool { equal(self.data(), other.data()) } } -impl PartialEq for GenericStringArray { +impl PartialEq for GenericStringArray { fn eq(&self, other: &Self) -> bool { equal(self.data(), other.data()) } } -impl PartialEq for GenericBinaryArray { +impl PartialEq for GenericBinaryArray { fn eq(&self, other: &Self) -> bool { equal(self.data(), other.data()) } @@ -136,140 +140,99 @@ impl PartialEq for StructArray { } /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively -/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability. -/// -/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent. -/// This then affects the null count of the array, thus the merged nulls are passed separately -/// as `lhs_nulls` and `rhs_nulls` variables to functions. -/// The nulls are merged with a bitwise AND, and null counts are recomputed where necessary. +/// for `len` slots. #[inline] fn equal_values( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { match lhs.data_type() { DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len), - DataType::Boolean => { - boolean_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) - } - DataType::UInt8 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt16 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt32 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt64 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int8 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int16 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int32 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int64 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Float32 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Float64 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), + DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Date32 | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), + | DataType::Interval(IntervalUnit::YearMonth) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } DataType::Date64 | DataType::Interval(IntervalUnit::DayTime) | DataType::Time64(_) | DataType::Timestamp(_, _) - | DataType::Duration(_) => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Utf8 | DataType::Binary => variable_sized_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::LargeUtf8 | DataType::LargeBinary => variable_sized_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::FixedSizeBinary(_) => { - fixed_binary_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + | DataType::Duration(_) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Decimal(_, _) => { - decimal_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::Interval(IntervalUnit::MonthDayNano) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::List(_) => { - list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::Utf8 | DataType::Binary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::LargeList(_) => { - list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::FixedSizeList(_, _) => { - fixed_list_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::FixedSizeBinary(_) => { + fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Struct(_) => { - struct_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::Decimal(_, _) => decimal_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::FixedSizeList(_, _) => { + fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Union(_) => unimplemented!("See ARROW-8576"), + DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Union(_, _, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Dictionary(data_type, _) => match data_type.as_ref() { - DataType::Int8 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int16 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int32 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int64 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt8 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt16 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt32 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt64 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), + DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Int32 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Int64 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt8 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt16 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt32 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt64 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } _ => unreachable!(), }, - DataType::Float16 => unreachable!(), - DataType::Map(_, _) => { - list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) - } + DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Map(_, _) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), } } fn equal_range( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { - utils::base_equal(lhs, rhs) - && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) - && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len) + && equal_values(lhs, rhs, lhs_start, rhs_start, len) } /// Logically compares two [ArrayData]. @@ -285,12 +248,10 @@ fn equal_range( /// This function may panic whenever any of the [ArrayData] does not follow the Arrow specification. /// (e.g. wrong number of buffers, buffer `len` does not correspond to the declared `len`) pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { - let lhs_nulls = lhs.null_buffer(); - let rhs_nulls = rhs.null_buffer(); utils::base_equal(lhs, rhs) && lhs.null_count() == rhs.null_count() - && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len()) - && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len()) + && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len()) + && equal_values(lhs, rhs, 0, 0, lhs.len()) } #[cfg(test)] @@ -299,14 +260,14 @@ mod tests { use std::sync::Arc; use crate::array::{ - array::Array, ArrayDataBuilder, ArrayRef, BinaryOffsetSizeTrait, BooleanArray, - DecimalBuilder, FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray, - Int32Builder, ListBuilder, NullArray, PrimitiveBuilder, StringArray, - StringDictionaryBuilder, StringOffsetSizeTrait, StructArray, + array::Array, ArrayData, ArrayDataBuilder, ArrayRef, BooleanArray, + FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray, Int32Builder, + ListBuilder, NullArray, PrimitiveBuilder, StringArray, StringDictionaryBuilder, + StructArray, UnionBuilder, }; use crate::array::{GenericStringArray, Int32Array}; use crate::buffer::Buffer; - use crate::datatypes::{Field, Int16Type, ToByteSlice}; + use crate::datatypes::{Field, Int16Type, Int32Type, ToByteSlice}; use super::*; @@ -479,6 +440,13 @@ mod tests { (1, 2), true, ), + ( + vec![Some(1), Some(2), None, Some(0)], + (2, 2), + vec![Some(4), Some(5), Some(0), None], + (2, 2), + false, + ), ]; for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases { @@ -502,7 +470,9 @@ mod tests { assert_eq!(equal(rhs, lhs), expected, "\n{:?}\n{:?}", rhs, lhs); } - fn binary_cases() -> Vec<(Vec>, Vec>, bool)> { + type OptionString = Option; + + fn binary_cases() -> Vec<(Vec, Vec, bool)> { let base = vec![ Some("hello".to_owned()), None, @@ -535,15 +505,13 @@ mod tests { ] } - fn test_generic_string_equal() { + fn test_generic_string_equal() { let cases = binary_cases(); for (lhs, rhs, expected) in cases { - let lhs = lhs.iter().map(|x| x.as_deref()).collect(); - let rhs = rhs.iter().map(|x| x.as_deref()).collect(); - let lhs = GenericStringArray::::from_opt_vec(lhs); + let lhs: GenericStringArray = lhs.into_iter().collect(); let lhs = lhs.data(); - let rhs = GenericStringArray::::from_opt_vec(rhs); + let rhs: GenericStringArray = rhs.into_iter().collect(); let rhs = rhs.data(); test_equal(lhs, rhs, expected); } @@ -559,7 +527,7 @@ mod tests { test_generic_string_equal::() } - fn test_generic_binary_equal() { + fn test_generic_binary_equal() { let cases = binary_cases(); for (lhs, rhs, expected) in cases { @@ -589,6 +557,19 @@ mod tests { test_generic_binary_equal::() } + #[test] + fn test_fixed_size_binary_array() { + let a_input_arg = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; + let a = FixedSizeBinaryArray::try_from_iter(a_input_arg.into_iter()).unwrap(); + let a = a.data(); + + let b_input_arg = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; + let b = FixedSizeBinaryArray::try_from_iter(b_input_arg.into_iter()).unwrap(); + let b = b.data(); + + test_equal(a, b, true); + } + #[test] fn test_string_offset() { let a = StringArray::from(vec![Some("a"), None, Some("b")]); @@ -648,6 +629,57 @@ mod tests { test_equal(&a, &b, false); } + #[test] + fn test_empty_offsets_list_equal() { + let empty: Vec = vec![]; + let values = Int32Array::from(empty); + let empty_offsets: [u8; 0] = []; + + let a = ArrayDataBuilder::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))) + .len(0) + .add_buffer(Buffer::from(&empty_offsets)) + .add_child_data(values.data().clone()) + .null_bit_buffer(Some(Buffer::from(&empty_offsets))) + .build() + .unwrap(); + + let b = ArrayDataBuilder::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))) + .len(0) + .add_buffer(Buffer::from(&empty_offsets)) + .add_child_data(values.data().clone()) + .null_bit_buffer(Some(Buffer::from(&empty_offsets))) + .build() + .unwrap(); + + test_equal(&a, &b, true); + + let c = ArrayDataBuilder::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))) + .len(0) + .add_buffer(Buffer::from(vec![0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) + .add_child_data( + Int32Array::from(vec![1, 2, -1, -2, 3, 4, -3, -4]) + .data() + .clone(), + ) + .null_bit_buffer(Some(Buffer::from(vec![0b00001001]))) + .build() + .unwrap(); + + test_equal(&a, &c, true); + } + // Test the case where null_count > 0 #[test] fn test_list_null() { @@ -681,7 +713,7 @@ mod tests { .len(6) .add_buffer(Buffer::from(vec![0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) .add_child_data(c_values.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00001001])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001001]))) .build() .unwrap(); @@ -703,7 +735,7 @@ mod tests { .len(6) .add_buffer(Buffer::from(vec![0i32, 2, 3, 4, 6, 7, 8].to_byte_slice())) .add_child_data(d_values.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00001001])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001001]))) .build() .unwrap(); test_equal(&c, &d, true); @@ -807,16 +839,11 @@ mod tests { } fn create_decimal_array(data: &[Option]) -> ArrayData { - let mut builder = DecimalBuilder::new(20, 23, 6); - - for d in data { - if let Some(v) = d { - builder.append_value(*v).unwrap(); - } else { - builder.append_null().unwrap(); - } - } - builder.finish().data().clone() + data.iter() + .collect::() + .with_precision_and_scale(23, 6) + .unwrap() + .into() } #[test] @@ -967,6 +994,11 @@ mod tests { None, ]); test_equal(&a, &b, false); + + let b = create_fixed_size_list_array(&[None, Some(&[4, 5, 6]), None, None]); + + test_equal(&a.slice(2, 4), &b, true); + test_equal(&a.slice(3, 3), &b.slice(1, 3), true); } #[test] @@ -1052,7 +1084,7 @@ mod tests { Field::new("f1", DataType::Utf8, true), Field::new("f2", DataType::Int32, true), ])) - .null_bit_buffer(Buffer::from(vec![0b00001011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001011]))) .len(5) .add_child_data(strings.data_ref().clone()) .add_child_data(ints.data_ref().clone()) @@ -1064,7 +1096,7 @@ mod tests { Field::new("f1", DataType::Utf8, true), Field::new("f2", DataType::Int32, true), ])) - .null_bit_buffer(Buffer::from(vec![0b00001011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001011]))) .len(5) .add_child_data(strings.data_ref().clone()) .add_child_data(ints_non_null.data_ref().clone()) @@ -1080,7 +1112,7 @@ mod tests { Field::new("f1", DataType::Utf8, true), Field::new("f2", DataType::Int32, true), ])) - .null_bit_buffer(Buffer::from(vec![0b00001011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001011]))) .len(5) .add_child_data(strings.data_ref().clone()) .add_child_data(c_ints_non_null.data_ref().clone()) @@ -1096,7 +1128,7 @@ mod tests { a.data_type().clone(), true, )])) - .null_bit_buffer(Buffer::from(vec![0b00011110])) + .null_bit_buffer(Some(Buffer::from(vec![0b00011110]))) .len(5) .add_child_data(a.data_ref().clone()) .build() @@ -1115,7 +1147,7 @@ mod tests { Field::new("f1", DataType::Utf8, true), Field::new("f2", DataType::Int32, true), ])) - .null_bit_buffer(Buffer::from(vec![0b00001011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001011]))) .len(5) .add_child_data(strings.data_ref().clone()) .add_child_data(ints_non_null.data_ref().clone()) @@ -1127,7 +1159,7 @@ mod tests { b.data_type().clone(), true, )])) - .null_bit_buffer(Buffer::from(vec![0b00011110])) + .null_bit_buffer(Some(Buffer::from(vec![0b00011110]))) .len(5) .add_child_data(b) .build() @@ -1160,7 +1192,7 @@ mod tests { DataType::Utf8, true, )])) - .null_bit_buffer(Buffer::from(vec![0b00001010])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001010]))) .len(5) .add_child_data(strings1.data_ref().clone()) .build() @@ -1172,7 +1204,7 @@ mod tests { DataType::Utf8, true, )])) - .null_bit_buffer(Buffer::from(vec![0b00001010])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001010]))) .len(5) .add_child_data(strings2.data_ref().clone()) .build() @@ -1194,7 +1226,7 @@ mod tests { DataType::Utf8, true, )])) - .null_bit_buffer(Buffer::from(vec![0b00001011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00001011]))) .len(5) .add_child_data(strings3.data_ref().clone()) .build() @@ -1294,4 +1326,132 @@ mod tests { ); test_equal(&a, &b, false); } + + #[test] + fn test_non_null_empty_strings() { + let s = StringArray::from(vec![Some(""), Some(""), Some("")]); + + let string1 = s.data(); + + let string2 = ArrayData::builder(DataType::Utf8) + .len(string1.len()) + .buffers(string1.buffers().to_vec()) + .build() + .unwrap(); + + // string2 is identical to string1 except that it has no validity buffer but since there + // are no nulls, string1 and string2 are equal + test_equal(string1, &string2, true); + } + + #[test] + fn test_null_empty_strings() { + let s = StringArray::from(vec![Some(""), None, Some("")]); + + let string1 = s.data(); + + let string2 = ArrayData::builder(DataType::Utf8) + .len(string1.len()) + .buffers(string1.buffers().to_vec()) + .build() + .unwrap(); + + // string2 is identical to string1 except that it has no validity buffer since string1 has + // nulls in it, string1 and string2 are not equal + test_equal(string1, &string2, false); + } + + #[test] + fn test_union_equal_dense() { + let mut builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union1 = builder.build().unwrap(); + + builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union2 = builder.build().unwrap(); + + builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 5).unwrap(); + builder.append::("c", 4).unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union3 = builder.build().unwrap(); + + builder = UnionBuilder::new_dense(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null::("c").unwrap(); + builder.append_null::("b").unwrap(); + builder.append::("b", 7).unwrap(); + let union4 = builder.build().unwrap(); + + test_equal(union1.data(), union2.data(), true); + test_equal(union1.data(), union3.data(), false); + test_equal(union1.data(), union4.data(), false); + } + + #[test] + fn test_union_equal_sparse() { + let mut builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union1 = builder.build().unwrap(); + + builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union2 = builder.build().unwrap(); + + builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 5).unwrap(); + builder.append::("c", 4).unwrap(); + builder.append::("a", 6).unwrap(); + builder.append::("b", 7).unwrap(); + let union3 = builder.build().unwrap(); + + builder = UnionBuilder::new_sparse(7); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("c", 3).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append_null::("a").unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("b", 7).unwrap(); + let union4 = builder.build().unwrap(); + + test_equal(union1.data(), union2.data(), true); + test_equal(union1.data(), union3.data(), false); + test_equal(union1.data(), union4.data(), false); + } } diff --git a/arrow/src/array/equal/primitive.rs b/arrow/src/array/equal/primitive.rs index db7587915c8a..09882cd78509 100644 --- a/arrow/src/array/equal/primitive.rs +++ b/arrow/src/array/equal/primitive.rs @@ -18,7 +18,6 @@ use std::mem::size_of; use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::util::bit_util::get_bit; use super::utils::equal_len; @@ -26,8 +25,6 @@ use super::utils::equal_len; pub(super) fn primitive_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -36,8 +33,8 @@ pub(super) fn primitive_equal( let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * byte_width..]; let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * byte_width..]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { // without nulls, we just need to compare slices @@ -50,8 +47,8 @@ pub(super) fn primitive_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; diff --git a/arrow/src/array/equal/structure.rs b/arrow/src/array/equal/structure.rs index b3cc4029e9ec..0f943e40cac6 100644 --- a/arrow/src/array/equal/structure.rs +++ b/arrow/src/array/equal/structure.rs @@ -15,24 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::{ - array::data::count_nulls, array::ArrayData, buffer::Buffer, util::bit_util::get_bit, -}; +use crate::{array::data::count_nulls, array::ArrayData, util::bit_util::get_bit}; -use super::{equal_range, utils::child_logical_null_buffer}; +use super::equal_range; /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively -/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability. -/// -/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent. -/// This then affects the null count of the array, thus the merged nulls are passed separately -/// as `lhs_nulls` and `rhs_nulls` variables to functions. -/// The nulls are merged with a bitwise AND, and null counts are recomputed where necessary. -fn equal_values( +/// for `len` slots. +fn equal_child_values( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -41,39 +32,27 @@ fn equal_values( .iter() .zip(rhs.child_data()) .all(|(lhs_values, rhs_values)| { - // merge the null data - let lhs_merged_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values); - let rhs_merged_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values); - equal_range( - lhs_values, - rhs_values, - lhs_merged_nulls.as_ref(), - rhs_merged_nulls.as_ref(), - lhs_start, - rhs_start, - len, - ) + equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) }) } pub(super) fn struct_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { // we have to recalculate null counts from the null buffers - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + if lhs_null_count == 0 && rhs_null_count == 0 { - equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + equal_child_values(lhs, rhs, lhs_start, rhs_start, len) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; @@ -82,9 +61,11 @@ pub(super) fn struct_equal( let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); - lhs_is_null - || (lhs_is_null == rhs_is_null) - && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_pos, rhs_pos, 1) + if lhs_is_null != rhs_is_null { + return false; + } + + lhs_is_null || equal_child_values(lhs, rhs, lhs_pos, rhs_pos, 1) }) } } diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs new file mode 100644 index 000000000000..e8b9d27b6f0f --- /dev/null +++ b/arrow/src/array/equal/union.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{array::ArrayData, datatypes::DataType, datatypes::UnionMode}; + +use super::equal_range; + +#[allow(clippy::too_many_arguments)] +fn equal_dense( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_type_ids: &[i8], + rhs_type_ids: &[i8], + lhs_offsets: &[i32], + rhs_offsets: &[i32], + lhs_field_type_ids: &[i8], + rhs_field_type_ids: &[i8], +) -> bool { + let offsets = lhs_offsets.iter().zip(rhs_offsets.iter()); + + lhs_type_ids + .iter() + .zip(rhs_type_ids.iter()) + .zip(offsets) + .all(|((l_type_id, r_type_id), (l_offset, r_offset))| { + let lhs_child_index = lhs_field_type_ids + .iter() + .position(|r| r == l_type_id) + .unwrap(); + let rhs_child_index = rhs_field_type_ids + .iter() + .position(|r| r == r_type_id) + .unwrap(); + let lhs_values = &lhs.child_data()[lhs_child_index]; + let rhs_values = &rhs.child_data()[rhs_child_index]; + + equal_range( + lhs_values, + rhs_values, + *l_offset as usize, + *r_offset as usize, + 1, + ) + }) +} + +fn equal_sparse( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + lhs.child_data() + .iter() + .zip(rhs.child_data()) + .all(|(lhs_values, rhs_values)| { + equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) + }) +} + +pub(super) fn union_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_type_ids = lhs.buffer::(0); + let rhs_type_ids = rhs.buffer::(0); + + let lhs_type_id_range = &lhs_type_ids[lhs_start..lhs_start + len]; + let rhs_type_id_range = &rhs_type_ids[rhs_start..rhs_start + len]; + + match (lhs.data_type(), rhs.data_type()) { + ( + DataType::Union(_, lhs_type_ids, UnionMode::Dense), + DataType::Union(_, rhs_type_ids, UnionMode::Dense), + ) => { + let lhs_offsets = lhs.buffer::(1); + let rhs_offsets = rhs.buffer::(1); + + let lhs_offsets_range = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_offsets_range = &rhs_offsets[rhs_start..rhs_start + len]; + + lhs_type_id_range == rhs_type_id_range + && equal_dense( + lhs, + rhs, + lhs_type_id_range, + rhs_type_id_range, + lhs_offsets_range, + rhs_offsets_range, + lhs_type_ids, + rhs_type_ids, + ) + } + ( + DataType::Union(_, _, UnionMode::Sparse), + DataType::Union(_, _, UnionMode::Sparse), + ) => { + lhs_type_id_range == rhs_type_id_range + && equal_sparse(lhs, rhs, lhs_start, rhs_start, len) + } + _ => unimplemented!( + "Logical equality not yet implemented between dense and sparse union arrays" + ), + } +} diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index 8eb988cb9a98..fed3933a0893 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData, OffsetSizeTrait}; -use crate::bitmap::Bitmap; -use crate::buffer::{Buffer, MutableBuffer}; +use crate::array::{data::count_nulls, ArrayData}; use crate::datatypes::DataType; use crate::util::bit_util; @@ -41,17 +39,20 @@ pub(super) fn equal_bits( pub(super) fn equal_nulls( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + + if lhs_null_count != rhs_null_count { + return false; + } + if lhs_null_count > 0 || rhs_null_count > 0 { - let lhs_values = lhs_nulls.unwrap().as_slice(); - let rhs_values = rhs_nulls.unwrap().as_slice(); + let lhs_values = lhs.null_buffer().unwrap().as_slice(); + let rhs_values = rhs.null_buffer().unwrap().as_slice(); equal_bits( lhs_values, rhs_values, @@ -66,7 +67,38 @@ pub(super) fn equal_nulls( #[inline] pub(super) fn base_equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { - lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() + let equal_type = match (lhs.data_type(), rhs.data_type()) { + (DataType::Union(l_fields, _, l_mode), DataType::Union(r_fields, _, r_mode)) => { + l_fields == r_fields && l_mode == r_mode + } + (DataType::Map(l_field, l_sorted), DataType::Map(r_field, r_sorted)) => { + let field_equal = match (l_field.data_type(), r_field.data_type()) { + (DataType::Struct(l_fields), DataType::Struct(r_fields)) + if l_fields.len() == 2 && r_fields.len() == 2 => + { + let l_key_field = l_fields.get(0).unwrap(); + let r_key_field = r_fields.get(0).unwrap(); + let l_value_field = l_fields.get(1).unwrap(); + let r_value_field = r_fields.get(1).unwrap(); + + // We don't enforce the equality of field names + let data_type_equal = l_key_field.data_type() + == r_key_field.data_type() + && l_value_field.data_type() == r_value_field.data_type(); + let nullability_equal = l_key_field.is_nullable() + == r_key_field.is_nullable() + && l_value_field.is_nullable() == r_value_field.is_nullable(); + let metadata_equal = l_key_field.metadata() == r_key_field.metadata() + && l_value_field.metadata() == r_value_field.metadata(); + data_type_equal && nullability_equal && metadata_equal + } + _ => panic!("Map type should have 2 fields Struct in its field"), + }; + field_equal && l_sorted == r_sorted + } + (l_data_type, r_data_type) => l_data_type == r_data_type, + }; + equal_type && lhs.len() == rhs.len() } // whether the two memory regions are equal @@ -80,188 +112,3 @@ pub(super) fn equal_len( ) -> bool { lhs_values[lhs_start..(lhs_start + len)] == rhs_values[rhs_start..(rhs_start + len)] } - -/// Computes the logical validity bitmap of the array data using the -/// parent's array data. The parent should be a list or struct, else -/// the logical bitmap of the array is returned unaltered. -/// -/// Parent data is passed along with the parent's logical bitmap, as -/// nested arrays could have a logical bitmap different to the physical -/// one on the `ArrayData`. -pub(super) fn child_logical_null_buffer( - parent_data: &ArrayData, - logical_null_buffer: Option<&Buffer>, - child_data: &ArrayData, -) -> Option { - let parent_len = parent_data.len(); - let parent_bitmap = logical_null_buffer - .cloned() - .map(Bitmap::from) - .unwrap_or_else(|| { - let ceil = bit_util::ceil(parent_len, 8); - Bitmap::from(Buffer::from(vec![0b11111111; ceil])) - }); - let self_null_bitmap = child_data.null_bitmap().clone().unwrap_or_else(|| { - let ceil = bit_util::ceil(child_data.len(), 8); - Bitmap::from(Buffer::from(vec![0b11111111; ceil])) - }); - match parent_data.data_type() { - DataType::List(_) | DataType::Map(_, _) => Some(logical_list_bitmap::( - parent_data, - parent_bitmap, - self_null_bitmap, - )), - DataType::LargeList(_) => Some(logical_list_bitmap::( - parent_data, - parent_bitmap, - self_null_bitmap, - )), - DataType::FixedSizeList(_, len) => { - let len = *len as usize; - let array_offset = parent_data.offset(); - let bitmap_len = bit_util::ceil(parent_len * len, 8); - let mut buffer = MutableBuffer::from_len_zeroed(bitmap_len); - let mut null_slice = buffer.as_slice_mut(); - (array_offset..parent_len + array_offset).for_each(|index| { - let start = index * len; - let end = start + len; - let mask = parent_bitmap.is_set(index); - (start..end).for_each(|child_index| { - if mask && self_null_bitmap.is_set(child_index) { - bit_util::set_bit(&mut null_slice, child_index); - } - }); - }); - Some(buffer.into()) - } - DataType::Struct(_) => { - // Arrow implementations are free to pad data, which can result in null buffers not - // having the same length. - // Rust bitwise comparisons will return an error if left AND right is performed on - // buffers of different length. - // This might be a valid case during integration testing, where we read Arrow arrays - // from IPC data, which has padding. - // - // We first perform a bitwise comparison, and if there is an error, we revert to a - // slower method that indexes into the buffers one-by-one. - let result = &parent_bitmap & &self_null_bitmap; - if let Ok(bitmap) = result { - return Some(bitmap.bits); - } - // slow path - let array_offset = parent_data.offset(); - let mut buffer = MutableBuffer::new_null(parent_len); - let mut null_slice = buffer.as_slice_mut(); - (0..parent_len).for_each(|index| { - if parent_bitmap.is_set(index + array_offset) - && self_null_bitmap.is_set(index + array_offset) - { - bit_util::set_bit(&mut null_slice, index); - } - }); - Some(buffer.into()) - } - DataType::Union(_) => { - unimplemented!("Logical equality not yet implemented for union arrays") - } - DataType::Dictionary(_, _) => { - unimplemented!("Logical equality not yet implemented for nested dictionaries") - } - data_type => panic!("Data type {:?} is not a supported nested type", data_type), - } -} - -// Calculate a list child's logical bitmap/buffer -#[inline] -fn logical_list_bitmap( - parent_data: &ArrayData, - parent_bitmap: Bitmap, - child_bitmap: Bitmap, -) -> Buffer { - let offsets = parent_data.buffer::(0); - let offset_start = offsets.first().unwrap().to_usize().unwrap(); - let offset_len = offsets.get(parent_data.len()).unwrap().to_usize().unwrap(); - let mut buffer = MutableBuffer::new_null(offset_len - offset_start); - let mut null_slice = buffer.as_slice_mut(); - - offsets - .windows(2) - .enumerate() - .take(parent_data.len()) - .for_each(|(index, window)| { - let start = window[0].to_usize().unwrap(); - let end = window[1].to_usize().unwrap(); - let mask = parent_bitmap.is_set(index); - (start..end).for_each(|child_index| { - if mask && child_bitmap.is_set(child_index) { - bit_util::set_bit(&mut null_slice, child_index - offset_start); - } - }); - }); - buffer.into() -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::datatypes::{Field, ToByteSlice}; - - #[test] - fn test_logical_null_buffer() { - let child_data = ArrayData::builder(DataType::Int32) - .len(11) - .add_buffer(Buffer::from( - vec![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11].to_byte_slice(), - )) - .build() - .unwrap(); - - let data = ArrayData::builder(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - false, - )))) - .len(7) - .add_buffer(Buffer::from(vec![0, 0, 3, 5, 6, 9, 10, 11].to_byte_slice())) - .null_bit_buffer(Buffer::from(vec![0b01011010])) - .add_child_data(child_data.clone()) - .build() - .unwrap(); - - // Get the child logical null buffer. The child is non-nullable, but because the list has nulls, - // we expect the child to logically have some nulls, inherited from the parent: - // [1, 2, 3, null, null, 6, 7, 8, 9, null, 11] - let nulls = child_logical_null_buffer( - &data, - data.null_buffer(), - data.child_data().get(0).unwrap(), - ); - let expected = Some(Buffer::from(vec![0b11100111, 0b00000101])); - assert_eq!(nulls, expected); - - // test with offset - let data = ArrayData::builder(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - false, - )))) - .len(4) - .offset(3) - .add_buffer(Buffer::from(vec![0, 0, 3, 5, 6, 9, 10, 11].to_byte_slice())) - // the null_bit_buffer doesn't have an offset, i.e. cleared the 3 offset bits 0b[---]01011[010] - .null_bit_buffer(Buffer::from(vec![0b00001011])) - .add_child_data(child_data) - .build() - .unwrap(); - - let nulls = child_logical_null_buffer( - &data, - data.null_buffer(), - data.child_data().get(0).unwrap(), - ); - - let expected = Some(Buffer::from(vec![0b00101111])); - assert_eq!(nulls, expected); - } -} diff --git a/arrow/src/array/equal/variable_size.rs b/arrow/src/array/equal/variable_size.rs index ecb3bc2a3c20..f40f79e404ac 100644 --- a/arrow/src/array/equal/variable_size.rs +++ b/arrow/src/array/equal/variable_size.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::Buffer; use crate::util::bit_util::get_bit; use crate::{ array::data::count_nulls, @@ -51,8 +50,6 @@ fn offset_value_equal( pub(super) fn variable_sized_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -64,8 +61,8 @@ pub(super) fn variable_sized_equal( let lhs_values = lhs.buffers()[1].as_slice(); let rhs_values = rhs.buffers()[1].as_slice(); - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 @@ -86,13 +83,16 @@ pub(super) fn variable_sized_equal( let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - // the null bits can still be `None`, so we don't unwrap - let lhs_is_null = !lhs_nulls + // the null bits can still be `None`, indicating that the value is valid. + let lhs_is_null = !lhs + .null_buffer() .map(|v| get_bit(v.as_slice(), lhs.offset() + lhs_pos)) - .unwrap_or(false); - let rhs_is_null = !rhs_nulls + .unwrap_or(true); + + let rhs_is_null = !rhs + .null_buffer() .map(|v| get_bit(v.as_slice(), rhs.offset() + rhs_pos)) - .unwrap_or(false); + .unwrap_or(true); lhs_is_null || (lhs_is_null == rhs_is_null) diff --git a/arrow/src/array/equal_json.rs b/arrow/src/array/equal_json.rs index adc33a7a1cd3..9db1a4397cb8 100644 --- a/arrow/src/array/equal_json.rs +++ b/arrow/src/array/equal_json.rs @@ -251,7 +251,7 @@ impl PartialEq for Value { } } -impl JsonEqual for GenericBinaryArray { +impl JsonEqual for GenericBinaryArray { fn equals_json(&self, json: &[&Value]) -> bool { if self.len() != json.len() { return false; @@ -271,9 +271,7 @@ impl JsonEqual for GenericBinaryArray PartialEq - for GenericBinaryArray -{ +impl PartialEq for GenericBinaryArray { fn eq(&self, json: &Value) -> bool { match json { Value::Array(json_array) => self.equals_json_values(json_array), @@ -282,9 +280,7 @@ impl PartialEq } } -impl PartialEq> - for Value -{ +impl PartialEq> for Value { fn eq(&self, arrow: &GenericBinaryArray) -> bool { match self { Value::Array(json_array) => arrow.equals_json_values(json_array), @@ -293,7 +289,7 @@ impl PartialEq } } -impl JsonEqual for GenericStringArray { +impl JsonEqual for GenericStringArray { fn equals_json(&self, json: &[&Value]) -> bool { if self.len() != json.len() { return false; @@ -307,9 +303,7 @@ impl JsonEqual for GenericStringArray PartialEq - for GenericStringArray -{ +impl PartialEq for GenericStringArray { fn eq(&self, json: &Value) -> bool { match json { Value::Array(json_array) => self.equals_json_values(json_array), @@ -318,9 +312,7 @@ impl PartialEq } } -impl PartialEq> - for Value -{ +impl PartialEq> for Value { fn eq(&self, arrow: &GenericStringArray) -> bool { match self { Value::Array(json_array) => arrow.equals_json_values(json_array), @@ -378,7 +370,7 @@ impl JsonEqual for DecimalArray { self.is_valid(i) && (s .parse::() - .map_or_else(|_| false, |v| v == self.value(i))) + .map_or_else(|_| false, |v| v == self.value(i).as_i128())) } JNull => self.is_null(i), _ => false, @@ -441,6 +433,12 @@ impl PartialEq for NullArray { } } +impl JsonEqual for ArrayRef { + fn equals_json(&self, json: &[&Value]) -> bool { + self.as_ref().equals_json(json) + } +} + #[cfg(test)] mod tests { use super::*; @@ -931,11 +929,11 @@ mod tests { #[test] fn test_decimal_json_equal() { // Test the equal case - let mut builder = DecimalBuilder::new(30, 23, 6); - builder.append_value(1_000).unwrap(); - builder.append_null().unwrap(); - builder.append_value(-250).unwrap(); - let arrow_array: DecimalArray = builder.finish(); + let arrow_array = [Some(1_000), None, Some(-250)] + .iter() + .collect::() + .with_precision_and_scale(23, 6) + .unwrap(); let json_array: Value = serde_json::from_str( r#" [ @@ -950,10 +948,11 @@ mod tests { assert!(json_array.eq(&arrow_array)); // Test unequal case - builder.append_value(1_000).unwrap(); - builder.append_null().unwrap(); - builder.append_value(55).unwrap(); - let arrow_array: DecimalArray = builder.finish(); + let arrow_array = [Some(1_000), None, Some(55)] + .iter() + .collect::() + .with_precision_and_scale(23, 6) + .unwrap(); let json_array: Value = serde_json::from_str( r#" [ diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs index 847649ce1264..12d6f440b78d 100644 --- a/arrow/src/array/ffi.rs +++ b/arrow/src/array/ffi.rs @@ -45,11 +45,14 @@ impl TryFrom for ffi::ArrowArray { #[cfg(test)] mod tests { + use crate::array::{DictionaryArray, FixedSizeListArray, Int32Array, StringArray}; + use crate::buffer::Buffer; use crate::error::Result; + use crate::util::bit_util; use crate::{ array::{ - Array, ArrayData, BooleanArray, Int64Array, StructArray, UInt32Array, - UInt64Array, + Array, ArrayData, BooleanArray, FixedSizeBinaryArray, Int64Array, + StructArray, UInt32Array, UInt64Array, }, datatypes::{DataType, Field}, ffi::ArrowArray, @@ -71,6 +74,11 @@ mod tests { let result = &ArrayData::try_from(d1)?; assert_eq!(result, expected); + + unsafe { + Arc::from_raw(array); + Arc::from_raw(schema); + } Ok(()) } @@ -127,4 +135,133 @@ mod tests { let data = array.data(); test_round_trip(data) } + + #[test] + fn test_dictionary() -> Result<()> { + let values = StringArray::from(vec![Some("foo"), Some("bar"), None]); + let keys = Int32Array::from(vec![ + Some(0), + Some(1), + None, + Some(1), + Some(1), + None, + Some(1), + Some(2), + Some(1), + None, + ]); + let array = DictionaryArray::try_new(&keys, &values)?; + + let data = array.data(); + test_round_trip(data) + } + + #[test] + fn test_fixed_size_binary() -> Result<()> { + let values = vec![vec![10, 10, 10], vec![20, 20, 20], vec![30, 30, 30]]; + let array = FixedSizeBinaryArray::try_from_iter(values.into_iter())?; + + let data = array.data(); + test_round_trip(data) + } + + #[test] + fn test_fixed_size_binary_with_nulls() -> Result<()> { + let values = vec![ + None, + Some(vec![10, 10, 10]), + None, + Some(vec![20, 20, 20]), + Some(vec![30, 30, 30]), + None, + ]; + let array = FixedSizeBinaryArray::try_from_sparse_iter(values.into_iter())?; + + let data = array.data(); + test_round_trip(data) + } + + #[test] + fn test_fixed_size_list() -> Result<()> { + let v: Vec = (0..9).into_iter().collect(); + let value_data = ArrayData::builder(DataType::Int64) + .len(9) + .add_buffer(Buffer::from_slice_ref(&v)) + .build()?; + let list_data_type = + DataType::FixedSizeList(Box::new(Field::new("f", DataType::Int64, false)), 3); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build()?; + let array = FixedSizeListArray::from(list_data); + + let data = array.data(); + test_round_trip(data) + } + + #[test] + fn test_fixed_size_list_with_nulls() -> Result<()> { + // 0100 0110 + let mut validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut validity_bits, 1); + bit_util::set_bit(&mut validity_bits, 2); + bit_util::set_bit(&mut validity_bits, 6); + + let v: Vec = (0..16).into_iter().collect(); + let value_data = ArrayData::builder(DataType::Int16) + .len(16) + .add_buffer(Buffer::from_slice_ref(&v)) + .build()?; + let list_data_type = + DataType::FixedSizeList(Box::new(Field::new("f", DataType::Int16, false)), 2); + let list_data = ArrayData::builder(list_data_type) + .len(8) + .null_bit_buffer(Some(Buffer::from(validity_bits))) + .add_child_data(value_data) + .build()?; + let array = FixedSizeListArray::from(list_data); + + let data = array.data(); + test_round_trip(data) + } + + #[test] + fn test_fixed_size_list_nested() -> Result<()> { + let v: Vec = (0..16).into_iter().collect(); + let value_data = ArrayData::builder(DataType::Int32) + .len(16) + .add_buffer(Buffer::from_slice_ref(&v)) + .build()?; + + let offsets: Vec = vec![0, 2, 4, 6, 8, 10, 12, 14, 16]; + let value_offsets = Buffer::from_slice_ref(&offsets); + let inner_list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let inner_list_data = ArrayData::builder(inner_list_data_type.clone()) + .len(8) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build()?; + + // 0000 0100 + let mut validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut validity_bits, 2); + + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("f", inner_list_data_type, false)), + 2, + ); + let list_data = ArrayData::builder(list_data_type) + .len(4) + .null_bit_buffer(Some(Buffer::from(validity_bits))) + .add_child_data(inner_list_data) + .build()?; + + let array = FixedSizeListArray::from(list_data); + + let data = array.data(); + test_round_trip(data) + } } diff --git a/arrow/src/array/iterator.rs b/arrow/src/array/iterator.rs index d97aa16744c1..bc70d1a2a8ed 100644 --- a/arrow/src/array/iterator.rs +++ b/arrow/src/array/iterator.rs @@ -18,9 +18,8 @@ use crate::datatypes::ArrowPrimitiveType; use super::{ - Array, ArrayRef, BinaryOffsetSizeTrait, BooleanArray, GenericBinaryArray, - GenericListArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, - StringOffsetSizeTrait, + Array, ArrayRef, BooleanArray, DecimalArray, GenericBinaryArray, GenericListArray, + GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; /// an iterator that returns Some(T) or None, that can be used on any PrimitiveArray @@ -172,14 +171,14 @@ impl<'a> std::iter::ExactSizeIterator for BooleanIter<'a> {} #[derive(Debug)] pub struct GenericStringIter<'a, T> where - T: StringOffsetSizeTrait, + T: OffsetSizeTrait, { array: &'a GenericStringArray, current: usize, current_end: usize, } -impl<'a, T: StringOffsetSizeTrait> GenericStringIter<'a, T> { +impl<'a, T: OffsetSizeTrait> GenericStringIter<'a, T> { /// create a new iterator pub fn new(array: &'a GenericStringArray) -> Self { GenericStringIter:: { @@ -190,7 +189,7 @@ impl<'a, T: StringOffsetSizeTrait> GenericStringIter<'a, T> { } } -impl<'a, T: StringOffsetSizeTrait> std::iter::Iterator for GenericStringIter<'a, T> { +impl<'a, T: OffsetSizeTrait> std::iter::Iterator for GenericStringIter<'a, T> { type Item = Option<&'a str>; fn next(&mut self) -> Option { @@ -219,9 +218,7 @@ impl<'a, T: StringOffsetSizeTrait> std::iter::Iterator for GenericStringIter<'a, } } -impl<'a, T: StringOffsetSizeTrait> std::iter::DoubleEndedIterator - for GenericStringIter<'a, T> -{ +impl<'a, T: OffsetSizeTrait> std::iter::DoubleEndedIterator for GenericStringIter<'a, T> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -242,23 +239,20 @@ impl<'a, T: StringOffsetSizeTrait> std::iter::DoubleEndedIterator } /// all arrays have known size. -impl<'a, T: StringOffsetSizeTrait> std::iter::ExactSizeIterator - for GenericStringIter<'a, T> -{ -} +impl<'a, T: OffsetSizeTrait> std::iter::ExactSizeIterator for GenericStringIter<'a, T> {} /// an iterator that returns `Some(&[u8])` or `None`, for binary arrays #[derive(Debug)] pub struct GenericBinaryIter<'a, T> where - T: BinaryOffsetSizeTrait, + T: OffsetSizeTrait, { array: &'a GenericBinaryArray, current: usize, current_end: usize, } -impl<'a, T: BinaryOffsetSizeTrait> GenericBinaryIter<'a, T> { +impl<'a, T: OffsetSizeTrait> GenericBinaryIter<'a, T> { /// create a new iterator pub fn new(array: &'a GenericBinaryArray) -> Self { GenericBinaryIter:: { @@ -269,7 +263,7 @@ impl<'a, T: BinaryOffsetSizeTrait> GenericBinaryIter<'a, T> { } } -impl<'a, T: BinaryOffsetSizeTrait> std::iter::Iterator for GenericBinaryIter<'a, T> { +impl<'a, T: OffsetSizeTrait> std::iter::Iterator for GenericBinaryIter<'a, T> { type Item = Option<&'a [u8]>; fn next(&mut self) -> Option { @@ -298,9 +292,7 @@ impl<'a, T: BinaryOffsetSizeTrait> std::iter::Iterator for GenericBinaryIter<'a, } } -impl<'a, T: BinaryOffsetSizeTrait> std::iter::DoubleEndedIterator - for GenericBinaryIter<'a, T> -{ +impl<'a, T: OffsetSizeTrait> std::iter::DoubleEndedIterator for GenericBinaryIter<'a, T> { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -321,10 +313,7 @@ impl<'a, T: BinaryOffsetSizeTrait> std::iter::DoubleEndedIterator } /// all arrays have known size. -impl<'a, T: BinaryOffsetSizeTrait> std::iter::ExactSizeIterator - for GenericBinaryIter<'a, T> -{ -} +impl<'a, T: OffsetSizeTrait> std::iter::ExactSizeIterator for GenericBinaryIter<'a, T> {} #[derive(Debug)] pub struct GenericListArrayIter<'a, S> @@ -403,6 +392,54 @@ impl<'a, S: OffsetSizeTrait> std::iter::ExactSizeIterator { } +/// an iterator that returns `Some(i128)` or `None`, that can be used on a +/// [`DecimalArray`] +#[derive(Debug)] +pub struct DecimalIter<'a> { + array: &'a DecimalArray, + current: usize, + current_end: usize, +} + +impl<'a> DecimalIter<'a> { + pub fn new(array: &'a DecimalArray) -> Self { + Self { + array, + current: 0, + current_end: array.len(), + } + } +} + +impl<'a> std::iter::Iterator for DecimalIter<'a> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.current == self.current_end { + None + } else { + let old = self.current; + self.current += 1; + // TODO: Improve performance by avoiding bounds check here + // (by using adding a `value_unchecked, for example) + if self.array.is_null(old) { + Some(None) + } else { + Some(Some(self.array.value(old).as_i128())) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let remain = self.current_end - self.current; + (remain, Some(remain)) + } +} + +/// iterator has known size. +impl<'a> std::iter::ExactSizeIterator for DecimalIter<'a> {} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 63b8b6146389..f4375a0d5897 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -15,42 +15,66 @@ // specific language governing permissions and limitations // under the License. -//! The central type in Apache Arrow are arrays, represented -//! by the [`Array` trait](crate::array::Array). -//! An array represents a known-length sequence of values all -//! having the same type. +//! The central type in Apache Arrow are arrays, which are a known-length sequence of values +//! all having the same type. This module provides concrete implementations of each type, as +//! well as an [`Array`] trait that can be used for type-erasure. //! -//! Internally, those values are represented by one or several -//! [buffers](crate::buffer::Buffer), the number and meaning -//! of which depend on the array’s data type, as documented in -//! [the Arrow data layout specification](https://arrow.apache.org/docs/format/Columnar.html). -//! For example, the type `Int16Array` represents an Apache -//! Arrow array of 16-bit integers. +//! # Downcasting an Array //! -//! Those buffers consist of the value data itself and an -//! optional [bitmap buffer](crate::bitmap::Bitmap) that -//! indicates which array entries are null values. -//! The bitmap buffer can be entirely omitted if the array is -//! known to have zero null values. +//! Arrays are often passed around as a dynamically typed [`&dyn Array`] or [`ArrayRef`]. +//! For example, [`RecordBatch`](`crate::record_batch::RecordBatch`) stores columns as [`ArrayRef`]. //! -//! There are concrete implementations of this trait for each -//! data type, that help you access individual values of the -//! array. +//! Whilst these arrays can be passed directly to the [`compute`](crate::compute), +//! [`csv`](crate::csv), [`json`](crate::json), etc... APIs, it is often the case that you wish +//! to interact with the data directly. This requires downcasting to the concrete type of the array: +//! +//! ``` +//! # use arrow::array::{Array, Float32Array, Int32Array}; +//! # +//! fn sum_int32(array: &dyn Array) -> i32 { +//! let integers: &Int32Array = array.as_any().downcast_ref().unwrap(); +//! integers.iter().map(|val| val.unwrap_or_default()).sum() +//! } +//! +//! // Note: the values for positions corresponding to nulls will be arbitrary +//! fn as_f32_slice(array: &dyn Array) -> &[f32] { +//! array.as_any().downcast_ref::().unwrap().values() +//! } +//! ``` //! //! # Building an Array //! -//! Arrow's `Arrays` are immutable, but there is the trait -//! [`ArrayBuilder`](crate::array::ArrayBuilder) -//! that helps you with constructing new `Arrays`. As with the -//! `Array` trait, there are builder implementations for all -//! concrete array types. +//! Most [`Array`] implementations can be constructed directly from iterators or [`Vec`] //! -//! # Example //! ``` -//! extern crate arrow; +//! # use arrow::array::Int32Array; +//! # use arrow::array::StringArray; +//! # use arrow::array::ListArray; +//! # use arrow::datatypes::Int32Type; +//! # +//! Int32Array::from(vec![1, 2]); +//! Int32Array::from(vec![Some(1), None]); +//! Int32Array::from_iter([1, 2, 3, 4]); +//! Int32Array::from_iter([Some(1), Some(2), None, Some(4)]); +//! +//! StringArray::from(vec!["foo", "bar"]); +//! StringArray::from(vec![Some("foo"), None]); +//! StringArray::from_iter([Some("foo"), None]); +//! StringArray::from_iter_values(["foo", "bar"]); //! -//! use arrow::array::Int16Array; +//! ListArray::from_iter_primitive::([ +//! Some(vec![Some(1), None, Some(3)]), +//! None, +//! Some(vec![]) +//! ]); +//! ``` //! +//! Additionally [`ArrayBuilder`](crate::array::ArrayBuilder) implementations can be +//! used to construct arrays with a push-based interface +//! +//! ``` +//! # use arrow::array::Int16Array; +//! # //! // Create a new builder with a capacity of 100 //! let mut builder = Int16Array::builder(100); //! @@ -80,6 +104,43 @@ //! "Get slice of len 2 starting at idx 3" //! ) //! ``` +//! +//! # Zero-Copy Slicing +//! +//! Given an [`Array`] of arbitrary length, it is possible to create an owned slice of this +//! data. Internally this just increments some ref-counts, and so is incredibly cheap +//! +//! ```rust +//! # use std::sync::Arc; +//! # use arrow::array::{Array, Int32Array, ArrayRef}; +//! let array = Arc::new(Int32Array::from_iter([1, 2, 3])) as ArrayRef; +//! +//! // Slice with offset 1 and length 2 +//! let sliced = array.slice(1, 2); +//! let ints = sliced.as_any().downcast_ref::().unwrap(); +//! assert_eq!(ints.values(), &[2, 3]); +//! ``` +//! +//! # Internal Representation +//! +//! Internally, arrays are represented by one or several [`Buffer`], the number and meaning of +//! which depend on the array’s data type, as documented in the [Arrow specification]. +//! +//! For example, the type `Int16Array` represents an array of 16-bit integers and consists of: +//! +//! * An optional [`Bitmap`] identifying any null values +//! * A contiguous [`Buffer`] of 16-bit integers +//! +//! Similarly, the type `StringArray` represents an array of UTF-8 strings and consists of: +//! +//! * An optional [`Bitmap`] identifying any null values +//! * An offsets [`Buffer`] of 32-bit integers identifying valid UTF-8 sequences within the values buffer +//! * A values [`Buffer`] of UTF-8 encoded string data +//! +//! [Arrow specification]: https://arrow.apache.org/docs/format/Columnar.html +//! [`&dyn Array`]: Array +//! [`Bitmap`]: crate::bitmap::Bitmap +//! [`Buffer`]: crate::buffer::Buffer #[allow(clippy::module_inception)] mod array; @@ -194,6 +255,14 @@ pub type UInt64Array = PrimitiveArray; /// /// # Example: Using `collect` /// ``` +/// # use arrow::array::Float16Array; +/// use half::f16; +/// let arr : Float16Array = [Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))].into_iter().collect(); +/// ``` +pub type Float16Array = PrimitiveArray; +/// +/// # Example: Using `collect` +/// ``` /// # use arrow::array::Float32Array; /// let arr : Float32Array = [Some(1.0), Some(2.0)].into_iter().collect(); /// ``` @@ -318,10 +387,58 @@ pub type UInt32DictionaryArray = DictionaryArray; /// assert_eq!(array.values(), &values); /// ``` pub type UInt64DictionaryArray = DictionaryArray; - +/// +/// A primitive array where each element is of type [TimestampSecondType]. +/// See also [`Timestamp`](crate::datatypes::DataType::Timestamp). +/// +/// # Example: UTC timestamps post epoch +/// ``` +/// # use arrow::array::TimestampSecondArray; +/// use chrono::FixedOffset; +/// // Corresponds to single element array with entry 1970-05-09T14:25:11+0:00 +/// let arr = TimestampSecondArray::from_vec(vec![11111111], None); +/// // OR +/// let arr = TimestampSecondArray::from_opt_vec(vec![Some(11111111)], None); +/// let utc_offset = FixedOffset::east(0); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, utc_offset).map(|v| v.to_string()).unwrap(), "1970-05-09 14:25:11") +/// ``` +/// +/// # Example: UTC timestamps pre epoch +/// ``` +/// # use arrow::array::TimestampSecondArray; +/// use chrono::FixedOffset; +/// // Corresponds to single element array with entry 1969-08-25T09:34:49+0:00 +/// let arr = TimestampSecondArray::from_vec(vec![-11111111], None); +/// // OR +/// let arr = TimestampSecondArray::from_opt_vec(vec![Some(-11111111)], None); +/// let utc_offset = FixedOffset::east(0); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, utc_offset).map(|v| v.to_string()).unwrap(), "1969-08-25 09:34:49") +/// ``` +/// +/// # Example: With timezone specified +/// ``` +/// # use arrow::array::TimestampSecondArray; +/// use chrono::FixedOffset; +/// // Corresponds to single element array with entry 1970-05-10T00:25:11+10:00 +/// let arr = TimestampSecondArray::from_vec(vec![11111111], Some("+10:00".to_string())); +/// // OR +/// let arr = TimestampSecondArray::from_opt_vec(vec![Some(11111111)], Some("+10:00".to_string())); +/// let sydney_offset = FixedOffset::east(10 * 60 * 60); +/// +/// assert_eq!(arr.value_as_datetime_with_tz(0, sydney_offset).map(|v| v.to_string()).unwrap(), "1970-05-10 00:25:11") +/// ``` +/// pub type TimestampSecondArray = PrimitiveArray; +/// A primitive array where each element is of type `TimestampMillisecondType.` +/// See examples for [`TimestampSecondArray.`](crate::array::TimestampSecondArray) pub type TimestampMillisecondArray = PrimitiveArray; +/// A primitive array where each element is of type `TimestampMicrosecondType.` +/// See examples for [`TimestampSecondArray.`](crate::array::TimestampSecondArray) pub type TimestampMicrosecondArray = PrimitiveArray; +/// A primitive array where each element is of type `TimestampNanosecondType.` +/// See examples for [`TimestampSecondArray.`](crate::array::TimestampSecondArray) pub type TimestampNanosecondArray = PrimitiveArray; pub type Date32Array = PrimitiveArray; pub type Date64Array = PrimitiveArray; @@ -331,20 +448,20 @@ pub type Time64MicrosecondArray = PrimitiveArray; pub type Time64NanosecondArray = PrimitiveArray; pub type IntervalYearMonthArray = PrimitiveArray; pub type IntervalDayTimeArray = PrimitiveArray; +pub type IntervalMonthDayNanoArray = PrimitiveArray; pub type DurationSecondArray = PrimitiveArray; pub type DurationMillisecondArray = PrimitiveArray; pub type DurationMicrosecondArray = PrimitiveArray; pub type DurationNanosecondArray = PrimitiveArray; -pub use self::array_binary::BinaryOffsetSizeTrait; pub use self::array_binary::GenericBinaryArray; pub use self::array_list::GenericListArray; pub use self::array_list::OffsetSizeTrait; pub use self::array_string::GenericStringArray; -pub use self::array_string::StringOffsetSizeTrait; // --------------------- Array Builder --------------------- +pub use self::builder::make_builder; pub use self::builder::BooleanBufferBuilder; pub use self::builder::BufferBuilder; @@ -359,22 +476,38 @@ pub type UInt64BufferBuilder = BufferBuilder; pub type Float32BufferBuilder = BufferBuilder; pub type Float64BufferBuilder = BufferBuilder; -pub type TimestampSecondBufferBuilder = BufferBuilder; -pub type TimestampMillisecondBufferBuilder = BufferBuilder; -pub type TimestampMicrosecondBufferBuilder = BufferBuilder; -pub type TimestampNanosecondBufferBuilder = BufferBuilder; -pub type Date32BufferBuilder = BufferBuilder; -pub type Date64BufferBuilder = BufferBuilder; -pub type Time32SecondBufferBuilder = BufferBuilder; -pub type Time32MillisecondBufferBuilder = BufferBuilder; -pub type Time64MicrosecondBufferBuilder = BufferBuilder; -pub type Time64NanosecondBufferBuilder = BufferBuilder; -pub type IntervalYearMonthBufferBuilder = BufferBuilder; -pub type IntervalDayTimeBufferBuilder = BufferBuilder; -pub type DurationSecondBufferBuilder = BufferBuilder; -pub type DurationMillisecondBufferBuilder = BufferBuilder; -pub type DurationMicrosecondBufferBuilder = BufferBuilder; -pub type DurationNanosecondBufferBuilder = BufferBuilder; +pub type TimestampSecondBufferBuilder = + BufferBuilder<::Native>; +pub type TimestampMillisecondBufferBuilder = + BufferBuilder<::Native>; +pub type TimestampMicrosecondBufferBuilder = + BufferBuilder<::Native>; +pub type TimestampNanosecondBufferBuilder = + BufferBuilder<::Native>; +pub type Date32BufferBuilder = BufferBuilder<::Native>; +pub type Date64BufferBuilder = BufferBuilder<::Native>; +pub type Time32SecondBufferBuilder = + BufferBuilder<::Native>; +pub type Time32MillisecondBufferBuilder = + BufferBuilder<::Native>; +pub type Time64MicrosecondBufferBuilder = + BufferBuilder<::Native>; +pub type Time64NanosecondBufferBuilder = + BufferBuilder<::Native>; +pub type IntervalYearMonthBufferBuilder = + BufferBuilder<::Native>; +pub type IntervalDayTimeBufferBuilder = + BufferBuilder<::Native>; +pub type IntervalMonthDayNanoBufferBuilder = + BufferBuilder<::Native>; +pub type DurationSecondBufferBuilder = + BufferBuilder<::Native>; +pub type DurationMillisecondBufferBuilder = + BufferBuilder<::Native>; +pub type DurationMicrosecondBufferBuilder = + BufferBuilder<::Native>; +pub type DurationNanosecondBufferBuilder = + BufferBuilder<::Native>; pub use self::builder::ArrayBuilder; pub use self::builder::BinaryBuilder; @@ -382,11 +515,13 @@ pub use self::builder::BooleanBuilder; pub use self::builder::DecimalBuilder; pub use self::builder::FixedSizeBinaryBuilder; pub use self::builder::FixedSizeListBuilder; +pub use self::builder::GenericListBuilder; pub use self::builder::GenericStringBuilder; pub use self::builder::LargeBinaryBuilder; pub use self::builder::LargeListBuilder; pub use self::builder::LargeStringBuilder; pub use self::builder::ListBuilder; +pub use self::builder::MapBuilder; pub use self::builder::PrimitiveBuilder; pub use self::builder::PrimitiveDictionaryBuilder; pub use self::builder::StringBuilder; @@ -417,6 +552,7 @@ pub type Time64MicrosecondBuilder = PrimitiveBuilder; pub type Time64NanosecondBuilder = PrimitiveBuilder; pub type IntervalYearMonthBuilder = PrimitiveBuilder; pub type IntervalDayTimeBuilder = PrimitiveBuilder; +pub type IntervalMonthDayNanoBuilder = PrimitiveBuilder; pub type DurationSecondBuilder = PrimitiveBuilder; pub type DurationMillisecondBuilder = PrimitiveBuilder; pub type DurationMicrosecondBuilder = PrimitiveBuilder; @@ -439,11 +575,46 @@ pub use self::ord::{build_compare, DynComparator}; // --------------------- Array downcast helper functions --------------------- pub use self::cast::{ - as_boolean_array, as_dictionary_array, as_generic_binary_array, + as_boolean_array, as_decimal_array, as_dictionary_array, as_generic_binary_array, as_generic_list_array, as_large_list_array, as_largestring_array, as_list_array, - as_null_array, as_primitive_array, as_string_array, as_struct_array, + as_map_array, as_null_array, as_primitive_array, as_string_array, as_struct_array, + as_union_array, }; // ------------------------------ C Data Interface --------------------------- -pub use self::array::make_array_from_raw; +pub use self::array::{export_array_into_raw, make_array_from_raw}; + +#[cfg(test)] +mod tests { + use crate::array::*; + + #[test] + fn test_buffer_builder_availability() { + let _builder = Int8BufferBuilder::new(10); + let _builder = Int16BufferBuilder::new(10); + let _builder = Int32BufferBuilder::new(10); + let _builder = Int64BufferBuilder::new(10); + let _builder = UInt16BufferBuilder::new(10); + let _builder = UInt32BufferBuilder::new(10); + let _builder = Float32BufferBuilder::new(10); + let _builder = Float64BufferBuilder::new(10); + let _builder = TimestampSecondBufferBuilder::new(10); + let _builder = TimestampMillisecondBufferBuilder::new(10); + let _builder = TimestampMicrosecondBufferBuilder::new(10); + let _builder = TimestampNanosecondBufferBuilder::new(10); + let _builder = Date32BufferBuilder::new(10); + let _builder = Date64BufferBuilder::new(10); + let _builder = Time32SecondBufferBuilder::new(10); + let _builder = Time32MillisecondBufferBuilder::new(10); + let _builder = Time64MicrosecondBufferBuilder::new(10); + let _builder = Time64NanosecondBufferBuilder::new(10); + let _builder = IntervalYearMonthBufferBuilder::new(10); + let _builder = IntervalDayTimeBufferBuilder::new(10); + let _builder = IntervalMonthDayNanoBufferBuilder::new(10); + let _builder = DurationSecondBufferBuilder::new(10); + let _builder = DurationMillisecondBufferBuilder::new(10); + let _builder = DurationMicrosecondBufferBuilder::new(10); + let _builder = DurationNanosecondBufferBuilder::new(10); + } +} diff --git a/arrow/src/array/ord.rs b/arrow/src/array/ord.rs index d6534efc9286..be910f96bd54 100644 --- a/arrow/src/array/ord.rs +++ b/arrow/src/array/ord.rs @@ -72,7 +72,7 @@ where fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator where - T: StringOffsetSizeTrait, + T: OffsetSizeTrait, { let left: StringArray = StringArray::from(left.data().clone()); let right: StringArray = StringArray::from(right.data().clone()); @@ -174,6 +174,9 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { compare_primitives::(left, right) } + (Interval(MonthDayNano), Interval(MonthDayNano)) => { + compare_primitives::(left, right) + } (Duration(Second), Duration(Second)) => { compare_primitives::(left, right) } @@ -222,6 +225,11 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { + let left: DecimalArray = DecimalArray::from(left.data().clone()); + let right: DecimalArray = DecimalArray::from(right.data().clone()); + Box::new(move |i, j| left.value(i).cmp(&right.value(j))) + } (lhs, _) => { return Err(ArrowError::InvalidArgumentError(format!( "The data type type {:?} has no natural order", @@ -290,6 +298,20 @@ pub mod tests { Ok(()) } + #[test] + fn test_decimal() -> Result<()> { + let array = vec![Some(5), Some(2), Some(3)] + .iter() + .collect::() + .with_precision_and_scale(23, 6) + .unwrap(); + + let cmp = build_compare(&array, &array)?; + assert_eq!(Ordering::Less, (cmp)(1, 0)); + assert_eq!(Ordering::Greater, (cmp)(0, 2)); + Ok(()) + } + #[test] fn test_dict() -> Result<()> { let data = vec!["a", "b", "c", "a", "a", "c", "c"]; diff --git a/arrow/src/array/raw_pointer.rs b/arrow/src/array/raw_pointer.rs index cd1f802ae778..1016b808bc5a 100644 --- a/arrow/src/array/raw_pointer.rs +++ b/arrow/src/array/raw_pointer.rs @@ -17,8 +17,10 @@ use std::ptr::NonNull; -/// This struct is highly `unsafe` and offers the possibility to self-reference a [arrow::Buffer] from [arrow::array::ArrayData]. -/// as a pointer to the beginning of its contents. +/// This struct is highly `unsafe` and offers the possibility to +/// self-reference a [crate::buffer::Buffer] from +/// [crate::array::ArrayData], as a pointer to the beginning of its +/// contents. pub(super) struct RawPtrBox { ptr: NonNull, } diff --git a/arrow/src/array/transform/boolean.rs b/arrow/src/array/transform/boolean.rs index 182914971732..e0b6231a226e 100644 --- a/arrow/src/array/transform/boolean.rs +++ b/arrow/src/array/transform/boolean.rs @@ -15,12 +15,9 @@ // specific language governing permissions and limitations // under the License. +use super::{Extend, _MutableArrayData, utils::resize_for_bits}; use crate::array::ArrayData; - -use super::{ - Extend, _MutableArrayData, - utils::{resize_for_bits, set_bits}, -}; +use crate::util::bit_mask::set_bits; pub(super) fn build_extend(array: &ArrayData) -> Extend { let values = array.buffers()[0].as_slice(); @@ -29,7 +26,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { let buffer = &mut mutable.buffer1; resize_for_bits(buffer, mutable.len + len); set_bits( - &mut buffer.as_slice_mut(), + buffer.as_slice_mut(), values, mutable.len, array.offset() + start, diff --git a/arrow/src/array/transform/fixed_size_list.rs b/arrow/src/array/transform/fixed_size_list.rs new file mode 100644 index 000000000000..77912a7026fd --- /dev/null +++ b/arrow/src/array/transform/fixed_size_list.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; +use crate::datatypes::DataType; + +use super::{Extend, _MutableArrayData}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let size = match array.data_type() { + DataType::FixedSizeList(_, i) => *i as usize, + _ => unreachable!(), + }; + + if array.null_count() == 0 { + Box::new( + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { + mutable.child_data.iter_mut().for_each(|child| { + child.extend(index, start * size, (start + len) * size) + }) + }, + ) + } else { + Box::new( + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { + (start..start + len).for_each(|i| { + if array.is_valid(i) { + mutable.child_data.iter_mut().for_each(|child| { + child.extend(index, i * size, (i + 1) * size) + }) + } else { + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend_nulls(size)) + } + }) + }, + ) + } +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let size = match mutable.data_type { + DataType::FixedSizeList(_, i) => i as usize, + _ => unreachable!(), + }; + + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend_nulls(len * size)) +} diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index a598f0d7167e..68ae7f6d4d0d 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -15,26 +15,27 @@ // specific language governing permissions and limitations // under the License. +use super::{ + data::{into_buffers, new_buffers}, + ArrayData, ArrayDataBuilder, OffsetSizeTrait, +}; use crate::{ buffer::MutableBuffer, datatypes::DataType, error::{ArrowError, Result}, util::bit_util, }; +use half::f16; use std::mem; -use super::{ - data::{into_buffers, new_buffers}, - ArrayData, ArrayDataBuilder, -}; -use crate::array::StringOffsetSizeTrait; - mod boolean; mod fixed_binary; +mod fixed_size_list; mod list; mod null; mod primitive; mod structure; +mod union; mod utils; mod variable_size; @@ -77,18 +78,13 @@ impl<'a> _MutableArrayData<'a> { } }; - let mut array_data_builder = ArrayDataBuilder::new(self.data_type) + ArrayDataBuilder::new(self.data_type) .offset(0) .len(self.len) .null_count(self.null_count) .buffers(buffers) - .child_data(child_data); - if self.null_count > 0 { - array_data_builder = - array_data_builder.null_bit_buffer(self.null_buffer.into()); - } - - array_data_builder + .child_data(child_data) + .null_bit_buffer((self.null_count > 0).then(|| self.null_buffer.into())) } } @@ -97,7 +93,7 @@ fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits let bytes = bitmap.bits.as_slice(); Box::new(move |mutable, start, len| { utils::resize_for_bits(&mut mutable.null_buffer, mutable.len + len); - mutable.null_count += utils::set_bits( + mutable.null_count += crate::util::bit_mask::set_bits( mutable.null_buffer.as_slice_mut(), bytes, mutable.len, @@ -140,6 +136,7 @@ fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits /// assert_eq!(Int32Array::from(vec![2, 3, 1, 2, 3]), new_array); /// ``` pub struct MutableArrayData<'a> { + #[allow(dead_code)] arrays: Vec<&'a ArrayData>, // The attributes in [_MutableArrayData] cannot be in [MutableArrayData] due to // mutability invariants (interior mutability): @@ -182,50 +179,23 @@ fn build_extend_dictionary( max: usize, ) -> Option { use crate::datatypes::*; - use std::convert::TryInto; - + macro_rules! validate_and_build { + ($dt: ty) => {{ + let _: $dt = max.try_into().ok()?; + let offset: $dt = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + }}; + } match array.data_type() { DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { - DataType::UInt8 => { - let _: u8 = max.try_into().ok()?; - let offset: u8 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } - DataType::UInt16 => { - let _: u16 = max.try_into().ok()?; - let offset: u16 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } - DataType::UInt32 => { - let _: u32 = max.try_into().ok()?; - let offset: u32 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } - DataType::UInt64 => { - let _: u64 = max.try_into().ok()?; - let offset: u64 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } - DataType::Int8 => { - let _: i8 = max.try_into().ok()?; - let offset: i8 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } - DataType::Int16 => { - let _: i16 = max.try_into().ok()?; - let offset: i16 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } - DataType::Int32 => { - let _: i32 = max.try_into().ok()?; - let offset: i32 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } - DataType::Int64 => { - let _: i64 = max.try_into().ok()?; - let offset: i64 = offset.try_into().ok()?; - Some(primitive::build_extend_with_offset(array, offset)) - } + DataType::UInt8 => validate_and_build!(u8), + DataType::UInt16 => validate_and_build!(u16), + DataType::UInt32 => validate_and_build!(u32), + DataType::UInt64 => validate_and_build!(u64), + DataType::Int8 => validate_and_build!(i8), + DataType::Int16 => validate_and_build!(i16), + DataType::Int32 => validate_and_build!(i32), + DataType::Int64 => validate_and_build!(i64), _ => unreachable!(), }, _ => None, @@ -235,6 +205,7 @@ fn build_extend_dictionary( fn build_extend(array: &ArrayData) -> Extend { use crate::datatypes::*; match array.data_type() { + DataType::Decimal(_, _) => primitive::build_extend::(array), DataType::Null => null::build_extend(array), DataType::Boolean => boolean::build_extend(array), DataType::UInt8 => primitive::build_extend::(array), @@ -259,27 +230,31 @@ fn build_extend(array: &ArrayData) -> Extend { | DataType::Interval(IntervalUnit::DayTime) => { primitive::build_extend::(array) } + DataType::Interval(IntervalUnit::MonthDayNano) => { + primitive::build_extend::(array) + } DataType::Utf8 | DataType::Binary => variable_size::build_extend::(array), DataType::LargeUtf8 | DataType::LargeBinary => { variable_size::build_extend::(array) } - DataType::List(_) => list::build_extend::(array), + DataType::Map(_, _) | DataType::List(_) => list::build_extend::(array), DataType::LargeList(_) => list::build_extend::(array), DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), - DataType::Float16 => unreachable!(), - /* - DataType::FixedSizeList(_, _) => {} - DataType::Union(_) => {} - */ - _ => todo!("Take and filter operations still not supported for this datatype"), + DataType::Float16 => primitive::build_extend::(array), + DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array), + DataType::Union(_, _, mode) => match mode { + UnionMode::Sparse => union::build_extend_sparse(array), + UnionMode::Dense => union::build_extend_dense(array), + }, } } fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { use crate::datatypes::*; Box::new(match data_type { + DataType::Decimal(_, _) => primitive::extend_nulls::, DataType::Null => null::extend_nulls, DataType::Boolean => boolean::extend_nulls, DataType::UInt8 => primitive::extend_nulls::, @@ -300,9 +275,10 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { | DataType::Timestamp(_, _) | DataType::Duration(_) | DataType::Interval(IntervalUnit::DayTime) => primitive::extend_nulls::, + DataType::Interval(IntervalUnit::MonthDayNano) => primitive::extend_nulls::, DataType::Utf8 | DataType::Binary => variable_size::extend_nulls::, DataType::LargeUtf8 | DataType::LargeBinary => variable_size::extend_nulls::, - DataType::List(_) => list::extend_nulls::, + DataType::Map(_, _) | DataType::List(_) => list::extend_nulls::, DataType::LargeList(_) => list::extend_nulls::, DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { DataType::UInt8 => primitive::extend_nulls::, @@ -317,23 +293,23 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { }, DataType::Struct(_) => structure::extend_nulls, DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, - DataType::Float16 => unreachable!(), - /* - DataType::FixedSizeList(_, _) => {} - DataType::Union(_) => {} - */ - _ => todo!("Take and filter operations still not supported for this datatype"), + DataType::Float16 => primitive::extend_nulls::, + DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls, + DataType::Union(_, _, mode) => match mode { + UnionMode::Sparse => union::extend_nulls_sparse, + UnionMode::Dense => union::extend_nulls_dense, + }, }) } -fn preallocate_offset_and_binary_buffer( +fn preallocate_offset_and_binary_buffer( capacity: usize, binary_size: usize, ) -> [MutableBuffer; 2] { // offsets let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); // safety: `unsafe` code assumes that this buffer is initialized with one element - if Offset::is_large() { + if Offset::IS_LARGE { buffer.push(0i64); } else { buffer.push(0i32) @@ -380,15 +356,15 @@ impl<'a> MutableArrayData<'a> { Self::with_capacities(arrays, use_nulls, Capacities::Array(capacity)) } - /// Similar to [MutableArray::new], but lets users define the preallocated capacities of the array. - /// See also [MutableArray::new] for more information on the arguments. + /// Similar to [MutableArrayData::new], but lets users define the preallocated capacities of the array. + /// See also [MutableArrayData::new] for more information on the arguments. /// /// # Panic /// This function panics if the given `capacities` don't match the data type of `arrays`. Or when /// a [Capacities] variant is not yet supported. pub fn with_capacities( arrays: Vec<&'a ArrayData>, - mut use_nulls: bool, + use_nulls: bool, capacities: Capacities, ) -> Self { let data_type = arrays[0].data_type(); @@ -396,20 +372,22 @@ impl<'a> MutableArrayData<'a> { // if any of the arrays has nulls, insertions from any array requires setting bits // as there is at least one array with nulls. - if arrays.iter().any(|array| array.null_count() > 0) { - use_nulls = true; - }; + let use_nulls = use_nulls | arrays.iter().any(|array| array.null_count() > 0); let mut array_capacity; let [buffer1, buffer2] = match (data_type, &capacities) { - (DataType::LargeUtf8, Capacities::Binary(capacity, Some(value_cap))) - | (DataType::LargeBinary, Capacities::Binary(capacity, Some(value_cap))) => { + ( + DataType::LargeUtf8 | DataType::LargeBinary, + Capacities::Binary(capacity, Some(value_cap)), + ) => { array_capacity = *capacity; preallocate_offset_and_binary_buffer::(*capacity, *value_cap) } - (DataType::Utf8, Capacities::Binary(capacity, Some(value_cap))) - | (DataType::Binary, Capacities::Binary(capacity, Some(value_cap))) => { + ( + DataType::Utf8 | DataType::Binary, + Capacities::Binary(capacity, Some(value_cap)), + ) => { array_capacity = *capacity; preallocate_offset_and_binary_buffer::(*capacity, *value_cap) } @@ -417,11 +395,19 @@ impl<'a> MutableArrayData<'a> { array_capacity = *capacity; new_buffers(data_type, *capacity) } + ( + DataType::List(_) | DataType::LargeList(_), + Capacities::List(capacity, _), + ) => { + array_capacity = *capacity; + new_buffers(data_type, *capacity) + } _ => panic!("Capacities: {:?} not yet supported", capacities), }; let child_data = match &data_type { - DataType::Null + DataType::Decimal(_, _) + | DataType::Null | DataType::Boolean | DataType::UInt8 | DataType::UInt16 @@ -431,6 +417,7 @@ impl<'a> MutableArrayData<'a> { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Float16 | DataType::Float32 | DataType::Float64 | DataType::Date32 @@ -445,7 +432,7 @@ impl<'a> MutableArrayData<'a> { | DataType::LargeBinary | DataType::Interval(_) | DataType::FixedSizeBinary(_) => vec![], - DataType::List(_) | DataType::LargeList(_) => { + DataType::Map(_, _) | DataType::List(_) | DataType::LargeList(_) => { let childs = arrays .iter() .map(|array| &array.child_data()[0]) @@ -454,11 +441,10 @@ impl<'a> MutableArrayData<'a> { let capacities = if let Capacities::List(capacity, ref child_capacities) = capacities { - array_capacity = capacity; child_capacities .clone() .map(|c| *c) - .unwrap_or(Capacities::Array(array_capacity)) + .unwrap_or(Capacities::Array(capacity)) } else { Capacities::Array(array_capacity) }; @@ -469,7 +455,6 @@ impl<'a> MutableArrayData<'a> { } // the dictionary type just appends keys and clones the values. DataType::Dictionary(_, _) => vec![], - DataType::Float16 => unreachable!(), DataType::Struct(fields) => match capacities { Capacities::Struct(capacity, Some(ref child_capacities)) => { array_capacity = capacity; @@ -510,39 +495,58 @@ impl<'a> MutableArrayData<'a> { }) .collect::>(), }, - _ => { - todo!("Take and filter operations still not supported for this datatype") + DataType::FixedSizeList(_, _) => { + let childs = arrays + .iter() + .map(|array| &array.child_data()[0]) + .collect::>(); + vec![MutableArrayData::new(childs, use_nulls, array_capacity)] } + DataType::Union(fields, _, _) => (0..fields.len()) + .map(|i| { + let child_arrays = arrays + .iter() + .map(|array| &array.child_data()[i]) + .collect::>(); + MutableArrayData::new(child_arrays, use_nulls, array_capacity) + }) + .collect::>(), }; - let dictionary = match &data_type { - DataType::Dictionary(_, _) => match arrays.len() { - 0 => unreachable!(), - 1 => Some(arrays[0].child_data()[0].clone()), - _ => { - if let Capacities::Dictionary(_, _) = capacities { - panic!("dictionary capacity not yet supported") - } - // Concat dictionaries together - let dictionaries: Vec<_> = - arrays.iter().map(|array| &array.child_data()[0]).collect(); - let lengths: Vec<_> = dictionaries - .iter() - .map(|dictionary| dictionary.len()) - .collect(); - let capacity = lengths.iter().sum(); + // Get the dictionary if any, and if it is a concatenation of multiple + let (dictionary, dict_concat) = match &data_type { + DataType::Dictionary(_, _) => { + // If more than one dictionary, concatenate dictionaries together + let dict_concat = !arrays + .windows(2) + .all(|a| a[0].child_data()[0].ptr_eq(&a[1].child_data()[0])); + + match dict_concat { + false => (Some(arrays[0].child_data()[0].clone()), false), + true => { + if let Capacities::Dictionary(_, _) = capacities { + panic!("dictionary capacity not yet supported") + } + let dictionaries: Vec<_> = + arrays.iter().map(|array| &array.child_data()[0]).collect(); + let lengths: Vec<_> = dictionaries + .iter() + .map(|dictionary| dictionary.len()) + .collect(); + let capacity = lengths.iter().sum(); - let mut mutable = - MutableArrayData::new(dictionaries, false, capacity); + let mut mutable = + MutableArrayData::new(dictionaries, false, capacity); - for (i, len) in lengths.iter().enumerate() { - mutable.extend(i, 0, *len) - } + for (i, len) in lengths.iter().enumerate() { + mutable.extend(i, 0, *len) + } - Some(mutable.freeze()) + (Some(mutable.freeze()), true) + } } - }, - _ => None, + } + _ => (None, false), }; let extend_nulls = build_extend_nulls(data_type); @@ -567,8 +571,13 @@ impl<'a> MutableArrayData<'a> { .iter() .map(|array| { let offset = next_offset; - next_offset += array.child_data()[0].len(); - build_extend_dictionary(array, offset, next_offset) + let dict_len = array.child_data()[0].len(); + + if dict_concat { + next_offset += dict_len; + } + + build_extend_dictionary(array, offset, offset + dict_len) .ok_or(ArrowError::DictionaryKeyOverflowError) }) .collect(); @@ -597,10 +606,17 @@ impl<'a> MutableArrayData<'a> { } } - /// Extends this [MutableArrayData] with elements from the bounded [ArrayData] at `start` - /// and for a size of `len`. + /// Extends this array with a chunk of its source arrays + /// + /// # Arguments + /// * `index` - the index of array that you what to copy values from + /// * `start` - the start index of the chunk (inclusive) + /// * `end` - the end index of the chunk (exclusive) + /// /// # Panic - /// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`. + /// This function panics if there is an invalid index, + /// i.e. `index` >= the number of source arrays + /// or `end` > the length of the `index`th array pub fn extend(&mut self, index: usize, start: usize, end: usize) { let len = end - start; (self.extend_null_bits[index])(&mut self.data, start, len); @@ -654,12 +670,13 @@ mod tests { use super::*; + use crate::array::DecimalArray; use crate::{ array::{ Array, ArrayData, ArrayRef, BooleanArray, DictionaryArray, FixedSizeBinaryArray, Int16Array, Int16Type, Int32Array, Int64Array, - Int64Builder, ListBuilder, NullArray, PrimitiveBuilder, StringArray, - StringDictionaryBuilder, StructArray, UInt8Array, + Int64Builder, ListBuilder, MapBuilder, NullArray, PrimitiveBuilder, + StringArray, StringDictionaryBuilder, StructArray, UInt8Array, }, buffer::Buffer, datatypes::Field, @@ -669,6 +686,68 @@ mod tests { error::Result, }; + fn create_decimal_array( + array: &[Option], + precision: usize, + scale: usize, + ) -> DecimalArray { + array + .iter() + .collect::() + .with_precision_and_scale(precision, scale) + .unwrap() + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_decimal() { + let decimal_array = + create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3); + let arrays = vec![decimal_array.data()]; + let mut a = MutableArrayData::new(arrays, true, 3); + a.extend(0, 0, 3); + a.extend(0, 2, 3); + let result = a.freeze(); + let array = DecimalArray::from(result); + let expected = create_decimal_array(&[Some(1), Some(2), None, None], 10, 3); + assert_eq!(array, expected); + } + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_decimal_offset() { + let decimal_array = + create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3); + let decimal_array = decimal_array.slice(1, 3); // 2, null, 3 + let arrays = vec![decimal_array.data()]; + let mut a = MutableArrayData::new(arrays, true, 2); + a.extend(0, 0, 2); // 2, null + let result = a.freeze(); + let array = DecimalArray::from(result); + let expected = create_decimal_array(&[Some(2), None], 10, 3); + assert_eq!(array, expected); + } + + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_decimal_null_offset_nulls() { + let decimal_array = + create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3); + let decimal_array = decimal_array.slice(1, 3); // 2, null, 3 + let arrays = vec![decimal_array.data()]; + let mut a = MutableArrayData::new(arrays, true, 2); + a.extend(0, 0, 2); // 2, null + a.extend_nulls(3); // 2, null, null, null, null + a.extend(0, 1, 3); //2, null, null, null, null, null, 3 + let result = a.freeze(); + let array = DecimalArray::from(result); + let expected = create_decimal_array( + &[Some(2), None, None, None, None, None, Some(3)], + 10, + 3, + ); + assert_eq!(array, expected); + } + /// tests extending from a primitive array w/ offset nor nulls #[test] fn test_primitive() { @@ -1154,7 +1233,6 @@ mod tests { DataType::List(Box::new(Field::new("item", DataType::Int64, true))), 8, None, - None, 0, vec![list_value_offsets], vec![expected_int_array.data().clone()], @@ -1235,7 +1313,6 @@ mod tests { let expected_list_data = ArrayData::try_new( DataType::List(Box::new(Field::new("item", DataType::Int64, true))), 12, - None, Some(Buffer::from(&[0b11011011, 0b1110])), 0, vec![list_value_offsets], @@ -1247,6 +1324,190 @@ mod tests { Ok(()) } + #[test] + fn test_list_append_with_capacities() -> Result<()> { + let mut builder = ListBuilder::::new(Int64Builder::new(24)); + builder.values().append_slice(&[1, 2, 3])?; + builder.append(true)?; + builder.values().append_slice(&[4, 5])?; + builder.append(true)?; + builder.values().append_slice(&[6, 7, 8])?; + builder.values().append_slice(&[9, 10, 11])?; + builder.append(true)?; + let a = builder.finish(); + + let a_builder = Int64Builder::new(24); + let mut a_builder = ListBuilder::::new(a_builder); + a_builder.values().append_slice(&[12, 13])?; + a_builder.append(true)?; + a_builder.append(true)?; + a_builder.values().append_slice(&[14, 15, 16, 17])?; + a_builder.append(true)?; + let b = a_builder.finish(); + + let mutable = MutableArrayData::with_capacities( + vec![a.data(), b.data()], + false, + Capacities::List(6, Some(Box::new(Capacities::Array(17)))), + ); + + // capacities are rounded up to multiples of 64 by MutableBuffer + assert_eq!(mutable.data.buffer1.capacity(), 64); + assert_eq!(mutable.data.child_data[0].data.buffer1.capacity(), 192); + + Ok(()) + } + + #[test] + fn test_map_nulls_append() -> Result<()> { + let mut builder = MapBuilder::::new( + None, + Int64Builder::new(32), + Int64Builder::new(32), + ); + builder.keys().append_slice(&[1, 2, 3])?; + builder.values().append_slice(&[1, 2, 3])?; + builder.append(true)?; + builder.keys().append_slice(&[4, 5])?; + builder.values().append_slice(&[4, 5])?; + builder.append(true)?; + builder.append(false)?; + builder + .keys() + .append_slice(&[6, 7, 8, 100, 101, 9, 10, 11])?; + builder.values().append_slice(&[6, 7, 8])?; + builder.values().append_null()?; + builder.values().append_null()?; + builder.values().append_slice(&[9, 10, 11])?; + builder.append(true)?; + + let a = builder.finish(); + let a = a.data(); + + let mut builder = MapBuilder::::new( + None, + Int64Builder::new(32), + Int64Builder::new(32), + ); + + builder.keys().append_slice(&[12, 13])?; + builder.values().append_slice(&[12, 13])?; + builder.append(true)?; + builder.append(false)?; + builder.append(true)?; + builder.keys().append_slice(&[100, 101, 14, 15])?; + builder.values().append_null()?; + builder.values().append_null()?; + builder.values().append_slice(&[14, 15])?; + builder.append(true)?; + + let b = builder.finish(); + let b = b.data(); + let c = b.slice(1, 2); + let d = b.slice(2, 2); + + let mut mutable = MutableArrayData::new(vec![a, b, &c, &d], false, 10); + + mutable.extend(0, 0, a.len()); + mutable.extend(1, 0, b.len()); + mutable.extend(2, 0, c.len()); + mutable.extend(3, 0, d.len()); + let result = mutable.freeze(); + + let expected_key_array = Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(100), + Some(101), + Some(9), + Some(10), + Some(11), + // second array + Some(12), + Some(13), + Some(100), + Some(101), + Some(14), + Some(15), + // slice(1, 2) results in no values added + Some(100), + Some(101), + Some(14), + Some(15), + ]); + + let expected_value_array = Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + None, + None, + Some(9), + Some(10), + Some(11), + // second array + Some(12), + Some(13), + None, + None, + Some(14), + Some(15), + // slice(1, 2) results in no values added + None, + None, + Some(14), + Some(15), + ]); + + let expected_entry_array = StructArray::from(vec![ + ( + Field::new("keys", DataType::Int64, false), + Arc::new(expected_key_array) as ArrayRef, + ), + ( + Field::new("values", DataType::Int64, true), + Arc::new(expected_value_array) as ArrayRef, + ), + ]); + + let map_offsets = + Buffer::from_slice_ref(&[0, 3, 5, 5, 13, 15, 15, 15, 19, 19, 19, 19, 23]); + + let expected_list_data = ArrayData::try_new( + DataType::Map( + Box::new(Field::new( + "entries", + DataType::Struct(vec![ + Field::new("keys", DataType::Int64, false), + Field::new("values", DataType::Int64, true), + ]), + false, + )), + false, + ), + 12, + Some(Buffer::from(&[0b11011011, 0b1110])), + 0, + vec![map_offsets], + vec![expected_entry_array.data().clone()], + ) + .unwrap(); + assert_eq!(result, expected_list_data); + + Ok(()) + } + #[test] fn test_list_of_strings_append() -> Result<()> { // [["alpha", "beta", None]] @@ -1308,7 +1569,6 @@ mod tests { DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), 6, None, - None, 0, vec![list_value_offsets], vec![expected_string_array.data().clone()], diff --git a/arrow/src/array/transform/union.rs b/arrow/src/array/transform/union.rs new file mode 100644 index 000000000000..bbea508219d0 --- /dev/null +++ b/arrow/src/array/transform/union.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; + +use super::{Extend, _MutableArrayData}; + +pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { + let type_ids = array.buffer::(0); + + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); + + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, start, start + len)) + }, + ) +} + +pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { + let type_ids = array.buffer::(0); + let offsets = array.buffer::(1); + + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); + + (start..start + len).for_each(|i| { + let type_id = type_ids[i] as usize; + let src_offset = offsets[i] as usize; + let child_data = &mut mutable.child_data[type_id]; + let dst_offset = child_data.len(); + + // Extend offsets + mutable.buffer2.push(dst_offset as i32); + mutable.child_data[type_id].extend(index, src_offset, src_offset + 1) + }) + }, + ) +} + +pub(super) fn extend_nulls_dense(_mutable: &mut _MutableArrayData, _len: usize) { + panic!("cannot call extend_nulls on UnionArray as cannot infer type"); +} + +pub(super) fn extend_nulls_sparse(_mutable: &mut _MutableArrayData, _len: usize) { + panic!("cannot call extend_nulls on UnionArray as cannot infer type"); +} diff --git a/arrow/src/array/transform/utils.rs b/arrow/src/array/transform/utils.rs index ed7f455a7b2b..68aee79c41bb 100644 --- a/arrow/src/array/transform/utils.rs +++ b/arrow/src/array/transform/utils.rs @@ -15,14 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{ - array::OffsetSizeTrait, - buffer::MutableBuffer, - util::{ - bit_chunk_iterator::BitChunks, - bit_util::{self, ceil}, - }, -}; +use crate::{array::OffsetSizeTrait, buffer::MutableBuffer, util::bit_util}; /// extends the `buffer` to be able to hold `len` bits, setting all bits of the new size to zero. #[inline] @@ -33,49 +26,6 @@ pub(super) fn resize_for_bits(buffer: &mut MutableBuffer, len: usize) { } } -/// sets all bits on `write_data` on the range `[offset_write..offset_write+len]` to be equal to the -/// bits on `data` on the range `[offset_read..offset_read+len]` -/// returns the number of `0` bits `data[offset_read..offset_read+len]` -pub(super) fn set_bits( - write_data: &mut [u8], - data: &[u8], - offset_write: usize, - offset_read: usize, - len: usize, -) -> usize { - let mut null_count = 0; - - let mut bits_to_align = offset_write % 8; - if bits_to_align > 0 { - bits_to_align = std::cmp::min(len, 8 - bits_to_align); - } - let mut write_byte_index = ceil(offset_write + bits_to_align, 8); - - // Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time) - let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align); - chunks.iter().for_each(|chunk| { - null_count += chunk.count_zeros(); - chunk.to_le_bytes().iter().for_each(|b| { - write_data[write_byte_index] = *b; - write_byte_index += 1; - }) - }); - - // Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator - let remainder_offset = len - chunks.remainder_len(); - (0..bits_to_align) - .chain(remainder_offset..len) - .for_each(|i| { - if bit_util::get_bit(data, offset_read + i) { - bit_util::set_bit(write_data, offset_write + i); - } else { - null_count += 1; - } - }); - - null_count as usize -} - pub(super) fn extend_offsets( buffer: &mut MutableBuffer, mut last_offset: T, @@ -104,130 +54,3 @@ pub(super) unsafe fn get_last_offset( debug_assert!(prefix.is_empty() && suffix.is_empty()); *offsets.get_unchecked(offsets.len() - 1) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_set_bits_aligned() { - let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, - ]; - - let destination_offset = 8; - let source_offset = 0; - - let len = 64; - - let expected_data: &[u8] = &[ - 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0, - ]; - let expected_null_count = 24; - let result = set_bits( - destination.as_mut_slice(), - source, - destination_offset, - source_offset, - len, - ); - - assert_eq!(destination, expected_data); - assert_eq!(result, expected_null_count); - } - - #[test] - fn test_set_bits_unaligned_destination_start() { - let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, - ]; - - let destination_offset = 3; - let source_offset = 0; - - let len = 64; - - let expected_data: &[u8] = &[ - 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, - 0b00111110, 0b00101111, 0b00000101, 0b00000000, - ]; - let expected_null_count = 24; - let result = set_bits( - destination.as_mut_slice(), - source, - destination_offset, - source_offset, - len, - ); - - assert_eq!(destination, expected_data); - assert_eq!(result, expected_null_count); - } - - #[test] - fn test_set_bits_unaligned_destination_end() { - let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, - ]; - - let destination_offset = 8; - let source_offset = 0; - - let len = 62; - - let expected_data: &[u8] = &[ - 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b00100101, 0, - ]; - let expected_null_count = 23; - let result = set_bits( - destination.as_mut_slice(), - source, - destination_offset, - source_offset, - len, - ); - - assert_eq!(destination, expected_data); - assert_eq!(result, expected_null_count); - } - - #[test] - fn test_set_bits_unaligned() { - let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - let source: &[u8] = &[ - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, - ]; - - let destination_offset = 3; - let source_offset = 5; - - let len = 95; - - let expected_data: &[u8] = &[ - 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, - 0b01111001, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, - 0b00000001, - ]; - let expected_null_count = 35; - let result = set_bits( - destination.as_mut_slice(), - source, - destination_offset, - source_offset, - len, - ); - - assert_eq!(destination, expected_data); - assert_eq!(result, expected_null_count); - } -} diff --git a/arrow/src/bitmap.rs b/arrow/src/bitmap.rs index 9d3784b31b85..4ba1bb9f8882 100644 --- a/arrow/src/bitmap.rs +++ b/arrow/src/bitmap.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines a bitmap, which is used to track which values in an Arrow array are null. -//! This is called a "validity bitmap" in the Arrow documentation. +//! Defines [Bitmap] for tracking validity bitmaps use crate::buffer::Buffer; use crate::error::Result; @@ -26,25 +25,25 @@ use std::mem; use std::ops::{BitAnd, BitOr}; #[derive(Debug, Clone)] +/// Defines a bitmap, which is used to track which values in an Arrow +/// array are null. +/// +/// This is called a "validity bitmap" in the Arrow documentation. pub struct Bitmap { pub(crate) bits: Buffer, } impl Bitmap { pub fn new(num_bits: usize) -> Self { - let num_bytes = num_bits / 8 + if num_bits % 8 > 0 { 1 } else { 0 }; - let r = num_bytes % 64; - let len = if r == 0 { - num_bytes - } else { - num_bytes + 64 - r - }; + let num_bytes = bit_util::ceil(num_bits, 8); + let len = bit_util::round_upto_multiple_of_64(num_bytes); Bitmap { bits: Buffer::from(&vec![0xFF; len]), } } - pub fn len(&self) -> usize { + /// Return the length of this Bitmap in bits (not bytes) + pub fn bit_len(&self) -> usize { self.bits.len() * 8 } @@ -117,9 +116,9 @@ mod tests { #[test] fn test_bitmap_length() { - assert_eq!(512, Bitmap::new(63 * 8).len()); - assert_eq!(512, Bitmap::new(64 * 8).len()); - assert_eq!(1024, Bitmap::new(65 * 8).len()); + assert_eq!(512, Bitmap::new(63 * 8).bit_len()); + assert_eq!(512, Bitmap::new(64 * 8).bit_len()); + assert_eq!(1024, Bitmap::new(65 * 8).bit_len()); } #[test] diff --git a/arrow/src/buffer/immutable.rs b/arrow/src/buffer/immutable.rs index f0aefd9b94b7..f5d59c5ed555 100644 --- a/arrow/src/buffer/immutable.rs +++ b/arrow/src/buffer/immutable.rs @@ -21,12 +21,10 @@ use std::ptr::NonNull; use std::sync::Arc; use std::{convert::AsRef, usize}; -use crate::util::bit_chunk_iterator::BitChunks; -use crate::{ - bytes::{Bytes, Deallocation}, - datatypes::ArrowNativeType, - ffi, -}; +use crate::alloc::{Allocation, Deallocation}; +use crate::ffi::FFI_ArrowArray; +use crate::util::bit_chunk_iterator::{BitChunks, UnalignedBitChunk}; +use crate::{bytes::Bytes, datatypes::ArrowNativeType}; use super::ops::bitwise_unary_op_helper; use super::MutableBuffer; @@ -76,7 +74,7 @@ impl Buffer { /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. pub unsafe fn from_raw_parts(ptr: NonNull, len: usize, capacity: usize) -> Self { assert!(len <= capacity); - Buffer::build_with_arguments(ptr, len, Deallocation::Native(capacity)) + Buffer::build_with_arguments(ptr, len, Deallocation::Arrow(capacity)) } /// Creates a buffer from an existing memory region (must already be byte-aligned), this @@ -86,18 +84,41 @@ impl Buffer { /// /// * `ptr` - Pointer to raw parts /// * `len` - Length of raw parts in **bytes** - /// * `data` - An [ffi::FFI_ArrowArray] with the data + /// * `data` - An [crate::ffi::FFI_ArrowArray] with the data /// /// # Safety /// /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` /// bytes and that the foreign deallocator frees the region. + #[deprecated( + note = "use from_custom_allocation instead which makes it clearer that the allocation is in fact owned" + )] pub unsafe fn from_unowned( ptr: NonNull, len: usize, - data: Arc, + data: Arc, ) -> Self { - Buffer::build_with_arguments(ptr, len, Deallocation::Foreign(data)) + Self::from_custom_allocation(ptr, len, data) + } + + /// Creates a buffer from an existing memory region. Ownership of the memory is tracked via reference counting + /// and the memory will be freed using the `drop` method of [crate::alloc::Allocation] when the reference count reaches zero. + /// + /// # Arguments + /// + /// * `ptr` - Pointer to raw parts + /// * `len` - Length of raw parts in **bytes** + /// * `owner` - A [crate::alloc::Allocation] which is responsible for freeing that data + /// + /// # Safety + /// + /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` bytes + pub unsafe fn from_custom_allocation( + ptr: NonNull, + len: usize, + owner: Arc, + ) -> Self { + Buffer::build_with_arguments(ptr, len, Deallocation::Custom(owner)) } /// Auxiliary method to create a new Buffer @@ -119,7 +140,7 @@ impl Buffer { } /// Returns the capacity of this buffer. - /// For exernally owned buffers, this returns zero + /// For externally owned buffers, this returns zero pub fn capacity(&self) -> usize { self.data.capacity() } @@ -153,29 +174,22 @@ impl Buffer { /// /// Note that this should be used cautiously, and the returned pointer should not be /// stored anywhere, to avoid dangling pointers. + #[inline] pub fn as_ptr(&self) -> *const u8 { unsafe { self.data.ptr().as_ptr().add(self.offset) } } /// View buffer as typed slice. /// - /// # Safety - /// - /// `ArrowNativeType` is public so that it can be used as a trait bound for other public - /// components, such as the `ToByteSlice` trait. However, this means that it can be - /// implemented by user defined types, which it is not intended for. + /// # Panics /// - /// Also `typed_data::` is unsafe as `0x00` and `0x01` are the only valid values for - /// `bool` in Rust. However, `bool` arrays in Arrow are bit-packed which breaks this condition. - /// View buffer as typed slice. - pub unsafe fn typed_data(&self) -> &[T] { - // JUSTIFICATION - // Benefit - // Many of the buffers represent specific types, and consumers of `Buffer` often need to re-interpret them. - // Soundness - // * The pointer is non-null by construction - // * alignment asserted below. - let (prefix, offsets, suffix) = self.as_slice().align_to::(); + /// This function panics if the underlying buffer is not aligned + /// correctly for type `T`. + pub fn typed_data(&self) -> &[T] { + // SAFETY + // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect + // implementation outside this crate, and this method checks alignment + let (prefix, offsets, suffix) = unsafe { self.as_slice().align_to::() }; assert!(prefix.is_empty() && suffix.is_empty()); offsets } @@ -208,11 +222,7 @@ impl Buffer { /// Returns the number of 1-bits in this buffer, starting from `offset` with `length` bits /// inspected. Note that both `offset` and `length` are measured in bits. pub fn count_set_bits_offset(&self, offset: usize, len: usize) -> usize { - let chunks = self.bit_chunks(offset, len); - let mut count = chunks.iter().map(|c| c.count_ones() as usize).sum(); - count += chunks.remainder_bits().count_ones() as usize; - - count + UnalignedBitChunk::new(self.as_slice(), offset, len).count_ones() } } @@ -248,6 +258,8 @@ impl std::ops::Deref for Buffer { } unsafe impl Sync for Buffer {} +// false positive, see https://github.com/apache/arrow-rs/pull/1169 +#[allow(clippy::non_send_fields_in_send_ty)] unsafe impl Send for Buffer {} impl From for Buffer { @@ -326,6 +338,7 @@ impl FromIterator for Buffer { #[cfg(test)] mod tests { + use std::panic::{RefUnwindSafe, UnwindSafe}; use std::thread; use super::*; @@ -434,7 +447,7 @@ mod tests { macro_rules! check_as_typed_data { ($input: expr, $native_t: ty) => {{ let buffer = Buffer::from_slice_ref($input); - let slice: &[$native_t] = unsafe { buffer.typed_data::<$native_t>() }; + let slice: &[$native_t] = buffer.typed_data::<$native_t>(); assert_eq!($input, slice); }}; } @@ -538,4 +551,30 @@ mod tests { Buffer::from(&[0b01101101, 0b10101010]).count_set_bits_offset(7, 9) ); } + + #[test] + fn test_unwind_safe() { + fn assert_unwind_safe() {} + assert_unwind_safe::() + } + + #[test] + fn test_from_foreign_vec() { + let mut vector = vec![1_i32, 2, 3, 4, 5]; + let buffer = unsafe { + Buffer::from_custom_allocation( + NonNull::new_unchecked(vector.as_mut_ptr() as *mut u8), + vector.len() * std::mem::size_of::(), + Arc::new(vector), + ) + }; + + let slice = buffer.typed_data::(); + assert_eq!(slice, &[1, 2, 3, 4, 5]); + + let buffer = buffer.slice(std::mem::size_of::()); + + let slice = buffer.typed_data::(); + assert_eq!(slice, &[2, 3, 4, 5]); + } } diff --git a/arrow/src/buffer/mod.rs b/arrow/src/buffer/mod.rs index cf0461b5f536..b392b0583d6d 100644 --- a/arrow/src/buffer/mod.rs +++ b/arrow/src/buffer/mod.rs @@ -23,6 +23,9 @@ pub use immutable::*; mod mutable; pub use mutable::*; mod ops; +mod scalar; +pub use scalar::*; + pub use ops::*; use crate::error::{ArrowError, Result}; diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs index d83997a3d24c..11783b82da54 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow/src/buffer/mutable.rs @@ -1,14 +1,3 @@ -use std::ptr::NonNull; - -use crate::{ - alloc, - bytes::{Bytes, Deallocation}, - datatypes::{ArrowNativeType, ToByteSlice}, - util::bit_util, -}; - -use super::Buffer; - // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -26,12 +15,26 @@ use super::Buffer; // specific language governing permissions and limitations // under the License. +use super::Buffer; +use crate::alloc::Deallocation; +use crate::{ + alloc, + bytes::Bytes, + datatypes::{ArrowNativeType, ToByteSlice}, + util::bit_util, +}; +use std::ptr::NonNull; + /// A [`MutableBuffer`] is Arrow's interface to build a [`Buffer`] out of items or slices of items. /// [`Buffer`]s created from [`MutableBuffer`] (via `into`) are guaranteed to have its pointer aligned /// along cache lines and in multiple of 64 bytes. /// Use [MutableBuffer::push] to insert an item, [MutableBuffer::extend_from_slice] /// to insert many items, and `into` to convert it to [`Buffer`]. +/// +/// For a safe, strongly typed API consider using [`crate::array::BufferBuilder`] +/// /// # Example +/// /// ``` /// # use arrow::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); @@ -153,6 +156,17 @@ impl MutableBuffer { } } + /// Truncates this buffer to `len` bytes + /// + /// If `len` is greater than the buffer's current length, this has no effect + #[inline(always)] + pub fn truncate(&mut self, len: usize) { + if len > self.len { + return; + } + self.len = len; + } + /// Resizes the buffer, either truncating its contents (with no change in capacity), or /// growing it (potentially reallocating it) and writing `value` in the newly available bytes. /// # Example @@ -268,22 +282,26 @@ impl MutableBuffer { #[inline] pub(super) fn into_buffer(self) -> Buffer { let bytes = unsafe { - Bytes::new(self.data, self.len, Deallocation::Native(self.capacity)) + Bytes::new(self.data, self.len, Deallocation::Arrow(self.capacity)) }; std::mem::forget(self); Buffer::from_bytes(bytes) } - /// View this buffer asa slice of a specific type. - /// # Safety - /// This function must only be used when this buffer was extended with items of type `T`. - /// Failure to do so results in undefined behavior. + /// View this buffer as a slice of a specific type. + /// + /// # Panics + /// + /// This function panics if the underlying buffer is not aligned + /// correctly for type `T`. pub fn typed_data_mut(&mut self) -> &mut [T] { - unsafe { - let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::(); - assert!(prefix.is_empty() && suffix.is_empty()); - offsets - } + // SAFETY + // ArrowNativeType is trivially transmutable, is sealed to prevent potentially incorrect + // implementation outside this crate, and this method checks alignment + let (prefix, offsets, suffix) = + unsafe { self.as_slice_mut().align_to_mut::() }; + assert!(prefix.is_empty() && suffix.is_empty()); + offsets } /// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed. @@ -295,13 +313,16 @@ impl MutableBuffer { /// assert_eq!(buffer.len(), 8) // u32 has 4 bytes /// ``` #[inline] - pub fn extend_from_slice(&mut self, items: &[T]) { + pub fn extend_from_slice(&mut self, items: &[T]) { let len = items.len(); let additional = len * std::mem::size_of::(); self.reserve(additional); unsafe { - let dst = self.data.as_ptr().add(self.len); + // this assumes that `[ToByteSlice]` can be copied directly + // without calling `to_byte_slice` for each element, + // which is correct for all ArrowNativeType implementations. let src = items.as_ptr() as *const u8; + let dst = self.data.as_ptr().add(self.len); std::ptr::copy_nonoverlapping(src, dst, additional) } self.len += additional; @@ -320,8 +341,9 @@ impl MutableBuffer { let additional = std::mem::size_of::(); self.reserve(additional); unsafe { - let dst = self.data.as_ptr().add(self.len) as *mut T; - std::ptr::write(dst, item); + let src = item.to_byte_slice().as_ptr(); + let dst = self.data.as_ptr().add(self.len); + std::ptr::copy_nonoverlapping(src, dst, additional); } self.len += additional; } @@ -332,8 +354,9 @@ impl MutableBuffer { #[inline] pub unsafe fn push_unchecked(&mut self, item: T) { let additional = std::mem::size_of::(); - let dst = self.data.as_ptr().add(self.len) as *mut T; - std::ptr::write(dst, item); + let src = item.to_byte_slice().as_ptr(); + let dst = self.data.as_ptr().add(self.len); + std::ptr::copy_nonoverlapping(src, dst, additional); self.len += additional; } @@ -380,23 +403,24 @@ impl MutableBuffer { &mut self, mut iterator: I, ) { - let size = std::mem::size_of::(); + let item_size = std::mem::size_of::(); let (lower, _) = iterator.size_hint(); - let additional = lower * size; + let additional = lower * item_size; self.reserve(additional); // this is necessary because of https://github.com/rust-lang/rust/issues/32155 let mut len = SetLenOnDrop::new(&mut self.len); - let mut dst = unsafe { self.data.as_ptr().add(len.local_len) as *mut T }; + let mut dst = unsafe { self.data.as_ptr().add(len.local_len) }; let capacity = self.capacity; - while len.local_len + size <= capacity { + while len.local_len + item_size <= capacity { if let Some(item) = iterator.next() { unsafe { - std::ptr::write(dst, item); - dst = dst.add(1); + let src = item.to_byte_slice().as_ptr(); + std::ptr::copy_nonoverlapping(src, dst, item_size); + dst = dst.add(item_size); } - len.local_len += size; + len.local_len += item_size; } else { break; } @@ -427,21 +451,23 @@ impl MutableBuffer { pub unsafe fn from_trusted_len_iter>( iterator: I, ) -> Self { + let item_size = std::mem::size_of::(); let (_, upper) = iterator.size_hint(); let upper = upper.expect("from_trusted_len_iter requires an upper limit"); - let len = upper * std::mem::size_of::(); + let len = upper * item_size; let mut buffer = MutableBuffer::new(len); - let mut dst = buffer.data.as_ptr() as *mut T; + let mut dst = buffer.data.as_ptr(); for item in iterator { // note how there is no reserve here (compared with `extend_from_iter`) - std::ptr::write(dst, item); - dst = dst.add(1); + let src = item.to_byte_slice().as_ptr(); + std::ptr::copy_nonoverlapping(src, dst, item_size); + dst = dst.add(item_size); } assert_eq!( - dst.offset_from(buffer.data.as_ptr() as *mut T) as usize, - upper, + dst.offset_from(buffer.data.as_ptr()) as usize, + len, "Trusted iterator length was not accurately reported" ); buffer.len = len; @@ -518,34 +544,32 @@ impl MutableBuffer { >( iterator: I, ) -> std::result::Result { + let item_size = std::mem::size_of::(); let (_, upper) = iterator.size_hint(); let upper = upper.expect("try_from_trusted_len_iter requires an upper limit"); - let len = upper * std::mem::size_of::(); + let len = upper * item_size; let mut buffer = MutableBuffer::new(len); - let mut dst = buffer.data.as_ptr() as *mut T; + let mut dst = buffer.data.as_ptr(); for item in iterator { + let item = item?; // note how there is no reserve here (compared with `extend_from_iter`) - std::ptr::write(dst, item?); - dst = dst.add(1); + let src = item.to_byte_slice().as_ptr(); + std::ptr::copy_nonoverlapping(src, dst, item_size); + dst = dst.add(item_size); } // try_from_trusted_len_iter is instantiated a lot, so we extract part of it into a less // generic method to reduce compile time - unsafe fn finalize_buffer( - dst: *mut T, - buffer: &mut MutableBuffer, - upper: usize, - len: usize, - ) { + unsafe fn finalize_buffer(dst: *mut u8, buffer: &mut MutableBuffer, len: usize) { assert_eq!( - dst.offset_from(buffer.data.as_ptr() as *mut T) as usize, - upper, + dst.offset_from(buffer.data.as_ptr()) as usize, + len, "Trusted iterator length was not accurately reported" ); buffer.len = len; } - finalize_buffer(dst, &mut buffer, upper, len); + finalize_buffer(dst, &mut buffer, len); Ok(buffer) } } @@ -707,6 +731,44 @@ mod tests { ); } + #[test] + fn mutable_extend_from_iter_unaligned_u64() { + let mut buf = MutableBuffer::new(16); + buf.push(1_u8); + buf.extend([1_u64]); + assert_eq!(9, buf.len()); + assert_eq!(&[1u8, 1u8, 0, 0, 0, 0, 0, 0, 0], buf.as_slice()); + } + + #[test] + fn mutable_extend_from_slice_unaligned_u64() { + let mut buf = MutableBuffer::new(16); + buf.extend_from_slice(&[1_u8]); + buf.extend_from_slice(&[1_u64]); + assert_eq!(9, buf.len()); + assert_eq!(&[1u8, 1u8, 0, 0, 0, 0, 0, 0, 0], buf.as_slice()); + } + + #[test] + fn mutable_push_unaligned_u64() { + let mut buf = MutableBuffer::new(16); + buf.push(1_u8); + buf.push(1_u64); + assert_eq!(9, buf.len()); + assert_eq!(&[1u8, 1u8, 0, 0, 0, 0, 0, 0, 0], buf.as_slice()); + } + + #[test] + fn mutable_push_unchecked_unaligned_u64() { + let mut buf = MutableBuffer::new(16); + unsafe { + buf.push_unchecked(1_u8); + buf.push_unchecked(1_u64); + } + assert_eq!(9, buf.len()); + assert_eq!(&[1u8, 1u8, 0, 0, 0, 0, 0, 0, 0], buf.as_slice()); + } + #[test] fn test_from_trusted_len_iter() { let iter = vec![1u32, 2].into_iter(); diff --git a/arrow/src/buffer/ops.rs b/arrow/src/buffer/ops.rs index c37fd14bd22a..ea155c8d78e4 100644 --- a/arrow/src/buffer/ops.rs +++ b/arrow/src/buffer/ops.rs @@ -15,110 +15,8 @@ // specific language governing permissions and limitations // under the License. -#[cfg(feature = "simd")] -use crate::util::bit_util; -#[cfg(feature = "simd")] -use packed_simd::u8x64; - -#[cfg(feature = "avx512")] -use crate::arch::avx512::*; -use crate::util::bit_util::ceil; -#[cfg(any(feature = "simd", feature = "avx512"))] -use std::borrow::BorrowMut; - use super::{Buffer, MutableBuffer}; - -/// Apply a bitwise operation `simd_op` / `scalar_op` to two inputs using simd instructions and return the result as a Buffer. -/// The `simd_op` functions gets applied on chunks of 64 bytes (512 bits) at a time -/// and the `scalar_op` gets applied to remaining bytes. -/// Contrary to the non-simd version `bitwise_bin_op_helper`, the offset and length is specified in bytes -/// and this version does not support operations starting at arbitrary bit offsets. -#[cfg(feature = "simd")] -pub fn bitwise_bin_op_simd_helper( - left: &Buffer, - left_offset: usize, - right: &Buffer, - right_offset: usize, - len: usize, - simd_op: F_SIMD, - scalar_op: F_SCALAR, -) -> Buffer -where - F_SIMD: Fn(u8x64, u8x64) -> u8x64, - F_SCALAR: Fn(u8, u8) -> u8, -{ - let mut result = MutableBuffer::new(len).with_bitset(len, false); - let lanes = u8x64::lanes(); - - let mut left_chunks = left.as_slice()[left_offset..].chunks_exact(lanes); - let mut right_chunks = right.as_slice()[right_offset..].chunks_exact(lanes); - let mut result_chunks = result.as_slice_mut().chunks_exact_mut(lanes); - - result_chunks - .borrow_mut() - .zip(left_chunks.borrow_mut().zip(right_chunks.borrow_mut())) - .for_each(|(res, (left, right))| { - unsafe { bit_util::bitwise_bin_op_simd(&left, &right, res, &simd_op) }; - }); - - result_chunks - .into_remainder() - .iter_mut() - .zip( - left_chunks - .remainder() - .iter() - .zip(right_chunks.remainder().iter()), - ) - .for_each(|(res, (left, right))| { - *res = scalar_op(*left, *right); - }); - - result.into() -} - -/// Apply a bitwise operation `simd_op` / `scalar_op` to one input using simd instructions and return the result as a Buffer. -/// The `simd_op` functions gets applied on chunks of 64 bytes (512 bits) at a time -/// and the `scalar_op` gets applied to remaining bytes. -/// Contrary to the non-simd version `bitwise_unary_op_helper`, the offset and length is specified in bytes -/// and this version does not support operations starting at arbitrary bit offsets. -#[cfg(feature = "simd")] -pub fn bitwise_unary_op_simd_helper( - left: &Buffer, - left_offset: usize, - len: usize, - simd_op: F_SIMD, - scalar_op: F_SCALAR, -) -> Buffer -where - F_SIMD: Fn(u8x64) -> u8x64, - F_SCALAR: Fn(u8) -> u8, -{ - let mut result = MutableBuffer::new(len).with_bitset(len, false); - let lanes = u8x64::lanes(); - - let mut left_chunks = left.as_slice()[left_offset..].chunks_exact(lanes); - let mut result_chunks = result.as_slice_mut().chunks_exact_mut(lanes); - - result_chunks - .borrow_mut() - .zip(left_chunks.borrow_mut()) - .for_each(|(res, left)| unsafe { - let data_simd = u8x64::from_slice_unaligned_unchecked(left); - let simd_result = simd_op(data_simd); - simd_result.write_to_slice_unaligned_unchecked(res); - }); - - result_chunks - .into_remainder() - .iter_mut() - .zip(left_chunks.remainder().iter()) - .for_each(|(res, left)| { - *res = scalar_op(*left); - }); - - result.into() -} +use crate::util::bit_util::ceil; /// Apply a bitwise operation `op` to two inputs and return the result as a Buffer. /// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. @@ -140,7 +38,8 @@ where .iter() .zip(right_chunks.iter()) .map(|(left, right)| op(left, right)); - // Soundness: `BitChunks` is a trusted len iterator + // Soundness: `BitChunks` is a `BitChunks` iterator which + // correctly reports its upper bound let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) }; let remainder_bytes = ceil(left_chunks.remainder_len(), 8); @@ -168,6 +67,7 @@ where MutableBuffer::new(ceil(len_in_bits, 8)).with_bitset(len_in_bits / 64 * 8, false); let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits); + let result_chunks = result.typed_data_mut::().iter_mut(); result_chunks @@ -185,100 +85,6 @@ where result.into() } -#[cfg(all(target_arch = "x86_64", feature = "avx512"))] -pub fn buffer_bin_and( - left: &Buffer, - left_offset_in_bits: usize, - right: &Buffer, - right_offset_in_bits: usize, - len_in_bits: usize, -) -> Buffer { - if left_offset_in_bits % 8 == 0 - && right_offset_in_bits % 8 == 0 - && len_in_bits % 8 == 0 - { - let len = len_in_bits / 8; - let left_offset = left_offset_in_bits / 8; - let right_offset = right_offset_in_bits / 8; - - let mut result = MutableBuffer::new(len).with_bitset(len, false); - - let mut left_chunks = - left.as_slice()[left_offset..].chunks_exact(AVX512_U8X64_LANES); - let mut right_chunks = - right.as_slice()[right_offset..].chunks_exact(AVX512_U8X64_LANES); - let mut result_chunks = - result.as_slice_mut().chunks_exact_mut(AVX512_U8X64_LANES); - - result_chunks - .borrow_mut() - .zip(left_chunks.borrow_mut().zip(right_chunks.borrow_mut())) - .for_each(|(res, (left, right))| unsafe { - avx512_bin_and(left, right, res); - }); - - result_chunks - .into_remainder() - .iter_mut() - .zip( - left_chunks - .remainder() - .iter() - .zip(right_chunks.remainder().iter()), - ) - .for_each(|(res, (left, right))| { - *res = *left & *right; - }); - - result.into() - } else { - bitwise_bin_op_helper( - &left, - left_offset_in_bits, - right, - right_offset_in_bits, - len_in_bits, - |a, b| a & b, - ) - } -} - -#[cfg(all(feature = "simd", not(feature = "avx512")))] -pub fn buffer_bin_and( - left: &Buffer, - left_offset_in_bits: usize, - right: &Buffer, - right_offset_in_bits: usize, - len_in_bits: usize, -) -> Buffer { - if left_offset_in_bits % 8 == 0 - && right_offset_in_bits % 8 == 0 - && len_in_bits % 8 == 0 - { - bitwise_bin_op_simd_helper( - &left, - left_offset_in_bits / 8, - &right, - right_offset_in_bits / 8, - len_in_bits / 8, - |a, b| a & b, - |a, b| a & b, - ) - } else { - bitwise_bin_op_helper( - &left, - left_offset_in_bits, - right, - right_offset_in_bits, - len_in_bits, - |a, b| a & b, - ) - } -} - -// Note: do not target specific features like x86 without considering -// other targets like wasm32, as those would fail to build -#[cfg(all(not(any(feature = "simd", feature = "avx512"))))] pub fn buffer_bin_and( left: &Buffer, left_offset_in_bits: usize, @@ -296,98 +102,6 @@ pub fn buffer_bin_and( ) } -#[cfg(all(target_arch = "x86_64", feature = "avx512"))] -pub fn buffer_bin_or( - left: &Buffer, - left_offset_in_bits: usize, - right: &Buffer, - right_offset_in_bits: usize, - len_in_bits: usize, -) -> Buffer { - if left_offset_in_bits % 8 == 0 - && right_offset_in_bits % 8 == 0 - && len_in_bits % 8 == 0 - { - let len = len_in_bits / 8; - let left_offset = left_offset_in_bits / 8; - let right_offset = right_offset_in_bits / 8; - - let mut result = MutableBuffer::new(len).with_bitset(len, false); - - let mut left_chunks = - left.as_slice()[left_offset..].chunks_exact(AVX512_U8X64_LANES); - let mut right_chunks = - right.as_slice()[right_offset..].chunks_exact(AVX512_U8X64_LANES); - let mut result_chunks = - result.as_slice_mut().chunks_exact_mut(AVX512_U8X64_LANES); - - result_chunks - .borrow_mut() - .zip(left_chunks.borrow_mut().zip(right_chunks.borrow_mut())) - .for_each(|(res, (left, right))| unsafe { - avx512_bin_or(left, right, res); - }); - - result_chunks - .into_remainder() - .iter_mut() - .zip( - left_chunks - .remainder() - .iter() - .zip(right_chunks.remainder().iter()), - ) - .for_each(|(res, (left, right))| { - *res = *left | *right; - }); - - result.into() - } else { - bitwise_bin_op_helper( - &left, - left_offset_in_bits, - right, - right_offset_in_bits, - len_in_bits, - |a, b| a | b, - ) - } -} - -#[cfg(all(feature = "simd", not(feature = "avx512")))] -pub fn buffer_bin_or( - left: &Buffer, - left_offset_in_bits: usize, - right: &Buffer, - right_offset_in_bits: usize, - len_in_bits: usize, -) -> Buffer { - if left_offset_in_bits % 8 == 0 - && right_offset_in_bits % 8 == 0 - && len_in_bits % 8 == 0 - { - bitwise_bin_op_simd_helper( - &left, - left_offset_in_bits / 8, - &right, - right_offset_in_bits / 8, - len_in_bits / 8, - |a, b| a | b, - |a, b| a | b, - ) - } else { - bitwise_bin_op_helper( - &left, - left_offset_in_bits, - right, - right_offset_in_bits, - len_in_bits, - |a, b| a | b, - ) - } -} - -#[cfg(all(not(any(feature = "simd", feature = "avx512"))))] pub fn buffer_bin_or( left: &Buffer, left_offset_in_bits: usize, @@ -410,20 +124,5 @@ pub fn buffer_unary_not( offset_in_bits: usize, len_in_bits: usize, ) -> Buffer { - // SIMD implementation if available and byte-aligned - #[cfg(feature = "simd")] - if offset_in_bits % 8 == 0 && len_in_bits % 8 == 0 { - return bitwise_unary_op_simd_helper( - &left, - offset_in_bits / 8, - len_in_bits / 8, - |a| !a, - |a| !a, - ); - } - // Default implementation - #[allow(unreachable_code)] - { - bitwise_unary_op_helper(left, offset_in_bits, len_in_bits, |a| !a) - } + bitwise_unary_op_helper(left, offset_in_bits, len_in_bits, |a| !a) } diff --git a/arrow/src/buffer/scalar.rs b/arrow/src/buffer/scalar.rs new file mode 100644 index 000000000000..7d663cd2bf96 --- /dev/null +++ b/arrow/src/buffer/scalar.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::Buffer; +use crate::datatypes::ArrowNativeType; +use std::ops::Deref; + +/// Provides a safe API for interpreting a [`Buffer`] as a slice of [`ArrowNativeType`] +/// +/// # Safety +/// +/// All [`ArrowNativeType`] are valid for all possible backing byte representations, and as +/// a result they are "trivially safely transmutable". +#[derive(Debug)] +pub struct ScalarBuffer { + #[allow(unused)] + buffer: Buffer, + // Borrows from `buffer` and is valid for the lifetime of `buffer` + ptr: *const T, + // The length of this slice + len: usize, +} + +impl ScalarBuffer { + /// Create a new [`ScalarBuffer`] from a [`Buffer`], and an `offset` + /// and `length` in units of `T` + /// + /// # Panics + /// + /// This method will panic if + /// + /// * `offset` or `len` would result in overflow + /// * `buffer` is not aligned to a multiple of `std::mem::size_of::` + /// * `bytes` is not large enough for the requested slice + pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { + let size = std::mem::size_of::(); + let offset_len = offset.checked_add(len).expect("length overflow"); + let start_bytes = offset.checked_mul(size).expect("start bytes overflow"); + let end_bytes = offset_len.checked_mul(size).expect("end bytes overflow"); + + let bytes = &buffer.as_slice()[start_bytes..end_bytes]; + + // SAFETY: all byte sequences correspond to a valid instance of T + let (prefix, offsets, suffix) = unsafe { bytes.align_to::() }; + assert!( + prefix.is_empty() && suffix.is_empty(), + "buffer is not aligned to {} byte boundary", + size + ); + + let ptr = offsets.as_ptr(); + Self { buffer, ptr, len } + } +} + +impl Deref for ScalarBuffer { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + // SAFETY: Bounds checked in constructor and ptr is valid for the lifetime of self + unsafe { std::slice::from_raw_parts(self.ptr, self.len) } + } +} + +impl AsRef<[T]> for ScalarBuffer { + fn as_ref(&self) -> &[T] { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic() { + let expected = [0_i32, 1, 2]; + let buffer = Buffer::from_iter(expected.iter().cloned()); + let typed = ScalarBuffer::::new(buffer.clone(), 0, 3); + assert_eq!(*typed, expected); + + let typed = ScalarBuffer::::new(buffer.clone(), 1, 2); + assert_eq!(*typed, expected[1..]); + + let typed = ScalarBuffer::::new(buffer.clone(), 1, 0); + assert!(typed.is_empty()); + + let typed = ScalarBuffer::::new(buffer, 3, 0); + assert!(typed.is_empty()); + } + + #[test] + #[should_panic(expected = "buffer is not aligned to 4 byte boundary")] + fn test_unaligned() { + let expected = [0_i32, 1, 2]; + let buffer = Buffer::from_iter(expected.iter().cloned()); + let buffer = buffer.slice(1); + ScalarBuffer::::new(buffer, 0, 2); + } + + #[test] + #[should_panic(expected = "range end index 16 out of range for slice of length 12")] + fn test_length_out_of_bounds() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 1, 3); + } + + #[test] + #[should_panic(expected = "range end index 16 out of range for slice of length 12")] + fn test_offset_out_of_bounds() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 4, 0); + } + + #[test] + #[should_panic(expected = "length overflow")] + fn test_length_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, usize::MAX, 1); + } + + #[test] + #[should_panic(expected = "start bytes overflow")] + fn test_start_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, usize::MAX / 4 + 1, 0); + } + + #[test] + #[should_panic(expected = "end bytes overflow")] + fn test_end_overflow() { + let buffer = Buffer::from_iter([0_i32, 1, 2]); + ScalarBuffer::::new(buffer, 0, usize::MAX / 4 + 1); + } +} diff --git a/arrow/src/bytes.rs b/arrow/src/bytes.rs index 38fa4439b42d..7b57552e60f6 100644 --- a/arrow/src/bytes.rs +++ b/arrow/src/bytes.rs @@ -21,45 +21,25 @@ use core::slice; use std::ptr::NonNull; -use std::sync::Arc; use std::{fmt::Debug, fmt::Formatter}; -use crate::{alloc, ffi}; - -/// Mode of deallocating memory regions -pub enum Deallocation { - /// Native deallocation, using Rust deallocator with Arrow-specific memory aligment - Native(usize), - /// Foreign interface, via a callback - Foreign(Arc), -} - -impl Debug for Deallocation { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - match self { - Deallocation::Native(capacity) => { - write!(f, "Deallocation::Native {{ capacity: {} }}", capacity) - } - Deallocation::Foreign(_) => { - write!(f, "Deallocation::Foreign {{ capacity: unknown }}") - } - } - } -} +use crate::alloc; +use crate::alloc::Deallocation; /// A continuous, fixed-size, immutable memory region that knows how to de-allocate itself. /// This structs' API is inspired by the `bytes::Bytes`, but it is not limited to using rust's -/// global allocator nor u8 aligmnent. +/// global allocator nor u8 alignment. +/// +/// In the most common case, this buffer is allocated using [`allocate_aligned`](crate::alloc::allocate_aligned) +/// and deallocated accordingly [`free_aligned`](crate::alloc::free_aligned). /// -/// In the most common case, this buffer is allocated using [`allocate_aligned`](memory::allocate_aligned) -/// and deallocated accordingly [`free_aligned`](memory::free_aligned). -/// When the region is allocated by an foreign allocator, [Deallocation::Foreign], this calls the -/// foreign deallocator to deallocate the region when it is no longer needed. +/// When the region is allocated by a different allocator, [Deallocation::Custom], this calls the +/// custom deallocator to deallocate the region when it is no longer needed. pub struct Bytes { - /// The raw pointer to be begining of the region + /// The raw pointer to be beginning of the region ptr: NonNull, - /// The number of bytes visible to this region. This is always smaller than its capacity (when avaliable). + /// The number of bytes visible to this region. This is always smaller than its capacity (when available). len: usize, /// how to deallocate this region @@ -80,7 +60,7 @@ impl Bytes { /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` /// bytes. If the `ptr` and `capacity` come from a `Buffer`, then this is guaranteed. #[inline] - pub unsafe fn new( + pub(crate) unsafe fn new( ptr: std::ptr::NonNull, len: usize, deallocation: Deallocation, @@ -113,10 +93,10 @@ impl Bytes { pub fn capacity(&self) -> usize { match self.deallocation { - Deallocation::Native(capacity) => capacity, + Deallocation::Arrow(capacity) => capacity, // we cannot determine this in general, // and thus we state that this is externally-owned memory - Deallocation::Foreign(_) => 0, + Deallocation::Custom(_) => 0, } } } @@ -125,11 +105,11 @@ impl Drop for Bytes { #[inline] fn drop(&mut self) { match &self.deallocation { - Deallocation::Native(capacity) => { + Deallocation::Arrow(capacity) => { unsafe { alloc::free_aligned::(self.ptr, *capacity) }; } - // foreign interface knows how to deallocate itself. - Deallocation::Foreign(_) => (), + // The automatic drop implementation will free the memory once the reference count reaches zero + Deallocation::Custom(_allocation) => (), } } } diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index a385c674cbe7..12ead669f79d 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -21,7 +21,8 @@ use multiversion::multiversion; use std::ops::Add; use crate::array::{ - Array, BooleanArray, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait, + Array, BooleanArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, + PrimitiveArray, }; use crate::datatypes::{ArrowNativeType, ArrowNumericType}; @@ -32,41 +33,6 @@ fn is_nan(a: T) -> bool { !(a == a) } -/// Helper macro to perform min/max of strings -fn min_max_string bool>( - array: &GenericStringArray, - cmp: F, -) -> Option<&str> { - let null_count = array.null_count(); - - if null_count == array.len() { - return None; - } - let data = array.data(); - let mut n; - if null_count == 0 { - n = array.value(0); - for i in 1..data.len() { - let item = array.value(i); - if cmp(n, item) { - n = item; - } - } - } else { - n = ""; - let mut has_value = false; - - for i in 0..data.len() { - let item = array.value(i); - if data.is_valid(i) && (!has_value || cmp(n, item)) { - has_value = true; - n = item; - } - } - } - Some(n) -} - /// Returns the minimum value in the array, according to the natural order. /// For floating point arrays any NaN values are considered to be greater than any other non-null value #[cfg(not(feature = "simd"))] @@ -89,20 +55,6 @@ where min_max_helper(array, |a, b| (!is_nan(*a) & is_nan(*b)) || a < b) } -/// Returns the maximum value in the string array, according to the natural order. -pub fn max_string( - array: &GenericStringArray, -) -> Option<&str> { - min_max_string(array, |a, b| a < b) -} - -/// Returns the minimum value in the string array, according to the natural order. -pub fn min_string( - array: &GenericStringArray, -) -> Option<&str> { - min_max_string(array, |a, b| a > b) -} - /// Helper function to perform min/max lambda function on values from a numeric array. #[multiversion] #[clone(target = "x86_64+avx")] @@ -190,6 +142,48 @@ pub fn max_boolean(array: &BooleanArray) -> Option { .or(Some(false)) } +/// Helper to compute min/max of [`GenericStringArray`] and [`GenericBinaryArray`] +macro_rules! min_max_binary_string { + ($array: expr, $cmp: expr) => {{ + let null_count = $array.null_count(); + if null_count == $array.len() { + None + } else if null_count == 0 { + // JUSTIFICATION + // Benefit: ~8% speedup + // Soundness: `i` is always within the array bounds + (0..$array.len()) + .map(|i| unsafe { $array.value_unchecked(i) }) + .reduce(|acc, item| if $cmp(acc, item) { item } else { acc }) + } else { + $array + .iter() + .flatten() + .reduce(|acc, item| if $cmp(acc, item) { item } else { acc }) + } + }}; +} + +/// Returns the maximum value in the binary array, according to the natural order. +pub fn max_binary(array: &GenericBinaryArray) -> Option<&[u8]> { + min_max_binary_string!(array, |a, b| a < b) +} + +/// Returns the minimum value in the binary array, according to the natural order. +pub fn min_binary(array: &GenericBinaryArray) -> Option<&[u8]> { + min_max_binary_string!(array, |a, b| a > b) +} + +/// Returns the maximum value in the string array, according to the natural order. +pub fn max_string(array: &GenericStringArray) -> Option<&str> { + min_max_binary_string!(array, |a, b| a < b) +} + +/// Returns the minimum value in the string array, according to the natural order. +pub fn min_string(array: &GenericStringArray) -> Option<&str> { + min_max_binary_string!(array, |a, b| a > b) +} + /// Returns the sum of values in the array. /// /// Returns `None` if the array is empty or only contains null values. @@ -641,7 +635,7 @@ mod tests { #[test] fn test_primitive_array_float_sum() { let a = Float64Array::from(vec![1.1, 2.2, 3.3, 4.4, 5.5]); - assert!(16.5 - sum(&a).unwrap() < f64::EPSILON); + assert_eq!(16.5, sum(&a).unwrap()); } #[test] @@ -899,11 +893,45 @@ mod tests { assert!(max(&a).unwrap().is_nan()); } + #[test] + fn test_binary_min_max_with_nulls() { + let a = BinaryArray::from(vec![ + Some("b".as_bytes()), + None, + None, + Some(b"a"), + Some(b"c"), + ]); + assert_eq!(Some("a".as_bytes()), min_binary(&a)); + assert_eq!(Some("c".as_bytes()), max_binary(&a)); + } + + #[test] + fn test_binary_min_max_no_null() { + let a = BinaryArray::from(vec![Some("b".as_bytes()), Some(b"a"), Some(b"c")]); + assert_eq!(Some("a".as_bytes()), min_binary(&a)); + assert_eq!(Some("c".as_bytes()), max_binary(&a)); + } + + #[test] + fn test_binary_min_max_all_nulls() { + let a = BinaryArray::from(vec![None, None]); + assert_eq!(None, min_binary(&a)); + assert_eq!(None, max_binary(&a)); + } + + #[test] + fn test_binary_min_max_1() { + let a = BinaryArray::from(vec![None, None, Some("b".as_bytes()), Some(b"a")]); + assert_eq!(Some("a".as_bytes()), min_binary(&a)); + assert_eq!(Some("b".as_bytes()), max_binary(&a)); + } + #[test] fn test_string_min_max_with_nulls() { let a = StringArray::from(vec![Some("b"), None, None, Some("a"), Some("c")]); - assert_eq!("a", min_string(&a).unwrap()); - assert_eq!("c", max_string(&a).unwrap()); + assert_eq!(Some("a"), min_string(&a)); + assert_eq!(Some("c"), max_string(&a)); } #[test] diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index f92888b37965..04865e15bca2 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -41,114 +41,6 @@ use std::borrow::BorrowMut; #[cfg(feature = "simd")] use std::slice::{ChunksExact, ChunksExactMut}; -/// SIMD vectorized version of `unary_math_op` above specialized for signed numerical values. -#[cfg(feature = "simd")] -fn simd_signed_unary_math_op( - array: &PrimitiveArray, - simd_op: SIMD_OP, - scalar_op: SCALAR_OP, -) -> Result> -where - T: datatypes::ArrowSignedNumericType, - SIMD_OP: Fn(T::SignedSimd) -> T::SignedSimd, - SCALAR_OP: Fn(T::Native) -> T::Native, -{ - let lanes = T::lanes(); - let buffer_size = array.len() * std::mem::size_of::(); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut array_chunks = array.values().chunks_exact(lanes); - - result_chunks - .borrow_mut() - .zip(array_chunks.borrow_mut()) - .for_each(|(result_slice, input_slice)| { - let simd_input = T::load_signed(input_slice); - let simd_result = T::signed_unary_op(simd_input, &simd_op); - T::write_signed(simd_result, result_slice); - }); - - let result_remainder = result_chunks.into_remainder(); - let array_remainder = array_chunks.remainder(); - - result_remainder.into_iter().zip(array_remainder).for_each( - |(scalar_result, scalar_input)| { - *scalar_result = scalar_op(*scalar_input); - }, - ); - - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - array.len(), - None, - array - .data_ref() - .null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())), - 0, - vec![result.into()], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) -} - -#[cfg(feature = "simd")] -fn simd_float_unary_math_op( - array: &PrimitiveArray, - simd_op: SIMD_OP, - scalar_op: SCALAR_OP, -) -> Result> -where - T: datatypes::ArrowFloatNumericType, - SIMD_OP: Fn(T::Simd) -> T::Simd, - SCALAR_OP: Fn(T::Native) -> T::Native, -{ - let lanes = T::lanes(); - let buffer_size = array.len() * std::mem::size_of::(); - - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut array_chunks = array.values().chunks_exact(lanes); - - result_chunks - .borrow_mut() - .zip(array_chunks.borrow_mut()) - .for_each(|(result_slice, input_slice)| { - let simd_input = T::load(input_slice); - let simd_result = T::unary_op(simd_input, &simd_op); - T::write(simd_result, result_slice); - }); - - let result_remainder = result_chunks.into_remainder(); - let array_remainder = array_chunks.remainder(); - - result_remainder.into_iter().zip(array_remainder).for_each( - |(scalar_result, scalar_input)| { - *scalar_result = scalar_op(*scalar_input); - }, - ); - - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - array.len(), - None, - array - .data_ref() - .null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())), - 0, - vec![result.into()], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) -} - /// Helper function to perform math lambda function on values from two arrays. If either /// left or right value is null then the output value is also null, so `1 + null` is /// `null`. @@ -172,7 +64,7 @@ where } let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; + combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; let values = left .values() @@ -183,7 +75,7 @@ where // Benefit // ~60% speedup // Soundness - // `values` is an iterator with a known size. + // `values` is an iterator with a known size from a PrimitiveArray let buffer = unsafe { Buffer::from_trusted_len_iter(values) }; let data = unsafe { @@ -200,20 +92,23 @@ where Ok(PrimitiveArray::::from(data)) } -/// Helper function to modulus two arrays. +/// Helper function for operations where a valid `0` on the right array should +/// result in an [ArrowError::DivideByZero], namely the division and modulo operations /// /// # Errors /// /// This function errors if: /// * the arrays have different lengths -/// * a division by zero is found -fn math_modulus( +/// * there is an element where both left and right values are valid and the right value is `0` +fn math_checked_divide_op( left: &PrimitiveArray, right: &PrimitiveArray, + op: F, ) -> Result> where T: ArrowNumericType, - T::Native: Rem + Zero, + T::Native: One + Zero, + F: Fn(T::Native, T::Native) -> T::Native, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -222,7 +117,7 @@ where } let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; + combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; let buffer = if let Some(b) = &null_bit_buffer { let values = left.values().iter().zip(right.values()).enumerate().map( @@ -232,13 +127,14 @@ where if right.is_zero() { Err(ArrowError::DivideByZero) } else { - Ok(*left % *right) + Ok(op(*left, *right)) } } else { Ok(T::default_value()) } }, ); + // Safety: Iterator comes from a PrimitiveArray which reports its size correctly unsafe { Buffer::try_from_trusted_len_iter(values) } } else { // no value is null @@ -250,9 +146,10 @@ where if right.is_zero() { Err(ArrowError::DivideByZero) } else { - Ok(*left % *right) + Ok(op(*left, *right)) } }); + // Safety: Iterator comes from a PrimitiveArray which reports its size correctly unsafe { Buffer::try_from_trusted_len_iter(values) } }?; @@ -270,176 +167,12 @@ where Ok(PrimitiveArray::::from(data)) } -/// Helper function to divide two arrays. +/// Calculates the modulus operation `left % right` on two SIMD inputs. +/// The lower-most bits of `valid_mask` specify which vector lanes are considered as valid. /// /// # Errors /// -/// This function errors if: -/// * the arrays have different lengths -/// * a division by zero is found -fn math_divide( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result> -where - T: ArrowNumericType, - T::Native: Div + Zero, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - - let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; - - let buffer = if let Some(b) = &null_bit_buffer { - let values = left.values().iter().zip(right.values()).enumerate().map( - |(i, (left, right))| { - let is_valid = unsafe { bit_util::get_bit_raw(b.as_ptr(), i) }; - if is_valid { - if right.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(*left / *right) - } - } else { - Ok(T::default_value()) - } - }, - ); - unsafe { Buffer::try_from_trusted_len_iter(values) } - } else { - // no value is null - let values = left - .values() - .iter() - .zip(right.values()) - .map(|(left, right)| { - if right.is_zero() { - Err(ArrowError::DivideByZero) - } else { - Ok(*left / *right) - } - }); - unsafe { Buffer::try_from_trusted_len_iter(values) } - }?; - - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - left.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) -} - -/// Scalar-modulo version of `math_modulus`. -fn math_modulus_scalar( - array: &PrimitiveArray, - modulo: T::Native, -) -> Result> -where - T: ArrowNumericType, - T::Native: Rem + Zero, -{ - if modulo.is_zero() { - return Err(ArrowError::DivideByZero); - } - - Ok(unary(array, |value| value % modulo)) -} - -/// Scalar-divisor version of `math_divide`. -fn math_divide_scalar( - array: &PrimitiveArray, - divisor: T::Native, -) -> Result> -where - T: ArrowNumericType, - T::Native: Div + Zero, -{ - if divisor.is_zero() { - return Err(ArrowError::DivideByZero); - } - - Ok(unary(array, |value| value / divisor)) -} - -/// SIMD vectorized version of `math_op` above. -#[cfg(feature = "simd")] -fn simd_math_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - simd_op: SIMD_OP, - scalar_op: SCALAR_OP, -) -> Result> -where - T: ArrowNumericType, - SIMD_OP: Fn(T::Simd, T::Simd) -> T::Simd, - SCALAR_OP: Fn(T::Native, T::Native) -> T::Native, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - - let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; - - let lanes = T::lanes(); - let buffer_size = left.len() * std::mem::size_of::(); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut left_chunks = left.values().chunks_exact(lanes); - let mut right_chunks = right.values().chunks_exact(lanes); - - result_chunks - .borrow_mut() - .zip(left_chunks.borrow_mut().zip(right_chunks.borrow_mut())) - .for_each(|(result_slice, (left_slice, right_slice))| { - let simd_left = T::load(left_slice); - let simd_right = T::load(right_slice); - let simd_result = T::bin_op(simd_left, simd_right, &simd_op); - T::write(simd_result, result_slice); - }); - - let result_remainder = result_chunks.into_remainder(); - let left_remainder = left_chunks.remainder(); - let right_remainder = right_chunks.remainder(); - - result_remainder - .iter_mut() - .zip(left_remainder.iter().zip(right_remainder.iter())) - .for_each(|(scalar_result, (scalar_left, scalar_right))| { - *scalar_result = scalar_op(*scalar_left, *scalar_right); - }); - - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - left.len(), - None, - null_bit_buffer, - 0, - vec![result.into()], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) -} - -/// SIMD vectorized implementation of `left % right`. -/// If any of the lanes marked as valid in `valid_mask` are `0` then an `ArrowError::DivideByZero` -/// is returned. The contents of no-valid lanes are undefined. +/// This function returns a [`ArrowError::DivideByZero`] if a valid element in `right` is `0` #[cfg(feature = "simd")] #[inline] fn simd_checked_modulus( @@ -471,9 +204,12 @@ where } } -/// SIMD vectorized implementation of `left / right`. -/// If any of the lanes marked as valid in `valid_mask` are `0` then an `ArrowError::DivideByZero` -/// is returned. The contents of no-valid lanes are undefined. +/// Calculates the division operation `left / right` on two SIMD inputs. +/// The lower-most bits of `valid_mask` specify which vector lanes are considered as valid. +/// +/// # Errors +/// +/// This function returns a [`ArrowError::DivideByZero`] if a valid element in `right` is `0` #[cfg(feature = "simd")] #[inline] fn simd_checked_divide( @@ -488,281 +224,89 @@ where let one = T::init(T::Native::one()); let right_no_invalid_zeros = match valid_mask { - Some(mask) => { - let simd_mask = T::mask_from_u64(mask); - // select `1` for invalid lanes, which will be a no-op during division later - T::mask_select(simd_mask, right, one) - } - None => right, - }; - - let zero_mask = T::eq(right_no_invalid_zeros, zero); - - if T::mask_any(zero_mask) { - Err(ArrowError::DivideByZero) - } else { - Ok(T::bin_op(left, right_no_invalid_zeros, |a, b| a / b)) - } -} - -/// Scalar implementation of `left % right` for the remainder elements after complete chunks have been processed using SIMD. -/// If any of the values marked as valid in `valid_mask` are `0` then an `ArrowError::DivideByZero` is returned. -#[cfg(feature = "simd")] -#[inline] -fn simd_checked_modulus_remainder( - valid_mask: Option, - left_chunks: ChunksExact, - right_chunks: ChunksExact, - result_chunks: ChunksExactMut, -) -> Result<()> -where - T::Native: Zero + Rem, -{ - let result_remainder = result_chunks.into_remainder(); - let left_remainder = left_chunks.remainder(); - let right_remainder = right_chunks.remainder(); - - result_remainder - .iter_mut() - .zip(left_remainder.iter().zip(right_remainder.iter())) - .enumerate() - .try_for_each(|(i, (result_scalar, (left_scalar, right_scalar)))| { - if valid_mask.map(|mask| mask & (1 << i) != 0).unwrap_or(true) { - if *right_scalar == T::Native::zero() { - return Err(ArrowError::DivideByZero); - } - *result_scalar = *left_scalar % *right_scalar; - } - Ok(()) - })?; - - Ok(()) -} - -/// Scalar implementation of `left / right` for the remainder elements after complete chunks have been processed using SIMD. -/// If any of the values marked as valid in `valid_mask` are `0` then an `ArrowError::DivideByZero` is returned. -#[cfg(feature = "simd")] -#[inline] -fn simd_checked_divide_remainder( - valid_mask: Option, - left_chunks: ChunksExact, - right_chunks: ChunksExact, - result_chunks: ChunksExactMut, -) -> Result<()> -where - T::Native: Zero + Div, -{ - let result_remainder = result_chunks.into_remainder(); - let left_remainder = left_chunks.remainder(); - let right_remainder = right_chunks.remainder(); - - result_remainder - .iter_mut() - .zip(left_remainder.iter().zip(right_remainder.iter())) - .enumerate() - .try_for_each(|(i, (result_scalar, (left_scalar, right_scalar)))| { - if valid_mask.map(|mask| mask & (1 << i) != 0).unwrap_or(true) { - if *right_scalar == T::Native::zero() { - return Err(ArrowError::DivideByZero); - } - *result_scalar = *left_scalar / *right_scalar; - } - Ok(()) - })?; - - Ok(()) -} - -/// Scalar-modulo version of `simd_checked_modulus_remainder`. -#[cfg(feature = "simd")] -#[inline] -fn simd_checked_modulus_scalar_remainder( - array_chunks: ChunksExact, - modulo: T::Native, - result_chunks: ChunksExactMut, -) -> Result<()> -where - T::Native: Zero + Rem, -{ - if modulo.is_zero() { - return Err(ArrowError::DivideByZero); - } - - let result_remainder = result_chunks.into_remainder(); - let array_remainder = array_chunks.remainder(); - - result_remainder - .iter_mut() - .zip(array_remainder.iter()) - .for_each(|(result_scalar, array_scalar)| { - *result_scalar = *array_scalar % modulo; - }); - - Ok(()) -} - -/// Scalar-divisor version of `simd_checked_divide_remainder`. -#[cfg(feature = "simd")] -#[inline] -fn simd_checked_divide_scalar_remainder( - array_chunks: ChunksExact, - divisor: T::Native, - result_chunks: ChunksExactMut, -) -> Result<()> -where - T::Native: Zero + Div, -{ - if divisor.is_zero() { - return Err(ArrowError::DivideByZero); - } - - let result_remainder = result_chunks.into_remainder(); - let array_remainder = array_chunks.remainder(); - - result_remainder - .iter_mut() - .zip(array_remainder.iter()) - .for_each(|(result_scalar, array_scalar)| { - *result_scalar = *array_scalar / divisor; - }); - - Ok(()) -} - -/// SIMD vectorized version of `modulus`. -/// -/// The modulus kernels need their own implementation as there is a need to handle situations -/// where a modulus by `0` occurs. This is complicated by `NULL` slots and padding. -#[cfg(feature = "simd")] -fn simd_modulus( - left: &PrimitiveArray, - right: &PrimitiveArray, -) -> Result> -where - T: ArrowNumericType, - T::Native: One + Zero + Rem, -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform math operation on arrays of different length".to_string(), - )); - } - - // Create the combined `Bitmap` - let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; - - let lanes = T::lanes(); - let buffer_size = left.len() * std::mem::size_of::(); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - match &null_bit_buffer { - Some(b) => { - // combine_option_bitmap returns a slice or new buffer starting at 0 - let valid_chunks = b.bit_chunks(0, left.len()); - - // process data in chunks of 64 elements since we also get 64 bits of validity information at a time - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64); - let mut left_chunks = left.values().chunks_exact(64); - let mut right_chunks = right.values().chunks_exact(64); - - valid_chunks - .iter() - .zip( - result_chunks - .borrow_mut() - .zip(left_chunks.borrow_mut().zip(right_chunks.borrow_mut())), - ) - .try_for_each( - |(mut mask, (result_slice, (left_slice, right_slice)))| { - // split chunks further into slices corresponding to the vector length - // the compiler is able to unroll this inner loop and remove bounds checks - // since the outer chunk size (64) is always a multiple of the number of lanes - result_slice - .chunks_exact_mut(lanes) - .zip(left_slice.chunks_exact(lanes).zip(right_slice.chunks_exact(lanes))) - .try_for_each(|(result_slice, (left_slice, right_slice))| -> Result<()> { - let simd_left = T::load(left_slice); - let simd_right = T::load(right_slice); - - let simd_result = simd_checked_modulus::(Some(mask), simd_left, simd_right)?; - - T::write(simd_result, result_slice); - - // skip the shift and avoid overflow for u8 type, which uses 64 lanes. - mask >>= T::lanes() % 64; - - Ok(()) - }) - }, - )?; - - let valid_remainder = valid_chunks.remainder_bits(); - - simd_checked_modulus_remainder::( - Some(valid_remainder), - left_chunks, - right_chunks, - result_chunks, - )?; - } - None => { - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut left_chunks = left.values().chunks_exact(lanes); - let mut right_chunks = right.values().chunks_exact(lanes); - - result_chunks - .borrow_mut() - .zip(left_chunks.borrow_mut().zip(right_chunks.borrow_mut())) - .try_for_each( - |(result_slice, (left_slice, right_slice))| -> Result<()> { - let simd_left = T::load(left_slice); - let simd_right = T::load(right_slice); + Some(mask) => { + let simd_mask = T::mask_from_u64(mask); + // select `1` for invalid lanes, which will be a no-op during division later + T::mask_select(simd_mask, right, one) + } + None => right, + }; - let simd_result = - simd_checked_modulus::(None, simd_left, simd_right)?; + let zero_mask = T::eq(right_no_invalid_zeros, zero); - T::write(simd_result, result_slice); + if T::mask_any(zero_mask) { + Err(ArrowError::DivideByZero) + } else { + Ok(T::bin_op(left, right_no_invalid_zeros, |a, b| a / b)) + } +} - Ok(()) - }, - )?; +/// Applies `op` on the remainder elements of two input chunks and writes the result into +/// the remainder elements of `result_chunks`. +/// The lower-most bits of `valid_mask` specify which elements are considered as valid. +/// +/// # Errors +/// +/// This function returns a [`ArrowError::DivideByZero`] if a valid element in `right` is `0` +#[cfg(feature = "simd")] +#[inline] +fn simd_checked_divide_op_remainder( + valid_mask: Option, + left_chunks: ChunksExact, + right_chunks: ChunksExact, + result_chunks: ChunksExactMut, + op: F, +) -> Result<()> +where + T: ArrowNumericType, + T::Native: Zero, + F: Fn(T::Native, T::Native) -> T::Native, +{ + let result_remainder = result_chunks.into_remainder(); + let left_remainder = left_chunks.remainder(); + let right_remainder = right_chunks.remainder(); - simd_checked_modulus_remainder::( - None, - left_chunks, - right_chunks, - result_chunks, - )?; - } - } + result_remainder + .iter_mut() + .zip(left_remainder.iter().zip(right_remainder.iter())) + .enumerate() + .try_for_each(|(i, (result_scalar, (left_scalar, right_scalar)))| { + if valid_mask.map(|mask| mask & (1 << i) != 0).unwrap_or(true) { + if *right_scalar == T::Native::zero() { + return Err(ArrowError::DivideByZero); + } + *result_scalar = op(*left_scalar, *right_scalar); + } else { + *result_scalar = T::default_value(); + } + Ok(()) + })?; - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - left.len(), - None, - null_bit_buffer, - 0, - vec![result.into()], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) + Ok(()) } -/// SIMD vectorized version of `divide`. +/// Creates a new PrimitiveArray by applying `simd_op` to the `left` and `right` input array. +/// If the length of the arrays is not multiple of the number of vector lanes +/// then the remainder of the array will be calculated using `scalar_op`. +/// Any operation on a `NULL` value will result in a `NULL` value in the output. /// -/// The divide kernels need their own implementation as there is a need to handle situations -/// where a divide by `0` occurs. This is complicated by `NULL` slots and padding. +/// # Errors +/// +/// This function errors if: +/// * the arrays have different lengths +/// * there is an element where both left and right values are valid and the right value is `0` #[cfg(feature = "simd")] -fn simd_divide( +fn simd_checked_divide_op( left: &PrimitiveArray, right: &PrimitiveArray, + simd_op: SI, + scalar_op: SC, ) -> Result> where T: ArrowNumericType, - T::Native: One + Zero + Div, + T::Native: One + Zero, + SI: Fn(Option, T::Simd, T::Simd) -> Result, + SC: Fn(T::Native, T::Native) -> T::Native, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -772,7 +316,7 @@ where // Create the combined `Bitmap` let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; + combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; let lanes = T::lanes(); let buffer_size = left.len() * std::mem::size_of::(); @@ -784,7 +328,10 @@ where let valid_chunks = b.bit_chunks(0, left.len()); // process data in chunks of 64 elements since we also get 64 bits of validity information at a time - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64); + + // safety: result is newly created above, always written as a T below + let mut result_chunks = + unsafe { result.typed_data_mut().chunks_exact_mut(64) }; let mut left_chunks = left.values().chunks_exact(64); let mut right_chunks = right.values().chunks_exact(64); @@ -807,7 +354,7 @@ where let simd_left = T::load(left_slice); let simd_right = T::load(right_slice); - let simd_result = simd_checked_divide::(Some(mask), simd_left, simd_right)?; + let simd_result = simd_op(Some(mask), simd_left, simd_right)?; T::write(simd_result, result_slice); @@ -821,15 +368,18 @@ where let valid_remainder = valid_chunks.remainder_bits(); - simd_checked_divide_remainder::( + simd_checked_divide_op_remainder::( Some(valid_remainder), left_chunks, right_chunks, result_chunks, + scalar_op, )?; } None => { - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); + // safety: result is newly created above, always written as a T below + let mut result_chunks = + unsafe { result.typed_data_mut().chunks_exact_mut(lanes) }; let mut left_chunks = left.values().chunks_exact(lanes); let mut right_chunks = right.values().chunks_exact(lanes); @@ -841,8 +391,7 @@ where let simd_left = T::load(left_slice); let simd_right = T::load(right_slice); - let simd_result = - simd_checked_divide::(None, simd_left, simd_right)?; + let simd_result = simd_op(None, simd_left, simd_right)?; T::write(simd_result, result_slice); @@ -850,11 +399,12 @@ where }, )?; - simd_checked_divide_remainder::( + simd_checked_divide_op_remainder::( None, left_chunks, right_chunks, result_chunks, + scalar_op, )?; } } @@ -873,133 +423,50 @@ where Ok(PrimitiveArray::::from(data)) } -/// SIMD vectorized version of `modulus_scalar`. -#[cfg(feature = "simd")] -fn simd_modulus_scalar( - array: &PrimitiveArray, - modulo: T::Native, +/// Perform `left + right` operation on two arrays. If either left or right value is null +/// then the result is also null. +pub fn add( + left: &PrimitiveArray, + right: &PrimitiveArray, ) -> Result> where T: ArrowNumericType, - T::Native: One + Zero + Rem, + T::Native: Add, { - if modulo.is_zero() { - return Err(ArrowError::DivideByZero); - } - - let lanes = T::lanes(); - let buffer_size = array.len() * std::mem::size_of::(); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut array_chunks = array.values().chunks_exact(lanes); - - result_chunks - .borrow_mut() - .zip(array_chunks.borrow_mut()) - .for_each(|(result_slice, array_slice)| { - let simd_left = T::load(array_slice); - let simd_right = T::init(modulo); - - let simd_result = T::bin_op(simd_left, simd_right, |a, b| a % b); - T::write(simd_result, result_slice); - }); - - simd_checked_modulus_scalar_remainder::(array_chunks, modulo, result_chunks)?; - - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - array.len(), - None, - array - .data_ref() - .null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())), - 0, - vec![result.into()], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) + math_op(left, right, |a, b| a + b) } -/// SIMD vectorized version of `divide_scalar`. -#[cfg(feature = "simd")] -fn simd_divide_scalar( +/// Add every value in an array by a scalar. If any value in the array is null then the +/// result is also null. +pub fn add_scalar( array: &PrimitiveArray, - divisor: T::Native, + scalar: T::Native, ) -> Result> where - T: ArrowNumericType, - T::Native: One + Zero + Div, + T: datatypes::ArrowNumericType, + T::Native: Add, { - if divisor.is_zero() { - return Err(ArrowError::DivideByZero); - } - - let lanes = T::lanes(); - let buffer_size = array.len() * std::mem::size_of::(); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes); - let mut array_chunks = array.values().chunks_exact(lanes); - - result_chunks - .borrow_mut() - .zip(array_chunks.borrow_mut()) - .for_each(|(result_slice, array_slice)| { - let simd_left = T::load(array_slice); - let simd_right = T::init(divisor); - - let simd_result = T::bin_op(simd_left, simd_right, |a, b| a / b); - T::write(simd_result, result_slice); - }); - - simd_checked_divide_scalar_remainder::(array_chunks, divisor, result_chunks)?; - - let data = unsafe { - ArrayData::new_unchecked( - T::DATA_TYPE, - array.len(), - None, - array - .data_ref() - .null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())), - 0, - vec![result.into()], - vec![], - ) - }; - Ok(PrimitiveArray::::from(data)) + Ok(unary(array, |value| value + scalar)) } -/// Perform `left + right` operation on two arrays. If either left or right value is null +/// Perform `left - right` operation on two arrays. If either left or right value is null /// then the result is also null. -pub fn add( +pub fn subtract( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> where - T: ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Zero, + T: datatypes::ArrowNumericType, + T::Native: Sub, { - #[cfg(feature = "simd")] - return simd_math_op(&left, &right, |a, b| a + b, |a, b| a + b); - #[cfg(not(feature = "simd"))] - return math_op(left, right, |a, b| a + b); + math_op(left, right, |a, b| a - b) } -/// Perform `left - right` operation on two arrays. If either left or right value is null -/// then the result is also null. -pub fn subtract( - left: &PrimitiveArray, - right: &PrimitiveArray, +/// Subtract every value in an array by a scalar. If any value in the array is null then the +/// result is also null. +pub fn subtract_scalar( + array: &PrimitiveArray, + scalar: T::Native, ) -> Result> where T: datatypes::ArrowNumericType, @@ -1009,22 +476,16 @@ where + Div + Zero, { - #[cfg(feature = "simd")] - return simd_math_op(&left, &right, |a, b| a - b, |a, b| a - b); - #[cfg(not(feature = "simd"))] - return math_op(left, right, |a, b| a - b); + Ok(unary(array, |value| value - scalar)) } /// Perform `-` operation on an array. If value is null then the result is also null. pub fn negate(array: &PrimitiveArray) -> Result> where - T: datatypes::ArrowSignedNumericType, + T: datatypes::ArrowNumericType, T::Native: Neg, { - #[cfg(feature = "simd")] - return simd_signed_unary_math_op(array, |x| -x, |x| -x); - #[cfg(not(feature = "simd"))] - return Ok(unary(array, |x| -x)); + Ok(unary(array, |x| -x)) } /// Raise array with floating point values to the power of a scalar. @@ -1036,17 +497,7 @@ where T: datatypes::ArrowFloatNumericType, T::Native: Pow, { - #[cfg(feature = "simd")] - { - let raise_vector = T::init(raise); - return simd_float_unary_math_op( - array, - |x| T::pow(x, raise_vector), - |x| x.pow(raise), - ); - } - #[cfg(not(feature = "simd"))] - return Ok(unary(array, |x| x.pow(raise))); + Ok(unary(array, |x| x.pow(raise))) } /// Perform `left * right` operation on two arrays. If either left or right value is null @@ -1055,6 +506,19 @@ pub fn multiply( left: &PrimitiveArray, right: &PrimitiveArray, ) -> Result> +where + T: datatypes::ArrowNumericType, + T::Native: Mul, +{ + math_op(left, right, |a, b| a * b) +} + +/// Multiply every value in an array by a scalar. If any value in the array is null then the +/// result is also null. +pub fn multiply_scalar( + array: &PrimitiveArray, + scalar: T::Native, +) -> Result> where T: datatypes::ArrowNumericType, T::Native: Add @@ -1062,12 +526,10 @@ where + Mul + Div + Rem - + Zero, + + Zero + + One, { - #[cfg(feature = "simd")] - return simd_math_op(&left, &right, |a, b| a * b, |a, b| a * b); - #[cfg(not(feature = "simd"))] - return math_op(left, right, |a, b| a * b); + Ok(unary(array, |value| value * scalar)) } /// Perform `left % right` operation on two arrays. If either left or right value is null @@ -1079,18 +541,14 @@ pub fn modulus( ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Rem - + Zero - + One, + T::Native: Rem + Zero + One, { #[cfg(feature = "simd")] - return simd_modulus(&left, &right); + return simd_checked_divide_op(&left, &right, simd_checked_modulus::, |a, b| { + a % b + }); #[cfg(not(feature = "simd"))] - return math_modulus(left, right); + return math_checked_divide_op(left, right, |a, b| a % b); } /// Perform `left / right` operation on two arrays. If either left or right value is null @@ -1102,18 +560,26 @@ pub fn divide( ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Rem - + Zero - + One, + T::Native: Div + Zero + One, { #[cfg(feature = "simd")] - return simd_divide(&left, &right); + return simd_checked_divide_op(&left, &right, simd_checked_divide::, |a, b| a / b); #[cfg(not(feature = "simd"))] - return math_divide(left, right); + return math_checked_divide_op(left, right, |a, b| a / b); +} + +/// Perform `left / right` operation on two arrays without checking for division by zero. +/// The result of dividing by zero follows normal floating point rules. +/// If either left or right value is null then the result is also null. If any right hand value is zero then the result of this +pub fn divide_unchecked( + left: &PrimitiveArray, + right: &PrimitiveArray, +) -> Result> +where + T: datatypes::ArrowFloatNumericType, + T::Native: Div, +{ + math_op(left, right, |a, b| a / b) } /// Modulus every value in an array by a scalar. If any value in the array is null then the @@ -1125,18 +591,13 @@ pub fn modulus_scalar( ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Rem - + Zero - + One, + T::Native: Rem + Zero, { - #[cfg(feature = "simd")] - return simd_modulus_scalar(&array, modulo); - #[cfg(not(feature = "simd"))] - return math_modulus_scalar(array, modulo); + if modulo.is_zero() { + return Err(ArrowError::DivideByZero); + } + + Ok(unary(array, |a| a % modulo)) } /// Divide every value in an array by a scalar. If any value in the array is null then the @@ -1148,18 +609,12 @@ pub fn divide_scalar( ) -> Result> where T: datatypes::ArrowNumericType, - T::Native: Add - + Sub - + Mul - + Div - + Rem - + Zero - + One, + T::Native: Div + Zero, { - #[cfg(feature = "simd")] - return simd_divide_scalar(&array, divisor); - #[cfg(not(feature = "simd"))] - return math_divide_scalar(array, divisor); + if divisor.is_zero() { + return Err(ArrowError::DivideByZero); + } + Ok(unary(array, |a| a / divisor)) } #[cfg(test)] @@ -1213,6 +668,25 @@ mod tests { ); } + #[test] + fn test_primitive_array_add_scalar() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = add_scalar(&a, b).unwrap(); + let expected = Int32Array::from(vec![18, 17, 12, 11, 4]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_array_add_scalar_sliced() { + let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); + let a = a.slice(1, 4); + let a = as_primitive_array(&a); + let actual = add_scalar(a, 3).unwrap(); + let expected = Int32Array::from(vec![None, Some(12), Some(11), None]); + assert_eq!(actual, expected); + } + #[test] fn test_primitive_array_subtract() { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); @@ -1225,6 +699,25 @@ mod tests { assert_eq!(4, c.value(4)); } + #[test] + fn test_primitive_array_subtract_scalar() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = subtract_scalar(&a, b).unwrap(); + let expected = Int32Array::from(vec![12, 11, 6, 5, -2]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_array_subtract_scalar_sliced() { + let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); + let a = a.slice(1, 4); + let a = as_primitive_array(&a); + let actual = subtract_scalar(a, 3).unwrap(); + let expected = Int32Array::from(vec![None, Some(6), Some(5), None]); + assert_eq!(actual, expected); + } + #[test] fn test_primitive_array_multiply() { let a = Int32Array::from(vec![5, 6, 7, 8, 9]); @@ -1237,6 +730,25 @@ mod tests { assert_eq!(72, c.value(4)); } + #[test] + fn test_primitive_array_multiply_scalar() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = 3; + let c = multiply_scalar(&a, b).unwrap(); + let expected = Int32Array::from(vec![45, 42, 27, 24, 3]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_array_multiply_scalar_sliced() { + let a = Int32Array::from(vec![Some(15), None, Some(9), Some(8), None]); + let a = a.slice(1, 4); + let a = as_primitive_array(&a); + let actual = multiply_scalar(a, 3).unwrap(); + let expected = Int32Array::from(vec![None, Some(27), Some(24), None]); + assert_eq!(actual, expected); + } + #[test] fn test_primitive_array_divide() { let a = Int32Array::from(vec![15, 15, 8, 1, 9]); @@ -1508,9 +1020,9 @@ mod tests { let a = Float64Array::from(vec![15.0, 15.0, 8.0]); let b = Float64Array::from(vec![5.0, 6.0, 8.0]); let c = divide(&a, &b).unwrap(); - assert!(3.0 - c.value(0) < f64::EPSILON); - assert!(2.5 - c.value(1) < f64::EPSILON); - assert!(1.0 - c.value(2) < f64::EPSILON); + assert_eq!(3.0, c.value(0)); + assert_eq!(2.5, c.value(1)); + assert_eq!(1.0, c.value(2)); } #[test] diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 41206e001d77..60a0cb77fe20 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -42,12 +42,12 @@ fn into_primitive_array_data( } } -/// Applies an unary and infalible function to a primitive array. +/// Applies an unary and infallible function to a primitive array. /// This is the fastest way to perform an operation on a primitive array when /// the benefits of a vectorized operation outweights the cost of branching nulls and non-nulls. /// # Implementation /// This will apply the function for all values, including those on null slots. -/// This implies that the operation must be infalible for any value of the corresponding type +/// This implies that the operation must be infallible for any value of the corresponding type /// or this function may panic. /// # Example /// ```rust diff --git a/arrow/src/compute/kernels/boolean.rs b/arrow/src/compute/kernels/boolean.rs index fd3539455b5a..209edc48d195 100644 --- a/arrow/src/compute/kernels/boolean.rs +++ b/arrow/src/compute/kernels/boolean.rs @@ -193,7 +193,7 @@ where let left_data = left.data_ref(); let right_data = right.data_ref(); - let null_bit_buffer = combine_option_bitmap(left_data, right_data, len)?; + let null_bit_buffer = combine_option_bitmap(&[left_data, right_data], len)?; let left_buffer = &left_data.buffers()[0]; let right_buffer = &right_data.buffers()[0]; @@ -1010,7 +1010,7 @@ mod tests { let expected = BooleanArray::from(vec![false, false, false, false]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] @@ -1023,7 +1023,7 @@ mod tests { let expected = BooleanArray::from(vec![false, false, false, false]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] @@ -1035,7 +1035,7 @@ mod tests { let expected = BooleanArray::from(vec![true, true, true, true]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] @@ -1048,7 +1048,7 @@ mod tests { let expected = BooleanArray::from(vec![true, true, true, true]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] @@ -1060,7 +1060,7 @@ mod tests { let expected = BooleanArray::from(vec![false, true, false, true]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] @@ -1091,7 +1091,7 @@ mod tests { let expected = BooleanArray::from(vec![false, true, false, true]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] @@ -1103,7 +1103,7 @@ mod tests { let expected = BooleanArray::from(vec![true, false, true, false]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] @@ -1134,7 +1134,7 @@ mod tests { let expected = BooleanArray::from(vec![true, false, true, false]); assert_eq!(expected, res); - assert_eq!(&None, res.data_ref().null_bitmap()); + assert_eq!(None, res.data_ref().null_bitmap()); } #[test] diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 69733fa1bd9b..fa92179b747c 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -68,6 +68,79 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } match (from_type, to_type) { + // TODO UTF8/unsigned numeric to decimal + // cast one decimal type to another decimal type + (Decimal(_, _), Decimal(_, _)) => true, + // signed numeric to decimal + (Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) | + // decimal to signed numeric + (Decimal(_, _), Int8 | Int16 | Int32 | Int64 | Float32 | Float64) + | ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + Null, + ) => true, + (Decimal(_, _), _) => false, + (_, Decimal(_, _)) => false, (Struct(_), _) => false, (_, Struct(_)) => false, (LargeList(list_from), LargeList(list_to)) => { @@ -88,138 +161,78 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), - (_, Boolean) => DataType::is_numeric(from_type), + (_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8, (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8, (Utf8, LargeUtf8) => true, (LargeUtf8, Utf8) => true, - (Utf8, Date32) => true, - (Utf8, Date64) => true, - (Utf8, Timestamp(TimeUnit::Nanosecond, None)) => true, + (Utf8, Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, None)) => true, (Utf8, _) => DataType::is_numeric(to_type), - (LargeUtf8, Date32) => true, - (LargeUtf8, Date64) => true, - (LargeUtf8, Timestamp(TimeUnit::Nanosecond, None)) => true, + (LargeUtf8, Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, None)) => true, (LargeUtf8, _) => DataType::is_numeric(to_type), (Timestamp(_, _), Utf8) | (Timestamp(_, _), LargeUtf8) => true, - (_, Utf8) | (_, LargeUtf8) => { - DataType::is_numeric(from_type) || from_type == &Binary - } + (Date32, Utf8) | (Date32, LargeUtf8) => true, + (Date64, Utf8) | (Date64, LargeUtf8) => true, + (_, Utf8 | LargeUtf8) => DataType::is_numeric(from_type) || from_type == &Binary, // start numeric casts - (UInt8, UInt16) => true, - (UInt8, UInt32) => true, - (UInt8, UInt64) => true, - (UInt8, Int8) => true, - (UInt8, Int16) => true, - (UInt8, Int32) => true, - (UInt8, Int64) => true, - (UInt8, Float32) => true, - (UInt8, Float64) => true, - - (UInt16, UInt8) => true, - (UInt16, UInt32) => true, - (UInt16, UInt64) => true, - (UInt16, Int8) => true, - (UInt16, Int16) => true, - (UInt16, Int32) => true, - (UInt16, Int64) => true, - (UInt16, Float32) => true, - (UInt16, Float64) => true, - - (UInt32, UInt8) => true, - (UInt32, UInt16) => true, - (UInt32, UInt64) => true, - (UInt32, Int8) => true, - (UInt32, Int16) => true, - (UInt32, Int32) => true, - (UInt32, Int64) => true, - (UInt32, Float32) => true, - (UInt32, Float64) => true, - - (UInt64, UInt8) => true, - (UInt64, UInt16) => true, - (UInt64, UInt32) => true, - (UInt64, Int8) => true, - (UInt64, Int16) => true, - (UInt64, Int32) => true, - (UInt64, Int64) => true, - (UInt64, Float32) => true, - (UInt64, Float64) => true, - - (Int8, UInt8) => true, - (Int8, UInt16) => true, - (Int8, UInt32) => true, - (Int8, UInt64) => true, - (Int8, Int16) => true, - (Int8, Int32) => true, - (Int8, Int64) => true, - (Int8, Float32) => true, - (Int8, Float64) => true, - - (Int16, UInt8) => true, - (Int16, UInt16) => true, - (Int16, UInt32) => true, - (Int16, UInt64) => true, - (Int16, Int8) => true, - (Int16, Int32) => true, - (Int16, Int64) => true, - (Int16, Float32) => true, - (Int16, Float64) => true, - - (Int32, UInt8) => true, - (Int32, UInt16) => true, - (Int32, UInt32) => true, - (Int32, UInt64) => true, - (Int32, Int8) => true, - (Int32, Int16) => true, - (Int32, Int64) => true, - (Int32, Float32) => true, - (Int32, Float64) => true, - - (Int64, UInt8) => true, - (Int64, UInt16) => true, - (Int64, UInt32) => true, - (Int64, UInt64) => true, - (Int64, Int8) => true, - (Int64, Int16) => true, - (Int64, Int32) => true, - (Int64, Float32) => true, - (Int64, Float64) => true, - - (Float32, UInt8) => true, - (Float32, UInt16) => true, - (Float32, UInt32) => true, - (Float32, UInt64) => true, - (Float32, Int8) => true, - (Float32, Int16) => true, - (Float32, Int32) => true, - (Float32, Int64) => true, - (Float32, Float64) => true, - - (Float64, UInt8) => true, - (Float64, UInt16) => true, - (Float64, UInt32) => true, - (Float64, UInt64) => true, - (Float64, Int8) => true, - (Float64, Int16) => true, - (Float64, Int32) => true, - (Float64, Int64) => true, - (Float64, Float32) => true, + ( + UInt8, + UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + ) => true, + + ( + UInt16, + UInt8 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + ) => true, + + ( + UInt32, + UInt8 | UInt16 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + ) => true, + + ( + UInt64, + UInt8 | UInt16 | UInt32 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + ) => true, + + ( + Int8, + UInt8 | UInt16 | UInt32 | UInt64 | Int16 | Int32 | Int64 | Float32 | Float64, + ) => true, + + ( + Int16, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int32 | Int64 | Float32 | Float64, + ) => true, + + ( + Int32, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int64 | Float32 | Float64, + ) => true, + + ( + Int64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Float32 | Float64, + ) => true, + + ( + Float32, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float64, + ) => true, + + ( + Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32, + ) => true, // end numeric casts // temporal casts - (Int32, Date32) => true, - (Int32, Date64) => true, - (Int32, Time32(_)) => true, - (Date32, Int32) => true, - (Date32, Int64) => true, + (Int32, Date32 | Date64 | Time32(_)) => true, + (Date32, Int32 | Int64) => true, (Time32(_), Int32) => true, - (Int64, Date64) => true, - (Int64, Date32) => true, - (Int64, Time64(_)) => true, - (Date64, Int64) => true, - (Date64, Int32) => true, + (Int64, Date64 | Date32 | Time64(_)) => true, + (Date64, Int64 | Int32) => true, (Time64(_), Int64) => true, (Date32, Date64) => true, (Date64, Date32) => true, @@ -233,12 +246,31 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (Timestamp(_, _), Int64) => true, (Int64, Timestamp(_, _)) => true, - (Timestamp(_, _), Timestamp(_, _)) => true, - (Timestamp(_, _), Date32) => true, - (Timestamp(_, _), Date64) => true, + (Timestamp(_, _), Timestamp(_, _) | Date32 | Date64) => true, // date64 to timestamp might not make sense, (Int64, Duration(_)) => true, - (Null, Int32) => true, + (Duration(_), Int64) => true, + (Interval(from_type), Int64) => { + match from_type{ + IntervalUnit::YearMonth => true, + IntervalUnit::DayTime => true, + IntervalUnit::MonthDayNano => false, // Native type is i128 + } + }, + (Int32, Interval(to_type)) => { + match to_type{ + IntervalUnit::YearMonth => true, + IntervalUnit::DayTime => false, + IntervalUnit::MonthDayNano => false, + } + }, + (Int64, Interval(to_type)) => { + match to_type{ + IntervalUnit::YearMonth => false, + IntervalUnit::DayTime => true, + IntervalUnit::MonthDayNano => false, + } + } (_, _) => false, } } @@ -248,6 +280,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { /// /// Behavior: /// * Boolean to Utf8: `true` => '1', `false` => `0` +/// * Utf8 to boolean: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`, +/// short variants are accepted, other strings return null or error /// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings /// in integer casts return null /// * Numeric to boolean: 0 returns `false`, any other value returns `true` @@ -261,18 +295,110 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { /// Unsupported Casts /// * To or from `StructArray` /// * List to primitive -/// * Utf8 to boolean /// * Interval and duration pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS) } +// cast the integer array to defined decimal data type array +macro_rules! cast_integer_to_decimal { + ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let mul: i128 = 10_i128.pow(*$SCALE as u32); + let decimal_array = array + .iter() + .map(|v| { + v.map(|v| { + let v = v as i128; + // with_precision_and_scale validates the + // value is within range for the output precision + mul * v + }) + }) + .collect::() + .with_precision_and_scale(*$PRECISION, *$SCALE)?; + Ok(Arc::new(decimal_array)) + }}; +} + +// cast the floating-point array to defined decimal data type array +macro_rules! cast_floating_point_to_decimal { + ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + let mul = 10_f64.powi(*$SCALE as i32); + let decimal_array = array + .iter() + .map(|v| { + v.map(|v| { + // with_precision_and_scale validates the + // value is within range for the output precision + ((v as f64) * mul) as i128 + }) + }) + .collect::() + .with_precision_and_scale(*$PRECISION, *$SCALE)?; + Ok(Arc::new(decimal_array)) + }}; +} + +// cast the decimal array to integer array +macro_rules! cast_decimal_to_integer { + ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ident, $DATA_TYPE : expr) => {{ + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + let mut value_builder = $VALUE_BUILDER::new(array.len()); + let div: i128 = 10_i128.pow(*$SCALE as u32); + let min_bound = ($NATIVE_TYPE::MIN) as i128; + let max_bound = ($NATIVE_TYPE::MAX) as i128; + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null()?; + } else { + let v = array.value(i).as_i128() / div; + // check the overflow + // For example: Decimal(128,10,0) as i8 + // 128 is out of range i8 + if v <= max_bound && v >= min_bound { + value_builder.append_value(v as $NATIVE_TYPE)?; + } else { + return Err(ArrowError::CastError(format!( + "value of {} is out of range {}", + v, $DATA_TYPE + ))); + } + } + } + Ok(Arc::new(value_builder.finish())) + }}; +} + +// cast the decimal array to floating-point array +macro_rules! cast_decimal_to_float { + ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ty) => {{ + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + let div = 10_f64.powi(*$SCALE as i32); + let mut value_builder = $VALUE_BUILDER::new(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + value_builder.append_null()?; + } else { + // The range of f32 or f64 is larger than i128, we don't need to check overflow. + // cast the i128 to f64 will lose precision, for example the `112345678901234568` will be as `112345678901234560`. + let v = (array.value(i).as_i128() as f64 / div) as $NATIVE_TYPE; + value_builder.append_value(v)?; + } + } + Ok(Arc::new(value_builder.finish())) + }}; +} + /// Cast `array` to the provided data type and return a new Array with /// type `to_type`, if possible. It accepts `CastOptions` to allow consumers /// to configure cast behavior. /// /// Behavior: /// * Boolean to Utf8: `true` => '1', `false` => `0` +/// * Utf8 to boolean: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`, +/// short variants are accepted, other strings return null or error /// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings /// in integer casts return null /// * Numeric to boolean: 0 returns `false`, any other value returns `true` @@ -286,8 +412,6 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { /// Unsupported Casts /// * To or from `StructArray` /// * List to primitive -/// * Utf8 to boolean -/// * Interval and duration pub fn cast_with_options( array: &ArrayRef, to_type: &DataType, @@ -301,6 +425,126 @@ pub fn cast_with_options( return Ok(array.clone()); } match (from_type, to_type) { + (Decimal(_, s1), Decimal(p2, s2)) => cast_decimal_to_decimal(array, s1, p2, s2), + (Decimal(_, scale), _) => { + // cast decimal to other type + match to_type { + Int8 => { + cast_decimal_to_integer!(array, scale, Int8Builder, i8, Int8) + } + Int16 => { + cast_decimal_to_integer!(array, scale, Int16Builder, i16, Int16) + } + Int32 => { + cast_decimal_to_integer!(array, scale, Int32Builder, i32, Int32) + } + Int64 => { + cast_decimal_to_integer!(array, scale, Int64Builder, i64, Int64) + } + Float32 => { + cast_decimal_to_float!(array, scale, Float32Builder, f32) + } + Float64 => { + cast_decimal_to_float!(array, scale, Float64Builder, f64) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type + ))), + } + } + (_, Decimal(precision, scale)) => { + // cast data to decimal + match from_type { + // TODO now just support signed numeric to decimal, support decimal to numeric later + Int8 => { + cast_integer_to_decimal!(array, Int8Array, precision, scale) + } + Int16 => { + cast_integer_to_decimal!(array, Int16Array, precision, scale) + } + Int32 => { + cast_integer_to_decimal!(array, Int32Array, precision, scale) + } + Int64 => { + cast_integer_to_decimal!(array, Int64Array, precision, scale) + } + Float32 => { + cast_floating_point_to_decimal!(array, Float32Array, precision, scale) + } + Float64 => { + cast_floating_point_to_decimal!(array, Float64Array, precision, scale) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type + ))), + } + } + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | Timestamp(_, _) + | Time64(_) + | Duration(_) + | Interval(_) + | FixedSizeBinary(_) + | Binary + | Utf8 + | LargeBinary + | LargeUtf8 + | List(_) + | LargeList(_) + | FixedSizeList(_, _) + | Struct(_) + | Map(_, _) + | Dictionary(_, _), + Null, + ) => Ok(new_null_array(to_type, array.len())), (Struct(_), _) => Err(ArrowError::CastError( "Cannot cast from struct to other types".to_string(), )), @@ -401,10 +645,7 @@ pub fn cast_with_options( Int64 => cast_numeric_to_bool::(array), Float32 => cast_numeric_to_bool::(array), Float64 => cast_numeric_to_bool::(array), - Utf8 => Err(ArrowError::CastError(format!( - "Casting from {:?} to {:?} not supported", - from_type, to_type, - ))), + Utf8 => cast_utf8_to_boolean(array, cast_options), _ => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -483,6 +724,8 @@ pub fn cast_with_options( cast_timestamp_to_string::(array) } }, + Date32 => cast_date32_to_string::(array), + Date64 => cast_date64_to_string::(array), Binary => { let array = array.as_any().downcast_ref::().unwrap(); Ok(Arc::new( @@ -537,6 +780,8 @@ pub fn cast_with_options( cast_timestamp_to_string::(array) } }, + Date32 => cast_date32_to_string::(array), + Date64 => cast_date64_to_string::(array), Binary => { let array = array.as_any().downcast_ref::().unwrap(); Ok(Arc::new( @@ -946,10 +1191,35 @@ pub fn cast_with_options( } } } - - // null to primitive/flat types - (Null, Int32) => Ok(Arc::new(Int32Array::from(vec![None; array.len()]))), - + (Duration(_), Int64) => cast_array_data::(array, to_type.clone()), + (Interval(from_type), Int64) => match from_type { + IntervalUnit::YearMonth => { + cast_numeric_arrays::(array) + } + IntervalUnit::DayTime => cast_array_data::(array, to_type.clone()), + IntervalUnit::MonthDayNano => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (Int32, Interval(to_type)) => match to_type { + IntervalUnit::YearMonth => { + cast_array_data::(array, Interval(to_type.clone())) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, + (Int64, Interval(to_type)) => match to_type { + IntervalUnit::DayTime => { + cast_array_data::(array, Interval(to_type.clone())) + } + _ => Err(ArrowError::CastError(format!( + "Casting from {:?} to {:?} not supported", + from_type, to_type, + ))), + }, (_, _) => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, @@ -980,6 +1250,37 @@ const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; /// Number of days between 0001-01-01 and 1970-01-01 const EPOCH_DAYS_FROM_CE: i32 = 719_163; +/// Cast one type of decimal array to another type of decimal array +fn cast_decimal_to_decimal( + array: &ArrayRef, + input_scale: &usize, + output_precision: &usize, + output_scale: &usize, +) -> Result { + let array = array.as_any().downcast_ref::().unwrap(); + + let output_array = if input_scale > output_scale { + // For example, input_scale is 4 and output_scale is 3; + // Original value is 11234_i128, and will be cast to 1123_i128. + let div = 10_i128.pow((input_scale - output_scale) as u32); + array + .iter() + .map(|v| v.map(|v| v / div)) + .collect::() + } else { + // For example, input_scale is 3 and output_scale is 4; + // Original value is 1123_i128, and will be cast to 11230_i128. + let mul = 10_i128.pow((output_scale - input_scale) as u32); + array + .iter() + .map(|v| v.map(|v| v * mul)) + .collect::() + } + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) +} + /// Cast an array by changing its array_data type to the desired type /// /// Arrays should have the same primitive data type, otherwise this should fail. @@ -994,7 +1295,11 @@ where to_type, array.len(), Some(array.null_count()), - array.data().null_bitmap().clone().map(|bitmap| bitmap.bits), + array + .data() + .null_bitmap() + .cloned() + .map(|bitmap| bitmap.bits), array.data().offset(), array.data().buffers().to_vec(), vec![], @@ -1039,7 +1344,7 @@ fn cast_timestamp_to_string(array: &ArrayRef) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From<::Native>, - OffsetSize: StringOffsetSizeTrait, + OffsetSize: OffsetSizeTrait, { let array = array.as_any().downcast_ref::>().unwrap(); @@ -1056,12 +1361,50 @@ where )) } +/// Cast date32 types to Utf8/LargeUtf8 +fn cast_date32_to_string( + array: &ArrayRef, +) -> Result { + let array = array.as_any().downcast_ref::().unwrap(); + + Ok(Arc::new( + (0..array.len()) + .map(|ix| { + if array.is_null(ix) { + None + } else { + array.value_as_date(ix).map(|v| v.to_string()) + } + }) + .collect::>(), + )) +} + +/// Cast date64 types to Utf8/LargeUtf8 +fn cast_date64_to_string( + array: &ArrayRef, +) -> Result { + let array = array.as_any().downcast_ref::().unwrap(); + + Ok(Arc::new( + (0..array.len()) + .map(|ix| { + if array.is_null(ix) { + None + } else { + array.value_as_datetime(ix).map(|v| v.to_string()) + } + }) + .collect::>(), + )) +} + /// Cast numeric types to Utf8 fn cast_numeric_to_string(array: &ArrayRef) -> Result where FROM: ArrowNumericType, FROM::Native: lexical_core::ToLexical, - OffsetSize: StringOffsetSizeTrait, + OffsetSize: OffsetSizeTrait, { Ok(Arc::new(numeric_to_string_cast::( array @@ -1077,7 +1420,7 @@ fn numeric_to_string_cast( where T: ArrowPrimitiveType + ArrowNumericType, T::Native: lexical_core::ToLexical, - OffsetSize: StringOffsetSizeTrait, + OffsetSize: OffsetSizeTrait, { from.iter() .map(|maybe_value| maybe_value.map(lexical_to_string)) @@ -1085,7 +1428,7 @@ where } /// Cast numeric types to Utf8 -fn cast_string_to_numeric( +fn cast_string_to_numeric( from: &ArrayRef, cast_options: &CastOptions, ) -> Result @@ -1101,7 +1444,7 @@ where )?)) } -fn string_to_numeric_cast( +fn string_to_numeric_cast( from: &GenericStringArray, cast_options: &CastOptions, ) -> Result> @@ -1150,7 +1493,7 @@ where } /// Casts generic string arrays to Date32Array -fn cast_string_to_date32( +fn cast_string_to_date32( array: &dyn Array, cast_options: &CastOptions, ) -> Result { @@ -1212,7 +1555,7 @@ fn cast_string_to_date32( } /// Casts generic string arrays to Date64Array -fn cast_string_to_date64( +fn cast_string_to_date64( array: &dyn Array, cast_options: &CastOptions, ) -> Result { @@ -1245,7 +1588,7 @@ fn cast_string_to_date64( if string_array.is_null(i) { Ok(None) } else { - let string = string_array + let string = string_array .value(i); let result = string @@ -1273,7 +1616,7 @@ fn cast_string_to_date64( } /// Casts generic string arrays to TimeStampNanosecondArray -fn cast_string_to_timestamp_ns( +fn cast_string_to_timestamp_ns( array: &dyn Array, cast_options: &CastOptions, ) -> Result { @@ -1317,6 +1660,34 @@ fn cast_string_to_timestamp_ns( Ok(Arc::new(array) as ArrayRef) } +/// Casts Utf8 to Boolean +fn cast_utf8_to_boolean(from: &ArrayRef, cast_options: &CastOptions) -> Result { + let array = as_string_array(from); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => { + Ok(Some(true)) + } + "f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off" + | "0" => Ok(Some(false)), + invalid_value => match cast_options.safe { + true => Ok(None), + false => Err(ArrowError::CastError(format!( + "Cannot cast string '{}' to value of Boolean type", + invalid_value, + ))), + }, + }, + None => Ok(None), + }) + .collect::>()?; + + Ok(Arc::new(output_array)) +} + /// Cast numeric types to Boolean /// /// Any zero value returns `false` while non-zero returns `true` @@ -1442,7 +1813,7 @@ fn dictionary_cast( cast_keys .data() .null_bitmap() - .clone() + .cloned() .map(|bitmap| bitmap.bits), cast_keys.data().offset(), cast_keys.data().buffers().to_vec(), @@ -1464,7 +1835,7 @@ fn dictionary_cast( return Err(ArrowError::CastError(format!( "Unsupported type {:?} for dictionary index", to_index_type - ))) + ))); } }; @@ -1499,9 +1870,9 @@ where // Note take requires first casting the indices to u32 let keys_array: ArrayRef = Arc::new(PrimitiveArray::::from(dict_array.keys().data().clone())); - let indicies = cast_with_options(&keys_array, &DataType::UInt32, cast_options)?; - let u32_indicies = - indicies + let indices = cast_with_options(&keys_array, &DataType::UInt32, cast_options)?; + let u32_indices = + indices .as_any() .downcast_ref::() .ok_or_else(|| { @@ -1510,7 +1881,7 @@ where ) })?; - take(cast_dict_values.as_ref(), u32_indicies, None) + take(cast_dict_values.as_ref(), u32_indices, None) } /// Attempts to encode an array into an `ArrayDictionary` with index @@ -1660,7 +2031,7 @@ fn cast_primitive_to_list( cast_array .data() .null_bitmap() - .clone() + .cloned() .map(|bitmap| bitmap.bits), 0, vec![offsets.into()], @@ -1688,7 +2059,7 @@ fn cast_list_inner( to_type.clone(), array.len(), Some(data.null_count()), - data.null_bitmap().clone().map(|bitmap| bitmap.bits), + data.null_bitmap().cloned().map(|bitmap| bitmap.bits), array.offset(), // reuse offset buffer data.buffers().to_vec(), @@ -1703,8 +2074,8 @@ fn cast_list_inner( /// a `Utf8` array it will return an Error. fn cast_str_container(array: &dyn Array) -> Result where - OffsetSizeFrom: StringOffsetSizeTrait + ToPrimitive, - OffsetSizeTo: StringOffsetSizeTrait + NumCast + ArrowNativeType, + OffsetSizeFrom: OffsetSizeTrait + ToPrimitive, + OffsetSizeTo: OffsetSizeTrait + NumCast + ArrowNativeType, { let str_array = array .as_any() @@ -1713,7 +2084,7 @@ where let list_data = array.data(); let str_values_buf = str_array.value_data(); - let offsets = unsafe { list_data.buffers()[0].typed_data::() }; + let offsets = list_data.buffers()[0].typed_data::(); let mut offset_builder = BufferBuilder::::new(offsets.len()); offsets.iter().try_for_each::<_, Result<_>>(|offset| { @@ -1734,15 +2105,13 @@ where DataType::Utf8 }; - let mut builder = ArrayData::builder(dtype) + let builder = ArrayData::builder(dtype) .offset(array.offset()) .len(array.len()) .add_buffer(offset_buffer) - .add_buffer(str_values_buf); + .add_buffer(str_values_buf) + .null_bit_buffer(list_data.null_buffer().cloned()); - if let Some(buf) = list_data.null_buffer() { - builder = builder.null_bit_buffer(buf.clone()) - } let array_data = unsafe { builder.build_unchecked() }; Ok(Arc::new(GenericStringArray::::from( @@ -1813,15 +2182,13 @@ where let offset_buffer = unsafe { Buffer::from_trusted_len_iter(iter) }; // wrap up - let mut builder = ArrayData::builder(out_dtype) + let builder = ArrayData::builder(out_dtype) .offset(array.offset()) .len(array.len()) .add_buffer(offset_buffer) - .add_child_data(value_data); + .add_child_data(value_data) + .null_bit_buffer(data.null_buffer().cloned()); - if let Some(buf) = data.null_buffer() { - builder = builder.null_bit_buffer(buf.clone()) - } let array_data = unsafe { builder.build_unchecked() }; Ok(make_array(array_data)) } @@ -1829,19 +2196,316 @@ where #[cfg(test)] mod tests { use super::*; + use crate::util::decimal::Decimal128; use crate::{buffer::Buffer, util::display::array_value_to_string}; + macro_rules! generate_cast_test_case { + ($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => { + // assert cast type + let input_array_type = $INPUT_ARRAY.data_type(); + assert!(can_cast_types(input_array_type, $OUTPUT_TYPE)); + let casted_array = cast($INPUT_ARRAY, $OUTPUT_TYPE).unwrap(); + let result_array = casted_array + .as_any() + .downcast_ref::<$OUTPUT_TYPE_ARRAY>() + .unwrap(); + assert_eq!($OUTPUT_TYPE, result_array.data_type()); + assert_eq!(result_array.len(), $OUTPUT_VALUES.len()); + for (i, x) in $OUTPUT_VALUES.iter().enumerate() { + match x { + Some(x) => { + assert_eq!(result_array.value(i), *x); + } + None => { + assert!(result_array.is_null(i)); + } + } + } + }; + } + + fn create_decimal_array( + array: &[Option], + precision: usize, + scale: usize, + ) -> Result { + array + .iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + #[test] + fn test_cast_decimal_to_decimal() { + let input_type = DataType::Decimal(20, 3); + let output_type = DataType::Decimal(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let input_decimal_array = create_decimal_array(&array, 20, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + DecimalArray, + &output_type, + vec![ + Some(Decimal128::new_from_i128(20, 4, 11234560_i128)), + Some(Decimal128::new_from_i128(20, 4, 21234560_i128)), + Some(Decimal128::new_from_i128(20, 4, 31234560_i128)), + None + ] + ); + // negative test + let array = vec![Some(123456), None]; + let input_decimal_array = create_decimal_array(&array, 10, 0).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + let result = cast(&array, &DataType::Decimal(2, 2)); + assert!(result.is_err()); + assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal of precision 2. Max is 99", + result.unwrap_err().to_string()); + } + + #[test] + fn test_cast_decimal_to_numeric() { + let decimal_type = DataType::Decimal(38, 2); + // negative test + assert!(!can_cast_types(&decimal_type, &DataType::UInt8)); + let value_array: Vec> = + vec![Some(125), Some(225), Some(325), None, Some(525)]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + // i8 + generate_cast_test_case!( + &array, + Int8Array, + &DataType::Int8, + vec![Some(1_i8), Some(2_i8), Some(3_i8), None, Some(5_i8)] + ); + // i16 + generate_cast_test_case!( + &array, + Int16Array, + &DataType::Int16, + vec![Some(1_i16), Some(2_i16), Some(3_i16), None, Some(5_i16)] + ); + // i32 + generate_cast_test_case!( + &array, + Int32Array, + &DataType::Int32, + vec![Some(1_i32), Some(2_i32), Some(3_i32), None, Some(5_i32)] + ); + // i64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f32 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + // f64 + generate_cast_test_case!( + &array, + Int64Array, + &DataType::Int64, + vec![Some(1_i64), Some(2_i64), Some(3_i64), None, Some(5_i64)] + ); + + // overflow test: out of range of max i8 + let value_array: Vec> = vec![Some(24400)]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + let casted_array = cast(&array, &DataType::Int8); + assert_eq!( + "Cast error: value of 244 is out of range Int8".to_string(), + casted_array.unwrap_err().to_string() + ); + + // loss the precision: convert decimal to f32、f64 + // f32 + // 112345678_f32 and 112345679_f32 are same, so the 112345679_f32 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678), + Some(112345679), + ]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Float32Array, + &DataType::Float32, + vec![ + Some(1.25_f32), + Some(2.25_f32), + Some(3.25_f32), + None, + Some(5.25_f32), + Some(1_123_456.7_f32), + Some(1_123_456.7_f32) + ] + ); + + // f64 + // 112345678901234568_f64 and 112345678901234560_f64 are same, so the 112345678901234568_f64 will lose precision. + let value_array: Vec> = vec![ + Some(125), + Some(225), + Some(325), + None, + Some(525), + Some(112345678901234568), + Some(112345678901234560), + ]; + let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let array = Arc::new(decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Float64Array, + &DataType::Float64, + vec![ + Some(1.25_f64), + Some(2.25_f64), + Some(3.25_f64), + None, + Some(5.25_f64), + Some(1_123_456_789_012_345.6_f64), + Some(1_123_456_789_012_345.6_f64), + ] + ); + } + + #[test] + fn test_cast_numeric_to_decimal() { + // test negative cast type + let decimal_type = DataType::Decimal(38, 6); + assert!(!can_cast_types(&DataType::UInt64, &decimal_type)); + + // i8, i16, i32, i64 + let input_datas = vec![ + Arc::new(Int8Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i8 + Arc::new(Int16Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i16 + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i32 + Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(5), + ])) as ArrayRef, // i64 + ]; + for array in input_datas { + generate_cast_test_case!( + &array, + DecimalArray, + &decimal_type, + vec![ + Some(Decimal128::new_from_i128(38, 6, 1000000_i128)), + Some(Decimal128::new_from_i128(38, 6, 2000000_i128)), + Some(Decimal128::new_from_i128(38, 6, 3000000_i128)), + None, + Some(Decimal128::new_from_i128(38, 6, 5000000_i128)) + ] + ); + } + + // test i8 to decimal type with overflow the result type + // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. + let array = Int8Array::from(vec![1, 2, 3, 4, 100]); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &DataType::Decimal(3, 1)); + assert!(casted_array.is_err()); + assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal of precision 3. Max is 999", casted_array.unwrap_err().to_string()); + + // test f32 to decimal type + let array = Float32Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_7), + Some(1.123_456_7), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + DecimalArray, + &decimal_type, + vec![ + Some(Decimal128::new_from_i128(38, 6, 1100000_i128)), + Some(Decimal128::new_from_i128(38, 6, 2200000_i128)), + Some(Decimal128::new_from_i128(38, 6, 4400000_i128)), + None, + Some(Decimal128::new_from_i128(38, 6, 1123456_i128)), + Some(Decimal128::new_from_i128(38, 6, 1123456_i128)), + ] + ); + + // test f64 to decimal type + let array = Float64Array::from(vec![ + Some(1.1), + Some(2.2), + Some(4.4), + None, + Some(1.123_456_789_123_4), + Some(1.123_456_789_012_345_6), + Some(1.123_456_789_012_345_6), + ]); + let array = Arc::new(array) as ArrayRef; + generate_cast_test_case!( + &array, + DecimalArray, + &decimal_type, + vec![ + Some(Decimal128::new_from_i128(38, 6, 1100000_i128)), + Some(Decimal128::new_from_i128(38, 6, 2200000_i128)), + Some(Decimal128::new_from_i128(38, 6, 4400000_i128)), + None, + Some(Decimal128::new_from_i128(38, 6, 1123456_i128)), + Some(Decimal128::new_from_i128(38, 6, 1123456_i128)), + Some(Decimal128::new_from_i128(38, 6, 1123456_i128)), + ] + ); + } + #[test] fn test_cast_i32_to_f64() { let a = Int32Array::from(vec![5, 6, 7, 8, 9]); let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Float64).unwrap(); let c = b.as_any().downcast_ref::().unwrap(); - assert!(5.0 - c.value(0) < f64::EPSILON); - assert!(6.0 - c.value(1) < f64::EPSILON); - assert!(7.0 - c.value(2) < f64::EPSILON); - assert!(8.0 - c.value(3) < f64::EPSILON); - assert!(9.0 - c.value(4) < f64::EPSILON); + assert_eq!(5.0, c.value(0)); + assert_eq!(6.0, c.value(1)); + assert_eq!(7.0, c.value(2)); + assert_eq!(8.0, c.value(3)); + assert_eq!(9.0, c.value(4)); } #[test] @@ -1963,10 +2627,10 @@ mod tests { let values = arr.values(); let c = values.as_any().downcast_ref::().unwrap(); assert_eq!(1, c.null_count()); - assert!(7.0 - c.value(0) < f64::EPSILON); - assert!(8.0 - c.value(1) < f64::EPSILON); + assert_eq!(7.0, c.value(0)); + assert_eq!(8.0, c.value(1)); assert!(!c.is_valid(2)); - assert!(10.0 - c.value(3) < f64::EPSILON); + assert_eq!(10.0, c.value(3)); } #[test] @@ -1998,6 +2662,34 @@ mod tests { } } + #[test] + fn test_cast_utf8_to_bool() { + let strings = Arc::new(StringArray::from(vec![ + "true", "false", "invalid", " Y ", "", + ])) as ArrayRef; + let casted = cast(&strings, &DataType::Boolean).unwrap(); + let expected = + BooleanArray::from(vec![Some(true), Some(false), None, Some(true), None]); + assert_eq!(*as_boolean_array(&casted), expected); + } + + #[test] + fn test_cast_with_options_utf8_to_bool() { + let strings = Arc::new(StringArray::from(vec![ + "true", "false", "invalid", " Y ", "", + ])) as ArrayRef; + let casted = + cast_with_options(&strings, &DataType::Boolean, &CastOptions { safe: false }); + match casted { + Ok(_) => panic!("expected error"), + Err(e) => { + assert!(e.to_string().contains( + "Cast error: Cannot cast string 'invalid' to value of Boolean type" + )) + } + } + } + #[test] fn test_cast_bool_to_i32() { let a = BooleanArray::from(vec![Some(true), Some(false), None]); @@ -2015,8 +2707,8 @@ mod tests { let array = Arc::new(a) as ArrayRef; let b = cast(&array, &DataType::Float64).unwrap(); let c = b.as_any().downcast_ref::().unwrap(); - assert!(1.0 - c.value(0) < f64::EPSILON); - assert!(0.0 - c.value(1) < f64::EPSILON); + assert_eq!(1.0, c.value(0)); + assert_eq!(0.0, c.value(1)); assert!(!c.is_valid(2)); } @@ -2082,6 +2774,7 @@ mod tests { assert_eq!(4, values.null_count()); let u16arr = values.as_any().downcast_ref::().unwrap(); + // expect 4 nulls: negative numbers and overflow let expected: UInt16Array = vec![Some(0), Some(0), Some(0), None, None, None, Some(2), None] .into_iter() @@ -2170,6 +2863,48 @@ mod tests { } } + #[test] + fn test_cast_string_to_date32() { + let a1 = Arc::new(StringArray::from(vec![ + Some("2018-12-25"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2018-12-25"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let b = cast(array, &DataType::Date32).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(17890, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + } + + #[test] + fn test_cast_string_to_date64() { + let a1 = Arc::new(StringArray::from(vec![ + Some("2020-09-08T12:00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("2020-09-08T12:00:00"), + Some("Not a valid date"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let b = cast(array, &DataType::Date64).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(1599566400000, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + } + } + #[test] fn test_cast_date32_to_int32() { let a = Date32Array::from(vec![10000, 17890]); @@ -2249,6 +2984,28 @@ mod tests { assert!(c.is_null(2)); } + #[test] + fn test_cast_date32_to_string() { + let a = Date32Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(&DataType::Utf8, c.data_type()); + assert_eq!("1997-05-19", c.value(0)); + assert_eq!("2018-12-25", c.value(1)); + } + + #[test] + fn test_cast_date64_to_string() { + let a = Date64Array::from(vec![10000 * 86400000, 17890 * 86400000]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Utf8).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(&DataType::Utf8, c.data_type()); + assert_eq!("1997-05-19 00:00:00", c.value(0)); + assert_eq!("2018-12-25 00:00:00", c.value(1)); + } + #[test] fn test_cast_between_timestamps() { let a = TimestampMillisecondArray::from_opt_vec( @@ -2263,6 +3020,44 @@ mod tests { assert!(c.is_null(2)); } + #[test] + fn test_cast_duration_to_i64() { + let base = vec![5, 6, 7, 8, 100000000]; + + let duration_arrays = vec![ + Arc::new(DurationNanosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMicrosecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationMillisecondArray::from(base.clone())) as ArrayRef, + Arc::new(DurationSecondArray::from(base.clone())) as ArrayRef, + ]; + + for arr in duration_arrays { + assert!(can_cast_types(arr.data_type(), &DataType::Int64)); + let result = cast(&arr, &DataType::Int64).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(base.as_slice(), result.values()); + } + } + + #[test] + fn test_cast_interval_to_i64() { + let base = vec![5, 6, 7, 8]; + + let interval_arrays = vec![ + Arc::new(IntervalDayTimeArray::from(base.clone())) as ArrayRef, + Arc::new(IntervalYearMonthArray::from( + base.iter().map(|x| *x as i32).collect::>(), + )) as ArrayRef, + ]; + + for arr in interval_arrays { + assert!(can_cast_types(arr.data_type(), &DataType::Int64)); + let result = cast(&arr, &DataType::Int64).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(base.as_slice(), result.values()); + } + } + #[test] fn test_cast_to_strings() { let a = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; @@ -3518,17 +4313,115 @@ mod tests { } #[test] - fn test_cast_null_array_to_int32() { - let array = Arc::new(NullArray::new(6)) as ArrayRef; + fn test_cast_null_array_from_and_to_primitive_array() { + macro_rules! typed_test { + ($ARR_TYPE:ident, $DATATYPE:ident, $TYPE:tt) => {{ + { + let array = Arc::new(NullArray::new(6)) as ArrayRef; + let expected = $ARR_TYPE::from(vec![None; 6]); + let cast_type = DataType::$DATATYPE; + let cast_array = cast(&array, &cast_type).expect("cast failed"); + let cast_array = as_primitive_array::<$TYPE>(&cast_array); + assert_eq!(cast_array.data_type(), &cast_type); + assert_eq!(cast_array, &expected); + } + { + let array = Arc::new($ARR_TYPE::from(vec![None; 4])) as ArrayRef; + let expected = NullArray::new(4); + let cast_array = cast(&array, &DataType::Null).expect("cast failed"); + let cast_array = as_null_array(&cast_array); + assert_eq!(cast_array.data_type(), &DataType::Null); + assert_eq!(cast_array, &expected); + } + }}; + } - let expected = Int32Array::from(vec![None; 6]); + typed_test!(Int16Array, Int16, Int16Type); + typed_test!(Int32Array, Int32, Int32Type); + typed_test!(Int64Array, Int64, Int64Type); - // Cast to a dictionary (same value type, Utf8) - let cast_type = DataType::Int32; - let cast_array = cast(&array, &cast_type).expect("cast failed"); - let cast_array = as_primitive_array::(&cast_array); - assert_eq!(cast_array.data_type(), &cast_type); - assert_eq!(cast_array, &expected); + typed_test!(UInt16Array, UInt16, UInt16Type); + typed_test!(UInt32Array, UInt32, UInt32Type); + typed_test!(UInt64Array, UInt64, UInt64Type); + + typed_test!(Float32Array, Float32, Float32Type); + typed_test!(Float64Array, Float64, Float64Type); + + typed_test!(Date32Array, Date32, Date32Type); + typed_test!(Date64Array, Date64, Date64Type); + } + + fn cast_from_and_to_null(data_type: &DataType) { + // Cast from data_type to null + { + let array = new_null_array(data_type, 4); + assert_eq!(array.data_type(), data_type); + let cast_array = cast(&array, &DataType::Null).expect("cast failed"); + assert_eq!(cast_array.data_type(), &DataType::Null); + for i in 0..4 { + assert!(cast_array.is_null(i)); + } + } + // Cast from null to data_type + { + let array = new_null_array(&DataType::Null, 4); + assert_eq!(array.data_type(), &DataType::Null); + let cast_array = cast(&array, data_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), data_type); + for i in 0..4 { + assert!(cast_array.is_null(i)); + } + } + } + + #[test] + fn test_cast_null_from_and_to_variable_sized() { + cast_from_and_to_null(&DataType::Utf8); + cast_from_and_to_null(&DataType::LargeUtf8); + cast_from_and_to_null(&DataType::Binary); + cast_from_and_to_null(&DataType::LargeBinary); + } + + #[test] + fn test_cast_null_from_and_to_nested_type() { + // Cast null from and to map + let data_type = DataType::Map( + Box::new(Field::new( + "entry", + DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]), + false, + )), + false, + ); + cast_from_and_to_null(&data_type); + + // Cast null from and to list + let data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + cast_from_and_to_null(&data_type); + let data_type = + DataType::LargeList(Box::new(Field::new("item", DataType::Int32, true))); + cast_from_and_to_null(&data_type); + let data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, true)), + 4, + ); + cast_from_and_to_null(&data_type); + + // Cast null from and to dictionary + let values = vec![None, None, None, None] as Vec>; + let array: DictionaryArray = values.into_iter().collect(); + let array = Arc::new(array) as ArrayRef; + let data_type = array.data_type().to_owned(); + cast_from_and_to_null(&data_type); + + // Cast null from and to struct + let data_type = + DataType::Struct(vec![Field::new("data", DataType::Int64, false)]); + cast_from_and_to_null(&data_type); } /// Print the `DictionaryArray` `array` as a vector of strings @@ -3712,7 +4605,7 @@ mod tests { Arc::new(Int32Array::from(vec![42, 28, 19, 31])), ), ])), - //Arc::new(make_union_array()), + Arc::new(make_union_array()), Arc::new(NullArray::new(10)), Arc::new(StringArray::from(vec!["foo", "bar"])), Arc::new(LargeStringArray::from(vec!["foo", "bar"])), @@ -3755,10 +4648,14 @@ mod tests { Arc::new(Time64NanosecondArray::from(vec![1000, 2000])), Arc::new(IntervalYearMonthArray::from(vec![1000, 2000])), Arc::new(IntervalDayTimeArray::from(vec![1000, 2000])), + Arc::new(IntervalMonthDayNanoArray::from(vec![1000, 2000])), Arc::new(DurationSecondArray::from(vec![1000, 2000])), Arc::new(DurationMillisecondArray::from(vec![1000, 2000])), Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), + Arc::new( + create_decimal_array(&[Some(1), Some(2), Some(3), None], 38, 0).unwrap(), + ), ] } @@ -3910,6 +4807,7 @@ mod tests { Duration(TimeUnit::Nanosecond), Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::DayTime), + Interval(IntervalUnit::MonthDayNano), Binary, FixedSizeBinary(10), LargeBinary, @@ -3925,13 +4823,18 @@ mod tests { Field::new("f1", DataType::Int32, false), Field::new("f2", DataType::Utf8, true), ]), - Union(vec![ - Field::new("f1", DataType::Int32, false), - Field::new("f2", DataType::Utf8, true), - ]), + Union( + vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, true), + ], + vec![0, 1], + UnionMode::Dense, + ), Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + Decimal(38, 0), ] } diff --git a/arrow/src/compute/kernels/cast_utils.rs b/arrow/src/compute/kernels/cast_utils.rs index 8c1b6696722b..e43961b4ab8a 100644 --- a/arrow/src/compute/kernels/cast_utils.rs +++ b/arrow/src/compute/kernels/cast_utils.rs @@ -54,7 +54,7 @@ use chrono::{prelude::*, LocalResult}; /// /// Numerical values of timestamps are stored compared to offset UTC. /// -/// This function intertprets strings without an explicit time zone as +/// This function interprets strings without an explicit time zone as /// timestamps with offsets of the local time on the machine /// /// For example, `1997-01-31 09:26:56.123Z` is interpreted as UTC, as @@ -202,7 +202,7 @@ mod tests { Ok(()) } - /// Interprets a naive_datetime (with no explicit timzone offset) + /// Interprets a naive_datetime (with no explicit timezone offset) /// using the local timezone and returns the timestamp in UTC (0 /// offset) fn naive_datetime_to_timestamp(naive_datetime: &NaiveDateTime) -> i64 { @@ -224,7 +224,7 @@ mod tests { fn string_to_timestamp_no_timezone() -> Result<()> { // This test is designed to succeed in regardless of the local // timezone the test machine is running. Thus it is still - // somewhat suceptable to bugs in the use of chrono + // somewhat susceptible to bugs in the use of chrono let naive_datetime = NaiveDateTime::new( NaiveDate::from_ymd(2020, 9, 8), NaiveTime::from_hms_nano(13, 42, 29, 190855000), diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 1f0cb1a39cd2..068b9dedf59b 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -15,24 +15,28 @@ // specific language governing permissions and limitations // under the License. -//! Defines basic comparison kernels for [`PrimitiveArray`]s. +//! Comparison kernels for `Array`s. //! //! These kernels can leverage SIMD if available on your system. Currently no runtime //! detection is provided, you should enable the specific SIMD intrinsics using //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. +//! use crate::array::*; use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer}; use crate::compute::binary_boolean_kernel; use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ - ArrowNumericType, DataType, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use regex::Regex; +use regex::{escape, Regex}; use std::any::type_name; use std::collections::HashMap; @@ -48,7 +52,7 @@ macro_rules! compare_op { } let null_bit_buffer = - combine_option_bitmap($left.data_ref(), $right.data_ref(), $left.len())?; + combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?; // Safety: // `i < $left.len()` and $left.len() == $right.len() @@ -82,7 +86,7 @@ macro_rules! compare_op_primitive { } let null_bit_buffer = - combine_option_bitmap($left.data_ref(), $right.data_ref(), $left.len())?; + combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?; let mut values = MutableBuffer::from_len_zeroed(($left.len() + 7) / 8); let lhs_chunks_iter = $left.values().chunks_exact(8); @@ -228,28 +232,23 @@ where compare_op_scalar_primitive!(left, right, op) } -/// Perform SQL `left LIKE right` operation on [`StringArray`] / [`LargeStringArray`]. -/// -/// There are two wildcards supported with the LIKE operator: -/// -/// 1. `%` - The percent sign represents zero, one, or multiple characters -/// 2. `_` - The underscore represents a single character -/// -/// For example: -/// ``` -/// use arrow::array::{StringArray, BooleanArray}; -/// use arrow::compute::like_utf8; -/// -/// let strings = StringArray::from(vec!["Arrow", "Arrow", "Arrow", "Ar"]); -/// let patterns = StringArray::from(vec!["A%", "B%", "A.", "A."]); +fn is_like_pattern(c: char) -> bool { + c == '%' || c == '_' +} + +/// Evaluate regex `op(left)` matching `right` on [`StringArray`] / [`LargeStringArray`] /// -/// let result = like_utf8(&strings, &patterns).unwrap(); -/// assert_eq!(result, BooleanArray::from(vec![true, false, false, true])); -/// ``` -pub fn like_utf8( +/// If `negate_regex` is true, the regex expression will be negated. (for example, with `not like`) +fn regex_like( left: &GenericStringArray, right: &GenericStringArray, -) -> Result { + negate_regex: bool, + op: F, +) -> Result +where + OffsetSize: OffsetSizeTrait, + F: Fn(&str) -> Result, +{ let mut map = HashMap::new(); if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -259,7 +258,7 @@ pub fn like_utf8( } let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; + combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; let mut result = BooleanBufferBuilder::new(left.len()); for i in 0..left.len() { @@ -268,18 +267,17 @@ pub fn like_utf8( let re = if let Some(ref regex) = map.get(pat) { regex } else { - let re_pattern = pat.replace("%", ".*").replace("_", "."); - let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from LIKE pattern: {}", - e - )) - })?; + let re_pattern = escape(pat).replace('%', ".*").replace('_', "."); + let re = op(&re_pattern)?; map.insert(pat, re); map.get(pat).unwrap() }; - result.append(re.is_match(haystack)); + result.append(if negate_regex { + !re.is_match(haystack) + } else { + re.is_match(haystack) + }); } let data = unsafe { @@ -296,15 +294,43 @@ pub fn like_utf8( Ok(BooleanArray::from(data)) } -fn is_like_pattern(c: char) -> bool { - c == '%' || c == '_' +/// Perform SQL `left LIKE right` operation on [`StringArray`] / [`LargeStringArray`]. +/// +/// There are two wildcards supported with the LIKE operator: +/// +/// 1. `%` - The percent sign represents zero, one, or multiple characters +/// 2. `_` - The underscore represents a single character +/// +/// For example: +/// ``` +/// use arrow::array::{StringArray, BooleanArray}; +/// use arrow::compute::like_utf8; +/// +/// let strings = StringArray::from(vec!["Arrow", "Arrow", "Arrow", "Ar"]); +/// let patterns = StringArray::from(vec!["A%", "B%", "A.", "A_"]); +/// +/// let result = like_utf8(&strings, &patterns).unwrap(); +/// assert_eq!(result, BooleanArray::from(vec![true, false, false, true])); +/// ``` +pub fn like_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { + regex_like(left, right, false, |re_pattern| { + Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from LIKE pattern: {}", + e + )) + }) + }) } /// Perform SQL `left LIKE right` operation on [`StringArray`] / /// [`LargeStringArray`] and a scalar. /// /// See the documentation on [`like_utf8`] for more details. -pub fn like_utf8_scalar( +pub fn like_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { @@ -338,7 +364,7 @@ pub fn like_utf8_scalar( } } } else { - let re_pattern = right.replace("%", ".*").replace("_", "."); + let re_pattern = escape(right).replace('%', ".*").replace('_', "."); let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -372,40 +398,140 @@ pub fn like_utf8_scalar( /// [`LargeStringArray`]. /// /// See the documentation on [`like_utf8`] for more details. -pub fn nlike_utf8( +pub fn nlike_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - let mut map = HashMap::new(); - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); + regex_like(left, right, true, |re_pattern| { + Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from LIKE pattern: {}", + e + )) + }) + }) +} + +/// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn nlike_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { + let null_bit_buffer = left.data().null_buffer().cloned(); + let mut result = BooleanBufferBuilder::new(left.len()); + + if !right.contains(is_like_pattern) { + // fast path, can use equals + for i in 0..left.len() { + result.append(left.value(i) != right); + } + } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + { + // fast path, can use ends_with + for i in 0..left.len() { + result.append(!left.value(i).starts_with(&right[..right.len() - 1])); + } + } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { + // fast path, can use starts_with + for i in 0..left.len() { + result.append(!left.value(i).ends_with(&right[1..])); + } + } else { + let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from LIKE pattern: {}", + e + )) + })?; + for i in 0..left.len() { + let haystack = left.value(i); + result.append(!re.is_match(haystack)); + } } - let null_bit_buffer = - combine_option_bitmap(left.data_ref(), right.data_ref(), left.len())?; + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![result.finish()], + vec![], + ) + }; + Ok(BooleanArray::from(data)) +} + +/// Perform SQL `left ILIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`]. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn ilike_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { + regex_like(left, right, false, |re_pattern| { + Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from ILIKE pattern: {}", + e + )) + }) + }) +} +/// Perform SQL `left ILIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn ilike_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { + let null_bit_buffer = left.data().null_buffer().cloned(); let mut result = BooleanBufferBuilder::new(left.len()); - for i in 0..left.len() { - let haystack = left.value(i); - let pat = right.value(i); - let re = if let Some(ref regex) = map.get(pat) { - regex - } else { - let re_pattern = pat.replace("%", ".*").replace("_", "."); - let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from LIKE pattern: {}", - e - )) - })?; - map.insert(pat, re); - map.get(pat).unwrap() - }; - result.append(!re.is_match(haystack)); + if !right.contains(is_like_pattern) { + // fast path, can use equals + for i in 0..left.len() { + result.append(left.value(i) == right); + } + } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + { + // fast path, can use ends_with + for i in 0..left.len() { + result.append( + left.value(i) + .to_uppercase() + .starts_with(&right[..right.len() - 1].to_uppercase()), + ); + } + } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { + // fast path, can use starts_with + for i in 0..left.len() { + result.append( + left.value(i) + .to_uppercase() + .ends_with(&right[1..].to_uppercase()), + ); + } + } else { + let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from ILIKE pattern: {}", + e + )) + })?; + for i in 0..left.len() { + let haystack = left.value(i); + result.append(re.is_match(haystack)); + } } let data = unsafe { @@ -422,11 +548,29 @@ pub fn nlike_utf8( Ok(BooleanArray::from(data)) } -/// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / +/// Perform SQL `left NOT ILIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`]. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn nilike_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { + regex_like(left, right, true, |re_pattern| { + Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { + ArrowError::ComputeError(format!( + "Unable to build regex from ILIKE pattern: {}", + e + )) + }) + }) +} + +/// Perform SQL `left NOT ILIKE right` operation on [`StringArray`] / /// [`LargeStringArray`] and a scalar. /// /// See the documentation on [`like_utf8`] for more details. -pub fn nlike_utf8_scalar( +pub fn nilike_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { @@ -442,18 +586,28 @@ pub fn nlike_utf8_scalar( { // fast path, can use ends_with for i in 0..left.len() { - result.append(!left.value(i).starts_with(&right[..right.len() - 1])); + result.append( + !left + .value(i) + .to_uppercase() + .starts_with(&right[..right.len() - 1].to_uppercase()), + ); } } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { // fast path, can use starts_with for i in 0..left.len() { - result.append(!left.value(i).ends_with(&right[1..])); + result.append( + !left + .value(i) + .to_uppercase() + .ends_with(&right[1..].to_uppercase()), + ); } } else { - let re_pattern = right.replace("%", ".*").replace("_", "."); - let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { + let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( - "Unable to build regex from LIKE pattern: {}", + "Unable to build regex from ILIKE pattern: {}", e )) })?; @@ -484,7 +638,7 @@ pub fn nlike_utf8_scalar( /// special search modes, such as case insensitive and multi-line mode. /// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) /// for more information. -pub fn regexp_is_match_utf8( +pub fn regexp_is_match_utf8( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>, @@ -496,7 +650,7 @@ pub fn regexp_is_match_utf8( )); } let null_bit_buffer = - combine_option_bitmap(array.data_ref(), regex_array.data_ref(), array.len())?; + combine_option_bitmap(&[array.data_ref(), regex_array.data_ref()], array.len())?; let mut patterns: HashMap = HashMap::new(); let mut result = BooleanBufferBuilder::new(array.len()); @@ -568,7 +722,7 @@ pub fn regexp_is_match_utf8( /// [`LargeStringArray`] and a scalar. /// /// See the documentation on [`regexp_is_match_utf8`] for more details. -pub fn regexp_is_match_utf8_scalar( +pub fn regexp_is_match_utf8_scalar( array: &GenericStringArray, regex: &str, flag: Option<&str>, @@ -580,10 +734,8 @@ pub fn regexp_is_match_utf8_scalar( Some(flag) => format!("(?{}){}", flag, regex), None => regex.to_string(), }; - if pattern == *"" { - for _i in 0..array.len() { - result.append(true); - } + if pattern.is_empty() { + result.append_n(array.len(), true); } else { let re = Regex::new(pattern.as_str()).map_err(|e| { ArrowError::ComputeError(format!( @@ -597,6 +749,7 @@ pub fn regexp_is_match_utf8_scalar( } } + let buffer = result.finish(); let data = unsafe { ArrayData::new_unchecked( DataType::Boolean, @@ -604,7 +757,7 @@ pub fn regexp_is_match_utf8_scalar( None, null_bit_buffer, 0, - vec![result.finish()], + vec![buffer], vec![], ) }; @@ -612,7 +765,7 @@ pub fn regexp_is_match_utf8_scalar( } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`]. -pub fn eq_utf8( +pub fn eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { @@ -620,7 +773,7 @@ pub fn eq_utf8( } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. -pub fn eq_utf8_scalar( +pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { @@ -657,37 +810,37 @@ where } /// Perform `left == right` operation on [`BooleanArray`] -fn eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { binary_boolean_op(left, right, |a, b| !(a ^ b)) } /// Perform `left != right` operation on [`BooleanArray`] -fn neq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn neq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { binary_boolean_op(left, right, |a, b| (a ^ b)) } /// Perform `left < right` operation on [`BooleanArray`] -fn lt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn lt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { binary_boolean_op(left, right, |a, b| ((!a) & b)) } /// Perform `left <= right` operation on [`BooleanArray`] -fn lt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn lt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { binary_boolean_op(left, right, |a, b| !(a & (!b))) } /// Perform `left > right` operation on [`BooleanArray`] -fn gt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn gt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { binary_boolean_op(left, right, |a, b| (a & (!b))) } /// Perform `left >= right` operation on [`BooleanArray`] -fn gt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { +pub fn gt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { binary_boolean_op(left, right, |a, b| !((!a) & b)) } /// Perform `left == right` operation on [`BooleanArray`] and a scalar -fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { let len = left.len(); let left_offset = left.offset(); @@ -715,13 +868,129 @@ fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { Ok(BooleanArray::from(data)) } +/// Perform `left < right` operation on [`BooleanArray`] and a scalar +pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result { + compare_op_scalar!(left, right, |a: bool, b: bool| !a & b) +} + +/// Perform `left <= right` operation on [`BooleanArray`] and a scalar +pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { + compare_op_scalar!(left, right, |a, b| a <= b) +} + +/// Perform `left > right` operation on [`BooleanArray`] and a scalar +pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result { + compare_op_scalar!(left, right, |a: bool, b: bool| a & !b) +} + +/// Perform `left >= right` operation on [`BooleanArray`] and a scalar +pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { + compare_op_scalar!(left, right, |a, b| a >= b) +} + /// Perform `left != right` operation on [`BooleanArray`] and a scalar -fn neq_bool_scalar(left: &BooleanArray, right: bool) -> Result { +pub fn neq_bool_scalar(left: &BooleanArray, right: bool) -> Result { eq_bool_scalar(left, !right) } +/// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +pub fn eq_binary( + left: &GenericBinaryArray, + right: &GenericBinaryArray, +) -> Result { + compare_op!(left, right, |a, b| a == b) +} + +/// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar +pub fn eq_binary_scalar( + left: &GenericBinaryArray, + right: &[u8], +) -> Result { + compare_op_scalar!(left, right, |a, b| a == b) +} + +/// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +pub fn neq_binary( + left: &GenericBinaryArray, + right: &GenericBinaryArray, +) -> Result { + compare_op!(left, right, |a, b| a != b) +} + +/// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +pub fn neq_binary_scalar( + left: &GenericBinaryArray, + right: &[u8], +) -> Result { + compare_op_scalar!(left, right, |a, b| a != b) +} + +/// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +pub fn lt_binary( + left: &GenericBinaryArray, + right: &GenericBinaryArray, +) -> Result { + compare_op!(left, right, |a, b| a < b) +} + +/// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +pub fn lt_binary_scalar( + left: &GenericBinaryArray, + right: &[u8], +) -> Result { + compare_op_scalar!(left, right, |a, b| a < b) +} + +/// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +pub fn lt_eq_binary( + left: &GenericBinaryArray, + right: &GenericBinaryArray, +) -> Result { + compare_op!(left, right, |a, b| a <= b) +} + +/// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +pub fn lt_eq_binary_scalar( + left: &GenericBinaryArray, + right: &[u8], +) -> Result { + compare_op_scalar!(left, right, |a, b| a <= b) +} + +/// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +pub fn gt_binary( + left: &GenericBinaryArray, + right: &GenericBinaryArray, +) -> Result { + compare_op!(left, right, |a, b| a > b) +} + +/// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +pub fn gt_binary_scalar( + left: &GenericBinaryArray, + right: &[u8], +) -> Result { + compare_op_scalar!(left, right, |a, b| a > b) +} + +/// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +pub fn gt_eq_binary( + left: &GenericBinaryArray, + right: &GenericBinaryArray, +) -> Result { + compare_op!(left, right, |a, b| a >= b) +} + +/// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +pub fn gt_eq_binary_scalar( + left: &GenericBinaryArray, + right: &[u8], +) -> Result { + compare_op_scalar!(left, right, |a, b| a >= b) +} + /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`]. -pub fn neq_utf8( +pub fn neq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { @@ -729,7 +998,7 @@ pub fn neq_utf8( } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. -pub fn neq_utf8_scalar( +pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { @@ -737,7 +1006,7 @@ pub fn neq_utf8_scalar( } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`]. -pub fn lt_utf8( +pub fn lt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { @@ -745,7 +1014,7 @@ pub fn lt_utf8( } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. -pub fn lt_utf8_scalar( +pub fn lt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { @@ -753,7 +1022,7 @@ pub fn lt_utf8_scalar( } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`]. -pub fn lt_eq_utf8( +pub fn lt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { @@ -761,7 +1030,7 @@ pub fn lt_eq_utf8( } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. -pub fn lt_eq_utf8_scalar( +pub fn lt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { @@ -769,7 +1038,7 @@ pub fn lt_eq_utf8_scalar( } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`]. -pub fn gt_utf8( +pub fn gt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { @@ -777,7 +1046,7 @@ pub fn gt_utf8( } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. -pub fn gt_utf8_scalar( +pub fn gt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { @@ -785,7 +1054,7 @@ pub fn gt_utf8_scalar( } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`]. -pub fn gt_eq_utf8( +pub fn gt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { @@ -793,28 +1062,694 @@ pub fn gt_eq_utf8( } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. -pub fn gt_eq_utf8_scalar( +pub fn gt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { compare_op_scalar!(left, right, |a, b| a >= b) } -/// Helper function to perform boolean lambda function on values from two arrays using -/// SIMD. -#[cfg(feature = "simd")] -fn simd_compare_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - simd_op: SIMD_OP, - scalar_op: SCALAR_OP, -) -> Result -where - T: ArrowNumericType, - SIMD_OP: Fn(T::Simd, T::Simd) -> T::SimdMask, - SCALAR_OP: Fn(T::Native, T::Native) -> bool, -{ - use std::borrow::BorrowMut; +/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. +/// Type of expression is `Result<.., ArrowError>` +macro_rules! try_to_type { + ($RIGHT: expr, $TY: ident) => {{ + $RIGHT.$TY().ok_or_else(|| { + ArrowError::ComputeError(format!( + "Could not convert {} with {}", + stringify!($RIGHT), + stringify!($TY) + )) + }) + }}; +} + +macro_rules! dyn_compare_scalar { + // Applies `LEFT OP RIGHT` when `LEFT` is a `PrimitiveArray` + ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{ + match $LEFT.data_type() { + DataType::Int8 => { + let right = try_to_type!($RIGHT, to_i8)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Int16 => { + let right = try_to_type!($RIGHT, to_i16)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Int32 => { + let right = try_to_type!($RIGHT, to_i32)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Int64 => { + let right = try_to_type!($RIGHT, to_i64)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt8 => { + let right = try_to_type!($RIGHT, to_u8)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt16 => { + let right = try_to_type!($RIGHT, to_u16)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt32 => { + let right = try_to_type!($RIGHT, to_u32)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::UInt64 => { + let right = try_to_type!($RIGHT, to_u64)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Float32 => { + let right = try_to_type!($RIGHT, to_f32)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + DataType::Float64 => { + let right = try_to_type!($RIGHT, to_f64)?; + let left = as_primitive_array::($LEFT); + $OP::(left, right) + } + _ => Err(ArrowError::ComputeError(format!( + "Unsupported data type {:?} for comparison {} with {:?}", + $LEFT.data_type(), + stringify!($OP), + $RIGHT + ))), + } + }}; + // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` with keys of type `KT` + ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ + match $KT.as_ref() { + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + unpack_dict_comparison( + left, + dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, + ) + } + _ => Err(ArrowError::ComputeError(format!( + "Unsupported dictionary key type {:?}", + $KT.as_ref() + ))), + } + }}; +} + +macro_rules! dyn_compare_utf8_scalar { + ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ + match $KT.as_ref() { + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + let values = as_string_array(left.values()); + unpack_dict_comparison(left, $OP(values, $RIGHT)?) + } + _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), + } + }}; +} + +/// Perform `left == right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn eq_dyn_scalar(left: &dyn Array, right: T) -> Result +where + T: num::ToPrimitive + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, _value_type) => { + dyn_compare_scalar!(left, right, key_type, eq_scalar) + } + _ => dyn_compare_scalar!(left, right, eq_scalar), + } +} + +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn lt_dyn_scalar(left: &dyn Array, right: T) -> Result +where + T: num::ToPrimitive + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, _value_type) => { + dyn_compare_scalar!(left, right, key_type, lt_scalar) + } + _ => dyn_compare_scalar!(left, right, lt_scalar), + } +} + +/// Perform `left <= right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn lt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result +where + T: num::ToPrimitive + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, _value_type) => { + dyn_compare_scalar!(left, right, key_type, lt_eq_scalar) + } + _ => dyn_compare_scalar!(left, right, lt_eq_scalar), + } +} + +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn gt_dyn_scalar(left: &dyn Array, right: T) -> Result +where + T: num::ToPrimitive + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, _value_type) => { + dyn_compare_scalar!(left, right, key_type, gt_scalar) + } + _ => dyn_compare_scalar!(left, right, gt_scalar), + } +} + +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn gt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result +where + T: num::ToPrimitive + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, _value_type) => { + dyn_compare_scalar!(left, right, key_type, gt_eq_scalar) + } + _ => dyn_compare_scalar!(left, right, gt_eq_scalar), + } +} + +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values +pub fn neq_dyn_scalar(left: &dyn Array, right: T) -> Result +where + T: num::ToPrimitive + std::fmt::Debug, +{ + match left.data_type() { + DataType::Dictionary(key_type, _value_type) => { + dyn_compare_scalar!(left, right, key_type, neq_scalar) + } + _ => dyn_compare_scalar!(left, right, neq_scalar), + } +} + +/// Perform `left == right` operation on an array and a numeric scalar +/// value. Supports BinaryArray and LargeBinaryArray +pub fn eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { + match left.data_type() { + DataType::Binary => { + let left = as_generic_binary_array::(left); + eq_binary_scalar(left, right) + } + DataType::LargeBinary => { + let left = as_generic_binary_array::(left); + eq_binary_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "eq_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), + )), + } +} + +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports BinaryArray and LargeBinaryArray +pub fn neq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { + match left.data_type() { + DataType::Binary => { + let left = as_generic_binary_array::(left); + neq_binary_scalar(left, right) + } + DataType::LargeBinary => { + let left = as_generic_binary_array::(left); + neq_binary_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "neq_dyn_binary_scalar only supports Binary or LargeBinary arrays" + .to_string(), + )), + } +} + +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports BinaryArray and LargeBinaryArray +pub fn lt_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { + match left.data_type() { + DataType::Binary => { + let left = as_generic_binary_array::(left); + lt_binary_scalar(left, right) + } + DataType::LargeBinary => { + let left = as_generic_binary_array::(left); + lt_binary_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "lt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), + )), + } +} + +/// Perform `left <= right` operation on an array and a numeric scalar +/// value. Supports BinaryArray and LargeBinaryArray +pub fn lt_eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { + match left.data_type() { + DataType::Binary => { + let left = as_generic_binary_array::(left); + lt_eq_binary_scalar(left, right) + } + DataType::LargeBinary => { + let left = as_generic_binary_array::(left); + lt_eq_binary_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "lt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" + .to_string(), + )), + } +} + +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports BinaryArray and LargeBinaryArray +pub fn gt_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { + match left.data_type() { + DataType::Binary => { + let left = as_generic_binary_array::(left); + gt_binary_scalar(left, right) + } + DataType::LargeBinary => { + let left = as_generic_binary_array::(left); + gt_binary_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "gt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), + )), + } +} + +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports BinaryArray and LargeBinaryArray +pub fn gt_eq_dyn_binary_scalar(left: &dyn Array, right: &[u8]) -> Result { + match left.data_type() { + DataType::Binary => { + let left = as_generic_binary_array::(left); + gt_eq_binary_scalar(left, right) + } + DataType::LargeBinary => { + let left = as_generic_binary_array::(left); + gt_eq_binary_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" + .to_string(), + )), + } +} + +/// Perform `left == right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn eq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(left, right, key_type, eq_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 => { + let left = as_string_array(left); + eq_utf8_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = as_largestring_array(left); + eq_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn lt_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(left, right, key_type, lt_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 => { + let left = as_string_array(left); + lt_utf8_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = as_largestring_array(left); + lt_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn gt_eq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(left, right, key_type, gt_eq_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 => { + let left = as_string_array(left); + gt_eq_utf8_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = as_largestring_array(left); + gt_eq_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left <= right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn lt_eq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(left, right, key_type, lt_eq_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "lt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 => { + let left = as_string_array(left); + lt_eq_utf8_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = as_largestring_array(left); + lt_eq_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "lt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn gt_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(left, right, key_type, gt_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "gt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 => { + let left = as_string_array(left); + gt_utf8_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = as_largestring_array(left); + gt_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "gt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +pub fn neq_dyn_utf8_scalar(left: &dyn Array, right: &str) -> Result { + let result = match left.data_type() { + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => { + dyn_compare_utf8_scalar!(left, right, key_type, neq_utf8_scalar) + } + _ => Err(ArrowError::ComputeError( + "neq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )), + }, + DataType::Utf8 => { + let left = as_string_array(left); + neq_utf8_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = as_largestring_array(left); + neq_utf8_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "neq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), + )), + }; + result +} + +/// Perform `left == right` operation on an array and a numeric scalar +/// value. +pub fn eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { + let result = match left.data_type() { + DataType::Boolean => { + let left = as_boolean_array(left); + eq_bool_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "eq_dyn_bool_scalar only supports BooleanArray".to_string(), + )), + }; + result +} + +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +pub fn lt_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { + let result = match left.data_type() { + DataType::Boolean => { + let left = as_boolean_array(left); + lt_bool_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "lt_dyn_bool_scalar only supports BooleanArray".to_string(), + )), + }; + result +} + +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +pub fn gt_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { + let result = match left.data_type() { + DataType::Boolean => { + let left = as_boolean_array(left); + gt_bool_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "gt_dyn_bool_scalar only supports BooleanArray".to_string(), + )), + }; + result +} + +/// Perform `left <= right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +pub fn lt_eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { + let result = match left.data_type() { + DataType::Boolean => { + let left = as_boolean_array(left); + lt_eq_bool_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "lt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), + )), + }; + result +} + +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +pub fn gt_eq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { + let result = match left.data_type() { + DataType::Boolean => { + let left = as_boolean_array(left); + gt_eq_bool_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "gt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), + )), + }; + result +} + +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +pub fn neq_dyn_bool_scalar(left: &dyn Array, right: bool) -> Result { + let result = match left.data_type() { + DataType::Boolean => { + let left = as_boolean_array(left); + neq_bool_scalar(left, right) + } + _ => Err(ArrowError::ComputeError( + "neq_dyn_bool_scalar only supports BooleanArray".to_string(), + )), + }; + result +} + +/// unpacks the results of comparing left.values (as a boolean) +/// +/// TODO add example +/// +fn unpack_dict_comparison( + dict: &DictionaryArray, + dict_comparison: BooleanArray, +) -> Result +where + K: ArrowNumericType, +{ + assert_eq!(dict_comparison.len(), dict.values().len()); + + let result: BooleanArray = dict + .keys() + .iter() + .map(|key| { + key.map(|key| unsafe { + // safety lengths were verified above + let key = key.to_usize().expect("Dictionary index not usize"); + dict_comparison.value_unchecked(key) + }) + }) + .collect(); + + Ok(result) +} + +/// Helper function to perform boolean lambda function on values from two arrays using +/// SIMD. +#[cfg(feature = "simd")] +fn simd_compare_op( + left: &PrimitiveArray, + right: &PrimitiveArray, + simd_op: SI, + scalar_op: SC, +) -> Result +where + T: ArrowNumericType, + SI: Fn(T::Simd, T::Simd) -> T::SimdMask, + SC: Fn(T::Native, T::Native) -> bool, +{ + use std::borrow::BorrowMut; let len = left.len(); if len != right.len() { @@ -824,59 +1759,68 @@ where )); } - let null_bit_buffer = combine_option_bitmap(left.data_ref(), right.data_ref(), len)?; + let null_bit_buffer = + combine_option_bitmap(&[left.data_ref(), right.data_ref()], len)?; + // we process the data in chunks so that each iteration results in one u64 of comparison result bits + const CHUNK_SIZE: usize = 64; let lanes = T::lanes(); - let buffer_size = bit_util::ceil(len, 8); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); // this is currently the case for all our datatypes and allows us to always append full bytes assert!( - lanes % 8 == 0, - "Number of vector lanes must be multiple of 8" + lanes <= CHUNK_SIZE, + "Number of vector lanes must be at most 64" ); - let mut left_chunks = left.values().chunks_exact(lanes); - let mut right_chunks = right.values().chunks_exact(lanes); + let buffer_size = bit_util::ceil(len, 8); + let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); + + let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); + let mut right_chunks = right.values().chunks_exact(CHUNK_SIZE); + + // safety: result is newly created above, always written as a T below + let result_chunks = unsafe { result.typed_data_mut() }; let result_remainder = left_chunks .borrow_mut() .zip(right_chunks.borrow_mut()) - .fold( - result.typed_data_mut(), - |result_slice, (left_slice, right_slice)| { - let simd_left = T::load(left_slice); - let simd_right = T::load(right_slice); + .fold(result_chunks, |result_slice, (left_slice, right_slice)| { + let mut i = 0; + let mut bitmask = 0_u64; + while i < CHUNK_SIZE { + let simd_left = T::load(&left_slice[i..]); + let simd_right = T::load(&right_slice[i..]); let simd_result = simd_op(simd_left, simd_right); - let bitmask = T::mask_to_u64(&simd_result); - let bytes = bitmask.to_le_bytes(); - result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]); + let m = T::mask_to_u64(&simd_result); + bitmask |= m << i; - &mut result_slice[lanes / 8..] - }, - ); + i += lanes; + } + let bytes = bitmask.to_le_bytes(); + result_slice[0..8].copy_from_slice(&bytes); + + &mut result_slice[8..] + }); let left_remainder = left_chunks.remainder(); let right_remainder = right_chunks.remainder(); assert_eq!(left_remainder.len(), right_remainder.len()); - let remainder_bitmask = left_remainder - .iter() - .zip(right_remainder.iter()) - .enumerate() - .fold(0_u64, |mut mask, (i, (scalar_left, scalar_right))| { - let bit = if scalar_op(*scalar_left, *scalar_right) { - 1_u64 - } else { - 0_u64 - }; - mask |= bit << i; - mask - }); - let remainder_mask_as_bytes = - &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; - result_remainder.copy_from_slice(remainder_mask_as_bytes); + if !left_remainder.is_empty() { + let remainder_bitmask = left_remainder + .iter() + .zip(right_remainder.iter()) + .enumerate() + .fold(0_u64, |mut mask, (i, (scalar_left, scalar_right))| { + let bit = scalar_op(*scalar_left, *scalar_right) as u64; + mask |= bit << i; + mask + }); + let remainder_mask_as_bytes = + &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; + result_remainder.copy_from_slice(remainder_mask_as_bytes); + } let data = unsafe { ArrayData::new_unchecked( @@ -895,65 +1839,75 @@ where /// Helper function to perform boolean lambda function on values from an array and a scalar value using /// SIMD. #[cfg(feature = "simd")] -fn simd_compare_op_scalar( +fn simd_compare_op_scalar( left: &PrimitiveArray, right: T::Native, - simd_op: SIMD_OP, - scalar_op: SCALAR_OP, + simd_op: SI, + scalar_op: SC, ) -> Result where T: ArrowNumericType, - SIMD_OP: Fn(T::Simd, T::Simd) -> T::SimdMask, - SCALAR_OP: Fn(T::Native, T::Native) -> bool, + SI: Fn(T::Simd, T::Simd) -> T::SimdMask, + SC: Fn(T::Native, T::Native) -> bool, { use std::borrow::BorrowMut; let len = left.len(); + // we process the data in chunks so that each iteration results in one u64 of comparison result bits + const CHUNK_SIZE: usize = 64; let lanes = T::lanes(); - let buffer_size = bit_util::ceil(len, 8); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); // this is currently the case for all our datatypes and allows us to always append full bytes assert!( - lanes % 8 == 0, - "Number of vector lanes must be multiple of 8" + lanes <= CHUNK_SIZE, + "Number of vector lanes must be at most 64" ); - let mut left_chunks = left.values().chunks_exact(lanes); - let simd_right = T::init(right); - let result_remainder = left_chunks.borrow_mut().fold( - result.typed_data_mut(), - |result_slice, left_slice| { - let simd_left = T::load(left_slice); - let simd_result = simd_op(simd_left, simd_right); + let buffer_size = bit_util::ceil(len, 8); + let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - let bitmask = T::mask_to_u64(&simd_result); - let bytes = bitmask.to_le_bytes(); - result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]); + let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); + let simd_right = T::init(right); - &mut result_slice[lanes / 8..] - }, - ); + // safety: result is newly created above, always written as a T below + let result_chunks = unsafe { result.typed_data_mut() }; + let result_remainder = + left_chunks + .borrow_mut() + .fold(result_chunks, |result_slice, left_slice| { + let mut i = 0; + let mut bitmask = 0_u64; + while i < CHUNK_SIZE { + let simd_left = T::load(&left_slice[i..]); + let simd_result = simd_op(simd_left, simd_right); + + let m = T::mask_to_u64(&simd_result); + bitmask |= m << i; + + i += lanes; + } + let bytes = bitmask.to_le_bytes(); + result_slice[0..8].copy_from_slice(&bytes); + + &mut result_slice[8..] + }); let left_remainder = left_chunks.remainder(); - let remainder_bitmask = - left_remainder - .iter() - .enumerate() - .fold(0_u64, |mut mask, (i, scalar_left)| { - let bit = if scalar_op(*scalar_left, right) { - 1_u64 - } else { - 0_u64 - }; + if !left_remainder.is_empty() { + let remainder_bitmask = left_remainder.iter().enumerate().fold( + 0_u64, + |mut mask, (i, scalar_left)| { + let bit = scalar_op(*scalar_left, right) as u64; mask |= bit << i; mask - }); - let remainder_mask_as_bytes = - &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; - result_remainder.copy_from_slice(remainder_mask_as_bytes); + }, + ); + let remainder_mask_as_bytes = + &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; + result_remainder.copy_from_slice(remainder_mask_as_bytes); + } let null_bit_buffer = left .data_ref() @@ -1011,7 +1965,7 @@ macro_rules! typed_cmp { } macro_rules! typed_compares { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident) => {{ + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident, $OP_BINARY: ident) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Boolean, DataType::Boolean) => { typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL) @@ -1052,6 +2006,102 @@ macro_rules! typed_compares { (DataType::LargeUtf8, DataType::LargeUtf8) => { typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64) } + (DataType::Binary, DataType::Binary) => { + typed_cmp!($LEFT, $RIGHT, BinaryArray, $OP_BINARY, i32) + } + (DataType::LargeBinary, DataType::LargeBinary) => { + typed_cmp!($LEFT, $RIGHT, LargeBinaryArray, $OP_BINARY, i64) + } + ( + DataType::Timestamp(TimeUnit::Nanosecond, _), + DataType::Timestamp(TimeUnit::Nanosecond, _), + ) => { + typed_cmp!( + $LEFT, + $RIGHT, + TimestampNanosecondArray, + $OP_PRIM, + TimestampNanosecondType + ) + } + ( + DataType::Timestamp(TimeUnit::Microsecond, _), + DataType::Timestamp(TimeUnit::Microsecond, _), + ) => { + typed_cmp!( + $LEFT, + $RIGHT, + TimestampMicrosecondArray, + $OP_PRIM, + TimestampMicrosecondType + ) + } + ( + DataType::Timestamp(TimeUnit::Millisecond, _), + DataType::Timestamp(TimeUnit::Millisecond, _), + ) => { + typed_cmp!( + $LEFT, + $RIGHT, + TimestampMillisecondArray, + $OP_PRIM, + TimestampMillisecondType + ) + } + ( + DataType::Timestamp(TimeUnit::Second, _), + DataType::Timestamp(TimeUnit::Second, _), + ) => { + typed_cmp!( + $LEFT, + $RIGHT, + TimestampSecondArray, + $OP_PRIM, + TimestampSecondType + ) + } + (DataType::Date32, DataType::Date32) => { + typed_cmp!($LEFT, $RIGHT, Date32Array, $OP_PRIM, Date32Type) + } + (DataType::Date64, DataType::Date64) => { + typed_cmp!($LEFT, $RIGHT, Date64Array, $OP_PRIM, Date64Type) + } + ( + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::YearMonth), + ) => { + typed_cmp!( + $LEFT, + $RIGHT, + IntervalYearMonthArray, + $OP_PRIM, + IntervalYearMonthType + ) + } + ( + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::DayTime), + ) => { + typed_cmp!( + $LEFT, + $RIGHT, + IntervalDayTimeArray, + $OP_PRIM, + IntervalDayTimeType + ) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::MonthDayNano), + ) => { + typed_cmp!( + $LEFT, + $RIGHT, + IntervalMonthDayNanoArray, + $OP_PRIM, + IntervalMonthDayNanoType + ) + } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing arrays of type {} is not yet implemented", t1 @@ -1064,52 +2114,425 @@ macro_rules! typed_compares { }}; } -/// Perform `left == right` operation on two (dynamic) [`Array`]s. -/// -/// Only when two arrays are of the same type the comparison will happen otherwise it will err +/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT +macro_rules! typed_dict_cmp { + ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr, $KT: tt) => {{ + match ($LEFT.value_type(), $RIGHT.value_type()) { + (DataType::Boolean, DataType::Boolean) => { + cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL) + } + (DataType::Int8, DataType::Int8) => { + cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Int16, DataType::Int16) => { + cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Int32, DataType::Int32) => { + cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Int64, DataType::Int64) => { + cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt8, DataType::UInt8) => { + cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt16, DataType::UInt16) => { + cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt32, DataType::UInt32) => { + cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::UInt64, DataType::UInt64) => { + cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Float32, DataType::Float32) => { + cmp_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Float64, DataType::Float64) => { + cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Utf8, DataType::Utf8) => { + cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + cmp_dict_utf8::<$KT, i64, _>($LEFT, $RIGHT, $OP) + } + (DataType::Binary, DataType::Binary) => { + cmp_dict_binary::<$KT, i32, _>($LEFT, $RIGHT, $OP) + } + (DataType::LargeBinary, DataType::LargeBinary) => { + cmp_dict_binary::<$KT, i64, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Nanosecond, _), + DataType::Timestamp(TimeUnit::Nanosecond, _), + ) => { + cmp_dict::<$KT, TimestampNanosecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Microsecond, _), + DataType::Timestamp(TimeUnit::Microsecond, _), + ) => { + cmp_dict::<$KT, TimestampMicrosecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Millisecond, _), + DataType::Timestamp(TimeUnit::Millisecond, _), + ) => { + cmp_dict::<$KT, TimestampMillisecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Timestamp(TimeUnit::Second, _), + DataType::Timestamp(TimeUnit::Second, _), + ) => { + cmp_dict::<$KT, TimestampSecondType, _>($LEFT, $RIGHT, $OP) + } + (DataType::Date32, DataType::Date32) => { + cmp_dict::<$KT, Date32Type, _>($LEFT, $RIGHT, $OP) + } + (DataType::Date64, DataType::Date64) => { + cmp_dict::<$KT, Date64Type, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::YearMonth), + ) => { + cmp_dict::<$KT, IntervalYearMonthType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::DayTime), + ) => { + cmp_dict::<$KT, IntervalDayTimeType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::MonthDayNano), + ) => { + cmp_dict::<$KT, IntervalMonthDayNanoType, _>($LEFT, $RIGHT, $OP) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary arrays of value type {} is not yet implemented", + t1 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare two dictionary arrays of different value types ({} and {})", + t1, t2 + ))), + } + }}; +} + +macro_rules! typed_dict_compares { + // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` + ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr) => {{ + match ($LEFT.data_type(), $RIGHT.data_type()) { + (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { + match (left_key_type.as_ref(), right_key_type.as_ref()) { + (DataType::Int8, DataType::Int8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int8Type) + } + (DataType::Int16, DataType::Int16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int16Type) + } + (DataType::Int32, DataType::Int32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int32Type) + } + (DataType::Int64, DataType::Int64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int64Type) + } + (DataType::UInt8, DataType::UInt8) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt8Type) + } + (DataType::UInt16, DataType::UInt16) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt16Type) + } + (DataType::UInt32, DataType::UInt32) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt32Type) + } + (DataType::UInt64, DataType::UInt64) => { + let left = as_dictionary_array::($LEFT); + let right = as_dictionary_array::($RIGHT); + typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt64Type) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary arrays of type {} is not yet implemented", + t1 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare two dictionary arrays of different key types ({} and {})", + t1, t2 + ))), + } + } + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare dictionary array with non-dictionary array ({} and {})", + t1, t2 + ))), + } + }}; +} + +/// Helper function to perform boolean lambda function on values from two dictionary arrays, this +/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) +macro_rules! compare_dict_op { + ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ + if $left.len() != $right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + // Safety justification: Since the inputs are valid Arrow arrays, all values are + // valid indexes into the dictionary (which is verified during construction) + + let left_iter = unsafe { + $left + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($left.keys_iter()) + }; + + let right_iter = unsafe { + $right + .values() + .as_any() + .downcast_ref::<$value_ty>() + .unwrap() + .take_iter_unchecked($right.keys_iter()) + }; + + let result = left_iter + .zip(right_iter) + .map(|(left_value, right_value)| { + if let (Some(left), Some(right)) = (left_value, right_value) { + Some($op(left, right)) + } else { + None + } + }) + .collect(); + + Ok(result) + }}; +} + +/// Perform given operation on two `DictionaryArray`s. +/// Returns an error if the two arrays have different value type +pub fn cmp_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> bool, +{ + compare_dict_op!(left, right, op, PrimitiveArray) +} + +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Boolean`. +pub fn cmp_dict_bool( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(bool, bool) -> bool, +{ + compare_dict_op!(left, right, op, BooleanArray) +} + +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Utf8` or `DataType::LargeUtf8`. +pub fn cmp_dict_utf8( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(&str, &str) -> bool, +{ + compare_dict_op!(left, right, op, GenericStringArray) +} + +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Binary` or `DataType::LargeBinary`. +pub fn cmp_dict_binary( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(&[u8], &[u8]) -> bool, +{ + compare_dict_op!(left, right, op, GenericBinaryArray) +} + +/// Perform `left == right` operation on two (dynamic) [`Array`]s. +/// +/// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. +/// +/// # Example +/// ``` +/// use arrow::array::{StringArray, BooleanArray}; +/// use arrow::compute::eq_dyn; +/// let array1 = StringArray::from(vec![Some("foo"), None, Some("bar")]); +/// let array2 = StringArray::from(vec![Some("foo"), None, Some("baz")]); +/// let result = eq_dyn(&array1, &array2).unwrap(); +/// assert_eq!(BooleanArray::from(vec![Some(true), None, Some(false)]), result); +/// ``` pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, eq_bool, eq, eq_utf8) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b) + } + _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary), + } } /// Perform `left != right` operation on two (dynamic) [`Array`]s. /// /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. +/// +/// # Example +/// ``` +/// use arrow::array::{BinaryArray, BooleanArray}; +/// use arrow::compute::neq_dyn; +/// let values1: Vec> = vec![Some(&[0xfc, 0xa9]), None, Some(&[0x36])]; +/// let values2: Vec> = vec![Some(&[0xfc, 0xa9]), None, Some(&[0x36, 0x00])]; +/// let array1 = BinaryArray::from(values1); +/// let array2 = BinaryArray::from(values2); +/// let result = neq_dyn(&array1, &array2).unwrap(); +/// assert_eq!(BooleanArray::from(vec![Some(false), None, Some(true)]), result); +/// ``` pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, neq_bool, neq, neq_utf8) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b) + } + _ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary), + } } /// Perform `left < right` operation on two (dynamic) [`Array`]s. /// /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. +/// +/// # Example +/// ``` +/// use arrow::array::{PrimitiveArray, BooleanArray}; +/// use arrow::datatypes::Int32Type; +/// use arrow::compute::lt_dyn; +/// let array1: PrimitiveArray = PrimitiveArray::from(vec![Some(0), Some(1), Some(2)]); +/// let array2: PrimitiveArray = PrimitiveArray::from(vec![Some(1), Some(1), None]); +/// let result = lt_dyn(&array1, &array2).unwrap(); +/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); +/// ``` +#[allow(clippy::bool_comparison)] pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, lt_bool, lt, lt_utf8) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b) + } + _ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary), + } } /// Perform `left <= right` operation on two (dynamic) [`Array`]s. /// /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. +/// +/// # Example +/// ``` +/// use arrow::array::{PrimitiveArray, BooleanArray}; +/// use arrow::datatypes::Date32Type; +/// use arrow::compute::lt_eq_dyn; +/// let array1: PrimitiveArray = vec![Some(12356), Some(13548), Some(-365), Some(365)].into(); +/// let array2: PrimitiveArray = vec![Some(12355), Some(13548), Some(-364), None].into(); +/// let result = lt_eq_dyn(&array1, &array2).unwrap(); +/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), Some(true), None]), result); +/// ``` pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b) + } + _ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary), + } } /// Perform `left > right` operation on two (dynamic) [`Array`]s. /// /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. +/// +/// # Example +/// ``` +/// use arrow::array::BooleanArray; +/// use arrow::compute::gt_dyn; +/// let array1 = BooleanArray::from(vec![Some(true), Some(false), None]); +/// let array2 = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let result = gt_dyn(&array1, &array2).unwrap(); +/// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); +/// ``` +#[allow(clippy::bool_comparison)] pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, gt_bool, gt, gt_utf8) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b) + } + _ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary), + } } /// Perform `left >= right` operation on two (dynamic) [`Array`]s. /// /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. +/// +/// # Example +/// ``` +/// use arrow::array::{BooleanArray, StringArray}; +/// use arrow::compute::gt_eq_dyn; +/// let array1 = StringArray::from(vec![Some(""), Some("aaa"), None]); +/// let array2 = StringArray::from(vec![Some(" "), Some("aa"), None]); +/// let result = gt_eq_dyn(&array1, &array2).unwrap(); +/// assert_eq!(BooleanArray::from(vec![Some(false), Some(true), None]), result); +/// ``` pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8) + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b) + } + _ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary), + } } /// Perform `left == right` operation on two [`PrimitiveArray`]s. @@ -1278,7 +2701,7 @@ where let num_bytes = bit_util::ceil(left_len, 8); let not_both_null_bit_buffer = - match combine_option_bitmap(left.data_ref(), right.data_ref(), left_len)? { + match combine_option_bitmap(&[left.data_ref(), right.data_ref()], left_len)? { Some(buff) => buff, None => new_all_set_buffer(num_bytes), }; @@ -1322,7 +2745,7 @@ pub fn contains_utf8( right: &ListArray, ) -> Result where - OffsetSize: StringOffsetSizeTrait, + OffsetSize: OffsetSizeTrait, { let left_len = left.len(); if left_len != right.len() { @@ -1335,7 +2758,7 @@ where let num_bytes = bit_util::ceil(left_len, 8); let not_both_null_bit_buffer = - match combine_option_bitmap(left.data_ref(), right.data_ref(), left_len)? { + match combine_option_bitmap(&[left.data_ref(), right.data_ref()], left_len)? { Some(buff) => buff, None => new_all_set_buffer(num_bytes), }; @@ -1389,18 +2812,21 @@ fn new_all_set_buffer(len: usize) -> Buffer { #[rustfmt::skip::macros(vec)] #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::datatypes::Int8Type; use crate::{array::Int32Array, array::Int64Array, datatypes::Field}; /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. - /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>`. + /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>` where `T` is the native + /// type of the data type of the Arrow array element. /// `EXPECTED` can be either `Vec` or `Vec>`. /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`. - macro_rules! cmp_i64 { - ($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { - let a = Int64Array::from($A_VEC); - let b = Int64Array::from($B_VEC); + macro_rules! cmp_vec { + ($KERNEL:ident, $DYN_KERNEL:ident, $ARRAY:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { + let a = $ARRAY::from($A_VEC); + let b = $ARRAY::from($B_VEC); let c = $KERNEL(&a, &b).unwrap(); assert_eq!(BooleanArray::from($EXPECTED), c); @@ -1409,6 +2835,30 @@ mod tests { let b = b.slice(0, b.len()); let c = $DYN_KERNEL(a.as_ref(), b.as_ref()).unwrap(); assert_eq!(BooleanArray::from($EXPECTED), c); + + // test with a larger version of the same data to ensure we cover the chunked part of the comparison + let mut a = vec![]; + let mut b = vec![]; + let mut e = vec![]; + for _i in 0..10 { + a.extend($A_VEC); + b.extend($B_VEC); + e.extend($EXPECTED); + } + let a = $ARRAY::from(a); + let b = $ARRAY::from(b); + let c = $KERNEL(&a, &b).unwrap(); + assert_eq!(BooleanArray::from(e), c); + }; + } + + /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. + /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>`. + /// `EXPECTED` can be either `Vec` or `Vec>`. + /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`. + macro_rules! cmp_i64 { + ($KERNEL:ident, $DYN_KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { + cmp_vec!($KERNEL, $DYN_KERNEL, Int64Array, $A_VEC, $B_VEC, $EXPECTED); }; } @@ -1421,6 +2871,18 @@ mod tests { let a = Int64Array::from($A_VEC); let c = $KERNEL(&a, $B).unwrap(); assert_eq!(BooleanArray::from($EXPECTED), c); + + // test with a larger version of the same data to ensure we cover the chunked part of the comparison + let mut a = vec![]; + let mut e = vec![]; + for _i in 0..10 { + a.extend($A_VEC); + e.extend($EXPECTED); + } + let a = Int64Array::from(a); + let c = $KERNEL(&a, $B).unwrap(); + assert_eq!(BooleanArray::from(e), c); + }; } @@ -1433,6 +2895,15 @@ mod tests { vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![false, false, true, false, false, false, false, true, false, false] ); + + cmp_vec!( + eq, + eq_dyn, + TimestampSecondArray, + vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, false, false, false, false, true, false, false] + ); } #[test] @@ -1480,6 +2951,15 @@ mod tests { vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![true, true, false, true, true, true, true, false, true, true] ); + + cmp_vec!( + neq, + neq_dyn, + TimestampMillisecondArray, + vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![true, true, false, true, true, true, true, false, true, true] + ); } #[test] @@ -1621,6 +3101,62 @@ mod tests { assert_eq!(res2, vec![Some(false), Some(true), None]); } + #[test] + fn test_boolean_array_lt_scalar() { + let a: BooleanArray = vec![Some(true), Some(false), None].into(); + + let res1: Vec> = lt_bool_scalar(&a, false).unwrap().iter().collect(); + + assert_eq!(res1, vec![Some(false), Some(false), None]); + + let res2: Vec> = lt_bool_scalar(&a, true).unwrap().iter().collect(); + + assert_eq!(res2, vec![Some(false), Some(true), None]); + } + + #[test] + fn test_boolean_array_lt_eq_scalar() { + let a: BooleanArray = vec![Some(true), Some(false), None].into(); + + let res1: Vec> = + lt_eq_bool_scalar(&a, false).unwrap().iter().collect(); + + assert_eq!(res1, vec![Some(false), Some(true), None]); + + let res2: Vec> = + lt_eq_bool_scalar(&a, true).unwrap().iter().collect(); + + assert_eq!(res2, vec![Some(true), Some(true), None]); + } + + #[test] + fn test_boolean_array_gt_scalar() { + let a: BooleanArray = vec![Some(true), Some(false), None].into(); + + let res1: Vec> = gt_bool_scalar(&a, false).unwrap().iter().collect(); + + assert_eq!(res1, vec![Some(true), Some(false), None]); + + let res2: Vec> = gt_bool_scalar(&a, true).unwrap().iter().collect(); + + assert_eq!(res2, vec![Some(false), Some(false), None]); + } + + #[test] + fn test_boolean_array_gt_eq_scalar() { + let a: BooleanArray = vec![Some(true), Some(false), None].into(); + + let res1: Vec> = + gt_eq_bool_scalar(&a, false).unwrap().iter().collect(); + + assert_eq!(res1, vec![Some(true), Some(true), None]); + + let res2: Vec> = + gt_eq_bool_scalar(&a, true).unwrap().iter().collect(); + + assert_eq!(res2, vec![Some(true), Some(false), None]); + } + #[test] fn test_primitive_array_lt() { cmp_i64!( @@ -1630,6 +3166,15 @@ mod tests { vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![false, false, false, true, true, false, false, false, true, true] ); + + cmp_vec!( + lt, + lt_dyn, + TimestampMillisecondArray, + vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, false, true, true, false, false, false, true, true] + ); } #[test] @@ -1651,6 +3196,15 @@ mod tests { vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),], vec![None, None, None, Some(false), None, None, None, Some(true)] ); + + cmp_vec!( + lt, + lt_dyn, + TimestampMillisecondArray, + vec![None, None, Some(1), Some(1), None, None, Some(2), Some(2),], + vec![None, Some(1), None, Some(1), None, Some(3), None, Some(3),], + vec![None, None, None, Some(false), None, None, None, Some(true)] + ); } #[test] @@ -1857,7 +3411,7 @@ mod tests { .len(4) .add_buffer(value_offsets) .add_child_data(value_data) - .null_bit_buffer(Buffer::from([0b00001011])) + .null_bit_buffer(Some(Buffer::from([0b00001011]))) .build() .unwrap(); @@ -1885,30 +3439,252 @@ mod tests { ); } - // Expected behaviour: - // contains("ab", ["ab", "cd", null]) = true - // contains("ef", ["ab", "cd", null]) = false - // contains(null, ["ab", "cd", null]) = false - // contains(null, null) = false #[test] - fn test_contains_utf8() { - let values_builder = StringBuilder::new(10); - let mut builder = ListBuilder::new(values_builder); + fn test_interval_array() { + let a = IntervalDayTimeArray::from( + vec![Some(0), Some(6), Some(834), None, Some(3), None], + ); + let b = IntervalDayTimeArray::from( + vec![Some(70), Some(6), Some(833), Some(6), Some(3), None], + ); + let res = eq(&a, &b).unwrap(); + let res_dyn = eq_dyn(&a, &b).unwrap(); + assert_eq!(res, res_dyn); + assert_eq!( + &res_dyn, + &BooleanArray::from( + vec![Some(false), Some(true), Some(false), None, Some(true), None] + ) + ); - builder.values().append_value("Lorem").unwrap(); - builder.values().append_value("ipsum").unwrap(); - builder.values().append_null().unwrap(); - builder.append(true).unwrap(); - builder.values().append_value("sit").unwrap(); - builder.values().append_value("amet").unwrap(); - builder.values().append_value("Lorem").unwrap(); - builder.append(true).unwrap(); - builder.append(false).unwrap(); - builder.values().append_value("ipsum").unwrap(); - builder.append(true).unwrap(); + let a = IntervalMonthDayNanoArray::from( + vec![Some(0), Some(6), Some(834), None, Some(3), None], + ); + let b = IntervalMonthDayNanoArray::from( + vec![Some(86), Some(5), Some(8), Some(6), Some(3), None], + ); + let res = lt(&a, &b).unwrap(); + let res_dyn = lt_dyn(&a, &b).unwrap(); + assert_eq!(res, res_dyn); + assert_eq!( + &res_dyn, + &BooleanArray::from( + vec![Some(true), Some(false), Some(false), None, Some(false), None] + ) + ); - // [["Lorem", "ipsum", null], ["sit", "amet", "Lorem"], null, ["ipsum"]] - // value_offsets = [0, 3, 6, 6] + let a = IntervalYearMonthArray::from( + vec![Some(0), Some(623), Some(834), None, Some(3), None], + ); + let b = IntervalYearMonthArray::from( + vec![Some(86), Some(5), Some(834), Some(6), Some(86), None], + ); + let res = gt_eq(&a, &b).unwrap(); + let res_dyn = gt_eq_dyn(&a, &b).unwrap(); + assert_eq!(res, res_dyn); + assert_eq!( + &res_dyn, + &BooleanArray::from( + vec![Some(false), Some(true), Some(true), None, Some(false), None] + ) + ); + } + + macro_rules! test_binary { + ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { + #[test] + fn $test_name() { + let left = BinaryArray::from_vec($left); + let right = BinaryArray::from_vec($right); + let res = $op(&left, &right).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + let left = LargeBinaryArray::from_vec($left); + let right = LargeBinaryArray::from_vec($right); + let res = $op(&left, &right).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + } + }; + } + + #[test] + fn test_binary_eq_scalar_on_slice() { + let a = BinaryArray::from_opt_vec( + vec![Some(b"hi"), None, Some(b"hello"), Some(b"world")], + ); + let a = a.slice(1, 3); + let a = as_generic_binary_array::(&a); + let a_eq = eq_binary_scalar(a, b"hello").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![None, Some(true), Some(false)]) + ); + } + + macro_rules! test_binary_scalar { + ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { + #[test] + fn $test_name() { + let left = BinaryArray::from_vec($left); + let res = $op(&left, $right).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!( + v, + expected[i], + "unexpected result when comparing {:?} at position {} to {:?} ", + left.value(i), + i, + $right + ); + } + + let left = LargeBinaryArray::from_vec($left); + let res = $op(&left, $right).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!( + v, + expected[i], + "unexpected result when comparing {:?} at position {} to {:?} ", + left.value(i), + i, + $right + ); + } + } + }; + } + + test_binary!( + test_binary_array_eq, + vec![b"arrow", b"arrow", b"arrow", b"arrow", &[0xff, 0xf8]], + vec![b"arrow", b"parquet", b"datafusion", b"flight", &[0xff, 0xf8]], + eq_binary, + vec![true, false, false, false, true] + ); + + test_binary_scalar!( + test_binary_array_eq_scalar, + vec![b"arrow", b"parquet", b"datafusion", b"flight", &[0xff, 0xf8]], + "arrow".as_bytes(), + eq_binary_scalar, + vec![true, false, false, false, false] + ); + + test_binary!( + test_binary_array_neq, + vec![b"arrow", b"arrow", b"arrow", b"arrow", &[0xff, 0xf8]], + vec![b"arrow", b"parquet", b"datafusion", b"flight", &[0xff, 0xf9]], + neq_binary, + vec![false, true, true, true, true] + ); + test_binary_scalar!( + test_binary_array_neq_scalar, + vec![b"arrow", b"parquet", b"datafusion", b"flight", &[0xff, 0xf8]], + "arrow".as_bytes(), + neq_binary_scalar, + vec![false, true, true, true, true] + ); + + test_binary!( + test_binary_array_lt, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf8]], + vec![b"flight", b"flight", b"flight", b"flight", &[0xff, 0xf9]], + lt_binary, + vec![true, true, false, false, true] + ); + test_binary_scalar!( + test_binary_array_lt_scalar, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf8]], + "flight".as_bytes(), + lt_binary_scalar, + vec![true, true, false, false, false] + ); + + test_binary!( + test_binary_array_lt_eq, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf8]], + vec![b"flight", b"flight", b"flight", b"flight", &[0xff, 0xf8, 0xf9]], + lt_eq_binary, + vec![true, true, true, false, true] + ); + test_binary_scalar!( + test_binary_array_lt_eq_scalar, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf8]], + "flight".as_bytes(), + lt_eq_binary_scalar, + vec![true, true, true, false, false] + ); + + test_binary!( + test_binary_array_gt, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf9]], + vec![b"flight", b"flight", b"flight", b"flight", &[0xff, 0xf8]], + gt_binary, + vec![false, false, false, true, true] + ); + test_binary_scalar!( + test_binary_array_gt_scalar, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf8]], + "flight".as_bytes(), + gt_binary_scalar, + vec![false, false, false, true, true] + ); + + test_binary!( + test_binary_array_gt_eq, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf8]], + vec![b"flight", b"flight", b"flight", b"flight", &[0xff, 0xf8]], + gt_eq_binary, + vec![false, false, true, true, true] + ); + test_binary_scalar!( + test_binary_array_gt_eq_scalar, + vec![b"arrow", b"datafusion", b"flight", b"parquet", &[0xff, 0xf8]], + "flight".as_bytes(), + gt_eq_binary_scalar, + vec![false, false, true, true, true] + ); + + // Expected behaviour: + // contains("ab", ["ab", "cd", null]) = true + // contains("ef", ["ab", "cd", null]) = false + // contains(null, ["ab", "cd", null]) = false + // contains(null, null) = false + #[test] + fn test_contains_utf8() { + let values_builder = StringBuilder::new(10); + let mut builder = ListBuilder::new(values_builder); + + builder.values().append_value("Lorem").unwrap(); + builder.values().append_value("ipsum").unwrap(); + builder.values().append_null().unwrap(); + builder.append(true).unwrap(); + builder.values().append_value("sit").unwrap(); + builder.values().append_value("amet").unwrap(); + builder.values().append_value("Lorem").unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.values().append_value("ipsum").unwrap(); + builder.append(true).unwrap(); + + // [["Lorem", "ipsum", null], ["sit", "amet", "Lorem"], null, ["ipsum"]] + // value_offsets = [0, 3, 6, 6] let list_array = builder.finish(); let nulls = StringArray::from(vec![None, None, None, None]); @@ -2083,10 +3859,34 @@ mod tests { test_utf8!( test_utf8_array_like, - vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow"], - vec!["arrow", "ar%", "%ro%", "foo", "arr", "arrow_", "arrow_"], + vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], + vec!["arrow", "ar%", "%ro%", "foo", "arr", "arrow_", "arrow_", ".*"], like_utf8, - vec![true, true, true, false, false, true, false] + vec![true, true, true, false, false, true, false, false] + ); + + test_utf8_scalar!( + test_utf8_array_like_scalar_escape_testing, + vec!["varchar(255)", "int(255)", "varchar", "int"], + "%(%)%", + like_utf8_scalar, + vec![true, true, false, false] + ); + + test_utf8_scalar!( + test_utf8_array_like_scalar_escape_regex, + vec![".*", "a", "*"], + ".*", + like_utf8_scalar, + vec![true, false, false] + ); + + test_utf8_scalar!( + test_utf8_array_like_scalar_escape_regex_dot, + vec![".", "a", "*"], + ".", + like_utf8_scalar, + vec![true, false, false] ); test_utf8_scalar!( @@ -2128,6 +3928,21 @@ mod tests { vec![false, true, false, false] ); + test_utf8!( + test_utf8_array_eq, + vec!["arrow", "arrow", "arrow", "arrow"], + vec!["arrow", "parquet", "datafusion", "flight"], + eq_utf8, + vec![true, false, false, false] + ); + test_utf8_scalar!( + test_utf8_array_eq_scalar, + vec!["arrow", "parquet", "datafusion", "flight"], + "arrow", + eq_utf8_scalar, + vec![true, false, false, false] + ); + test_utf8!( test_utf8_array_nlike, vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow"], @@ -2136,26 +3951,34 @@ mod tests { vec![false, false, false, true, true, false, true] ); test_utf8_scalar!( - test_utf8_array_nlike_scalar, - vec!["arrow", "parquet", "datafusion", "flight"], - "%ar%", + test_utf8_array_nlike_escape_testing, + vec!["varchar(255)", "int(255)", "varchar", "int"], + "%(%)%", nlike_utf8_scalar, vec![false, false, true, true] ); - test_utf8!( - test_utf8_array_eq, - vec!["arrow", "arrow", "arrow", "arrow"], - vec!["arrow", "parquet", "datafusion", "flight"], - eq_utf8, - vec![true, false, false, false] + test_utf8_scalar!( + test_utf8_array_nlike_scalar_escape_regex, + vec![".*", "a", "*"], + ".*", + nlike_utf8_scalar, + vec![false, true, true] ); + test_utf8_scalar!( - test_utf8_array_eq_scalar, + test_utf8_array_nlike_scalar_escape_regex_dot, + vec![".", "a", "*"], + ".", + nlike_utf8_scalar, + vec![false, true, true] + ); + test_utf8_scalar!( + test_utf8_array_nlike_scalar, vec!["arrow", "parquet", "datafusion", "flight"], - "arrow", - eq_utf8_scalar, - vec![true, false, false, false] + "%ar%", + nlike_utf8_scalar, + vec![false, false, true, true] ); test_utf8_scalar!( @@ -2190,6 +4013,114 @@ mod tests { vec![true, false, true, true] ); + test_utf8!( + test_utf8_array_ilike, + vec!["arrow", "arrow", "ARROW", "arrow", "ARROW", "ARROWS", "arROw"], + vec!["arrow", "ar%", "%ro%", "foo", "ar%r", "arrow_", "arrow_"], + ilike_utf8, + vec![true, true, true, false, false, true, false] + ); + test_utf8_scalar!( + ilike_utf8_scalar_escape_testing, + vec!["varchar(255)", "int(255)", "varchar", "int"], + "%(%)%", + ilike_utf8_scalar, + vec![true, true, false, false] + ); + test_utf8_scalar!( + test_utf8_array_ilike_scalar, + vec!["arrow", "parquet", "datafusion", "flight"], + "%AR%", + ilike_utf8_scalar, + vec![true, true, false, false] + ); + + test_utf8_scalar!( + test_utf8_array_ilike_scalar_start, + vec!["arrow", "parrow", "arrows", "ARR"], + "aRRow%", + ilike_utf8_scalar, + vec![true, false, true, false] + ); + + test_utf8_scalar!( + test_utf8_array_ilike_scalar_end, + vec!["ArroW", "parrow", "ARRowS", "arr"], + "%arrow", + ilike_utf8_scalar, + vec![true, true, false, false] + ); + + test_utf8_scalar!( + test_utf8_array_ilike_scalar_equals, + vec!["arrow", "parrow", "arrows", "arr"], + "arrow", + ilike_utf8_scalar, + vec![true, false, false, false] + ); + + test_utf8_scalar!( + test_utf8_array_ilike_scalar_one, + vec!["arrow", "arrows", "parrow", "arr"], + "arrow_", + ilike_utf8_scalar, + vec![false, true, false, false] + ); + + test_utf8!( + test_utf8_array_nilike, + vec!["arrow", "arrow", "ARROW", "arrow", "ARROW", "ARROWS", "arROw"], + vec!["arrow", "ar%", "%ro%", "foo", "ar%r", "arrow_", "arrow_"], + nilike_utf8, + vec![false, false, false, true, true, false, true] + ); + test_utf8_scalar!( + nilike_utf8_scalar_escape_testing, + vec!["varchar(255)", "int(255)", "varchar", "int"], + "%(%)%", + nilike_utf8_scalar, + vec![false, false, true, true] + ); + test_utf8_scalar!( + test_utf8_array_nilike_scalar, + vec!["arrow", "parquet", "datafusion", "flight"], + "%AR%", + nilike_utf8_scalar, + vec![false, false, true, true] + ); + + test_utf8_scalar!( + test_utf8_array_nilike_scalar_start, + vec!["arrow", "parrow", "arrows", "ARR"], + "aRRow%", + nilike_utf8_scalar, + vec![false, true, false, true] + ); + + test_utf8_scalar!( + test_utf8_array_nilike_scalar_end, + vec!["ArroW", "parrow", "ARRowS", "arr"], + "%arrow", + nilike_utf8_scalar, + vec![false, false, true, true] + ); + + test_utf8_scalar!( + test_utf8_array_nilike_scalar_equals, + vec!["arrow", "parrow", "arrows", "arr"], + "arrow", + nilike_utf8_scalar, + vec![false, true, true, true] + ); + + test_utf8_scalar!( + test_utf8_array_nilike_scalar_one, + vec!["arrow", "arrows", "parrow", "arr"], + "arrow_", + nilike_utf8_scalar, + vec![true, false, true, true] + ); + test_utf8!( test_utf8_array_neq, vec!["arrow", "arrow", "arrow", "arrow"], @@ -2302,4 +4233,818 @@ mod tests { regexp_is_match_utf8_scalar, vec![true, true, false, false] ); + + #[test] + fn test_eq_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let a_eq = eq_dyn_scalar(&array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(false)] + ) + ); + } + + #[test] + fn test_eq_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = builder.finish(); + let a_eq = eq_dyn_scalar(&array, 123).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + + #[test] + fn test_eq_dyn_scalar_float() { + let array: Float32Array = vec![6.0, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(false)], + ); + assert_eq!(eq_dyn_scalar(&array, 8).unwrap(), expected); + + let array: ArrayRef = Arc::new(array); + let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + assert_eq!(eq_dyn_scalar(&array, 8).unwrap(), expected); + } + + #[test] + fn test_lt_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let a_eq = lt_dyn_scalar(&array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(false)] + ) + ); + } + + #[test] + fn test_lt_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = builder.finish(); + let a_eq = lt_dyn_scalar(&array, 123).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), None, Some(true)]) + ); + } + + #[test] + fn test_lt_dyn_scalar_float() { + let array: Float32Array = vec![6.0, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(lt_dyn_scalar(&array, 8).unwrap(), expected); + + let array: ArrayRef = Arc::new(array); + let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + assert_eq!(lt_dyn_scalar(&array, 8).unwrap(), expected); + } + + #[test] + fn test_lt_eq_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let a_eq = lt_eq_dyn_scalar(&array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(false)] + ) + ); + } + #[test] + fn test_lt_eq_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = builder.finish(); + let a_eq = lt_eq_dyn_scalar(&array, 23).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), None, Some(true)]) + ); + } + + #[test] + fn test_lt_eq_dyn_scalar_float() { + let array: Float32Array = vec![6.0, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(false)], + ); + assert_eq!(lt_eq_dyn_scalar(&array, 8).unwrap(), expected); + + let array: ArrayRef = Arc::new(array); + let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + assert_eq!(lt_eq_dyn_scalar(&array, 8).unwrap(), expected); + } + + #[test] + fn test_gt_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let a_eq = gt_dyn_scalar(&array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true)] + ) + ); + } + + #[test] + fn test_gt_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(123).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = builder.finish(); + let a_eq = gt_dyn_scalar(&array, 23).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + + #[test] + fn test_gt_dyn_scalar_float() { + let array: Float32Array = vec![6.0, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true)], + ); + assert_eq!(gt_dyn_scalar(&array, 8).unwrap(), expected); + + let array: ArrayRef = Arc::new(array); + let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + assert_eq!(gt_dyn_scalar(&array, 8).unwrap(), expected); + } + + #[test] + fn test_gt_eq_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let a_eq = gt_eq_dyn_scalar(&array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(true)] + ) + ); + } + + #[test] + fn test_gt_eq_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(22).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = builder.finish(); + let a_eq = gt_eq_dyn_scalar(&array, 23).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), None, Some(true)]) + ); + } + + #[test] + fn test_gt_eq_dyn_scalar_float() { + let array: Float32Array = vec![6.0, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(gt_eq_dyn_scalar(&array, 8).unwrap(), expected); + + let array: ArrayRef = Arc::new(array); + let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + assert_eq!(gt_eq_dyn_scalar(&array, 8).unwrap(), expected); + } + + #[test] + fn test_neq_dyn_scalar() { + let array = Int32Array::from(vec![6, 7, 8, 8, 10]); + let a_eq = neq_dyn_scalar(&array, 8).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(true)] + ) + ); + } + + #[test] + fn test_neq_dyn_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = PrimitiveBuilder::::new(2); + let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); + builder.append(22).unwrap(); + builder.append_null().unwrap(); + builder.append(23).unwrap(); + let array = builder.finish(); + let a_eq = neq_dyn_scalar(&array, 23).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + + #[test] + fn test_neq_dyn_scalar_float() { + let array: Float32Array = vec![6.0, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(true)], + ); + assert_eq!(neq_dyn_scalar(&array, 8).unwrap(), expected); + + let array: ArrayRef = Arc::new(array); + let array = crate::compute::cast(&array, &DataType::Float64).unwrap(); + assert_eq!(neq_dyn_scalar(&array, 8).unwrap(), expected); + } + + #[test] + fn test_eq_dyn_binary_scalar() { + let data: Vec> = vec![Some(b"arrow"), Some(b"datafusion"), Some(b"flight"), Some(b"parquet"), Some(&[0xff, 0xf8]), None]; + let array = BinaryArray::from(data.clone()); + let large_array = LargeBinaryArray::from(data); + let scalar = "flight".as_bytes(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(false), Some(false), None], + ); + + assert_eq!(eq_dyn_binary_scalar(&array, scalar).unwrap(), expected); + assert_eq!( + eq_dyn_binary_scalar(&large_array, scalar).unwrap(), + expected + ); + } + + #[test] + fn test_neq_dyn_binary_scalar() { + let data: Vec> = vec![Some(b"arrow"), Some(b"datafusion"), Some(b"flight"), Some(b"parquet"), Some(&[0xff, 0xf8]), None]; + let array = BinaryArray::from(data.clone()); + let large_array = LargeBinaryArray::from(data); + let scalar = "flight".as_bytes(); + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(true), Some(true), None], + ); + + assert_eq!(neq_dyn_binary_scalar(&array, scalar).unwrap(), expected); + assert_eq!( + neq_dyn_binary_scalar(&large_array, scalar).unwrap(), + expected + ); + } + + #[test] + fn test_lt_dyn_binary_scalar() { + let data: Vec> = vec![Some(b"arrow"), Some(b"datafusion"), Some(b"flight"), Some(b"parquet"), Some(&[0xff, 0xf8]), None]; + let array = BinaryArray::from(data.clone()); + let large_array = LargeBinaryArray::from(data); + let scalar = "flight".as_bytes(); + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), Some(false), None], + ); + + assert_eq!(lt_dyn_binary_scalar(&array, scalar).unwrap(), expected); + assert_eq!( + lt_dyn_binary_scalar(&large_array, scalar).unwrap(), + expected + ); + } + + #[test] + fn test_lt_eq_dyn_binary_scalar() { + let data: Vec> = vec![Some(b"arrow"), Some(b"datafusion"), Some(b"flight"), Some(b"parquet"), Some(&[0xff, 0xf8]), None]; + let array = BinaryArray::from(data.clone()); + let large_array = LargeBinaryArray::from(data); + let scalar = "flight".as_bytes(); + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(false), Some(false), None], + ); + + assert_eq!(lt_eq_dyn_binary_scalar(&array, scalar).unwrap(), expected); + assert_eq!( + lt_eq_dyn_binary_scalar(&large_array, scalar).unwrap(), + expected + ); + } + + #[test] + fn test_gt_dyn_binary_scalar() { + let data: Vec> = vec![Some(b"arrow"), Some(b"datafusion"), Some(b"flight"), Some(b"parquet"), Some(&[0xff, 0xf8]), None]; + let array = BinaryArray::from(data.clone()); + let large_array = LargeBinaryArray::from(data); + let scalar = "flight".as_bytes(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(true), Some(true), None], + ); + + assert_eq!(gt_dyn_binary_scalar(&array, scalar).unwrap(), expected); + assert_eq!( + gt_dyn_binary_scalar(&large_array, scalar).unwrap(), + expected + ); + } + + #[test] + fn test_gt_eq_dyn_binary_scalar() { + let data: Vec> = vec![Some(b"arrow"), Some(b"datafusion"), Some(b"flight"), Some(b"parquet"), Some(&[0xff, 0xf8]), None]; + let array = BinaryArray::from(data.clone()); + let large_array = LargeBinaryArray::from(data); + let scalar = &[0xff, 0xf8]; + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), None], + ); + + assert_eq!(gt_eq_dyn_binary_scalar(&array, scalar).unwrap(), expected); + assert_eq!( + gt_eq_dyn_binary_scalar(&large_array, scalar).unwrap(), + expected + ); + } + + #[test] + fn test_eq_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let a_eq = eq_dyn_utf8_scalar(&array, "xyz").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(false), Some(true)]) + ); + } + + #[test] + fn test_eq_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("abc").unwrap(); + let array = builder.finish(); + let a_eq = eq_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), None, Some(true), Some(true), Some(false)] + ) + ); + } + #[test] + fn test_lt_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let a_eq = lt_dyn_utf8_scalar(&array, "xyz").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), Some(true), Some(false)]) + ); + } + #[test] + fn test_lt_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("abc").unwrap(); + let array = builder.finish(); + let a_eq = lt_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), None, Some(false), Some(false), Some(true)] + ) + ); + } + + #[test] + fn test_lt_eq_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let a_eq = lt_eq_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), Some(true), Some(false)]) + ); + } + #[test] + fn test_lt_eq_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("xyz").unwrap(); + let array = builder.finish(); + let a_eq = lt_eq_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), None, Some(true), Some(true), Some(false)] + ) + ); + } + + #[test] + fn test_gt_eq_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let a_eq = gt_eq_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(true), Some(true)]) + ); + } + #[test] + fn test_gt_eq_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("xyz").unwrap(); + let array = builder.finish(); + let a_eq = gt_eq_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), None, Some(true), Some(true), Some(true)] + ) + ); + } + + #[test] + fn test_gt_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let a_eq = gt_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(false), Some(true)]) + ); + } + + #[test] + fn test_gt_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("xyz").unwrap(); + let array = builder.finish(); + let a_eq = gt_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(false), None, Some(false), Some(false), Some(true)] + ) + ); + } + + #[test] + fn test_neq_dyn_utf8_scalar() { + let array = StringArray::from(vec!["abc", "def", "xyz"]); + let a_eq = neq_dyn_utf8_scalar(&array, "xyz").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), Some(true), Some(false)]) + ); + } + #[test] + fn test_neq_dyn_utf8_scalar_with_dict() { + let key_builder = PrimitiveBuilder::::new(3); + let value_builder = StringBuilder::new(100); + let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); + builder.append("abc").unwrap(); + builder.append_null().unwrap(); + builder.append("def").unwrap(); + builder.append("def").unwrap(); + builder.append("abc").unwrap(); + let array = builder.finish(); + let a_eq = neq_dyn_utf8_scalar(&array, "def").unwrap(); + assert_eq!( + a_eq, + BooleanArray::from( + vec![Some(true), None, Some(false), Some(false), Some(true)] + ) + ); + } + + #[test] + fn test_eq_dyn_bool_scalar() { + let array = BooleanArray::from(vec![true, false, true]); + let a_eq = eq_dyn_bool_scalar(&array, false).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(true), Some(false)]) + ); + } + + #[test] + fn test_lt_dyn_bool_scalar() { + let array = BooleanArray::from(vec![Some(true), Some(false), Some(true), None]); + let a_eq = lt_dyn_bool_scalar(&array, false).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(false), Some(false), None]) + ); + } + + #[test] + fn test_gt_dyn_bool_scalar() { + let array = BooleanArray::from(vec![true, false, true]); + let a_eq = gt_dyn_bool_scalar(&array, false).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), Some(false), Some(true)]) + ); + } + + #[test] + fn test_lt_eq_dyn_bool_scalar() { + let array = BooleanArray::from(vec![true, false, true]); + let a_eq = lt_eq_dyn_bool_scalar(&array, false).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(false), Some(true), Some(false)]) + ); + } + + #[test] + fn test_gt_eq_dyn_bool_scalar() { + let array = BooleanArray::from(vec![true, false, true]); + let a_eq = gt_eq_dyn_bool_scalar(&array, false).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), Some(true), Some(true)]) + ); + } + + #[test] + fn test_neq_dyn_bool_scalar() { + let array = BooleanArray::from(vec![true, false, true]); + let a_eq = neq_dyn_bool_scalar(&array, false).unwrap(); + assert_eq!( + a_eq, + BooleanArray::from(vec![Some(true), Some(false), Some(true)]) + ); + } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_i8_array() { + // Construct a value array + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + + let keys1 = Int8Array::from_iter_values([2_i8, 3, 4]); + let keys2 = Int8Array::from_iter_values([2_i8, 4, 4]); + let dict_array1 = DictionaryArray::try_new(&keys1, &values).unwrap(); + let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); + } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_u64_array() { + let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]); + + let keys1 = UInt64Array::from_iter_values([1_u64, 3, 4]); + let keys2 = UInt64Array::from_iter_values([2_u64, 3, 5]); + let dict_array1 = + DictionaryArray::::try_new(&keys1, &values).unwrap(); + let dict_array2 = + DictionaryArray::::try_new(&keys2, &values).unwrap(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_utf8_array() { + let test1 = vec!["a", "a", "b", "c"]; + let test2 = vec!["a", "b", "b", "c"]; + + let dict_array1: DictionaryArray = test1 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + let dict_array2: DictionaryArray = test2 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(true)]) + ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); + } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_binary_array() { + let values: BinaryArray = ["hello", "", "parquet"] + .into_iter() + .map(|b| Some(b.as_bytes())) + .collect(); + + let keys1 = UInt64Array::from_iter_values([0_u64, 1, 2]); + let keys2 = UInt64Array::from_iter_values([0_u64, 2, 1]); + let dict_array1 = + DictionaryArray::::try_new(&keys1, &values).unwrap(); + let dict_array2 = + DictionaryArray::::try_new(&keys2, &values).unwrap(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_interval_array() { + let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]); + + let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]); + let keys2 = UInt64Array::from_iter_values([2_u64, 0, 3]); + let dict_array1 = + DictionaryArray::::try_new(&keys1, &values).unwrap(); + let dict_array2 = + DictionaryArray::::try_new(&keys2, &values).unwrap(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_date_array() { + let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]); + + let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]); + let keys2 = UInt64Array::from_iter_values([2_u64, 0, 3]); + let dict_array1 = + DictionaryArray::::try_new(&keys1, &values).unwrap(); + let dict_array2 = + DictionaryArray::::try_new(&keys2, &values).unwrap(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + } + + #[test] + fn test_eq_dyn_neq_dyn_dictionary_bool_array() { + let values = BooleanArray::from(vec![true, false]); + + let keys1 = UInt64Array::from_iter_values([1_u64, 1, 1]); + let keys2 = UInt64Array::from_iter_values([0_u64, 1, 0]); + let dict_array1 = + DictionaryArray::::try_new(&keys1, &values).unwrap(); + let dict_array2 = + DictionaryArray::::try_new(&keys2, &values).unwrap(); + + let result = eq_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); + + let result = neq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + } + + #[test] + fn test_lt_dyn_gt_dyn_dictionary_i8_array() { + // Construct a value array + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + + let keys1 = Int8Array::from_iter_values([3_i8, 4, 4]); + let keys2 = Int8Array::from_iter_values([4_i8, 3, 4]); + let dict_array1 = DictionaryArray::try_new(&keys1, &values).unwrap(); + let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap(); + + let result = lt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + + let result = lt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, false, true])); + + let result = gt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, true, false]) + ); + + let result = gt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } + + #[test] + fn test_lt_dyn_gt_dyn_dictionary_bool_array() { + let values = BooleanArray::from(vec![true, false]); + + let keys1 = UInt64Array::from_iter_values([1_u64, 1, 0]); + let keys2 = UInt64Array::from_iter_values([0_u64, 1, 1]); + let dict_array1 = + DictionaryArray::::try_new(&keys1, &values).unwrap(); + let dict_array2 = + DictionaryArray::::try_new(&keys2, &values).unwrap(); + + let result = lt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![true, false, false]) + ); + + let result = lt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![true, true, false])); + + let result = gt_dyn(&dict_array1, &dict_array2); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![false, false, true]) + ); + + let result = gt_eq_dyn(&dict_array1, &dict_array2); + assert_eq!(result.unwrap(), BooleanArray::from(vec![false, true, true])); + } } diff --git a/arrow/src/compute/kernels/concat.rs b/arrow/src/compute/kernels/concat.rs index 303c919e0c31..a7a3ffc782c6 100644 --- a/arrow/src/compute/kernels/concat.rs +++ b/arrow/src/compute/kernels/concat.rs @@ -34,9 +34,7 @@ use crate::array::*; use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; -fn compute_str_values_length( - arrays: &[&ArrayData], -) -> usize { +fn compute_str_values_length(arrays: &[&ArrayData]) -> usize { arrays .iter() .map(|&data| { @@ -525,4 +523,50 @@ mod tests { Ok(()) } + + #[test] + fn test_dictionary_concat_reuse() { + let array: DictionaryArray = + vec!["a", "a", "b", "c"].into_iter().collect(); + let copy: DictionaryArray = array.data().clone().into(); + + // dictionary is "a", "b", "c" + assert_eq!( + array.values(), + &(Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef) + ); + assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2])); + + // concatenate it with itself + let combined = concat(&[© as _, &array as _]).unwrap(); + + let combined = combined + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!( + combined.values(), + &(Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef), + "Actual: {:#?}", + combined + ); + + assert_eq!( + combined.keys(), + &Int8Array::from(vec![0, 0, 1, 2, 0, 0, 1, 2]) + ); + + // Should have reused the dictionary + assert!(array.data().child_data()[0].ptr_eq(&combined.data().child_data()[0])); + assert!(copy.data().child_data()[0].ptr_eq(&combined.data().child_data()[0])); + + let new: DictionaryArray = vec!["d"].into_iter().collect(); + let combined = concat(&[© as _, &array as _, &new as _]).unwrap(); + + // Should not have reused the dictionary + assert!(!array.data().child_data()[0].ptr_eq(&combined.data().child_data()[0])); + assert!(!copy.data().child_data()[0].ptr_eq(&combined.data().child_data()[0])); + assert!(!new.data().child_data()[0].ptr_eq(&combined.data().child_data()[0])); + } } diff --git a/arrow/src/compute/kernels/concat_elements.rs b/arrow/src/compute/kernels/concat_elements.rs new file mode 100644 index 000000000000..bc341df889c0 --- /dev/null +++ b/arrow/src/compute/kernels/concat_elements.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::*; +use crate::compute::util::combine_option_bitmap; +use crate::error::{ArrowError, Result}; + +/// Returns the elementwise concatenation of a [`StringArray`]. +/// +/// An index of the resulting [`StringArray`] is null if any of +/// `StringArray` are null at that location. +/// +/// ```text +/// e.g: +/// +/// ["Hello"] + ["World"] = ["HelloWorld"] +/// +/// ["a", "b"] + [None, "c"] = [None, "bc"] +/// ``` +/// +/// An error will be returned if `left` and `right` have different lengths +pub fn concat_elements_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result> { + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Arrays must have the same length: {} != {}", + left.len(), + right.len() + ))); + } + + let output_bitmap = combine_option_bitmap(&[left.data(), right.data()], left.len())?; + + let left_offsets = left.value_offsets(); + let right_offsets = right.value_offsets(); + + let left_buffer = left.value_data(); + let right_buffer = right.value_data(); + let left_values = left_buffer.as_slice(); + let right_values = right_buffer.as_slice(); + + let mut output_values = BufferBuilder::::new( + left_values.len() + right_values.len() + - left_offsets[0].to_usize().unwrap() + - right_offsets[0].to_usize().unwrap(), + ); + + let mut output_offsets = BufferBuilder::::new(left_offsets.len()); + output_offsets.append(Offset::zero()); + for (left_idx, right_idx) in left_offsets.windows(2).zip(right_offsets.windows(2)) { + output_values.append_slice( + &left_values + [left_idx[0].to_usize().unwrap()..left_idx[1].to_usize().unwrap()], + ); + output_values.append_slice( + &right_values + [right_idx[0].to_usize().unwrap()..right_idx[1].to_usize().unwrap()], + ); + output_offsets.append(Offset::from_usize(output_values.len()).unwrap()); + } + + let builder = ArrayDataBuilder::new(GenericStringArray::::get_data_type()) + .len(left.len()) + .add_buffer(output_offsets.finish()) + .add_buffer(output_values.finish()) + .null_bit_buffer(output_bitmap); + + // SAFETY - offsets valid by construction + Ok(unsafe { builder.build_unchecked() }.into()) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_string_concat() { + let left = [Some("foo"), Some("bar"), None] + .into_iter() + .collect::(); + let right = [None, Some("yyy"), Some("zzz")] + .into_iter() + .collect::(); + + let output = concat_elements_utf8(&left, &right).unwrap(); + + let expected = [None, Some("baryyy"), None] + .into_iter() + .collect::(); + + assert_eq!(output, expected); + } + + #[test] + fn test_string_concat_empty_string() { + let left = [Some("foo"), Some(""), Some("bar")] + .into_iter() + .collect::(); + let right = [Some("baz"), Some(""), Some("")] + .into_iter() + .collect::(); + + let output = concat_elements_utf8(&left, &right).unwrap(); + + let expected = [Some("foobaz"), Some(""), Some("bar")] + .into_iter() + .collect::(); + + assert_eq!(output, expected); + } + + #[test] + fn test_string_concat_no_null() { + let left = StringArray::from(vec!["foo", "bar"]); + let right = StringArray::from(vec!["bar", "baz"]); + + let output = concat_elements_utf8(&left, &right).unwrap(); + + let expected = StringArray::from(vec!["foobar", "barbaz"]); + + assert_eq!(output, expected); + } + + #[test] + fn test_string_concat_error() { + let left = StringArray::from(vec!["foo", "bar"]); + let right = StringArray::from(vec!["baz"]); + + let output = concat_elements_utf8(&left, &right); + + assert!(output.is_err()); + } + + #[test] + fn test_string_concat_slice() { + let left = &StringArray::from(vec![None, Some("foo"), Some("bar"), Some("baz")]); + let right = &StringArray::from(vec![Some("boo"), None, Some("far"), Some("faz")]); + + let left_slice = left.slice(0, 3); + let right_slice = right.slice(1, 3); + let output = concat_elements_utf8( + left_slice + .as_any() + .downcast_ref::>() + .unwrap(), + right_slice + .as_any() + .downcast_ref::>() + .unwrap(), + ) + .unwrap(); + + let expected = [None, Some("foofar"), Some("barfaz")] + .into_iter() + .collect::(); + + assert_eq!(output, expected); + + let left_slice = left.slice(2, 2); + let right_slice = right.slice(1, 2); + + let output = concat_elements_utf8( + left_slice + .as_any() + .downcast_ref::>() + .unwrap(), + right_slice + .as_any() + .downcast_ref::>() + .unwrap(), + ) + .unwrap(); + + let expected = [None, Some("bazfar")].into_iter().collect::(); + + assert_eq!(output, expected); + } +} diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 61a73d0d64bf..1af93bff5ad7 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -17,184 +17,144 @@ //! Defines miscellaneous array kernels. -use crate::buffer::buffer_bin_and; -use crate::datatypes::DataType; -use crate::error::Result; +use std::ops::AddAssign; +use std::sync::Arc; + +use num::Zero; + +use TimeUnit::*; + +use crate::array::*; +use crate::buffer::{buffer_bin_and, Buffer, MutableBuffer}; +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; -use crate::{array::*, util::bit_chunk_iterator::BitChunkIterator}; -use std::iter::Enumerate; +use crate::util::bit_iterator::{BitIndexIterator, BitSliceIterator}; +use crate::util::bit_util; -/// Function that can filter arbitrary arrays -pub type Filter<'a> = Box ArrayData + 'a>; +/// If the filter selects more than this fraction of rows, use +/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate +/// over individual rows using [`IndexIterator`] +/// +/// Threshold of 0.8 chosen based on +/// +const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; -/// Internal state of [SlicesIterator] -#[derive(Debug, PartialEq)] -enum State { - // it is iterating over bits of a mask (`u64`, steps of size of 1 slot) - Bits(u64), - // it is iterating over chunks (steps of size of 64 slots) - Chunks, - // it is iterating over the remainding bits (steps of size of 1 slot) - Remainder, - // nothing more to iterate. - Finish, +macro_rules! downcast_filter { + ($type: ty, $values: expr, $filter: expr) => {{ + let values = $values + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to a primitive array"); + + Ok(Arc::new(filter_primitive::<$type>(&values, $filter))) + }}; +} + +macro_rules! downcast_dict_filter { + ($type: ty, $values: expr, $filter: expr) => {{ + let values = $values + .as_any() + .downcast_ref::>() + .expect("Unable to downcast to a dictionary array"); + Ok(Arc::new(filter_dict::<$type>(values, $filter))) + }}; } -/// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose -/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be +/// An iterator of `(usize, usize)` each representing an interval +/// `[start, end)` whose slots of a [BooleanArray] are true. Each +/// interval corresponds to a contiguous region of memory to be /// "taken" from an array to be filtered. +/// +/// ## Notes: +/// +/// 1. Ignores the validity bitmap (ignores nulls) +/// +/// 2. Only performant for filters that copy across long contiguous runs #[derive(Debug)] -pub struct SlicesIterator<'a> { - iter: Enumerate>, - state: State, - filter: &'a BooleanArray, - remainder_mask: u64, - remainder_len: usize, - chunk_len: usize, - len: usize, - start: usize, - on_region: bool, - current_chunk: usize, - current_bit: usize, -} +pub struct SlicesIterator<'a>(BitSliceIterator<'a>); impl<'a> SlicesIterator<'a> { pub fn new(filter: &'a BooleanArray) -> Self { let values = &filter.data_ref().buffers()[0]; - let chunks = values.bit_chunks(filter.offset(), filter.len()); + let len = filter.len(); + let offset = filter.offset(); - Self { - iter: chunks.iter().enumerate(), - state: State::Chunks, - filter, - remainder_len: chunks.remainder_len(), - chunk_len: chunks.chunk_len(), - remainder_mask: chunks.remainder_bits(), - len: 0, - start: 0, - on_region: false, - current_chunk: 0, - current_bit: 0, - } + Self(BitSliceIterator::new(values, offset, len)) } +} - /// Counts the number of set bits in the filter array. - fn filter_count(&self) -> usize { - let values = self.filter.values(); - values.count_set_bits_offset(self.filter.offset(), self.filter.len()) - } +impl<'a> Iterator for SlicesIterator<'a> { + type Item = (usize, usize); - #[inline] - fn current_start(&self) -> usize { - self.current_chunk * 64 + self.current_bit + fn next(&mut self) -> Option { + self.0.next() } +} - #[inline] - fn iterate_bits(&mut self, mask: u64, max: usize) -> Option<(usize, usize)> { - while self.current_bit < max { - if (mask & (1 << self.current_bit)) != 0 { - if !self.on_region { - self.start = self.current_start(); - self.on_region = true; - } - self.len += 1; - } else if self.on_region { - let result = (self.start, self.start + self.len); - self.len = 0; - self.on_region = false; - self.current_bit += 1; - return Some(result); - } - self.current_bit += 1; - } - self.current_bit = 0; - None - } +/// An iterator of `usize` whose index in [`BooleanArray`] is true +/// +/// This provides the best performance on most predicates, apart from those which keep +/// large runs and therefore favour [`SlicesIterator`] +struct IndexIterator<'a> { + remaining: usize, + iter: BitIndexIterator<'a>, +} - /// iterates over chunks. - #[inline] - fn iterate_chunks(&mut self) -> Option<(usize, usize)> { - while let Some((i, mask)) = self.iter.next() { - self.current_chunk = i; - if mask == 0 { - if self.on_region { - let result = (self.start, self.start + self.len); - self.len = 0; - self.on_region = false; - return Some(result); - } - } else if mask == 18446744073709551615u64 { - // = !0u64 - if !self.on_region { - self.start = self.current_start(); - self.on_region = true; - } - self.len += 64; - } else { - // there is a chunk that has a non-trivial mask => iterate over bits. - self.state = State::Bits(mask); - return None; - } - } - // no more chunks => start iterating over the remainder - self.current_chunk = self.chunk_len; - self.state = State::Remainder; - None +impl<'a> IndexIterator<'a> { + fn new(filter: &'a BooleanArray, remaining: usize) -> Self { + assert_eq!(filter.null_count(), 0); + let data = filter.data(); + let iter = BitIndexIterator::new(&data.buffers()[0], data.offset(), data.len()); + Self { remaining, iter } } } -impl<'a> Iterator for SlicesIterator<'a> { - type Item = (usize, usize); +impl<'a> Iterator for IndexIterator<'a> { + type Item = usize; fn next(&mut self) -> Option { - match self.state { - State::Chunks => { - match self.iterate_chunks() { - None => { - // iterating over chunks does not yield any new slice => continue to the next - self.current_bit = 0; - self.next() - } - other => other, - } - } - State::Bits(mask) => { - match self.iterate_bits(mask, 64) { - None => { - // iterating over bits does not yield any new slice => change back - // to chunks and continue to the next - self.state = State::Chunks; - self.next() - } - other => other, - } - } - State::Remainder => { - match self.iterate_bits(self.remainder_mask, self.remainder_len) { - None => { - self.state = State::Finish; - if self.on_region { - Some((self.start, self.start + self.len)) - } else { - None - } - } - other => other, - } - } - State::Finish => None, + if self.remaining != 0 { + // Fascinatingly swapping these two lines around results in a 50% + // performance regression for some benchmarks + let next = self.iter.next().expect("IndexIterator exhausted early"); + self.remaining -= 1; + // Must panic if exhausted early as trusted length iterator + return Some(next); } + None + } + + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) } } +/// Counts the number of set bits in `filter` +fn filter_count(filter: &BooleanArray) -> usize { + filter + .values() + .count_set_bits_offset(filter.offset(), filter.len()) +} + +/// Function that can filter arbitrary arrays +/// +/// Deprecated: Use [`FilterPredicate`] instead +#[deprecated] +pub type Filter<'a> = Box ArrayData + 'a>; + /// Returns a prepared function optimized to filter multiple arrays. /// Creating this function requires time, but using it is faster than [filter] when the /// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`). /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. /// Therefore, it is considered undefined behavior to pass `filter` with null values. +/// +/// Deprecated: Use [`FilterBuilder`] instead +#[deprecated] +#[allow(deprecated)] pub fn build_filter(filter: &BooleanArray) -> Result { let iter = SlicesIterator::new(filter); - let filter_count = iter.filter_count(); + let filter_count = filter_count(filter); let chunks = iter.collect::>(); Ok(Box::new(move |array: &ArrayData| { @@ -247,35 +207,9 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { /// # Ok(()) /// # } /// ``` -pub fn filter(array: &dyn Array, predicate: &BooleanArray) -> Result { - if predicate.null_count() > 0 { - // this greatly simplifies subsequent filtering code - // now we only have a boolean mask to deal with - let predicate = prep_null_mask_filter(predicate); - return filter(array, &predicate); - } - - let iter = SlicesIterator::new(predicate); - let filter_count = iter.filter_count(); - match filter_count { - 0 => { - // return empty - Ok(new_empty_array(array.data_type())) - } - len if len == array.len() => { - // return all - let data = array.data().clone(); - Ok(make_array(data)) - } - _ => { - // actually filter - let mut mutable = - MutableArrayData::new(vec![array.data_ref()], false, filter_count); - iter.for_each(|(start, end)| mutable.extend(0, start, end)); - let data = mutable.freeze(); - Ok(make_array(data)) - } - } +pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result { + let predicate = FilterBuilder::new(predicate).build(); + filter_array(values, &predicate) } /// Returns a new [RecordBatch] with arrays containing only values matching the filter. @@ -283,40 +217,591 @@ pub fn filter_record_batch( record_batch: &RecordBatch, predicate: &BooleanArray, ) -> Result { - if predicate.null_count() > 0 { - // this greatly simplifies subsequent filtering code - // now we only have a boolean mask to deal with - let predicate = prep_null_mask_filter(predicate); - return filter_record_batch(record_batch, &predicate); + let mut filter_builder = FilterBuilder::new(predicate); + if record_batch.num_columns() > 1 { + // Only optimize if filtering more than one column + filter_builder = filter_builder.optimize(); + } + let filter = filter_builder.build(); + + let filtered_arrays = record_batch + .columns() + .iter() + .map(|a| filter_array(a, &filter)) + .collect::>>()?; + + RecordBatch::try_new(record_batch.schema(), filtered_arrays) +} + +/// A builder to construct [`FilterPredicate`] +#[derive(Debug)] +pub struct FilterBuilder { + filter: BooleanArray, + count: usize, + strategy: IterationStrategy, +} + +impl FilterBuilder { + /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`] + pub fn new(filter: &BooleanArray) -> Self { + let filter = match filter.null_count() { + 0 => BooleanArray::from(filter.data().clone()), + _ => prep_null_mask_filter(filter), + }; + + let count = filter_count(&filter); + let strategy = IterationStrategy::default_strategy(filter.len(), count); + + Self { + filter, + count, + strategy, + } + } + + /// Compute an optimised representation of the provided `filter` mask that can be + /// applied to an array more quickly. + /// + /// Note: There is limited benefit to calling this to then filter a single array + /// Note: This will likely have a larger memory footprint than the original mask + pub fn optimize(mut self) -> Self { + match self.strategy { + IterationStrategy::SlicesIterator => { + let slices = SlicesIterator::new(&self.filter).collect(); + self.strategy = IterationStrategy::Slices(slices) + } + IterationStrategy::IndexIterator => { + let indices = IndexIterator::new(&self.filter, self.count).collect(); + self.strategy = IterationStrategy::Indices(indices) + } + _ => {} + } + self + } + + /// Construct the final `FilterPredicate` + pub fn build(self) -> FilterPredicate { + FilterPredicate { + filter: self.filter, + count: self.count, + strategy: self.strategy, + } + } +} + +/// The iteration strategy used to evaluate [`FilterPredicate`] +#[derive(Debug)] +enum IterationStrategy { + /// A lazily evaluated iterator of ranges + SlicesIterator, + /// A lazily evaluated iterator of indices + IndexIterator, + /// A precomputed list of indices + Indices(Vec), + /// A precomputed array of ranges + Slices(Vec<(usize, usize)>), + /// Select all rows + All, + /// Select no rows + None, +} + +impl IterationStrategy { + /// The default [`IterationStrategy`] for a filter of length `filter_length` + /// and selecting `filter_count` rows + fn default_strategy(filter_length: usize, filter_count: usize) -> Self { + if filter_length == 0 || filter_count == 0 { + return IterationStrategy::None; + } + + if filter_count == filter_length { + return IterationStrategy::All; + } + + // Compute the selectivity of the predicate by dividing the number of true + // bits in the predicate by the predicate's total length + // + // This can then be used as a heuristic for the optimal iteration strategy + let selectivity_frac = filter_count as f64 / filter_length as f64; + if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD { + return IterationStrategy::SlicesIterator; + } + IterationStrategy::IndexIterator } +} - let num_colums = record_batch.columns().len(); +/// A filtering predicate that can be applied to an [`Array`] +#[derive(Debug)] +pub struct FilterPredicate { + filter: BooleanArray, + count: usize, + strategy: IterationStrategy, +} - let filtered_arrays = match num_colums { - 1 => { - vec![filter(record_batch.columns()[0].as_ref(), predicate)?] +impl FilterPredicate { + /// Selects rows from `values` based on this [`FilterPredicate`] + pub fn filter(&self, values: &dyn Array) -> Result { + filter_array(values, self) + } +} + +fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { + if predicate.filter.len() > values.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Filter predicate of length {} is larger than target array of length {}", + predicate.filter.len(), + values.len() + ))); + } + + match predicate.strategy { + IterationStrategy::None => Ok(new_empty_array(values.data_type())), + IterationStrategy::All => Ok(make_array(values.data().slice(0, predicate.count))), + // actually filter + _ => match values.data_type() { + DataType::Boolean => { + let values = values.as_any().downcast_ref::().unwrap(); + Ok(Arc::new(filter_boolean(values, predicate))) + } + DataType::Int8 => { + downcast_filter!(Int8Type, values, predicate) + } + DataType::Int16 => { + downcast_filter!(Int16Type, values, predicate) + } + DataType::Int32 => { + downcast_filter!(Int32Type, values, predicate) + } + DataType::Int64 => { + downcast_filter!(Int64Type, values, predicate) + } + DataType::UInt8 => { + downcast_filter!(UInt8Type, values, predicate) + } + DataType::UInt16 => { + downcast_filter!(UInt16Type, values, predicate) + } + DataType::UInt32 => { + downcast_filter!(UInt32Type, values, predicate) + } + DataType::UInt64 => { + downcast_filter!(UInt64Type, values, predicate) + } + DataType::Float32 => { + downcast_filter!(Float32Type, values, predicate) + } + DataType::Float64 => { + downcast_filter!(Float64Type, values, predicate) + } + DataType::Date32 => { + downcast_filter!(Date32Type, values, predicate) + } + DataType::Date64 => { + downcast_filter!(Date64Type, values, predicate) + } + DataType::Time32(Second) => { + downcast_filter!(Time32SecondType, values, predicate) + } + DataType::Time32(Millisecond) => { + downcast_filter!(Time32MillisecondType, values, predicate) + } + DataType::Time64(Microsecond) => { + downcast_filter!(Time64MicrosecondType, values, predicate) + } + DataType::Time64(Nanosecond) => { + downcast_filter!(Time64NanosecondType, values, predicate) + } + DataType::Timestamp(Second, _) => { + downcast_filter!(TimestampSecondType, values, predicate) + } + DataType::Timestamp(Millisecond, _) => { + downcast_filter!(TimestampMillisecondType, values, predicate) + } + DataType::Timestamp(Microsecond, _) => { + downcast_filter!(TimestampMicrosecondType, values, predicate) + } + DataType::Timestamp(Nanosecond, _) => { + downcast_filter!(TimestampNanosecondType, values, predicate) + } + DataType::Interval(IntervalUnit::YearMonth) => { + downcast_filter!(IntervalYearMonthType, values, predicate) + } + DataType::Interval(IntervalUnit::DayTime) => { + downcast_filter!(IntervalDayTimeType, values, predicate) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + downcast_filter!(IntervalMonthDayNanoType, values, predicate) + } + DataType::Duration(TimeUnit::Second) => { + downcast_filter!(DurationSecondType, values, predicate) + } + DataType::Duration(TimeUnit::Millisecond) => { + downcast_filter!(DurationMillisecondType, values, predicate) + } + DataType::Duration(TimeUnit::Microsecond) => { + downcast_filter!(DurationMicrosecondType, values, predicate) + } + DataType::Duration(TimeUnit::Nanosecond) => { + downcast_filter!(DurationNanosecondType, values, predicate) + } + DataType::Utf8 => { + let values = values + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(Arc::new(filter_string::(values, predicate))) + } + DataType::LargeUtf8 => { + let values = values + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(Arc::new(filter_string::(values, predicate))) + } + DataType::Dictionary(key_type, _) => match key_type.as_ref() { + DataType::Int8 => downcast_dict_filter!(Int8Type, values, predicate), + DataType::Int16 => downcast_dict_filter!(Int16Type, values, predicate), + DataType::Int32 => downcast_dict_filter!(Int32Type, values, predicate), + DataType::Int64 => downcast_dict_filter!(Int64Type, values, predicate), + DataType::UInt8 => downcast_dict_filter!(UInt8Type, values, predicate), + DataType::UInt16 => downcast_dict_filter!(UInt16Type, values, predicate), + DataType::UInt32 => downcast_dict_filter!(UInt32Type, values, predicate), + DataType::UInt64 => downcast_dict_filter!(UInt64Type, values, predicate), + t => { + unimplemented!("Filter not supported for dictionary key type {:?}", t) + } + }, + _ => { + // fallback to using MutableArrayData + let mut mutable = MutableArrayData::new( + vec![values.data_ref()], + false, + predicate.count, + ); + + match &predicate.strategy { + IterationStrategy::Slices(slices) => { + slices + .iter() + .for_each(|(start, end)| mutable.extend(0, *start, *end)); + } + _ => { + let iter = SlicesIterator::new(&predicate.filter); + iter.for_each(|(start, end)| mutable.extend(0, start, end)); + } + } + + let data = mutable.freeze(); + Ok(make_array(data)) + } + }, + } +} + +/// Computes a new null mask for `data` based on `predicate` +/// +/// If the predicate selected no null-rows, returns `None`, otherwise returns +/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls +/// in the filtered output, and `null_buffer` is the filtered null buffer +/// +fn filter_null_mask( + data: &ArrayData, + predicate: &FilterPredicate, +) -> Option<(usize, Buffer)> { + if data.null_count() == 0 { + return None; + } + + let nulls = filter_bits(data.null_buffer()?, data.offset(), predicate); + // The filtered `nulls` has a length of `predicate.count` bits and + // therefore the null count is this minus the number of valid bits + let null_count = predicate.count - nulls.count_set_bits(); + + if null_count == 0 { + return None; + } + + Some((null_count, nulls)) +} + +/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset` +fn filter_bits(buffer: &Buffer, offset: usize, predicate: &FilterPredicate) -> Buffer { + let src = buffer.as_slice(); + + match &predicate.strategy { + IterationStrategy::IndexIterator => { + let bits = IndexIterator::new(&predicate.filter, predicate.count) + .map(|src_idx| bit_util::get_bit(src, src_idx + offset)); + + // SAFETY: `IndexIterator` reports its size correctly + unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } } - _ => { - let filter = build_filter(predicate)?; - record_batch - .columns() + IterationStrategy::Indices(indices) => { + let bits = indices .iter() - .map(|a| make_array(filter(a.data()))) - .collect() + .map(|src_idx| bit_util::get_bit(src, *src_idx + offset)); + + // SAFETY: `Vec::iter()` reports its size correctly + unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() } + } + IterationStrategy::SlicesIterator => { + let mut builder = + BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8)); + for (start, end) in SlicesIterator::new(&predicate.filter) { + builder.append_packed_range(start + offset..end + offset, src) + } + builder.finish() + } + IterationStrategy::Slices(slices) => { + let mut builder = + BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8)); + for (start, end) in slices { + builder.append_packed_range(*start + offset..*end + offset, src) + } + builder.finish() + } + IterationStrategy::All | IterationStrategy::None => unreachable!(), + } +} + +/// `filter` implementation for boolean buffers +fn filter_boolean(values: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray { + let data = values.data(); + assert_eq!(data.buffers().len(), 1); + assert_eq!(data.child_data().len(), 0); + + let values = filter_bits(&data.buffers()[0], data.offset(), predicate); + + let mut builder = ArrayDataBuilder::new(DataType::Boolean) + .len(predicate.count) + .add_buffer(values); + + if let Some((null_count, nulls)) = filter_null_mask(data, predicate) { + builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); + } + + let data = unsafe { builder.build_unchecked() }; + BooleanArray::from(data) +} + +/// `filter` implementation for primitive arrays +fn filter_primitive( + values: &PrimitiveArray, + predicate: &FilterPredicate, +) -> PrimitiveArray +where + T: ArrowPrimitiveType, +{ + let data = values.data(); + assert_eq!(data.buffers().len(), 1); + assert_eq!(data.child_data().len(), 0); + + let values = data.buffer::(0); + assert!(values.len() >= predicate.filter.len()); + + let buffer = match &predicate.strategy { + IterationStrategy::SlicesIterator => { + let mut buffer = + MutableBuffer::with_capacity(predicate.count * T::get_byte_width()); + for (start, end) in SlicesIterator::new(&predicate.filter) { + buffer.extend_from_slice(&values[start..end]); + } + buffer + } + IterationStrategy::Slices(slices) => { + let mut buffer = + MutableBuffer::with_capacity(predicate.count * T::get_byte_width()); + for (start, end) in slices { + buffer.extend_from_slice(&values[*start..*end]); + } + buffer + } + IterationStrategy::IndexIterator => { + let iter = + IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]); + + // SAFETY: IndexIterator is trusted length + unsafe { MutableBuffer::from_trusted_len_iter(iter) } } + IterationStrategy::Indices(indices) => { + let iter = indices.iter().map(|x| values[*x]); + + // SAFETY: `Vec::iter` is trusted length + unsafe { MutableBuffer::from_trusted_len_iter(iter) } + } + IterationStrategy::All | IterationStrategy::None => unreachable!(), }; - RecordBatch::try_new(record_batch.schema(), filtered_arrays) + + let mut builder = ArrayDataBuilder::new(data.data_type().clone()) + .len(predicate.count) + .add_buffer(buffer.into()); + + if let Some((null_count, nulls)) = filter_null_mask(data, predicate) { + builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); + } + + let data = unsafe { builder.build_unchecked() }; + PrimitiveArray::from(data) +} + +/// [`FilterString`] is created from a source [`GenericStringArray`] and can be +/// used to build a new [`GenericStringArray`] by copying values from the source +/// +/// TODO(raphael): Could this be used for the take kernel as well? +struct FilterString<'a, OffsetSize> { + src_offsets: &'a [OffsetSize], + src_values: &'a [u8], + dst_offsets: MutableBuffer, + dst_values: MutableBuffer, + cur_offset: OffsetSize, +} + +impl<'a, OffsetSize> FilterString<'a, OffsetSize> +where + OffsetSize: Zero + AddAssign + OffsetSizeTrait, +{ + fn new(capacity: usize, array: &'a GenericStringArray) -> Self { + let num_offsets_bytes = (capacity + 1) * std::mem::size_of::(); + let mut dst_offsets = MutableBuffer::new(num_offsets_bytes); + let dst_values = MutableBuffer::new(0); + let cur_offset = OffsetSize::zero(); + dst_offsets.push(cur_offset); + + Self { + src_offsets: array.value_offsets(), + src_values: &array.data().buffers()[1], + dst_offsets, + dst_values, + cur_offset, + } + } + + /// Returns the byte offset at `idx` + #[inline] + fn get_value_offset(&self, idx: usize) -> usize { + self.src_offsets[idx].to_usize().expect("illegal offset") + } + + /// Returns the start and end of the value at index `idx` along with its length + #[inline] + fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) { + // These can only fail if `array` contains invalid data + let start = self.get_value_offset(idx); + let end = self.get_value_offset(idx + 1); + let len = OffsetSize::from_usize(end - start).expect("illegal offset range"); + (start, end, len) + } + + /// Extends the in-progress array by the indexes in the provided iterator + fn extend_idx(&mut self, iter: impl Iterator) { + for idx in iter { + let (start, end, len) = self.get_value_range(idx); + self.cur_offset += len; + self.dst_offsets.push(self.cur_offset); + self.dst_values + .extend_from_slice(&self.src_values[start..end]); + } + } + + /// Extends the in-progress array by the ranges in the provided iterator + fn extend_slices(&mut self, iter: impl Iterator) { + for (start, end) in iter { + // These can only fail if `array` contains invalid data + for idx in start..end { + let (_, _, len) = self.get_value_range(idx); + self.cur_offset += len; + self.dst_offsets.push(self.cur_offset); // push_unchecked? + } + + let value_start = self.get_value_offset(start); + let value_end = self.get_value_offset(end); + self.dst_values + .extend_from_slice(&self.src_values[value_start..value_end]); + } + } +} + +/// `filter` implementation for string arrays +/// +/// Note: NULLs with a non-zero slot length in `array` will have the corresponding +/// data copied across. This allows handling the null mask separately from the data +fn filter_string( + array: &GenericStringArray, + predicate: &FilterPredicate, +) -> GenericStringArray +where + OffsetSize: Zero + AddAssign + OffsetSizeTrait, +{ + let data = array.data(); + assert_eq!(data.buffers().len(), 2); + assert_eq!(data.child_data().len(), 0); + let mut filter = FilterString::new(predicate.count, array); + + match &predicate.strategy { + IterationStrategy::SlicesIterator => { + filter.extend_slices(SlicesIterator::new(&predicate.filter)) + } + IterationStrategy::Slices(slices) => filter.extend_slices(slices.iter().cloned()), + IterationStrategy::IndexIterator => { + filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count)) + } + IterationStrategy::Indices(indices) => filter.extend_idx(indices.iter().cloned()), + IterationStrategy::All | IterationStrategy::None => unreachable!(), + } + + let mut builder = ArrayDataBuilder::new(data.data_type().clone()) + .len(predicate.count) + .add_buffer(filter.dst_offsets.into()) + .add_buffer(filter.dst_values.into()); + + if let Some((null_count, nulls)) = filter_null_mask(data, predicate) { + builder = builder.null_count(null_count).null_bit_buffer(Some(nulls)); + } + + let data = unsafe { builder.build_unchecked() }; + GenericStringArray::from(data) +} + +/// `filter` implementation for dictionaries +fn filter_dict( + array: &DictionaryArray, + predicate: &FilterPredicate, +) -> DictionaryArray +where + T: ArrowPrimitiveType, + T::Native: num::Num, +{ + let filtered_keys = filter_primitive::(array.keys(), predicate); + let filtered_data = filtered_keys.data_ref(); + + let data = unsafe { + ArrayData::new_unchecked( + array.data_type().clone(), + filtered_data.len(), + Some(filtered_data.null_count()), + filtered_data.null_buffer().cloned(), + filtered_data.offset(), + filtered_data.buffers().to_vec(), + array.data().child_data().to_vec(), + ) + }; + + DictionaryArray::::from(data) } #[cfg(test)] mod tests { - use super::*; + use rand::distributions::{Alphanumeric, Standard}; + use rand::prelude::*; + use crate::datatypes::Int64Type; use crate::{ buffer::Buffer, datatypes::{DataType, Field}, }; + use super::*; + macro_rules! def_temporal_test { ($test:ident, $array_type: ident, $data: expr) => { #[test] @@ -473,7 +958,7 @@ mod tests { } #[test] - fn test_filter_primative_array_with_null() { + fn test_filter_primitive_array_with_null() { let a = Int32Array::from(vec![Some(5), None]); let b = BooleanArray::from(vec![false, true]); let c = filter(&a, &b).unwrap(); @@ -579,7 +1064,7 @@ mod tests { .len(4) .add_buffer(value_offsets) .add_child_data(value_data) - .null_bit_buffer(Buffer::from([0b00000111])) + .null_bit_buffer(Some(Buffer::from([0b00000111]))) .build() .unwrap(); @@ -603,7 +1088,7 @@ mod tests { .len(2) .add_buffer(value_offsets) .add_child_data(value_data) - .null_bit_buffer(Buffer::from([0b00000001])) + .null_bit_buffer(Some(Buffer::from([0b00000001]))) .build() .unwrap(); @@ -614,9 +1099,9 @@ mod tests { fn test_slice_iterator_bits() { let filter_values = (0..64).map(|i| i == 1).collect::>(); let filter = BooleanArray::from(filter_values); + let filter_count = filter_count(&filter); let iter = SlicesIterator::new(&filter); - let filter_count = iter.filter_count(); let chunks = iter.collect::>(); assert_eq!(chunks, vec![(1, 2)]); @@ -627,9 +1112,9 @@ mod tests { fn test_slice_iterator_bits1() { let filter_values = (0..64).map(|i| i != 1).collect::>(); let filter = BooleanArray::from(filter_values); + let filter_count = filter_count(&filter); let iter = SlicesIterator::new(&filter); - let filter_count = iter.filter_count(); let chunks = iter.collect::>(); assert_eq!(chunks, vec![(0, 1), (2, 64)]); @@ -640,9 +1125,9 @@ mod tests { fn test_slice_iterator_chunk_and_bits() { let filter_values = (0..130).map(|i| i % 62 != 0).collect::>(); let filter = BooleanArray::from(filter_values); + let filter_count = filter_count(&filter); let iter = SlicesIterator::new(&filter); - let filter_count = iter.filter_count(); let chunks = iter.collect::>(); assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]); @@ -693,4 +1178,537 @@ mod tests { assert_eq!(out.data_type(), &DataType::Int64); Ok(()) } + + #[test] + fn test_slices() { + // takes up 2 u64s + let bools = std::iter::repeat(true) + .take(10) + .chain(std::iter::repeat(false).take(30)) + .chain(std::iter::repeat(true).take(20)) + .chain(std::iter::repeat(false).take(17)) + .chain(std::iter::repeat(true).take(4)); + + let bool_array: BooleanArray = bools.map(Some).collect(); + + let slices: Vec<_> = SlicesIterator::new(&bool_array).collect(); + let expected = vec![(0, 10), (40, 60), (77, 81)]; + assert_eq!(slices, expected); + + // slice with offset and truncated len + let len = bool_array.len(); + let sliced_array = bool_array.slice(7, len - 10); + let sliced_array = sliced_array + .as_any() + .downcast_ref::() + .unwrap(); + let slices: Vec<_> = SlicesIterator::new(sliced_array).collect(); + let expected = vec![(0, 3), (33, 53), (70, 71)]; + assert_eq!(slices, expected); + } + + fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) { + let mut rng = thread_rng(); + + let bools: Vec = std::iter::from_fn(|| Some(rng.gen())) + .take(mask_len) + .collect(); + + let buffer = Buffer::from_iter(bools.iter().cloned()); + + let truncated_length = mask_len - offset - truncate; + + let data = ArrayDataBuilder::new(DataType::Boolean) + .len(truncated_length) + .offset(offset) + .add_buffer(buffer) + .build() + .unwrap(); + + let filter = BooleanArray::from(data); + + let slice_bits: Vec<_> = SlicesIterator::new(&filter) + .flat_map(|(start, end)| start..end) + .collect(); + + let count = filter_count(&filter); + let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect(); + + let expected_bits: Vec<_> = bools + .iter() + .skip(offset) + .take(truncated_length) + .enumerate() + .flat_map(|(idx, v)| v.then(|| idx)) + .collect(); + + assert_eq!(slice_bits, expected_bits); + assert_eq!(index_bits, expected_bits); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn fuzz_test_slices_iterator() { + let mut rng = thread_rng(); + + for _ in 0..100 { + let mask_len = rng.gen_range(0..1024); + let max_offset = 64.min(mask_len); + let offset = rng.gen::().checked_rem(max_offset).unwrap_or(0); + + let max_truncate = 128.min(mask_len - offset); + let truncate = rng.gen::().checked_rem(max_truncate).unwrap_or(0); + + test_slices_fuzz(mask_len, offset, truncate); + } + + test_slices_fuzz(64, 0, 0); + test_slices_fuzz(64, 8, 0); + test_slices_fuzz(64, 8, 8); + test_slices_fuzz(32, 8, 8); + test_slices_fuzz(32, 5, 9); + } + + /// Filters `values` by `predicate` using standard rust iterators + fn filter_rust(values: impl IntoIterator, predicate: &[bool]) -> Vec { + values + .into_iter() + .zip(predicate) + .filter(|(_, x)| **x) + .map(|(a, _)| a) + .collect() + } + + /// Generates an array of length `len` with `valid_percent` non-null values + fn gen_primitive(len: usize, valid_percent: f64) -> Vec> + where + Standard: Distribution, + { + let mut rng = thread_rng(); + (0..len) + .map(|_| rng.gen_bool(valid_percent).then(|| rng.gen())) + .collect() + } + + /// Generates an array of length `len` with `valid_percent` non-null values + fn gen_strings( + len: usize, + valid_percent: f64, + str_len_range: std::ops::Range, + ) -> Vec> { + let mut rng = thread_rng(); + (0..len) + .map(|_| { + rng.gen_bool(valid_percent).then(|| { + let len = rng.gen_range(str_len_range.clone()); + (0..len) + .map(|_| char::from(rng.sample(Alphanumeric))) + .collect() + }) + }) + .collect() + } + + /// Returns an iterator that calls `Option::as_deref` on each item + fn as_deref( + src: &[Option], + ) -> impl Iterator> { + src.iter().map(|x| x.as_deref()) + } + + #[test] + #[cfg_attr(miri, ignore)] + fn fuzz_filter() { + let mut rng = thread_rng(); + + for i in 0..100 { + let filter_percent = match i { + 0..=4 => 1., + 5..=10 => 0., + _ => rng.gen_range(0.0..1.0), + }; + + let valid_percent = rng.gen_range(0.0..1.0); + + let array_len = rng.gen_range(32..256); + let array_offset = rng.gen_range(0..10); + + // Construct a predicate + let filter_offset = rng.gen_range(0..10); + let filter_truncate = rng.gen_range(0..10); + let bools: Vec<_> = std::iter::from_fn(|| Some(rng.gen_bool(filter_percent))) + .take(array_len + filter_offset - filter_truncate) + .collect(); + + let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some)); + + // Offset predicate + let predicate = predicate.slice(filter_offset, array_len - filter_truncate); + let predicate = predicate.as_any().downcast_ref::().unwrap(); + let bools = &bools[filter_offset..]; + + // Test i32 + let values = gen_primitive(array_len + array_offset, valid_percent); + let src = Int32Array::from_iter(values.iter().cloned()); + + let src = src.slice(array_offset, array_len); + let src = src.as_any().downcast_ref::().unwrap(); + let values = &values[array_offset..]; + + let filtered = filter(src, predicate).unwrap(); + let array = filtered.as_any().downcast_ref::().unwrap(); + let actual: Vec<_> = array.iter().collect(); + + assert_eq!(actual, filter_rust(values.iter().cloned(), bools)); + + // Test string + let strings = gen_strings(array_len + array_offset, valid_percent, 0..20); + let src = StringArray::from_iter(as_deref(&strings)); + + let src = src.slice(array_offset, array_len); + let src = src.as_any().downcast_ref::().unwrap(); + + let filtered = filter(src, predicate).unwrap(); + let array = filtered.as_any().downcast_ref::().unwrap(); + let actual: Vec<_> = array.iter().collect(); + + let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools); + assert_eq!(actual, expected_strings); + + // Test string dictionary + let src = DictionaryArray::::from_iter(as_deref(&strings)); + + let src = src.slice(array_offset, array_len); + let src = src + .as_any() + .downcast_ref::>() + .unwrap(); + + let filtered = filter(src, predicate).unwrap(); + + let array = filtered + .as_any() + .downcast_ref::>() + .unwrap(); + + let values = array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + let actual: Vec<_> = array + .keys() + .iter() + .map(|key| key.map(|key| values.value(key as usize))) + .collect(); + + assert_eq!(actual, expected_strings); + } + } + + #[test] + fn test_filter_map() { + let mut builder = + MapBuilder::new(None, StringBuilder::new(16), Int64Builder::new(4)); + // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1} + builder.keys().append_value("key1").unwrap(); + builder.values().append_value(1).unwrap(); + builder.append(true).unwrap(); + builder.keys().append_value("key2").unwrap(); + builder.keys().append_value("key3").unwrap(); + builder.values().append_value(2).unwrap(); + builder.values().append_value(3).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.keys().append_value("key1").unwrap(); + builder.values().append_value(1).unwrap(); + builder.append(true).unwrap(); + let maparray = Arc::new(builder.finish()) as ArrayRef; + + let indices = vec![Some(true), Some(false), Some(false), Some(true)] + .into_iter() + .collect::(); + let got = filter(&maparray, &indices).unwrap(); + + let mut builder = + MapBuilder::new(None, StringBuilder::new(8), Int64Builder::new(2)); + builder.keys().append_value("key1").unwrap(); + builder.values().append_value(1).unwrap(); + builder.append(true).unwrap(); + builder.keys().append_value("key1").unwrap(); + builder.values().append_value(1).unwrap(); + builder.append(true).unwrap(); + let expected = Arc::new(builder.finish()) as ArrayRef; + + assert_eq!(&expected, &got); + } + + #[test] + fn test_filter_fixed_size_list_arrays() { + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) + .build() + .unwrap(); + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 3, + ); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build() + .unwrap(); + let array = FixedSizeListArray::from(list_data); + + let filter_array = BooleanArray::from(vec![true, false, false]); + + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + assert_eq!(filtered.len(), 1); + + let list = filtered.value(0); + assert_eq!( + &[0, 1, 2], + list.as_any().downcast_ref::().unwrap().values() + ); + + let filter_array = BooleanArray::from(vec![true, false, true]); + + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + assert_eq!(filtered.len(), 2); + + let list = filtered.value(0); + assert_eq!( + &[0, 1, 2], + list.as_any().downcast_ref::().unwrap().values() + ); + let list = filtered.value(1); + assert_eq!( + &[6, 7, 8], + list.as_any().downcast_ref::().unwrap().values() + ); + } + + #[test] + fn test_filter_fixed_size_list_arrays_with_null() { + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Set null buts for the nested array: + // [[0, 1], null, null, [6, 7], [8, 9]] + // 01011001 00000001 + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 2, + ); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let array = FixedSizeListArray::from(list_data); + + let filter_array = BooleanArray::from(vec![true, true, false, true, false]); + + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + assert_eq!(filtered.len(), 3); + + let list = filtered.value(0); + assert_eq!( + &[0, 1], + list.as_any().downcast_ref::().unwrap().values() + ); + assert!(filtered.is_null(1)); + let list = filtered.value(2); + assert_eq!( + &[6, 7], + list.as_any().downcast_ref::().unwrap().values() + ); + } + + fn test_filter_union_array(array: UnionArray) { + let filter_array = BooleanArray::from(vec![true, false, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(1); + builder.append::("A", 1).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + + let filter_array = BooleanArray::from(vec![true, false, true]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(2); + builder.append::("A", 1).unwrap(); + builder.append::("A", 34).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + + let filter_array = BooleanArray::from(vec![true, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(2); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + } + + #[test] + fn test_filter_union_array_dense() { + let mut builder = UnionBuilder::new_dense(3); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + test_filter_union_array(array); + } + + #[test] + fn test_filter_run_union_array_dense() { + let mut builder = UnionBuilder::new_dense(3); + builder.append::("A", 1).unwrap(); + builder.append::("A", 3).unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + let filter_array = BooleanArray::from(vec![true, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(3); + builder.append::("A", 1).unwrap(); + builder.append::("A", 3).unwrap(); + let expected = builder.build().unwrap(); + + assert_eq!(filtered.data(), expected.data()); + } + + #[test] + fn test_filter_union_array_dense_with_nulls() { + let mut builder = UnionBuilder::new_dense(4); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append_null::("B").unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + let filter_array = BooleanArray::from(vec![true, true, false, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(2); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + + let filter_array = BooleanArray::from(vec![true, false, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(2); + builder.append::("A", 1).unwrap(); + builder.append_null::("B").unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + } + + #[test] + fn test_filter_union_array_sparse() { + let mut builder = UnionBuilder::new_sparse(3); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + test_filter_union_array(array); + } + + #[test] + fn test_filter_union_array_sparse_with_nulls() { + let mut builder = UnionBuilder::new_sparse(4); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + builder.append_null::("B").unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + let filter_array = BooleanArray::from(vec![true, false, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_sparse(2); + builder.append::("A", 1).unwrap(); + builder.append_null::("B").unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + } + + fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) { + assert_eq!(union1.len(), union2.len()); + + for i in 0..union1.len() { + let type_id = union1.type_id(i); + + let slot1 = union1.value(i); + let slot2 = union2.value(i); + + assert_eq!(slot1.is_null(0), slot2.is_null(0)); + + if !slot1.is_null(0) && !slot2.is_null(0) { + match type_id { + 0 => { + let slot1 = slot1.as_any().downcast_ref::().unwrap(); + assert_eq!(slot1.len(), 1); + let value1 = slot1.value(0); + + let slot2 = slot2.as_any().downcast_ref::().unwrap(); + assert_eq!(slot2.len(), 1); + let value2 = slot2.value(0); + assert_eq!(value1, value2); + } + 1 => { + let slot1 = + slot1.as_any().downcast_ref::().unwrap(); + assert_eq!(slot1.len(), 1); + let value1 = slot1.value(0); + + let slot2 = + slot2.as_any().downcast_ref::().unwrap(); + assert_eq!(slot2.len(), 1); + let value2 = slot2.value(0); + assert_eq!(value1, value2); + } + _ => unreachable!(), + } + } + } + } } diff --git a/arrow/src/compute/kernels/length.rs b/arrow/src/compute/kernels/length.rs index b0f3d9ad58ef..a68aa2bde4eb 100644 --- a/arrow/src/compute/kernels/length.rs +++ b/arrow/src/compute/kernels/length.rs @@ -15,116 +15,204 @@ // specific language governing permissions and limitations // under the License. -//! Defines kernel for length of a string array +//! Defines kernel for length of string arrays and binary arrays +use crate::{array::*, buffer::Buffer, datatypes::ArrowPrimitiveType}; use crate::{ - array::*, - buffer::Buffer, - datatypes::{ArrowNativeType, ArrowPrimitiveType}, -}; -use crate::{ - datatypes::{DataType, Int32Type, Int64Type}, + datatypes::*, error::{ArrowError, Result}, }; -fn unary_offsets_string( - array: &GenericStringArray, - data_type: DataType, - op: F, -) -> ArrayRef +use std::sync::Arc; + +macro_rules! unary_offsets { + ($array: expr, $data_type: expr, $op: expr) => {{ + let slice = $array.value_offsets(); + + let lengths = slice.windows(2).map(|offset| $op(offset[1] - offset[0])); + + // JUSTIFICATION + // Benefit + // ~60% speedup + // Soundness + // `values` come from a slice iterator with a known size. + let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; + + let null_bit_buffer = $array + .data_ref() + .null_buffer() + .map(|b| b.bit_slice($array.offset(), $array.len())); + + let data = unsafe { + ArrayData::new_unchecked( + $data_type, + $array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ) + }; + make_array(data) + }}; +} + +macro_rules! kernel_dict { + ($array: ident, $kernel: expr, $kt: ident, $($t: ident: $gt: ident), *) => { + match $kt.as_ref() { + $(&DataType::$t => { + let dict = $array + .as_any() + .downcast_ref::>() + .unwrap_or_else(|| { + panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}", + stringify!($gt), $array.data_type()) + }); + let values = $kernel(dict.values())?; + let result = DictionaryArray::try_new(dict.keys(), &values)?; + Ok(Arc::new(result)) + }, + )* + t => panic!("Unsupported dictionary key type: {}", t) + } + } +} + +fn length_list(array: &dyn Array) -> ArrayRef where - O: StringOffsetSizeTrait + ArrowNativeType, - F: Fn(O) -> O, + O: OffsetSizeTrait, + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, { - // note: offsets are stored as u8, but they can be interpreted as OffsetSize - let offsets = &array.data_ref().buffers()[0]; - // this is a 30% improvement over iterating over u8s and building OffsetSize, which - // justifies the usage of `unsafe`. - let slice: &[O] = &unsafe { offsets.typed_data::() }[array.offset()..]; - - let lengths = slice.windows(2).map(|offset| op(offset[1] - offset[0])); - - // JUSTIFICATION - // Benefit - // ~60% speedup - // Soundness - // `values` is an iterator with a known size. - let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; - - let null_bit_buffer = array - .data_ref() - .null_buffer() - .map(|b| b.bit_slice(array.offset(), array.len())); - - let data = unsafe { - ArrayData::new_unchecked( - data_type, - array.len(), - None, - null_bit_buffer, - 0, - vec![buffer], - vec![], - ) - }; - make_array(data) + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + unary_offsets!(array, T::DATA_TYPE, |x| x) +} + +fn length_binary(array: &dyn Array) -> ArrayRef +where + O: OffsetSizeTrait, + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + unary_offsets!(array, T::DATA_TYPE, |x| x) } -fn octet_length( - array: &dyn Array, -) -> ArrayRef +fn length_string(array: &dyn Array) -> ArrayRef where - T::Native: StringOffsetSizeTrait, + O: OffsetSizeTrait, + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, { let array = array .as_any() .downcast_ref::>() .unwrap(); - unary_offsets_string::(array, T::DATA_TYPE, |x| x) + unary_offsets!(array, T::DATA_TYPE, |x| x) } -fn bit_length_impl( - array: &dyn Array, -) -> ArrayRef +fn bit_length_binary(array: &dyn Array) -> ArrayRef where - T::Native: StringOffsetSizeTrait, + O: OffsetSizeTrait, + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, +{ + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let bits_in_bytes = O::from_usize(8).unwrap(); + unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes) +} + +fn bit_length_string(array: &dyn Array) -> ArrayRef +where + O: OffsetSizeTrait, + T: ArrowPrimitiveType, + T::Native: OffsetSizeTrait, { let array = array .as_any() .downcast_ref::>() .unwrap(); let bits_in_bytes = O::from_usize(8).unwrap(); - unary_offsets_string::(array, T::DATA_TYPE, |x| x * bits_in_bytes) + unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes) } -/// Returns an array of Int32/Int64 denoting the number of bytes in each string in the array. +/// Returns an array of Int32/Int64 denoting the length of each value in the array. +/// For list array, length is the number of elements in each list. +/// For string array and binary array, length is the number of bytes of each value. /// -/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 +/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray and BinaryArray/LargeBinaryArray, +/// or DictionaryArray with above Arrays as values /// * length of null is null. -/// * length is in number of bytes pub fn length(array: &dyn Array) -> Result { match array.data_type() { - DataType::Utf8 => Ok(octet_length::(array)), - DataType::LargeUtf8 => Ok(octet_length::(array)), - _ => Err(ArrowError::ComputeError(format!( + DataType::Dictionary(kt, _) => { + kernel_dict!( + array, + |a| { length(a) }, + kt, + Int8: Int8Type, + Int16: Int16Type, + Int32: Int32Type, + Int64: Int64Type, + UInt8: UInt8Type, + UInt16: UInt16Type, + UInt32: UInt32Type, + UInt64: UInt64Type + ) + } + DataType::List(_) => Ok(length_list::(array)), + DataType::LargeList(_) => Ok(length_list::(array)), + DataType::Utf8 => Ok(length_string::(array)), + DataType::LargeUtf8 => Ok(length_string::(array)), + DataType::Binary => Ok(length_binary::(array)), + DataType::LargeBinary => Ok(length_binary::(array)), + other => Err(ArrowError::ComputeError(format!( "length not supported for {:?}", - array.data_type() + other ))), } } -/// Returns an array of Int32/Int64 denoting the number of bits in each string in the array. +/// Returns an array of Int32/Int64 denoting the number of bits in each value in the array. /// -/// * this only accepts StringArray/Utf8 and LargeString/LargeUtf8 +/// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, BinaryArray and LargeBinaryArray, +/// or DictionaryArray with above Arrays as values /// * bit_length of null is null. /// * bit_length is in number of bits pub fn bit_length(array: &dyn Array) -> Result { match array.data_type() { - DataType::Utf8 => Ok(bit_length_impl::(array)), - DataType::LargeUtf8 => Ok(bit_length_impl::(array)), - _ => Err(ArrowError::ComputeError(format!( + DataType::Dictionary(kt, _) => { + kernel_dict!( + array, + |a| { bit_length(a) }, + kt, + Int8: Int8Type, + Int16: Int16Type, + Int32: Int32Type, + Int64: Int64Type, + UInt8: UInt8Type, + UInt16: UInt16Type, + UInt32: UInt32Type, + UInt64: UInt64Type + ) + } + DataType::Utf8 => Ok(bit_length_string::(array)), + DataType::LargeUtf8 => Ok(bit_length_string::(array)), + DataType::Binary => Ok(bit_length_binary::(array)), + DataType::LargeBinary => Ok(bit_length_binary::(array)), + other => Err(ArrowError::ComputeError(format!( "bit_length not supported for {:?}", - array.data_type() + other ))), } } @@ -133,11 +221,11 @@ pub fn bit_length(array: &dyn Array) -> Result { mod tests { use super::*; - fn length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { - fn double_vec(v: Vec) -> Vec { - [&v[..], &v[..]].concat() - } + fn double_vec(v: Vec) -> Vec { + [&v[..], &v[..]].concat() + } + fn length_cases_string() -> Vec<(Vec<&'static str>, usize, Vec)> { // a large array let mut values = vec!["one", "on", "o", ""]; let mut expected = vec![3, 2, 1, 0]; @@ -154,10 +242,35 @@ mod tests { ] } + macro_rules! length_binary_helper { + ($offset_ty: ty, $result_ty: ty, $kernel: ident, $value: expr, $expected: expr) => {{ + let array = GenericBinaryArray::<$offset_ty>::from($value); + let result = $kernel(&array)?; + let result = result.as_any().downcast_ref::<$result_ty>().unwrap(); + let expected: $result_ty = $expected.into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }}; + } + + macro_rules! length_list_helper { + ($offset_ty: ty, $result_ty: ty, $element_ty: ty, $value: expr, $expected: expr) => {{ + let array = + GenericListArray::<$offset_ty>::from_iter_primitive::<$element_ty, _, _>( + $value, + ); + let result = length(&array)?; + let result = result.as_any().downcast_ref::<$result_ty>().unwrap(); + let expected: $result_ty = $expected.into(); + assert_eq!(expected.data(), result.data()); + Ok(()) + }}; + } + #[test] #[cfg_attr(miri, ignore)] // running forever fn length_test_string() -> Result<()> { - length_cases() + length_cases_string() .into_iter() .try_for_each(|(input, len, expected)| { let array = StringArray::from(input); @@ -174,7 +287,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] // running forever fn length_test_large_string() -> Result<()> { - length_cases() + length_cases_string() .into_iter() .try_for_each(|(input, len, expected)| { let array = LargeStringArray::from(input); @@ -188,7 +301,45 @@ mod tests { }) } - fn length_null_cases() -> Vec<(Vec>, usize, Vec>)> { + #[test] + fn length_test_binary() -> Result<()> { + let value: Vec<&[u8]> = vec![b"zero", b"one", &[0xff, 0xf8]]; + let result: Vec = vec![4, 3, 2]; + length_binary_helper!(i32, Int32Array, length, value, result) + } + + #[test] + fn length_test_large_binary() -> Result<()> { + let value: Vec<&[u8]> = vec![b"zero", &[0xff, 0xf8], b"two"]; + let result: Vec = vec![4, 2, 3]; + length_binary_helper!(i64, Int64Array, length, value, result) + } + + #[test] + fn length_test_list() -> Result<()> { + let value = vec![ + Some(vec![]), + Some(vec![Some(1), Some(2), Some(4)]), + Some(vec![Some(0)]), + ]; + let result: Vec = vec![0, 3, 1]; + length_list_helper!(i32, Int32Array, Int32Type, value, result) + } + + #[test] + fn length_test_large_list() -> Result<()> { + let value = vec![ + Some(vec![]), + Some(vec![Some(1.1), Some(2.2), Some(3.3)]), + Some(vec![None]), + ]; + let result: Vec = vec![0, 3, 1]; + length_list_helper!(i64, Int64Array, Float32Type, value, result) + } + + type OptionStr = Option<&'static str>; + + fn length_null_cases_string() -> Vec<(Vec, usize, Vec>)> { vec![( vec![Some("one"), None, Some("three"), Some("four")], 4, @@ -198,7 +349,7 @@ mod tests { #[test] fn length_null_string() -> Result<()> { - length_null_cases() + length_null_cases_string() .into_iter() .try_for_each(|(input, len, expected)| { let array = StringArray::from(input); @@ -214,7 +365,7 @@ mod tests { #[test] fn length_null_large_string() -> Result<()> { - length_null_cases() + length_null_cases_string() .into_iter() .try_for_each(|(input, len, expected)| { let array = LargeStringArray::from(input); @@ -233,6 +384,46 @@ mod tests { }) } + #[test] + fn length_null_binary() -> Result<()> { + let value: Vec> = + vec![Some(b"zero"), None, Some(&[0xff, 0xf8]), Some(b"three")]; + let result: Vec> = vec![Some(4), None, Some(2), Some(5)]; + length_binary_helper!(i32, Int32Array, length, value, result) + } + + #[test] + fn length_null_large_binary() -> Result<()> { + let value: Vec> = + vec![Some(&[0xff, 0xf8]), None, Some(b"two"), Some(b"three")]; + let result: Vec> = vec![Some(2), None, Some(3), Some(5)]; + length_binary_helper!(i64, Int64Array, length, value, result) + } + + #[test] + fn length_null_list() -> Result<()> { + let value = vec![ + Some(vec![]), + None, + Some(vec![Some(1), None, Some(2), Some(4)]), + Some(vec![Some(0)]), + ]; + let result: Vec> = vec![Some(0), None, Some(4), Some(1)]; + length_list_helper!(i32, Int32Array, Int8Type, value, result) + } + + #[test] + fn length_null_large_list() -> Result<()> { + let value = vec![ + Some(vec![]), + None, + Some(vec![Some(1.1), None, Some(4.0)]), + Some(vec![Some(0.1)]), + ]; + let result: Vec> = vec![Some(0), None, Some(3), Some(1)]; + length_list_helper!(i64, Int64Array, Float32Type, value, result) + } + /// Tests that length is not valid for u64. #[test] fn length_wrong_type() { @@ -243,7 +434,7 @@ mod tests { /// Tests with an offset #[test] - fn length_offsets() -> Result<()> { + fn length_offsets_string() -> Result<()> { let a = StringArray::from(vec![Some("hello"), Some(" "), Some("world"), None]); let b = a.slice(1, 3); let result = length(b.as_ref())?; @@ -255,11 +446,22 @@ mod tests { Ok(()) } - fn bit_length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { - fn double_vec(v: Vec) -> Vec { - [&v[..], &v[..]].concat() - } + #[test] + fn length_offsets_binary() -> Result<()> { + let value: Vec> = + vec![Some(b"hello"), Some(b" "), Some(&[0xff, 0xf8]), None]; + let a = BinaryArray::from(value); + let b = a.slice(1, 3); + let result = length(b.as_ref())?; + let result: &Int32Array = as_primitive_array(&result); + let expected = Int32Array::from(vec![Some(1), Some(2), None]); + assert_eq!(&expected, result); + + Ok(()) + } + + fn bit_length_cases() -> Vec<(Vec<&'static str>, usize, Vec)> { // a large array let mut values = vec!["one", "on", "o", ""]; let mut expected = vec![24, 16, 8, 0]; @@ -310,8 +512,21 @@ mod tests { }) } - fn bit_length_null_cases() -> Vec<(Vec>, usize, Vec>)> - { + #[test] + fn bit_length_binary() -> Result<()> { + let value: Vec<&[u8]> = vec![b"one", &[0xff, 0xf8], b"three"]; + let expected: Vec = vec![24, 16, 40]; + length_binary_helper!(i32, Int32Array, bit_length, value, expected) + } + + #[test] + fn bit_length_large_binary() -> Result<()> { + let value: Vec<&[u8]> = vec![b"zero", b" ", &[0xff, 0xf8]]; + let expected: Vec = vec![32, 8, 16]; + length_binary_helper!(i64, Int64Array, bit_length, value, expected) + } + + fn bit_length_null_cases() -> Vec<(Vec, usize, Vec>)> { vec![( vec![Some("one"), None, Some("three"), Some("four")], 4, @@ -356,6 +571,22 @@ mod tests { }) } + #[test] + fn bit_length_null_binary() -> Result<()> { + let value: Vec> = + vec![Some(b"one"), None, Some(b"three"), Some(&[0xff, 0xf8])]; + let expected: Vec> = vec![Some(24), None, Some(40), Some(16)]; + length_binary_helper!(i32, Int32Array, bit_length, value, expected) + } + + #[test] + fn bit_length_null_large_binary() -> Result<()> { + let value: Vec> = + vec![Some(b"one"), None, Some(&[0xff, 0xf8]), Some(b"four")]; + let expected: Vec> = vec![Some(24), None, Some(16), Some(32)]; + length_binary_helper!(i64, Int64Array, bit_length, value, expected) + } + /// Tests that bit_length is not valid for u64. #[test] fn bit_length_wrong_type() { @@ -366,7 +597,7 @@ mod tests { /// Tests with an offset #[test] - fn bit_length_offsets() -> Result<()> { + fn bit_length_offsets_string() -> Result<()> { let a = StringArray::from(vec![Some("hello"), Some(" "), Some("world"), None]); let b = a.slice(1, 3); let result = bit_length(b.as_ref())?; @@ -377,4 +608,121 @@ mod tests { Ok(()) } + + #[test] + fn bit_length_offsets_binary() -> Result<()> { + let value: Vec> = + vec![Some(b"hello"), Some(&[]), Some(b"world"), None]; + let a = BinaryArray::from(value); + let b = a.slice(1, 3); + let result = bit_length(b.as_ref())?; + let result: &Int32Array = as_primitive_array(&result); + + let expected = Int32Array::from(vec![Some(0), Some(40), None]); + assert_eq!(&expected, result); + + Ok(()) + } + + #[test] + fn length_dictionary() -> Result<()> { + _length_dictionary::()?; + _length_dictionary::()?; + _length_dictionary::()?; + _length_dictionary::()?; + _length_dictionary::()?; + _length_dictionary::()?; + _length_dictionary::()?; + _length_dictionary::()?; + Ok(()) + } + + fn _length_dictionary() -> Result<()> { + const TOTAL: i32 = 100; + + let v = ["aaaa", "bb", "ccccc", "ddd", "eeeeee"]; + let data: Vec> = (0..TOTAL) + .map(|n| { + let i = n % 5; + if i == 3 { + None + } else { + Some(v[i as usize]) + } + }) + .collect(); + + let dict_array: DictionaryArray = data.clone().into_iter().collect(); + + let expected: Vec> = + data.iter().map(|opt| opt.map(|s| s.len() as i32)).collect(); + + let res = length(&dict_array)?; + let actual = res.as_any().downcast_ref::>().unwrap(); + let actual: Vec> = actual + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(dict_array.keys_iter()) + .collect(); + + for i in 0..TOTAL as usize { + assert_eq!(expected[i], actual[i],); + } + + Ok(()) + } + + #[test] + fn bit_length_dictionary() -> Result<()> { + _bit_length_dictionary::()?; + _bit_length_dictionary::()?; + _bit_length_dictionary::()?; + _bit_length_dictionary::()?; + _bit_length_dictionary::()?; + _bit_length_dictionary::()?; + _bit_length_dictionary::()?; + _bit_length_dictionary::()?; + Ok(()) + } + + fn _bit_length_dictionary() -> Result<()> { + const TOTAL: i32 = 100; + + let v = ["aaaa", "bb", "ccccc", "ddd", "eeeeee"]; + let data: Vec> = (0..TOTAL) + .map(|n| { + let i = n % 5; + if i == 3 { + None + } else { + Some(v[i as usize]) + } + }) + .collect(); + + let dict_array: DictionaryArray = data.clone().into_iter().collect(); + + let expected: Vec> = data + .iter() + .map(|opt| opt.map(|s| (s.chars().count() * 8) as i32)) + .collect(); + + let res = bit_length(&dict_array)?; + let actual = res.as_any().downcast_ref::>().unwrap(); + let actual: Vec> = actual + .values() + .as_any() + .downcast_ref::() + .unwrap() + .take_iter(dict_array.keys_iter()) + .collect(); + + for i in 0..TOTAL as usize { + assert_eq!(expected[i], actual[i],); + } + + Ok(()) + } } diff --git a/arrow/src/compute/kernels/limit.rs b/arrow/src/compute/kernels/limit.rs index 34f5dcafc3d0..07cf727b09d4 100644 --- a/arrow/src/compute/kernels/limit.rs +++ b/arrow/src/compute/kernels/limit.rs @@ -113,7 +113,7 @@ mod tests { .len(9) .add_buffer(value_offsets) .add_child_data(value_data) - .null_bit_buffer(Buffer::from(null_bits)) + .null_bit_buffer(Some(Buffer::from(null_bits))) .build() .unwrap(); let list_array: ArrayRef = Arc::new(ListArray::from(list_data)); @@ -145,24 +145,25 @@ mod tests { let boolean_data = ArrayData::builder(DataType::Boolean) .len(5) .add_buffer(Buffer::from([0b00010000])) - .null_bit_buffer(Buffer::from([0b00010001])) + .null_bit_buffer(Some(Buffer::from([0b00010001]))) .build() .unwrap(); let int_data = ArrayData::builder(DataType::Int32) .len(5) .add_buffer(Buffer::from_slice_ref(&[0, 28, 42, 0, 0])) - .null_bit_buffer(Buffer::from([0b00000110])) + .null_bit_buffer(Some(Buffer::from([0b00000110]))) .build() .unwrap(); - let mut field_types = vec![]; - field_types.push(Field::new("a", DataType::Boolean, false)); - field_types.push(Field::new("b", DataType::Int32, false)); + let field_types = vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int32, false), + ]; let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) .len(5) .add_child_data(boolean_data.clone()) .add_child_data(int_data.clone()) - .null_bit_buffer(Buffer::from([0b00010111])) + .null_bit_buffer(Some(Buffer::from([0b00010111]))) .build() .unwrap(); let struct_array = StructArray::from(struct_array_data); diff --git a/arrow/src/compute/kernels/mod.rs b/arrow/src/compute/kernels/mod.rs index a0ef50a7b85a..c615d3a55e1a 100644 --- a/arrow/src/compute/kernels/mod.rs +++ b/arrow/src/compute/kernels/mod.rs @@ -25,6 +25,7 @@ pub mod cast; pub mod cast_utils; pub mod comparison; pub mod concat; +pub mod concat_elements; pub mod filter; pub mod length; pub mod limit; diff --git a/arrow/src/compute/kernels/partition.rs b/arrow/src/compute/kernels/partition.rs index 8a16942220e9..e3a1497b8d27 100644 --- a/arrow/src/compute/kernels/partition.rs +++ b/arrow/src/compute/kernels/partition.rs @@ -44,7 +44,6 @@ struct LexicographicalPartitionIterator<'a> { num_rows: usize, previous_partition_point: usize, partition_point: usize, - value_indices: Vec, } impl<'a> LexicographicalPartitionIterator<'a> { @@ -62,39 +61,78 @@ impl<'a> LexicographicalPartitionIterator<'a> { }; let comparator = LexicographicalComparator::try_new(columns)?; - let value_indices = (0..num_rows).collect::>(); Ok(LexicographicalPartitionIterator { comparator, num_rows, previous_partition_point: 0, partition_point: 0, - value_indices, }) } } -/// Exponential search is to remedy for the case when array size and cardinality are both large +/// Returns the next partition point of the range `start..end` according to the given comparator. +/// The return value is the index of the first element of the second partition, +/// and is guaranteed to be between `start..=end` (inclusive). +/// +/// The values corresponding to those indices are assumed to be partitioned according to the given comparator. +/// +/// Exponential search is to remedy for the case when array size and cardinality are both large. +/// In these cases the partition point would be near the beginning of the range and +/// plain binary search would be doing some unnecessary iterations on each call. +/// /// see #[inline] -fn exponential_search( - indices: &[usize], - target: &usize, +fn exponential_search_next_partition_point( + start: usize, + end: usize, comparator: &LexicographicalComparator<'_>, ) -> usize { + let target = start; let mut bound = 1; - while bound < indices.len() - && comparator.compare(&indices[bound], target) != Ordering::Greater + while bound + start < end + && comparator.compare(&(bound + start), &target) != Ordering::Greater { bound *= 2; } + // invariant after while loop: - // indices[bound / 2] <= target < indices[min(indices.len(), bound + 1)] + // (start + bound / 2) <= target < min(end, start + bound + 1) // where <= and < are defined by the comparator; - // note here we have right = min(indices.len(), bound + 1) because indices[bound] might + // note here we have right = min(end, start + bound + 1) because (start + bound) might // actually be considered and must be included. - (bound / 2) - + indices[(bound / 2)..indices.len().min(bound + 1)] - .partition_point(|idx| comparator.compare(idx, target) != Ordering::Greater) + partition_point(start + bound / 2, end.min(start + bound + 1), |idx| { + comparator.compare(&idx, &target) != Ordering::Greater + }) +} + +/// Returns the partition point of the range `start..end` according to the given predicate. +/// The return value is the index of the first element of the second partition, +/// and is guaranteed to be between `start..=end` (inclusive). +/// +/// The algorithm is similar to a binary search. +/// +/// The values corresponding to those indices are assumed to be partitioned according to the given predicate. +/// +/// See [`slice::partition_point`] +#[inline] +fn partition_point bool>(start: usize, end: usize, pred: P) -> usize { + let mut left = start; + let mut right = end; + let mut size = right - left; + while left < right { + let mid = left + size / 2; + + let less = pred(mid); + + if less { + left = mid + 1; + } else { + right = mid; + } + + size = right - left; + } + left } impl<'a> Iterator for LexicographicalPartitionIterator<'a> { @@ -103,17 +141,12 @@ impl<'a> Iterator for LexicographicalPartitionIterator<'a> { fn next(&mut self) -> Option { if self.partition_point < self.num_rows { // invariant: - // value_indices[0..previous_partition_point] all are values <= value_indices[previous_partition_point] - // so in order to save time we can do binary search on the value_indices[previous_partition_point..] - // and find when any value is greater than value_indices[previous_partition_point]; because we are using - // new indices, the new offset is _added_ to the previous_partition_point. - // - // be careful that idx is of type &usize which points to the actual value within value_indices, which itself - // contains usize (0..row_count), providing access to lexicographical_comparator as pointers into the - // original columnar data. - self.partition_point += exponential_search( - &self.value_indices[self.partition_point..], - &self.partition_point, + // in the range [0..previous_partition_point] all values are <= the value at [previous_partition_point] + // so in order to save time we can do binary search on the range [previous_partition_point..num_rows] + // and find the index where any value is greater than the value at [previous_partition_point] + self.partition_point = exponential_search_next_partition_point( + self.partition_point, + self.num_rows, &self.comparator, ); let start = self.previous_partition_point; @@ -134,6 +167,56 @@ mod tests { use crate::datatypes::DataType; use std::sync::Arc; + #[test] + fn test_partition_point() { + let input = &[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4]; + { + let median = input[input.len() / 2]; + assert_eq!( + 9, + partition_point( + 0, + input.len(), + &(|i: usize| input[i].cmp(&median) != Ordering::Greater) + ) + ); + } + { + let search = input[9]; + assert_eq!( + 12, + partition_point( + 9, + input.len(), + &(|i: usize| input[i].cmp(&search) != Ordering::Greater) + ) + ); + } + { + let search = input[0]; + assert_eq!( + 3, + partition_point( + 0, + 9, + &(|i: usize| input[i].cmp(&search) != Ordering::Greater) + ) + ); + } + let input = &[1, 2, 2, 2, 2, 2, 2, 2, 9]; + { + let search = input[5]; + assert_eq!( + 8, + partition_point( + 5, + 9, + &(|i: usize| input[i].cmp(&search) != Ordering::Greater) + ) + ); + } + } + #[test] fn test_lexicographical_partition_ranges_empty() { let input = vec![]; diff --git a/arrow/src/compute/kernels/regexp.rs b/arrow/src/compute/kernels/regexp.rs index 5093dcede8fe..081a6e193bda 100644 --- a/arrow/src/compute/kernels/regexp.rs +++ b/arrow/src/compute/kernels/regexp.rs @@ -19,8 +19,7 @@ //! expression of a \[Large\]StringArray use crate::array::{ - ArrayRef, GenericStringArray, GenericStringBuilder, ListBuilder, - StringOffsetSizeTrait, + ArrayRef, GenericStringArray, GenericStringBuilder, ListBuilder, OffsetSizeTrait, }; use crate::error::{ArrowError, Result}; use std::collections::HashMap; @@ -30,7 +29,7 @@ use std::sync::Arc; use regex::Regex; /// Extract all groups matched by a regular expression for a given String array. -pub fn regexp_match( +pub fn regexp_match( array: &GenericStringArray, regex_array: &GenericStringArray, flags_array: Option<&GenericStringArray>, diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 88c7785bc985..e399cf9f0c19 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -170,6 +170,7 @@ pub fn sort_to_indices( let (v, n) = partition_validity(values); Ok(match values.data_type() { + DataType::Decimal(_, _) => sort_decimal(values, v, n, cmp, &options, limit), DataType::Boolean => sort_boolean(values, v, n, &options, limit), DataType::Int8 => { sort_primitive::(values, v, n, cmp, &options, limit) @@ -243,6 +244,11 @@ pub fn sort_to_indices( DataType::Interval(IntervalUnit::DayTime) => { sort_primitive::(values, v, n, cmp, &options, limit) } + DataType::Interval(IntervalUnit::MonthDayNano) => { + sort_primitive::( + values, v, n, cmp, &options, limit, + ) + } DataType::Duration(TimeUnit::Second) => { sort_primitive::(values, v, n, cmp, &options, limit) } @@ -263,7 +269,9 @@ pub fn sort_to_indices( } DataType::Utf8 => sort_string::(values, v, n, &options, limit), DataType::LargeUtf8 => sort_string::(values, v, n, &options, limit), - DataType::List(field) => match field.data_type() { + DataType::List(field) | DataType::FixedSizeList(field, _) => match field + .data_type() + { DataType::Int8 => sort_list::(values, v, n, &options, limit), DataType::Int16 => sort_list::(values, v, n, &options, limit), DataType::Int32 => sort_list::(values, v, n, &options, limit), @@ -288,7 +296,7 @@ pub fn sort_to_indices( return Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t - ))) + ))); } }, DataType::LargeList(field) => match field.data_type() { @@ -316,35 +324,7 @@ pub fn sort_to_indices( return Err(ArrowError::ComputeError(format!( "Sort not supported for list type {:?}", t - ))) - } - }, - DataType::FixedSizeList(field, _) => match field.data_type() { - DataType::Int8 => sort_list::(values, v, n, &options, limit), - DataType::Int16 => sort_list::(values, v, n, &options, limit), - DataType::Int32 => sort_list::(values, v, n, &options, limit), - DataType::Int64 => sort_list::(values, v, n, &options, limit), - DataType::UInt8 => sort_list::(values, v, n, &options, limit), - DataType::UInt16 => { - sort_list::(values, v, n, &options, limit) - } - DataType::UInt32 => { - sort_list::(values, v, n, &options, limit) - } - DataType::UInt64 => { - sort_list::(values, v, n, &options, limit) - } - DataType::Float32 => { - sort_list::(values, v, n, &options, limit) - } - DataType::Float64 => { - sort_list::(values, v, n, &options, limit) - } - t => { - return Err(ArrowError::ComputeError(format!( - "Sort not supported for list type {:?}", - t - ))) + ))); } }, DataType::Dictionary(key_type, value_type) @@ -379,7 +359,7 @@ pub fn sort_to_indices( return Err(ArrowError::ComputeError(format!( "Sort not supported for dictionary key type {:?}", t - ))) + ))); } } } @@ -391,7 +371,7 @@ pub fn sort_to_indices( return Err(ArrowError::ComputeError(format!( "Sort not supported for data type {:?}", t - ))) + ))); } }) } @@ -420,8 +400,9 @@ impl Default for SortOptions { /// when a limit is present, the sort is pair-comparison based as k-select might be more efficient, /// when the limit is absent, binary partition is used to speed up (which is linear). /// -/// TODO maybe partition_validity call can be eliminated in this case and tri-color sort can be used -/// instead. https://en.wikipedia.org/wiki/Dutch_national_flag_problem +/// TODO maybe partition_validity call can be eliminated in this case +/// and [tri-color sort](https://en.wikipedia.org/wiki/Dutch_national_flag_problem) +/// can be used instead. fn sort_boolean( values: &ArrayRef, value_indices: Vec, @@ -453,7 +434,7 @@ fn sort_boolean( // when limit is not present, we have a better way than sorting: we can just partition // the vec into [false..., true...] or [true..., false...] when descending // TODO when https://github.com/rust-lang/rust/issues/62543 is merged we can use partition_in_place - let (mut a, b): (Vec<(u32, bool)>, Vec<(u32, bool)>) = value_indices + let (mut a, b): (Vec<_>, Vec<_>) = value_indices .into_iter() .map(|index| (index, values.value(index as usize))) .partition(|(_, value)| *value == descending); @@ -503,6 +484,30 @@ fn sort_boolean( UInt32Array::from(result_data) } +/// Sort Decimal array +fn sort_decimal( + decimal_values: &ArrayRef, + value_indices: Vec, + null_indices: Vec, + cmp: F, + options: &SortOptions, + limit: Option, +) -> UInt32Array +where + F: Fn(i128, i128) -> std::cmp::Ordering, +{ + // downcast to decimal array + let decimal_array = decimal_values + .as_any() + .downcast_ref::() + .expect("Unable to downcast to decimal array"); + let valids = value_indices + .into_iter() + .map(|index| (index, decimal_array.value(index as usize).as_i128())) + .collect::>(); + sort_primitive_inner(decimal_values, null_indices, cmp, options, limit, valids) +} + /// Sort primitive values fn sort_primitive( values: &ArrayRef, @@ -607,7 +612,7 @@ fn insert_valid_values(result_slice: &mut [u32], offset: usize, valids: &[(u3 } /// Sort strings -fn sort_string( +fn sort_string( values: &ArrayRef, value_indices: Vec, null_indices: Vec, @@ -772,7 +777,7 @@ fn sort_binary( limit: Option, ) -> UInt32Array where - S: BinaryOffsetSizeTrait, + S: OffsetSizeTrait, { let mut valids: Vec<(u32, &[u8])> = values .as_any() @@ -816,7 +821,7 @@ where } } -/// Compare two `Array`s based on the ordering defined in [ord](crate::array::ord). +/// Compare two `Array`s based on the ordering defined in [build_compare] fn cmp_array(a: &dyn Array, b: &dyn Array) -> Ordering { let cmp_op = build_compare(a, b).unwrap(); let length = a.len().max(b.len()); @@ -1032,13 +1037,11 @@ fn sort_valids( ) where T: ?Sized + Copy, { - let nulls_len = nulls.len(); + let valids_len = valids.len(); if !descending { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| cmp(a.1, b.1)); + sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1)); } else { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| { - cmp(a.1, b.1).reverse() - }); + sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1).reverse()); // reverse to keep a stable ordering nulls.reverse(); } @@ -1050,13 +1053,13 @@ fn sort_valids_array( nulls: &mut [T], len: usize, ) { - let nulls_len = nulls.len(); + let valids_len = valids.len(); if !descending { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| { + sort_unstable_by(valids, len.min(valids_len), |a, b| { cmp_array(a.1.as_ref(), b.1.as_ref()) }); } else { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| { + sort_unstable_by(valids, len.min(valids_len), |a, b| { cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() }); // reverse to keep a stable ordering @@ -1075,6 +1078,43 @@ mod tests { use std::convert::TryFrom; use std::sync::Arc; + fn create_decimal_array(data: &[Option]) -> DecimalArray { + data.iter() + .collect::() + .with_precision_and_scale(23, 6) + .unwrap() + } + + fn test_sort_to_indices_decimal_array( + data: Vec>, + options: Option, + limit: Option, + expected_data: Vec, + ) { + let output = create_decimal_array(&data); + let expected = UInt32Array::from(expected_data); + let output = + sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); + assert_eq!(output, expected) + } + + fn test_sort_decimal_array( + data: Vec>, + options: Option, + limit: Option, + expected_data: Vec>, + ) { + let output = create_decimal_array(&data); + let expected = Arc::new(create_decimal_array(&expected_data)) as ArrayRef; + let output = match limit { + Some(_) => { + sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() + } + _ => sort(&(Arc::new(output) as ArrayRef), options).unwrap(), + }; + assert_eq!(&output, &expected) + } + fn test_sort_to_indices_boolean_arrays( data: Vec>, options: Option, @@ -1301,7 +1341,7 @@ mod tests { } // Generic size binary array - fn make_generic_binary_array( + fn make_generic_binary_array( data: &[Option>], ) -> Arc> { Arc::new(GenericBinaryArray::::from_opt_vec( @@ -1555,6 +1595,19 @@ mod tests { ); } + #[test] + fn test_sort_to_indices_primitive_more_nulls_than_limit() { + test_sort_to_indices_primitive_arrays::( + vec![None, None, Some(3), None, Some(1), None, Some(2)], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![4, 6], + ); + } + #[test] fn test_sort_boolean() { // boolean @@ -1641,6 +1694,162 @@ mod tests { ); } + #[test] + fn test_sort_indices_decimal128() { + // decimal default + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + None, + None, + vec![0, 6, 4, 2, 3, 5, 1], + ); + // decimal descending + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: false, + }), + None, + vec![1, 5, 3, 2, 4, 6, 0], + ); + // decimal null_first and descending + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + None, + vec![6, 0, 1, 5, 3, 2, 4], + ); + // decimal null_first + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + None, + vec![0, 6, 4, 2, 3, 5, 1], + ); + // limit + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + None, + Some(3), + vec![0, 6, 4], + ); + // limit descending + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: false, + }), + Some(3), + vec![1, 5, 3], + ); + // limit descending null_first + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![6, 0, 1], + ); + // limit null_first + test_sort_to_indices_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![0, 6, 4], + ); + } + + #[test] + fn test_sort_decimal128() { + // decimal default + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + None, + None, + vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)], + ); + // decimal descending + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: false, + }), + None, + vec![Some(5), Some(4), Some(3), Some(2), Some(1), None, None], + ); + // decimal null_first and descending + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + None, + vec![None, None, Some(5), Some(4), Some(3), Some(2), Some(1)], + ); + // decimal null_first + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + None, + vec![None, None, Some(1), Some(2), Some(3), Some(4), Some(5)], + ); + // limit + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + None, + Some(3), + vec![None, None, Some(1)], + ); + // limit descending + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: false, + }), + Some(3), + vec![Some(5), Some(4), Some(3)], + ); + // limit descending null_first + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some(5)], + ); + // limit null_first + test_sort_decimal_array( + vec![None, Some(5), Some(2), Some(3), Some(1), Some(4), None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![None, None, Some(1)], + ); + } + #[test] fn test_sort_primitives() { // default case diff --git a/arrow/src/compute/kernels/substring.rs b/arrow/src/compute/kernels/substring.rs index 01fdf640bdae..024f5633fef4 100644 --- a/arrow/src/compute/kernels/substring.rs +++ b/arrow/src/compute/kernels/substring.rs @@ -15,261 +15,1022 @@ // specific language governing permissions and limitations // under the License. -//! Defines kernel to extract a substring of a \[Large\]StringArray +//! Defines kernel to extract a substring of an Array +//! Supported array types: +//! [GenericStringArray], [GenericBinaryArray], [FixedSizeBinaryArray], [DictionaryArray] +use crate::array::DictionaryArray; +use crate::buffer::MutableBuffer; +use crate::datatypes::*; use crate::{array::*, buffer::Buffer}; use crate::{ datatypes::DataType, error::{ArrowError, Result}, }; +use std::cmp::Ordering; +use std::sync::Arc; -#[allow(clippy::unnecessary_wraps)] -fn generic_substring( - array: &GenericStringArray, - start: OffsetSize, - length: &Option, -) -> Result { - // compute current offsets - let offsets = array.data_ref().clone().buffers()[0].clone(); - let offsets: &[OffsetSize] = unsafe { offsets.typed_data::() }; - - // compute null bitmap (copy) - let null_bit_buffer = array.data_ref().null_buffer().cloned(); +/// Returns an [`ArrayRef`] with substrings of all the elements in `array`. +/// +/// # Arguments +/// +/// * `start` - The start index of all substrings. +/// If `start >= 0`, then count from the start of the string, +/// otherwise count from the end of the string. +/// +/// * `length`(option) - The length of all substrings. +/// If `length` is [None], then the substring is from `start` to the end of the string. +/// +/// Attention: Both `start` and `length` are counted by byte, not by char. +/// +/// # Basic usage +/// ``` +/// # use arrow::array::StringArray; +/// # use arrow::compute::kernels::substring::substring; +/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]); +/// let result = substring(&array, 1, Some(4)).unwrap(); +/// let result = result.as_any().downcast_ref::().unwrap(); +/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")])); +/// ``` +/// +/// # Error +/// - The function errors when the passed array is not a [`GenericStringArray`], [`GenericBinaryArray`], [`FixedSizeBinaryArray`] +/// or [`DictionaryArray`] with supported array type as its value type. +/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array). +/// It is recommended to use [`substring_by_char`] if the input array may contain non-ASCII chars. +/// +/// ## Example of trying to get an invalid utf-8 format substring +/// ``` +/// # use arrow::array::StringArray; +/// # use arrow::compute::kernels::substring::substring; +/// let array = StringArray::from(vec![Some("E=mc²")]); +/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string(); +/// assert!(error.contains("invalid utf-8 boundary")); +/// ``` +pub fn substring(array: &dyn Array, start: i64, length: Option) -> Result { + macro_rules! substring_dict { + ($kt: ident, $($t: ident: $gt: ident), *) => { + match $kt.as_ref() { + $( + &DataType::$t => { + let dict = array + .as_any() + .downcast_ref::>() + .unwrap_or_else(|| { + panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}", + stringify!($gt), array.data_type()) + }); + let values = substring(dict.values(), start, length)?; + let result = DictionaryArray::try_new(dict.keys(), &values)?; + Ok(Arc::new(result)) + }, + )* + t => panic!("Unsupported dictionary key type: {}", t) + } + } + } - // compute values - let values = &array.data_ref().buffers()[1]; - let data = values.as_slice(); + match array.data_type() { + DataType::Dictionary(kt, _) => { + substring_dict!( + kt, + Int8: Int8Type, + Int16: Int16Type, + Int32: Int32Type, + Int64: Int64Type, + UInt8: UInt8Type, + UInt16: UInt16Type, + UInt32: UInt32Type, + UInt64: UInt64Type + ) + } + DataType::LargeBinary => binary_substring( + array + .as_any() + .downcast_ref::() + .expect("A large binary is expected"), + start, + length.map(|e| e as i64), + ), + DataType::Binary => binary_substring( + array + .as_any() + .downcast_ref::() + .expect("A binary is expected"), + start as i32, + length.map(|e| e as i32), + ), + DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring( + array + .as_any() + .downcast_ref::() + .expect("a fixed size binary is expected"), + *old_len, + start as i32, + length.map(|e| e as i32), + ), + DataType::LargeUtf8 => utf8_substring( + array + .as_any() + .downcast_ref::() + .expect("A large string is expected"), + start, + length.map(|e| e as i64), + ), + DataType::Utf8 => utf8_substring( + array + .as_any() + .downcast_ref::() + .expect("A string is expected"), + start as i32, + length.map(|e| e as i32), + ), + _ => Err(ArrowError::ComputeError(format!( + "substring does not support type {:?}", + array.data_type() + ))), + } +} - let mut new_values = Vec::new(); // we have no way to estimate how much this will be. - let mut new_offsets: Vec = Vec::with_capacity(array.len() + 1); +/// # Arguments +/// * `array` - The input string array +/// +/// * `start` - The start index of all substrings. +/// If `start >= 0`, then count from the start of the string, +/// otherwise count from the end of the string. +/// +/// * `length`(option) - The length of all substrings. +/// If `length` is `None`, then the substring is from `start` to the end of the string. +/// +/// Attention: Both `start` and `length` are counted by char. +/// +/// # Performance +/// This function is slower than [substring]. +/// Theoretically, the time complexity is `O(n)` where `n` is the length of the value buffer. +/// It is recommended to use [substring] if the input array only contains ASCII chars. +/// +/// # Basic usage +/// ``` +/// # use arrow::array::StringArray; +/// # use arrow::compute::kernels::substring::substring_by_char; +/// let array = StringArray::from(vec![Some("arrow"), None, Some("Γ ⊢x:T")]); +/// let result = substring_by_char(&array, 1, Some(4)).unwrap(); +/// assert_eq!(result, StringArray::from(vec![Some("rrow"), None, Some(" ⊢x:")])); +/// ``` +pub fn substring_by_char( + array: &GenericStringArray, + start: i64, + length: Option, +) -> Result> { + let mut vals = BufferBuilder::::new({ + let offsets = array.value_offsets(); + (offsets[array.len()] - offsets[0]).to_usize().unwrap() + }); + let mut new_offsets = BufferBuilder::::new(array.len() + 1); + new_offsets.append(OffsetSize::zero()); + let length = length.map(|len| len.to_usize().unwrap()); - let mut length_so_far = OffsetSize::zero(); - new_offsets.push(length_so_far); - (0..array.len()).for_each(|i| { - // the length of this entry - let length_i: OffsetSize = offsets[i + 1] - offsets[i]; - // compute where we should start slicing this entry - let start = offsets[i] - + if start >= OffsetSize::zero() { - start + array.iter().for_each(|val| { + if let Some(val) = val { + let char_count = val.chars().count(); + let start = if start >= 0 { + start.to_usize().unwrap() } else { - length_i + start + char_count - (-start).to_usize().unwrap().min(char_count) }; + let (start_offset, end_offset) = get_start_end_offset(val, start, length); + vals.append_slice(&val.as_bytes()[start_offset..end_offset]); + } + new_offsets.append(OffsetSize::from_usize(vals.len()).unwrap()); + }); + let data = unsafe { + ArrayData::new_unchecked( + GenericStringArray::::get_data_type(), + array.len(), + None, + array + .data_ref() + .null_buffer() + .map(|b| b.bit_slice(array.offset(), array.len())), + 0, + vec![new_offsets.finish(), vals.finish()], + vec![], + ) + }; + Ok(GenericStringArray::::from(data)) +} - let start = start.max(offsets[i]).min(offsets[i + 1]); - // compute the length of the slice - let length: OffsetSize = length - .unwrap_or(length_i) - // .max(0) is not needed as it is guaranteed - .min(offsets[i + 1] - start); // so we do not go beyond this entry - - length_so_far += length; +/// * `val` - string +/// * `start` - the start char index of the substring +/// * `length` - the char length of the substring +/// +/// Return the `start` and `end` offset (by byte) of the substring +fn get_start_end_offset( + val: &str, + start: usize, + length: Option, +) -> (usize, usize) { + let len = val.len(); + let mut offset_char_iter = val.char_indices(); + let start_offset = offset_char_iter + .nth(start) + .map_or(len, |(offset, _)| offset); + let end_offset = length.map_or(len, |length| { + if length > 0 { + offset_char_iter + .nth(length - 1) + .map_or(len, |(offset, _)| offset) + } else { + start_offset + } + }); + (start_offset, end_offset) +} - new_offsets.push(length_so_far); +fn binary_substring( + array: &GenericBinaryArray, + start: OffsetSize, + length: Option, +) -> Result { + let offsets = array.value_offsets(); + let values = array.value_data(); + let data = values.as_slice(); + let zero = OffsetSize::zero(); - // we need usize for ranges - let start = start.to_usize().unwrap(); - let length = length.to_usize().unwrap(); + // start and end offsets of all substrings + let mut new_starts_ends: Vec<(OffsetSize, OffsetSize)> = + Vec::with_capacity(array.len()); + let mut new_offsets: Vec = Vec::with_capacity(array.len() + 1); + let mut len_so_far = zero; + new_offsets.push(zero); - new_values.extend_from_slice(&data[start..start + length]); + offsets.windows(2).for_each(|pair| { + let new_start = match start.cmp(&zero) { + Ordering::Greater => (pair[0] + start).min(pair[1]), + Ordering::Equal => pair[0], + Ordering::Less => (pair[1] + start).max(pair[0]), + }; + let new_end = match length { + Some(length) => (length + new_start).min(pair[1]), + None => pair[1], + }; + len_so_far += new_end - new_start; + new_starts_ends.push((new_start, new_end)); + new_offsets.push(len_so_far); }); + // concatenate substrings into a buffer + let mut new_values = + MutableBuffer::new(new_offsets.last().unwrap().to_usize().unwrap()); + + new_starts_ends + .iter() + .map(|(start, end)| { + let start = start.to_usize().unwrap(); + let end = end.to_usize().unwrap(); + &data[start..end] + }) + .for_each(|slice| new_values.extend_from_slice(slice)); + let data = unsafe { ArrayData::new_unchecked( - ::DATA_TYPE, + GenericBinaryArray::::get_data_type(), array.len(), None, - null_bit_buffer, + array + .data_ref() + .null_buffer() + .map(|b| b.bit_slice(array.offset(), array.len())), 0, - vec![ - Buffer::from_slice_ref(&new_offsets), - Buffer::from_slice_ref(&new_values), - ], + vec![Buffer::from_slice_ref(&new_offsets), new_values.into()], vec![], ) }; Ok(make_array(data)) } -/// Returns an ArrayRef with a substring starting from `start` and with optional length `length` of each of the elements in `array`. -/// `start` can be negative, in which case the start counts from the end of the string. -/// this function errors when the passed array is not a \[Large\]String array. -pub fn substring( - array: &dyn Array, - start: i64, - length: &Option, +fn fixed_size_binary_substring( + array: &FixedSizeBinaryArray, + old_len: i32, + start: i32, + length: Option, ) -> Result { - match array.data_type() { - DataType::LargeUtf8 => generic_substring( + let new_start = if start >= 0 { + start.min(old_len) + } else { + (old_len + start).max(0) + }; + let new_len = match length { + Some(len) => len.min(old_len - new_start), + None => old_len - new_start, + }; + + // build value buffer + let num_of_elements = array.len(); + let values = array.value_data(); + let data = values.as_slice(); + let mut new_values = MutableBuffer::new(num_of_elements * (new_len as usize)); + (0..num_of_elements) + .map(|idx| { + let offset = array.value_offset(idx); + ( + (offset + new_start) as usize, + (offset + new_start + new_len) as usize, + ) + }) + .for_each(|(start, end)| new_values.extend_from_slice(&data[start..end])); + + let array_data = unsafe { + ArrayData::new_unchecked( + DataType::FixedSizeBinary(new_len), + num_of_elements, + None, array - .as_any() - .downcast_ref::() - .expect("A large string is expected"), - start, - &length.map(|e| e as i64), - ), - DataType::Utf8 => generic_substring( + .data_ref() + .null_buffer() + .map(|b| b.bit_slice(array.offset(), num_of_elements)), + 0, + vec![new_values.into()], + vec![], + ) + }; + + Ok(make_array(array_data)) +} + +/// substring by byte +fn utf8_substring( + array: &GenericStringArray, + start: OffsetSize, + length: Option, +) -> Result { + let offsets = array.value_offsets(); + let values = array.value_data(); + let data = values.as_slice(); + let zero = OffsetSize::zero(); + + // Check if `offset` is at a valid char boundary. + // If yes, return `offset`, else return error + let check_char_boundary = { + // Safety: a StringArray must contain valid UTF8 data + let data_str = unsafe { std::str::from_utf8_unchecked(data) }; + |offset: OffsetSize| { + let offset_usize = offset.to_usize().unwrap(); + if data_str.is_char_boundary(offset_usize) { + Ok(offset) + } else { + Err(ArrowError::ComputeError(format!( + "The offset {} is at an invalid utf-8 boundary.", + offset_usize + ))) + } + } + }; + + // start and end offsets of all substrings + let mut new_starts_ends: Vec<(OffsetSize, OffsetSize)> = + Vec::with_capacity(array.len()); + let mut new_offsets: Vec = Vec::with_capacity(array.len() + 1); + let mut len_so_far = zero; + new_offsets.push(zero); + + offsets.windows(2).try_for_each(|pair| -> Result<()> { + let new_start = match start.cmp(&zero) { + Ordering::Greater => check_char_boundary((pair[0] + start).min(pair[1]))?, + Ordering::Equal => pair[0], + Ordering::Less => check_char_boundary((pair[1] + start).max(pair[0]))?, + }; + let new_end = match length { + Some(length) => check_char_boundary((length + new_start).min(pair[1]))?, + None => pair[1], + }; + len_so_far += new_end - new_start; + new_starts_ends.push((new_start, new_end)); + new_offsets.push(len_so_far); + Ok(()) + })?; + + // concatenate substrings into a buffer + let mut new_values = + MutableBuffer::new(new_offsets.last().unwrap().to_usize().unwrap()); + + new_starts_ends + .iter() + .map(|(start, end)| { + let start = start.to_usize().unwrap(); + let end = end.to_usize().unwrap(); + &data[start..end] + }) + .for_each(|slice| new_values.extend_from_slice(slice)); + + let data = unsafe { + ArrayData::new_unchecked( + GenericStringArray::::get_data_type(), + array.len(), + None, array - .as_any() - .downcast_ref::() - .expect("A string is expected"), - start as i32, - &length.map(|e| e as i32), - ), - _ => Err(ArrowError::ComputeError(format!( - "substring does not support type {:?}", - array.data_type() - ))), - } + .data_ref() + .null_buffer() + .map(|b| b.bit_slice(array.offset(), array.len())), + 0, + vec![Buffer::from_slice_ref(&new_offsets), new_values.into()], + vec![], + ) + }; + Ok(make_array(data)) } #[cfg(test)] mod tests { use super::*; + use crate::datatypes::*; + + /// A helper macro to generate test cases. + /// # Arguments + /// * `input` - A vector which array can be built from. + /// * `start` - The start index of the substring. + /// * `len` - The length of the substring. + /// * `result` - The expected result of substring, which is a vector that array can be built from. + /// # Return + /// A vector of `(input, start, len, result)`. + /// + /// Users can provide any number of `(start, len, result)` to generate test cases for one `input`. + macro_rules! gen_test_cases { + ($input:expr, $(($start:expr, $len:expr, $result:expr)), *) => { + [ + $( + ($input.clone(), $start, $len, $result), + )* + ] + }; + } - fn with_nulls>>>( - ) -> Result<()> { - let cases = vec![ + /// A helper macro to test the substring functions. + /// # Arguments + /// * `cases` - The test cases which is a vector of `(input, start, len, result)`. + /// Please look at [`gen_test_cases`] to find how to generate it. + /// * `array_ty` - The array type. + /// * `substring_fn` - Either [`substring`] or [`substring_by_char`]. + macro_rules! do_test { + ($cases:expr, $array_ty:ty, $substring_fn:ident) => { + $cases + .into_iter() + .for_each(|(array, start, length, expected)| { + let array = <$array_ty>::from(array); + let result = $substring_fn(&array, start, length).unwrap(); + let result = result.as_any().downcast_ref::<$array_ty>().unwrap(); + let expected = <$array_ty>::from(expected); + assert_eq!(&expected, result); + }) + }; + } + + fn with_nulls_generic_binary() { + let input = vec![ + Some("hello".as_bytes()), + None, + Some(&[0xf8, 0xf9, 0xff, 0xfa]), + ]; + // all-nulls array is always identical + let base_case = gen_test_cases!( + vec![None, None, None], + (-1, Some(1), vec![None, None, None]) + ); + let cases = gen_test_cases!( + input, // identity - ( - vec![Some("hello"), None, Some("word")], - 0, - None, - vec![Some("hello"), None, Some("word")], - ), + (0, None, input.clone()), // 0 length -> Nothing - ( - vec![Some("hello"), None, Some("word")], - 0, - Some(0), - vec![Some(""), None, Some("")], - ), + (0, Some(0), vec![Some(&[]), None, Some(&[])]), // high start -> Nothing - ( - vec![Some("hello"), None, Some("word")], - 1000, - Some(0), - vec![Some(""), None, Some("")], - ), + (1000, Some(0), vec![Some(&[]), None, Some(&[])]), // high negative start -> identity - ( - vec![Some("hello"), None, Some("word")], - -1000, - None, - vec![Some("hello"), None, Some("word")], - ), + (-1000, None, input.clone()), // high length -> identity - ( - vec![Some("hello"), None, Some("word")], - 0, - Some(1000), - vec![Some("hello"), None, Some("word")], - ), + (0, Some(1000), input.clone()) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + GenericBinaryArray, + substring + ); + } + + #[test] + fn with_nulls_binary() { + with_nulls_generic_binary::() + } + + #[test] + fn with_nulls_large_binary() { + with_nulls_generic_binary::() + } + + fn without_nulls_generic_binary() { + let input = vec!["hello".as_bytes(), b"", &[0xf8, 0xf9, 0xff, 0xfa]]; + // empty array is always identical + let base_case = gen_test_cases!( + vec!["".as_bytes(), b"", b""], + (2, Some(1), vec!["".as_bytes(), b"", b""]) + ); + let cases = gen_test_cases!( + input, + // identity + (0, None, input.clone()), + // increase start + (1, None, vec![b"ello", b"", &[0xf9, 0xff, 0xfa]]), + (2, None, vec![b"llo", b"", &[0xff, 0xfa]]), + (3, None, vec![b"lo", b"", &[0xfa]]), + (10, None, vec![b"", b"", b""]), + // increase start negatively + (-1, None, vec![b"o", b"", &[0xfa]]), + (-2, None, vec![b"lo", b"", &[0xff, 0xfa]]), + (-3, None, vec![b"llo", b"", &[0xf9, 0xff, 0xfa]]), + (-10, None, input.clone()), + // increase length + (1, Some(1), vec![b"e", b"", &[0xf9]]), + (1, Some(2), vec![b"el", b"", &[0xf9, 0xff]]), + (1, Some(3), vec![b"ell", b"", &[0xf9, 0xff, 0xfa]]), + (1, Some(4), vec![b"ello", b"", &[0xf9, 0xff, 0xfa]]), + (-3, Some(1), vec![b"l", b"", &[0xf9]]), + (-3, Some(2), vec![b"ll", b"", &[0xf9, 0xff]]), + (-3, Some(3), vec![b"llo", b"", &[0xf9, 0xff, 0xfa]]), + (-3, Some(4), vec![b"llo", b"", &[0xf9, 0xff, 0xfa]]) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + GenericBinaryArray, + substring + ); + } + + #[test] + fn without_nulls_binary() { + without_nulls_generic_binary::() + } + + #[test] + fn without_nulls_large_binary() { + without_nulls_generic_binary::() + } + + fn generic_binary_with_non_zero_offset() { + let values = 0_u8..15; + let offsets = &[ + O::zero(), + O::from_usize(5).unwrap(), + O::from_usize(10).unwrap(), + O::from_usize(15).unwrap(), ]; + // set the first and third element to be valid + let bitmap = [0b101_u8]; + + let data = ArrayData::builder(GenericBinaryArray::::get_data_type()) + .len(2) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from_iter(values)) + .null_bit_buffer(Some(Buffer::from(bitmap))) + .offset(1) + .build() + .unwrap(); + // array is `[null, [10, 11, 12, 13, 14]]` + let array = GenericBinaryArray::::from(data); + // result is `[null, [11, 12, 13, 14]]` + let result = substring(&array, 1, None).unwrap(); + let result = result + .as_any() + .downcast_ref::>() + .unwrap(); + let expected = + GenericBinaryArray::::from_opt_vec(vec![None, Some(&[11_u8, 12, 13, 14])]); + assert_eq!(result, &expected); + } - cases.into_iter().try_for_each::<_, Result<()>>( - |(array, start, length, expected)| { - let array = T::from(array); - let result: ArrayRef = substring(&array, start, &length)?; - assert_eq!(array.len(), result.len()); + #[test] + fn binary_with_non_zero_offset() { + generic_binary_with_non_zero_offset::() + } - let result = result.as_any().downcast_ref::().unwrap(); - let expected = T::from(expected); - assert_eq!(&expected, result); - Ok(()) - }, - )?; + #[test] + fn large_binary_with_non_zero_offset() { + generic_binary_with_non_zero_offset::() + } - Ok(()) + #[test] + fn with_nulls_fixed_size_binary() { + let input = vec![Some("cat".as_bytes()), None, Some(&[0xf8, 0xf9, 0xff])]; + // all-nulls array is always identical + let base_case = + gen_test_cases!(vec![None, None, None], (3, Some(2), vec![None, None, None])); + let cases = gen_test_cases!( + input, + // identity + (0, None, input.clone()), + // increase start + (1, None, vec![Some(b"at"), None, Some(&[0xf9, 0xff])]), + (2, None, vec![Some(b"t"), None, Some(&[0xff])]), + (3, None, vec![Some(b""), None, Some(b"")]), + (10, None, vec![Some(b""), None, Some(b"")]), + // increase start negatively + (-1, None, vec![Some(b"t"), None, Some(&[0xff])]), + (-2, None, vec![Some(b"at"), None, Some(&[0xf9, 0xff])]), + (-3, None, input.clone()), + (-10, None, input.clone()), + // increase length + (1, Some(1), vec![Some(b"a"), None, Some(&[0xf9])]), + (1, Some(2), vec![Some(b"at"), None, Some(&[0xf9, 0xff])]), + (1, Some(3), vec![Some(b"at"), None, Some(&[0xf9, 0xff])]), + (-3, Some(1), vec![Some(b"c"), None, Some(&[0xf8])]), + (-3, Some(2), vec![Some(b"ca"), None, Some(&[0xf8, 0xf9])]), + (-3, Some(3), input.clone()), + (-3, Some(4), input.clone()) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + FixedSizeBinaryArray, + substring + ); + } + + #[test] + fn without_nulls_fixed_size_binary() { + let input = vec!["cat".as_bytes(), b"dog", &[0xf8, 0xf9, 0xff]]; + // empty array is always identical + let base_case = gen_test_cases!( + vec!["".as_bytes(), &[], &[]], + (1, Some(2), vec!["".as_bytes(), &[], &[]]) + ); + let cases = gen_test_cases!( + input, + // identity + (0, None, input.clone()), + // increase start + (1, None, vec![b"at", b"og", &[0xf9, 0xff]]), + (2, None, vec![b"t", b"g", &[0xff]]), + (3, None, vec![&[], &[], &[]]), + (10, None, vec![&[], &[], &[]]), + // increase start negatively + (-1, None, vec![b"t", b"g", &[0xff]]), + (-2, None, vec![b"at", b"og", &[0xf9, 0xff]]), + (-3, None, input.clone()), + (-10, None, input.clone()), + // increase length + (1, Some(1), vec![b"a", b"o", &[0xf9]]), + (1, Some(2), vec![b"at", b"og", &[0xf9, 0xff]]), + (1, Some(3), vec![b"at", b"og", &[0xf9, 0xff]]), + (-3, Some(1), vec![b"c", b"d", &[0xf8]]), + (-3, Some(2), vec![b"ca", b"do", &[0xf8, 0xf9]]), + (-3, Some(3), input.clone()), + (-3, Some(4), input.clone()) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + FixedSizeBinaryArray, + substring + ); + } + + #[test] + fn fixed_size_binary_with_non_zero_offset() { + let values: [u8; 15] = *b"hellotherearrow"; + // set the first and third element to be valid + let bits_v = [0b101_u8]; + + let data = ArrayData::builder(DataType::FixedSizeBinary(5)) + .len(2) + .add_buffer(Buffer::from(&values[..])) + .offset(1) + .null_bit_buffer(Some(Buffer::from(bits_v))) + .build() + .unwrap(); + // array is `[null, "arrow"]` + let array = FixedSizeBinaryArray::from(data); + // result is `[null, "rrow"]` + let result = substring(&array, 1, None).unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap(); + let expected = FixedSizeBinaryArray::try_from_sparse_iter( + vec![None, Some(b"rrow")].into_iter(), + ) + .unwrap(); + assert_eq!(result, &expected); + } + + fn with_nulls_generic_string() { + let input = vec![Some("hello"), None, Some("word")]; + // all-nulls array is always identical + let base_case = + gen_test_cases!(vec![None, None, None], (0, None, vec![None, None, None])); + let cases = gen_test_cases!( + input, + // identity + (0, None, input.clone()), + // 0 length -> Nothing + (0, Some(0), vec![Some(""), None, Some("")]), + // high start -> Nothing + (1000, Some(0), vec![Some(""), None, Some("")]), + // high negative start -> identity + (-1000, None, input.clone()), + // high length -> identity + (0, Some(1000), input.clone()) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + GenericStringArray, + substring + ); } #[test] - fn with_nulls_string() -> Result<()> { - with_nulls::() + fn with_nulls_string() { + with_nulls_generic_string::() } #[test] - fn with_nulls_large_string() -> Result<()> { - with_nulls::() + fn with_nulls_large_string() { + with_nulls_generic_string::() + } + + fn without_nulls_generic_string() { + let input = vec!["hello", "", "word"]; + // empty array is always identical + let base_case = gen_test_cases!(vec!["", "", ""], (0, None, vec!["", "", ""])); + let cases = gen_test_cases!( + input, + // identity + (0, None, input.clone()), + (1, None, vec!["ello", "", "ord"]), + (2, None, vec!["llo", "", "rd"]), + (3, None, vec!["lo", "", "d"]), + (10, None, vec!["", "", ""]), + // increase start negatively + (-1, None, vec!["o", "", "d"]), + (-2, None, vec!["lo", "", "rd"]), + (-3, None, vec!["llo", "", "ord"]), + (-10, None, input.clone()), + // increase length + (1, Some(1), vec!["e", "", "o"]), + (1, Some(2), vec!["el", "", "or"]), + (1, Some(3), vec!["ell", "", "ord"]), + (1, Some(4), vec!["ello", "", "ord"]), + (-3, Some(1), vec!["l", "", "o"]), + (-3, Some(2), vec!["ll", "", "or"]), + (-3, Some(3), vec!["llo", "", "ord"]), + (-3, Some(4), vec!["llo", "", "ord"]) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + GenericStringArray, + substring + ); } - fn without_nulls>>>( - ) -> Result<()> { - let cases = vec![ + #[test] + fn without_nulls_string() { + without_nulls_generic_string::() + } + + #[test] + fn without_nulls_large_string() { + without_nulls_generic_string::() + } + + fn generic_string_with_non_zero_offset() { + let values = "hellotherearrow"; + let offsets = &[ + O::zero(), + O::from_usize(5).unwrap(), + O::from_usize(10).unwrap(), + O::from_usize(15).unwrap(), + ]; + // set the first and third element to be valid + let bitmap = [0b101_u8]; + + let data = ArrayData::builder(GenericStringArray::::get_data_type()) + .len(2) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from(values)) + .null_bit_buffer(Some(Buffer::from(bitmap))) + .offset(1) + .build() + .unwrap(); + // array is `[null, "arrow"]` + let array = GenericStringArray::::from(data); + // result is `[null, "rrow"]` + let result = substring(&array, 1, None).unwrap(); + let result = result + .as_any() + .downcast_ref::>() + .unwrap(); + let expected = GenericStringArray::::from(vec![None, Some("rrow")]); + assert_eq!(result, &expected); + } + + #[test] + fn string_with_non_zero_offset() { + generic_string_with_non_zero_offset::() + } + + #[test] + fn large_string_with_non_zero_offset() { + generic_string_with_non_zero_offset::() + } + + fn with_nulls_generic_string_by_char() { + let input = vec![Some("hello"), None, Some("Γ ⊢x:T")]; + // all-nulls array is always identical + let base_case = + gen_test_cases!(vec![None, None, None], (0, None, vec![None, None, None])); + let cases = gen_test_cases!( + input, + // identity + (0, None, input.clone()), + // 0 length -> Nothing + (0, Some(0), vec![Some(""), None, Some("")]), + // high start -> Nothing + (1000, Some(0), vec![Some(""), None, Some("")]), + // high negative start -> identity + (-1000, None, input.clone()), + // high length -> identity + (0, Some(1000), input.clone()) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + GenericStringArray, + substring_by_char + ); + } + + #[test] + fn with_nulls_string_by_char() { + with_nulls_generic_string_by_char::() + } + + #[test] + fn with_nulls_large_string_by_char() { + with_nulls_generic_string_by_char::() + } + + fn without_nulls_generic_string_by_char() { + let input = vec!["hello", "", "Γ ⊢x:T"]; + // empty array is always identical + let base_case = gen_test_cases!(vec!["", "", ""], (0, None, vec!["", "", ""])); + let cases = gen_test_cases!( + input, + //identity + (0, None, input.clone()), // increase start - ( - vec!["hello", "", "word"], - 0, - None, - vec!["hello", "", "word"], - ), - (vec!["hello", "", "word"], 1, None, vec!["ello", "", "ord"]), - (vec!["hello", "", "word"], 2, None, vec!["llo", "", "rd"]), - (vec!["hello", "", "word"], 3, None, vec!["lo", "", "d"]), - (vec!["hello", "", "word"], 10, None, vec!["", "", ""]), + (1, None, vec!["ello", "", " ⊢x:T"]), + (2, None, vec!["llo", "", "⊢x:T"]), + (3, None, vec!["lo", "", "x:T"]), + (10, None, vec!["", "", ""]), // increase start negatively - (vec!["hello", "", "word"], -1, None, vec!["o", "", "d"]), - (vec!["hello", "", "word"], -2, None, vec!["lo", "", "rd"]), - (vec!["hello", "", "word"], -3, None, vec!["llo", "", "ord"]), - ( - vec!["hello", "", "word"], - -10, - None, - vec!["hello", "", "word"], - ), + (-1, None, vec!["o", "", "T"]), + (-2, None, vec!["lo", "", ":T"]), + (-4, None, vec!["ello", "", "⊢x:T"]), + (-10, None, input.clone()), // increase length - (vec!["hello", "", "word"], 1, Some(1), vec!["e", "", "o"]), - (vec!["hello", "", "word"], 1, Some(2), vec!["el", "", "or"]), - ( - vec!["hello", "", "word"], - 1, - Some(3), - vec!["ell", "", "ord"], - ), - ( - vec!["hello", "", "word"], - 1, - Some(4), - vec!["ello", "", "ord"], - ), - (vec!["hello", "", "word"], -3, Some(1), vec!["l", "", "o"]), - (vec!["hello", "", "word"], -3, Some(2), vec!["ll", "", "or"]), - ( - vec!["hello", "", "word"], - -3, - Some(3), - vec!["llo", "", "ord"], - ), - ( - vec!["hello", "", "word"], - -3, - Some(4), - vec!["llo", "", "ord"], - ), + (1, Some(1), vec!["e", "", " "]), + (1, Some(2), vec!["el", "", " ⊢"]), + (1, Some(3), vec!["ell", "", " ⊢x"]), + (1, Some(6), vec!["ello", "", " ⊢x:T"]), + (-4, Some(1), vec!["e", "", "⊢"]), + (-4, Some(2), vec!["el", "", "⊢x"]), + (-4, Some(3), vec!["ell", "", "⊢x:"]), + (-4, Some(4), vec!["ello", "", "⊢x:T"]) + ); + + do_test!( + [&base_case[..], &cases[..]].concat(), + GenericStringArray, + substring_by_char + ); + } + + #[test] + fn without_nulls_string_by_char() { + without_nulls_generic_string_by_char::() + } + + #[test] + fn without_nulls_large_string_by_char() { + without_nulls_generic_string_by_char::() + } + + fn generic_string_by_char_with_non_zero_offset() { + let values = "S→T = Πx:S.T"; + let offsets = &[ + O::zero(), + O::from_usize(values.char_indices().nth(3).map(|(pos, _)| pos).unwrap()) + .unwrap(), + O::from_usize(values.char_indices().nth(6).map(|(pos, _)| pos).unwrap()) + .unwrap(), + O::from_usize(values.len()).unwrap(), ]; + // set the first and third element to be valid + let bitmap = [0b101_u8]; - cases.into_iter().try_for_each::<_, Result<()>>( - |(array, start, length, expected)| { - let array = StringArray::from(array); - let result = substring(&array, start, &length)?; - assert_eq!(array.len(), result.len()); - let result = result.as_any().downcast_ref::().unwrap(); - let expected = StringArray::from(expected); - assert_eq!(&expected, result,); - Ok(()) - }, - )?; + let data = ArrayData::builder(GenericStringArray::::get_data_type()) + .len(2) + .add_buffer(Buffer::from_slice_ref(offsets)) + .add_buffer(Buffer::from(values)) + .null_bit_buffer(Some(Buffer::from(bitmap))) + .offset(1) + .build() + .unwrap(); + // array is `[null, "Πx:S.T"]` + let array = GenericStringArray::::from(data); + // result is `[null, "x:S.T"]` + let result = substring_by_char(&array, 1, None).unwrap(); + let expected = GenericStringArray::::from(vec![None, Some("x:S.T")]); + assert_eq!(result, expected); + } - Ok(()) + #[test] + fn string_with_non_zero_offset_by_char() { + generic_string_by_char_with_non_zero_offset::() + } + + #[test] + fn large_string_with_non_zero_offset_by_char() { + generic_string_by_char_with_non_zero_offset::() + } + + #[test] + fn dictionary() { + _dictionary::(); + _dictionary::(); + _dictionary::(); + _dictionary::(); + _dictionary::(); + _dictionary::(); + _dictionary::(); + _dictionary::(); + } + + fn _dictionary() { + const TOTAL: i32 = 100; + + let v = ["aaa", "bbb", "ccc", "ddd", "eee"]; + let data: Vec> = (0..TOTAL) + .map(|n| { + let i = n % 5; + if i == 3 { + None + } else { + Some(v[i as usize]) + } + }) + .collect(); + + let dict_array: DictionaryArray = data.clone().into_iter().collect(); + + let expected: Vec> = + data.iter().map(|opt| opt.map(|s| &s[1..3])).collect(); + + let res = substring(&dict_array, 1, Some(2)).unwrap(); + let actual = res.as_any().downcast_ref::>().unwrap(); + let actual: Vec> = actual + .values() + .as_any() + .downcast_ref::>() + .unwrap() + .take_iter(actual.keys_iter()) + .collect(); + + for i in 0..TOTAL as usize { + assert_eq!(expected[i], actual[i],); + } + } + + #[test] + fn check_invalid_array_type() { + let array = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + let err = substring(&array, 0, None).unwrap_err().to_string(); + assert!(err.contains("substring does not support type")); } + // tests for the utf-8 validation checking #[test] - fn without_nulls_string() -> Result<()> { - without_nulls::() + fn check_start_index() { + let array = StringArray::from(vec![Some("E=mc²"), Some("ascii")]); + let err = substring(&array, -1, None).unwrap_err().to_string(); + assert!(err.contains("invalid utf-8 boundary")); } #[test] - fn without_nulls_large_string() -> Result<()> { - without_nulls::() + fn check_length() { + let array = StringArray::from(vec![Some("E=mc²"), Some("ascii")]); + let err = substring(&array, 0, Some(5)).unwrap_err().to_string(); + assert!(err.contains("invalid utf-8 boundary")); } } diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 692de278974d..624e9ddcdb58 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -53,6 +53,21 @@ macro_rules! downcast_dict_take { /// Take elements by index from [Array], creating a new [Array] from those indexes. /// +/// ```text +/// ┌─────────────────┐ ┌─────────┐ ┌─────────────────┐ +/// │ A │ │ 0 │ │ A │ +/// ├─────────────────┤ ├─────────┤ ├─────────────────┤ +/// │ D │ │ 2 │ │ B │ +/// ├─────────────────┤ ├─────────┤ take(values, indicies) ├─────────────────┤ +/// │ B │ │ 3 │ ─────────────────────────▶ │ C │ +/// ├─────────────────┤ ├─────────┤ ├─────────────────┤ +/// │ C │ │ 1 │ │ D │ +/// ├─────────────────┤ └─────────┘ └─────────────────┘ +/// │ E │ +/// └─────────────────┘ +/// values array indicies array result +/// ``` +/// /// # Errors /// This function errors whenever: /// * An index cannot be casted to `usize` (typically 32 bit architectures) @@ -131,6 +146,10 @@ where let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_boolean(values, indices)?)) } + DataType::Decimal(_, _) => { + let decimal_values = values.as_any().downcast_ref::().unwrap(); + Ok(Arc::new(take_decimal128(decimal_values, indices)?)) + } DataType::Int8 => downcast_take!(Int8Type, values, indices), DataType::Int16 => downcast_take!(Int16Type, values, indices), DataType::Int32 => downcast_take!(Int32Type, values, indices), @@ -171,6 +190,9 @@ where DataType::Interval(IntervalUnit::DayTime) => { downcast_take!(IntervalDayTimeType, values, indices) } + DataType::Interval(IntervalUnit::MonthDayNano) => { + downcast_take!(IntervalMonthDayNanoType, values, indices) + } DataType::Duration(TimeUnit::Second) => { downcast_take!(DurationSecondType, values, indices) } @@ -280,12 +302,23 @@ where .unwrap(); Ok(Arc::new(take_fixed_size_binary(values, indices)?)) } + DataType::Null => { + // Take applied to a null array produces a null array. + if values.len() >= indices.len() { + // If the existing null array is as big as the indices, we can use a slice of it + // to avoid allocating a new null array. + Ok(values.slice(0, indices.len())) + } else { + // If the existing null array isn't big enough, create a new one. + Ok(new_null_array(&DataType::Null, indices.len())) + } + } t => unimplemented!("Take not supported for data type {:?}", t), } } /// Options that define how `take` should behave -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct TakeOptions { /// Perform bounds check before taking indices from values. /// If enabled, an `ArrowError` is returned if the indices are out of bounds. @@ -293,14 +326,6 @@ pub struct TakeOptions { pub check_bounds: bool, } -impl Default for TakeOptions { - fn default() -> Self { - Self { - check_bounds: false, - } - } -} - #[inline(always)] fn maybe_usize(index: I) -> Result { index @@ -477,6 +502,41 @@ where Ok((buffer, nulls)) } +/// `take` implementation for decimal arrays +fn take_decimal128( + decimal_values: &DecimalArray, + indices: &PrimitiveArray, +) -> Result +where + IndexType: ArrowNumericType, + IndexType::Native: ToPrimitive, +{ + indices + .iter() + .map(|index| { + // Use type annotations below for readability (was blowing + // my mind otherwise) + let t: Option>> = index.map(|index| { + let index = ToPrimitive::to_usize(&index).ok_or_else(|| { + ArrowError::ComputeError("Cast to usize failed".to_string()) + })?; + + if decimal_values.is_null(index) { + Ok(None) + } else { + Ok(Some(decimal_values.value(index).as_i128())) + } + }); + let t: Result>> = t.transpose(); + let t: Result> = t.map(|t| t.flatten()); + t + }) + .collect::>()? + // PERF: we could avoid re-validating that the data in + // DecimalArray was in range as we know it came from a valid DecimalArray + .with_precision_and_scale(decimal_values.precision(), decimal_values.scale()) +} + /// `take` implementation for all primitive arrays /// /// This checks if an `indices` slot is populated, and gets the value from `values` @@ -525,7 +585,7 @@ where let data = unsafe { ArrayData::new_unchecked( - T::DATA_TYPE, + values.data_type().clone(), indices.len(), None, nulls, @@ -555,8 +615,7 @@ where let null_count = values.null_count(); - let nulls; - if null_count == 0 { + let nulls = if null_count == 0 { (0..data_len).try_for_each::<_, Result<()>>(|i| { let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { ArrowError::ComputeError("Cast to usize failed".to_string()) @@ -569,7 +628,7 @@ where Ok(()) })?; - nulls = indices.data_ref().null_buffer().cloned(); + indices.data_ref().null_buffer().cloned() } else { let mut null_buf = MutableBuffer::new(num_byte).with_bitset(num_byte, true); let null_slice = null_buf.as_slice_mut(); @@ -588,7 +647,7 @@ where Ok(()) })?; - nulls = match indices.data_ref().null_buffer() { + match indices.data_ref().null_buffer() { Some(buffer) => Some(buffer_bin_and( buffer, indices.offset(), @@ -597,8 +656,8 @@ where indices.len(), )), None => Some(null_buf.into()), - }; - } + } + }; let data = unsafe { ArrayData::new_unchecked( @@ -620,7 +679,7 @@ fn take_string( indices: &PrimitiveArray, ) -> Result> where - OffsetSize: Zero + AddAssign + StringOffsetSizeTrait, + OffsetSize: Zero + AddAssign + OffsetSizeTrait, IndexType: ArrowNumericType, IndexType::Native: ToPrimitive, { @@ -717,14 +776,13 @@ where }; } - let mut array_data = - ArrayData::builder(::DATA_TYPE) + let array_data = + ArrayData::builder(GenericStringArray::::get_data_type()) .len(data_len) .add_buffer(offsets_buffer.into()) - .add_buffer(values.into()); - if let Some(null_buffer) = nulls { - array_data = array_data.null_bit_buffer(null_buffer); - } + .add_buffer(values.into()) + .null_bit_buffer(nulls); + let array_data = unsafe { array_data.build_unchecked() }; Ok(GenericStringArray::::from(array_data)) @@ -772,7 +830,7 @@ where // create a new list with taken data and computed null information let list_data = ArrayDataBuilder::new(values.data_type().clone()) .len(indices.len()) - .null_bit_buffer(null_buf.into()) + .null_bit_buffer(Some(null_buf.into())) .offset(0) .add_child_data(taken.data().clone()) .add_buffer(value_offsets); @@ -815,7 +873,7 @@ where let list_data = ArrayDataBuilder::new(values.data_type().clone()) .len(indices.len()) - .null_bit_buffer(null_buf.into()) + .null_bit_buffer(Some(null_buf.into())) .offset(0) .add_child_data(taken.data().clone()); @@ -829,7 +887,7 @@ fn take_binary( indices: &PrimitiveArray, ) -> Result> where - OffsetType: BinaryOffsetSizeTrait, + OffsetType: OffsetSizeTrait, IndexType: ArrowNumericType, IndexType::Native: ToPrimitive, { @@ -914,6 +972,32 @@ mod tests { use super::*; use crate::compute::util::tests::build_fixed_size_list_nullable; + fn test_take_decimal_arrays( + data: Vec>, + index: &UInt32Array, + options: Option, + expected_data: Vec>, + precision: &usize, + scale: &usize, + ) -> Result<()> { + let output = data + .into_iter() + .collect::() + .with_precision_and_scale(*precision, *scale) + .unwrap(); + + let expected = expected_data + .into_iter() + .collect::() + .with_precision_and_scale(*precision, *scale) + .unwrap(); + + let expected = Arc::new(expected) as ArrayRef; + let output = take(&output, index, options).unwrap(); + assert_eq!(&output, &expected); + Ok(()) + } + fn test_take_boolean_arrays( data: Vec>, index: &UInt32Array, @@ -1010,6 +1094,38 @@ mod tests { struct_builder.finish() } + #[test] + fn test_take_decimal128_non_null_indices() { + let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); + let precision: usize = 10; + let scale: usize = 5; + test_take_decimal_arrays( + vec![None, Some(3), Some(5), Some(2), Some(3), None], + &index, + None, + vec![None, None, Some(2), Some(3), Some(3), Some(5)], + &precision, + &scale, + ) + .unwrap(); + } + + #[test] + fn test_take_decimal128() { + let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); + let precision: usize = 10; + let scale: usize = 5; + test_take_decimal_arrays( + vec![Some(0), Some(1), Some(2), Some(3), Some(4)], + &index, + None, + vec![Some(3), None, Some(1), Some(3), Some(2)], + &precision, + &scale, + ) + .unwrap(); + } + #[test] fn test_take_primitive_non_null_indices() { let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); @@ -1182,6 +1298,15 @@ mod tests { ) .unwrap(); + // interval_month_day_nano + test_take_primitive_arrays::( + vec![Some(0), None, Some(2), Some(-15), None], + &index, + None, + vec![Some(-15), None, None, Some(-15), Some(2)], + ) + .unwrap(); + // duration_second test_take_primitive_arrays::( vec![Some(0), None, Some(2), Some(-15), None], @@ -1237,6 +1362,23 @@ mod tests { .unwrap(); } + #[test] + fn test_take_preserve_timezone() { + let index = Int64Array::from(vec![Some(0), None]); + + let input = TimestampNanosecondArray::from_vec( + vec![1_639_715_368_000_000_000, 1_639_715_368_000_000_000], + Some("UTC".to_owned()), + ); + let result = take_impl(&input, &index, None).unwrap(); + match result.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + assert_eq!(tz.clone(), Some("UTC".to_owned())) + } + _ => panic!(), + } + } + #[test] fn test_take_impl_primitive_with_int64_indices() { let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); @@ -1429,9 +1571,7 @@ mod tests { let expected_list_data = ArrayData::builder(list_data_type) .len(5) // null buffer remains the same as only the indices have nulls - .null_bit_buffer( - index.data().null_bitmap().as_ref().unwrap().bits.clone(), - ) + .null_bit_buffer(index.data().null_buffer().cloned()) .add_buffer(expected_offsets) .add_child_data(expected_data) .build() @@ -1470,7 +1610,7 @@ mod tests { let list_data = ArrayData::builder(list_data_type.clone()) .len(4) .add_buffer(value_offsets) - .null_bit_buffer(Buffer::from([0b10111101, 0b00000000])) + .null_bit_buffer(Some(Buffer::from([0b10111101, 0b00000000]))) .add_child_data(value_data) .build() .unwrap(); @@ -1505,9 +1645,7 @@ mod tests { let expected_list_data = ArrayData::builder(list_data_type) .len(5) // null buffer remains the same as only the indices have nulls - .null_bit_buffer( - index.data().null_bitmap().as_ref().unwrap().bits.clone(), - ) + .null_bit_buffer(index.data().null_buffer().cloned()) .add_buffer(expected_offsets) .add_child_data(expected_data) .build() @@ -1545,7 +1683,7 @@ mod tests { let list_data = ArrayData::builder(list_data_type.clone()) .len(4) .add_buffer(value_offsets) - .null_bit_buffer(Buffer::from([0b01111101])) + .null_bit_buffer(Some(Buffer::from([0b01111101]))) .add_child_data(value_data) .build() .unwrap(); @@ -1583,7 +1721,7 @@ mod tests { let expected_list_data = ArrayData::builder(list_data_type) .len(5) // null buffer must be recalculated as both values and indices have nulls - .null_bit_buffer(Buffer::from(null_bits)) + .null_bit_buffer(Some(Buffer::from(null_bits))) .add_buffer(expected_offsets) .add_child_data(expected_data) .build() @@ -1813,6 +1951,38 @@ mod tests { .unwrap(); } + #[test] + fn test_null_array_smaller_than_indices() { + let values = NullArray::new(2); + let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); + + let result = take(&values, &indices, None).unwrap(); + let expected: ArrayRef = Arc::new(NullArray::new(3)); + assert_eq!(&result, &expected); + } + + #[test] + fn test_null_array_larger_than_indices() { + let values = NullArray::new(5); + let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); + + let result = take(&values, &indices, None).unwrap(); + let expected: ArrayRef = Arc::new(NullArray::new(3)); + assert_eq!(&result, &expected); + } + + #[test] + fn test_null_array_indices_out_of_bounds() { + let values = NullArray::new(5); + let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); + + let result = take(&values, &indices, Some(TakeOptions { check_bounds: true })); + assert_eq!( + result.unwrap_err().to_string(), + "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries" + ); + } + #[test] fn test_take_dict() { let keys_builder = Int16Builder::new(8); diff --git a/arrow/src/compute/kernels/temporal.rs b/arrow/src/compute/kernels/temporal.rs index 24559b0f8cd4..f731ff5ce745 100644 --- a/arrow/src/compute/kernels/temporal.rs +++ b/arrow/src/compute/kernels/temporal.rs @@ -17,7 +17,7 @@ //! Defines temporal kernels for time and date related functions. -use chrono::{Datelike, Timelike}; +use chrono::{Datelike, NaiveDate, NaiveDateTime, Timelike}; use crate::array::*; use crate::datatypes::*; @@ -40,6 +40,20 @@ macro_rules! extract_component_from_array { } } }; + ($array:ident, $builder:ident, $extract_fn1:ident, $extract_fn2:ident, $using:ident) => { + for i in 0..$array.len() { + if $array.is_null(i) { + $builder.append_null()?; + } else { + match $array.$using(i) { + Some(dt) => { + $builder.append_value(dt.$extract_fn1().$extract_fn2() as i32)? + } + None => $builder.append_null()?, + } + } + } + }; ($array:ident, $builder:ident, $extract_fn:ident, $using:ident, $tz:ident, $parsed:ident) => { if ($tz.starts_with('+') || $tz.starts_with('-')) && !$tz.contains(':') { return_compute_error_with!( @@ -47,23 +61,44 @@ macro_rules! extract_component_from_array { "Expected format [+-]XX:XX".to_string() ) } else { - let fixed_offset = match parse(&mut $parsed, $tz, StrftimeItems::new("%z")) { + let tz_parse_result = parse(&mut $parsed, $tz, StrftimeItems::new("%z")); + let fixed_offset_from_parsed = match tz_parse_result { Ok(_) => match $parsed.to_fixed_offset() { - Ok(fo) => fo, + Ok(fo) => Some(fo), err => return_compute_error_with!("Invalid timezone", err), }, - _ => match using_chrono_tz($tz) { - Some(fo) => fo, - err => return_compute_error_with!("Unable to parse timezone", err), - }, + _ => None, }; + for i in 0..$array.len() { if $array.is_null(i) { $builder.append_null()?; } else { - match $array.$using(i, fixed_offset) { - Some(dt) => $builder.append_value(dt.$extract_fn() as i32)?, - None => $builder.append_null()?, + match $array.value_as_datetime(i) { + Some(utc) => { + let fixed_offset = match fixed_offset_from_parsed { + Some(fo) => fo, + None => match using_chrono_tz_and_utc_naive_date_time( + $tz, utc, + ) { + Some(fo) => fo, + err => return_compute_error_with!( + "Unable to parse timezone", + err + ), + }, + }; + match $array.$using(i, fixed_offset) { + Some(dt) => { + $builder.append_value(dt.$extract_fn() as i32)? + } + None => $builder.append_null()?, + } + } + err => return_compute_error_with!( + "Unable to read value as datetime", + err + ), } } } @@ -77,21 +112,54 @@ macro_rules! return_compute_error_with { }; } -/// Parse the given string into a string representing fixed-offset +trait ChronoDateQuarter { + /// Returns a value in range `1..=4` indicating the quarter this date falls into + fn quarter(&self) -> u32; + + /// Returns a value in range `0..=3` indicating the quarter (zero-based) this date falls into + fn quarter0(&self) -> u32; +} + +impl ChronoDateQuarter for NaiveDateTime { + fn quarter(&self) -> u32 { + self.quarter0() + 1 + } + + fn quarter0(&self) -> u32 { + self.month0() / 3 + } +} + +impl ChronoDateQuarter for NaiveDate { + fn quarter(&self) -> u32 { + self.quarter0() + 1 + } + + fn quarter0(&self) -> u32 { + self.month0() / 3 + } +} + #[cfg(not(feature = "chrono-tz"))] -pub fn using_chrono_tz(_: &str) -> Option { +pub fn using_chrono_tz_and_utc_naive_date_time( + _tz: &str, + _utc: chrono::NaiveDateTime, +) -> Option { None } -/// Parse the given string into a string representing fixed-offset +/// Parse the given string into a string representing fixed-offset that is correct as of the given +/// UTC NaiveDateTime. +/// Note that the offset is function of time and can vary depending on whether daylight savings is +/// in effect or not. e.g. Australia/Sydney is +10:00 or +11:00 depending on DST. #[cfg(feature = "chrono-tz")] -pub fn using_chrono_tz(tz: &str) -> Option { +pub fn using_chrono_tz_and_utc_naive_date_time( + tz: &str, + utc: chrono::NaiveDateTime, +) -> Option { use chrono::{Offset, TimeZone}; tz.parse::() - .map(|tz| { - tz.offset_from_utc_datetime(&chrono::NaiveDateTime::from_timestamp(0, 0)) - .fix() - }) + .map(|tz| tz.offset_from_utc_datetime(&utc).fix()) .ok() } @@ -143,6 +211,118 @@ where Ok(b.finish()) } +/// Extracts the quarter of a given temporal array as an array of integers +pub fn quarter(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: std::convert::From, +{ + let mut b = Int32Builder::new(array.len()); + match array.data_type() { + &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { + extract_component_from_array!(array, b, quarter, value_as_datetime) + } + &DataType::Timestamp(_, Some(ref tz)) => { + let mut scratch = Parsed::new(); + extract_component_from_array!( + array, + b, + quarter, + value_as_datetime_with_tz, + tz, + scratch + ) + } + dt => return_compute_error_with!("quarter does not support", dt), + } + + Ok(b.finish()) +} + +/// Extracts the month of a given temporal array as an array of integers +pub fn month(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: std::convert::From, +{ + let mut b = Int32Builder::new(array.len()); + match array.data_type() { + &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { + extract_component_from_array!(array, b, month, value_as_datetime) + } + &DataType::Timestamp(_, Some(ref tz)) => { + let mut scratch = Parsed::new(); + extract_component_from_array!( + array, + b, + month, + value_as_datetime_with_tz, + tz, + scratch + ) + } + dt => return_compute_error_with!("month does not support", dt), + } + + Ok(b.finish()) +} + +/// Extracts the day of week of a given temporal array as an array of integers +pub fn weekday(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: std::convert::From, +{ + let mut b = Int32Builder::new(array.len()); + match array.data_type() { + &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { + extract_component_from_array!(array, b, weekday, value_as_datetime) + } + &DataType::Timestamp(_, Some(ref tz)) => { + let mut scratch = Parsed::new(); + extract_component_from_array!( + array, + b, + weekday, + value_as_datetime_with_tz, + tz, + scratch + ) + } + dt => return_compute_error_with!("weekday does not support", dt), + } + + Ok(b.finish()) +} + +/// Extracts the day of a given temporal array as an array of integers +pub fn day(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: std::convert::From, +{ + let mut b = Int32Builder::new(array.len()); + match array.data_type() { + &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { + extract_component_from_array!(array, b, day, value_as_datetime) + } + &DataType::Timestamp(_, Some(ref tz)) => { + let mut scratch = Parsed::new(); + extract_component_from_array!( + array, + b, + day, + value_as_datetime_with_tz, + tz, + scratch + ) + } + dt => return_compute_error_with!("day does not support", dt), + } + + Ok(b.finish()) +} + /// Extracts the minutes of a given temporal array as an array of integers pub fn minute(array: &PrimitiveArray) -> Result where @@ -171,6 +351,24 @@ where Ok(b.finish()) } +/// Extracts the week of a given temporal array as an array of integers +pub fn week(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: std::convert::From, +{ + let mut b = Int32Builder::new(array.len()); + + match array.data_type() { + &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { + extract_component_from_array!(array, b, iso_week, week, value_as_datetime) + } + dt => return_compute_error_with!("week does not support", dt), + } + + Ok(b.finish()) +} + /// Extracts the seconds of a given temporal array as an array of integers pub fn second(array: &PrimitiveArray) -> Result where @@ -202,6 +400,8 @@ where #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "chrono-tz")] + use chrono::NaiveDate; #[test] fn test_temporal_array_date64_hour() { @@ -273,6 +473,145 @@ mod tests { assert_eq!(2012, b.value(2)); } + #[test] + fn test_temporal_array_date64_quarter() { + //1514764800000 -> 2018-01-01 + //1566275025000 -> 2019-08-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1566275025000)].into(); + + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(3, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_quarter() { + let a: PrimitiveArray = vec![Some(1), None, Some(300)].into(); + + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(4, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_quarter_with_timezone() { + use std::sync::Arc; + + // 24 * 60 * 60 = 86400 + let a = Arc::new(TimestampSecondArray::from_vec( + vec![86400 * 90], + Some("+00:00".to_string()), + )); + let b = quarter(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = Arc::new(TimestampSecondArray::from_vec( + vec![86400 * 90], + Some("-10:00".to_string()), + )); + let b = quarter(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_date64_month() { + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_month() { + let a: PrimitiveArray = vec![Some(1), None, Some(31)].into(); + + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_timestamp_month_with_timezone() { + use std::sync::Arc; + + // 24 * 60 * 60 = 86400 + let a = Arc::new(TimestampSecondArray::from_vec( + vec![86400 * 31], + Some("+00:00".to_string()), + )); + let b = month(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = Arc::new(TimestampSecondArray::from_vec( + vec![86400 * 31], + Some("-10:00".to_string()), + )); + let b = month(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_timestamp_day_with_timezone() { + use std::sync::Arc; + + // 24 * 60 * 60 = 86400 + let a = Arc::new(TimestampSecondArray::from_vec( + vec![86400], + Some("+00:00".to_string()), + )); + let b = day(&a).unwrap(); + assert_eq!(2, b.value(0)); + let a = Arc::new(TimestampSecondArray::from_vec( + vec![86400], + Some("-10:00".to_string()), + )); + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + } + + #[test] + fn test_temporal_array_date64_weekday() { + //1514764800000 -> 2018-01-01 (Monday) + //1550636625000 -> 2019-02-20 (Wednesday) + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = weekday(&a).unwrap(); + assert_eq!(0, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_day() { + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = + vec![Some(1514764800000), None, Some(1550636625000)].into(); + + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(20, b.value(2)); + } + + #[test] + fn test_temporal_array_date32_day() { + let a: PrimitiveArray = vec![Some(0), None, Some(31)].into(); + + let b = day(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + } + #[test] fn test_temporal_array_timestamp_micro_year() { let a: TimestampMicrosecondArray = @@ -306,6 +645,48 @@ mod tests { assert_eq!(44, b.value(2)); } + #[test] + fn test_temporal_array_date32_week() { + let a: PrimitiveArray = vec![Some(0), None, Some(7)].into(); + + let b = week(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(2, b.value(2)); + } + + #[test] + fn test_temporal_array_date64_week() { + // 1646116175000 -> 2022.03.01 , 1641171600000 -> 2022.01.03 + // 1640998800000 -> 2022.01.01 + let a: PrimitiveArray = vec![ + Some(1646116175000), + None, + Some(1641171600000), + Some(1640998800000), + ] + .into(); + + let b = week(&a).unwrap(); + assert_eq!(9, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(1, b.value(2)); + assert_eq!(52, b.value(3)); + } + + #[test] + fn test_temporal_array_timestamp_micro_week() { + //1612025847000000 -> 2021.1.30 + //1722015847000000 -> 2024.7.27 + let a: TimestampMicrosecondArray = + vec![Some(1612025847000000), None, Some(1722015847000000)].into(); + + let b = week(&a).unwrap(); + assert_eq!(4, b.value(0)); + assert!(!b.is_valid(1)); + assert_eq!(30, b.value(2)); + } + #[test] fn test_temporal_array_date64_second() { let a: PrimitiveArray = @@ -424,6 +805,22 @@ mod tests { assert_eq!(15, b.value(0)); } + #[cfg(feature = "chrono-tz")] + #[test] + fn test_temporal_array_timestamp_hour_with_dst_timezone_using_chrono_tz() { + // + // 1635577147 converts to 2021-10-30 17:59:07 in time zone Australia/Sydney (AEDT) + // The offset (difference to UTC) is +11:00. Note that daylight savings is in effect on 2021-10-30. + // When daylight savings is not in effect, Australia/Sydney has an offset difference of +10:00. + + let a = TimestampMillisecondArray::from_opt_vec( + vec![Some(1635577147000)], + Some("Australia/Sydney".to_string()), + ); + let b = hour(&a).unwrap(); + assert_eq!(17, b.value(0)); + } + #[cfg(not(feature = "chrono-tz"))] #[test] fn test_temporal_array_timestamp_hour_with_timezone_using_chrono_tz() { @@ -435,4 +832,62 @@ mod tests { )); assert!(matches!(hour(&a), Err(ArrowError::ComputeError(_)))) } + + #[cfg(feature = "chrono-tz")] + #[test] + fn test_using_chrono_tz_and_utc_naive_date_time() { + let sydney_tz = "Australia/Sydney".to_string(); + let sydney_offset_without_dst = FixedOffset::east(10 * 60 * 60); + let sydney_offset_with_dst = FixedOffset::east(11 * 60 * 60); + // Daylight savings ends + // When local daylight time was about to reach + // Sunday, 4 April 2021, 3:00:00 am clocks were turned backward 1 hour to + // Sunday, 4 April 2021, 2:00:00 am local standard time instead. + + // Daylight savings starts + // When local standard time was about to reach + // Sunday, 3 October 2021, 2:00:00 am clocks were turned forward 1 hour to + // Sunday, 3 October 2021, 3:00:00 am local daylight time instead. + + // Sydney 2021-04-04T02:30:00+11:00 is 2021-04-03T15:30:00Z + let utc_just_before_sydney_dst_ends = + NaiveDate::from_ymd(2021, 4, 3).and_hms_nano(15, 30, 0, 0); + assert_eq!( + using_chrono_tz_and_utc_naive_date_time( + &sydney_tz, + utc_just_before_sydney_dst_ends + ), + Some(sydney_offset_with_dst) + ); + // Sydney 2021-04-04T02:30:00+10:00 is 2021-04-03T16:30:00Z + let utc_just_after_sydney_dst_ends = + NaiveDate::from_ymd(2021, 4, 3).and_hms_nano(16, 30, 0, 0); + assert_eq!( + using_chrono_tz_and_utc_naive_date_time( + &sydney_tz, + utc_just_after_sydney_dst_ends + ), + Some(sydney_offset_without_dst) + ); + // Sydney 2021-10-03T01:30:00+10:00 is 2021-10-02T15:30:00Z + let utc_just_before_sydney_dst_starts = + NaiveDate::from_ymd(2021, 10, 2).and_hms_nano(15, 30, 0, 0); + assert_eq!( + using_chrono_tz_and_utc_naive_date_time( + &sydney_tz, + utc_just_before_sydney_dst_starts + ), + Some(sydney_offset_without_dst) + ); + // Sydney 2021-04-04T03:30:00+11:00 is 2021-10-02T16:30:00Z + let utc_just_after_sydney_dst_starts = + NaiveDate::from_ymd(2022, 10, 2).and_hms_nano(16, 30, 0, 0); + assert_eq!( + using_chrono_tz_and_utc_naive_date_time( + &sydney_tz, + utc_just_after_sydney_dst_starts + ), + Some(sydney_offset_with_dst) + ); + } } diff --git a/arrow/src/compute/kernels/window.rs b/arrow/src/compute/kernels/window.rs index b354b4aa3bd8..54b11c3b2747 100644 --- a/arrow/src/compute/kernels/window.rs +++ b/arrow/src/compute/kernels/window.rs @@ -23,7 +23,7 @@ use crate::{ array::{make_array, new_null_array}, compute::concat, }; -use num::{abs, clamp}; +use num::abs; /// Shifts array by defined number of items (to left or right) /// A positive value for `offset` shifts the array to the right @@ -51,7 +51,7 @@ use num::{abs, clamp}; /// let expected: Int32Array = vec![Some(1), None, Some(4)].into(); /// assert_eq!(res.as_ref(), &expected); /// -/// // shift array 3 element tot he right +/// // shift array 3 element to the right /// let res = shift(&a, 3).unwrap(); /// let expected: Int32Array = vec![None, None, None].into(); /// assert_eq!(res.as_ref(), &expected); @@ -63,18 +63,21 @@ pub fn shift(array: &dyn Array, offset: i64) -> Result { } else if offset == i64::MIN || abs(offset) >= value_len { Ok(new_null_array(array.data_type(), array.len())) } else { - let slice_offset = clamp(-offset, 0, value_len) as usize; - let length = array.len() - abs(offset) as usize; - let slice = array.slice(slice_offset, length); - - // Generate array with remaining `null` items - let nulls = abs(offset) as usize; - let null_arr = new_null_array(array.data_type(), nulls); - // Concatenate both arrays, add nulls after if shift > 0 else before if offset > 0 { + let length = array.len() - offset as usize; + let slice = array.slice(0, length); + + // Generate array with remaining `null` items + let null_arr = new_null_array(array.data_type(), offset as usize); concat(&[null_arr.as_ref(), slice.as_ref()]) } else { + let offset = -offset as usize; + let length = array.len() - offset; + let slice = array.slice(offset, length); + + // Generate array with remaining `null` items + let null_arr = new_null_array(array.data_type(), offset); concat(&[slice.as_ref(), null_arr.as_ref()]) } } diff --git a/arrow/src/compute/util.rs b/arrow/src/compute/util.rs index f4ddbaf56d1e..7d5837261059 100644 --- a/arrow/src/compute/util.rs +++ b/arrow/src/compute/util.rs @@ -18,78 +18,45 @@ //! Common utilities for computation kernels. use crate::array::*; -use crate::buffer::{buffer_bin_and, buffer_bin_or, Buffer}; +use crate::buffer::{buffer_bin_and, Buffer}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use num::{One, ToPrimitive, Zero}; use std::ops::Add; -/// Combines the null bitmaps of two arrays using a bitwise `and` operation. +/// Combines the null bitmaps of multiple arrays using a bitwise `and` operation. /// /// This function is useful when implementing operations on higher level arrays. #[allow(clippy::unnecessary_wraps)] pub(super) fn combine_option_bitmap( - left_data: &ArrayData, - right_data: &ArrayData, + arrays: &[&ArrayData], len_in_bits: usize, ) -> Result> { - let left_offset_in_bits = left_data.offset(); - let right_offset_in_bits = right_data.offset(); - - let left = left_data.null_buffer(); - let right = right_data.null_buffer(); - - match left { - None => match right { - None => Ok(None), - Some(r) => Ok(Some(r.bit_slice(right_offset_in_bits, len_in_bits))), - }, - Some(l) => match right { - None => Ok(Some(l.bit_slice(left_offset_in_bits, len_in_bits))), - - Some(r) => Ok(Some(buffer_bin_and( - l, - left_offset_in_bits, - r, - right_offset_in_bits, - len_in_bits, - ))), - }, - } -} - -/// Compares the null bitmaps of two arrays using a bitwise `or` operation. -/// -/// This function is useful when implementing operations on higher level arrays. -#[allow(clippy::unnecessary_wraps)] -pub(super) fn compare_option_bitmap( - left_data: &ArrayData, - right_data: &ArrayData, - len_in_bits: usize, -) -> Result> { - let left_offset_in_bits = left_data.offset(); - let right_offset_in_bits = right_data.offset(); - - let left = left_data.null_buffer(); - let right = right_data.null_buffer(); - - match left { - None => match right { - None => Ok(None), - Some(r) => Ok(Some(r.bit_slice(right_offset_in_bits, len_in_bits))), - }, - Some(l) => match right { - None => Ok(Some(l.bit_slice(left_offset_in_bits, len_in_bits))), - - Some(r) => Ok(Some(buffer_bin_or( - l, - left_offset_in_bits, - r, - right_offset_in_bits, - len_in_bits, - ))), - }, - } + arrays + .iter() + .map(|array| (array.null_buffer().cloned(), array.offset())) + .reduce(|acc, buffer_and_offset| match (acc, buffer_and_offset) { + ((None, _), (None, _)) => (None, 0), + ((Some(buffer), offset), (None, _)) | ((None, _), (Some(buffer), offset)) => { + (Some(buffer), offset) + } + ((Some(buffer_left), offset_left), (Some(buffer_right), offset_right)) => ( + Some(buffer_bin_and( + &buffer_left, + offset_left, + &buffer_right, + offset_right, + len_in_bits, + )), + 0, + ), + }) + .map_or( + Err(ArrowError::ComputeError( + "Arrays must not be empty".to_string(), + )), + |(buffer, offset)| Ok(buffer.map(|buffer| buffer.slice(offset))), + ) } /// Takes/filters a list array's inner data using the offsets of the list array. @@ -175,24 +142,58 @@ pub(super) mod tests { use std::sync::Arc; + use crate::buffer::buffer_bin_or; use crate::datatypes::DataType; use crate::util::bit_util; use crate::{array::ArrayData, buffer::MutableBuffer}; + /// Compares the null bitmaps of two arrays using a bitwise `or` operation. + /// + /// This function is useful when implementing operations on higher level arrays. + pub(super) fn compare_option_bitmap( + left_data: &ArrayData, + right_data: &ArrayData, + len_in_bits: usize, + ) -> Result> { + let left_offset_in_bits = left_data.offset(); + let right_offset_in_bits = right_data.offset(); + + let left = left_data.null_buffer(); + let right = right_data.null_buffer(); + + match left { + None => match right { + None => Ok(None), + Some(r) => Ok(Some(r.bit_slice(right_offset_in_bits, len_in_bits))), + }, + Some(l) => match right { + None => Ok(Some(l.bit_slice(left_offset_in_bits, len_in_bits))), + + Some(r) => Ok(Some(buffer_bin_or( + l, + left_offset_in_bits, + r, + right_offset_in_bits, + len_in_bits, + ))), + }, + } + } + fn make_data_with_null_bit_buffer( len: usize, offset: usize, null_bit_buffer: Option, ) -> Arc { - // empty vec for buffers and children is not really correct, but for these tests we only care about the null bitmap + let buffer = Buffer::from(&vec![11; len]); + Arc::new( ArrayData::try_new( DataType::UInt8, len, - None, null_bit_buffer, offset, - vec![], + vec![buffer], vec![], ) .unwrap(), @@ -206,25 +207,52 @@ pub(super) mod tests { make_data_with_null_bit_buffer(8, 0, Some(Buffer::from([0b01001010]))); let inverse_bitmap = make_data_with_null_bit_buffer(8, 0, Some(Buffer::from([0b10110101]))); + let some_other_bitmap = + make_data_with_null_bit_buffer(8, 0, Some(Buffer::from([0b11010111]))); assert_eq!( - None, - combine_option_bitmap(&none_bitmap, &none_bitmap, 8).unwrap() + combine_option_bitmap(&[], 8).unwrap_err().to_string(), + "Compute error: Arrays must not be empty", ); assert_eq!( Some(Buffer::from([0b01001010])), - combine_option_bitmap(&some_bitmap, &none_bitmap, 8).unwrap() + combine_option_bitmap(&[&some_bitmap], 8).unwrap() + ); + assert_eq!( + None, + combine_option_bitmap(&[&none_bitmap, &none_bitmap], 8).unwrap() ); assert_eq!( Some(Buffer::from([0b01001010])), - combine_option_bitmap(&none_bitmap, &some_bitmap, 8,).unwrap() + combine_option_bitmap(&[&some_bitmap, &none_bitmap], 8).unwrap() + ); + assert_eq!( + Some(Buffer::from([0b11010111])), + combine_option_bitmap(&[&none_bitmap, &some_other_bitmap], 8).unwrap() ); assert_eq!( Some(Buffer::from([0b01001010])), - combine_option_bitmap(&some_bitmap, &some_bitmap, 8,).unwrap() + combine_option_bitmap(&[&some_bitmap, &some_bitmap], 8,).unwrap() ); assert_eq!( Some(Buffer::from([0b0])), - combine_option_bitmap(&some_bitmap, &inverse_bitmap, 8,).unwrap() + combine_option_bitmap(&[&some_bitmap, &inverse_bitmap], 8,).unwrap() + ); + assert_eq!( + Some(Buffer::from([0b01000010])), + combine_option_bitmap(&[&some_bitmap, &some_other_bitmap, &none_bitmap], 8,) + .unwrap() + ); + assert_eq!( + Some(Buffer::from([0b00001001])), + combine_option_bitmap( + &[ + &some_bitmap.slice(3, 5), + &inverse_bitmap.slice(2, 5), + &some_other_bitmap.slice(1, 5) + ], + 5, + ) + .unwrap() ); } @@ -300,7 +328,7 @@ pub(super) mod tests { values.append(&mut array); } else { list_null_count += 1; - bit_util::unset_bit(&mut list_bitmap.as_slice_mut(), idx); + bit_util::unset_bit(list_bitmap.as_slice_mut(), idx); } offset.push(values.len() as i64); } @@ -333,7 +361,7 @@ pub(super) mod tests { let list_data = ArrayData::builder(list_data_type) .len(list_len) - .null_bit_buffer(list_bitmap.into()) + .null_bit_buffer(Some(list_bitmap.into())) .add_buffer(value_offsets) .add_child_data(value_data) .build() @@ -342,27 +370,6 @@ pub(super) mod tests { GenericListArray::::from(list_data) } - pub(crate) fn build_fixed_size_list( - data: Vec>>, - length: ::Native, - ) -> FixedSizeListArray - where - T: ArrowPrimitiveType, - PrimitiveArray: From>>, - { - let data = data - .into_iter() - .map(|subarray| { - subarray.map(|item| { - item.into_iter() - .map(Some) - .collect::>>() - }) - }) - .collect(); - build_fixed_size_list_nullable(data, length) - } - pub(crate) fn build_fixed_size_list_nullable( list_values: Vec>>>, length: ::Native, @@ -385,7 +392,7 @@ pub(super) mod tests { values.extend(items.into_iter()); } else { list_null_count += 1; - bit_util::unset_bit(&mut list_bitmap.as_slice_mut(), idx); + bit_util::unset_bit(list_bitmap.as_slice_mut(), idx); values.extend(vec![None; length as usize].into_iter()); } } @@ -399,7 +406,7 @@ pub(super) mod tests { let list_data = ArrayData::builder(list_data_type) .len(list_len) - .null_bit_buffer(list_bitmap.into()) + .null_bit_buffer(Some(list_bitmap.into())) .add_child_data(child_data) .build() .unwrap(); diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs index b68ac1b378fb..21e107ee4c8e 100644 --- a/arrow/src/csv/reader.rs +++ b/arrow/src/csv/reader.rs @@ -36,7 +36,7 @@ //! //! let file = File::open("test/data/uk_cities.csv").unwrap(); //! -//! let mut csv = csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None); +//! let mut csv = csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None, None); //! let batch = csv.next().unwrap().unwrap(); //! ``` @@ -50,17 +50,21 @@ use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; use crate::array::{ - ArrayRef, BooleanArray, DictionaryArray, PrimitiveArray, StringArray, + ArrayRef, BooleanArray, DecimalBuilder, DictionaryArray, PrimitiveArray, StringArray, }; -use crate::compute::kernels::cast_utils::string_to_timestamp_nanos; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; +use crate::util::reader_parser::Parser; use csv_crate::{ByteRecord, StringRecord}; +use std::ops::Neg; lazy_static! { - static ref DECIMAL_RE: Regex = Regex::new(r"^-?(\d+\.\d+)$").unwrap(); + static ref PARSE_DECIMAL_RE: Regex = + Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap(); + static ref DECIMAL_RE: Regex = + Regex::new(r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$").unwrap(); static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap(); static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$") .case_insensitive(true) @@ -72,7 +76,8 @@ lazy_static! { } /// Infer the data type of a record -fn infer_field_schema(string: &str) -> DataType { +fn infer_field_schema(string: &str, datetime_re: Option) -> DataType { + let datetime_re = datetime_re.unwrap_or_else(|| DATETIME_RE.clone()); // when quoting is enabled in the reader, these quotes aren't escaped, we default to // Utf8 for them if string.starts_with('"') { @@ -85,7 +90,7 @@ fn infer_field_schema(string: &str) -> DataType { DataType::Float64 } else if INTEGER_RE.is_match(string) { DataType::Int64 - } else if DATETIME_RE.is_match(string) { + } else if datetime_re.is_match(string) { DataType::Date64 } else if DATE_RE.is_match(string) { DataType::Date32 @@ -94,51 +99,50 @@ fn infer_field_schema(string: &str) -> DataType { } } +/// This is a collection of options for csv reader when the builder pattern cannot be used +/// and the parameters need to be passed around +#[derive(Debug, Default, Clone)] +pub struct ReaderOptions { + has_header: bool, + delimiter: Option, + escape: Option, + quote: Option, + terminator: Option, + max_read_records: Option, + datetime_re: Option, +} + /// Infer the schema of a CSV file by reading through the first n records of the file, /// with `max_read_records` controlling the maximum number of records to read. /// /// If `max_read_records` is not set, the whole file is read to infer its schema. /// -/// Return infered schema and number of records used for inference. This function does not change +/// Return inferred schema and number of records used for inference. This function does not change /// reader cursor offset. pub fn infer_file_schema( - reader: &mut R, + reader: R, delimiter: u8, max_read_records: Option, has_header: bool, ) -> Result<(Schema, usize)> { - infer_file_schema_with_csv_options( - reader, - delimiter, + let roptions = ReaderOptions { + delimiter: Some(delimiter), max_read_records, has_header, - None, - None, - None, - ) + ..Default::default() + }; + + infer_file_schema_with_csv_options(reader, roptions) } fn infer_file_schema_with_csv_options( - reader: &mut R, - delimiter: u8, - max_read_records: Option, - has_header: bool, - escape: Option, - quote: Option, - terminator: Option, + mut reader: R, + roptions: ReaderOptions, ) -> Result<(Schema, usize)> { let saved_offset = reader.seek(SeekFrom::Current(0))?; - let (schema, records_count) = infer_reader_schema_with_csv_options( - reader, - delimiter, - max_read_records, - has_header, - escape, - quote, - terminator, - )?; - + let (schema, records_count) = + infer_reader_schema_with_csv_options(&mut reader, roptions)?; // return the reader seek back to the start reader.seek(SeekFrom::Start(saved_offset))?; @@ -152,43 +156,36 @@ fn infer_file_schema_with_csv_options( /// /// Return infered schema and number of records used for inference. pub fn infer_reader_schema( - reader: &mut R, + reader: R, delimiter: u8, max_read_records: Option, has_header: bool, ) -> Result<(Schema, usize)> { - infer_reader_schema_with_csv_options( - reader, - delimiter, + let roptions = ReaderOptions { + delimiter: Some(delimiter), max_read_records, has_header, - None, - None, - None, - ) + ..Default::default() + }; + infer_reader_schema_with_csv_options(reader, roptions) } fn infer_reader_schema_with_csv_options( - reader: &mut R, - delimiter: u8, - max_read_records: Option, - has_header: bool, - escape: Option, - quote: Option, - terminator: Option, + reader: R, + roptions: ReaderOptions, ) -> Result<(Schema, usize)> { let mut csv_reader = Reader::build_csv_reader( reader, - has_header, - Some(delimiter), - escape, - quote, - terminator, + roptions.has_header, + roptions.delimiter, + roptions.escape, + roptions.quote, + roptions.terminator, ); // get or create header names // when has_header is false, creates default column names with column_ prefix - let headers: Vec = if has_header { + let headers: Vec = if roptions.has_header { let headers = &csv_reader.headers()?.clone(); headers.iter().map(|s| s.to_string()).collect() } else { @@ -208,7 +205,7 @@ fn infer_reader_schema_with_csv_options( let mut fields = vec![]; let mut record = StringRecord::new(); - let max_records = max_read_records.unwrap_or(usize::MAX); + let max_records = roptions.max_read_records.unwrap_or(usize::MAX); while records_count < max_records { if !csv_reader.read_record(&mut record)? { break; @@ -220,7 +217,8 @@ fn infer_reader_schema_with_csv_options( if string.is_empty() { nulls[i] = true; } else { - column_types[i].insert(infer_field_schema(string)); + column_types[i] + .insert(infer_field_schema(string, roptions.datetime_re.clone())); } } } @@ -271,7 +269,7 @@ pub fn infer_schema_from_files( has_header: bool, ) -> Result { let mut schemas = vec![]; - let mut records_to_read = max_read_records.unwrap_or(std::usize::MAX); + let mut records_to_read = max_read_records.unwrap_or(usize::MAX); for fname in files.iter() { let (schema, records_read) = infer_file_schema( @@ -312,6 +310,10 @@ pub struct Reader { batch_size: usize, /// Vector that can hold the `StringRecord`s of the batches batch_records: Vec, + /// datetime format used to parse datetime values, (format understood by chrono) + /// + /// For format refer to [chrono docs](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html) + datetime_format: Option, } impl fmt::Debug for Reader @@ -323,6 +325,7 @@ where .field("schema", &self.schema) .field("projection", &self.projection) .field("line_number", &self.line_number) + .field("datetime_format", &self.datetime_format) .finish() } } @@ -333,6 +336,7 @@ impl Reader { /// If reading a `File` or an input that supports `std::io::Read` and `std::io::Seek`; /// you can customise the Reader, such as to enable schema inference, use /// `ReaderBuilder`. + #[allow(clippy::too_many_arguments)] pub fn new( reader: R, schema: SchemaRef, @@ -341,9 +345,17 @@ impl Reader { batch_size: usize, bounds: Bounds, projection: Option>, + datetime_format: Option, ) -> Self { Self::from_reader( - reader, schema, has_header, delimiter, batch_size, bounds, projection, + reader, + schema, + has_header, + delimiter, + batch_size, + bounds, + projection, + datetime_format, ) } @@ -366,6 +378,7 @@ impl Reader { /// /// This constructor allows you more flexibility in what records are processed by the /// csv reader. + #[allow(clippy::too_many_arguments)] pub fn from_reader( reader: R, schema: SchemaRef, @@ -374,11 +387,18 @@ impl Reader { batch_size: usize, bounds: Bounds, projection: Option>, + datetime_format: Option, ) -> Self { let csv_reader = Self::build_csv_reader(reader, has_header, delimiter, None, None, None); Self::from_csv_reader( - csv_reader, schema, has_header, batch_size, bounds, projection, + csv_reader, + schema, + has_header, + batch_size, + bounds, + projection, + datetime_format, ) } @@ -413,6 +433,7 @@ impl Reader { batch_size: usize, bounds: Bounds, projection: Option>, + datetime_format: Option, ) -> Self { let (start, end) = match bounds { None => (0, usize::MAX), @@ -421,7 +442,7 @@ impl Reader { // First we will skip `start` rows // note that this skips by iteration. This is because in general it is not possible - // to seek in CSV. However, skiping still saves the burden of creating arrow arrays, + // to seek in CSV. However, skipping still saves the burden of creating arrow arrays, // which is a slow operation that scales with the number of columns let mut record = ByteRecord::new(); @@ -446,6 +467,7 @@ impl Reader { batch_size, end, batch_records, + datetime_format, } } } @@ -478,13 +500,19 @@ impl Iterator for Reader { return None; } + let format: Option<&str> = match self.datetime_format { + Some(ref format) => Some(format.as_ref()), + _ => None, + }; + // parse the batches into a RecordBatch let result = parse( &self.batch_records[..read_records], self.schema.fields(), Some(self.schema.metadata.clone()), - &self.projection, + self.projection.as_ref(), self.line_number, + format, ); self.line_number += read_records; @@ -493,16 +521,18 @@ impl Iterator for Reader { } } -/// parses a slice of [csv_crate::StringRecord] into a [array::record_batch::RecordBatch]. +/// parses a slice of [csv_crate::StringRecord] into a +/// [RecordBatch](crate::record_batch::RecordBatch). fn parse( rows: &[StringRecord], fields: &[Field], metadata: Option>, - projection: &Option>, + projection: Option<&Vec>, line_number: usize, + datetime_format: Option<&str>, ) -> Result { let projection: Vec = match projection { - Some(ref v) => v.clone(), + Some(v) => v.clone(), None => fields.iter().enumerate().map(|(i, _)| i).collect(), }; @@ -513,47 +543,63 @@ fn parse( let field = &fields[i]; match field.data_type() { DataType::Boolean => build_boolean_array(line_number, rows, i), - DataType::Int8 => build_primitive_array::(line_number, rows, i), + DataType::Decimal(precision, scale) => { + build_decimal_array(line_number, rows, i, *precision, *scale) + } + DataType::Int8 => { + build_primitive_array::(line_number, rows, i, None) + } DataType::Int16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::Int32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::Int64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::UInt8 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::UInt16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::UInt32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::UInt64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::Float32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::Float64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } DataType::Date32 => { - build_primitive_array::(line_number, rows, i) - } - DataType::Date64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, None) } - DataType::Timestamp(TimeUnit::Microsecond, _) => build_primitive_array::< - TimestampMicrosecondType, - >( - line_number, rows, i + DataType::Date64 => build_primitive_array::( + line_number, + rows, + i, + datetime_format, ), + DataType::Timestamp(TimeUnit::Microsecond, _) => { + build_primitive_array::( + line_number, + rows, + i, + None, + ) + } DataType::Timestamp(TimeUnit::Nanosecond, _) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + None, + ) } DataType::Utf8 => Ok(Arc::new( rows.iter().map(|row| row.get(i)).collect::(), @@ -626,105 +672,161 @@ fn parse( arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr)) } - -/// Specialized parsing implementations -trait Parser: ArrowPrimitiveType { - fn parse(string: &str) -> Option { - string.parse::().ok() - } +fn parse_item(string: &str) -> Option { + T::parse(string) } -impl Parser for Float32Type { - fn parse(string: &str) -> Option { - lexical_core::parse(string.as_bytes()).ok() - } +fn parse_formatted(string: &str, format: &str) -> Option { + T::parse_formatted(string, format) } -impl Parser for Float64Type { - fn parse(string: &str) -> Option { - lexical_core::parse(string.as_bytes()).ok() +fn parse_bool(string: &str) -> Option { + if string.eq_ignore_ascii_case("false") { + Some(false) + } else if string.eq_ignore_ascii_case("true") { + Some(true) + } else { + None } } -impl Parser for UInt64Type {} - -impl Parser for UInt32Type {} - -impl Parser for UInt16Type {} - -impl Parser for UInt8Type {} - -impl Parser for Int64Type {} - -impl Parser for Int32Type {} - -impl Parser for Int16Type {} - -impl Parser for Int8Type {} - -/// Number of days between 0001-01-01 and 1970-01-01 -const EPOCH_DAYS_FROM_CE: i32 = 719_163; - -impl Parser for Date32Type { - fn parse(string: &str) -> Option { - use chrono::Datelike; - - match Self::DATA_TYPE { - DataType::Date32 => { - let date = string.parse::().ok()?; - Self::Native::from_i32(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) +// parse the column string to an Arrow Array +fn build_decimal_array( + _line_number: usize, + rows: &[StringRecord], + col_idx: usize, + precision: usize, + scale: usize, +) -> Result { + let mut decimal_builder = DecimalBuilder::new(rows.len(), precision, scale); + for row in rows { + let col_s = row.get(col_idx); + match col_s { + None => { + // No data for this row + decimal_builder.append_null()?; + } + Some(s) => { + if s.is_empty() { + // append null + decimal_builder.append_null()?; + } else { + let decimal_value: Result = + parse_decimal_with_parameter(s, precision, scale); + match decimal_value { + Ok(v) => { + decimal_builder.append_value(v)?; + } + Err(e) => { + return Err(e); + } + } + } } - _ => None, } } + Ok(Arc::new(decimal_builder.finish())) } -impl Parser for Date64Type { - fn parse(string: &str) -> Option { - match Self::DATA_TYPE { - DataType::Date64 => { - let date_time = string.parse::().ok()?; - Self::Native::from_i64(date_time.timestamp_millis()) +// Parse the string format decimal value to i128 format and checking the precision and scale. +// The result i128 value can't be out of bounds. +fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> Result { + if PARSE_DECIMAL_RE.is_match(s) { + let mut offset = s.len(); + let len = s.len(); + let mut base = 1; + + // handle the value after the '.' and meet the scale + let delimiter_position = s.find('.'); + match delimiter_position { + None => { + // there is no '.' + base = 10_i128.pow(scale as u32); } - _ => None, - } - } -} + Some(mid) => { + // there is the '.' + if len - mid >= scale + 1 { + // If the string value is "123.12345" and the scale is 2, we should just remain '.12' and drop the '345' value. + offset -= len - mid - 1 - scale; + } else { + // If the string value is "123.12" and the scale is 4, we should append '00' to the tail. + base = 10_i128.pow((scale + 1 + mid - len) as u32); + } + } + }; + + // each byte is digit、'-' or '.' + let bytes = s.as_bytes(); + let mut negative = false; + let mut result: i128 = 0; -impl Parser for TimestampNanosecondType { - fn parse(string: &str) -> Option { - match Self::DATA_TYPE { - DataType::Timestamp(TimeUnit::Nanosecond, None) => { - string_to_timestamp_nanos(string).ok() + bytes[0..offset].iter().rev().for_each(|&byte| match byte { + b'-' => { + negative = true; } - _ => None, + b'0'..=b'9' => { + result += i128::from(byte - b'0') * base; + base *= 10; + } + // because of the PARSE_DECIMAL_RE, bytes just contains digit、'-' and '.'. + _ => {} + }); + + if negative { + result = result.neg(); } + validate_decimal_precision(result, precision) + .map_err(|e| ArrowError::ParseError(format!("parse decimal overflow: {}", e))) + } else { + Err(ArrowError::ParseError(format!( + "can't parse the string value {} to decimal", + s + ))) } } -impl Parser for TimestampMicrosecondType { - fn parse(string: &str) -> Option { - match Self::DATA_TYPE { - DataType::Timestamp(TimeUnit::Microsecond, None) => { - let nanos = string_to_timestamp_nanos(string).ok(); - nanos.map(|x| x / 1000) +// Parse the string format decimal value to i128 format without checking the precision and scale. +// Like "125.12" to 12512_i128. +#[cfg(test)] +fn parse_decimal(s: &str) -> Result { + if PARSE_DECIMAL_RE.is_match(s) { + let mut offset = s.len(); + // each byte is digit、'-' or '.' + let bytes = s.as_bytes(); + let mut negative = false; + let mut result: i128 = 0; + let mut base = 1; + while offset > 0 { + match bytes[offset - 1] { + b'-' => { + negative = true; + } + b'.' => { + // do nothing + } + b'0'..=b'9' => { + result += i128::from(bytes[offset - 1] - b'0') * base; + base *= 10; + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't match byte {}", + bytes[offset - 1] + ))); + } } - _ => None, + offset -= 1; + } + if negative { + Ok(result.neg()) + } else { + Ok(result) } - } -} - -fn parse_item(string: &str) -> Option { - T::parse(string) -} - -fn parse_bool(string: &str) -> Option { - if string.eq_ignore_ascii_case("false") { - Some(false) - } else if string.eq_ignore_ascii_case("true") { - Some(true) } else { - None + Err(ArrowError::ParseError(format!( + "can't parse the string value {} to decimal", + s + ))) } } @@ -733,6 +835,7 @@ fn build_primitive_array( line_number: usize, rows: &[StringRecord], col_idx: usize, + format: Option<&str>, ) -> Result { rows.iter() .enumerate() @@ -743,7 +846,10 @@ fn build_primitive_array( return Ok(None); } - let parsed = parse_item::(s); + let parsed = match format { + Some(format) => parse_formatted::(s, format), + _ => parse_item::(s), + }; match parsed { Some(e) => Ok(Some(e)), None => Err(ArrowError::ParseError(format!( @@ -811,9 +917,9 @@ pub struct ReaderBuilder { has_header: bool, /// An optional column delimiter. Defaults to `b','` delimiter: Option, - /// An optional escape charactor. Defaults None + /// An optional escape character. Defaults None escape: Option, - /// An optional quote charactor. Defaults b'\"' + /// An optional quote character. Defaults b'\"' quote: Option, /// An optional record terminator. Defaults CRLF terminator: Option, @@ -829,6 +935,10 @@ pub struct ReaderBuilder { bounds: Bounds, /// Optional projection for which columns to load (zero-based column indices) projection: Option>, + /// DateTime format to be used while trying to infer datetime format + datetime_re: Option, + /// DateTime format to be used while parsing datetime format + datetime_format: Option, } impl Default for ReaderBuilder { @@ -844,6 +954,8 @@ impl Default for ReaderBuilder { batch_size: 1024, bounds: None, projection: None, + datetime_re: None, + datetime_format: None, } } } @@ -888,6 +1000,23 @@ impl ReaderBuilder { self } + /// Set the datetime regex used to parse the string to Date64Type + /// this regex is used while infering schema + pub fn with_datetime_re(mut self, datetime_re: Regex) -> Self { + self.datetime_re = Some(datetime_re); + self + } + + /// Set the datetime fromat used to parse the string to Date64Type + /// this fromat is used while when the schema wants to parse Date64Type. + /// + /// For format refer to [chrono docs](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html) + /// + pub fn with_datetime_format(mut self, datetime_format: String) -> Self { + self.datetime_format = Some(datetime_format); + self + } + /// Set the CSV file's column delimiter as a byte character pub fn with_delimiter(mut self, delimiter: u8) -> Self { self.delimiter = Some(delimiter); @@ -923,6 +1052,13 @@ impl ReaderBuilder { self } + /// Set the bounds over which to scan the reader. + /// `start` and `end` are line numbers. + pub fn with_bounds(mut self, start: usize, end: usize) -> Self { + self.bounds = Some((start, end)); + self + } + /// Set the reader's column projection pub fn with_projection(mut self, projection: Vec) -> Self { self.projection = Some(projection); @@ -936,15 +1072,17 @@ impl ReaderBuilder { let schema = match self.schema { Some(schema) => schema, None => { - let (inferred_schema, _) = infer_file_schema_with_csv_options( - &mut reader, - delimiter, - self.max_records, - self.has_header, - self.escape, - self.quote, - self.terminator, - )?; + let roptions = ReaderOptions { + delimiter: Some(delimiter), + max_read_records: self.max_records, + has_header: self.has_header, + escape: self.escape, + quote: self.quote, + terminator: self.terminator, + datetime_re: self.datetime_re, + }; + let (inferred_schema, _) = + infer_file_schema_with_csv_options(&mut reader, roptions)?; Arc::new(inferred_schema) } @@ -962,8 +1100,9 @@ impl ReaderBuilder { schema, self.has_header, self.batch_size, - None, + self.bounds, self.projection.clone(), + self.datetime_format, )) } } @@ -983,44 +1122,49 @@ mod tests { #[test] fn test_csv() { - let schema = Schema::new(vec![ - Field::new("city", DataType::Utf8, false), - Field::new("lat", DataType::Float64, false), - Field::new("lng", DataType::Float64, false), - ]); - - let file = File::open("test/data/uk_cities.csv").unwrap(); - - let mut csv = Reader::new( - file, - Arc::new(schema.clone()), - false, - None, - 1024, - None, - None, - ); - assert_eq!(Arc::new(schema), csv.schema()); - let batch = csv.next().unwrap().unwrap(); - assert_eq!(37, batch.num_rows()); - assert_eq!(3, batch.num_columns()); - - // access data from a primitive array - let lat = batch - .column(1) - .as_any() - .downcast_ref::() - .unwrap(); - assert!(57.653484 - lat.value(0) < f64::EPSILON); - - // access data from a string array (ListArray) - let city = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - - assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + let _: Vec<()> = vec![None, Some("%Y-%m-%dT%H:%M:%S%.f%:z".to_string())] + .into_iter() + .map(|format| { + let schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ]); + + let file = File::open("test/data/uk_cities.csv").unwrap(); + let mut csv = Reader::new( + file, + Arc::new(schema.clone()), + false, + None, + 1024, + None, + None, + format, + ); + assert_eq!(Arc::new(schema), csv.schema()); + let batch = csv.next().unwrap().unwrap(); + assert_eq!(37, batch.num_rows()); + assert_eq!(3, batch.num_columns()); + + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(57.653484, lat.value(0)); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); + }) + .collect(); } #[test] @@ -1046,6 +1190,7 @@ mod tests { 1024, None, None, + None, ); assert_eq!(Arc::new(schema), csv.schema()); let batch = csv.next().unwrap().unwrap(); @@ -1055,6 +1200,38 @@ mod tests { assert_eq!(&metadata, batch.schema().metadata()); } + #[test] + fn test_csv_reader_with_decimal() { + let schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Decimal(38, 6), false), + Field::new("lng", DataType::Decimal(38, 6), false), + ]); + + let file = File::open("test/data/decimal_test.csv").unwrap(); + + let mut csv = + Reader::new(file, Arc::new(schema), false, None, 1024, None, None, None); + let batch = csv.next().unwrap().unwrap(); + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("57.653484", lat.value_as_string(0)); + assert_eq!("53.002666", lat.value_as_string(1)); + assert_eq!("52.412811", lat.value_as_string(2)); + assert_eq!("51.481583", lat.value_as_string(3)); + assert_eq!("12.123456", lat.value_as_string(4)); + assert_eq!("50.760000", lat.value_as_string(5)); + assert_eq!("0.123000", lat.value_as_string(6)); + assert_eq!("123.000000", lat.value_as_string(7)); + assert_eq!("123.000000", lat.value_as_string(8)); + assert_eq!("-50.760000", lat.value_as_string(9)); + } + #[test] fn test_csv_from_buf_reader() { let schema = Schema::new(vec![ @@ -1077,6 +1254,7 @@ mod tests { 1024, None, None, + None, ); let batch = csv.next().unwrap().unwrap(); assert_eq!(74, batch.num_rows()); @@ -1106,7 +1284,7 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - assert!(57.653484 - lat.value(0) < f64::EPSILON); + assert_eq!(57.653484, lat.value(0)); // access data from a string array (ListArray) let city = batch @@ -1144,7 +1322,7 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - assert!(57.653484 - lat.value(0) < f64::EPSILON); + assert_eq!(57.653484, lat.value(0)); // access data from a string array (ListArray) let city = batch @@ -1156,6 +1334,30 @@ mod tests { assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13)); } + #[test] + fn test_csv_builder_with_bounds() { + let file = File::open("test/data/uk_cities.csv").unwrap(); + + // Set the bounds to the lines 0, 1 and 2. + let mut csv = ReaderBuilder::new().with_bounds(0, 2).build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + // access data from a string array (ListArray) + let city = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // The value on line 0 is within the bounds + assert_eq!("Elgin, Scotland, the UK", city.value(0)); + + // The value on line 13 is outside of the bounds. Therefore + // the call to .value() will panic. + let result = std::panic::catch_unwind(|| city.value(13)); + assert!(result.is_err()); + } + #[test] fn test_csv_with_projection() { let schema = Schema::new(vec![ @@ -1174,6 +1376,7 @@ mod tests { 1024, None, Some(vec![0, 1]), + None, ); let projected_schema = Arc::new(Schema::new(vec![ Field::new("city", DataType::Utf8, false), @@ -1208,6 +1411,7 @@ mod tests { 1024, None, Some(vec![0, 1]), + None, ); let projected_schema = Arc::new(Schema::new(vec![ Field::new( @@ -1241,7 +1445,8 @@ mod tests { let file = File::open("test/data/null_test.csv").unwrap(); - let mut csv = Reader::new(file, Arc::new(schema), true, None, 1024, None, None); + let mut csv = + Reader::new(file, Arc::new(schema), true, None, 1024, None, None, None); let batch = csv.next().unwrap().unwrap(); assert!(!batch.column(1).is_null(0)); @@ -1265,7 +1470,7 @@ mod tests { let mut csv = builder.build(file).unwrap(); let batch = csv.next().unwrap().unwrap(); - assert_eq!(5, batch.num_rows()); + assert_eq!(7, batch.num_rows()); assert_eq!(6, batch.num_columns()); let schema = batch.schema(); @@ -1338,14 +1543,31 @@ mod tests { #[test] fn test_infer_field_schema() { - assert_eq!(infer_field_schema("A"), DataType::Utf8); - assert_eq!(infer_field_schema("\"123\""), DataType::Utf8); - assert_eq!(infer_field_schema("10"), DataType::Int64); - assert_eq!(infer_field_schema("10.2"), DataType::Float64); - assert_eq!(infer_field_schema("true"), DataType::Boolean); - assert_eq!(infer_field_schema("false"), DataType::Boolean); - assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32); - assert_eq!(infer_field_schema("2020-11-08T14:20:01"), DataType::Date64); + assert_eq!(infer_field_schema("A", None), DataType::Utf8); + assert_eq!(infer_field_schema("\"123\"", None), DataType::Utf8); + assert_eq!(infer_field_schema("10", None), DataType::Int64); + assert_eq!(infer_field_schema("10.2", None), DataType::Float64); + assert_eq!(infer_field_schema(".2", None), DataType::Float64); + assert_eq!(infer_field_schema("2.", None), DataType::Float64); + assert_eq!(infer_field_schema("true", None), DataType::Boolean); + assert_eq!(infer_field_schema("false", None), DataType::Boolean); + assert_eq!(infer_field_schema("2020-11-08", None), DataType::Date32); + assert_eq!( + infer_field_schema("2020-11-08T14:20:01", None), + DataType::Date64 + ); + // to be inferred as a date64 this needs a custom datetime_re + assert_eq!( + infer_field_schema("2020-11-08 14:20:01", None), + DataType::Utf8 + ); + let reg = Regex::new(r"^\d{4}-\d\d-\d\d \d\d:\d\d:\d\d$").ok(); + assert_eq!( + infer_field_schema("2020-11-08 14:20:01", reg), + DataType::Date64 + ); + assert_eq!(infer_field_schema("-5.13", None), DataType::Float64); + assert_eq!(infer_field_schema("0.1300", None), DataType::Float64); } #[test] @@ -1370,6 +1592,90 @@ mod tests { parse_item::("1900-02-28T12:34:56").unwrap(), -2203932304000 ); + assert_eq!( + parse_formatted::("1900-02-28 12:34:56", "%Y-%m-%d %H:%M:%S") + .unwrap(), + -2203932304000 + ); + assert_eq!( + parse_formatted::( + "1900-02-28 12:34:56+0030", + "%Y-%m-%d %H:%M:%S%z" + ) + .unwrap(), + -2203932304000 - (30 * 60 * 1000) + ); + } + + #[test] + fn test_parse_decimal() { + let tests = [ + ("123.00", 12300i128), + ("123.123", 123123i128), + ("0.0123", 123i128), + ("0.12300", 12300i128), + ("-5.123", -5123i128), + ("-45.432432", -45432432i128), + ]; + for (s, i) in tests { + let result = parse_decimal(s); + assert_eq!(i, result.unwrap()); + } + } + + #[test] + fn test_parse_decimal_with_parameter() { + let tests = [ + ("123.123", 123123i128), + ("123.1234", 123123i128), + ("123.1", 123100i128), + ("123", 123000i128), + ("-123.123", -123123i128), + ("-123.1234", -123123i128), + ("-123.1", -123100i128), + ("-123", -123000i128), + ("0.0000123", 0i128), + ("12.", 12000i128), + ("-12.", -12000i128), + ("00.1", 100i128), + ("-00.1", -100i128), + ("12345678912345678.1234", 12345678912345678123i128), + ("-12345678912345678.1234", -12345678912345678123i128), + ("99999999999999999.999", 99999999999999999999i128), + ("-99999999999999999.999", -99999999999999999999i128), + (".123", 123i128), + ("-.123", -123i128), + ("123.", 123000i128), + ("-123.", -123000i128), + ]; + for (s, i) in tests { + let result = parse_decimal_with_parameter(s, 20, 3); + assert_eq!(i, result.unwrap()) + } + let can_not_parse_tests = ["123,123", ".", "123.123.123"]; + for s in can_not_parse_tests { + let result = parse_decimal_with_parameter(s, 20, 3); + assert_eq!( + format!( + "Parser error: can't parse the string value {} to decimal", + s + ), + result.unwrap_err().to_string() + ); + } + let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"]; + for s in overflow_parse_tests { + let result = parse_decimal_with_parameter(s, 10, 3); + let expected = "Parser error: parse decimal overflow"; + let actual = result.unwrap_err().to_string(); + + assert!( + actual.contains(&expected), + "actual: '{}', expected: '{}'", + actual, + expected + ); + } } /// Interprets a naive_datetime (with no explicit timezone offset) @@ -1473,6 +1779,7 @@ mod tests { writeln!(csv1, "c1,c2,c3")?; writeln!(csv1, "1,\"foo\",0.5")?; writeln!(csv1, "3,\"bar\",1")?; + writeln!(csv1, "3,\"bar\",2e-06")?; // reading csv2 will set c2 to optional writeln!(csv2, "c1,c2,c3,c4")?; writeln!(csv2, "10,,3.14,true")?; @@ -1488,7 +1795,7 @@ mod tests { csv4.path().to_str().unwrap().to_string(), ], b',', - Some(3), // only csv1 and csv2 should be read + Some(4), // only csv1 and csv2 should be read true, )?; @@ -1537,6 +1844,7 @@ mod tests { // starting at row 2 and up to row 6. Some((2, 6)), Some(vec![0]), + None, ); let batch = csv.next().unwrap().unwrap(); @@ -1578,6 +1886,8 @@ mod tests { assert_eq!(Some(-12.34), parse_item::("-12.34")); assert_eq!(Some(12.0), parse_item::("12")); assert_eq!(Some(0.0), parse_item::("0")); + assert_eq!(Some(2.0), parse_item::("2.")); + assert_eq!(Some(0.2), parse_item::(".2")); assert!(parse_item::("nan").unwrap().is_nan()); assert!(parse_item::("NaN").unwrap().is_nan()); assert!(parse_item::("inf").unwrap().is_infinite()); diff --git a/arrow/src/csv/writer.rs b/arrow/src/csv/writer.rs index c6e49f0d6a3d..6735d9668560 100644 --- a/arrow/src/csv/writer.rs +++ b/arrow/src/csv/writer.rs @@ -35,7 +35,7 @@ //! Field::new("c1", DataType::Utf8, false), //! Field::new("c2", DataType::Float64, true), //! Field::new("c3", DataType::UInt32, false), -//! Field::new("c3", DataType::Boolean, true), +//! Field::new("c4", DataType::Boolean, true), //! ]); //! let c1 = StringArray::from(vec![ //! "Lorem ipsum dolor sit amet", @@ -67,6 +67,11 @@ use std::io::Write; +#[cfg(feature = "chrono-tz")] +use crate::compute::kernels::temporal::using_chrono_tz_and_utc_naive_date_time; +#[cfg(feature = "chrono-tz")] +use chrono::{DateTime, Utc}; + use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; @@ -75,6 +80,7 @@ use crate::{array::*, util::serialization::lexical_to_string}; const DEFAULT_DATE_FORMAT: &str = "%F"; const DEFAULT_TIME_FORMAT: &str = "%T"; const DEFAULT_TIMESTAMP_FORMAT: &str = "%FT%H:%M:%S.%9f"; +const DEFAULT_TIMESTAMP_TZ_FORMAT: &str = "%FT%H:%M:%S.%9f%:z"; fn write_primitive_value(array: &ArrayRef, i: usize) -> String where @@ -90,8 +96,6 @@ where pub struct Writer { /// The object to write to writer: csv_crate::Writer, - /// Column delimiter. Defaults to `b','` - delimiter: u8, /// Whether file should be written with headers. Defaults to `true` has_headers: bool, /// The date format for date arrays @@ -100,6 +104,9 @@ pub struct Writer { datetime_format: String, /// The timestamp format for timestamp arrays timestamp_format: String, + /// The timestamp format for timestamp (with timezone) arrays + #[allow(dead_code)] + timestamp_tz_format: String, /// The time format for time arrays time_format: String, /// Is the beginning-of-writer @@ -114,12 +121,12 @@ impl Writer { let writer = builder.delimiter(delimiter).from_writer(writer); Writer { writer, - delimiter, has_headers: true, date_format: DEFAULT_DATE_FORMAT.to_string(), datetime_format: DEFAULT_TIMESTAMP_FORMAT.to_string(), time_format: DEFAULT_TIME_FORMAT.to_string(), timestamp_format: DEFAULT_TIMESTAMP_FORMAT.to_string(), + timestamp_tz_format: DEFAULT_TIMESTAMP_TZ_FORMAT.to_string(), beginning: true, } } @@ -213,35 +220,8 @@ impl Writer { .format(&self.time_format) .to_string() } - DataType::Timestamp(time_unit, _) => { - use TimeUnit::*; - let datetime = match time_unit { - Second => col - .as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(row_index) - .unwrap(), - Millisecond => col - .as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(row_index) - .unwrap(), - Microsecond => col - .as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(row_index) - .unwrap(), - Nanosecond => col - .as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(row_index) - .unwrap(), - }; - format!("{}", datetime.format(&self.timestamp_format)) + DataType::Timestamp(time_unit, time_zone) => { + self.handle_timestamp(time_unit, time_zone.as_ref(), row_index, col)? } DataType::Decimal(..) => make_string_from_decimal(col, row_index)?, t => { @@ -258,6 +238,102 @@ impl Writer { Ok(()) } + #[cfg(not(feature = "chrono-tz"))] + fn handle_timestamp( + &self, + time_unit: &TimeUnit, + _time_zone: Option<&String>, + row_index: usize, + col: &ArrayRef, + ) -> Result { + use TimeUnit::*; + let datetime = match time_unit { + Second => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + Millisecond => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + Microsecond => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + Nanosecond => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + }; + Ok(format!("{}", datetime.format(&self.timestamp_format))) + } + + #[cfg(feature = "chrono-tz")] + fn handle_timestamp( + &self, + time_unit: &TimeUnit, + time_zone: Option<&String>, + row_index: usize, + col: &ArrayRef, + ) -> Result { + use TimeUnit::*; + + let datetime = match time_unit { + Second => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + Millisecond => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + Microsecond => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + Nanosecond => col + .as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(row_index) + .unwrap(), + }; + let tzs = match time_zone { + None => "UTC".to_string(), + Some(tzs) => tzs.to_string(), + }; + + match using_chrono_tz_and_utc_naive_date_time(&tzs, datetime) { + Some(tz) => { + let utc_time = DateTime::::from_utc(datetime, Utc); + Ok(format!( + "{}", + utc_time + .with_timezone(&tz) + .format(&self.timestamp_tz_format) + )) + } + err => Err(ArrowError::ComputeError(format!( + "{}: {:?}", + "Unable to parse timezone", err + ))), + } + } + /// Write a vector of record batches to a writable object pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { let num_columns = batch.num_columns(); @@ -311,6 +387,8 @@ pub struct WriterBuilder { datetime_format: Option, /// Optional timestamp format for timestamp arrays timestamp_format: Option, + /// Optional timestamp format for timestamp with timezone arrays + timestamp_tz_format: Option, /// Optional time format for time arrays time_format: Option, } @@ -324,6 +402,7 @@ impl Default for WriterBuilder { datetime_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), time_format: Some(DEFAULT_TIME_FORMAT.to_string()), timestamp_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), + timestamp_tz_format: Some(DEFAULT_TIMESTAMP_TZ_FORMAT.to_string()), } } } @@ -373,6 +452,12 @@ impl WriterBuilder { self } + /// Set the CSV file's datetime format + pub fn with_datetime_format(mut self, format: String) -> Self { + self.datetime_format = Some(format); + self + } + /// Set the CSV file's time format pub fn with_time_format(mut self, format: String) -> Self { self.time_format = Some(format); @@ -392,7 +477,6 @@ impl WriterBuilder { let writer = builder.delimiter(delimiter).from_writer(writer); Writer { writer, - delimiter, has_headers: self.has_headers, date_format: self .date_format @@ -406,6 +490,9 @@ impl WriterBuilder { timestamp_format: self .timestamp_format .unwrap_or_else(|| DEFAULT_TIMESTAMP_FORMAT.to_string()), + timestamp_tz_format: self + .timestamp_tz_format + .unwrap_or_else(|| DEFAULT_TIMESTAMP_TZ_FORMAT.to_string()), beginning: true, } } @@ -417,6 +504,7 @@ mod tests { use crate::csv::Reader; use crate::datatypes::{Field, Schema}; + #[cfg(feature = "chrono-tz")] use crate::util::string_writer::StringWriter; use crate::util::test_util::get_temp_file; use std::fs::File; @@ -485,7 +573,16 @@ mod tests { let mut buffer: Vec = vec![]; file.read_to_end(&mut buffer).unwrap(); - assert_eq!( + let expected = if cfg!(feature = "chrono-tz") { + r#"c1,c2,c3,c4,c5,c6,c7 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000+00:00,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000+00:00,23:46:03,foo +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000+00:00,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000+00:00,23:46:03,foo +"# + } else { r#"c1,c2,c3,c4,c5,c6,c7 Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes @@ -494,9 +591,8 @@ Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo "# - .to_string(), - String::from_utf8(buffer).unwrap() - ); + }; + assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap()); } #[test] @@ -559,73 +655,53 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo ); } + #[cfg(feature = "chrono-tz")] #[test] - fn test_export_csv_string() { + fn test_export_csv_timestamps() { let schema = Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::UInt32, false), - Field::new("c4", DataType::Boolean, true), - Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), - Field::new("c6", DataType::Time32(TimeUnit::Second), false), - Field::new("c7", DataType::Decimal(6, 2), false), + Field::new( + "c1", + DataType::Timestamp( + TimeUnit::Millisecond, + Some("Australia/Sydney".to_string()), + ), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Millisecond, None), true), ]); - let c1 = StringArray::from(vec![ - "Lorem ipsum dolor sit amet", - "consectetur adipiscing elit", - "sed do eiusmod tempor", - ]); - let c2 = PrimitiveArray::::from(vec![ - Some(123.564532), - None, - Some(-556132.25), - ]); - let c3 = PrimitiveArray::::from(vec![3, 2, 1]); - let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); - let c5 = TimestampMillisecondArray::from_opt_vec( - vec![None, Some(1555584887378), Some(1555555555555)], + let c1 = TimestampMillisecondArray::from_opt_vec( + // 1555584887 converts to 2019-04-18, 20:54:47 in time zone Australia/Sydney (AEST). + // The offset (difference to UTC) is +10:00. + // 1635577147 converts to 2021-10-30 17:59:07 in time zone Australia/Sydney (AEDT) + // The offset (difference to UTC) is +11:00. Note that daylight savings is in effect on 2021-10-30. + // + vec![Some(1555584887378), Some(1635577147000)], + Some("Australia/Sydney".to_string()), + ); + let c2 = TimestampMillisecondArray::from_opt_vec( + vec![Some(1555584887378), Some(1635577147000)], None, ); - let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); - let mut c7_builder = DecimalBuilder::new(5, 6, 2); - c7_builder.append_value(12345_i128).unwrap(); - c7_builder.append_value(-12345_i128).unwrap(); - c7_builder.append_null().unwrap(); - let c7 = c7_builder.finish(); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(c1), - Arc::new(c2), - Arc::new(c3), - Arc::new(c4), - Arc::new(c5), - Arc::new(c6), - Arc::new(c7), - ], - ) - .unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]) + .unwrap(); let sw = StringWriter::new(); let mut writer = Writer::new(sw); - let batches = vec![&batch, &batch]; + let batches = vec![&batch]; for batch in batches { writer.write(batch).unwrap(); } - let left = "c1,c2,c3,c4,c5,c6,c7 -Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,123.45 -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,-123.45 -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03, -Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,123.45 -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,-123.45 -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,\n"; + let left = "c1,c2 +2019-04-18T20:54:47.378000000+10:00,2019-04-18T10:54:47.378000000+00:00 +2021-10-30T17:59:07.000000000+11:00,2021-10-30T06:59:07.000000000+00:00\n"; let right = writer.writer.into_inner().map(|s| s.to_string()); assert_eq!(Some(left.to_string()), right.ok()); } + #[cfg(not(feature = "chrono-tz"))] #[test] fn test_conversion_consistency() { // test if we can serialize and deserialize whilst retaining the same type information/ precision @@ -663,6 +739,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,\n"; // starting at row 2 and up to row 6. None, None, + None, ); let rb = reader.next().unwrap().unwrap(); let c1 = rb.column(0).as_any().downcast_ref::().unwrap(); @@ -675,4 +752,71 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,\n"; let expected = vec![Some(3), Some(2), Some(1)]; assert_eq!(actual, expected); } + + #[cfg(feature = "chrono-tz")] + #[test] + fn test_conversion_consistency() { + // test if we can serialize and deserialize whilst retaining the same type information/ precision + + let schema = Schema::new(vec![ + Field::new("c1", DataType::Date32, false), + Field::new("c2", DataType::Date64, false), + Field::new("c3", DataType::Timestamp(TimeUnit::Nanosecond, None), false), + ]); + + let nanoseconds = vec![ + 1599566300000000000, + 1599566200000000000, + 1599566100000000000, + ]; + let c1 = Date32Array::from(vec![3, 2, 1]); + let c2 = Date64Array::from(vec![3, 2, 1]); + let c3 = TimestampNanosecondArray::from_vec(nanoseconds.clone(), None); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], + ) + .unwrap(); + + let builder = WriterBuilder::new().has_headers(false); + + let mut buf: Cursor> = Default::default(); + // drop the writer early to release the borrow. + { + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + } + buf.set_position(0); + + let mut reader = Reader::new( + buf, + Arc::new(schema), + false, + None, + 3, + // starting at row 2 and up to row 6. + None, + None, + None, + ); + let rb = reader.next().unwrap().unwrap(); + let c1 = rb.column(0).as_any().downcast_ref::().unwrap(); + let c2 = rb.column(1).as_any().downcast_ref::().unwrap(); + let c3 = rb + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + let actual = c1.into_iter().collect::>(); + let expected = vec![Some(3), Some(2), Some(1)]; + assert_eq!(actual, expected); + let actual = c2.into_iter().collect::>(); + let expected = vec![Some(3), Some(2), Some(1)]; + assert_eq!(actual, expected); + let actual = c3.into_iter().collect::>(); + let expected = nanoseconds.into_iter().map(|x| Some(x)).collect::>(); + assert_eq!(actual, expected); + } } diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index 1cbec341cf37..895e5cc67c38 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -114,8 +114,12 @@ pub enum DataType { LargeList(Box), /// A nested datatype that contains a number of sub-fields. Struct(Vec), - /// A nested datatype that can represent slots of differing types. - Union(Vec), + /// A nested datatype that can represent slots of differing types. Components: + /// + /// 1. [`Field`] for each possible child type the Union can hold + /// 2. The corresponding `type_id` used to identify which Field + /// 3. The type of union (Sparse or Dense) + Union(Vec, Vec, UnionMode), /// A dictionary encoded array (`key_type`, `value_type`), where /// each array element is an index of `key_type` into an /// associated dictionary of `value_type`. @@ -127,7 +131,12 @@ pub enum DataType { /// This type mostly used to represent low cardinality string /// arrays or a limited set of primitive types as integers. Dictionary(Box, Box), - /// Decimal value with precision and scale + /// Exact decimal value with precision and scale + /// + /// * precision is the total number of digits + /// * scale is the number of digits past the decimal + /// + /// For example the number 123.45 has precision 5 and scale 2. Decimal(usize, usize), /// A Map is a logical nested type that is represented as /// @@ -158,7 +167,7 @@ pub enum TimeUnit { Nanosecond, } -/// YEAR_MONTH or DAY_TIME interval in SQL style. +/// YEAR_MONTH, DAY_TIME, MONTH_DAY_NANO interval in SQL style. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum IntervalUnit { /// Indicates the number of elapsed whole months, stored as 4-byte integers. @@ -166,6 +175,21 @@ pub enum IntervalUnit { /// Indicates the number of elapsed days and milliseconds, /// stored as 2 contiguous 32-bit integers (days, milliseconds) (8-bytes in total). DayTime, + /// A triple of the number of elapsed months, days, and nanoseconds. + /// The values are stored contiguously in 16 byte blocks. Months and + /// days are encoded as 32 bit integers and nanoseconds is encoded as a + /// 64 bit integer. All integers are signed. Each field is independent + /// (e.g. there is no constraint that nanoseconds have the same sign + /// as days or that the quantity of nanoseconds represents less + /// than a day's worth of time). + MonthDayNano, +} + +// Sparse or Dense union layouts +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum UnionMode { + Sparse, + Dense, } impl fmt::Display for DataType { @@ -174,6 +198,123 @@ impl fmt::Display for DataType { } } +/// `MAX_DECIMAL_FOR_EACH_PRECISION[p]` holds the maximum `i128` value +/// that can be stored in [DataType::Decimal] value of precision `p` +pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ + 9, + 99, + 999, + 9999, + 99999, + 999999, + 9999999, + 99999999, + 999999999, + 9999999999, + 99999999999, + 999999999999, + 9999999999999, + 99999999999999, + 999999999999999, + 9999999999999999, + 99999999999999999, + 999999999999999999, + 9999999999999999999, + 99999999999999999999, + 999999999999999999999, + 9999999999999999999999, + 99999999999999999999999, + 999999999999999999999999, + 9999999999999999999999999, + 99999999999999999999999999, + 999999999999999999999999999, + 9999999999999999999999999999, + 99999999999999999999999999999, + 999999999999999999999999999999, + 9999999999999999999999999999999, + 99999999999999999999999999999999, + 999999999999999999999999999999999, + 9999999999999999999999999999999999, + 99999999999999999999999999999999999, + 999999999999999999999999999999999999, + 9999999999999999999999999999999999999, + 170141183460469231731687303715884105727, +]; + +/// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value +/// that can be stored in a [DataType::Decimal] value of precision `p` +pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ + -9, + -99, + -999, + -9999, + -99999, + -999999, + -9999999, + -99999999, + -999999999, + -9999999999, + -99999999999, + -999999999999, + -9999999999999, + -99999999999999, + -999999999999999, + -9999999999999999, + -99999999999999999, + -999999999999999999, + -9999999999999999999, + -99999999999999999999, + -999999999999999999999, + -9999999999999999999999, + -99999999999999999999999, + -999999999999999999999999, + -9999999999999999999999999, + -99999999999999999999999999, + -999999999999999999999999999, + -9999999999999999999999999999, + -99999999999999999999999999999, + -999999999999999999999999999999, + -9999999999999999999999999999999, + -99999999999999999999999999999999, + -999999999999999999999999999999999, + -9999999999999999999999999999999999, + -99999999999999999999999999999999999, + -999999999999999999999999999999999999, + -9999999999999999999999999999999999999, + -170141183460469231731687303715884105728, +]; + +/// The maximum precision for [DataType::Decimal] values +pub const DECIMAL_MAX_PRECISION: usize = 38; + +/// The maximum scale for [DataType::Decimal] values +pub const DECIMAL_MAX_SCALE: usize = 38; + +/// The default scale for [DataType::Decimal] values +pub const DECIMAL_DEFAULT_SCALE: usize = 10; + +/// Validates that the specified `i128` value can be properly +/// interpreted as a Decimal number with precision `precision` +#[inline] +pub(crate) fn validate_decimal_precision(value: i128, precision: usize) -> Result { + let max = MAX_DECIMAL_FOR_EACH_PRECISION[precision - 1]; + let min = MIN_DECIMAL_FOR_EACH_PRECISION[precision - 1]; + + if value > max { + Err(ArrowError::InvalidArgumentError(format!( + "{} is too large to store in a Decimal of precision {}. Max is {}", + value, precision, max + ))) + } else if value < min { + Err(ArrowError::InvalidArgumentError(format!( + "{} is too small to store in a Decimal of precision {}. Min is {}", + value, precision, min + ))) + } else { + Ok(value) + } +} + impl DataType { /// Parse a data type from a JSON representation. pub(crate) fn from(json: &Value) -> Result { @@ -287,6 +428,9 @@ impl DataType { Some(p) if p == "YEAR_MONTH" => { Ok(DataType::Interval(IntervalUnit::YearMonth)) } + Some(p) if p == "MONTH_DAY_NANO" => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } _ => Err(ArrowError::ParseError( "interval unit missing or invalid".to_string(), )), @@ -359,6 +503,43 @@ impl DataType { )) } } + Some(s) if s == "union" => { + if let Some(Value::String(mode)) = map.get("mode") { + let union_mode = if mode == "SPARSE" { + UnionMode::Sparse + } else if mode == "DENSE" { + UnionMode::Dense + } else { + return Err(ArrowError::ParseError(format!( + "Unknown union mode {:?} for union", + mode + ))); + }; + if let Some(type_ids) = map.get("typeIds") { + let type_ids = type_ids + .as_array() + .unwrap() + .iter() + .map(|t| t.as_i64().unwrap() as i8) + .collect::>(); + + let default_fields = type_ids + .iter() + .map(|_| default_field.clone()) + .collect::>(); + + Ok(DataType::Union(default_fields, type_ids, union_mode)) + } else { + Err(ArrowError::ParseError( + "Expecting a typeIds for union ".to_string(), + )) + } + } else { + Err(ArrowError::ParseError( + "Expecting a mode for union".to_string(), + )) + } + } Some(other) => Err(ArrowError::ParseError(format!( "invalid or unsupported type name: {} in {:?}", other, json @@ -395,7 +576,7 @@ impl DataType { json!({"name": "fixedsizebinary", "byteWidth": byte_width}) } DataType::Struct(_) => json!({"name": "struct"}), - DataType::Union(_) => json!({"name": "union"}), + DataType::Union(_, _, _) => json!({"name": "union"}), DataType::List(_) => json!({ "name": "list"}), DataType::LargeList(_) => json!({ "name": "largelist"}), DataType::FixedSizeList(_, length) => { @@ -442,6 +623,7 @@ impl DataType { DataType::Interval(unit) => json!({"name": "interval", "unit": match unit { IntervalUnit::YearMonth => "YEAR_MONTH", IntervalUnit::DayTime => "DAY_TIME", + IntervalUnit::MonthDayNano => "MONTH_DAY_NANO", }}), DataType::Duration(unit) => json!({"name": "duration", "unit": match unit { TimeUnit::Second => "SECOND", @@ -477,9 +659,19 @@ impl DataType { ) } + /// Returns true if this type is valid as a dictionary key + /// (e.g. [`super::ArrowDictionaryKeyType`] + pub fn is_dictionary_key_type(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 + ) + } + /// Compares the datatype with another, ignoring nested field names /// and metadata. - pub(crate) fn equals_datatype(&self, other: &DataType) -> bool { + pub fn equals_datatype(&self, other: &DataType) -> bool { match (&self, other) { (DataType::List(a), DataType::List(b)) | (DataType::LargeList(a), DataType::LargeList(b)) => { diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 7e98508cf090..2f1b092a862b 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -26,9 +26,9 @@ use crate::{ impl TryFrom<&FFI_ArrowSchema> for DataType { type Error = ArrowError; - /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings) fn try_from(c_schema: &FFI_ArrowSchema) -> Result { - let dtype = match c_schema.format() { + let mut dtype = match c_schema.format() { "n" => DataType::Null, "b" => DataType::Boolean, "c" => DataType::Int8, @@ -52,6 +52,10 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "ttm" => DataType::Time32(TimeUnit::Millisecond), "ttu" => DataType::Time64(TimeUnit::Microsecond), "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "tDs" => DataType::Duration(TimeUnit::Second), + "tDm" => DataType::Duration(TimeUnit::Millisecond), + "tDu" => DataType::Duration(TimeUnit::Microsecond), + "tDn" => DataType::Duration(TimeUnit::Nanosecond), "+l" => { let c_child = c_schema.child(0); DataType::List(Box::new(Field::try_from(c_child)?)) @@ -67,6 +71,23 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { // Parametrized types, requiring string parse other => { match other.splitn(2, ':').collect::>().as_slice() { + // FixedSizeBinary type in format "w:num_bytes" + ["w", num_bytes] => { + let parsed_num_bytes = num_bytes.parse::().map_err(|_| { + ArrowError::CDataInterface( + "FixedSizeBinary requires an integer parameter representing number of bytes per element".to_string()) + })?; + DataType::FixedSizeBinary(parsed_num_bytes) + }, + // FixedSizeList type in format "+w:num_elems" + ["+w", num_elems] => { + let c_child = c_schema.child(0); + let parsed_num_elems = num_elems.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The FixedSizeList type requires an integer parameter representing number of elements per list".to_string()) + })?; + DataType::FixedSizeList(Box::new(Field::try_from(c_child)?), parsed_num_elems) + }, // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" ["d", extra] => { match extra.splitn(3, ',').collect::>().as_slice() { @@ -134,6 +155,12 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { } } }; + + if let Some(dict_schema) = c_schema.dictionary() { + let value_type = Self::try_from(dict_schema)?; + dtype = DataType::Dictionary(Box::new(dtype), Box::new(value_type)); + } + Ok(dtype) } } @@ -167,54 +194,14 @@ impl TryFrom<&FFI_ArrowSchema> for Schema { impl TryFrom<&DataType> for FFI_ArrowSchema { type Error = ArrowError; - /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + /// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings) fn try_from(dtype: &DataType) -> Result { - let format = match dtype { - DataType::Null => "n".to_string(), - DataType::Boolean => "b".to_string(), - DataType::Int8 => "c".to_string(), - DataType::UInt8 => "C".to_string(), - DataType::Int16 => "s".to_string(), - DataType::UInt16 => "S".to_string(), - DataType::Int32 => "i".to_string(), - DataType::UInt32 => "I".to_string(), - DataType::Int64 => "l".to_string(), - DataType::UInt64 => "L".to_string(), - DataType::Float16 => "e".to_string(), - DataType::Float32 => "f".to_string(), - DataType::Float64 => "g".to_string(), - DataType::Binary => "z".to_string(), - DataType::LargeBinary => "Z".to_string(), - DataType::Utf8 => "u".to_string(), - DataType::LargeUtf8 => "U".to_string(), - DataType::Decimal(precision, scale) => format!("d:{},{}", precision, scale), - DataType::Date32 => "tdD".to_string(), - DataType::Date64 => "tdm".to_string(), - DataType::Time32(TimeUnit::Second) => "tts".to_string(), - DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(), - DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(), - DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(), - DataType::Timestamp(TimeUnit::Second, None) => "tss:".to_string(), - DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:".to_string(), - DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:".to_string(), - DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:".to_string(), - DataType::Timestamp(TimeUnit::Second, Some(tz)) => format!("tss:{}", tz), - DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => format!("tsm:{}", tz), - DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => format!("tsu:{}", tz), - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => format!("tsn:{}", tz), - DataType::List(_) => "+l".to_string(), - DataType::LargeList(_) => "+L".to_string(), - DataType::Struct(_) => "+s".to_string(), - other => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - other - ))) - } - }; + let format = get_format_string(dtype)?; // allocate and hold the children let children = match dtype { - DataType::List(child) | DataType::LargeList(child) => { + DataType::List(child) + | DataType::LargeList(child) + | DataType::FixedSizeList(child, _) => { vec![FFI_ArrowSchema::try_from(child.as_ref())?] } DataType::Struct(fields) => fields @@ -223,7 +210,63 @@ impl TryFrom<&DataType> for FFI_ArrowSchema { .collect::>>()?, _ => vec![], }; - FFI_ArrowSchema::try_new(&format, children) + let dictionary = if let DataType::Dictionary(_, value_data_type) = dtype { + Some(Self::try_from(value_data_type.as_ref())?) + } else { + None + }; + FFI_ArrowSchema::try_new(&format, children, dictionary) + } +} + +fn get_format_string(dtype: &DataType) -> Result { + match dtype { + DataType::Null => Ok("n".to_string()), + DataType::Boolean => Ok("b".to_string()), + DataType::Int8 => Ok("c".to_string()), + DataType::UInt8 => Ok("C".to_string()), + DataType::Int16 => Ok("s".to_string()), + DataType::UInt16 => Ok("S".to_string()), + DataType::Int32 => Ok("i".to_string()), + DataType::UInt32 => Ok("I".to_string()), + DataType::Int64 => Ok("l".to_string()), + DataType::UInt64 => Ok("L".to_string()), + DataType::Float16 => Ok("e".to_string()), + DataType::Float32 => Ok("f".to_string()), + DataType::Float64 => Ok("g".to_string()), + DataType::Binary => Ok("z".to_string()), + DataType::LargeBinary => Ok("Z".to_string()), + DataType::Utf8 => Ok("u".to_string()), + DataType::LargeUtf8 => Ok("U".to_string()), + DataType::FixedSizeBinary(num_bytes) => Ok(format!("w:{}", num_bytes)), + DataType::FixedSizeList(_, num_elems) => Ok(format!("+w:{}", num_elems)), + DataType::Decimal(precision, scale) => Ok(format!("d:{},{}", precision, scale)), + DataType::Date32 => Ok("tdD".to_string()), + DataType::Date64 => Ok("tdm".to_string()), + DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()), + DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".to_string()), + DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".to_string()), + DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".to_string()), + DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".to_string()), + DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".to_string()), + DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".to_string()), + DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".to_string()), + DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(format!("tss:{}", tz)), + DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(format!("tsm:{}", tz)), + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(format!("tsu:{}", tz)), + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(format!("tsn:{}", tz)), + DataType::Duration(TimeUnit::Second) => Ok("tDs".to_string()), + DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".to_string()), + DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".to_string()), + DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".to_string()), + DataType::List(_) => Ok("+l".to_string()), + DataType::LargeList(_) => Ok("+L".to_string()), + DataType::Struct(_) => Ok("+s".to_string()), + DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type), + other => Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + other + ))), } } @@ -311,6 +354,11 @@ mod tests { round_trip_type(DataType::Float64)?; round_trip_type(DataType::Date64)?; round_trip_type(DataType::Time64(TimeUnit::Nanosecond))?; + round_trip_type(DataType::FixedSizeBinary(12))?; + round_trip_type(DataType::FixedSizeList( + Box::new(Field::new("a", DataType::Int64, false)), + 5, + ))?; round_trip_type(DataType::Utf8)?; round_trip_type(DataType::List(Box::new(Field::new( "a", diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs index 497dbb389fd7..5025d32a4f37 100644 --- a/arrow/src/datatypes/field.rs +++ b/arrow/src/datatypes/field.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; use std::collections::BTreeMap; +use std::hash::{Hash, Hasher}; use serde_derive::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -27,7 +29,7 @@ use super::DataType; /// Contains the meta-data for a single relative type. /// /// The `Schema` object is an ordered collection of `Field` objects. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct Field { name: String, data_type: DataType, @@ -39,6 +41,47 @@ pub struct Field { metadata: Option>, } +// Auto-derive `PartialEq` traits will pull `dict_id` and `dict_is_ordered` +// into comparison. However, these properties are only used in IPC context +// for matching dictionary encoded data. They are not necessary to be same +// to consider schema equality. For example, in C++ `Field` implementation, +// it doesn't contain these dictionary properties too. +impl PartialEq for Field { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.data_type == other.data_type + && self.nullable == other.nullable + && self.metadata == other.metadata + } +} + +impl Eq for Field {} + +impl PartialOrd for Field { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Field { + fn cmp(&self, other: &Self) -> Ordering { + self.name + .cmp(other.name()) + .then(self.data_type.cmp(other.data_type())) + .then(self.nullable.cmp(&other.nullable)) + .then(self.metadata.cmp(&other.metadata)) + } +} + +impl Hash for Field { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.data_type.hash(state); + self.nullable.hash(state); + self.metadata.hash(state); + } +} + impl Field { /// Creates a new field pub fn new(name: &str, data_type: DataType, nullable: bool) -> Self { @@ -83,10 +126,16 @@ impl Field { } } + /// Sets the metadata of this `Field` to be `metadata` and returns self + pub fn with_metadata(mut self, metadata: Option>) -> Self { + self.set_metadata(metadata); + self + } + /// Returns the immutable reference to the `Field`'s optional custom metadata. #[inline] - pub const fn metadata(&self) -> &Option> { - &self.metadata + pub const fn metadata(&self) -> Option<&BTreeMap> { + self.metadata.as_ref() } /// Returns an immutable reference to the `Field`'s name. @@ -107,6 +156,47 @@ impl Field { self.nullable } + /// Returns a (flattened) vector containing all fields contained within this field (including it self) + pub(crate) fn fields(&self) -> Vec<&Field> { + let mut collected_fields = vec![self]; + collected_fields.append(&mut self._fields(&self.data_type)); + + collected_fields + } + + fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> { + let mut collected_fields = vec![]; + + match dt { + DataType::Struct(fields) | DataType::Union(fields, _, _) => { + collected_fields.extend(fields.iter().flat_map(|f| f.fields())) + } + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) + | DataType::Map(field, _) => collected_fields.extend(field.fields()), + DataType::Dictionary(_, value_field) => { + collected_fields.append(&mut self._fields(value_field.as_ref())) + } + _ => (), + } + + collected_fields + } + + /// Returns a vector containing all (potentially nested) `Field` instances selected by the + /// dictionary ID they use + #[inline] + pub(crate) fn fields_with_dict_id(&self, id: i64) -> Vec<&Field> { + self.fields() + .into_iter() + .filter(|&field| { + matches!(field.data_type(), DataType::Dictionary(_, _)) + && field.dict_id == id + }) + .collect() + } + /// Returns the dictionary ID, if this is a dictionary type. #[inline] pub const fn dict_id(&self) -> Option { @@ -256,7 +346,7 @@ impl Field { DataType::Struct(mut fields) => match map.get("children") { Some(Value::Array(values)) => { let struct_fields: Result> = - values.iter().map(|v| Field::from(v)).collect(); + values.iter().map(Field::from).collect(); fields.append(&mut struct_fields?); DataType::Struct(fields) } @@ -300,6 +390,23 @@ impl Field { } } } + DataType::Union(_, type_ids, mode) => match map.get("children") { + Some(Value::Array(values)) => { + let union_fields: Vec = + values.iter().map(Field::from).collect::>()?; + DataType::Union(union_fields, type_ids, mode) + } + Some(_) => { + return Err(ArrowError::ParseError( + "Field 'children' must be an array".to_string(), + )) + } + None => { + return Err(ArrowError::ParseError( + "Field missing 'children' attribute".to_string(), + )); + } + }, _ => data_type, }; @@ -454,18 +561,34 @@ impl Field { )); } }, - DataType::Union(nested_fields) => match &from.data_type { - DataType::Union(from_nested_fields) => { - for from_field in from_nested_fields { + DataType::Union(nested_fields, type_ids, _) => match &from.data_type { + DataType::Union(from_nested_fields, from_type_ids, _) => { + for (idx, from_field) in from_nested_fields.iter().enumerate() { let mut is_new_field = true; - for self_field in nested_fields.iter_mut() { + let field_type_id = from_type_ids.get(idx).unwrap(); + + for (self_idx, self_field) in nested_fields.iter_mut().enumerate() + { if from_field == self_field { + let self_type_id = type_ids.get(self_idx).unwrap(); + + // If the nested fields in two unions are the same, they must have same + // type id. + if self_type_id != field_type_id { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting type ids in union datatype" + .to_string(), + )); + } + is_new_field = false; break; } } + if is_new_field { nested_fields.push(from_field.clone()); + type_ids.push(*field_type_id); } } } @@ -572,3 +695,102 @@ impl std::fmt::Display for Field { write!(f, "{:?}", self) } } + +#[cfg(test)] +mod test { + use super::{DataType, Field}; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + #[test] + fn test_fields_with_dict_id() { + let dict1 = Field::new_dict( + "dict1", + DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), + false, + 10, + false, + ); + let dict2 = Field::new_dict( + "dict2", + DataType::Dictionary(DataType::Int32.into(), DataType::Int8.into()), + false, + 20, + false, + ); + + let field = Field::new( + "struct]>", + DataType::Struct(vec![ + dict1.clone(), + Field::new( + "list[struct]>]", + DataType::List(Box::new(Field::new( + "struct]>", + DataType::Struct(vec![ + dict1.clone(), + Field::new( + "list[struct]", + DataType::List(Box::new(Field::new( + "struct", + DataType::Struct(vec![dict2.clone()]), + false, + ))), + false, + ), + ]), + false, + ))), + false, + ), + ]), + false, + ); + + for field in field.fields_with_dict_id(10) { + assert_eq!(dict1, *field); + } + for field in field.fields_with_dict_id(20) { + assert_eq!(dict2, *field); + } + } + + fn get_field_hash(field: &Field) -> u64 { + let mut s = DefaultHasher::new(); + field.hash(&mut s); + s.finish() + } + + #[test] + fn test_field_comparison_case() { + // dictionary-encoding properties not used for field comparison + let dict1 = Field::new_dict( + "dict1", + DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), + false, + 10, + false, + ); + let dict2 = Field::new_dict( + "dict1", + DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), + false, + 20, + false, + ); + + assert_eq!(dict1, dict2); + assert_eq!(get_field_hash(&dict1), get_field_hash(&dict2)); + + let dict1 = Field::new_dict( + "dict0", + DataType::Dictionary(DataType::Utf8.into(), DataType::Int32.into()), + false, + 10, + false, + ); + + assert_ne!(dict1, dict2); + assert_ne!(get_field_hash(&dict1), get_field_hash(&dict2)); + } +} diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index 9920cf95d3c6..47074633d7e2 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -46,7 +46,7 @@ pub type SchemaRef = Arc; mod tests { use super::*; use crate::error::Result; - use serde_json::Value::{Bool, Number as VNumber}; + use serde_json::Value::{Bool, Number as VNumber, String as VString}; use serde_json::{Number, Value}; use std::{ collections::{BTreeMap, HashMap}, @@ -123,12 +123,12 @@ mod tests { let field_metadata: BTreeMap = kv_array.iter().cloned().collect(); // Non-empty map: should be converted as JSON obj { ... } - let mut first_name = Field::new("first_name", DataType::Utf8, false); - first_name.set_metadata(Some(field_metadata)); + let first_name = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(field_metadata)); // Empty map: should be omitted. - let mut last_name = Field::new("last_name", DataType::Utf8, false); - last_name.set_metadata(Some(BTreeMap::default())); + let last_name = Field::new("last_name", DataType::Utf8, false) + .with_metadata(Some(BTreeMap::default())); let person = DataType::Struct(vec![ first_name, @@ -392,6 +392,61 @@ mod tests { assert_eq!(expected, dt); } + #[test] + fn parse_union_from_json() { + let json = r#" + { + "name": "my_union", + "nullable": false, + "type": { + "name": "union", + "mode": "SPARSE", + "typeIds": [ + 5, + 7 + ] + }, + "children": [ + { + "name": "f1", + "type": { + "name": "int", + "isSigned": true, + "bitWidth": 32 + }, + "nullable": true, + "children": [] + }, + { + "name": "f2", + "type": { + "name": "utf8" + }, + "nullable": true, + "children": [] + } + ] + } + "#; + let value: Value = serde_json::from_str(json).unwrap(); + let dt = Field::from(&value).unwrap(); + + let expected = Field::new( + "my_union", + DataType::Union( + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + vec![5, 7], + UnionMode::Sparse, + ), + false, + ); + + assert_eq!(expected, dt); + } + #[test] fn parse_utf8_from_json() { let json = "{\"name\":\"utf8\"}"; @@ -454,13 +509,14 @@ mod tests { ), Field::new("c19", DataType::Interval(IntervalUnit::DayTime), false), Field::new("c20", DataType::Interval(IntervalUnit::YearMonth), false), + Field::new("c21", DataType::Interval(IntervalUnit::MonthDayNano), false), Field::new( - "c21", + "c22", DataType::List(Box::new(Field::new("item", DataType::Boolean, true))), false, ), Field::new( - "c22", + "c23", DataType::FixedSizeList( Box::new(Field::new("bools", DataType::Boolean, false)), 5, @@ -468,7 +524,7 @@ mod tests { false, ), Field::new( - "c23", + "c24", DataType::List(Box::new(Field::new( "inner_list", DataType::List(Box::new(Field::new( @@ -481,21 +537,22 @@ mod tests { true, ), Field::new( - "c24", + "c25", DataType::Struct(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::UInt16, false), ]), false, ), - Field::new("c25", DataType::Interval(IntervalUnit::YearMonth), true), - Field::new("c26", DataType::Interval(IntervalUnit::DayTime), true), - Field::new("c27", DataType::Duration(TimeUnit::Second), false), - Field::new("c28", DataType::Duration(TimeUnit::Millisecond), false), - Field::new("c29", DataType::Duration(TimeUnit::Microsecond), false), - Field::new("c30", DataType::Duration(TimeUnit::Nanosecond), false), + Field::new("c26", DataType::Interval(IntervalUnit::YearMonth), true), + Field::new("c27", DataType::Interval(IntervalUnit::DayTime), true), + Field::new("c28", DataType::Interval(IntervalUnit::MonthDayNano), true), + Field::new("c29", DataType::Duration(TimeUnit::Second), false), + Field::new("c30", DataType::Duration(TimeUnit::Millisecond), false), + Field::new("c31", DataType::Duration(TimeUnit::Microsecond), false), + Field::new("c32", DataType::Duration(TimeUnit::Nanosecond), false), Field::new_dict( - "c31", + "c33", DataType::Dictionary( Box::new(DataType::Int32), Box::new(DataType::Utf8), @@ -504,10 +561,10 @@ mod tests { 123, true, ), - Field::new("c32", DataType::LargeBinary, true), - Field::new("c33", DataType::LargeUtf8, true), + Field::new("c34", DataType::LargeBinary, true), + Field::new("c35", DataType::LargeUtf8, true), Field::new( - "c34", + "c36", DataType::LargeList(Box::new(Field::new( "inner_large_list", DataType::LargeList(Box::new(Field::new( @@ -520,7 +577,7 @@ mod tests { true, ), Field::new( - "c35", + "c37", DataType::Map( Box::new(Field::new( "my_entries", @@ -731,6 +788,15 @@ mod tests { { "name": "c21", "nullable": false, + "type": { + "name": "interval", + "unit": "MONTH_DAY_NANO" + }, + "children": [] + }, + { + "name": "c22", + "nullable": false, "type": { "name": "list" }, @@ -746,7 +812,7 @@ mod tests { ] }, { - "name": "c22", + "name": "c23", "nullable": false, "type": { "name": "fixedsizelist", @@ -764,7 +830,7 @@ mod tests { ] }, { - "name": "c23", + "name": "c24", "nullable": true, "type": { "name": "list" @@ -790,7 +856,7 @@ mod tests { ] }, { - "name": "c24", + "name": "c25", "nullable": false, "type": { "name": "struct" @@ -817,7 +883,7 @@ mod tests { ] }, { - "name": "c25", + "name": "c26", "nullable": true, "type": { "name": "interval", @@ -826,7 +892,7 @@ mod tests { "children": [] }, { - "name": "c26", + "name": "c27", "nullable": true, "type": { "name": "interval", @@ -835,7 +901,16 @@ mod tests { "children": [] }, { - "name": "c27", + "name": "c28", + "nullable": true, + "type": { + "name": "interval", + "unit": "MONTH_DAY_NANO" + }, + "children": [] + }, + { + "name": "c29", "nullable": false, "type": { "name": "duration", @@ -844,7 +919,7 @@ mod tests { "children": [] }, { - "name": "c28", + "name": "c30", "nullable": false, "type": { "name": "duration", @@ -853,7 +928,7 @@ mod tests { "children": [] }, { - "name": "c29", + "name": "c31", "nullable": false, "type": { "name": "duration", @@ -862,7 +937,7 @@ mod tests { "children": [] }, { - "name": "c30", + "name": "c32", "nullable": false, "type": { "name": "duration", @@ -871,7 +946,7 @@ mod tests { "children": [] }, { - "name": "c31", + "name": "c33", "nullable": true, "children": [], "type": { @@ -888,7 +963,7 @@ mod tests { } }, { - "name": "c32", + "name": "c34", "nullable": true, "type": { "name": "largebinary" @@ -896,7 +971,7 @@ mod tests { "children": [] }, { - "name": "c33", + "name": "c35", "nullable": true, "type": { "name": "largeutf8" @@ -904,7 +979,7 @@ mod tests { "children": [] }, { - "name": "c34", + "name": "c36", "nullable": true, "type": { "name": "largelist" @@ -930,7 +1005,7 @@ mod tests { ] }, { - "name": "c35", + "name": "c37", "nullable": false, "type": { "name": "map", @@ -1134,8 +1209,7 @@ mod tests { assert!(schema2 != schema4); assert!(schema3 != schema4); - let mut f = Field::new("c1", DataType::Utf8, false); - f.set_metadata(Some( + let f = Field::new("c1", DataType::Utf8, false).with_metadata(Some( [("foo".to_string(), "bar".to_string())] .iter() .cloned() @@ -1156,6 +1230,7 @@ mod tests { assert_eq!(Some(VNumber(Number::from(1))), 1i16.into_json_value()); assert_eq!(Some(VNumber(Number::from(1))), 1i32.into_json_value()); assert_eq!(Some(VNumber(Number::from(1))), 1i64.into_json_value()); + assert_eq!(Some(VString("1".to_string())), 1i128.into_json_value()); assert_eq!(Some(VNumber(Number::from(1))), 1u8.into_json_value()); assert_eq!(Some(VNumber(Number::from(1))), 1u16.into_json_value()); assert_eq!(Some(VNumber(Number::from(1))), 1u32.into_json_value()); @@ -1174,8 +1249,8 @@ mod tests { fn person_schema() -> Schema { let kv_array = [("k".to_string(), "v".to_string())]; let field_metadata: BTreeMap = kv_array.iter().cloned().collect(); - let mut first_name = Field::new("first_name", DataType::Utf8, false); - first_name.set_metadata(Some(field_metadata)); + let first_name = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(field_metadata)); Schema::new(vec![ first_name, @@ -1206,16 +1281,16 @@ mod tests { .iter() .cloned() .collect(); - let mut f1 = Field::new("first_name", DataType::Utf8, false); - f1.set_metadata(Some(metadata1)); + let f1 = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(metadata1)); let metadata2: BTreeMap = [("foo".to_string(), "baz".to_string())] .iter() .cloned() .collect(); - let mut f2 = Field::new("first_name", DataType::Utf8, false); - f2.set_metadata(Some(metadata2)); + let f2 = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(metadata2)); assert!( Schema::try_merge(vec![Schema::new(vec![f1]), Schema::new(vec![f2])]) @@ -1229,8 +1304,8 @@ mod tests { .iter() .cloned() .collect(); - let mut f2 = Field::new("first_name", DataType::Utf8, false); - f2.set_metadata(Some(metadata2)); + let f2 = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(metadata2)); assert!(f1.try_merge(&f2).is_ok()); assert!(f1.metadata().is_some()); @@ -1240,15 +1315,13 @@ mod tests { ); // 3. Some + Some - let mut f1 = Field::new("first_name", DataType::Utf8, false); - f1.set_metadata(Some( + let mut f1 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( [("foo".to_string(), "bar".to_string())] .iter() .cloned() .collect(), )); - let mut f2 = Field::new("first_name", DataType::Utf8, false); - f2.set_metadata(Some( + let f2 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( [("foo2".to_string(), "bar2".to_string())] .iter() .cloned() @@ -1258,7 +1331,7 @@ mod tests { assert!(f1.try_merge(&f2).is_ok()); assert!(f1.metadata().is_some()); assert_eq!( - f1.metadata().clone().unwrap(), + f1.metadata().cloned().unwrap(), [ ("foo".to_string(), "bar".to_string()), ("foo2".to_string(), "bar2".to_string()) @@ -1269,8 +1342,7 @@ mod tests { ); // 4. Some + None. - let mut f1 = Field::new("first_name", DataType::Utf8, false); - f1.set_metadata(Some( + let mut f1 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( [("foo".to_string(), "bar".to_string())] .iter() .cloned() @@ -1280,7 +1352,7 @@ mod tests { assert!(f1.try_merge(&f2).is_ok()); assert!(f1.metadata().is_some()); assert_eq!( - f1.metadata().clone().unwrap(), + f1.metadata().cloned().unwrap(), [("foo".to_string(), "bar".to_string())] .iter() .cloned() @@ -1358,28 +1430,40 @@ mod tests { Schema::try_merge(vec![ Schema::new(vec![Field::new( "c1", - DataType::Union(vec![ - Field::new("c11", DataType::Utf8, true), - Field::new("c12", DataType::Utf8, true), - ]), + DataType::Union( + vec![ + Field::new("c11", DataType::Utf8, true), + Field::new("c12", DataType::Utf8, true), + ], + vec![0, 1], + UnionMode::Dense + ), false ),]), Schema::new(vec![Field::new( "c1", - DataType::Union(vec![ - Field::new("c12", DataType::Utf8, true), - Field::new("c13", DataType::Time64(TimeUnit::Second), true), - ]), + DataType::Union( + vec![ + Field::new("c12", DataType::Utf8, true), + Field::new("c13", DataType::Time64(TimeUnit::Second), true), + ], + vec![1, 2], + UnionMode::Dense + ), false ),]) ])?, Schema::new(vec![Field::new( "c1", - DataType::Union(vec![ - Field::new("c11", DataType::Utf8, true), - Field::new("c12", DataType::Utf8, true), - Field::new("c13", DataType::Time64(TimeUnit::Second), true), - ]), + DataType::Union( + vec![ + Field::new("c11", DataType::Utf8, true), + Field::new("c12", DataType::Utf8, true), + Field::new("c13", DataType::Time64(TimeUnit::Second), true), + ], + vec![0, 1, 2], + UnionMode::Dense + ), false ),]), ); diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index 6e8cf892237e..d9a3f667d8e4 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -15,9 +15,13 @@ // specific language governing permissions and limitations // under the License. +use super::DataType; +use half::f16; use serde_json::{Number, Value}; -use super::DataType; +mod private { + pub trait Sealed {} +} /// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.). pub trait JsonSerializable: 'static { @@ -26,8 +30,26 @@ pub trait JsonSerializable: 'static { /// Trait expressing a Rust type that has the same in-memory representation /// as Arrow. This includes `i16`, `f32`, but excludes `bool` (which in arrow is represented in bits). +/// /// In little endian machines, types that implement [`ArrowNativeType`] can be memcopied to arrow buffers /// as is. +/// +/// # Transmute Safety +/// +/// A type T implementing this trait means that any arbitrary slice of bytes of length and +/// alignment `size_of::()` can be safely interpreted as a value of that type without +/// being unsound, i.e. potentially resulting in undefined behaviour. +/// +/// Note: in the case of floating point numbers this transmutation can result in a signalling +/// NaN, which, whilst sound, can be unwieldy. In general, whilst it is perfectly sound to +/// reinterpret bytes as different types using this trait, it is likely unwise. For more information +/// see [f32::from_bits] and [f64::from_bits]. +/// +/// Note: `bool` is restricted to `0` or `1`, and so `bool: !ArrowNativeType` +/// +/// # Sealed +/// +/// Due to the above restrictions, this trait is sealed to prevent accidental misuse pub trait ArrowNativeType: std::fmt::Debug + Send @@ -37,6 +59,7 @@ pub trait ArrowNativeType: + std::str::FromStr + Default + JsonSerializable + + private::Sealed { /// Convert native type from usize. #[inline] @@ -67,6 +90,12 @@ pub trait ArrowNativeType: fn from_i64(_: i64) -> Option { None } + + /// Convert native type from i128. + #[inline] + fn from_i128(_: i128) -> Option { + None + } } /// Trait bridging the dynamic-typed nature of Arrow (via [`DataType`]) with the @@ -103,6 +132,7 @@ impl JsonSerializable for i8 { } } +impl private::Sealed for i8 {} impl ArrowNativeType for i8 { #[inline] fn from_usize(v: usize) -> Option { @@ -126,6 +156,7 @@ impl JsonSerializable for i16 { } } +impl private::Sealed for i16 {} impl ArrowNativeType for i16 { #[inline] fn from_usize(v: usize) -> Option { @@ -149,6 +180,7 @@ impl JsonSerializable for i32 { } } +impl private::Sealed for i32 {} impl ArrowNativeType for i32 { #[inline] fn from_usize(v: usize) -> Option { @@ -178,6 +210,7 @@ impl JsonSerializable for i64 { } } +impl private::Sealed for i64 {} impl ArrowNativeType for i64 { #[inline] fn from_usize(v: usize) -> Option { @@ -201,12 +234,47 @@ impl ArrowNativeType for i64 { } } +impl JsonSerializable for i128 { + fn into_json_value(self) -> Option { + // Serialize as string to avoid issues with arbitrary_precision serde_json feature + // - https://github.com/serde-rs/json/issues/559 + // - https://github.com/serde-rs/json/issues/845 + // - https://github.com/serde-rs/json/issues/846 + Some(self.to_string().into()) + } +} + +impl private::Sealed for i128 {} +impl ArrowNativeType for i128 { + #[inline] + fn from_usize(v: usize) -> Option { + num::FromPrimitive::from_usize(v) + } + + #[inline] + fn to_usize(&self) -> Option { + num::ToPrimitive::to_usize(self) + } + + #[inline] + fn to_isize(&self) -> Option { + num::ToPrimitive::to_isize(self) + } + + /// Convert native type from i128. + #[inline] + fn from_i128(val: i128) -> Option { + Some(val) + } +} + impl JsonSerializable for u8 { fn into_json_value(self) -> Option { Some(self.into()) } } +impl private::Sealed for u8 {} impl ArrowNativeType for u8 { #[inline] fn from_usize(v: usize) -> Option { @@ -230,6 +298,7 @@ impl JsonSerializable for u16 { } } +impl private::Sealed for u16 {} impl ArrowNativeType for u16 { #[inline] fn from_usize(v: usize) -> Option { @@ -253,6 +322,7 @@ impl JsonSerializable for u32 { } } +impl private::Sealed for u32 {} impl ArrowNativeType for u32 { #[inline] fn from_usize(v: usize) -> Option { @@ -276,6 +346,7 @@ impl JsonSerializable for u64 { } } +impl private::Sealed for u64 {} impl ArrowNativeType for u64 { #[inline] fn from_usize(v: usize) -> Option { @@ -293,6 +364,12 @@ impl ArrowNativeType for u64 { } } +impl JsonSerializable for f16 { + fn into_json_value(self) -> Option { + Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number) + } +} + impl JsonSerializable for f32 { fn into_json_value(self) -> Option { Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number) @@ -305,8 +382,12 @@ impl JsonSerializable for f64 { } } +impl ArrowNativeType for f16 {} +impl private::Sealed for f16 {} impl ArrowNativeType for f32 {} +impl private::Sealed for f32 {} impl ArrowNativeType for f64 {} +impl private::Sealed for f64 {} /// Allows conversion from supported Arrow types to a byte slice. pub trait ToByteSlice { diff --git a/arrow/src/datatypes/numeric.rs b/arrow/src/datatypes/numeric.rs index 39c6732c3231..b8fa87197c38 100644 --- a/arrow/src/datatypes/numeric.rs +++ b/arrow/src/datatypes/numeric.rs @@ -19,9 +19,7 @@ use super::*; #[cfg(feature = "simd")] use packed_simd::*; #[cfg(feature = "simd")] -use std::ops::{ - Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, Mul, Neg, Not, Rem, Sub, -}; +use std::ops::{Add, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, Mul, Not, Rem, Sub}; /// A subtype of primitive type that represents numeric values. /// @@ -147,6 +145,20 @@ macro_rules! make_numeric_type { // this match will get removed by the compiler since the number of lanes is known at // compile-time for each concrete numeric type match Self::lanes() { + 4 => { + // the bit position in each lane indicates the index of that lane + let vecidx = i128x4::new(1, 2, 4, 8); + + // broadcast the lowermost 8 bits of mask to each lane + let vecmask = i128x4::splat((mask & 0x0F) as i128); + // compute whether the bit corresponding to each lanes index is set + let vecmask = (vecidx & vecmask).eq(vecidx); + + // transmute is necessary because the different match arms return different + // mask types, at runtime only one of those expressions will exist per type, + // with the type being equal to `SimdMask`. + unsafe { std::mem::transmute(vecmask) } + } 8 => { // the bit position in each lane indicates the index of that lane let vecidx = i64x8::new(1, 2, 4, 8, 16, 32, 64, 128); @@ -348,79 +360,12 @@ make_numeric_type!(Time64MicrosecondType, i64, i64x8, m64x8); make_numeric_type!(Time64NanosecondType, i64, i64x8, m64x8); make_numeric_type!(IntervalYearMonthType, i32, i32x16, m32x16); make_numeric_type!(IntervalDayTimeType, i64, i64x8, m64x8); +make_numeric_type!(IntervalMonthDayNanoType, i128, i128x4, m128x4); make_numeric_type!(DurationSecondType, i64, i64x8, m64x8); make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8); make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8); make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8); -/// A subtype of primitive type that represents signed numeric values. -/// -/// SIMD operations are defined in this trait if available on the target system. -#[cfg(feature = "simd")] -pub trait ArrowSignedNumericType: ArrowNumericType -where - Self::SignedSimd: Neg, -{ - /// Defines the SIMD type that should be used for this numeric type - type SignedSimd; - - /// Loads a slice of signed numeric type into a SIMD register - fn load_signed(slice: &[Self::Native]) -> Self::SignedSimd; - - /// Performs a SIMD unary operation on signed numeric type - fn signed_unary_op Self::SignedSimd>( - a: Self::SignedSimd, - op: F, - ) -> Self::SignedSimd; - - /// Writes a signed SIMD result back to a slice - fn write_signed(simd_result: Self::SignedSimd, slice: &mut [Self::Native]); -} - -#[cfg(not(feature = "simd"))] -pub trait ArrowSignedNumericType: ArrowNumericType -where - Self::Native: std::ops::Neg, -{ -} - -macro_rules! make_signed_numeric_type { - ($impl_ty:ty, $simd_ty:ident) => { - #[cfg(feature = "simd")] - impl ArrowSignedNumericType for $impl_ty { - type SignedSimd = $simd_ty; - - #[inline] - fn load_signed(slice: &[Self::Native]) -> Self::SignedSimd { - unsafe { Self::SignedSimd::from_slice_unaligned_unchecked(slice) } - } - - #[inline] - fn signed_unary_op Self::SignedSimd>( - a: Self::SignedSimd, - op: F, - ) -> Self::SignedSimd { - op(a) - } - - #[inline] - fn write_signed(simd_result: Self::SignedSimd, slice: &mut [Self::Native]) { - unsafe { simd_result.write_to_slice_unaligned_unchecked(slice) }; - } - } - - #[cfg(not(feature = "simd"))] - impl ArrowSignedNumericType for $impl_ty {} - }; -} - -make_signed_numeric_type!(Int8Type, i8x64); -make_signed_numeric_type!(Int16Type, i16x32); -make_signed_numeric_type!(Int32Type, i32x16); -make_signed_numeric_type!(Int64Type, i64x8); -make_signed_numeric_type!(Float32Type, f32x16); -make_signed_numeric_type!(Float64Type, f64x8); - #[cfg(feature = "simd")] pub trait ArrowFloatNumericType: ArrowNumericType { fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd; @@ -447,11 +392,11 @@ macro_rules! make_float_numeric_type { make_float_numeric_type!(Float32Type, f32x16); make_float_numeric_type!(Float64Type, f64x8); -#[cfg(all(test, simd_x86))] +#[cfg(all(test, feature = "simd"))] mod tests { use crate::datatypes::{ ArrowNumericType, Float32Type, Float64Type, Int32Type, Int64Type, Int8Type, - UInt16Type, + IntervalMonthDayNanoType, UInt16Type, }; use packed_simd::*; use FromCast; @@ -469,6 +414,17 @@ mod tests { }}; } + #[test] + fn test_mask_i128() { + let mask = 0b1101; + let actual = IntervalMonthDayNanoType::mask_from_u64(mask); + let expected = expected_mask!(i128, mask); + let expected = + m128x4::from_cast(i128x4::from_slice_unaligned(expected.as_slice())); + + assert_eq!(expected, actual); + } + #[test] fn test_mask_f64() { let mask = 0b10101010; diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs index cfc0744954b1..5a7336624f5b 100644 --- a/arrow/src/datatypes/schema.rs +++ b/arrow/src/datatypes/schema.rs @@ -53,7 +53,6 @@ impl Schema { /// # Example /// /// ``` - /// # extern crate arrow; /// # use arrow::datatypes::{Field, DataType, Schema}; /// let field_a = Field::new("a", DataType::Int64, false); /// let field_b = Field::new("b", DataType::Boolean, false); @@ -70,7 +69,6 @@ impl Schema { /// # Example /// /// ``` - /// # extern crate arrow; /// # use arrow::datatypes::{Field, DataType, Schema}; /// # use std::collections::HashMap; /// let field_a = Field::new("a", DataType::Int64, false); @@ -89,6 +87,30 @@ impl Schema { Self { fields, metadata } } + /// Sets the metadata of this `Schema` to be `metadata` and returns self + pub fn with_metadata(mut self, metadata: HashMap) -> Self { + self.metadata = metadata; + self + } + + /// Returns a new schema with only the specified columns in the new schema + /// This carries metadata from the parent schema over as well + pub fn project(&self, indices: &[usize]) -> Result { + let new_fields = indices + .iter() + .map(|i| { + self.fields.get(*i).cloned().ok_or_else(|| { + ArrowError::SchemaError(format!( + "project index {} out of bounds, max field {}", + i, + self.fields().len() + )) + }) + }) + .collect::>>()?; + Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) + } + /// Merge schema into self if it is compatible. Struct fields will be merged recursively. /// /// Example: @@ -159,6 +181,12 @@ impl Schema { &self.fields } + /// Returns a vector with references to all fields (including nested fields) + #[inline] + pub(crate) fn all_fields(&self) -> Vec<&Field> { + self.fields.iter().flat_map(|f| f.fields()).collect() + } + /// Returns an immutable reference of a specific `Field` instance selected using an /// offset within the internal `fields` vector. pub fn field(&self, i: usize) -> &Field { @@ -175,7 +203,7 @@ impl Schema { pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { self.fields .iter() - .filter(|f| f.dict_id() == Some(dict_id)) + .flat_map(|f| f.fields_with_dict_id(dict_id)) .collect() } @@ -222,10 +250,7 @@ impl Schema { match *json { Value::Object(ref schema) => { let fields = if let Some(Value::Array(fields)) = schema.get("fields") { - fields - .iter() - .map(|f| Field::from(f)) - .collect::>()? + fields.iter().map(Field::from).collect::>()? } else { return Err(ArrowError::ParseError( "Schema fields should be an array".to_string(), @@ -281,7 +306,7 @@ impl Schema { } } - /// Check to see if `self` is a superset of `other` schema. Here are the comparision rules: + /// Check to see if `self` is a superset of `other` schema. Here are the comparison rules: /// /// * `self` and `other` should contain the same number of fields /// * for every field `f` in `other`, the field in `self` with corresponding index should be a @@ -346,7 +371,7 @@ mod tests { #[test] fn test_ser_de_metadata() { // ser/de with empty metadata - let mut schema = Schema::new(vec![ + let schema = Schema::new(vec![ Field::new("name", DataType::Utf8, false), Field::new("address", DataType::Utf8, false), Field::new("priority", DataType::UInt8, false), @@ -358,13 +383,54 @@ mod tests { assert_eq!(schema, de_schema); // ser/de with non-empty metadata - schema.metadata = [("key".to_owned(), "val".to_owned())] - .iter() - .cloned() - .collect(); + let schema = schema + .with_metadata([("key".to_owned(), "val".to_owned())].into_iter().collect()); let json = serde_json::to_string(&schema).unwrap(); let de_schema = serde_json::from_str(&json).unwrap(); assert_eq!(schema, de_schema); } + + #[test] + fn test_projection() { + let mut metadata = HashMap::new(); + metadata.insert("meta".to_string(), "data".to_string()); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata); + + let projected: Schema = schema.project(&[0, 2]).unwrap(); + + assert_eq!(projected.fields().len(), 2); + assert_eq!(projected.fields()[0].name(), "name"); + assert_eq!(projected.fields()[1].name(), "priority"); + assert_eq!(projected.metadata.get("meta").unwrap(), "data") + } + + #[test] + fn test_oob_projection() { + let mut metadata = HashMap::new(); + metadata.insert("meta".to_string(), "data".to_string()); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata); + + let projected: Result = schema.project(&[0, 3]); + + assert!(projected.is_err()); + if let Err(e) = projected { + assert_eq!( + e.to_string(), + "Schema error: project index 3 out of bounds, max field 3".to_string() + ) + } + } } diff --git a/arrow/src/datatypes/types.rs b/arrow/src/datatypes/types.rs index 30c9aae89565..0937c3b3c9d7 100644 --- a/arrow/src/datatypes/types.rs +++ b/arrow/src/datatypes/types.rs @@ -16,6 +16,7 @@ // under the License. use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit}; +use half::f16; // BooleanType is special: its bit-width is not the size of the primitive type, and its `index` // operation assumes bit-packing. @@ -46,6 +47,7 @@ make_type!(UInt8Type, u8, DataType::UInt8); make_type!(UInt16Type, u16, DataType::UInt16); make_type!(UInt32Type, u32, DataType::UInt32); make_type!(UInt64Type, u64, DataType::UInt64); +make_type!(Float16Type, f16, DataType::Float16); make_type!(Float32Type, f32, DataType::Float32); make_type!(Float64Type, f64, DataType::Float64); make_type!( @@ -96,6 +98,11 @@ make_type!( i64, DataType::Interval(IntervalUnit::DayTime) ); +make_type!( + IntervalMonthDayNanoType, + i128, + DataType::Interval(IntervalUnit::MonthDayNano) +); make_type!( DurationSecondType, i64, @@ -152,6 +159,7 @@ impl ArrowTemporalType for Time64MicrosecondType {} impl ArrowTemporalType for Time64NanosecondType {} // impl ArrowTemporalType for IntervalYearMonthType {} // impl ArrowTemporalType for IntervalDayTimeType {} +// impl ArrowTemporalType for IntervalMonthDayNanoType {} impl ArrowTemporalType for DurationSecondType {} impl ArrowTemporalType for DurationMillisecondType {} impl ArrowTemporalType for DurationMicrosecondType {} diff --git a/arrow/src/error.rs b/arrow/src/error.rs index 86896c016882..ef7abbbddef9 100644 --- a/arrow/src/error.rs +++ b/arrow/src/error.rs @@ -65,7 +65,7 @@ impl From for ArrowError { csv_crate::ErrorKind::Io(error) => ArrowError::CsvError(error.to_string()), csv_crate::ErrorKind::Utf8 { pos: _, err } => ArrowError::CsvError(format!( "Encountered UTF-8 error while reading CSV file: {}", - err.to_string() + err )), csv_crate::ErrorKind::UnequalLengths { expected_len, len, .. diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index a61f291bd4ab..84905af20a63 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -26,9 +26,10 @@ //! //! ```rust //! # use std::sync::Arc; -//! # use arrow::array::{Int32Array, Array, ArrayData, make_array_from_raw}; +//! # use arrow::array::{Int32Array, Array, ArrayData, export_array_into_raw, make_array, make_array_from_raw}; //! # use arrow::error::{Result, ArrowError}; //! # use arrow::compute::kernels::arithmetic; +//! # use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; //! # use std::convert::TryFrom; //! # fn main() -> Result<()> { //! // create an array natively @@ -51,7 +52,37 @@ //! // verify //! assert_eq!(array, Int32Array::from(vec![Some(2), None, Some(6)])); //! +//! // Simulate if raw pointers are provided by consumer +//! let array = make_array(Int32Array::from(vec![Some(1), None, Some(3)]).data().clone()); +//! +//! let out_array = Box::new(FFI_ArrowArray::empty()); +//! let out_schema = Box::new(FFI_ArrowSchema::empty()); +//! let out_array_ptr = Box::into_raw(out_array); +//! let out_schema_ptr = Box::into_raw(out_schema); +//! +//! // export array into raw pointers from consumer +//! unsafe { export_array_into_raw(array, out_array_ptr, out_schema_ptr)?; }; +//! +//! // import it +//! let array = unsafe { make_array_from_raw(out_array_ptr, out_schema_ptr)? }; +//! +//! // perform some operation +//! let array = array.as_any().downcast_ref::().ok_or( +//! ArrowError::ParseError("Expects an int32".to_string()), +//! )?; +//! let array = arithmetic::add(&array, &array)?; +//! +//! // verify +//! assert_eq!(array, Int32Array::from(vec![Some(2), None, Some(6)])); +//! //! // (drop/release) +//! unsafe { +//! Box::from_raw(out_array_ptr); +//! Box::from_raw(out_schema_ptr); +//! Arc::from_raw(array_ptr); +//! Arc::from_raw(schema_ptr); +//! } +//! //! Ok(()) //! } //! ``` @@ -109,19 +140,20 @@ bitflags! { #[repr(C)] #[derive(Debug)] pub struct FFI_ArrowSchema { - format: *const c_char, - name: *const c_char, - metadata: *const c_char, - flags: i64, - n_children: i64, - children: *mut *mut FFI_ArrowSchema, - dictionary: *mut FFI_ArrowSchema, - release: Option, - private_data: *mut c_void, + pub(crate) format: *const c_char, + pub(crate) name: *const c_char, + pub(crate) metadata: *const c_char, + pub(crate) flags: i64, + pub(crate) n_children: i64, + pub(crate) children: *mut *mut FFI_ArrowSchema, + pub(crate) dictionary: *mut FFI_ArrowSchema, + pub(crate) release: Option, + pub(crate) private_data: *mut c_void, } struct SchemaPrivateData { children: Box<[*mut FFI_ArrowSchema]>, + dictionary: *mut FFI_ArrowSchema, } // callback used to drop [FFI_ArrowSchema] when it is exported. @@ -132,15 +164,19 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { let schema = &mut *schema; // take ownership back to release it. - CString::from_raw(schema.format as *mut c_char); + drop(CString::from_raw(schema.format as *mut c_char)); if !schema.name.is_null() { - CString::from_raw(schema.name as *mut c_char); + drop(CString::from_raw(schema.name as *mut c_char)); } if !schema.private_data.is_null() { let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData); for child in private_data.children.iter() { drop(Box::from_raw(*child)) } + if !private_data.dictionary.is_null() { + drop(Box::from_raw(private_data.dictionary)); + } + drop(private_data); } @@ -148,8 +184,13 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { } impl FFI_ArrowSchema { - /// create a new [`Ffi_ArrowSchema`]. This fails if the fields' [`DataType`] is not supported. - pub fn try_new(format: &str, children: Vec) -> Result { + /// create a new [`FFI_ArrowSchema`]. This fails if the fields' + /// [`DataType`] is not supported. + pub fn try_new( + format: &str, + children: Vec, + dictionary: Option, + ) -> Result { let mut this = Self::empty(); let children_ptr = children @@ -162,13 +203,20 @@ impl FFI_ArrowSchema { this.release = Some(release_schema); this.n_children = children_ptr.len() as i64; + let dictionary_ptr = dictionary + .map(|d| Box::into_raw(Box::new(d))) + .unwrap_or(std::ptr::null_mut()); + let mut private_data = Box::new(SchemaPrivateData { children: children_ptr, + dictionary: dictionary_ptr, }); // intentionally set from private_data (see https://github.com/apache/arrow-rs/issues/580) this.children = private_data.children.as_mut_ptr(); + this.dictionary = dictionary_ptr; + this.private_data = Box::into_raw(private_data) as *mut c_void; Ok(this) @@ -232,6 +280,10 @@ impl FFI_ArrowSchema { pub fn nullable(&self) -> bool { (self.flags / 2) & 1 == 1 } + + pub fn dictionary(&self) -> Option<&Self> { + unsafe { self.dictionary.as_ref() } + } } impl Drop for FFI_ArrowSchema { @@ -245,6 +297,7 @@ impl Drop for FFI_ArrowSchema { // returns the number of bits that buffer `i` (in the C data interface) is expected to have. // This is set by the Arrow specification +#[allow(clippy::manual_bits)] fn bit_width(data_type: &DataType, i: usize) -> Result { Ok(match (data_type, i) { // the null buffer is bit sized @@ -263,6 +316,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Float64, 1) => size_of::() * 8, (DataType::Decimal(..), 1) => size_of::() * 8, (DataType::Timestamp(..), 1) => size_of::() * 8, + (DataType::Duration(..), 1) => size_of::() * 8, // primitive types have a single buffer (DataType::Boolean, _) | (DataType::UInt8, _) | @@ -276,12 +330,24 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Float32, _) | (DataType::Float64, _) | (DataType::Decimal(..), _) | - (DataType::Timestamp(..), _) => { + (DataType::Timestamp(..), _) | + (DataType::Duration(..), _) => { return Err(ArrowError::CDataInterface(format!( "The datatype \"{:?}\" expects 2 buffers, but requested {}. Please verify that the C data interface is correctly implemented.", data_type, i ))) } + (DataType::FixedSizeBinary(num_bytes), 1) => size_of::() * (*num_bytes as usize) * 8, + (DataType::FixedSizeList(f, num_elems), 1) => { + let child_bit_width = bit_width(f.data_type(), 1)?; + child_bit_width * (*num_elems as usize) + }, + (DataType::FixedSizeBinary(_), _) | (DataType::FixedSizeList(_, _), _) => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" expects 2 buffers, but requested {}. Please verify that the C data interface is correctly implemented.", + data_type, i + ))) + }, // Variable-sized binaries: have two buffers. // "small": first buffer is i32, second is in bytes (DataType::Utf8, 1) | (DataType::Binary, 1) | (DataType::List(_), 1) => size_of::() * 8, @@ -323,15 +389,15 @@ pub struct FFI_ArrowArray { pub(crate) n_buffers: i64, pub(crate) n_children: i64, pub(crate) buffers: *mut *const c_void, - children: *mut *mut FFI_ArrowArray, - dictionary: *mut FFI_ArrowArray, - release: Option, + pub(crate) children: *mut *mut FFI_ArrowArray, + pub(crate) dictionary: *mut FFI_ArrowArray, + pub(crate) release: Option, // When exported, this MUST contain everything that is owned by this array. // for example, any buffer pointed to in `buffers` must be here, as well // as the `buffers` pointer itself. // In other words, everything in [FFI_ArrowArray] must be owned by // `private_data` and can assume that they do not outlive `private_data`. - private_data: *mut c_void, + pub(crate) private_data: *mut c_void, } impl Drop for FFI_ArrowArray { @@ -355,14 +421,19 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { for child in private.children.iter() { let _ = Box::from_raw(*child); } + if !private.dictionary.is_null() { + let _ = Box::from_raw(private.dictionary); + } array.release = None; } struct ArrayPrivateData { + #[allow(dead_code)] buffers: Vec>, buffers_ptr: Box<[*const c_void]>, children: Box<[*mut FFI_ArrowArray]>, + dictionary: *mut FFI_ArrowArray, } impl FFI_ArrowArray { @@ -370,7 +441,7 @@ impl FFI_ArrowArray { /// # Safety /// This method releases `buffers`. Consumers of this struct *must* call `release` before /// releasing this struct, or contents in `buffers` leak. - fn new(data: &ArrayData) -> Self { + pub fn new(data: &ArrayData) -> Self { // * insert the null buffer at the start // * make all others `Option`. let buffers = iter::once(data.null_buffer().cloned()) @@ -387,8 +458,16 @@ impl FFI_ArrowArray { }) .collect::>(); - let children = data - .child_data() + let empty = vec![]; + let (child_data, dictionary) = match data.data_type() { + DataType::Dictionary(_, _) => ( + empty.as_slice(), + Box::into_raw(Box::new(FFI_ArrowArray::new(&data.child_data()[0]))), + ), + _ => (data.child_data(), std::ptr::null_mut()), + }; + + let children = child_data .iter() .map(|child| Box::into_raw(Box::new(FFI_ArrowArray::new(child)))) .collect::>(); @@ -400,6 +479,7 @@ impl FFI_ArrowArray { buffers, buffers_ptr, children, + dictionary, }); Self { @@ -410,7 +490,7 @@ impl FFI_ArrowArray { n_children, buffers: private_data.buffers_ptr.as_mut_ptr(), children: private_data.children.as_mut_ptr(), - dictionary: std::ptr::null_mut(), + dictionary, release: Some(release_array), private_data: Box::into_raw(private_data) as *mut c_void, } @@ -474,7 +554,8 @@ unsafe fn create_buffer( assert!(index < array.n_buffers as usize); let ptr = *buffers.add(index); - NonNull::new(ptr as *mut u8).map(|ptr| Buffer::from_unowned(ptr, len, owner)) + NonNull::new(ptr as *mut u8) + .map(|ptr| Buffer::from_custom_allocation(ptr, len, owner)) } fn create_child( @@ -506,7 +587,7 @@ pub trait ArrowArrayRef { let buffers = self.buffers()?; let null_bit_buffer = self.null_bit_buffer(); - let child_data = (0..self.array().n_children as usize) + let mut child_data: Vec = (0..self.array().n_children as usize) .map(|i| { let child = self.child(i); child.to_data() @@ -514,6 +595,13 @@ pub trait ArrowArrayRef { .map(|d| d.unwrap()) .collect(); + if let Some(d) = self.dictionary() { + // For dictionary type there should only be a single child, so we don't need to worry if + // there are other children added above. + assert!(child_data.is_empty()); + child_data.push(d.to_data()?); + } + // Should FFI be checking validity? Ok(unsafe { ArrayData::new_unchecked( @@ -553,10 +641,15 @@ pub trait ArrowArrayRef { // for variable-sized buffers, such as the second buffer of a stringArray, we need // to fetch offset buffer's len to build the second buffer. fn buffer_len(&self, i: usize) -> Result { - // Inner type is not important for buffer length. - let data_type = &self.data_type()?; + // Special handling for dictionary type as we only care about the key type in the case. + let t = self.data_type()?; + let data_type = match &t { + DataType::Dictionary(key_data_type, _) => key_data_type.as_ref(), + dt => dt, + }; - Ok(match (data_type, i) { + // Inner type is not important for buffer length. + Ok(match (&data_type, i) { (DataType::Utf8, 1) | (DataType::LargeUtf8, 1) | (DataType::Binary, 1) @@ -620,8 +713,24 @@ pub trait ArrowArrayRef { fn array(&self) -> &FFI_ArrowArray; fn schema(&self) -> &FFI_ArrowSchema; fn data_type(&self) -> Result; + fn dictionary(&self) -> Option { + unsafe { + assert!(!(self.array().dictionary.is_null() ^ self.schema().dictionary.is_null()), + "Dictionary should both be set or not set in FFI_ArrowArray and FFI_ArrowSchema"); + if !self.array().dictionary.is_null() { + Some(ArrowArrayChild::from_raw( + &*self.array().dictionary, + &*self.schema().dictionary, + self.owner().clone(), + )) + } else { + None + } + } + } } +#[allow(rustdoc::private_intra_doc_links)] /// Struct used to move an Array from and to the C Data Interface. /// Its main responsibility is to expose functionality that requires /// both [FFI_ArrowArray] and [FFI_ArrowSchema]. @@ -643,8 +752,8 @@ pub trait ArrowArrayRef { /// Furthermore, this struct assumes that the incoming data agrees with the C data interface. #[derive(Debug)] pub struct ArrowArray { - array: Arc, - schema: Arc, + pub(crate) array: Arc, + pub(crate) schema: Arc, } #[derive(Debug)] @@ -706,6 +815,9 @@ impl ArrowArray { /// creates a new [ArrowArray] from two pointers. Used to import from the C Data Interface. /// # Safety /// See safety of [ArrowArray] + /// Note that this function will copy the content pointed by the raw pointers. Considering + /// the raw pointers can be from `Arc::into_raw` or other raw pointers, users must be responsible + /// on managing the allocation of the structs by themselves. /// # Error /// Errors if any of the pointers is null pub unsafe fn try_from_raw( @@ -718,9 +830,16 @@ impl ArrowArray { .to_string(), )); }; + + let array_mut = array as *mut FFI_ArrowArray; + let schema_mut = schema as *mut FFI_ArrowSchema; + + let array_data = std::ptr::replace(array_mut, FFI_ArrowArray::empty()); + let schema_data = std::ptr::replace(schema_mut, FFI_ArrowSchema::empty()); + Ok(Self { - array: Arc::from_raw(array as *mut FFI_ArrowArray), - schema: Arc::from_raw(schema as *mut FFI_ArrowSchema), + array: Arc::new(array_data), + schema: Arc::new(schema_data), }) } @@ -757,13 +876,13 @@ impl<'a> ArrowArrayChild<'a> { mod tests { use super::*; use crate::array::{ - make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray, - DecimalBuilder, GenericBinaryArray, GenericListArray, GenericStringArray, - Int32Array, OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray, - TimestampMillisecondArray, + export_array_into_raw, make_array, Array, ArrayData, BooleanArray, DecimalArray, + DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, + GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array, + OffsetSizeTrait, Time32MillisecondArray, TimestampMillisecondArray, }; use crate::compute::kernels; - use crate::datatypes::Field; + use crate::datatypes::{Field, Int8Type}; use std::convert::TryFrom; #[test] @@ -790,13 +909,14 @@ mod tests { } #[test] + #[cfg(not(feature = "force_validate"))] fn test_decimal_round_trip() -> Result<()> { // create an array natively - let mut builder = DecimalBuilder::new(5, 6, 2); - builder.append_value(12345_i128).unwrap(); - builder.append_value(-12345_i128).unwrap(); - builder.append_null().unwrap(); - let original_array = builder.finish(); + let original_array = [Some(12345_i128), Some(-12345_i128), None] + .into_iter() + .collect::() + .with_precision_and_scale(6, 2) + .unwrap(); // export it let array = ArrowArray::try_from(original_array.data().clone())?; @@ -816,7 +936,7 @@ mod tests { } // case with nulls is tested in the docs, through the example on this module. - fn test_generic_string() -> Result<()> { + fn test_generic_string() -> Result<()> { // create an array natively let array = GenericStringArray::::from(vec![Some("a"), None, Some("aaa")]); @@ -928,7 +1048,7 @@ mod tests { test_generic_list::() } - fn test_generic_binary() -> Result<()> { + fn test_generic_binary() -> Result<()> { // create an array natively let array: Vec> = vec![Some(b"a"), None, Some(b"aaa")]; let array = GenericBinaryArray::::from(array); @@ -1070,4 +1190,213 @@ mod tests { // (drop/release) Ok(()) } + + #[test] + fn test_fixed_size_binary_array() -> Result<()> { + let values = vec![ + None, + Some(vec![10, 10, 10]), + None, + Some(vec![20, 20, 20]), + Some(vec![30, 30, 30]), + None, + ]; + let array = FixedSizeBinaryArray::try_from_sparse_iter(values.into_iter())?; + + // export it + let array = ArrowArray::try_from(array.data().clone())?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).unwrap(); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &FixedSizeBinaryArray::try_from_sparse_iter( + vec![ + None, + Some(vec![10, 10, 10]), + None, + Some(vec![20, 20, 20]), + Some(vec![30, 30, 30]), + None, + None, + Some(vec![10, 10, 10]), + None, + Some(vec![20, 20, 20]), + Some(vec![30, 30, 30]), + None, + ] + .into_iter() + )? + ); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_fixed_size_list_array() -> Result<()> { + // 0000 0100 + let mut validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut validity_bits, 2); + + let v: Vec = (0..9).into_iter().collect(); + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&v)) + .build()?; + + let list_data_type = + DataType::FixedSizeList(Box::new(Field::new("f", DataType::Int32, false)), 3); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .null_bit_buffer(Some(Buffer::from(validity_bits))) + .add_child_data(value_data) + .build()?; + + // export it + let array = ArrowArray::try_from(list_data)?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).unwrap(); + let array = array.as_any().downcast_ref::().unwrap(); + + // 0010 0100 + let mut expected_validity_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut expected_validity_bits, 2); + bit_util::set_bit(&mut expected_validity_bits, 5); + + let mut w = vec![]; + w.extend_from_slice(&v); + w.extend_from_slice(&v); + + let expected_value_data = ArrayData::builder(DataType::Int32) + .len(18) + .add_buffer(Buffer::from_slice_ref(&w)) + .build()?; + + let expected_list_data = ArrayData::builder(list_data_type) + .len(6) + .null_bit_buffer(Some(Buffer::from(expected_validity_bits))) + .add_child_data(expected_value_data) + .build()?; + let expected_array = FixedSizeListArray::from(expected_list_data); + + // verify + assert_eq!(array, &expected_array); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_dictionary() -> Result<()> { + // create an array natively + let values = vec!["a", "aaa", "aaa"]; + let dict_array: DictionaryArray = values.into_iter().collect(); + + // export it + let array = ArrowArray::try_from(dict_array.data().clone())?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).unwrap(); + let actual = array + .as_any() + .downcast_ref::>() + .unwrap(); + + // verify + let new_values = vec!["a", "aaa", "aaa", "a", "aaa", "aaa"]; + let expected: DictionaryArray = new_values.into_iter().collect(); + assert_eq!(actual, &expected); + + // (drop/release) + Ok(()) + } + + #[test] + fn test_export_array_into_raw() -> Result<()> { + let array = make_array(Int32Array::from(vec![1, 2, 3]).data().clone()); + + // Assume two raw pointers provided by the consumer + let out_array = Box::new(FFI_ArrowArray::empty()); + let out_schema = Box::new(FFI_ArrowSchema::empty()); + let out_array_ptr = Box::into_raw(out_array); + let out_schema_ptr = Box::into_raw(out_schema); + + unsafe { + export_array_into_raw(array, out_array_ptr, out_schema_ptr)?; + } + + // (simulate consumer) import it + unsafe { + let array = ArrowArray::try_from_raw(out_array_ptr, out_schema_ptr).unwrap(); + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = array.as_any().downcast_ref::().unwrap(); + let array = kernels::arithmetic::add(array, array).unwrap(); + + // verify + assert_eq!(array, Int32Array::from(vec![2, 4, 6])); + + Box::from_raw(out_array_ptr); + Box::from_raw(out_schema_ptr); + } + Ok(()) + } + + #[test] + fn test_duration() -> Result<()> { + // create an array natively + let array = DurationSecondArray::from(vec![None, Some(1), Some(2)]); + + // export it + let array = ArrowArray::try_from(array.data().clone())?; + + // (simulate consumer) import it + let data = ArrayData::try_from(array)?; + let array = make_array(data); + + // perform some operation + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).unwrap(); + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + // verify + assert_eq!( + array, + &DurationSecondArray::from(vec![ + None, + Some(1), + Some(2), + None, + Some(1), + Some(2) + ]) + ); + + // (drop/release) + Ok(()) + } } diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs new file mode 100644 index 000000000000..3a85f2ef6421 --- /dev/null +++ b/arrow/src/ffi_stream.rs @@ -0,0 +1,553 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains declarations to bind to the [C Stream Interface](https://arrow.apache.org/docs/format/CStreamInterface.html). +//! +//! This module has two main interfaces: +//! One interface maps C ABI to native Rust types, i.e. convert c-pointers, c_char, to native rust. +//! This is handled by [FFI_ArrowArrayStream]. +//! +//! The second interface is used to import `FFI_ArrowArrayStream` as Rust implementation `RecordBatch` reader. +//! This is handled by `ArrowArrayStreamReader`. +//! +//! ```ignore +//! # use std::fs::File; +//! # use std::sync::Arc; +//! # use arrow::error::Result; +//! # use arrow::ffi_stream::{export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream}; +//! # use arrow::ipc::reader::FileReader; +//! # use arrow::record_batch::RecordBatchReader; +//! # fn main() -> Result<()> { +//! // create an record batch reader natively +//! let file = File::open("arrow_file").unwrap(); +//! let reader = Box::new(FileReader::try_new(file).unwrap()); +//! +//! // export it +//! let stream = Box::new(FFI_ArrowArrayStream::empty()); +//! let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream; +//! unsafe { export_reader_into_raw(reader, stream_ptr) }; +//! +//! // consumed and used by something else... +//! +//! // import it +//! let stream_reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() }; +//! let imported_schema = stream_reader.schema(); +//! +//! let mut produced_batches = vec![]; +//! for batch in stream_reader { +//! produced_batches.push(batch.unwrap()); +//! } +//! +//! // (drop/release) +//! unsafe { +//! Box::from_raw(stream_ptr); +//! } +//! Ok(()) +//! } +//! ``` + +use std::{ + convert::TryFrom, + ffi::CString, + os::raw::{c_char, c_int, c_void}, + sync::Arc, +}; + +use crate::array::Array; +use crate::array::StructArray; +use crate::datatypes::{Schema, SchemaRef}; +use crate::error::ArrowError; +use crate::error::Result; +use crate::ffi::*; +use crate::record_batch::{RecordBatch, RecordBatchReader}; + +const ENOMEM: i32 = 12; +const EIO: i32 = 5; +const EINVAL: i32 = 22; +const ENOSYS: i32 = 78; + +/// ABI-compatible struct for `ArrayStream` from C Stream Interface +/// See +/// This was created by bindgen +#[repr(C)] +#[derive(Debug)] +pub struct FFI_ArrowArrayStream { + pub get_schema: Option< + unsafe extern "C" fn( + arg1: *mut FFI_ArrowArrayStream, + out: *mut FFI_ArrowSchema, + ) -> c_int, + >, + pub get_next: Option< + unsafe extern "C" fn( + arg1: *mut FFI_ArrowArrayStream, + out: *mut FFI_ArrowArray, + ) -> c_int, + >, + pub get_last_error: + Option *const c_char>, + pub release: Option, + pub private_data: *mut c_void, +} + +// callback used to drop [FFI_ArrowArrayStream] when it is exported. +unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) { + if stream.is_null() { + return; + } + let stream = &mut *stream; + + stream.get_schema = None; + stream.get_next = None; + stream.get_last_error = None; + + let private_data = Box::from_raw(stream.private_data as *mut StreamPrivateData); + drop(private_data); + + stream.release = None; +} + +struct StreamPrivateData { + batch_reader: Box, + last_error: String, +} + +// The callback used to get array schema +unsafe extern "C" fn get_schema( + stream: *mut FFI_ArrowArrayStream, + schema: *mut FFI_ArrowSchema, +) -> c_int { + ExportedArrayStream { stream }.get_schema(schema) +} + +// The callback used to get next array +unsafe extern "C" fn get_next( + stream: *mut FFI_ArrowArrayStream, + array: *mut FFI_ArrowArray, +) -> c_int { + ExportedArrayStream { stream }.get_next(array) +} + +// The callback used to get the error from last operation on the `FFI_ArrowArrayStream` +unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char { + let mut ffi_stream = ExportedArrayStream { stream }; + let last_error = ffi_stream.get_last_error(); + CString::new(last_error.as_str()).unwrap().into_raw() +} + +impl Drop for FFI_ArrowArrayStream { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +impl FFI_ArrowArrayStream { + /// Creates a new [`FFI_ArrowArrayStream`]. + pub fn new(batch_reader: Box) -> Self { + let private_data = Box::new(StreamPrivateData { + batch_reader, + last_error: String::new(), + }); + + Self { + get_schema: Some(get_schema), + get_next: Some(get_next), + get_last_error: Some(get_last_error), + release: Some(release_stream), + private_data: Box::into_raw(private_data) as *mut c_void, + } + } + + /// Creates a new empty [FFI_ArrowArrayStream]. Used to import from the C Stream Interface. + pub fn empty() -> Self { + Self { + get_schema: None, + get_next: None, + get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + } + } +} + +struct ExportedArrayStream { + stream: *mut FFI_ArrowArrayStream, +} + +impl ExportedArrayStream { + fn get_private_data(&mut self) -> &mut StreamPrivateData { + unsafe { &mut *((*self.stream).private_data as *mut StreamPrivateData) } + } + + pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 { + let mut private_data = self.get_private_data(); + let reader = &private_data.batch_reader; + + let schema = FFI_ArrowSchema::try_from(reader.schema().as_ref()); + + match schema { + Ok(mut schema) => unsafe { + std::ptr::copy(&schema as *const FFI_ArrowSchema, out, 1); + schema.release = None; + 0 + }, + Err(ref err) => { + private_data.last_error = err.to_string(); + get_error_code(err) + } + } + } + + pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 { + let mut private_data = self.get_private_data(); + let reader = &mut private_data.batch_reader; + + let ret_code = match reader.next() { + None => { + // Marks ArrowArray released to indicate reaching the end of stream. + unsafe { + (*out).release = None; + } + 0 + } + Some(next_batch) => { + if let Ok(batch) = next_batch { + let struct_array = StructArray::from(batch); + let mut array = FFI_ArrowArray::new(struct_array.data()); + + unsafe { + std::ptr::copy(&array as *const FFI_ArrowArray, out, 1); + array.release = None; + 0 + } + } else { + let err = &next_batch.unwrap_err(); + private_data.last_error = err.to_string(); + get_error_code(err) + } + } + }; + + ret_code + } + + pub fn get_last_error(&mut self) -> &String { + &self.get_private_data().last_error + } +} + +fn get_error_code(err: &ArrowError) -> i32 { + match err { + ArrowError::NotYetImplemented(_) => ENOSYS, + ArrowError::MemoryError(_) => ENOMEM, + ArrowError::IoError(_) => EIO, + _ => EINVAL, + } +} + +/// A `RecordBatchReader` which imports Arrays from `FFI_ArrowArrayStream`. +/// Struct used to fetch `RecordBatch` from the C Stream Interface. +/// Its main responsibility is to expose `RecordBatchReader` functionality +/// that requires [FFI_ArrowArrayStream]. +#[derive(Debug, Clone)] +pub struct ArrowArrayStreamReader { + stream: Arc, + schema: SchemaRef, +} + +/// Gets schema from a raw pointer of `FFI_ArrowArrayStream`. This is used when constructing +/// `ArrowArrayStreamReader` to cache schema. +fn get_stream_schema(stream_ptr: *mut FFI_ArrowArrayStream) -> Result { + let empty_schema = Arc::new(FFI_ArrowSchema::empty()); + let schema_ptr = Arc::into_raw(empty_schema) as *mut FFI_ArrowSchema; + + let ret_code = unsafe { (*stream_ptr).get_schema.unwrap()(stream_ptr, schema_ptr) }; + + let ffi_schema = unsafe { Arc::from_raw(schema_ptr) }; + + if ret_code == 0 { + let schema = Schema::try_from(ffi_schema.as_ref()).unwrap(); + Ok(Arc::new(schema)) + } else { + Err(ArrowError::CDataInterface(format!( + "Cannot get schema from input stream. Error code: {:?}", + ret_code + ))) + } +} + +impl ArrowArrayStreamReader { + /// Creates a new `ArrowArrayStreamReader` from a `FFI_ArrowArrayStream`. + /// This is used to import from the C Stream Interface. + #[allow(dead_code)] + pub fn try_new(stream: FFI_ArrowArrayStream) -> Result { + if stream.release.is_none() { + return Err(ArrowError::CDataInterface( + "input stream is already released".to_string(), + )); + } + + let stream_ptr = Arc::into_raw(Arc::new(stream)) as *mut FFI_ArrowArrayStream; + + let schema = get_stream_schema(stream_ptr)?; + + Ok(Self { + stream: unsafe { Arc::from_raw(stream_ptr) }, + schema, + }) + } + + /// Creates a new `ArrowArrayStreamReader` from a raw pointer of `FFI_ArrowArrayStream`. + /// + /// Assumes that the pointer represents valid C Stream Interfaces. + /// This function copies the content from the raw pointer and cleans up it to prevent + /// double-dropping. The caller is responsible for freeing up the memory allocated for + /// the pointer. + /// + /// # Safety + /// This function dereferences a raw pointer of `FFI_ArrowArrayStream`. + pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Result { + let stream_data = std::ptr::replace(raw_stream, FFI_ArrowArrayStream::empty()); + + Self::try_new(stream_data) + } + + /// Get the last error from `ArrowArrayStreamReader` + fn get_stream_last_error(&self) -> Option { + self.stream.get_last_error?; + + let stream_ptr = Arc::as_ptr(&self.stream) as *mut FFI_ArrowArrayStream; + + let error_str = unsafe { + let c_str = self.stream.get_last_error.unwrap()(stream_ptr) as *mut c_char; + CString::from_raw(c_str).into_string() + }; + + if let Err(err) = error_str { + Some(err.to_string()) + } else { + Some(error_str.unwrap()) + } + } +} + +impl Iterator for ArrowArrayStreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + let stream_ptr = Arc::as_ptr(&self.stream) as *mut FFI_ArrowArrayStream; + + let empty_array = Arc::new(FFI_ArrowArray::empty()); + let array_ptr = Arc::into_raw(empty_array) as *mut FFI_ArrowArray; + + let ret_code = unsafe { self.stream.get_next.unwrap()(stream_ptr, array_ptr) }; + + if ret_code == 0 { + let ffi_array = unsafe { Arc::from_raw(array_ptr) }; + + // The end of stream has been reached + ffi_array.release?; + + let schema_ref = self.schema(); + let schema = FFI_ArrowSchema::try_from(schema_ref.as_ref()).ok()?; + + let data = ArrowArray { + array: ffi_array, + schema: Arc::new(schema), + } + .to_data() + .ok()?; + + let record_batch = RecordBatch::from(&StructArray::from(data)); + + Some(Ok(record_batch)) + } else { + unsafe { Arc::from_raw(array_ptr) }; + + let last_error = self.get_stream_last_error(); + let err = ArrowError::CDataInterface(last_error.unwrap()); + Some(Err(err)) + } + } +} + +impl RecordBatchReader for ArrowArrayStreamReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Exports a record batch reader to raw pointer of the C Stream Interface provided by the consumer. +/// +/// # Safety +/// Assumes that the pointer represents valid C Stream Interfaces, both in memory +/// representation and lifetime via the `release` mechanism. +pub unsafe fn export_reader_into_raw( + reader: Box, + out_stream: *mut FFI_ArrowArrayStream, +) { + let stream = FFI_ArrowArrayStream::new(reader); + + std::ptr::write_unaligned(out_stream, stream); +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::Int32Array; + use crate::datatypes::{Field, Schema}; + + struct TestRecordBatchReader { + schema: SchemaRef, + iter: Box>>, + } + + impl TestRecordBatchReader { + pub fn new( + schema: SchemaRef, + iter: Box>>, + ) -> Box { + Box::new(TestRecordBatchReader { schema, iter }) + } + } + + impl Iterator for TestRecordBatchReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.iter.next() + } + } + + impl RecordBatchReader for TestRecordBatchReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + } + + fn _test_round_trip_export(arrays: Vec>) -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", arrays[0].data_type().clone(), true), + Field::new("b", arrays[1].data_type().clone(), true), + Field::new("c", arrays[2].data_type().clone(), true), + ])); + let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); + let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _; + + let reader = TestRecordBatchReader::new(schema.clone(), iter); + + // Export a `RecordBatchReader` through `FFI_ArrowArrayStream` + let stream = Arc::new(FFI_ArrowArrayStream::empty()); + let stream_ptr = Arc::into_raw(stream) as *mut FFI_ArrowArrayStream; + + unsafe { export_reader_into_raw(reader, stream_ptr) }; + + let empty_schema = Arc::new(FFI_ArrowSchema::empty()); + let schema_ptr = Arc::into_raw(empty_schema) as *mut FFI_ArrowSchema; + + // Get schema from `FFI_ArrowArrayStream` + let ret_code = unsafe { get_schema(stream_ptr, schema_ptr) }; + assert_eq!(ret_code, 0); + + let ffi_schema = unsafe { Arc::from_raw(schema_ptr) }; + + let exported_schema = Schema::try_from(ffi_schema.as_ref()).unwrap(); + assert_eq!(&exported_schema, schema.as_ref()); + + // Get array from `FFI_ArrowArrayStream` + let mut produced_batches = vec![]; + loop { + let empty_array = Arc::new(FFI_ArrowArray::empty()); + let array_ptr = Arc::into_raw(empty_array.clone()) as *mut FFI_ArrowArray; + + let ret_code = unsafe { get_next(stream_ptr, array_ptr) }; + assert_eq!(ret_code, 0); + + // The end of stream has been reached + let ffi_array = unsafe { Arc::from_raw(array_ptr) }; + if ffi_array.release.is_none() { + break; + } + + let array = ArrowArray { + array: ffi_array, + schema: ffi_schema.clone(), + } + .to_data() + .unwrap(); + + let record_batch = RecordBatch::from(&StructArray::from(array)); + produced_batches.push(record_batch); + } + + assert_eq!(produced_batches, vec![batch.clone(), batch]); + + unsafe { Arc::from_raw(stream_ptr) }; + Ok(()) + } + + fn _test_round_trip_import(arrays: Vec>) -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", arrays[0].data_type().clone(), true), + Field::new("b", arrays[1].data_type().clone(), true), + Field::new("c", arrays[2].data_type().clone(), true), + ])); + let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); + let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _; + + let reader = TestRecordBatchReader::new(schema.clone(), iter); + + // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader` + let stream = Arc::new(FFI_ArrowArrayStream::new(reader)); + let stream_ptr = Arc::into_raw(stream) as *mut FFI_ArrowArrayStream; + let stream_reader = + unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() }; + + let imported_schema = stream_reader.schema(); + assert_eq!(imported_schema, schema); + + let mut produced_batches = vec![]; + for batch in stream_reader { + produced_batches.push(batch.unwrap()); + } + + assert_eq!(produced_batches, vec![batch.clone(), batch]); + + unsafe { Arc::from_raw(stream_ptr) }; + Ok(()) + } + + #[test] + fn test_stream_round_trip_export() -> Result<()> { + let array = Int32Array::from(vec![Some(2), None, Some(1), None]); + let array: Arc = Arc::new(array); + + _test_round_trip_export(vec![array.clone(), array.clone(), array]) + } + + #[test] + fn test_stream_round_trip_import() -> Result<()> { + let array = Int32Array::from(vec![Some(2), None, Some(1), None]); + let array: Arc = Arc::new(array); + + _test_round_trip_import(vec![array.clone(), array.clone(), array]) + } +} diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs index dcc9fcc84a0f..c81ea8278c4f 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow/src/ipc/convert.rs @@ -17,7 +17,7 @@ //! Utilities for converting between IPC types and native Arrow types -use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; +use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}; use crate::error::{ArrowError, Result}; use crate::ipc; @@ -72,7 +72,7 @@ pub fn schema_to_fb_offset<'a>( /// Convert an IPC Field to Arrow Field impl<'a> From> for Field { fn from(field: ipc::Field) -> Field { - let mut arrow_field = if let Some(dictionary) = field.dictionary() { + let arrow_field = if let Some(dictionary) = field.dictionary() { Field::new_dict( field.name().unwrap(), get_data_type(field, true), @@ -99,8 +99,7 @@ impl<'a> From> for Field { metadata = Some(metadata_map); } - arrow_field.set_metadata(metadata); - arrow_field + arrow_field.with_metadata(metadata) } } @@ -263,6 +262,9 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT DataType::Interval(IntervalUnit::YearMonth) } ipc::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime), + ipc::IntervalUnit::MONTH_DAY_NANO => { + DataType::Interval(IntervalUnit::MonthDayNano) + } z => panic!("Interval type with unit of {:?} unsupported", z), } } @@ -320,6 +322,29 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT let fsb = field.type_as_decimal().unwrap(); DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) } + ipc::Type::Union => { + let union = field.type_as_union().unwrap(); + + let union_mode = match union.mode() { + ipc::UnionMode::Dense => UnionMode::Dense, + ipc::UnionMode::Sparse => UnionMode::Sparse, + mode => panic!("Unexpected union mode: {:?}", mode), + }; + + let mut fields = vec![]; + if let Some(children) = field.children() { + for i in 0..children.len() { + fields.push(children.get(i).into()); + } + }; + + let type_ids: Vec = match union.typeIds() { + None => (0_i8..fields.len() as i8).collect(), + Some(ids) => ids.iter().map(|i| i as i8).collect(), + }; + + DataType::Union(fields, type_ids, union_mode) + } t => unimplemented!("Type {:?} not supported", t), } } @@ -533,7 +558,7 @@ pub(crate) fn get_fb_field_type<'a>( } } Timestamp(unit, tz) => { - let tz = tz.clone().unwrap_or_else(String::new); + let tz = tz.clone().unwrap_or_default(); let tz_str = fbb.create_string(tz.as_str()); let mut builder = ipc::TimestampBuilder::new(fbb); let time_unit = match unit { @@ -557,6 +582,7 @@ pub(crate) fn get_fb_field_type<'a>( let interval_unit = match unit { IntervalUnit::YearMonth => ipc::IntervalUnit::YEAR_MONTH, IntervalUnit::DayTime => ipc::IntervalUnit::DAY_TIME, + IntervalUnit::MonthDayNano => ipc::IntervalUnit::MONTH_DAY_NANO, }; builder.add_unit(interval_unit); FBFieldType { @@ -645,7 +671,29 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } - t => unimplemented!("Type {:?} not supported", t), + Union(fields, type_ids, mode) => { + let mut children = vec![]; + for field in fields { + children.push(build_field(fbb, field)); + } + + let union_mode = match mode { + UnionMode::Sparse => ipc::UnionMode::Sparse, + UnionMode::Dense => ipc::UnionMode::Dense, + }; + + let fbb_type_ids = fbb + .create_vector(&type_ids.iter().map(|t| *t as i32).collect::>()); + let mut builder = ipc::UnionBuilder::new(fbb); + builder.add_mode(union_mode); + builder.add_typeIds(fbb_type_ids); + + FBFieldType { + type_type: ipc::Type::Union, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&children[..])), + } + } } } @@ -687,7 +735,7 @@ pub(crate) fn get_fb_dictionary<'a>( #[cfg(test)] mod tests { use super::*; - use crate::datatypes::{DataType, Field, Schema}; + use crate::datatypes::{DataType, Field, Schema, UnionMode}; #[test] fn convert_schema_round_trip() { @@ -701,11 +749,7 @@ mod tests { .collect(); let schema = Schema::new_with_metadata( vec![ - { - let mut f = Field::new("uint8", DataType::UInt8, false); - f.set_metadata(Some(field_md)); - f - }, + Field::new("uint8", DataType::UInt8, false).with_metadata(Some(field_md)), Field::new("uint16", DataType::UInt16, true), Field::new("uint32", DataType::UInt32, false), Field::new("uint64", DataType::UInt64, true), @@ -757,6 +801,11 @@ mod tests { DataType::Interval(IntervalUnit::DayTime), true, ), + Field::new( + "interval[mdn]", + DataType::Interval(IntervalUnit::MonthDayNano), + true, + ), Field::new("utf8", DataType::Utf8, false), Field::new("binary", DataType::Binary, false), Field::new( @@ -816,7 +865,68 @@ mod tests { ]), false, ), + Field::new( + "union]>]>", + DataType::Union( + vec![ + Field::new("int64", DataType::Int64, true), + Field::new( + "list[union]>]", + DataType::List(Box::new(Field::new( + "union]>", + DataType::Union( + vec![ + Field::new("date32", DataType::Date32, true), + Field::new( + "list[union<>]", + DataType::List(Box::new(Field::new( + "union", + DataType::Union( + vec![], + vec![], + UnionMode::Sparse, + ), + false, + ))), + false, + ), + ], + vec![0, 1], + UnionMode::Dense, + ), + false, + ))), + false, + ), + ], + vec![0, 1], + UnionMode::Sparse, + ), + false, + ), Field::new("struct<>", DataType::Struct(vec![]), true), + Field::new( + "union<>", + DataType::Union(vec![], vec![], UnionMode::Dense), + true, + ), + Field::new( + "union<>", + DataType::Union(vec![], vec![], UnionMode::Sparse), + true, + ), + Field::new( + "union", + DataType::Union( + vec![ + Field::new("int32", DataType::Int32, true), + Field::new("utf8", DataType::Utf8, true), + ], + vec![2, 3], // non-default type ids + UnionMode::Dense, + ), + true, + ), Field::new_dict( "dictionary", DataType::Dictionary( diff --git a/arrow/src/ipc/gen/Schema.rs b/arrow/src/ipc/gen/Schema.rs index 12af5b5b0806..dd204e0704df 100644 --- a/arrow/src/ipc/gen/Schema.rs +++ b/arrow/src/ipc/gen/Schema.rs @@ -639,8 +639,11 @@ pub const ENUM_MAX_INTERVAL_UNIT: i16 = 1; note = "Use associated constants instead. This will no longer be generated in 2021." )] #[allow(non_camel_case_types)] -pub const ENUM_VALUES_INTERVAL_UNIT: [IntervalUnit; 2] = - [IntervalUnit::YEAR_MONTH, IntervalUnit::DAY_TIME]; +pub const ENUM_VALUES_INTERVAL_UNIT: [IntervalUnit; 3] = [ + IntervalUnit::YEAR_MONTH, + IntervalUnit::DAY_TIME, + IntervalUnit::MONTH_DAY_NANO, +]; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] @@ -649,15 +652,18 @@ pub struct IntervalUnit(pub i16); impl IntervalUnit { pub const YEAR_MONTH: Self = Self(0); pub const DAY_TIME: Self = Self(1); + pub const MONTH_DAY_NANO: Self = Self(2); pub const ENUM_MIN: i16 = 0; - pub const ENUM_MAX: i16 = 1; - pub const ENUM_VALUES: &'static [Self] = &[Self::YEAR_MONTH, Self::DAY_TIME]; + pub const ENUM_MAX: i16 = 2; + pub const ENUM_VALUES: &'static [Self] = + &[Self::YEAR_MONTH, Self::DAY_TIME, Self::MONTH_DAY_NANO]; /// Returns the variant's name or "" if unknown. pub fn variant_name(self) -> Option<&'static str> { match self { Self::YEAR_MONTH => Some("YEAR_MONTH"), Self::DAY_TIME => Some("DAY_TIME"), + Self::MONTH_DAY_NANO => Some("MONTH_DAY_NANO"), _ => None, } } diff --git a/arrow/src/ipc/mod.rs b/arrow/src/ipc/mod.rs index a2d7103aacfc..d5455b454e7f 100644 --- a/arrow/src/ipc/mod.rs +++ b/arrow/src/ipc/mod.rs @@ -27,6 +27,7 @@ pub mod writer; #[allow(clippy::extra_unused_lifetimes)] #[allow(clippy::redundant_static_lifetimes)] #[allow(clippy::redundant_field_names)] +#[allow(non_camel_case_types)] pub mod gen; pub use self::gen::File::*; diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index e925e2af8f4b..1ac519382199 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -27,10 +27,10 @@ use std::sync::Arc; use crate::array::*; use crate::buffer::Buffer; use crate::compute::cast; -use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef}; +use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef, UnionMode}; use crate::error::{ArrowError, Result}; use crate::ipc; -use crate::record_batch::{RecordBatch, RecordBatchReader}; +use crate::record_batch::{RecordBatch, RecordBatchOptions, RecordBatchReader}; use ipc::CONTINUATION_MARKER; use DataType::*; @@ -52,16 +52,19 @@ fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer { /// - check if the bit width of non-64-bit numbers is 64, and /// - read the buffer as 64-bit (signed integer or float), and /// - cast the 64-bit array to the appropriate data type +#[allow(clippy::too_many_arguments)] fn create_array( nodes: &[ipc::FieldNode], - data_type: &DataType, + field: &Field, data: &[u8], buffers: &[ipc::Buffer], - dictionaries: &[Option], + dictionaries_by_id: &HashMap, mut node_index: usize, mut buffer_index: usize, -) -> (ArrayRef, usize, usize) { + metadata: &ipc::MetadataVersion, +) -> Result<(ArrayRef, usize, usize)> { use DataType::*; + let data_type = field.data_type(); let array = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { let array = create_primitive_array( @@ -99,13 +102,14 @@ fn create_array( buffer_index += 2; let triple = create_array( nodes, - list_field.data_type(), + list_field, data, buffers, - dictionaries, + dictionaries_by_id, node_index, buffer_index, - ); + metadata, + )?; node_index = triple.1; buffer_index = triple.2; @@ -121,13 +125,14 @@ fn create_array( buffer_index += 1; let triple = create_array( nodes, - list_field.data_type(), + list_field, data, buffers, - dictionaries, + dictionaries_by_id, node_index, buffer_index, - ); + metadata, + )?; node_index = triple.1; buffer_index = triple.2; @@ -146,13 +151,14 @@ fn create_array( for struct_field in struct_fields { let triple = create_array( nodes, - struct_field.data_type(), + struct_field, data, buffers, - dictionaries, + dictionaries_by_id, node_index, buffer_index, - ); + metadata, + )?; node_index = triple.1; buffer_index = triple.2; struct_arrays.push((struct_field.clone(), triple.0)); @@ -173,7 +179,17 @@ fn create_array( .iter() .map(|buf| read_buffer(buf, data)) .collect(); - let value_array = dictionaries[node_index].clone().unwrap(); + + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::IoError(format!("Field {} does not have dict id", field)) + })?; + + let value_array = dictionaries_by_id.get(&dict_id).ok_or_else(|| { + ArrowError::IoError(format!( + "Cannot find a dictionary batch with dict id: {}", + dict_id + )) + })?; node_index += 1; buffer_index += 2; @@ -181,13 +197,73 @@ fn create_array( index_node, data_type, &index_buffers[..], - value_array, + value_array.clone(), ) } + Union(fields, field_type_ids, mode) => { + let union_node = nodes[node_index]; + node_index += 1; + + let len = union_node.length() as usize; + + // In V4, union types has validity bitmap + // In V5 and later, union types have no validity bitmap + if metadata < &ipc::MetadataVersion::V5 { + read_buffer(&buffers[buffer_index], data); + buffer_index += 1; + } + + let type_ids: Buffer = + read_buffer(&buffers[buffer_index], data)[..len].into(); + + buffer_index += 1; + + let value_offsets = match mode { + UnionMode::Dense => { + let buffer = read_buffer(&buffers[buffer_index], data); + buffer_index += 1; + Some(buffer[..len * 4].into()) + } + UnionMode::Sparse => None, + }; + + let mut children = vec![]; + + for field in fields { + let triple = create_array( + nodes, + field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + metadata, + )?; + + node_index = triple.1; + buffer_index = triple.2; + + children.push((field.clone(), triple.0)); + } + + let array = + UnionArray::try_new(field_type_ids, type_ids, value_offsets, children)?; + Arc::new(array) + } Null => { - let length = nodes[node_index].length() as usize; + let length = nodes[node_index].length(); + let null_count = nodes[node_index].null_count(); + + if length != null_count { + return Err(ArrowError::IoError(format!( + "Field {} of NullArray has unequal null_count {} and len {}", + field, null_count, length + ))); + } + let data = ArrayData::builder(data_type.clone()) - .len(length) + .len(length as usize) .offset(0) .build() .unwrap(); @@ -209,7 +285,121 @@ fn create_array( array } }; - (array, node_index, buffer_index) + Ok((array, node_index, buffer_index)) +} + +/// Skip fields based on data types to advance `node_index` and `buffer_index`. +/// This function should be called when doing projection in fn `read_record_batch`. +/// The advancement logic references fn `create_array`. +fn skip_field( + nodes: &[ipc::FieldNode], + field: &Field, + data: &[u8], + buffers: &[ipc::Buffer], + dictionaries_by_id: &HashMap, + mut node_index: usize, + mut buffer_index: usize, +) -> Result<(usize, usize)> { + use DataType::*; + let data_type = field.data_type(); + match data_type { + Utf8 | Binary | LargeBinary | LargeUtf8 => { + node_index += 1; + buffer_index += 3; + } + FixedSizeBinary(_) => { + node_index += 1; + buffer_index += 2; + } + List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { + node_index += 1; + buffer_index += 2; + let tuple = skip_field( + nodes, + list_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } + FixedSizeList(ref list_field, _) => { + node_index += 1; + buffer_index += 1; + let tuple = skip_field( + nodes, + list_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } + Struct(struct_fields) => { + node_index += 1; + buffer_index += 1; + + // skip for each field + for struct_field in struct_fields { + let tuple = skip_field( + nodes, + struct_field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } + } + Dictionary(_, _) => { + node_index += 1; + buffer_index += 2; + } + Union(fields, _field_type_ids, mode) => { + node_index += 1; + buffer_index += 1; + + match mode { + UnionMode::Dense => { + buffer_index += 1; + } + UnionMode::Sparse => {} + }; + + for field in fields { + let tuple = skip_field( + nodes, + field, + data, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + + node_index = tuple.0; + buffer_index = tuple.1; + } + } + Null => { + node_index += 1; + // no buffer increases + } + _ => { + node_index += 1; + buffer_index += 2; + } + }; + Ok((node_index, buffer_index)) } /// Reads the correct number of buffers based on data type and null_count, and creates a @@ -224,24 +414,22 @@ fn create_primitive_array( let array_data = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { // read 3 buffers - let mut builder = ArrayData::builder(data_type.clone()) + ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..3].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } - builder.build().unwrap() + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())) + .build() + .unwrap() } FixedSizeBinary(_) => { // read 3 buffers - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..2].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + unsafe { builder.build_unchecked() } } Int8 @@ -255,52 +443,48 @@ fn create_primitive_array( | Interval(IntervalUnit::YearMonth) => { if buffers[1].len() / 8 == length && length != 1 { // interpret as a signed i64, and cast appropriately - let mut builder = ArrayData::builder(DataType::Int64) + let builder = ArrayData::builder(DataType::Int64) .len(length) .buffers(buffers[1..].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + let data = unsafe { builder.build_unchecked() }; let values = Arc::new(Int64Array::from(data)) as ArrayRef; // this cast is infallible, the unwrap is safe let casted = cast(&values, data_type).unwrap(); casted.data().clone() } else { - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + unsafe { builder.build_unchecked() } } } Float32 => { if buffers[1].len() / 8 == length && length != 1 { // interpret as a f64, and cast appropriately - let mut builder = ArrayData::builder(DataType::Float64) + let builder = ArrayData::builder(DataType::Float64) .len(length) .buffers(buffers[1..].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + let data = unsafe { builder.build_unchecked() }; let values = Arc::new(Float64Array::from(data)) as ArrayRef; // this cast is infallible, the unwrap is safe let casted = cast(&values, data_type).unwrap(); casted.data().clone() } else { - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + unsafe { builder.build_unchecked() } } } @@ -312,25 +496,24 @@ fn create_primitive_array( | Timestamp(_, _) | Date64 | Duration(_) - | Interval(IntervalUnit::DayTime) => { - let mut builder = ArrayData::builder(data_type.clone()) + | Interval(IntervalUnit::DayTime) + | Interval(IntervalUnit::MonthDayNano) => { + let builder = ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + unsafe { builder.build_unchecked() } } Decimal(_, _) => { // read 3 buffers - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..2].to_vec()) - .offset(0); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .offset(0) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + unsafe { builder.build_unchecked() } } t => panic!("Data type {:?} either unsupported or not primitive", t), @@ -349,36 +532,33 @@ fn create_list_array( ) -> ArrayRef { if let DataType::List(_) | DataType::LargeList(_) = *data_type { let null_count = field_node.null_count() as usize; - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) .buffers(buffers[1..2].to_vec()) .offset(0) - .child_data(vec![child_array.data().clone()]); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .child_data(vec![child_array.data().clone()]) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + make_array(unsafe { builder.build_unchecked() }) } else if let DataType::FixedSizeList(_, _) = *data_type { let null_count = field_node.null_count() as usize; - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) .buffers(buffers[1..1].to_vec()) .offset(0) - .child_data(vec![child_array.data().clone()]); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .child_data(vec![child_array.data().clone()]) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + make_array(unsafe { builder.build_unchecked() }) } else if let DataType::Map(_, _) = *data_type { let null_count = field_node.null_count() as usize; - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) .buffers(buffers[1..2].to_vec()) .offset(0) - .child_data(vec![child_array.data().clone()]); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .child_data(vec![child_array.data().clone()]) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + make_array(unsafe { builder.build_unchecked() }) } else { panic!("Cannot create list or map array from {:?}", data_type) @@ -395,14 +575,13 @@ fn create_dictionary_array( ) -> ArrayRef { if let DataType::Dictionary(_, _) = *data_type { let null_count = field_node.null_count() as usize; - let mut builder = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) .buffers(buffers[1..2].to_vec()) .offset(0) - .child_data(vec![value_array.data().clone()]); - if null_count > 0 { - builder = builder.null_bit_buffer(buffers[0].clone()) - } + .child_data(vec![value_array.data().clone()]) + .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + make_array(unsafe { builder.build_unchecked() }) } else { unreachable!("Cannot create dictionary array from {:?}", data_type) @@ -414,7 +593,9 @@ pub fn read_record_batch( buf: &[u8], batch: ipc::RecordBatch, schema: SchemaRef, - dictionaries: &[Option], + dictionaries_by_id: &HashMap, + projection: Option<&[usize]>, + metadata: &ipc::MetadataVersion, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IoError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -427,32 +608,80 @@ pub fn read_record_batch( let mut node_index = 0; let mut arrays = vec![]; - // keep track of index as lists require more than one node - for field in schema.fields() { - let triple = create_array( - field_nodes, - field.data_type(), - buf, - buffers, - dictionaries, - node_index, - buffer_index, - ); - node_index = triple.1; - buffer_index = triple.2; - arrays.push(triple.0); - } + let options = RecordBatchOptions { + row_count: Some(batch.length() as usize), + ..Default::default() + }; - RecordBatch::try_new(schema, arrays) + if let Some(projection) = projection { + // project fields + for (idx, field) in schema.fields().iter().enumerate() { + // Create array for projected field + if projection.contains(&idx) { + let triple = create_array( + field_nodes, + field, + buf, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + metadata, + )?; + node_index = triple.1; + buffer_index = triple.2; + arrays.push(triple.0); + } else { + // Skip field. + // This must be called to advance `node_index` and `buffer_index`. + let tuple = skip_field( + field_nodes, + field, + buf, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + )?; + node_index = tuple.0; + buffer_index = tuple.1; + } + } + + RecordBatch::try_new_with_options( + Arc::new(schema.project(projection)?), + arrays, + &options, + ) + } else { + // keep track of index as lists require more than one node + for field in schema.fields() { + let triple = create_array( + field_nodes, + field, + buf, + buffers, + dictionaries_by_id, + node_index, + buffer_index, + metadata, + )?; + node_index = triple.1; + buffer_index = triple.2; + arrays.push(triple.0); + } + RecordBatch::try_new_with_options(schema, arrays, &options) + } } /// Read the dictionary from the buffer and provided metadata, -/// updating the `dictionaries_by_field` with the resulting dictionary +/// updating the `dictionaries_by_id` with the resulting dictionary pub fn read_dictionary( buf: &[u8], batch: ipc::DictionaryBatch, schema: &Schema, - dictionaries_by_field: &mut [Option], + dictionaries_by_id: &mut HashMap, + metadata: &ipc::MetadataVersion, ) -> Result<()> { if batch.isDelta() { return Err(ArrowError::IoError( @@ -481,7 +710,9 @@ pub fn read_dictionary( buf, batch.data().unwrap(), Arc::new(schema), - dictionaries_by_field, + dictionaries_by_id, + None, + metadata, )?; Some(record_batch.column(0).clone()) } @@ -491,16 +722,10 @@ pub fn read_dictionary( ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) })?; - // for all fields with this dictionary id, update the dictionaries vector - // in the reader. Note that a dictionary batch may be shared between many fields. // We don't currently record the isOrdered field. This could be general // attributes of arrays. - for (i, field) in schema.fields().iter().enumerate() { - if field.dict_id() == Some(id) { - // Add (possibly multiple) array refs to the dictionaries array. - dictionaries_by_field[i] = Some(dictionary_values.clone()); - } - } + // Add (possibly multiple) array refs to the dictionaries array. + dictionaries_by_id.insert(id, dictionary_values.clone()); Ok(()) } @@ -527,10 +752,13 @@ pub struct FileReader { /// Optional dictionaries for each schema field. /// /// Dictionaries may be appended to in the streaming format. - dictionaries_by_field: Vec>, + dictionaries_by_id: HashMap, /// Metadata version metadata_version: ipc::MetadataVersion, + + /// Optional projection and projected_schema + projection: Option<(Vec, Schema)>, } impl FileReader { @@ -538,7 +766,7 @@ impl FileReader { /// /// Returns errors if the file does not meet the Arrow Format header and footer /// requirements - pub fn try_new(reader: R) -> Result { + pub fn try_new(reader: R, projection: Option>) -> Result { let mut reader = BufReader::new(reader); // check if header and footer contain correct magic bytes let mut magic_buffer: [u8; 6] = [0; 6]; @@ -582,45 +810,63 @@ impl FileReader { let schema = ipc::convert::fb_to_schema(ipc_schema); // Create an array of optional dictionary value arrays, one per field. - let mut dictionaries_by_field = vec![None; schema.fields().len()]; - for block in footer.dictionaries().unwrap() { - // read length from end of offset - let mut message_size: [u8; 4] = [0; 4]; - reader.seek(SeekFrom::Start(block.offset() as u64))?; - reader.read_exact(&mut message_size)?; - if message_size == CONTINUATION_MARKER { + let mut dictionaries_by_id = HashMap::new(); + if let Some(dictionaries) = footer.dictionaries() { + for block in dictionaries { + // read length from end of offset + let mut message_size: [u8; 4] = [0; 4]; + reader.seek(SeekFrom::Start(block.offset() as u64))?; reader.read_exact(&mut message_size)?; - } - let footer_len = i32::from_le_bytes(message_size); - let mut block_data = vec![0; footer_len as usize]; - - reader.read_exact(&mut block_data)?; - - let message = ipc::root_as_message(&block_data[..]).map_err(|err| { - ArrowError::IoError(format!("Unable to get root as message: {:?}", err)) - })?; + if message_size == CONTINUATION_MARKER { + reader.read_exact(&mut message_size)?; + } + let footer_len = i32::from_le_bytes(message_size); + let mut block_data = vec![0; footer_len as usize]; - match message.header_type() { - ipc::MessageHeader::DictionaryBatch => { - let batch = message.header_as_dictionary_batch().unwrap(); + reader.read_exact(&mut block_data)?; - // read the block that makes up the dictionary batch into a buffer - let mut buf = vec![0; block.bodyLength() as usize]; - reader.seek(SeekFrom::Start( - block.offset() as u64 + block.metaDataLength() as u64, - ))?; - reader.read_exact(&mut buf)?; + let message = ipc::root_as_message(&block_data[..]).map_err(|err| { + ArrowError::IoError(format!( + "Unable to get root as message: {:?}", + err + )) + })?; - read_dictionary(&buf, batch, &schema, &mut dictionaries_by_field)?; - } - t => { - return Err(ArrowError::IoError(format!( - "Expecting DictionaryBatch in dictionary blocks, found {:?}.", - t - ))); + match message.header_type() { + ipc::MessageHeader::DictionaryBatch => { + let batch = message.header_as_dictionary_batch().unwrap(); + + // read the block that makes up the dictionary batch into a buffer + let mut buf = vec![0; block.bodyLength() as usize]; + reader.seek(SeekFrom::Start( + block.offset() as u64 + block.metaDataLength() as u64, + ))?; + reader.read_exact(&mut buf)?; + + read_dictionary( + &buf, + batch, + &schema, + &mut dictionaries_by_id, + &message.version(), + )?; + } + t => { + return Err(ArrowError::IoError(format!( + "Expecting DictionaryBatch in dictionary blocks, found {:?}.", + t + ))); + } } - }; + } } + let projection = match projection { + Some(projection_indices) => { + let schema = schema.project(&projection_indices)?; + Some((projection_indices, schema)) + } + _ => None, + }; Ok(Self { reader, @@ -628,8 +874,9 @@ impl FileReader { blocks: blocks.to_vec(), current_block: 0, total_blocks, - dictionaries_by_field, + dictionaries_by_id, metadata_version: footer.version(), + projection, }) } @@ -709,7 +956,10 @@ impl FileReader { &buf, batch, self.schema(), - &self.dictionaries_by_field, + &self.dictionaries_by_id, + self.projection.as_ref().map(|x| x.0.as_ref()), + &message.version() + ).map(Some) } ipc::MessageHeader::NONE => { @@ -752,12 +1002,15 @@ pub struct StreamReader { /// Optional dictionaries for each schema field. /// /// Dictionaries may be appended to in the streaming format. - dictionaries_by_field: Vec>, + dictionaries_by_id: HashMap, /// An indicator of whether the stream is complete. /// /// This value is set to `true` the first time the reader's `next()` returns `None`. finished: bool, + + /// Optional projection + projection: Option<(Vec, Schema)>, } impl StreamReader { @@ -766,7 +1019,7 @@ impl StreamReader { /// The first message in the stream is the schema, the reader will fail if it does not /// encounter a schema. /// To check if the reader is done, use `is_finished(self)` - pub fn try_new(reader: R) -> Result { + pub fn try_new(reader: R, projection: Option>) -> Result { let mut reader = BufReader::new(reader); // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; @@ -793,13 +1046,21 @@ impl StreamReader { let schema = ipc::convert::fb_to_schema(ipc_schema); // Create an array of optional dictionary value arrays, one per field. - let dictionaries_by_field = vec![None; schema.fields().len()]; + let dictionaries_by_id = HashMap::new(); + let projection = match projection { + Some(projection_indices) => { + let schema = schema.project(&projection_indices)?; + Some((projection_indices, schema)) + } + _ => None, + }; Ok(Self { reader, schema: Arc::new(schema), finished: false, - dictionaries_by_field, + dictionaries_by_id, + projection, }) } @@ -872,7 +1133,7 @@ impl StreamReader { let mut buf = vec![0; message.bodyLength() as usize]; self.reader.read_exact(&mut buf)?; - read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field).map(Some) + read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some) } ipc::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().ok_or_else(|| { @@ -885,7 +1146,7 @@ impl StreamReader { self.reader.read_exact(&mut buf)?; read_dictionary( - &buf, batch, &self.schema, &mut self.dictionaries_by_field + &buf, batch, &self.schema, &mut self.dictionaries_by_id, &message.version() )?; // read the next message until we encounter a RecordBatch @@ -923,9 +1184,11 @@ mod tests { use flate2::read::GzDecoder; - use crate::util::integration_util::*; + use crate::datatypes::{ArrowNativeType, Float64Type, Int32Type, Int8Type}; + use crate::{datatypes, util::integration_util::*}; #[test] + #[cfg(not(feature = "force_validate"))] fn read_generated_files_014() { let testdata = crate::util::test_util::arrow_test_data(); let version = "0.14.1"; @@ -948,7 +1211,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -965,7 +1228,23 @@ mod tests { testdata )) .unwrap(); - FileReader::try_new(file).unwrap(); + FileReader::try_new(file, None).unwrap(); + } + + #[test] + #[should_panic( + expected = "Last offset 687865856 of Utf8 is larger than values length 41" + )] + fn read_dictionary_be_not_implemented() { + // The offsets are not translated for big-endian files + // https://github.com/apache/arrow-rs/issues/859 + let testdata = crate::util::test_util::arrow_test_data(); + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/1.0.0-bigendian/generated_dictionary.arrow_file", + testdata + )) + .unwrap(); + FileReader::try_new(file, None).unwrap(); } #[test] @@ -975,7 +1254,6 @@ mod tests { let paths = vec![ "generated_interval", "generated_datetime", - "generated_dictionary", "generated_map", "generated_nested", "generated_null_trivial", @@ -991,11 +1269,47 @@ mod tests { )) .unwrap(); - FileReader::try_new(file).unwrap(); + FileReader::try_new(file, None).unwrap(); }); } #[test] + fn projection_should_work() { + // complementary to the previous test + let testdata = crate::util::test_util::arrow_test_data(); + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_map", + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + // We must use littleendian files here. + // The offsets are not translated for big-endian files + // https://github.com/apache/arrow-rs/issues/859 + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/1.0.0-littleendian/{}.arrow_file", + testdata, path + )) + .unwrap(); + + let reader = FileReader::try_new(file, Some(vec![0])).unwrap(); + let datatype_0 = reader.schema().fields()[0].data_type().clone(); + reader.for_each(|batch| { + let batch = batch.unwrap(); + assert_eq!(batch.columns().len(), 1); + assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone()); + }); + }); + } + + #[test] + #[cfg(not(feature = "force_validate"))] fn read_generated_streams_014() { let testdata = crate::util::test_util::arrow_test_data(); let version = "0.14.1"; @@ -1018,7 +1332,7 @@ mod tests { )) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1055,7 +1369,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1088,7 +1402,7 @@ mod tests { )) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1100,6 +1414,169 @@ mod tests { }); } + fn create_test_projection_schema() -> Schema { + // define field types + let list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + + let fixed_size_list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 3, + ); + + let key_type = DataType::Int8; + let value_type = DataType::Utf8; + let dict_data_type = + DataType::Dictionary(Box::new(key_type), Box::new(value_type)); + + let union_fileds = vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ]; + let union_data_type = DataType::Union(union_fileds, vec![0, 1], UnionMode::Dense); + + let struct_fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "list", + DataType::List(Box::new(Field::new("item", DataType::Int8, true))), + false, + ), + ]; + let struct_data_type = DataType::Struct(struct_fields); + + // define schema + Schema::new(vec![ + Field::new("f0", DataType::UInt32, false), + Field::new("f1", DataType::Utf8, false), + Field::new("f2", DataType::Boolean, false), + Field::new("f3", union_data_type, true), + Field::new("f4", DataType::Null, true), + Field::new("f5", DataType::Float64, true), + Field::new("f6", list_data_type, false), + Field::new("f7", DataType::FixedSizeBinary(3), true), + Field::new("f8", fixed_size_list_data_type, false), + Field::new("f9", struct_data_type, false), + Field::new("f10", DataType::Boolean, false), + Field::new("f11", dict_data_type, false), + Field::new("f12", DataType::Utf8, false), + ]) + } + + fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch { + // set test data for each column + let array0 = UInt32Array::from(vec![1, 2, 3]); + let array1 = StringArray::from(vec!["foo", "bar", "baz"]); + let array2 = BooleanArray::from(vec![true, false, true]); + + let mut union_builder = UnionBuilder::new_dense(3); + union_builder.append::("a", 1).unwrap(); + union_builder.append::("b", 10.1).unwrap(); + union_builder.append_null::("b").unwrap(); + let array3 = union_builder.build().unwrap(); + + let array4 = NullArray::new(3); + let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]); + let array6_values = vec![ + Some(vec![Some(10), Some(10), Some(10)]), + Some(vec![Some(20), Some(20), Some(20)]), + Some(vec![Some(30), Some(30)]), + ]; + let array6 = ListArray::from_iter_primitive::(array6_values); + let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]]; + let array7 = + FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap(); + + let array8_values = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&[ + 40, 41, 42, 43, 44, 45, 46, 47, 48, + ])) + .build() + .unwrap(); + let array8_data = ArrayData::builder(schema.field(8).data_type().clone()) + .len(3) + .add_child_data(array8_values) + .build() + .unwrap(); + let array8 = FixedSizeListArray::from(array8_data); + + let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003])); + let array9_list: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(-10)]), + Some(vec![Some(-20), Some(-20), Some(-20)]), + Some(vec![Some(-30)]), + ])); + let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone()) + .add_child_data(array9_id.data().clone()) + .add_child_data(array9_list.data().clone()) + .len(3) + .build() + .unwrap(); + let array9: ArrayRef = Arc::new(StructArray::from(array9)); + + let array10 = BooleanArray::from(vec![false, false, true]); + + let array11_values = StringArray::from(vec!["x", "yy", "zzz"]); + let array11_keys = Int8Array::from_iter_values([1, 1, 2]); + let array11 = + DictionaryArray::::try_new(&array11_keys, &array11_values).unwrap(); + + let array12 = StringArray::from(vec!["a", "bb", "ccc"]); + + // create record batch + RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(array0), + Arc::new(array1), + Arc::new(array2), + Arc::new(array3), + Arc::new(array4), + Arc::new(array5), + Arc::new(array6), + Arc::new(array7), + Arc::new(array8), + Arc::new(array9), + Arc::new(array10), + Arc::new(array11), + Arc::new(array12), + ], + ) + .unwrap() + } + + #[test] + fn test_projection_array_values() { + // define schema + let schema = create_test_projection_schema(); + + // create record batch with test data + let batch = create_test_projection_batch_data(&schema); + + // write record batch in IPC format + let mut buf = Vec::new(); + { + let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &schema).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + + // read record batch with projection + for index in 0..12 { + let projection = vec![index]; + let reader = + FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection)); + let read_batch = reader.unwrap().next().unwrap().unwrap(); + let projected_column = read_batch.column(0); + let expected_column = batch.column(index); + + // check the projected column equals the expected column + assert_eq!(projected_column.as_ref(), expected_column.as_ref()); + } + } + #[test] fn test_arrow_single_float_row() { let schema = Schema::new(vec![ @@ -1124,7 +1601,7 @@ mod tests { // read stream back let file = File::open("target/debug/testdata/float.stream").unwrap(); - let reader = StreamReader::try_new(file).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); reader.for_each(|batch| { let batch = batch.unwrap(); @@ -1146,7 +1623,107 @@ mod tests { .value(0) != 0.0 ); - }) + }); + + let file = File::open("target/debug/testdata/float.stream").unwrap(); + + // Read with projection + let reader = StreamReader::try_new(file, Some(vec![0, 3])).unwrap(); + + reader.for_each(|batch| { + let batch = batch.unwrap(); + assert_eq!(batch.schema().fields().len(), 2); + assert_eq!(batch.schema().fields()[0].data_type(), &DataType::Float32); + assert_eq!(batch.schema().fields()[1].data_type(), &DataType::Int32); + }); + } + + fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch { + let mut buf = Vec::new(); + let mut writer = + ipc::writer::FileWriter::try_new(&mut buf, &rb.schema()).unwrap(); + writer.write(rb).unwrap(); + writer.finish().unwrap(); + drop(writer); + + let mut reader = + ipc::reader::FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); + reader.next().unwrap().unwrap() + } + + fn roundtrip_ipc_stream(rb: &RecordBatch) -> RecordBatch { + let mut buf = Vec::new(); + let mut writer = + ipc::writer::StreamWriter::try_new(&mut buf, &rb.schema()).unwrap(); + writer.write(rb).unwrap(); + writer.finish().unwrap(); + drop(writer); + + let mut reader = + ipc::reader::StreamReader::try_new(std::io::Cursor::new(buf), None).unwrap(); + reader.next().unwrap().unwrap() + } + + #[test] + fn test_roundtrip_nested_dict() { + let inner: DictionaryArray = + vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + let dctfield = Field::new("dict", array.data_type().clone(), false); + + let s = StructArray::from(vec![(dctfield, array)]); + let struct_array = Arc::new(s) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "struct", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); + + assert_eq!(batch, roundtrip_ipc(&batch)); + } + + fn check_union_with_builder(mut builder: UnionBuilder) { + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append::("a", 4).unwrap(); + builder.append::("d", 11).unwrap(); + let union = builder.build().unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "union", + union.data_type().clone(), + false, + )])); + + let union_array = Arc::new(union) as ArrayRef; + + let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap(); + let rb2 = roundtrip_ipc(&rb); + // TODO: equality not yet implemented for union, so we check that the length of the array is + // the same and that all of the buffers are the same instead. + assert_eq!(rb.schema(), rb2.schema()); + assert_eq!(rb.num_columns(), rb2.num_columns()); + assert_eq!(rb.num_rows(), rb2.num_rows()); + let union1 = rb.column(0); + let union2 = rb2.column(0); + + assert_eq!(union1.data().buffers(), union2.data().buffers()); + } + + #[test] + fn test_roundtrip_dense_union() { + check_union_with_builder(UnionBuilder::new_dense(6)); + } + + #[test] + fn test_roundtrip_sparse_union() { + check_union_with_builder(UnionBuilder::new_sparse(6)); } /// Read gzipped JSON file @@ -1164,4 +1741,218 @@ mod tests { let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); arrow_json } + + #[test] + fn test_roundtrip_stream_nested_dict() { + let xs = vec!["AA", "BB", "AA", "CC", "BB"]; + let dict = Arc::new( + xs.clone() + .into_iter() + .collect::>(), + ); + let string_array: ArrayRef = Arc::new(StringArray::from(xs.clone())); + let struct_array = StructArray::from(vec![ + (Field::new("f2.1", DataType::Utf8, false), string_array), + ( + Field::new("f2.2_struct", dict.data_type().clone(), false), + dict.clone() as ArrayRef, + ), + ]); + let schema = Arc::new(Schema::new(vec![ + Field::new("f1_string", DataType::Utf8, false), + Field::new("f2_struct", struct_array.data_type().clone(), false), + ])); + let input_batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(StringArray::from(xs.clone())), + Arc::new(struct_array), + ], + ) + .unwrap(); + let output_batch = roundtrip_ipc_stream(&input_batch); + assert_eq!(input_batch, output_batch); + } + + #[test] + fn test_roundtrip_stream_nested_dict_of_map_of_dict() { + let values = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]); + let value_dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 3, 1]); + let value_dict_array = + DictionaryArray::::try_new(&value_dict_keys, &values).unwrap(); + + let key_dict_keys = Int8Array::from_iter_values([0, 0, 2, 1, 1, 3]); + let key_dict_array = + DictionaryArray::::try_new(&key_dict_keys, &values).unwrap(); + + let keys_field = Field::new_dict( + "keys", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + 1, + false, + ); + let values_field = Field::new_dict( + "values", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + 1, + false, + ); + let entry_struct = StructArray::from(vec![ + (keys_field, make_array(key_dict_array.data().clone())), + (values_field, make_array(value_dict_array.data().clone())), + ]); + let map_data_type = DataType::Map( + Box::new(Field::new( + "entries", + entry_struct.data_type().clone(), + true, + )), + false, + ); + + let entry_offsets = Buffer::from_slice_ref(&[0, 2, 4, 6]); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.data().clone()) + .build() + .unwrap(); + let map_array = MapArray::from(map_data); + + let dict_keys = Int8Array::from_iter_values([0, 1, 1, 2, 2, 1]); + let dict_dict_array = + DictionaryArray::::try_new(&dict_keys, &map_array).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "f1", + dict_dict_array.data_type().clone(), + false, + )])); + let input_batch = + RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap(); + let output_batch = roundtrip_ipc_stream(&input_batch); + assert_eq!(input_batch, output_batch); + } + + fn test_roundtrip_stream_dict_of_list_of_dict_impl< + OffsetSize: OffsetSizeTrait, + U: ArrowNativeType, + >( + list_data_type: DataType, + offsets: &[U; 5], + ) { + let values = StringArray::from(vec![Some("a"), None, Some("c"), None]); + let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]); + let dict_array = DictionaryArray::::try_new(&keys, &values).unwrap(); + let dict_data = dict_array.data(); + + let value_offsets = Buffer::from_slice_ref(offsets); + + let list_data = ArrayData::builder(list_data_type) + .len(4) + .add_buffer(value_offsets) + .add_child_data(dict_data.clone()) + .build() + .unwrap(); + let list_array = GenericListArray::::from(list_data); + + let keys_for_dict = Int8Array::from_iter_values([0, 3, 0, 1, 1, 2, 0, 1, 3]); + let dict_dict_array = + DictionaryArray::::try_new(&keys_for_dict, &list_array).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "f1", + dict_dict_array.data_type().clone(), + false, + )])); + let input_batch = + RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap(); + let output_batch = roundtrip_ipc_stream(&input_batch); + assert_eq!(input_batch, output_batch); + } + + #[test] + fn test_roundtrip_stream_dict_of_list_of_dict() { + // list + let list_data_type = DataType::List(Box::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + 1, + false, + ))); + let offsets: &[i32; 5] = &[0, 2, 4, 4, 6]; + test_roundtrip_stream_dict_of_list_of_dict_impl::( + list_data_type, + offsets, + ); + + // large list + let list_data_type = DataType::LargeList(Box::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + 1, + false, + ))); + let offsets: &[i64; 5] = &[0, 2, 4, 4, 7]; + test_roundtrip_stream_dict_of_list_of_dict_impl::( + list_data_type, + offsets, + ); + } + + #[test] + fn test_roundtrip_stream_dict_of_fixed_size_list_of_dict() { + let values = StringArray::from(vec![Some("a"), None, Some("c"), None]); + let keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3, 1, 2]); + let dict_array = DictionaryArray::::try_new(&keys, &values).unwrap(); + let dict_data = dict_array.data(); + + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new_dict( + "item", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + 1, + false, + )), + 3, + ); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_child_data(dict_data.clone()) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + let keys_for_dict = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]); + let dict_dict_array = + DictionaryArray::::try_new(&keys_for_dict, &list_array).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "f1", + dict_dict_array.data_type().clone(), + false, + )])); + let input_batch = + RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap(); + let output_batch = roundtrip_ipc_stream(&input_batch); + assert_eq!(input_batch, output_batch); + } + + #[test] + fn test_no_columns_batch() { + let schema = Arc::new(Schema::new(vec![])); + let options = RecordBatchOptions { + match_field_names: true, + row_count: Some(10), + }; + let input_batch = + RecordBatch::try_new_with_options(schema, vec![], &options).unwrap(); + let output_batch = roundtrip_ipc_stream(&input_batch); + assert_eq!(input_batch, output_batch); + } } diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index 0376265f4f65..120eb7ab9b7e 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -25,7 +25,10 @@ use std::io::{BufWriter, Write}; use flatbuffers::FlatBufferBuilder; -use crate::array::{ArrayData, ArrayRef}; +use crate::array::{ + as_large_list_array, as_list_array, as_map_array, as_struct_array, as_union_array, + make_array, Array, ArrayData, ArrayRef, FixedSizeListArray, +}; use crate::buffer::{Buffer, MutableBuffer}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; @@ -36,7 +39,7 @@ use crate::util::bit_util; use ipc::CONTINUATION_MARKER; /// IPC write options used to control the behaviour of the writer -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct IpcWriteOptions { /// Write padding after memory buffers to this multiple of bytes. /// Generally 8 or 64, defaults to 8 @@ -137,26 +140,134 @@ impl IpcDataGenerator { } } - pub fn encoded_batch( + fn _encode_dictionaries( &self, - batch: &RecordBatch, + column: &ArrayRef, + encoded_dictionaries: &mut Vec, dictionary_tracker: &mut DictionaryTracker, write_options: &IpcWriteOptions, - ) -> Result<(Vec, EncodedData)> { - // TODO: handle nested dictionaries - let schema = batch.schema(); - let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); + ) -> Result<()> { + match column.data_type() { + DataType::Struct(fields) => { + let s = as_struct_array(column); + for (field, &column) in fields.iter().zip(s.columns().iter()) { + self.encode_dictionaries( + field, + column, + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + } + DataType::List(field) => { + let list = as_list_array(column); + self.encode_dictionaries( + field, + &list.values(), + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + DataType::LargeList(field) => { + let list = as_large_list_array(column); + self.encode_dictionaries( + field, + &list.values(), + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + DataType::FixedSizeList(field, _) => { + let list = column + .as_any() + .downcast_ref::() + .expect("Unable to downcast to fixed size list array"); + self.encode_dictionaries( + field, + &list.values(), + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + DataType::Map(field, _) => { + let map_array = as_map_array(column); + + let (keys, values) = match field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + (&fields[0], &fields[1]) + } + _ => panic!("Incorrect field data type {:?}", field.data_type()), + }; + + // keys + self.encode_dictionaries( + keys, + &map_array.keys(), + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + + // values + self.encode_dictionaries( + values, + &map_array.values(), + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + DataType::Union(fields, _, _) => { + let union = as_union_array(column); + for (field, ref column) in fields + .iter() + .enumerate() + .map(|(n, f)| (f, union.child(n as i8))) + { + self.encode_dictionaries( + field, + column, + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + } + _ => (), + } - for (i, field) in schema.fields().iter().enumerate() { - let column = batch.column(i); + Ok(()) + } - if let DataType::Dictionary(_key_type, _value_type) = column.data_type() { + fn encode_dictionaries( + &self, + field: &Field, + column: &ArrayRef, + encoded_dictionaries: &mut Vec, + dictionary_tracker: &mut DictionaryTracker, + write_options: &IpcWriteOptions, + ) -> Result<()> { + match column.data_type() { + DataType::Dictionary(_key_type, _value_type) => { let dict_id = field .dict_id() .expect("All Dictionary types have `dict_id`"); let dict_data = column.data(); let dict_values = &dict_data.child_data()[0]; + let values = make_array(dict_data.child_data()[0].clone()); + + self._encode_dictionaries( + &values, + encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + let emit = dictionary_tracker.insert(dict_id, column)?; if emit { @@ -167,10 +278,38 @@ impl IpcDataGenerator { )); } } + _ => self._encode_dictionaries( + column, + encoded_dictionaries, + dictionary_tracker, + write_options, + )?, } - let encoded_message = self.record_batch_to_bytes(batch, write_options); + Ok(()) + } + + pub fn encoded_batch( + &self, + batch: &RecordBatch, + dictionary_tracker: &mut DictionaryTracker, + write_options: &IpcWriteOptions, + ) -> Result<(Vec, EncodedData)> { + let schema = batch.schema(); + let mut encoded_dictionaries = Vec::with_capacity(schema.all_fields().len()); + + for (i, field) in schema.fields().iter().enumerate() { + let column = batch.column(i); + self.encode_dictionaries( + field, + column, + &mut encoded_dictionaries, + dictionary_tracker, + write_options, + )?; + } + let encoded_message = self.record_batch_to_bytes(batch, write_options); Ok((encoded_dictionaries, encoded_message)) } @@ -197,6 +336,7 @@ impl IpcDataGenerator { offset, array.len(), array.null_count(), + write_options, ); } @@ -250,6 +390,7 @@ impl IpcDataGenerator { 0, array_data.len(), array_data.null_count(), + write_options, ); // write data @@ -467,6 +608,18 @@ impl FileWriter { Ok(()) } + + /// Unwraps the BufWriter housed in FileWriter.writer, returning the underlying + /// writer + /// + /// The buffer is flushed and the FileWriter is finished before returning the + /// writer. + pub fn into_inner(mut self) -> Result { + if !self.finished { + self.finish()?; + } + self.writer.into_inner().map_err(ArrowError::from) + } } pub struct StreamWriter { @@ -474,8 +627,6 @@ pub struct StreamWriter { writer: BufWriter, /// IPC write options write_options: IpcWriteOptions, - /// A reference to the schema, used in validating record batches - schema: Schema, /// Whether the writer footer has been written, and the writer is finished finished: bool, /// Keeps track of dictionaries that have been written @@ -504,7 +655,6 @@ impl StreamWriter { Ok(Self { writer, write_options, - schema: schema.clone(), finished: false, dictionary_tracker: DictionaryTracker::new(false), data_gen, @@ -701,20 +851,37 @@ fn write_continuation( Ok(written) } +/// In V4, null types have no validity bitmap +/// In V5 and later, null and union types have no validity bitmap +fn has_validity_bitmap(data_type: &DataType, write_options: &IpcWriteOptions) -> bool { + if write_options.metadata_version < ipc::MetadataVersion::V5 { + !matches!(data_type, DataType::Null) + } else { + !matches!(data_type, DataType::Null | DataType::Union(_, _, _)) + } +} + /// Write array data to a vector of bytes +#[allow(clippy::too_many_arguments)] fn write_array_data( array_data: &ArrayData, - mut buffers: &mut Vec, - mut arrow_data: &mut Vec, - mut nodes: &mut Vec, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, offset: i64, num_rows: usize, null_count: usize, + write_options: &IpcWriteOptions, ) -> i64 { let mut offset = offset; - nodes.push(ipc::FieldNode::new(num_rows as i64, null_count as i64)); - // NullArray does not have any buffers, thus the null buffer is not generated - if array_data.data_type() != &DataType::Null { + if !matches!(array_data.data_type(), DataType::Null) { + nodes.push(ipc::FieldNode::new(num_rows as i64, null_count as i64)); + } else { + // NullArray's null_count equals to len, but the `null_count` passed in is from ArrayData + // where null_count is always 0. + nodes.push(ipc::FieldNode::new(num_rows as i64, num_rows as i64)); + } + if has_validity_bitmap(array_data.data_type(), write_options) { // write null buffer if exists let null_buffer = match array_data.null_buffer() { None => { @@ -727,11 +894,11 @@ fn write_array_data( Some(buffer) => buffer.clone(), }; - offset = write_buffer(&null_buffer, &mut buffers, &mut arrow_data, offset); + offset = write_buffer(&null_buffer, buffers, arrow_data, offset); } array_data.buffers().iter().for_each(|buffer| { - offset = write_buffer(buffer, &mut buffers, &mut arrow_data, offset); + offset = write_buffer(buffer, buffers, arrow_data, offset); }); if !matches!(array_data.data_type(), DataType::Dictionary(_, _)) { @@ -740,12 +907,13 @@ fn write_array_data( // write the nested data (e.g list data) offset = write_array_data( data_ref, - &mut buffers, - &mut arrow_data, - &mut nodes, + buffers, + arrow_data, + nodes, offset, data_ref.len(), data_ref.null_count(), + write_options, ); }); } @@ -824,7 +992,7 @@ mod tests { let file = File::open(format!("target/debug/testdata/{}.arrow_file", "arrow")) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); while let Some(Ok(read_batch)) = reader.next() { read_batch .columns() @@ -872,7 +1040,7 @@ mod tests { { let file = File::open(&file_name).unwrap(); - let reader = FileReader::try_new(file).unwrap(); + let reader = FileReader::try_new(file, None).unwrap(); reader.for_each(|maybe_batch| { maybe_batch .unwrap() @@ -920,6 +1088,7 @@ mod tests { } #[test] + #[cfg(not(feature = "force_validate"))] fn read_and_rewrite_generated_files_014() { let testdata = crate::util::test_util::arrow_test_data(); let version = "0.14.1"; @@ -942,7 +1111,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read and rewrite the file to a temp location { @@ -963,7 +1132,7 @@ mod tests { version, path )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -972,6 +1141,7 @@ mod tests { } #[test] + #[cfg(not(feature = "force_validate"))] fn read_and_rewrite_generated_streams_014() { let testdata = crate::util::test_util::arrow_test_data(); let version = "0.14.1"; @@ -994,7 +1164,7 @@ mod tests { )) .unwrap(); - let reader = StreamReader::try_new(file).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); // read and rewrite the stream to a temp location { @@ -1013,7 +1183,7 @@ mod tests { let file = File::open(format!("target/debug/testdata/{}-{}.stream", version, path)) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1051,7 +1221,7 @@ mod tests { )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read and rewrite the file to a temp location { @@ -1077,7 +1247,7 @@ mod tests { version, path )) .unwrap(); - let mut reader = FileReader::try_new(file).unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1115,7 +1285,7 @@ mod tests { )) .unwrap(); - let reader = StreamReader::try_new(file).unwrap(); + let reader = StreamReader::try_new(file, None).unwrap(); // read and rewrite the stream to a temp location { @@ -1138,7 +1308,7 @@ mod tests { let file = File::open(format!("target/debug/testdata/{}-{}.stream", version, path)) .unwrap(); - let mut reader = StreamReader::try_new(file).unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); // read expected JSON output let arrow_json = read_gzip_json(version, path); @@ -1161,4 +1331,180 @@ mod tests { let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); arrow_json } + + #[test] + fn track_union_nested_dict() { + let inner: DictionaryArray = vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + // Dict field with id 2 + let dctfield = + Field::new_dict("dict", array.data_type().clone(), false, 2, false); + + let types = Buffer::from_slice_ref(&[0_i8, 0, 0]); + let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]); + + let union = + UnionArray::try_new(&[0], types, Some(offsets), vec![(dctfield, array)]) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "union", + union.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(union)]).unwrap(); + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + // Dictionary with id 2 should have been written to the dict tracker + assert!(dict_tracker.written.contains_key(&2)); + } + + #[test] + fn track_struct_nested_dict() { + let inner: DictionaryArray = vec!["a", "b", "a"].into_iter().collect(); + + let array = Arc::new(inner) as ArrayRef; + + // Dict field with id 2 + let dctfield = + Field::new_dict("dict", array.data_type().clone(), false, 2, false); + + let s = StructArray::from(vec![(dctfield, array)]); + let struct_array = Arc::new(s) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "struct", + struct_array.data_type().clone(), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap(); + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + gen.encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + // Dictionary with id 2 should have been written to the dict tracker + assert!(dict_tracker.written.contains_key(&2)); + } + + #[test] + fn read_union_017() { + let testdata = crate::util::test_util::arrow_test_data(); + let version = "0.17.1"; + let data_file = File::open(format!( + "{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream", + testdata, + )) + .unwrap(); + + let reader = StreamReader::try_new(data_file, None).unwrap(); + + // read and rewrite the stream to a temp location + { + let file = File::create(format!( + "target/debug/testdata/{}-generated_union.stream", + version + )) + .unwrap(); + let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap(); + reader.for_each(|batch| { + writer.write(&batch.unwrap()).unwrap(); + }); + writer.finish().unwrap(); + } + + // Compare original file and rewrote file + let file = File::open(format!( + "target/debug/testdata/{}-generated_union.stream", + version + )) + .unwrap(); + let rewrite_reader = StreamReader::try_new(file, None).unwrap(); + + let data_file = File::open(format!( + "{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream", + testdata, + )) + .unwrap(); + let reader = StreamReader::try_new(data_file, None).unwrap(); + + reader.into_iter().zip(rewrite_reader.into_iter()).for_each( + |(batch1, batch2)| { + assert_eq!(batch1.unwrap(), batch2.unwrap()); + }, + ); + } + + fn write_union_file(options: IpcWriteOptions) { + let schema = Schema::new(vec![Field::new( + "union", + DataType::Union( + vec![ + Field::new("a", DataType::Int32, false), + Field::new("c", DataType::Float64, false), + ], + vec![0, 1], + UnionMode::Sparse, + ), + true, + )]); + let mut builder = UnionBuilder::new_sparse(5); + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append_null::("c").unwrap(); + builder.append::("a", 4).unwrap(); + let union = builder.build().unwrap(); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(union) as ArrayRef], + ) + .unwrap(); + let file_name = "target/debug/testdata/union.arrow_file"; + { + let file = File::create(&file_name).unwrap(); + let mut writer = + FileWriter::try_new_with_options(file, &schema, options).unwrap(); + + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + + { + let file = File::open(&file_name).unwrap(); + let reader = FileReader::try_new(file, None).unwrap(); + reader.for_each(|maybe_batch| { + maybe_batch + .unwrap() + .columns() + .iter() + .zip(batch.columns()) + .for_each(|(a, b)| { + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.null_count(), b.null_count()); + }); + }); + } + } + + #[test] + fn test_write_union_file_v4_v5() { + write_union_file( + IpcWriteOptions::try_new(8, false, MetadataVersion::V4).unwrap(), + ); + write_union_file( + IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(), + ); + } } diff --git a/arrow/src/json/reader.rs b/arrow/src/json/reader.rs index eb78e0a420fc..e1fa54f8a644 100644 --- a/arrow/src/json/reader.rs +++ b/arrow/src/json/reader.rs @@ -38,7 +38,12 @@ //! //! let file = File::open("test/data/basic.json").unwrap(); //! -//! let mut json = json::Reader::new(BufReader::new(file), Arc::new(schema), 1024, None); +//! let mut json = json::Reader::new( +//! BufReader::new(file), +//! Arc::new(schema), +//! json::reader::DecoderOptions::new(), +//! ); +//! //! let batch = json.next().unwrap().unwrap(); //! ``` @@ -55,6 +60,7 @@ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::util::bit_util; +use crate::util::reader_parser::Parser; use crate::{array::*, buffer::Buffer}; #[derive(Debug, Clone)] @@ -549,12 +555,15 @@ where generate_schema(field_types) } -/// JSON values to Arrow record batch decoder. Decoder's next_batch method takes a JSON Value -/// iterator as input and outputs Arrow record batch. +/// JSON values to Arrow record batch decoder. +/// +/// A [`Decoder`] decodes arbitrary streams of [`serde_json::Value`]s and +/// converts them to [`RecordBatch`]es. To decode JSON formatted files, +/// see [`Reader`]. /// /// # Examples /// ``` -/// use arrow::json::reader::{Decoder, ValueIter, infer_json_schema}; +/// use arrow::json::reader::{Decoder, DecoderOptions, ValueIter, infer_json_schema}; /// use std::fs::File; /// use std::io::{BufReader, Seek, SeekFrom}; /// use std::sync::Arc; @@ -562,8 +571,9 @@ where /// let mut reader = /// BufReader::new(File::open("test/data/mixed_arrays.json").unwrap()); /// let inferred_schema = infer_json_schema(&mut reader, None).unwrap(); -/// let batch_size = 1024; -/// let decoder = Decoder::new(Arc::new(inferred_schema), batch_size, None); +/// let options = DecoderOptions::new() +/// .with_batch_size(1024); +/// let decoder = Decoder::new(Arc::new(inferred_schema), options); /// /// // seek back to start so that the original file is usable again /// reader.seek(SeekFrom::Start(0)).unwrap(); @@ -576,31 +586,70 @@ where pub struct Decoder { /// Explicit schema for the JSON file schema: SchemaRef, + /// This is a collection of options for json decoder + options: DecoderOptions, +} + +#[derive(Debug, Clone, PartialEq)] +/// Options for JSON decoding +pub struct DecoderOptions { + /// Batch size (number of records to load each time), defaults to 1024 records + batch_size: usize, /// Optional projection for which columns to load (case-sensitive names) projection: Option>, - /// Batch size (number of records to load each time) - batch_size: usize, + /// optional HashMap of column name to its format string + format_strings: Option>, } -impl Decoder { - /// Create a new JSON decoder from any value that implements the `Iterator>` - /// trait. - pub fn new( - schema: SchemaRef, - batch_size: usize, - projection: Option>, - ) -> Self { +impl Default for DecoderOptions { + fn default() -> Self { Self { - schema, - projection, - batch_size, + batch_size: 1024, + projection: None, + format_strings: None, } } +} + +impl DecoderOptions { + pub fn new() -> Self { + Default::default() + } + + /// Set the batch size (number of records to load at one time) + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the reader's column projection + pub fn with_projection(mut self, projection: Vec) -> Self { + self.projection = Some(projection); + self + } + + /// Set the decoder's format Strings param + pub fn with_format_strings( + mut self, + format_strings: HashMap, + ) -> Self { + self.format_strings = Some(format_strings); + self + } +} + +impl Decoder { + /// Create a new JSON decoder from some value that implements an + /// iterator over [`serde_json::Value`]s (aka implements the + /// `Iterator>` trait). + pub fn new(schema: SchemaRef, options: DecoderOptions) -> Self { + Self { schema, options } + } /// Returns the schema of the reader, useful for getting the schema without reading /// record batches pub fn schema(&self) -> SchemaRef { - match &self.projection { + match &self.options.projection { Some(projection) => { let fields = self.schema.fields(); let projected_fields: Vec = fields @@ -620,14 +669,18 @@ impl Decoder { } } - /// Read the next batch of records + /// Read the next batch of [`serde_json::Value`] records from the + /// interator into a [`RecordBatch`]. + /// + /// Returns `None` if the input iterator is exhausted. pub fn next_batch(&self, value_iter: &mut I) -> Result> where I: Iterator>, { - let mut rows: Vec = Vec::with_capacity(self.batch_size); + let batch_size = self.options.batch_size; + let mut rows: Vec = Vec::with_capacity(batch_size); - for value in value_iter.by_ref().take(self.batch_size) { + for value in value_iter.by_ref().take(batch_size) { let v = value?; match v { Value::Object(_) => rows.push(v), @@ -645,7 +698,7 @@ impl Decoder { } let rows = &rows[..]; - let projection = self.projection.clone().unwrap_or_else(Vec::new); + let projection = self.options.projection.clone().unwrap_or_default(); let arrays = self.build_struct_array(rows, self.schema.fields(), &projection); let projected_fields: Vec = if projection.is_empty() { @@ -653,8 +706,7 @@ impl Decoder { } else { projection .iter() - .map(|name| self.schema.column_with_name(name)) - .flatten() + .filter_map(|name| self.schema.column_with_name(name)) .map(|(_, field)| field.clone()) .collect() }; @@ -735,14 +787,14 @@ impl Decoder { } #[inline(always)] - fn list_array_string_array_builder( + fn list_array_string_array_builder
( &self, data_type: &DataType, col_name: &str, rows: &[Value], ) -> Result where - DICT_TY: ArrowPrimitiveType + ArrowDictionaryKeyType, + DT: ArrowPrimitiveType + ArrowDictionaryKeyType, { let mut builder: Box = match data_type { DataType::Utf8 => { @@ -751,7 +803,7 @@ impl Decoder { } DataType::Dictionary(_, _) => { let values_builder = - self.build_string_dictionary_builder::(rows.len() * 5)?; + self.build_string_dictionary_builder::
(rows.len() * 5)?; Box::new(ListBuilder::new(values_builder)) } e => { @@ -813,7 +865,7 @@ impl Decoder { builder.append(true)?; } DataType::Dictionary(_, _) => { - let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::JsonError( + let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::JsonError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -914,7 +966,7 @@ impl Decoder { } #[allow(clippy::unnecessary_wraps)] - fn build_primitive_array( + fn build_primitive_array( &self, rows: &[Value], col_name: &str, @@ -923,20 +975,30 @@ impl Decoder { T: ArrowNumericType, T::Native: num::NumCast, { + let format_string = self + .options + .format_strings + .as_ref() + .and_then(|fmts| fmts.get(col_name)); Ok(Arc::new( rows.iter() .map(|row| { - row.get(&col_name) - .and_then(|value| { - if value.is_i64() { - value.as_i64().map(num::cast::cast) - } else if value.is_u64() { - value.as_u64().map(num::cast::cast) - } else { - value.as_f64().map(num::cast::cast) + row.get(&col_name).and_then(|value| { + if value.is_i64() { + value.as_i64().and_then(num::cast::cast) + } else if value.is_u64() { + value.as_u64().and_then(num::cast::cast) + } else if value.is_string() { + match format_string { + Some(fmt) => { + T::parse_formatted(value.as_str().unwrap(), fmt) + } + None => T::parse(value.as_str().unwrap()), } - }) - .flatten() + } else { + value.as_f64().and_then(num::cast::cast) + } + }) }) .collect::>(), )) @@ -1002,7 +1064,7 @@ impl Decoder { ArrayData::builder(list_field.data_type().clone()) .len(valid_len) .add_buffer(bool_values.into()) - .null_bit_buffer(bool_nulls.into()) + .null_bit_buffer(Some(bool_nulls.into())) .build_unchecked() } } @@ -1081,7 +1143,7 @@ impl Decoder { unsafe { ArrayDataBuilder::new(data_type) .len(rows.len()) - .null_bit_buffer(buf) + .null_bit_buffer(Some(buf)) .child_data( arrays.into_iter().map(|a| a.data().clone()).collect(), ) @@ -1100,7 +1162,7 @@ impl Decoder { .len(list_len) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_child_data(array_data) - .null_bit_buffer(list_nulls.into()); + .null_bit_buffer(Some(list_nulls.into())); let list_data = unsafe { list_data.build_unchecked() }; Ok(Arc::new(GenericListArray::::from(list_data))) } @@ -1272,12 +1334,7 @@ impl Decoder { .iter() .enumerate() .map(|(i, row)| { - ( - i, - row.as_object() - .map(|v| v.get(field.name())) - .flatten(), - ) + (i, row.as_object().and_then(|v| v.get(field.name()))) }) .map(|(i, v)| match v { // we want the field as an object, if it's not, we treat as null @@ -1294,7 +1351,7 @@ impl Decoder { let data_type = DataType::Struct(fields.clone()); let data = ArrayDataBuilder::new(data_type) .len(len) - .null_bit_buffer(null_buffer.into()) + .null_bit_buffer(Some(null_buffer.into())) .child_data( arrays.into_iter().map(|a| a.data().clone()).collect(), ); @@ -1351,8 +1408,7 @@ impl Decoder { let value_map_iter = rows.iter().map(|value| { value .get(field_name) - .map(|v| v.as_object().map(|map| (map, map.len() as i32))) - .flatten() + .and_then(|v| v.as_object().map(|map| (map, map.len() as i32))) }); let rows_len = rows.len(); let mut list_offsets = Vec::with_capacity(rows_len + 1); @@ -1542,13 +1598,8 @@ impl Reader { /// /// If reading a `File`, you can customise the Reader, such as to enable schema /// inference, use `ReaderBuilder`. - pub fn new( - reader: R, - schema: SchemaRef, - batch_size: usize, - projection: Option>, - ) -> Self { - Self::from_buf_reader(BufReader::new(reader), schema, batch_size, projection) + pub fn new(reader: R, schema: SchemaRef, options: DecoderOptions) -> Self { + Self::from_buf_reader(BufReader::new(reader), schema, options) } /// Create a new JSON Reader from a `BufReader` @@ -1557,12 +1608,11 @@ impl Reader { pub fn from_buf_reader( reader: BufReader, schema: SchemaRef, - batch_size: usize, - projection: Option>, + options: DecoderOptions, ) -> Self { Self { reader, - decoder: Decoder::new(schema, batch_size, projection), + decoder: Decoder::new(schema, options), } } @@ -1581,7 +1631,7 @@ impl Reader { } /// JSON file reader builder -#[derive(Debug)] +#[derive(Debug, Default)] pub struct ReaderBuilder { /// Optional schema for the JSON file /// @@ -1592,23 +1642,8 @@ pub struct ReaderBuilder { /// /// If a number is not provided, all the records are read. max_records: Option, - /// Batch size (number of records to load each time) - /// - /// The default batch size when using the `ReaderBuilder` is 1024 records - batch_size: usize, - /// Optional projection for which columns to load (zero-based column indices) - projection: Option>, -} - -impl Default for ReaderBuilder { - fn default() -> Self { - Self { - schema: None, - max_records: None, - batch_size: 1024, - projection: None, - } - } + /// Options for json decoder + options: DecoderOptions, } impl ReaderBuilder { @@ -1655,13 +1690,22 @@ impl ReaderBuilder { /// Set the batch size (number of records to load at one time) pub fn with_batch_size(mut self, batch_size: usize) -> Self { - self.batch_size = batch_size; + self.options = self.options.with_batch_size(batch_size); self } /// Set the reader's column projection pub fn with_projection(mut self, projection: Vec) -> Self { - self.projection = Some(projection); + self.options = self.options.with_projection(projection); + self + } + + /// Set the decoder's format Strings param + pub fn with_format_strings( + mut self, + format_strings: HashMap, + ) -> Self { + self.options = self.options.with_format_strings(format_strings); self } @@ -1681,12 +1725,7 @@ impl ReaderBuilder { )?), }; - Ok(Reader::from_buf_reader( - buf_reader, - schema, - self.batch_size, - self.projection, - )) + Ok(Reader::from_buf_reader(buf_reader, schema, self.options)) } } @@ -1718,7 +1757,7 @@ mod tests { .unwrap(); let batch = reader.next().unwrap().unwrap(); - assert_eq!(4, batch.num_columns()); + assert_eq!(5, batch.num_columns()); assert_eq!(12, batch.num_rows()); let schema = reader.schema(); @@ -1750,8 +1789,8 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - assert!(2.0 - bb.value(0) < f64::EPSILON); - assert!(-3.5 - bb.value(1) < f64::EPSILON); + assert_eq!(2.0, bb.value(0)); + assert_eq!(-3.5, bb.value(1)); let cc = batch .column(c.0) .as_any() @@ -1839,8 +1878,7 @@ mod tests { let mut reader: Reader = Reader::new( File::open("test/data/basic.json").unwrap(), Arc::new(schema.clone()), - 1024, - None, + DecoderOptions::new(), ); let reader_schema = reader.schema(); assert_eq!(reader_schema, Arc::new(schema)); @@ -1873,8 +1911,39 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - assert!(2.0 - bb.value(0) < f32::EPSILON); - assert!(-3.5 - bb.value(1) < f32::EPSILON); + assert_eq!(2.0, bb.value(0)); + assert_eq!(-3.5, bb.value(1)); + } + + #[test] + fn test_json_format_strings_for_date() { + let schema = + Arc::new(Schema::new(vec![Field::new("e", DataType::Date32, false)])); + let e = schema.column_with_name("e").unwrap(); + assert_eq!(&DataType::Date32, e.1.data_type()); + let mut fmts = HashMap::new(); + let date_format = "%Y-%m-%d".to_string(); + fmts.insert("e".to_string(), date_format.clone()); + + let mut reader: Reader = Reader::new( + File::open("test/data/basic.json").unwrap(), + schema.clone(), + DecoderOptions::new().with_format_strings(fmts), + ); + let reader_schema = reader.schema(); + assert_eq!(reader_schema, schema); + let batch = reader.next().unwrap().unwrap(); + + let ee = batch + .column(e.0) + .as_any() + .downcast_ref::() + .unwrap(); + let dt = Date32Type::parse_formatted("1970-1-2", &date_format).unwrap(); + assert_eq!(dt, ee.value(0)); + let dt = Date32Type::parse_formatted("1969-12-31", &date_format).unwrap(); + assert_eq!(dt, ee.value(1)); + assert!(!ee.is_valid(2)); } #[test] @@ -1891,8 +1960,7 @@ mod tests { let mut reader: Reader = Reader::new( File::open("test/data/basic.json").unwrap(), Arc::new(schema), - 1024, - Some(vec!["a".to_string(), "c".to_string()]), + DecoderOptions::new().with_projection(vec!["a".to_string(), "c".to_string()]), ); let reader_schema = reader.schema(); let expected_schema = Arc::new(Schema::new(vec![ @@ -1962,8 +2030,8 @@ mod tests { let bb = bb.values(); let bb = bb.as_any().downcast_ref::().unwrap(); assert_eq!(9, bb.len()); - assert!(2.0 - bb.value(0) < f64::EPSILON); - assert!(-6.1 - bb.value(5) < f64::EPSILON); + assert_eq!(2.0, bb.value(0)); + assert_eq!(-6.1, bb.value(5)); assert!(!bb.is_valid(7)); let cc = batch @@ -2059,7 +2127,8 @@ mod tests { file.seek(SeekFrom::Start(0)).unwrap(); let reader = BufReader::new(GzDecoder::new(&file)); - let mut reader = Reader::from_buf_reader(reader, Arc::new(schema), 64, None); + let options = DecoderOptions::new().with_batch_size(64); + let mut reader = Reader::from_buf_reader(reader, Arc::new(schema), options); let batch_gz = reader.next().unwrap().unwrap(); for batch in vec![batch, batch_gz] { @@ -2094,7 +2163,7 @@ mod tests { let bb = bb.values(); let bb = bb.as_any().downcast_ref::().unwrap(); assert_eq!(10, bb.len()); - assert!(4.0 - bb.value(9) < f64::EPSILON); + assert_eq!(4.0, bb.value(9)); let cc = batch .column(c.0) @@ -2166,7 +2235,7 @@ mod tests { let c = ArrayDataBuilder::new(c_field.data_type().clone()) .len(4) .add_child_data(d.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00000101])) + .null_bit_buffer(Some(Buffer::from(vec![0b00000101]))) .build() .unwrap(); let b = BooleanArray::from(vec![Some(true), Some(false), Some(true), None]); @@ -2174,7 +2243,7 @@ mod tests { .len(4) .add_child_data(b.data().clone()) .add_child_data(c) - .null_bit_buffer(Buffer::from(vec![0b00000111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00000111]))) .build() .unwrap(); let expected = make_array(a); @@ -2232,7 +2301,7 @@ mod tests { let c = ArrayDataBuilder::new(c_field.data_type().clone()) .len(7) .add_child_data(d.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00111011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00111011]))) .build() .unwrap(); let b = BooleanArray::from(vec![ @@ -2248,14 +2317,14 @@ mod tests { .len(7) .add_child_data(b.data().clone()) .add_child_data(c.clone()) - .null_bit_buffer(Buffer::from(vec![0b00111111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00111111]))) .build() .unwrap(); let a_list = ArrayDataBuilder::new(a_field.data_type().clone()) .len(6) .add_buffer(Buffer::from_slice_ref(&[0i32, 2, 3, 6, 6, 6, 7])) .add_child_data(a) - .null_bit_buffer(Buffer::from(vec![0b00110111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00110111]))) .build() .unwrap(); let expected = make_array(a_list); @@ -2354,7 +2423,7 @@ mod tests { vec![0i32, 2, 4, 7, 7, 8, 8, 9].to_byte_slice(), )) .add_child_data(expected_value_array_data) - .null_bit_buffer(Buffer::from(vec![0b01010111])) + .null_bit_buffer(Some(Buffer::from(vec![0b01010111]))) .build() .unwrap(); let expected_stocks_entries_data = ArrayDataBuilder::new(entries_struct_type) @@ -3088,6 +3157,37 @@ mod tests { assert_eq!(5, aa.value(7)); } + #[test] + fn test_time_from_string() { + parse_string_column::(4); + parse_string_column::(4); + parse_string_column::(4); + parse_string_column::(4); + } + + fn parse_string_column(value: T::Native) + where + T: ArrowPrimitiveType, + { + let schema = Schema::new(vec![Field::new("d", T::DATA_TYPE, true)]); + + let builder = ReaderBuilder::new() + .with_schema(Arc::new(schema)) + .with_batch_size(64); + let mut reader: Reader = builder + .build::(File::open("test/data/basic_nulls.json").unwrap()) + .unwrap(); + + let batch = reader.next().unwrap().unwrap(); + let dd = batch + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(value, dd.value(1)); + assert!(!dd.is_valid(2)); + } + #[test] fn test_json_read_nested_list() { let schema = Schema::new(vec![Field::new( @@ -3100,7 +3200,7 @@ mod tests { true, )]); - let decoder = Decoder::new(Arc::new(schema), 1024, None); + let decoder = Decoder::new(Arc::new(schema), DecoderOptions::new()); let batch = decoder .next_batch( &mut vec![ @@ -3135,7 +3235,7 @@ mod tests { true, )]); - let decoder = Decoder::new(Arc::new(schema), 1024, None); + let decoder = Decoder::new(Arc::new(schema), DecoderOptions::new()); let batch = decoder .next_batch( // NOTE: total struct element count needs to be greater than @@ -3164,7 +3264,7 @@ mod tests { #[test] fn test_json_read_binary_structs() { let schema = Schema::new(vec![Field::new("c1", DataType::Binary, true)]); - let decoder = Decoder::new(Arc::new(schema), 1024, None); + let decoder = Decoder::new(Arc::new(schema), DecoderOptions::new()); let batch = decoder .next_batch( &mut vec![ @@ -3207,7 +3307,7 @@ mod tests { let mut sum_a = 0; for batch in reader { let batch = batch.unwrap(); - assert_eq!(4, batch.num_columns()); + assert_eq!(5, batch.num_columns()); sum_num_rows += batch.num_rows(); num_batches += 1; let batch_schema = batch.schema(); @@ -3223,4 +3323,12 @@ mod tests { assert_eq!(3, num_batches); assert_eq!(100000000000011, sum_a); } + + #[test] + fn test_options_clone() { + // ensure options have appropriate derivation + let options = DecoderOptions::new().with_batch_size(64); + let cloned = options.clone(); + assert_eq!(options, cloned); + } } diff --git a/arrow/src/json/writer.rs b/arrow/src/json/writer.rs index 4279ab786fa0..078382f57dd0 100644 --- a/arrow/src/json/writer.rs +++ b/arrow/src/json/writer.rs @@ -38,7 +38,7 @@ //! let a = Int32Array::from(vec![1, 2, 3]); //! let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); //! -//! let json_rows = json::writer::record_batches_to_json_rows(&[batch]); +//! let json_rows = json::writer::record_batches_to_json_rows(&[batch]).unwrap(); //! assert_eq!( //! serde_json::Value::Object(json_rows[1].clone()), //! serde_json::json!({"a": 2}), @@ -110,64 +110,68 @@ use serde_json::Value; use crate::array::*; use crate::datatypes::*; -use crate::error::Result; +use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; -fn primitive_array_to_json(array: &ArrayRef) -> Vec { - as_primitive_array::(array) +fn primitive_array_to_json( + array: &ArrayRef, +) -> Result> { + Ok(as_primitive_array::(array) .iter() .map(|maybe_value| match maybe_value { Some(v) => v.into_json_value().unwrap_or(Value::Null), None => Value::Null, }) - .collect() + .collect()) } fn struct_array_to_jsonmap_array( array: &StructArray, row_count: usize, -) -> Vec> { +) -> Result>> { let inner_col_names = array.column_names(); let mut inner_objs = iter::repeat(JsonMap::new()) .take(row_count) .collect::>>(); - array - .columns() - .iter() - .enumerate() - .for_each(|(j, struct_col)| { - set_column_for_json_rows( - &mut inner_objs, - row_count, - struct_col, - inner_col_names[j], - ); - }); - - inner_objs + for (j, struct_col) in array.columns().iter().enumerate() { + set_column_for_json_rows( + &mut inner_objs, + row_count, + struct_col, + inner_col_names[j], + )? + } + Ok(inner_objs) } /// Converts an arrow [`ArrayRef`] into a `Vec` of Serde JSON [`serde_json::Value`]'s -pub fn array_to_json_array(array: &ArrayRef) -> Vec { +pub fn array_to_json_array(array: &ArrayRef) -> Result> { match array.data_type() { - DataType::Null => iter::repeat(Value::Null).take(array.len()).collect(), - DataType::Boolean => as_boolean_array(array) + DataType::Null => Ok(iter::repeat(Value::Null).take(array.len()).collect()), + DataType::Boolean => Ok(as_boolean_array(array) .iter() .map(|maybe_value| match maybe_value { Some(v) => v.into(), None => Value::Null, }) - .collect(), + .collect()), - DataType::Utf8 => as_string_array(array) + DataType::Utf8 => Ok(as_string_array(array) .iter() .map(|maybe_value| match maybe_value { Some(v) => v.into(), None => Value::Null, }) - .collect(), + .collect()), + DataType::LargeUtf8 => Ok(as_largestring_array(array) + .iter() + .map(|maybe_value| match maybe_value { + Some(v) => v.into(), + None => Value::Null, + }) + .collect()), DataType::Int8 => primitive_array_to_json::(array), DataType::Int16 => primitive_array_to_json::(array), DataType::Int32 => primitive_array_to_json::(array), @@ -181,28 +185,26 @@ pub fn array_to_json_array(array: &ArrayRef) -> Vec { DataType::List(_) => as_list_array(array) .iter() .map(|maybe_value| match maybe_value { - Some(v) => Value::Array(array_to_json_array(&v)), - None => Value::Null, + Some(v) => Ok(Value::Array(array_to_json_array(&v)?)), + None => Ok(Value::Null), }) .collect(), DataType::LargeList(_) => as_large_list_array(array) .iter() .map(|maybe_value| match maybe_value { - Some(v) => Value::Array(array_to_json_array(&v)), - None => Value::Null, + Some(v) => Ok(Value::Array(array_to_json_array(&v)?)), + None => Ok(Value::Null), }) .collect(), DataType::Struct(_) => { let jsonmaps = - struct_array_to_jsonmap_array(as_struct_array(array), array.len()); - jsonmaps.into_iter().map(Value::Object).collect() - } - _ => { - panic!( - "Unsupported datatype for array conversion: {:#?}", - array.data_type() - ); + struct_array_to_jsonmap_array(as_struct_array(array), array.len())?; + Ok(jsonmaps.into_iter().map(Value::Object).collect()) } + t => Err(ArrowError::JsonError(format!( + "data type {:?} not supported", + t + ))), } } @@ -261,37 +263,37 @@ fn set_column_for_json_rows( row_count: usize, array: &ArrayRef, col_name: &str, -) { +) -> Result<()> { match array.data_type() { DataType::Int8 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::Int16 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::Int32 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::Int64 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::UInt8 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::UInt16 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::UInt32 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::UInt64 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::Float32 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::Float64 => { - set_column_by_primitive_type::(rows, row_count, array, col_name) + set_column_by_primitive_type::(rows, row_count, array, col_name); } DataType::Null => { // when value is null, we simply skip setting the key @@ -302,6 +304,15 @@ fn set_column_for_json_rows( DataType::Utf8 => { set_column_by_array_type!(as_string_array, col_name, rows, array, row_count); } + DataType::LargeUtf8 => { + set_column_by_array_type!( + as_largestring_array, + col_name, + rows, + array, + row_count + ); + } DataType::Date32 => { set_temporal_column_by_array_type!( Date32Array, @@ -444,7 +455,7 @@ fn set_column_for_json_rows( } DataType::Struct(_) => { let inner_objs = - struct_array_to_jsonmap_array(as_struct_array(array), row_count); + struct_array_to_jsonmap_array(as_struct_array(array), row_count)?; rows.iter_mut() .take(row_count) .zip(inner_objs.into_iter()) @@ -457,46 +468,88 @@ fn set_column_for_json_rows( rows.iter_mut() .zip(listarr.iter()) .take(row_count) - .for_each(|(row, maybe_value)| { + .try_for_each(|(row, maybe_value)| -> Result<()> { if let Some(v) = maybe_value { row.insert( col_name.to_string(), - Value::Array(array_to_json_array(&v)), + Value::Array(array_to_json_array(&v)?), ); } - }); + Ok(()) + })?; } DataType::LargeList(_) => { let listarr = as_large_list_array(array); rows.iter_mut() .zip(listarr.iter()) .take(row_count) - .for_each(|(row, maybe_value)| { + .try_for_each(|(row, maybe_value)| -> Result<()> { if let Some(v) = maybe_value { - row.insert( - col_name.to_string(), - Value::Array(array_to_json_array(&v)), - ); + let val = array_to_json_array(&v)?; + row.insert(col_name.to_string(), Value::Array(val)); } - }); + Ok(()) + })?; } DataType::Dictionary(_, value_type) => { let slice = array.slice(0, row_count); let hydrated = crate::compute::kernels::cast::cast(&slice, value_type) .expect("cannot cast dictionary to underlying values"); - set_column_for_json_rows(rows, row_count, &hydrated, col_name) + set_column_for_json_rows(rows, row_count, &hydrated, col_name)?; + } + DataType::Map(_, _) => { + let maparr = as_map_array(array); + + let keys = maparr.keys(); + let values = maparr.values(); + + // Keys have to be strings to convert to json. + if !matches!(keys.data_type(), DataType::Utf8) { + return Err(ArrowError::JsonError(format!( + "data type {:?} not supported in nested map for json writer", + keys.data_type() + ))); + } + + let keys = as_string_array(&keys); + let values = array_to_json_array(&values)?; + + let mut kv = keys.iter().zip(values.into_iter()); + + for (i, row) in rows.iter_mut().take(row_count).enumerate() { + if maparr.is_null(i) { + row.insert(col_name.to_string(), serde_json::Value::Null); + continue; + } + + let len = maparr.value_length(i) as usize; + let mut obj = serde_json::Map::new(); + + for (_, (k, v)) in (0..len).zip(&mut kv) { + obj.insert( + k.expect("keys in a map should be non-null").to_string(), + v, + ); + } + + row.insert(col_name.to_string(), serde_json::Value::Object(obj)); + } } _ => { - panic!("Unsupported datatype: {:#?}", array.data_type()); + return Err(ArrowError::JsonError(format!( + "data type {:?} not supported in nested map for json writer", + array.data_type() + ))) } } + Ok(()) } /// Converts an arrow [`RecordBatch`] into a `Vec` of Serde JSON /// [`JsonMap`]s (objects) pub fn record_batches_to_json_rows( batches: &[RecordBatch], -) -> Vec> { +) -> Result>> { let mut rows: Vec> = iter::repeat(JsonMap::new()) .take(batches.iter().map(|b| b.num_rows()).sum()) .collect(); @@ -504,17 +557,17 @@ pub fn record_batches_to_json_rows( if !rows.is_empty() { let schema = batches[0].schema(); let mut base = 0; - batches.iter().for_each(|batch| { + for batch in batches { let row_count = batch.num_rows(); - batch.columns().iter().enumerate().for_each(|(j, col)| { + for (j, col) in batch.columns().iter().enumerate() { let col_name = schema.field(j).name(); - set_column_for_json_rows(&mut rows[base..], row_count, col, col_name); - }); + set_column_for_json_rows(&mut rows[base..], row_count, col, col_name)? + } base += row_count; - }); + } } - rows + Ok(rows) } /// This trait defines how to format a sequence of JSON objects to a @@ -646,9 +699,17 @@ where Ok(()) } + /// Convert the `RecordBatch` into JSON rows, and write them to the output + pub fn write(&mut self, batch: RecordBatch) -> Result<()> { + for row in record_batches_to_json_rows(&[batch])? { + self.write_row(&Value::Object(row))?; + } + Ok(()) + } + /// Convert the [`RecordBatch`] into JSON rows, and write them to the output pub fn write_batches(&mut self, batches: &[RecordBatch]) -> Result<()> { - for row in record_batches_to_json_rows(batches) { + for row in record_batches_to_json_rows(batches)? { self.write_row(&Value::Object(row))?; } Ok(()) @@ -715,6 +776,37 @@ mod tests { ); } + #[test] + fn write_large_utf8() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::LargeUtf8, true), + ]); + + let a = StringArray::from(vec![Some("a"), None, Some("c"), Some("d"), None]); + let b = LargeStringArray::from(vec![Some("a"), Some("b"), None, Some("d"), None]); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) + .unwrap(); + + let mut buf = Vec::new(); + { + let mut writer = LineDelimitedWriter::new(&mut buf); + writer.write_batches(&[batch]).unwrap(); + } + + assert_eq!( + String::from_utf8(buf).unwrap(), + r#"{"c1":"a","c2":"a"} +{"c2":"b"} +{"c1":"c"} +{"c1":"d","c2":"d"} +{} +"# + ); + } + #[test] fn write_dictionary() { let schema = Schema::new(vec![ @@ -1027,7 +1119,7 @@ mod tests { .len(5) .add_buffer(a_value_offsets) .add_child_data(a_values.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00011111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00011111]))) .build() .unwrap(); let a = ListArray::from(a_list_data); @@ -1078,7 +1170,7 @@ mod tests { let a_list_data = ArrayData::builder(list_inner_type.data_type().clone()) .len(3) .add_buffer(a_value_offsets) - .null_bit_buffer(Buffer::from(vec![0b00000111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00000111]))) .add_child_data(a_values.data().clone()) .build() .unwrap(); @@ -1162,7 +1254,7 @@ mod tests { .len(3) .add_buffer(c1_value_offsets) .add_child_data(struct_values.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00000101])) + .null_bit_buffer(Some(Buffer::from(vec![0b00000101]))) .build() .unwrap(); let c1 = ListArray::from(c1_list_data); @@ -1315,4 +1407,94 @@ mod tests { "# ); } + + #[test] + fn json_writer_map() { + let keys_array = + super::StringArray::from(vec!["foo", "bar", "baz", "qux", "quux"]); + let values_array = super::Int64Array::from(vec![10, 20, 30, 40, 50]); + + let keys = Field::new("keys", DataType::Utf8, false); + let values = Field::new("values", DataType::Int64, false); + let entry_struct = StructArray::from(vec![ + (keys, Arc::new(keys_array) as ArrayRef), + (values, Arc::new(values_array) as ArrayRef), + ]); + + let map_data_type = DataType::Map( + Box::new(Field::new( + "entries", + entry_struct.data_type().clone(), + true, + )), + false, + ); + + // [{"foo": 10}, null, {}, {"bar": 20, "baz": 30, "qux": 40}, {"quux": 50}, {}] + let entry_offsets = Buffer::from(&[0, 1, 1, 1, 4, 5, 5].to_byte_slice()); + let valid_buffer = Buffer::from([0b00111101]); + + let map_data = ArrayData::builder(map_data_type.clone()) + .len(6) + .null_bit_buffer(Some(valid_buffer)) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.data().clone()) + .build() + .unwrap(); + + let map = MapArray::from(map_data); + + let map_field = Field::new("map", map_data_type, false); + let schema = Arc::new(Schema::new(vec![map_field])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(map)]).unwrap(); + + let mut buf = Vec::new(); + { + let mut writer = LineDelimitedWriter::new(&mut buf); + writer.write_batches(&[batch]).unwrap(); + } + + assert_eq!( + String::from_utf8(buf).unwrap(), + r#"{"map":{"foo":10}} +{"map":null} +{"map":{}} +{"map":{"bar":20,"baz":30,"qux":40}} +{"map":{"quux":50}} +{"map":{}} +"# + ); + } + + #[test] + fn test_write_single_batch() { + let test_file = "test/data/basic.json"; + let builder = ReaderBuilder::new() + .infer_schema(None) + .with_batch_size(1024); + let mut reader: Reader = builder + .build::(File::open(test_file).unwrap()) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + + let mut buf = Vec::new(); + { + let mut writer = LineDelimitedWriter::new(&mut buf); + writer.write(batch).unwrap(); + } + + let result = String::from_utf8(buf).unwrap(); + let expected = read_to_string(test_file).unwrap(); + for (r, e) in result.lines().zip(expected.lines()) { + let mut expected_json = serde_json::from_str::(e).unwrap(); + // remove null value from object to make comparision consistent: + if let Value::Object(obj) = expected_json { + expected_json = Value::Object( + obj.into_iter().filter(|(_, v)| *v != Value::Null).collect(), + ); + } + assert_eq!(serde_json::from_str::(r).unwrap(), expected_json,); + } + } } diff --git a/arrow/src/lib.rs b/arrow/src/lib.rs index 2c2590cb4fc3..95c69ca0be6d 100644 --- a/arrow/src/lib.rs +++ b/arrow/src/lib.rs @@ -15,134 +15,220 @@ // specific language governing permissions and limitations // under the License. -//! A native Rust implementation of [Apache Arrow](https://arrow.apache.org), a cross-language +//! A complete, safe, native Rust implementation of [Apache Arrow](https://arrow.apache.org), a cross-language //! development platform for in-memory data. //! -//! ### DataType +//! # Columnar Format //! -//! Every [`Array`](array::Array) in this crate has an associated [`DataType`](datatypes::DataType), -//! that specifies how its data is layed in memory and represented. -//! Thus, a central enum of this crate is [`DataType`](datatypes::DataType), that contains the set of valid -//! DataTypes in the specification. For example, [`DataType::Utf8`](datatypes::DataType::Utf8). +//! The [`array`] module provides statically typed implementations of all the array +//! types as defined by the [Arrow Columnar Format](https://arrow.apache.org/docs/format/Columnar.html). //! -//! ## Array -//! -//! The central trait of this package is the dynamically-typed [`Array`](array::Array) that -//! represents a fixed-sized, immutable, Send + Sync Array of nullable elements. An example of such an array is [`UInt32Array`](array::UInt32Array). -//! One way to think about an arrow [`Array`](array::Array) is a `Arc<[Option; len]>` where T can be anything ranging from an integer to a string, or even -//! another [`Array`](array::Array). -//! -//! [`Arrays`](array::Array) have [`len()`](array::Array::len), [`data_type()`](array::Array::data_type), and the nullability of each of its elements, -//! can be obtained via [`is_null(index)`](array::Array::is_null). To downcast an [`Array`](array::Array) to a specific implementation, you can use +//! For example, an [`Int32Array`](array::Int32Array) represents a nullable array of `i32` //! //! ```rust -//! use arrow::array::{Array, UInt32Array}; -//! let array = UInt32Array::from(vec![Some(1), None, Some(3)]); +//! # use arrow::array::{Array, Int32Array}; +//! let array = Int32Array::from(vec![Some(1), None, Some(3)]); //! assert_eq!(array.len(), 3); //! assert_eq!(array.value(0), 1); //! assert_eq!(array.is_null(1), true); -//! ``` //! -//! To make the array dynamically typed, we wrap it in an [`Arc`](std::sync::Arc): -//! -//! ```rust -//! # use std::sync::Arc; -//! use arrow::datatypes::DataType; -//! use arrow::array::{UInt32Array, ArrayRef}; -//! # let array = UInt32Array::from(vec![Some(1), None, Some(3)]); -//! let array: ArrayRef = Arc::new(array); -//! assert_eq!(array.len(), 3); -//! // array.value() is not available in the dynamically-typed version -//! assert_eq!(array.is_null(1), true); -//! assert_eq!(array.data_type(), &DataType::UInt32); +//! let collected: Vec<_> = array.iter().collect(); +//! assert_eq!(collected, vec![Some(1), None, Some(3)]); +//! assert_eq!(array.values(), [1, 0, 3]) //! ``` //! -//! to downcast, use `as_any()`: +//! It is also possible to write generic code. For example, the following is generic over +//! all primitively typed arrays: //! //! ```rust -//! # use std::sync::Arc; -//! # use arrow::array::{UInt32Array, ArrayRef}; -//! # let array = UInt32Array::from(vec![Some(1), None, Some(3)]); -//! # let array: ArrayRef = Arc::new(array); -//! let array = array.as_any().downcast_ref::().unwrap(); -//! assert_eq!(array.value(0), 1); +//! # use std::iter::Sum; +//! # use arrow::array::{Float32Array, PrimitiveArray, TimestampNanosecondArray}; +//! # use arrow::datatypes::ArrowPrimitiveType; +//! # +//! fn sum(array: &PrimitiveArray) -> T::Native +//! where +//! T: ArrowPrimitiveType, +//! T::Native: Sum +//! { +//! array.iter().map(|v| v.unwrap_or_default()).sum() +//! } +//! +//! assert_eq!(sum(&Float32Array::from(vec![1.1, 2.9, 3.])), 7.); +//! assert_eq!(sum(&TimestampNanosecondArray::from(vec![1, 2, 3])), 6); //! ``` //! -//! ## Memory and Buffers +//! For more examples, consult the [`array`] docs. //! -//! Data in [`Array`](array::Array) is stored in [`ArrayData`](array::ArrayData), that in turn -//! is a collection of other [`ArrayData`](array::ArrayData) and [`Buffers`](buffer::Buffer). -//! [`Buffers`](buffer::Buffer) is the central struct that array implementations use keep allocated memory and pointers. -//! The [`MutableBuffer`](buffer::MutableBuffer) is the mutable counter-part of[`Buffer`](buffer::Buffer). -//! These are the lowest abstractions of this crate, and are used throughout the crate to -//! efficiently allocate, write, read and deallocate memory. +//! # Type Erasure / Trait Objects //! -//! ## Field, Schema and RecordBatch +//! It is often the case that code wishes to handle any type of array, without necessarily knowing +//! its concrete type. This use-case is catered for by a combination of [`Array`] +//! and [`DataType`](datatypes::DataType), with the former providing a type-erased container for +//! the array, and the latter identifying the concrete type of array. //! -//! [`Field`](datatypes::Field) is a struct that contains an array's metadata (datatype and whether its values -//! can be null), and a name. [`Schema`](datatypes::Schema) is a vector of fields with optional metadata. -//! Together, they form the basis of a schematic representation of a group of [`Arrays`](array::Array). +//! ```rust +//! # use arrow::array::{Array, Float32Array}; +//! # use arrow::array::StringArray; +//! # use arrow::datatypes::DataType; +//! # +//! fn impl_string(array: &StringArray) {} +//! fn impl_f32(array: &Float32Array) {} +//! +//! fn impl_dyn(array: &dyn Array) { +//! match array.data_type() { +//! DataType::Utf8 => impl_string(array.as_any().downcast_ref().unwrap()), +//! DataType::Float32 => impl_f32(array.as_any().downcast_ref().unwrap()), +//! _ => unimplemented!() +//! } +//! } +//! ``` //! -//! In fact, [`RecordBatch`](record_batch::RecordBatch) is a struct with a [`Schema`](datatypes::Schema) and a vector of -//! [`Array`](array::Array)s, all with the same `len`. A record batch is the highest order struct that this crate currently offers -//! and is broadly used to represent a table where each column in an `Array`. +//! It is also common to want to write a function that returns one of a number of possible +//! array implementations. [`ArrayRef`] is a type-alias for [`Arc`](array::Array) +//! which is frequently used for this purpose //! -//! ## Compute +//! ```rust +//! # use std::str::FromStr; +//! # use std::sync::Arc; +//! # use arrow::array::{ArrayRef, Int32Array, PrimitiveArray}; +//! # use arrow::datatypes::{ArrowPrimitiveType, DataType, Int32Type, UInt32Type}; +//! # use arrow::compute::cast; +//! # +//! fn parse_to_primitive<'a, T, I>(iter: I) -> PrimitiveArray +//! where +//! T: ArrowPrimitiveType, +//! I: IntoIterator, +//! { +//! PrimitiveArray::from_iter(iter.into_iter().map(|val| T::Native::from_str(val).ok())) +//! } +//! +//! fn parse_strings<'a, I>(iter: I, to_data_type: DataType) -> ArrayRef +//! where +//! I: IntoIterator, +//! { +//! match to_data_type { +//! DataType::Int32 => Arc::new(parse_to_primitive::(iter)) as _, +//! DataType::UInt32 => Arc::new(parse_to_primitive::(iter)) as _, +//! _ => unimplemented!() +//! } +//! } +//! +//! let array = parse_strings(["1", "2", "3"], DataType::Int32); +//! let integers = array.as_any().downcast_ref::().unwrap(); +//! assert_eq!(integers.values(), [1, 2, 3]) +//! ``` //! -//! This crate offers many operations (called kernels) to operate on `Array`s, that you can find at [compute::kernels]. -//! It has both vertical and horizontal operations, and some of them have an SIMD implementation. +//! # Compute Kernels //! -//! ## Status +//! The [`compute`](compute) module provides optimised implementations of many common operations, +//! for example the `parse_strings` operation above could also be implemented as follows: //! -//! This crate has most of the implementation of the arrow specification. Specifically, it supports the following types: +//! ``` +//! # use std::sync::Arc; +//! # use arrow::error::Result; +//! # use arrow::array::{ArrayRef, StringArray, UInt32Array}; +//! # use arrow::datatypes::DataType; +//! # +//! fn parse_strings<'a, I>(iter: I, to_data_type: &DataType) -> Result +//! where +//! I: IntoIterator, +//! { +//! let array = Arc::new(StringArray::from_iter(iter.into_iter().map(Some))) as _; +//! arrow::compute::cast(&array, to_data_type) +//! } +//! +//! let array = parse_strings(["1", "2", "3"], &DataType::UInt32).unwrap(); +//! let integers = array.as_any().downcast_ref::().unwrap(); +//! assert_eq!(integers.values(), [1, 2, 3]) +//! ``` //! -//! * All arrow primitive types, such as [`Int32Array`](array::UInt8Array), [`BooleanArray`](array::BooleanArray) and [`Float64Array`](array::Float64Array). -//! * All arrow variable length types, such as [`StringArray`](array::StringArray) and [`BinaryArray`](array::BinaryArray) -//! * All composite types such as [`StructArray`](array::StructArray) and [`ListArray`](array::ListArray) -//! * Dictionary types [`DictionaryArray`](array::DictionaryArray) - +//! This module also implements many common vertical operations: //! -//! This crate also implements many common vertical operations: -//! * all mathematical binary operators, such as [`subtract`](compute::kernels::arithmetic::subtract) -//! * all boolean binary operators such as [`equality`](compute::kernels::comparison::eq) +//! * All mathematical binary operators, such as [`subtract`](compute::kernels::arithmetic::subtract) +//! * All boolean binary operators such as [`equality`](compute::kernels::comparison::eq) //! * [`cast`](compute::kernels::cast::cast) //! * [`filter`](compute::kernels::filter::filter) //! * [`take`](compute::kernels::take::take) and [`limit`](compute::kernels::limit::limit) //! * [`sort`](compute::kernels::sort::sort) //! * some string operators such as [`substring`](compute::kernels::substring::substring) and [`length`](compute::kernels::length::length) //! -//! as well as some horizontal operations, such as +//! As well as some horizontal operations, such as: //! //! * [`min`](compute::kernels::aggregate::min) and [`max`](compute::kernels::aggregate::max) //! * [`sum`](compute::kernels::aggregate::sum) //! -//! Finally, this crate implements some readers and writers to different formats: +//! # Tabular Representation +//! +//! It is common to want to group one or more columns together into a tabular representation. This +//! is provided by [`RecordBatch`] which combines a [`Schema`](datatypes::Schema) +//! and a corresponding list of [`ArrayRef`]. +//! +//! +//! ``` +//! # use std::sync::Arc; +//! # use arrow::array::{Float32Array, Int32Array}; +//! # use arrow::record_batch::RecordBatch; +//! # +//! let col_1 = Arc::new(Int32Array::from_iter([1, 2, 3])) as _; +//! let col_2 = Arc::new(Float32Array::from_iter([1., 6.3, 4.])) as _; +//! +//! let batch = RecordBatch::try_from_iter([("col1", col_1), ("col_2", col_2)]).unwrap(); +//! ``` +//! +//! # IO +//! +//! This crate provides readers and writers for various formats to/from [`RecordBatch`] +//! +//! * JSON: [`Reader`](json::reader::Reader) and [`Writer`](json::writer::Writer) +//! * CSV: [`Reader`](csv::reader::Reader) and [`Writer`](csv::writer::Writer) +//! * IPC: [`Reader`](ipc::reader::StreamReader) and [`Writer`](ipc::writer::FileWriter) +//! +//! Parquet is published as a [separate crate](https://crates.io/crates/parquet) +//! +//! # Memory and Buffers +//! +//! Advanced users may wish to interact with the underlying buffers of an [`Array`], for example, +//! for FFI or high-performance conversion from other formats. This interface is provided by +//! [`ArrayData`] which stores the [`Buffer`] comprising an [`Array`], and can be accessed +//! with [`Array::data`](array::Array::data) +//! +//! The APIs for constructing [`ArrayData`] come in safe, and unsafe variants, with the former +//! performing extensive, but potentially expensive validation to ensure the buffers are well-formed. +//! +//! An [`ArrayRef`] can be cheaply created from an [`ArrayData`] using [`make_array`], +//! or by using the appropriate [`From`] conversion on the concrete [`Array`] implementation. +//! +//! # Safety and Security +//! +//! Like many crates, this crate makes use of unsafe where prudent. However, it endeavours to be +//! sound. Specifically, **it should not be possible to trigger undefined behaviour using safe APIs.** +//! +//! If you think you have found an instance where this is possible, please file +//! a ticket in our [issue tracker] and it will be triaged and fixed. For more information on +//! arrow's use of unsafe, see [here](https://github.com/apache/arrow-rs/tree/master/arrow#safety). +//! +//! # Higher-level Processing +//! +//! This crate aims to provide reusable, low-level primitives for operating on columnar data. For +//! more sophisticated query processing workloads, consider checking out [DataFusion]. This +//! orchestrates the primitives exported by this crate into an embeddable query engine, with +//! SQL and DataFrame frontends, and heavily influences this crate's roadmap. //! -//! * json: [reader](json::reader::Reader) -//! * csv: [reader](csv::reader::Reader) and [writer](csv::writer::Writer) -//! * ipc: [reader](ipc::reader::StreamReader) and [writer](ipc::writer::FileWriter) +//! [`array`]: mod@array +//! [`Array`]: array::Array +//! [`ArrayRef`]: array::ArrayRef +//! [`ArrayData`]: array::ArrayData +//! [`make_array`]: array::make_array +//! [`Buffer`]: buffer::Buffer +//! [`RecordBatch`]: record_batch::RecordBatch +//! [DataFusion]: https://github.com/apache/arrow-datafusion +//! [issue tracker]: https://github.com/apache/arrow-rs/issues //! -//! The parquet implementation is on a [separate crate](https://crates.io/crates/parquet) -#![cfg_attr(feature = "avx512", feature(stdsimd))] -#![cfg_attr(feature = "avx512", feature(repr_simd))] -#![cfg_attr(feature = "avx512", feature(avx512_target_feature))] -#![allow(dead_code)] -#![allow(non_camel_case_types)] #![deny(clippy::redundant_clone)] -#![allow( - // introduced to ignore lint errors when upgrading from 2020-04-22 to 2020-11-14 - clippy::float_equality_without_abs, - clippy::type_complexity, - // upper_case_acronyms lint was introduced in Rust 1.51. - // It is triggered in the ffi module, and ipc::gen, which we have no control over - clippy::upper_case_acronyms, - clippy::vec_init_then_push -)] #![warn(missing_debug_implementations)] pub mod alloc; -mod arch; pub mod array; pub mod bitmap; pub mod buffer; @@ -153,6 +239,7 @@ pub mod csv; pub mod datatypes; pub mod error; pub mod ffi; +pub mod ffi_stream; #[cfg(feature = "ipc")] pub mod ipc; pub mod json; @@ -162,4 +249,3 @@ pub mod record_batch; pub mod temporal_conversions; pub mod tensor; pub mod util; -mod zz_memory_check; diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 12d4c0d1f92d..3ae5b3b9987f 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -24,13 +24,16 @@ use std::sync::Arc; use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; use pyo3::prelude::*; -use pyo3::types::PyList; +use pyo3::types::{PyList, PyTuple}; -use crate::array::{make_array, Array, ArrayData, ArrayRef}; +use crate::array::{Array, ArrayData, ArrayRef}; use crate::datatypes::{DataType, Field, Schema}; use crate::error::ArrowError; use crate::ffi; use crate::ffi::FFI_ArrowSchema; +use crate::ffi_stream::{ + export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream, +}; use crate::record_batch::RecordBatch; import_exception!(pyarrow, ArrowException); @@ -148,16 +151,6 @@ impl PyArrowConvert for ArrayData { } } -impl PyArrowConvert for ArrayRef { - fn from_pyarrow(value: &PyAny) -> PyResult { - Ok(make_array(ArrayData::from_pyarrow(value)?)) - } - - fn to_pyarrow(&self, py: Python) -> PyResult { - self.data().to_pyarrow(py) - } -} - impl PyArrowConvert for T where T: Array + From, @@ -208,6 +201,42 @@ impl PyArrowConvert for RecordBatch { } } +impl PyArrowConvert for ArrowArrayStreamReader { + fn from_pyarrow(value: &PyAny) -> PyResult { + // prepare a pointer to receive the stream struct + let stream = Box::new(FFI_ArrowArrayStream::empty()); + let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream; + + // make the conversion through PyArrow's private API + // this changes the pointer's memory and is thus unsafe. + // In particular, `_export_to_c` can go out of bounds + let args = PyTuple::new(value.py(), &[stream_ptr as Py_uintptr_t]); + value.call_method1("_export_to_c", args)?; + + let stream_reader = + unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() }; + + unsafe { + Box::from_raw(stream_ptr); + } + + Ok(stream_reader) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let stream = Box::new(FFI_ArrowArrayStream::empty()); + let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream; + + unsafe { export_reader_into_raw(Box::new(self.clone()), stream_ptr) }; + + let module = py.import("pyarrow")?; + let class = module.getattr("RecordBatchReader")?; + let args = PyTuple::new(py, &[stream_ptr as Py_uintptr_t]); + let reader = class.call_method1("_import_from_c", args)?; + Ok(PyObject::from(reader)) + } +} + macro_rules! add_conversion { ($typ:ty) => { impl<'source> FromPyObject<'source> for $typ { @@ -229,3 +258,4 @@ add_conversion!(Field); add_conversion!(Schema); add_conversion!(ArrayData); add_conversion!(RecordBatch); +add_conversion!(ArrowArrayStreamReader); diff --git a/arrow/src/record_batch.rs b/arrow/src/record_batch.rs index b441f6cf295e..ae8fae58f1af 100644 --- a/arrow/src/record_batch.rs +++ b/arrow/src/record_batch.rs @@ -41,6 +41,11 @@ use crate::error::{ArrowError, Result}; pub struct RecordBatch { schema: SchemaRef, columns: Vec>, + + /// The number of rows in this RecordBatch + /// + /// This is stored separately from the columns to handle the case of no columns + row_count: usize, } impl RecordBatch { @@ -77,8 +82,7 @@ impl RecordBatch { /// ``` pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { let options = RecordBatchOptions::default(); - Self::validate_new_batch(&schema, columns.as_slice(), &options)?; - Ok(RecordBatch { schema, columns }) + Self::try_new_impl(schema, columns, &options) } /// Creates a `RecordBatch` from a schema and columns, with additional options, @@ -90,8 +94,7 @@ impl RecordBatch { columns: Vec, options: &RecordBatchOptions, ) -> Result { - Self::validate_new_batch(&schema, columns.as_slice(), options)?; - Ok(RecordBatch { schema, columns }) + Self::try_new_impl(schema, columns, options) } /// Creates a new empty [`RecordBatch`]. @@ -101,23 +104,21 @@ impl RecordBatch { .iter() .map(|field| new_empty_array(field.data_type())) .collect(); - RecordBatch { schema, columns } + + RecordBatch { + schema, + columns, + row_count: 0, + } } /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error - /// if any validation check fails. - fn validate_new_batch( - schema: &SchemaRef, - columns: &[ArrayRef], + /// if any validation check fails, otherwise returns the created [`Self`] + fn try_new_impl( + schema: SchemaRef, + columns: Vec, options: &RecordBatchOptions, - ) -> Result<()> { - // check that there are some columns - if columns.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "at least one column must be defined to create a record batch" - .to_string(), - )); - } + ) -> Result { // check that number of fields in schema match column length if schema.fields().len() != columns.len() { return Err(ArrowError::InvalidArgumentError(format!( @@ -126,48 +127,60 @@ impl RecordBatch { schema.fields().len(), ))); } - // check that all columns have the same row count, and match the schema - let len = columns[0].data().len(); - - // This is a bit repetitive, but it is better to check the condition outside the loop - if options.match_field_names { - for (i, column) in columns.iter().enumerate() { - if column.len() != len { - return Err(ArrowError::InvalidArgumentError( - "all columns in a record batch must have the same length" - .to_string(), - )); - } - if column.data_type() != schema.field(i).data_type() { - return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {:?} but found {:?} at column index {}", - schema.field(i).data_type(), - column.data_type(), - i))); + + // check that all columns have the same row count + let row_count = options + .row_count + .or_else(|| columns.first().map(|col| col.len())) + .ok_or_else(|| { + ArrowError::InvalidArgumentError( + "must either specify a row count or at least one column".to_string(), + ) + })?; + + if columns.iter().any(|c| c.len() != row_count) { + let err = match options.row_count { + Some(_) => { + "all columns in a record batch must have the specified row count" } + None => "all columns in a record batch must have the same length", + }; + return Err(ArrowError::InvalidArgumentError(err.to_string())); + } + + // function for comparing column type and field type + // return true if 2 types are not matched + let type_not_match = if options.match_field_names { + |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { + col_type != field_type } } else { - for (i, column) in columns.iter().enumerate() { - if column.len() != len { - return Err(ArrowError::InvalidArgumentError( - "all columns in a record batch must have the same length" - .to_string(), - )); - } - if !column - .data_type() - .equals_datatype(schema.field(i).data_type()) - { - return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {:?} but found {:?} at column index {}", - schema.field(i).data_type(), - column.data_type(), - i))); - } + |(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| { + !col_type.equals_datatype(field_type) } + }; + + // check that all columns match the schema + let not_match = columns + .iter() + .zip(schema.fields().iter()) + .map(|(col, field)| (col.data_type(), field.data_type())) + .enumerate() + .find(type_not_match); + + if let Some((i, (col_type, field_type))) = not_match { + return Err(ArrowError::InvalidArgumentError(format!( + "column types must match schema types, expected {:?} but found {:?} at column index {}", + field_type, + col_type, + i))); } - Ok(()) + Ok(RecordBatch { + schema, + columns, + row_count, + }) } /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. @@ -175,6 +188,25 @@ impl RecordBatch { self.schema.clone() } + /// Projects the schema onto the specified columns + pub fn project(&self, indices: &[usize]) -> Result { + let projected_schema = self.schema.project(indices)?; + let batch_fields = indices + .iter() + .map(|f| { + self.columns.get(*f).cloned().ok_or_else(|| { + ArrowError::SchemaError(format!( + "project index {} out of bounds, max field {}", + f, + self.columns.len() + )) + }) + }) + .collect::>>()?; + + RecordBatch::try_new(SchemaRef::new(projected_schema), batch_fields) + } + /// Returns the number of columns in the record batch. /// /// # Example @@ -203,10 +235,6 @@ impl RecordBatch { /// Returns the number of rows in each column. /// - /// # Panics - /// - /// Panics if the `RecordBatch` contains no columns. - /// /// # Example /// /// ``` @@ -228,7 +256,7 @@ impl RecordBatch { /// # } /// ``` pub fn num_rows(&self) -> usize { - self.columns[0].data().len() + self.row_count } /// Get a reference to a column's array by index. @@ -252,10 +280,6 @@ impl RecordBatch { /// /// Panics if `offset` with `length` is greater than column length. pub fn slice(&self, offset: usize, length: usize) -> RecordBatch { - if self.schema.fields().is_empty() { - assert!((offset + length) == 0); - return RecordBatch::new_empty(self.schema.clone()); - } assert!((offset + length) <= self.num_rows()); let columns = self @@ -267,6 +291,7 @@ impl RecordBatch { Self { schema: self.schema.clone(), columns, + row_count: length, } } @@ -387,15 +412,20 @@ impl RecordBatch { /// Options that control the behaviour used when creating a [`RecordBatch`]. #[derive(Debug)] +#[non_exhaustive] pub struct RecordBatchOptions { /// Match field names of structs and lists. If set to `true`, the names must match. pub match_field_names: bool, + + /// Optional row count, useful for specifying a row count for a RecordBatch with no columns + pub row_count: Option, } impl Default for RecordBatchOptions { fn default() -> Self { Self { match_field_names: true, + row_count: None, } } } @@ -411,6 +441,7 @@ impl From<&StructArray> for RecordBatch { let columns = struct_array.boxed_fields.clone(); RecordBatch { schema: Arc::new(schema), + row_count: struct_array.len(), columns, } } else { @@ -517,7 +548,7 @@ mod tests { } #[test] - #[should_panic(expected = "assertion failed: (offset + length) == 0")] + #[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")] fn create_record_batch_slice_empty_batch() { let schema = Schema::new(vec![]); @@ -582,7 +613,7 @@ mod tests { let a = Int64Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]); - assert!(!batch.is_ok()); + assert!(batch.is_err()); } #[test] @@ -629,6 +660,7 @@ mod tests { // creating the batch without field name validation should pass let options = RecordBatchOptions { match_field_names: false, + row_count: None, }; let batch = RecordBatch::try_new_with_options(schema, vec![a], &options); assert!(batch.is_ok()); @@ -643,7 +675,7 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]); - assert!(!batch.is_ok()); + assert!(batch.is_err()); } #[test] @@ -900,4 +932,51 @@ mod tests { assert_ne!(batch1, batch2); } + + #[test] + fn project() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); + + let record_batch = RecordBatch::try_from_iter(vec![ + ("a", a.clone()), + ("b", b.clone()), + ("c", c.clone()), + ]) + .expect("valid conversion"); + + let expected = RecordBatch::try_from_iter(vec![("a", a), ("c", c)]) + .expect("valid conversion"); + + assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); + } + + #[test] + fn test_no_column_record_batch() { + let schema = Arc::new(Schema::new(vec![])); + + let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err(); + assert!(err + .to_string() + .contains("must either specify a row count or at least one column")); + + let options = RecordBatchOptions { + row_count: Some(10), + ..Default::default() + }; + + let ok = + RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + assert_eq!(ok.num_rows(), 10); + + let a = ok.slice(2, 5); + assert_eq!(a.num_rows(), 5); + + let b = ok.slice(5, 0); + assert_eq!(b.num_rows(), 0); + + assert_ne!(a, b); + assert_eq!(b, RecordBatch::new_empty(schema)) + } } diff --git a/arrow/src/util/bench_util.rs b/arrow/src/util/bench_util.rs index 40340336882b..395f3702d57a 100644 --- a/arrow/src/util/bench_util.rs +++ b/arrow/src/util/bench_util.rs @@ -91,9 +91,18 @@ where } /// Creates an random (but fixed-seeded) array of a given size and null density -pub fn create_string_array( +pub fn create_string_array( size: usize, null_density: f32, +) -> GenericStringArray { + create_string_array_with_len(size, null_density, 4) +} + +/// Creates a random (but fixed-seeded) array of a given size, null density and length +pub fn create_string_array_with_len( + size: usize, + null_density: f32, + str_len: usize, ) -> GenericStringArray { let rng = &mut seedable_rng(); @@ -102,7 +111,7 @@ pub fn create_string_array( if rng.gen::() < null_density { None } else { - let value = rng.sample_iter(&Alphanumeric).take(4).collect(); + let value = rng.sample_iter(&Alphanumeric).take(str_len).collect(); let value = String::from_utf8(value).unwrap(); Some(value) } @@ -110,8 +119,31 @@ pub fn create_string_array( .collect() } +/// Creates an random (but fixed-seeded) array of a given size and null density +/// consisting of random 4 character alphanumeric strings +pub fn create_string_dict_array( + size: usize, + null_density: f32, +) -> DictionaryArray { + let rng = &mut seedable_rng(); + + let data: Vec<_> = (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + let value = rng.sample_iter(&Alphanumeric).take(4).collect(); + let value = String::from_utf8(value).unwrap(); + Some(value) + } + }) + .collect(); + + data.iter().map(|x| x.as_deref()).collect() +} + /// Creates an random (but fixed-seeded) binary array of a given size and null density -pub fn create_binary_array( +pub fn create_binary_array( size: usize, null_density: f32, ) -> GenericBinaryArray { diff --git a/arrow/src/util/bit_chunk_iterator.rs b/arrow/src/util/bit_chunk_iterator.rs index ea9280cee3f1..f0127ed2267f 100644 --- a/arrow/src/util/bit_chunk_iterator.rs +++ b/arrow/src/util/bit_chunk_iterator.rs @@ -14,9 +14,197 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + +//! Types for iterating over bitmasks in 64-bit chunks + use crate::util::bit_util::ceil; use std::fmt::Debug; +/// Iterates over an arbitrarily aligned byte buffer +/// +/// Yields an iterator of aligned u64, along with the leading and trailing +/// u64 necessary to align the buffer to a 8-byte boundary +/// +/// This is unlike [`BitChunkIterator`] which only exposes a trailing u64, +/// and consequently has to perform more work for each read +#[derive(Debug)] +pub struct UnalignedBitChunk<'a> { + lead_padding: usize, + trailing_padding: usize, + + prefix: Option, + chunks: &'a [u64], + suffix: Option, +} + +impl<'a> UnalignedBitChunk<'a> { + /// Create a from a byte array, and and an offset and length in bits + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { + if len == 0 { + return Self { + lead_padding: 0, + trailing_padding: 0, + prefix: None, + chunks: &[], + suffix: None, + }; + } + + let byte_offset = offset / 8; + let offset_padding = offset % 8; + + let bytes_len = (len + offset_padding + 7) / 8; + let buffer = &buffer[byte_offset..byte_offset + bytes_len]; + + let prefix_mask = compute_prefix_mask(offset_padding); + + // If less than 8 bytes, read into prefix + if buffer.len() <= 8 { + let (suffix_mask, trailing_padding) = + compute_suffix_mask(len, offset_padding); + let prefix = read_u64(buffer) & suffix_mask & prefix_mask; + + return Self { + lead_padding: offset_padding, + trailing_padding, + prefix: Some(prefix), + chunks: &[], + suffix: None, + }; + } + + // If less than 16 bytes, read into prefix and suffix + if buffer.len() <= 16 { + let (suffix_mask, trailing_padding) = + compute_suffix_mask(len, offset_padding); + let prefix = read_u64(&buffer[..8]) & prefix_mask; + let suffix = read_u64(&buffer[8..]) & suffix_mask; + + return Self { + lead_padding: offset_padding, + trailing_padding, + prefix: Some(prefix), + chunks: &[], + suffix: Some(suffix), + }; + } + + // Read into prefix and suffix as needed + let (prefix, mut chunks, suffix) = unsafe { buffer.align_to::() }; + assert!( + prefix.len() < 8 && suffix.len() < 8, + "align_to did not return largest possible aligned slice" + ); + + let (alignment_padding, prefix) = match (offset_padding, prefix.is_empty()) { + (0, true) => (0, None), + (_, true) => { + let prefix = chunks[0] & prefix_mask; + chunks = &chunks[1..]; + (0, Some(prefix)) + } + (_, false) => { + let alignment_padding = (8 - prefix.len()) * 8; + + let prefix = (read_u64(prefix) & prefix_mask) << alignment_padding; + (alignment_padding, Some(prefix)) + } + }; + + let lead_padding = offset_padding + alignment_padding; + let (suffix_mask, trailing_padding) = compute_suffix_mask(len, lead_padding); + + let suffix = match (trailing_padding, suffix.is_empty()) { + (0, _) => None, + (_, true) => { + let suffix = chunks[chunks.len() - 1] & suffix_mask; + chunks = &chunks[..chunks.len() - 1]; + Some(suffix) + } + (_, false) => Some(read_u64(suffix) & suffix_mask), + }; + + Self { + lead_padding, + trailing_padding, + prefix, + chunks, + suffix, + } + } + + pub fn lead_padding(&self) -> usize { + self.lead_padding + } + + pub fn trailing_padding(&self) -> usize { + self.trailing_padding + } + + pub fn prefix(&self) -> Option { + self.prefix + } + + pub fn suffix(&self) -> Option { + self.suffix + } + + pub fn chunks(&self) -> &'a [u64] { + self.chunks + } + + pub(crate) fn iter(&self) -> UnalignedBitChunkIterator<'a> { + self.prefix + .into_iter() + .chain(self.chunks.iter().cloned()) + .chain(self.suffix.into_iter()) + } + + /// Counts the number of ones + pub fn count_ones(&self) -> usize { + self.iter().map(|x| x.count_ones() as usize).sum() + } +} + +pub(crate) type UnalignedBitChunkIterator<'a> = std::iter::Chain< + std::iter::Chain< + std::option::IntoIter, + std::iter::Cloned>, + >, + std::option::IntoIter, +>; + +#[inline] +fn read_u64(input: &[u8]) -> u64 { + let len = input.len().min(8); + let mut buf = [0_u8; 8]; + (&mut buf[..len]).copy_from_slice(input); + u64::from_le_bytes(buf) +} + +#[inline] +fn compute_prefix_mask(lead_padding: usize) -> u64 { + !((1 << lead_padding) - 1) +} + +#[inline] +fn compute_suffix_mask(len: usize, lead_padding: usize) -> (u64, usize) { + let trailing_bits = (len + lead_padding) % 64; + + if trailing_bits == 0 { + return (u64::MAX, 0); + } + + let trailing_padding = 64 - trailing_bits; + let suffix_mask = (1 << trailing_bits) - 1; + (suffix_mask, trailing_padding) +} + +/// Iterates over an arbitrarily aligned byte buffer +/// +/// Yields an iterator of u64, and a remainder. The first byte in the buffer +/// will be the least significant byte in output u64 +/// #[derive(Debug)] pub struct BitChunks<'a> { buffer: &'a [u8], @@ -174,7 +362,10 @@ impl ExactSizeIterator for BitChunkIterator<'_> { #[cfg(test)] mod tests { + use rand::prelude::*; + use crate::buffer::Buffer; + use crate::util::bit_chunk_iterator::UnalignedBitChunk; #[test] fn test_iter_aligned() { @@ -272,4 +463,199 @@ mod tests { assert_eq!(u64::MAX, bitchunks.iter().last().unwrap()); assert_eq!(0x7F, bitchunks.remainder_bits()); } + + #[test] + #[allow(clippy::assertions_on_constants)] + fn test_unaligned_bit_chunk_iterator() { + let buffer = Buffer::from(&[0xFF; 5]); + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 40); + + assert!(unaligned.chunks().is_empty()); // Less than 128 elements + assert_eq!(unaligned.lead_padding(), 0); + assert_eq!(unaligned.trailing_padding(), 24); + // 24x 1 bit then 40x 0 bits + assert_eq!( + unaligned.prefix(), + Some(0b0000000000000000000000001111111111111111111111111111111111111111) + ); + assert_eq!(unaligned.suffix(), None); + + let buffer = buffer.slice(1); + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 32); + + assert!(unaligned.chunks().is_empty()); // Less than 128 elements + assert_eq!(unaligned.lead_padding(), 0); + assert_eq!(unaligned.trailing_padding(), 32); + // 32x 1 bit then 32x 0 bits + assert_eq!( + unaligned.prefix(), + Some(0b0000000000000000000000000000000011111111111111111111111111111111) + ); + assert_eq!(unaligned.suffix(), None); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 5, 27); + + assert!(unaligned.chunks().is_empty()); // Less than 128 elements + assert_eq!(unaligned.lead_padding(), 5); // 5 % 8 == 5 + assert_eq!(unaligned.trailing_padding(), 32); + // 5x 0 bit, 27x 1 bit then 32x 0 bits + assert_eq!( + unaligned.prefix(), + Some(0b0000000000000000000000000000000011111111111111111111111111100000) + ); + assert_eq!(unaligned.suffix(), None); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 12, 20); + + assert!(unaligned.chunks().is_empty()); // Less than 128 elements + assert_eq!(unaligned.lead_padding(), 4); // 12 % 8 == 4 + assert_eq!(unaligned.trailing_padding(), 40); + // 4x 0 bit, 20x 1 bit then 40x 0 bits + assert_eq!( + unaligned.prefix(), + Some(0b0000000000000000000000000000000000000000111111111111111111110000) + ); + assert_eq!(unaligned.suffix(), None); + + let buffer = Buffer::from(&[0xFF; 14]); + + // Verify buffer alignment + let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::() }; + assert_eq!(prefix.len(), 0); + assert_eq!(aligned.len(), 1); + assert_eq!(suffix.len(), 6); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 112); + + assert!(unaligned.chunks().is_empty()); // Less than 128 elements + assert_eq!(unaligned.lead_padding(), 0); // No offset and buffer aligned on 64-bit boundary + assert_eq!(unaligned.trailing_padding(), 16); + assert_eq!(unaligned.prefix(), Some(u64::MAX)); + assert_eq!(unaligned.suffix(), Some((1 << 48) - 1)); + + let buffer = Buffer::from(&[0xFF; 16]); + + // Verify buffer alignment + let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::() }; + assert_eq!(prefix.len(), 0); + assert_eq!(aligned.len(), 2); + assert_eq!(suffix.len(), 0); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 128); + + assert_eq!(unaligned.prefix(), Some(u64::MAX)); + assert_eq!(unaligned.suffix(), Some(u64::MAX)); + assert!(unaligned.chunks().is_empty()); // Exactly 128 elements + + let buffer = Buffer::from(&[0xFF; 64]); + + // Verify buffer alignment + let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::() }; + assert_eq!(prefix.len(), 0); + assert_eq!(aligned.len(), 8); + assert_eq!(suffix.len(), 0); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 512); + + // Buffer is completely aligned and larger than 128 elements -> all in chunks array + assert_eq!(unaligned.suffix(), None); + assert_eq!(unaligned.prefix(), None); + assert_eq!(unaligned.chunks(), [u64::MAX; 8].as_slice()); + assert_eq!(unaligned.lead_padding(), 0); + assert_eq!(unaligned.trailing_padding(), 0); + + let buffer = buffer.slice(1); // Offset buffer 1 byte off 64-bit alignment + + // Verify buffer alignment + let (prefix, aligned, suffix) = unsafe { buffer.as_slice().align_to::() }; + assert_eq!(prefix.len(), 7); + assert_eq!(aligned.len(), 7); + assert_eq!(suffix.len(), 0); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 0, 504); + + // Need a prefix with 1 byte of lead padding to bring the buffer into alignment + assert_eq!(unaligned.prefix(), Some(u64::MAX - 0xFF)); + assert_eq!(unaligned.suffix(), None); + assert_eq!(unaligned.chunks(), [u64::MAX; 7].as_slice()); + assert_eq!(unaligned.lead_padding(), 8); + assert_eq!(unaligned.trailing_padding(), 0); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 17, 300); + + // Out of 64-bit alignment by 8 bits from buffer, and 17 bits from provided offset + // => need 8 + 17 = 25 bits of lead padding + 39 bits in prefix + // + // This leaves 300 - 17 = 261 bits remaining + // => 4x 64-bit aligned 64-bit chunks + 5 remaining bits + // => trailing padding of 59 bits + assert_eq!(unaligned.lead_padding(), 25); + assert_eq!(unaligned.trailing_padding(), 59); + assert_eq!(unaligned.prefix(), Some(u64::MAX - (1 << 25) + 1)); + assert_eq!(unaligned.suffix(), Some(0b11111)); + assert_eq!(unaligned.chunks(), [u64::MAX; 4].as_slice()); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 17, 0); + + assert_eq!(unaligned.prefix(), None); + assert_eq!(unaligned.suffix(), None); + assert!(unaligned.chunks().is_empty()); + assert_eq!(unaligned.lead_padding(), 0); + assert_eq!(unaligned.trailing_padding(), 0); + + let unaligned = UnalignedBitChunk::new(buffer.as_slice(), 17, 1); + + assert_eq!(unaligned.prefix(), Some(2)); + assert_eq!(unaligned.suffix(), None); + assert!(unaligned.chunks().is_empty()); + assert_eq!(unaligned.lead_padding(), 1); + assert_eq!(unaligned.trailing_padding(), 62); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn fuzz_unaligned_bit_chunk_iterator() { + let mut rng = thread_rng(); + + for _ in 0..100 { + let mask_len = rng.gen_range(0..1024); + let bools: Vec<_> = std::iter::from_fn(|| Some(rng.gen())) + .take(mask_len) + .collect(); + + let buffer = Buffer::from_iter(bools.iter().cloned()); + + let max_offset = 64.min(mask_len); + let offset = rng.gen::().checked_rem(max_offset).unwrap_or(0); + + let max_truncate = 128.min(mask_len - offset); + let truncate = rng.gen::().checked_rem(max_truncate).unwrap_or(0); + + let unaligned = UnalignedBitChunk::new( + buffer.as_slice(), + offset, + mask_len - offset - truncate, + ); + + let bool_slice = &bools[offset..mask_len - truncate]; + + let count = unaligned.count_ones(); + let expected_count = bool_slice.iter().filter(|x| **x).count(); + + assert_eq!(count, expected_count); + + let collected: Vec = unaligned.iter().collect(); + + let get_bit = |idx: usize| -> bool { + let padded_index = idx + unaligned.lead_padding(); + let byte_idx = padded_index / 64; + let bit_idx = padded_index % 64; + (collected[byte_idx] & (1 << bit_idx)) != 0 + }; + + for (idx, b) in bool_slice.iter().enumerate() { + assert_eq!(*b, get_bit(idx)) + } + } + } } diff --git a/arrow/src/util/bit_iterator.rs b/arrow/src/util/bit_iterator.rs new file mode 100644 index 000000000000..bba9dac60a4b --- /dev/null +++ b/arrow/src/util/bit_iterator.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator}; + +/// Iterator of contiguous ranges of set bits within a provided packed bitmask +/// +/// Returns `(usize, usize)` each representing an interval where the corresponding +/// bits in the provides mask are set +/// +#[derive(Debug)] +pub struct BitSliceIterator<'a> { + iter: UnalignedBitChunkIterator<'a>, + len: usize, + current_offset: i64, + current_chunk: u64, +} + +impl<'a> BitSliceIterator<'a> { + /// Create a new [`BitSliceIterator`] from the provide `buffer`, + /// and `offset` and `len` in bits + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { + let chunk = UnalignedBitChunk::new(buffer, offset, len); + let mut iter = chunk.iter(); + + let current_offset = -(chunk.lead_padding() as i64); + let current_chunk = iter.next().unwrap_or(0); + + Self { + iter, + len, + current_offset, + current_chunk, + } + } + + /// Returns `Some((chunk_offset, bit_offset))` for the next chunk that has at + /// least one bit set, or None if there is no such chunk. + /// + /// Where `chunk_offset` is the bit offset to the current `u64` chunk + /// and `bit_offset` is the offset of the first `1` bit in that chunk + fn advance_to_set_bit(&mut self) -> Option<(i64, u32)> { + loop { + if self.current_chunk != 0 { + // Find the index of the first 1 + let bit_pos = self.current_chunk.trailing_zeros(); + return Some((self.current_offset, bit_pos)); + } + + self.current_chunk = self.iter.next()?; + self.current_offset += 64; + } + } +} + +impl<'a> Iterator for BitSliceIterator<'a> { + type Item = (usize, usize); + + fn next(&mut self) -> Option { + // Used as termination condition + if self.len == 0 { + return None; + } + + let (start_chunk, start_bit) = self.advance_to_set_bit()?; + + // Set bits up to start + self.current_chunk |= (1 << start_bit) - 1; + + loop { + if self.current_chunk != u64::MAX { + // Find the index of the first 0 + let end_bit = self.current_chunk.trailing_ones(); + + // Zero out up to end_bit + self.current_chunk &= !((1 << end_bit) - 1); + + return Some(( + (start_chunk + start_bit as i64) as usize, + (self.current_offset + end_bit as i64) as usize, + )); + } + + match self.iter.next() { + Some(next) => { + self.current_chunk = next; + self.current_offset += 64; + } + None => { + return Some(( + (start_chunk + start_bit as i64) as usize, + std::mem::replace(&mut self.len, 0), + )); + } + } + } + } +} + +/// An iterator of `usize` whose index in a provided bitmask is true +/// +/// This provides the best performance on most masks, apart from those which contain +/// large runs and therefore favour [`BitSliceIterator`] +#[derive(Debug)] +pub struct BitIndexIterator<'a> { + current_chunk: u64, + chunk_offset: i64, + iter: UnalignedBitChunkIterator<'a>, +} + +impl<'a> BitIndexIterator<'a> { + /// Create a new [`BitIndexIterator`] from the provide `buffer`, + /// and `offset` and `len` in bits + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { + let chunks = UnalignedBitChunk::new(buffer, offset, len); + let mut iter = chunks.iter(); + + let current_chunk = iter.next().unwrap_or(0); + let chunk_offset = -(chunks.lead_padding() as i64); + + Self { + current_chunk, + chunk_offset, + iter, + } + } +} + +impl<'a> Iterator for BitIndexIterator<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + loop { + if self.current_chunk != 0 { + let bit_pos = self.current_chunk.trailing_zeros(); + self.current_chunk ^= 1 << bit_pos; + return Some((self.chunk_offset + bit_pos as i64) as usize); + } + + self.current_chunk = self.iter.next()?; + self.chunk_offset += 64; + } + } +} + +// Note: tests located in filter module diff --git a/arrow/src/util/bit_mask.rs b/arrow/src/util/bit_mask.rs new file mode 100644 index 000000000000..501a46173fac --- /dev/null +++ b/arrow/src/util/bit_mask.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utils for working with packed bit masks + +use crate::util::bit_chunk_iterator::BitChunks; +use crate::util::bit_util::{ceil, get_bit, set_bit}; + +/// Sets all bits on `write_data` in the range `[offset_write..offset_write+len]` to be equal to the +/// bits in `data` in the range `[offset_read..offset_read+len]` +/// returns the number of `0` bits `data[offset_read..offset_read+len]` +pub fn set_bits( + write_data: &mut [u8], + data: &[u8], + offset_write: usize, + offset_read: usize, + len: usize, +) -> usize { + let mut null_count = 0; + + let mut bits_to_align = offset_write % 8; + if bits_to_align > 0 { + bits_to_align = std::cmp::min(len, 8 - bits_to_align); + } + let mut write_byte_index = ceil(offset_write + bits_to_align, 8); + + // Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time) + let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align); + chunks.iter().for_each(|chunk| { + null_count += chunk.count_zeros(); + chunk.to_le_bytes().iter().for_each(|b| { + write_data[write_byte_index] = *b; + write_byte_index += 1; + }) + }); + + // Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator + let remainder_offset = len - chunks.remainder_len(); + (0..bits_to_align) + .chain(remainder_offset..len) + .for_each(|i| { + if get_bit(data, offset_read + i) { + set_bit(write_data, offset_write + i); + } else { + null_count += 1; + } + }); + + null_count as usize +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_set_bits_aligned() { + let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let source: &[u8] = &[ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, + ]; + + let destination_offset = 8; + let source_offset = 0; + + let len = 64; + + let expected_data: &[u8] = &[ + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, 0, + ]; + let expected_null_count = 24; + let result = set_bits( + destination.as_mut_slice(), + source, + destination_offset, + source_offset, + len, + ); + + assert_eq!(destination, expected_data); + assert_eq!(result, expected_null_count); + } + + #[test] + fn test_set_bits_unaligned_destination_start() { + let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let source: &[u8] = &[ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, + ]; + + let destination_offset = 3; + let source_offset = 0; + + let len = 64; + + let expected_data: &[u8] = &[ + 0b00111000, 0b00101111, 0b11001101, 0b11011100, 0b01011110, 0b00011111, + 0b00111110, 0b00101111, 0b00000101, 0b00000000, + ]; + let expected_null_count = 24; + let result = set_bits( + destination.as_mut_slice(), + source, + destination_offset, + source_offset, + len, + ); + + assert_eq!(destination, expected_data); + assert_eq!(result, expected_null_count); + } + + #[test] + fn test_set_bits_unaligned_destination_end() { + let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let source: &[u8] = &[ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, + ]; + + let destination_offset = 8; + let source_offset = 0; + + let len = 62; + + let expected_data: &[u8] = &[ + 0, 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b00100101, 0, + ]; + let expected_null_count = 23; + let result = set_bits( + destination.as_mut_slice(), + source, + destination_offset, + source_offset, + len, + ); + + assert_eq!(destination, expected_data); + assert_eq!(result, expected_null_count); + } + + #[test] + fn test_set_bits_unaligned() { + let mut destination: Vec = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let source: &[u8] = &[ + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + 0b11100111, 0b10100101, 0b10011001, 0b11011011, 0b11101011, 0b11000011, + ]; + + let destination_offset = 3; + let source_offset = 5; + + let len = 95; + + let expected_data: &[u8] = &[ + 0b01111000, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, + 0b01111001, 0b01101001, 0b11100110, 0b11110110, 0b11111010, 0b11110000, + 0b00000001, + ]; + let expected_null_count = 35; + let result = set_bits( + destination.as_mut_slice(), + source, + destination_offset, + source_offset, + len, + ); + + assert_eq!(destination, expected_data); + assert_eq!(result, expected_null_count); + } +} diff --git a/arrow/src/util/bit_util.rs b/arrow/src/util/bit_util.rs index f643d593fabd..5752c5df972e 100644 --- a/arrow/src/util/bit_util.rs +++ b/arrow/src/util/bit_util.rs @@ -17,6 +17,7 @@ //! Utils for working with bits +use num::Integer; #[cfg(feature = "simd")] use packed_simd::u8x64; @@ -99,12 +100,9 @@ pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) { /// Returns the ceil of `value`/`divisor` #[inline] pub fn ceil(value: usize, divisor: usize) -> usize { - let (quot, rem) = (value / divisor, value % divisor); - if rem > 0 && divisor > 0 { - quot + 1 - } else { - quot - } + // Rewrite as `value.div_ceil(&divisor)` after + // https://github.com/rust-lang/rust/issues/88581 is merged. + Integer::div_ceil(&value, &divisor) } /// Performs SIMD bitwise binary operations. diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs index 35b65ef303db..21b8ee8c9fd1 100644 --- a/arrow/src/util/data_gen.rs +++ b/arrow/src/util/data_gen.rs @@ -49,6 +49,7 @@ pub fn create_random_batch( columns, &RecordBatchOptions { match_field_names: false, + row_count: None, }, ) } diff --git a/arrow/src/util/decimal.rs b/arrow/src/util/decimal.rs new file mode 100644 index 000000000000..b78af3acc6cd --- /dev/null +++ b/arrow/src/util/decimal.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decimal related utils + +use std::cmp::Ordering; + +/// Represents a decimal value with precision and scale. +/// The decimal value is represented by a signed 128-bit integer. +#[derive(Debug)] +pub struct Decimal128 { + #[allow(dead_code)] + precision: usize, + scale: usize, + value: i128, +} + +impl PartialOrd for Decimal128 { + fn partial_cmp(&self, other: &Self) -> Option { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimal128 with different scale: {}, {}", + self.scale, other.scale + ); + self.value.partial_cmp(&other.value) + } +} + +impl Ord for Decimal128 { + fn cmp(&self, other: &Self) -> Ordering { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimal128 with different scale: {}, {}", + self.scale, other.scale + ); + self.value.cmp(&other.value) + } +} + +impl PartialEq for Decimal128 { + fn eq(&self, other: &Self) -> bool { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimal128 with different scale: {}, {}", + self.scale, other.scale + ); + self.value.eq(&other.value) + } +} + +impl Eq for Decimal128 {} + +impl Decimal128 { + pub fn new_from_bytes(precision: usize, scale: usize, bytes: &[u8]) -> Self { + let as_array = bytes.try_into(); + let value = match as_array { + Ok(v) if bytes.len() == 16 => i128::from_le_bytes(v), + _ => panic!("Input to Decimal128 is not 128bit integer."), + }; + + Decimal128 { + precision, + scale, + value, + } + } + + pub fn new_from_i128(precision: usize, scale: usize, value: i128) -> Self { + Decimal128 { + precision, + scale, + value, + } + } + + pub fn as_i128(&self) -> i128 { + self.value + } + + pub fn as_string(&self) -> String { + let value_str = self.value.to_string(); + + if self.scale == 0 { + value_str + } else { + let (sign, rest) = value_str.split_at(if self.value >= 0 { 0 } else { 1 }); + + if rest.len() > self.scale { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - self.scale); + format!("{}.{}", whole, decimal) + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = self.scale) + } + } + } +} + +impl From for i128 { + fn from(decimal: Decimal128) -> Self { + decimal.as_i128() + } +} + +#[cfg(test)] +mod tests { + use crate::util::decimal::Decimal128; + + #[test] + fn decimal_128_to_string() { + let mut value = Decimal128::new_from_i128(5, 2, 100); + assert_eq!(value.as_string(), "1.00"); + + value = Decimal128::new_from_i128(5, 3, 100); + assert_eq!(value.as_string(), "0.100"); + } + + #[test] + fn decimal_128_from_bytes() { + let bytes = 100_i128.to_le_bytes(); + let value = Decimal128::new_from_bytes(5, 2, &bytes); + assert_eq!(value.as_string(), "1.00"); + } + + fn i128_func(value: impl Into) -> i128 { + value.into() + } + + #[test] + fn decimal_128_to_i128() { + let value = Decimal128::new_from_i128(5, 2, 100); + let integer = i128_func(value); + assert_eq!(integer, 100); + } +} diff --git a/arrow/src/util/display.rs b/arrow/src/util/display.rs index fbc0bd579e0b..6da73e4cff67 100644 --- a/arrow/src/util/display.rs +++ b/arrow/src/util/display.rs @@ -23,8 +23,9 @@ use std::sync::Arc; use crate::array::Array; use crate::datatypes::{ - ArrowNativeType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type, - Int8Type, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, + Int64Type, Int8Type, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + UnionMode, }; use crate::{array, datatypes::IntervalUnit}; @@ -106,6 +107,45 @@ macro_rules! make_string_interval_day_time { }}; } +macro_rules! make_string_interval_month_day_nano { + ($column: ident, $row: ident) => {{ + let array = $column + .as_any() + .downcast_ref::() + .unwrap(); + + let s = if array.is_null($row) { + "NULL".to_string() + } else { + let value: u128 = array.value($row) as u128; + + let months_part: i32 = + ((value & 0xFFFFFFFF000000000000000000000000) >> 96) as i32; + let days_part: i32 = ((value & 0xFFFFFFFF0000000000000000) >> 64) as i32; + let nanoseconds_part: i64 = (value & 0xFFFFFFFFFFFFFFFF) as i64; + + let secs = nanoseconds_part / 1000000000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + format!( + "0 years {} mons {} days {} hours {} mins {}.{:02} secs", + months_part, + days_part, + hours, + mins, + secs, + (nanoseconds_part % 1000000000), + ) + }; + + Ok(s) + }}; +} + macro_rules! make_string_date { ($array_type:ty, $column: ident, $row: ident) => {{ let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); @@ -194,6 +234,22 @@ macro_rules! make_string_from_list { }}; } +macro_rules! make_string_from_fixed_size_list { + ($column: ident, $row: ident) => {{ + let list = $column + .as_any() + .downcast_ref::() + .ok_or(ArrowError::InvalidArgumentError(format!( + "Repl error: could not convert list column to list array." + )))? + .value($row); + let string_values = (0..list.len()) + .map(|i| array_value_to_string(&list.clone(), i)) + .collect::>>()?; + Ok(format!("[{}]", string_values.join(", "))) + }}; +} + #[inline(always)] pub fn make_string_from_decimal(column: &Arc, row: usize) -> Result { let array = column @@ -246,6 +302,9 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result make_string!(array::LargeStringArray, column, row), DataType::Binary => make_string_hex!(array::BinaryArray, column, row), DataType::LargeBinary => make_string_hex!(array::LargeBinaryArray, column, row), + DataType::FixedSizeBinary(_) => { + make_string_hex!(array::FixedSizeBinaryArray, column, row) + } DataType::Boolean => make_string!(array::BooleanArray, column, row), DataType::Int8 => make_string!(array::Int8Array, column, row), DataType::Int16 => make_string!(array::Int16Array, column, row), @@ -255,7 +314,7 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result make_string!(array::UInt16Array, column, row), DataType::UInt32 => make_string!(array::UInt32Array, column, row), DataType::UInt64 => make_string!(array::UInt64Array, column, row), - DataType::Float16 => make_string!(array::Float32Array, column, row), + DataType::Float16 => make_string!(array::Float16Array, column, row), DataType::Float32 => make_string!(array::Float32Array, column, row), DataType::Float64 => make_string!(array::Float64Array, column, row), DataType::Decimal(..) => make_string_from_decimal(column, row), @@ -292,6 +351,9 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result { make_string_interval_year_month!(column, row) } + IntervalUnit::MonthDayNano => { + make_string_interval_month_day_nano!(column, row) + } }, DataType::List(_) => make_string_from_list!(column, row), DataType::Dictionary(index_type, _value_type) => match **index_type { @@ -308,6 +370,7 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result make_string_from_fixed_size_list!(column, row), DataType::Struct(_) => { let st = column .as_any() @@ -333,6 +396,9 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result { + union_to_string(column, row, field_vec, type_ids, mode) + } _ => Err(ArrowError::InvalidArgumentError(format!( "Pretty printing not implemented for {:?} type", column.data_type() @@ -340,6 +406,41 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result Result { + let list = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::InvalidArgumentError( + "Repl error: could not convert union column to union array.".to_string(), + ) + })?; + let type_id = list.type_id(row); + let field_idx = type_ids.iter().position(|t| t == &type_id).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Repl error: could not get field name for type id: {} in union array.", + type_id, + )) + })?; + let name = fields.get(field_idx).unwrap().name(); + + let value = array_value_to_string( + &list.child(type_id), + match mode { + UnionMode::Dense => list.value_offset(row) as usize, + UnionMode::Sparse => row, + }, + )?; + + Ok(format!("{{{}={}}}", name, value)) +} /// Converts the value of the dictionary array at `row` to a String fn dict_array_value_to_string( colum: &array::ArrayRef, diff --git a/arrow/src/util/integration_util.rs b/arrow/src/util/integration_util.rs index 1a402bc6e368..a174da6ea436 100644 --- a/arrow/src/util/integration_util.rs +++ b/arrow/src/util/integration_util.rs @@ -132,6 +132,8 @@ pub struct ArrowJsonColumn { pub data: Option>, #[serde(rename = "OFFSET")] pub offset: Option>, // leaving as Value as 64-bit offsets are strings + #[serde(rename = "TYPE_ID")] + pub type_id: Option>, pub children: Option>, } @@ -286,6 +288,10 @@ impl ArrowJsonBatch { .collect::>(); arr.equals_json(&x.iter().collect::>()[..]) } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let arr = IntervalMonthDayNanoArray::from(arr.data().clone()); + arr.equals_json(&json_array.iter().collect::>()[..]) + } DataType::UInt8 => { let arr = arr.as_any().downcast_ref::().unwrap(); arr.equals_json(&json_array.iter().collect::>()[..]) @@ -468,6 +474,7 @@ impl ArrowJsonBatch { validity: Some(validity), data: Some(data), offset: None, + type_id: None, children: None, } } @@ -477,6 +484,7 @@ impl ArrowJsonBatch { validity: None, data: None, offset: None, + type_id: None, children: None, }, }; @@ -779,97 +787,90 @@ mod tests { let micros_tz = Some("UTC".to_string()); let nanos_tz = Some("Africa/Johannesburg".to_string()); - let schema = Schema::new(vec![ - { - let mut f = - Field::new("bools-with-metadata-map", DataType::Boolean, true); - f.set_metadata(Some( - [("k".to_string(), "v".to_string())] - .iter() - .cloned() - .collect(), - )); - f - }, - { - let mut f = - Field::new("bools-with-metadata-vec", DataType::Boolean, true); - f.set_metadata(Some( - [("k2".to_string(), "v2".to_string())] - .iter() - .cloned() - .collect(), - )); - f - }, - Field::new("bools", DataType::Boolean, true), - Field::new("int8s", DataType::Int8, true), - Field::new("int16s", DataType::Int16, true), - Field::new("int32s", DataType::Int32, true), - Field::new("int64s", DataType::Int64, true), - Field::new("uint8s", DataType::UInt8, true), - Field::new("uint16s", DataType::UInt16, true), - Field::new("uint32s", DataType::UInt32, true), - Field::new("uint64s", DataType::UInt64, true), - Field::new("float32s", DataType::Float32, true), - Field::new("float64s", DataType::Float64, true), - Field::new("date_days", DataType::Date32, true), - Field::new("date_millis", DataType::Date64, true), - Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), - Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), - Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), - Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), - Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), - Field::new( - "ts_millis", - DataType::Timestamp(TimeUnit::Millisecond, None), - true, - ), - Field::new( - "ts_micros", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - Field::new( - "ts_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - Field::new( - "ts_secs_tz", - DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), - true, - ), - Field::new( - "ts_millis_tz", - DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), - true, - ), - Field::new( - "ts_micros_tz", - DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), - true, - ), - Field::new( - "ts_nanos_tz", - DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), - true, - ), - Field::new("utf8s", DataType::Utf8, true), - Field::new( - "lists", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - ), - Field::new( - "structs", - DataType::Struct(vec![ - Field::new("int32s", DataType::Int32, true), - Field::new("utf8s", DataType::Utf8, true), - ]), - true, - ), - ]); + let schema = + Schema::new(vec![ + Field::new("bools-with-metadata-map", DataType::Boolean, true) + .with_metadata(Some( + [("k".to_string(), "v".to_string())] + .iter() + .cloned() + .collect(), + )), + Field::new("bools-with-metadata-vec", DataType::Boolean, true) + .with_metadata(Some( + [("k2".to_string(), "v2".to_string())] + .iter() + .cloned() + .collect(), + )), + Field::new("bools", DataType::Boolean, true), + Field::new("int8s", DataType::Int8, true), + Field::new("int16s", DataType::Int16, true), + Field::new("int32s", DataType::Int32, true), + Field::new("int64s", DataType::Int64, true), + Field::new("uint8s", DataType::UInt8, true), + Field::new("uint16s", DataType::UInt16, true), + Field::new("uint32s", DataType::UInt32, true), + Field::new("uint64s", DataType::UInt64, true), + Field::new("float32s", DataType::Float32, true), + Field::new("float64s", DataType::Float64, true), + Field::new("date_days", DataType::Date32, true), + Field::new("date_millis", DataType::Date64, true), + Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), + Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), + Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), + Field::new( + "ts_millis", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micros", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "ts_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "ts_secs_tz", + DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), + true, + ), + Field::new( + "ts_millis_tz", + DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), + true, + ), + Field::new( + "ts_micros_tz", + DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), + true, + ), + Field::new( + "ts_nanos_tz", + DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), + true, + ), + Field::new("utf8s", DataType::Utf8, true), + Field::new( + "lists", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "structs", + DataType::Struct(vec![ + Field::new("int32s", DataType::Int32, true), + Field::new("utf8s", DataType::Utf8, true), + ]), + true, + ), + ]); let bools_with_metadata_map = BooleanArray::from(vec![Some(true), None, Some(false)]); diff --git a/arrow/src/util/mod.rs b/arrow/src/util/mod.rs index 1802d3e3adff..86253da8d777 100644 --- a/arrow/src/util/mod.rs +++ b/arrow/src/util/mod.rs @@ -18,6 +18,8 @@ #[cfg(feature = "test_utils")] pub mod bench_util; pub mod bit_chunk_iterator; +pub mod bit_iterator; +pub(crate) mod bit_mask; pub mod bit_util; #[cfg(feature = "test_utils")] pub mod data_gen; @@ -33,3 +35,6 @@ pub mod test_util; mod trusted_len; pub(crate) use trusted_len::trusted_len_unzip; + +pub mod decimal; +pub(crate) mod reader_parser; diff --git a/arrow/src/util/pretty.rs b/arrow/src/util/pretty.rs index 28bb0165ddd9..124de6127ddd 100644 --- a/arrow/src/util/pretty.rs +++ b/arrow/src/util/pretty.rs @@ -19,6 +19,7 @@ //! available unless `feature = "prettyprint"` is enabled. use crate::{array::ArrayRef, record_batch::RecordBatch}; +use std::fmt::Display; use comfy_table::{Cell, Table}; @@ -27,13 +28,16 @@ use crate::error::Result; use super::display::array_value_to_string; ///! Create a visual representation of record batches -pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { - Ok(create_table(results)?.to_string()) +pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { + create_table(results) } ///! Create a visual representation of columns -pub fn pretty_format_columns(col_name: &str, results: &[ArrayRef]) -> Result { - Ok(create_column(col_name, results)?.to_string()) +pub fn pretty_format_columns( + col_name: &str, + results: &[ArrayRef], +) -> Result { + create_column(col_name, results) } ///! Prints a visual representation of record batches to stdout @@ -70,7 +74,7 @@ fn create_table(results: &[RecordBatch]) -> Result { let mut cells = Vec::new(); for col in 0..batch.num_columns() { let column = batch.column(col); - cells.push(Cell::new(&array_value_to_string(&column, row)?)); + cells.push(Cell::new(&array_value_to_string(column, row)?)); } table.add_row(cells); } @@ -92,7 +96,7 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ for col in columns { for row in 0..col.len() { - let cells = vec![Cell::new(&array_value_to_string(&col, row)?)]; + let cells = vec![Cell::new(&array_value_to_string(col, row)?)]; table.add_row(cells); } } @@ -104,19 +108,24 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result
{ mod tests { use crate::{ array::{ - self, new_null_array, Array, Date32Array, Date64Array, PrimitiveBuilder, + self, new_null_array, Array, Date32Array, Date64Array, + FixedSizeBinaryBuilder, Float16Array, Int32Array, PrimitiveBuilder, StringArray, StringBuilder, StringDictionaryBuilder, StructArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, + TimestampNanosecondArray, TimestampSecondArray, UnionArray, UnionBuilder, }, - datatypes::{DataType, Field, Int32Type, Schema}, + buffer::Buffer, + datatypes::{DataType, Field, Float64Type, Int32Type, Schema, UnionMode}, }; use super::*; - use crate::array::{DecimalBuilder, Int32Array}; + use crate::array::{DecimalArray, FixedSizeListBuilder}; + use std::fmt::Write; use std::sync::Arc; + use half::f16; + #[test] fn test_pretty_format_batches() -> Result<()> { // define a schema. @@ -144,7 +153,7 @@ mod tests { ], )?; - let table = pretty_format_batches(&[batch])?; + let table = pretty_format_batches(&[batch])?.to_string(); let expected = vec![ "+---+-----+", @@ -176,7 +185,7 @@ mod tests { Arc::new(array::StringArray::from(vec![Some("e"), None, Some("g")])), ]; - let table = pretty_format_columns("a", &columns)?; + let table = pretty_format_columns("a", &columns)?.to_string(); let expected = vec![ "+---+", "| a |", "+---+", "| a |", "| b |", "| |", "| d |", "| e |", @@ -208,7 +217,7 @@ mod tests { // define data (null) let batch = RecordBatch::try_new(schema, arrays).unwrap(); - let table = pretty_format_batches(&[batch]).unwrap(); + let table = pretty_format_batches(&[batch]).unwrap().to_string(); let expected = vec![ "+---+---+---+", @@ -244,7 +253,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![array])?; - let table = pretty_format_batches(&[batch])?; + let table = pretty_format_batches(&[batch])?.to_string(); let expected = vec![ "+-------+", @@ -263,6 +272,79 @@ mod tests { Ok(()) } + #[test] + fn test_pretty_format_fixed_size_list() -> Result<()> { + // define a schema. + let field_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, true)), + 3, + ); + let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + + let keys_builder = Int32Array::builder(3); + let mut builder = FixedSizeListBuilder::new(keys_builder, 3); + + builder.values().append_slice(&[1, 2, 3]).unwrap(); + builder.append(true).unwrap(); + builder.values().append_slice(&[4, 5, 6]).unwrap(); + builder.append(false).unwrap(); + builder.values().append_slice(&[7, 8, 9]).unwrap(); + builder.append(true).unwrap(); + + let array = Arc::new(builder.finish()); + + let batch = RecordBatch::try_new(schema, vec![array])?; + let table = pretty_format_batches(&[batch])?.to_string(); + let expected = vec![ + "+-----------+", + "| d1 |", + "+-----------+", + "| [1, 2, 3] |", + "| |", + "| [7, 8, 9] |", + "+-----------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{}", table); + + Ok(()) + } + + #[test] + fn test_pretty_format_fixed_size_binary() -> Result<()> { + // define a schema. + let field_type = DataType::FixedSizeBinary(3); + let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + + let mut builder = FixedSizeBinaryBuilder::new(3, 3); + + builder.append_value(&[1, 2, 3]).unwrap(); + builder.append_null().unwrap(); + builder.append_value(&[7, 8, 9]).unwrap(); + + let array = Arc::new(builder.finish()); + + let batch = RecordBatch::try_new(schema, vec![array])?; + let table = pretty_format_batches(&[batch])?.to_string(); + let expected = vec![ + "+--------+", + "| d1 |", + "+--------+", + "| 010203 |", + "| |", + "| 070809 |", + "+--------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + + assert_eq!(expected, actual, "Actual result:\n{}", table); + + Ok(()) + } + /// Generate an array with type $ARRAYTYPE with a numeric value of /// $VALUE, and compare $EXPECTED_RESULT to the output of /// formatting that array with `pretty_format_batches` @@ -280,7 +362,9 @@ mod tests { )])); let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); - let table = pretty_format_batches(&[batch]).expect("formatting batches"); + let table = pretty_format_batches(&[batch]) + .expect("formatting batches") + .to_string(); let expected = $EXPECTED_RESULT; let actual: Vec<&str> = table.lines().collect(); @@ -434,17 +518,16 @@ mod tests { #[test] fn test_decimal_display() -> Result<()> { - let capacity = 10; let precision = 10; let scale = 2; - let mut builder = DecimalBuilder::new(capacity, precision, scale); - builder.append_value(101).unwrap(); - builder.append_null().unwrap(); - builder.append_value(200).unwrap(); - builder.append_value(3040).unwrap(); + let array = [Some(101), None, Some(200), Some(3040)] + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + .unwrap(); - let dm = Arc::new(builder.finish()) as ArrayRef; + let dm = Arc::new(array) as ArrayRef; let schema = Arc::new(Schema::new(vec![Field::new( "f", @@ -454,7 +537,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![dm])?; - let table = pretty_format_batches(&[batch])?; + let table = pretty_format_batches(&[batch])?.to_string(); let expected = vec![ "+-------+", @@ -475,17 +558,16 @@ mod tests { #[test] fn test_decimal_display_zero_scale() -> Result<()> { - let capacity = 10; let precision = 5; let scale = 0; - let mut builder = DecimalBuilder::new(capacity, precision, scale); - builder.append_value(101).unwrap(); - builder.append_null().unwrap(); - builder.append_value(200).unwrap(); - builder.append_value(3040).unwrap(); + let array = [Some(101), None, Some(200), Some(3040)] + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + .unwrap(); - let dm = Arc::new(builder.finish()) as ArrayRef; + let dm = Arc::new(array) as ArrayRef; let schema = Arc::new(Schema::new(vec![Field::new( "f", @@ -495,7 +577,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![dm])?; - let table = pretty_format_batches(&[batch])?; + let table = pretty_format_batches(&[batch])?.to_string(); let expected = vec![ "+------+", "| f |", "+------+", "| 101 |", "| |", "| 200 |", "| 3040 |", "+------+", @@ -549,7 +631,7 @@ mod tests { RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]) .unwrap(); - let table = pretty_format_batches(&[batch])?; + let table = pretty_format_batches(&[batch])?.to_string(); let expected = vec![ r#"+-------------------------------------+----+"#, r#"| c1 | c2 |"#, @@ -565,4 +647,226 @@ mod tests { Ok(()) } + + #[test] + fn test_pretty_format_dense_union() -> Result<()> { + let mut builder = UnionBuilder::new_dense(4); + builder.append::("a", 1).unwrap(); + builder.append::("b", 3.2234).unwrap(); + builder.append_null::("b").unwrap(); + builder.append_null::("a").unwrap(); + let union = builder.build().unwrap(); + + let schema = Schema::new(vec![Field::new( + "Teamsters", + DataType::Union( + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ], + vec![0, 1], + UnionMode::Dense, + ), + false, + )]); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let table = pretty_format_batches(&[batch])?.to_string(); + let actual: Vec<&str> = table.lines().collect(); + let expected = vec![ + "+------------+", + "| Teamsters |", + "+------------+", + "| {a=1} |", + "| {b=3.2234} |", + "| {b=} |", + "| {a=} |", + "+------------+", + ]; + + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + fn test_pretty_format_sparse_union() -> Result<()> { + let mut builder = UnionBuilder::new_sparse(4); + builder.append::("a", 1).unwrap(); + builder.append::("b", 3.2234).unwrap(); + builder.append_null::("b").unwrap(); + builder.append_null::("a").unwrap(); + let union = builder.build().unwrap(); + + let schema = Schema::new(vec![Field::new( + "Teamsters", + DataType::Union( + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ], + vec![0, 1], + UnionMode::Sparse, + ), + false, + )]); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap(); + let table = pretty_format_batches(&[batch])?.to_string(); + let actual: Vec<&str> = table.lines().collect(); + let expected = vec![ + "+------------+", + "| Teamsters |", + "+------------+", + "| {a=1} |", + "| {b=3.2234} |", + "| {b=} |", + "| {a=} |", + "+------------+", + ]; + + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + fn test_pretty_format_nested_union() -> Result<()> { + //Inner UnionArray + let mut builder = UnionBuilder::new_dense(5); + builder.append::("b", 1).unwrap(); + builder.append::("c", 3.2234).unwrap(); + builder.append_null::("c").unwrap(); + builder.append_null::("b").unwrap(); + builder.append_null::("c").unwrap(); + let inner = builder.build().unwrap(); + + let inner_field = Field::new( + "European Union", + DataType::Union( + vec![ + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Float64, false), + ], + vec![0, 1], + UnionMode::Dense, + ), + false, + ); + + // Can't use UnionBuilder with non-primitive types, so manually build outer UnionArray + let a_array = Int32Array::from(vec![None, None, None, Some(1234), Some(23)]); + let type_ids = Buffer::from_slice_ref(&[1_i8, 1, 0, 0, 1]); + + let children: Vec<(Field, Arc)> = vec![ + (Field::new("a", DataType::Int32, true), Arc::new(a_array)), + (inner_field.clone(), Arc::new(inner)), + ]; + + let outer = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap(); + + let schema = Schema::new(vec![Field::new( + "Teamsters", + DataType::Union( + vec![Field::new("a", DataType::Int32, true), inner_field], + vec![0, 1], + UnionMode::Sparse, + ), + false, + )]); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(outer)]).unwrap(); + let table = pretty_format_batches(&[batch])?.to_string(); + let actual: Vec<&str> = table.lines().collect(); + let expected = vec![ + "+-----------------------------+", + "| Teamsters |", + "+-----------------------------+", + "| {European Union={b=1}} |", + "| {European Union={c=3.2234}} |", + "| {a=} |", + "| {a=1234} |", + "| {European Union={c=}} |", + "+-----------------------------+", + ]; + assert_eq!(expected, actual); + Ok(()) + } + + #[test] + fn test_writing_formatted_batches() -> Result<()> { + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int32, true), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(array::StringArray::from(vec![ + Some("a"), + Some("b"), + None, + Some("d"), + ])), + Arc::new(array::Int32Array::from(vec![ + Some(1), + None, + Some(10), + Some(100), + ])), + ], + )?; + + let mut buf = String::new(); + write!(&mut buf, "{}", pretty_format_batches(&[batch])?).unwrap(); + + let s = vec![ + "+---+-----+", + "| a | b |", + "+---+-----+", + "| a | 1 |", + "| b | |", + "| | 10 |", + "| d | 100 |", + "+---+-----+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + + Ok(()) + } + + #[test] + fn test_float16_display() -> Result<()> { + let values = vec![ + Some(f16::from_f32(f32::NAN)), + Some(f16::from_f32(4.0)), + Some(f16::from_f32(f32::NEG_INFINITY)), + ]; + let array = Arc::new(values.into_iter().collect::()) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![Field::new( + "f16", + array.data_type().clone(), + true, + )])); + + let batch = RecordBatch::try_new(schema, vec![array])?; + + let table = pretty_format_batches(&[batch])?.to_string(); + + let expected = vec![ + "+------+", "| f16 |", "+------+", "| NaN |", "| 4 |", "| -inf |", + "+------+", + ]; + + let actual: Vec<&str> = table.lines().collect(); + assert_eq!(expected, actual, "Actual result:\n{}", table); + + Ok(()) + } } diff --git a/arrow/src/util/reader_parser.rs b/arrow/src/util/reader_parser.rs new file mode 100644 index 000000000000..6b6f24f82a43 --- /dev/null +++ b/arrow/src/util/reader_parser.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::compute::kernels::cast_utils::string_to_timestamp_nanos; +use crate::datatypes::*; + +/// Specialized parsing implementations +/// used by csv and json reader +pub(crate) trait Parser: ArrowPrimitiveType { + fn parse(string: &str) -> Option { + string.parse::().ok() + } + + fn parse_formatted(string: &str, _format: &str) -> Option { + Self::parse(string) + } +} + +impl Parser for Float32Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +impl Parser for Float64Type { + fn parse(string: &str) -> Option { + lexical_core::parse(string.as_bytes()).ok() + } +} + +impl Parser for UInt64Type {} + +impl Parser for UInt32Type {} + +impl Parser for UInt16Type {} + +impl Parser for UInt8Type {} + +impl Parser for Int64Type {} + +impl Parser for Int32Type {} + +impl Parser for Int16Type {} + +impl Parser for Int8Type {} + +impl Parser for TimestampNanosecondType { + fn parse(string: &str) -> Option { + string_to_timestamp_nanos(string).ok() + } +} + +impl Parser for TimestampMicrosecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1000) + } +} + +impl Parser for TimestampMillisecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000) + } +} + +impl Parser for TimestampSecondType { + fn parse(string: &str) -> Option { + let nanos = string_to_timestamp_nanos(string).ok(); + nanos.map(|x| x / 1_000_000_000) + } +} + +impl Parser for Time64NanosecondType {} + +impl Parser for Time64MicrosecondType {} + +impl Parser for Time32MillisecondType {} + +impl Parser for Time32SecondType {} + +/// Number of days between 0001-01-01 and 1970-01-01 +const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +impl Parser for Date32Type { + fn parse(string: &str) -> Option { + use chrono::Datelike; + let date = string.parse::().ok()?; + Self::Native::from_i32(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + use chrono::Datelike; + let date = chrono::NaiveDate::parse_from_str(string, format).ok()?; + Self::Native::from_i32(date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + } +} + +impl Parser for Date64Type { + fn parse(string: &str) -> Option { + let date_time = string.parse::().ok()?; + Self::Native::from_i64(date_time.timestamp_millis()) + } + + fn parse_formatted(string: &str, format: &str) -> Option { + use chrono::format::Fixed; + use chrono::format::StrftimeItems; + let fmt = StrftimeItems::new(format); + let has_zone = fmt.into_iter().any(|item| match item { + chrono::format::Item::Fixed(fixed_item) => matches!( + fixed_item, + Fixed::RFC2822 + | Fixed::RFC3339 + | Fixed::TimezoneName + | Fixed::TimezoneOffsetColon + | Fixed::TimezoneOffsetColonZ + | Fixed::TimezoneOffset + | Fixed::TimezoneOffsetZ + ), + _ => false, + }); + if has_zone { + let date_time = chrono::DateTime::parse_from_str(string, format).ok()?; + Self::Native::from_i64(date_time.timestamp_millis()) + } else { + let date_time = chrono::NaiveDateTime::parse_from_str(string, format).ok()?; + Self::Native::from_i64(date_time.timestamp_millis()) + } + } +} diff --git a/arrow/src/util/test_util.rs b/arrow/src/util/test_util.rs index 4b193f774178..cae148a53d5c 100644 --- a/arrow/src/util/test_util.rs +++ b/arrow/src/util/test_util.rs @@ -124,7 +124,7 @@ fn get_data_dir(udf_env: &str, submodule_data: &str) -> Result Result { + /// where the iterator currently is + cur: usize, + /// How many items will this iterator *actually* make + limit: usize, + /// How many items this iterator claims it will make + claimed: usize, + /// The items to return. If there are fewer items than `limit` + /// they will be repeated + pub items: Vec, +} + +impl BadIterator { + /// Create a new iterator for items, but that reports to + /// produce items. Must provide at least 1 item. + pub fn new(limit: usize, claimed: usize, items: Vec) -> Self { + assert!(!items.is_empty()); + Self { + cur: 0, + limit, + claimed, + items, + } + } +} + +impl Iterator for BadIterator { + type Item = T; + + fn next(&mut self) -> Option { + if self.cur < self.limit { + let next_item_idx = self.cur % self.items.len(); + let next_item = self.items[next_item_idx].clone(); + self.cur += 1; + Some(next_item) + } else { + None + } + } + + /// report whatever the iterator says to + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.claimed as usize)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow/test/data/basic.json b/arrow/test/data/basic.json index dafd2dd2e420..556c39c46be9 100644 --- a/arrow/test/data/basic.json +++ b/arrow/test/data/basic.json @@ -1,6 +1,6 @@ -{"a":1, "b":2.0, "c":false, "d":"4"} -{"a":-10, "b":-3.5, "c":true, "d":"4"} -{"a":2, "b":0.6, "c":false, "d":"text"} +{"a":1, "b":2.0, "c":false, "d":"4", "e":"1970-1-2"} +{"a":-10, "b":-3.5, "c":true, "d":"4", "e": "1969-12-31"} +{"a":2, "b":0.6, "c":false, "d":"text", "e": "1970-01-02 11:11:11"} {"a":1, "b":2.0, "c":false, "d":"4"} {"a":7, "b":-3.5, "c":true, "d":"4"} {"a":1, "b":0.6, "c":false, "d":"text"} @@ -9,4 +9,4 @@ {"a":1, "b":0.6, "c":false, "d":"text"} {"a":1, "b":2.0, "c":false, "d":"4"} {"a":1, "b":-3.5, "c":true, "d":"4"} -{"a":100000000000000, "b":0.6, "c":false, "d":"text"} \ No newline at end of file +{"a":100000000000000, "b":0.6, "c":false, "d":"text"} diff --git a/arrow/test/data/decimal_test.csv b/arrow/test/data/decimal_test.csv new file mode 100644 index 000000000000..460ed808c1a0 --- /dev/null +++ b/arrow/test/data/decimal_test.csv @@ -0,0 +1,10 @@ +"Elgin, Scotland, the UK",57.653484,-3.335724 +"Stoke-on-Trent, Staffordshire, the UK",53.002666,-2.179404 +"Solihull, Birmingham, UK",52.412811,-1.778197 +"Cardiff, Cardiff county, UK",51.481583,-3.179090 +"Cardiff, Cardiff county, UK",12.12345678,-3.179090 +"Eastbourne, East Sussex, UK",50.76,0.290472 +"Eastbourne, East Sussex, UK",.123,0.290472 +"Eastbourne, East Sussex, UK",123.,0.290472 +"Eastbourne, East Sussex, UK",123,0.290472 +"Eastbourne, East Sussex, UK",-50.76,0.290472 \ No newline at end of file diff --git a/arrow/test/data/various_types.csv b/arrow/test/data/various_types.csv index 8f4466fbe6a4..570d07f5c221 100644 --- a/arrow/test/data/various_types.csv +++ b/arrow/test/data/various_types.csv @@ -3,4 +3,6 @@ c_int|c_float|c_string|c_bool|c_date|c_datetime 2|2.2|"2.22"|true|2020-11-08|2020-11-08T01:00:00 3||"3.33"|true|1969-12-31|1969-11-08T02:00:00 4|4.4||false|| -5|6.6|""|false|1990-01-01|1990-01-01T03:00:00 \ No newline at end of file +5|6.6|""|false|1990-01-01|1990-01-01T03:00:00 +4|4e6||false|| +4|4.0e-6||false|| \ No newline at end of file diff --git a/arrow/test/dependency/README.md b/arrow/test/dependency/README.md deleted file mode 100644 index b618b4636e7c..000000000000 --- a/arrow/test/dependency/README.md +++ /dev/null @@ -1,21 +0,0 @@ - - -This directory contains projects that use arrow as a dependency with -various combinations of feature flags. diff --git a/arrow/test/dependency/default-features/src/main.rs b/arrow/test/dependency/default-features/src/main.rs deleted file mode 100644 index e7a11a969c03..000000000000 --- a/arrow/test/dependency/default-features/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello, world!"); -} diff --git a/arrow/test/dependency/no-default-features/src/main.rs b/arrow/test/dependency/no-default-features/src/main.rs deleted file mode 100644 index e7a11a969c03..000000000000 --- a/arrow/test/dependency/no-default-features/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello, world!"); -} diff --git a/arrow/test/dependency/simd/Cargo.toml b/arrow/test/dependency/simd/Cargo.toml deleted file mode 100644 index 22d411a537be..000000000000 --- a/arrow/test/dependency/simd/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "defeault-features" -description = "Models a user application of arrow that uses the simd feature of arrow" -version = "0.1.0" -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -arrow = { path = "../../../../arrow", version = "7.0.0-SNAPSHOT", features = ["simd"]} - -[workspace] diff --git a/arrow/test/dependency/simd/src/main.rs b/arrow/test/dependency/simd/src/main.rs deleted file mode 100644 index e7a11a969c03..000000000000 --- a/arrow/test/dependency/simd/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello, world!"); -} diff --git a/conbench/.flake8 b/conbench/.flake8 new file mode 100644 index 000000000000..e44b81084185 --- /dev/null +++ b/conbench/.flake8 @@ -0,0 +1,2 @@ +[flake8] +ignore = E501 diff --git a/conbench/.gitignore b/conbench/.gitignore new file mode 100755 index 000000000000..aa44ee2adbd4 --- /dev/null +++ b/conbench/.gitignore @@ -0,0 +1,130 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + diff --git a/conbench/.isort.cfg b/conbench/.isort.cfg new file mode 100644 index 000000000000..f238bf7ea137 --- /dev/null +++ b/conbench/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +profile = black diff --git a/conbench/README.md b/conbench/README.md new file mode 100644 index 000000000000..8c7f38cb398a --- /dev/null +++ b/conbench/README.md @@ -0,0 +1,251 @@ + + +# Arrow Rust + Conbench Integration + + +## Quick start + +``` +$ cd ~/arrow-rs/conbench/ +$ conda create -y -n conbench python=3.9 +$ conda activate conbench +(conbench) $ pip install -r requirements.txt +(conbench) $ conbench arrow-rs +``` + +## Example output + +``` +{ + "batch_id": "b68c559358cc43a3aab02d893d2693f4", + "context": { + "benchmark_language": "Rust" + }, + "github": { + "commit": "ca33a0a50494f95840ade2e9509c3c3d4df35249", + "repository": "https://github.com/dianaclarke/arrow-rs" + }, + "info": {}, + "machine_info": { + "architecture_name": "x86_64", + "cpu_core_count": "8", + "cpu_frequency_max_hz": "2400000000", + "cpu_l1d_cache_bytes": "65536", + "cpu_l1i_cache_bytes": "131072", + "cpu_l2_cache_bytes": "4194304", + "cpu_l3_cache_bytes": "0", + "cpu_model_name": "Apple M1", + "cpu_thread_count": "8", + "gpu_count": "0", + "gpu_product_names": [], + "kernel_name": "20.6.0", + "memory_bytes": "17179869184", + "name": "diana", + "os_name": "macOS", + "os_version": "10.16" + }, + "run_id": "08353595bde147fb9deebdb4facd019a", + "stats": { + "data": [ + "0.000287", + "0.000286", + "0.000285", + "0.000281", + "0.000282", + "0.000277", + "0.000277", + "0.000286", + "0.000279", + "0.000282", + "0.000277", + "0.000277", + "0.000282", + "0.000276", + "0.000281", + "0.000281", + "0.000281", + "0.000281", + "0.000281", + "0.000284", + "0.000288", + "0.000278", + "0.000276", + "0.000278", + "0.000275", + "0.000275", + "0.000275", + "0.000275", + "0.000281", + "0.000284", + "0.000277", + "0.000277", + "0.000278", + "0.000282", + "0.000281", + "0.000284", + "0.000282", + "0.000279", + "0.000280", + "0.000281", + "0.000281", + "0.000286", + "0.000278", + "0.000278", + "0.000281", + "0.000276", + "0.000284", + "0.000281", + "0.000276", + "0.000276", + "0.000279", + "0.000283", + "0.000282", + "0.000278", + "0.000281", + "0.000284", + "0.000279", + "0.000276", + "0.000278", + "0.000283", + "0.000282", + "0.000276", + "0.000281", + "0.000279", + "0.000276", + "0.000277", + "0.000283", + "0.000279", + "0.000281", + "0.000283", + "0.000279", + "0.000282", + "0.000283", + "0.000278", + "0.000281", + "0.000282", + "0.000278", + "0.000276", + "0.000281", + "0.000278", + "0.000276", + "0.000282", + "0.000281", + "0.000282", + "0.000280", + "0.000281", + "0.000282", + "0.000280", + "0.000282", + "0.000280", + "0.000280", + "0.000282", + "0.000278", + "0.000284", + "0.000290", + "0.000282", + "0.000281", + "0.000281", + "0.000281", + "0.000278" + ], + "iqr": "0.000004", + "iterations": 100, + "max": "0.000290", + "mean": "0.000280", + "median": "0.000281", + "min": "0.000275", + "q1": "0.000278", + "q3": "0.000282", + "stdev": "0.000003", + "time_unit": "s", + "times": [], + "unit": "s" + }, + "tags": { + "name": "nlike_utf8 scalar starts with", + "suite": "nlike_utf8 scalar starts with" + }, + "timestamp": "2022-02-09T02:33:26.792404+00:00" +} +``` + +## Debug with test benchmark + +``` +(conbench) $ cd ~/arrow-rs/conbench/ +(conbench) $ conbench test --iterations=3 + +Benchmark result: +{ + "batch_id": "f4235d547e9d4f94925b54692e625d7d", + "context": { + "benchmark_language": "Python" + }, + "github": { + "commit": "35e16be01e680e9381b2d1393c2e3f8e7acb7b13", + "repository": "https://github.com/dianaclarke/arrow-rs" + }, + "info": { + "benchmark_language_version": "Python 3.9.7" + }, + "machine_info": { + "architecture_name": "x86_64", + "cpu_core_count": "8", + "cpu_frequency_max_hz": "2400000000", + "cpu_l1d_cache_bytes": "65536", + "cpu_l1i_cache_bytes": "131072", + "cpu_l2_cache_bytes": "4194304", + "cpu_l3_cache_bytes": "0", + "cpu_model_name": "Apple M1", + "cpu_thread_count": "8", + "gpu_count": "0", + "gpu_product_names": [], + "kernel_name": "20.6.0", + "memory_bytes": "17179869184", + "name": "diana", + "os_name": "macOS", + "os_version": "10.16" + }, + "run_id": "b2ca0d581cf14f21936276ff7ca5a940", + "stats": { + "data": [ + "0.000002", + "0.000001", + "0.000001" + ], + "iqr": "0.000001", + "iterations": 3, + "max": "0.000002", + "mean": "0.000001", + "median": "0.000001", + "min": "0.000001", + "q1": "0.000001", + "q3": "0.000002", + "stdev": "0.000001", + "time_unit": "s", + "times": [], + "unit": "s" + }, + "tags": { + "name": "test" + }, + "timestamp": "2022-02-09T01:55:52.250727+00:00" +} +``` diff --git a/conbench/_criterion.py b/conbench/_criterion.py new file mode 100644 index 000000000000..168a1b9b6cb1 --- /dev/null +++ b/conbench/_criterion.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections +import csv +import os +import pathlib +import subprocess + +import conbench.runner +from conbench.machine_info import github_info + + +def _result_in_seconds(row): + # sample_measured_value - The value of the measurement for this sample. + # Note that this is the measured value for the whole sample, not the + # time-per-iteration To calculate the time-per-iteration, use + # sample_measured_value/iteration_count + # -- https://bheisler.github.io/criterion.rs/book/user_guide/csv_output.html + count = int(row["iteration_count"]) + sample = float(row["sample_measured_value"]) + return sample / count / 10**9 + + +def _parse_benchmark_group(row): + parts = row["group"].split(",") + if len(parts) > 1: + suite, name = parts[0], ",".join(parts[1:]) + else: + suite, name = row["group"], row["group"] + return suite.strip(), name.strip() + + +def _read_results(src_dir): + results = collections.defaultdict(lambda: collections.defaultdict(list)) + path = pathlib.Path(os.path.join(src_dir, "target", "criterion")) + for path in list(path.glob("**/new/raw.csv")): + with open(path) as csv_file: + reader = csv.DictReader(csv_file) + for row in reader: + suite, name = _parse_benchmark_group(row) + results[suite][name].append(_result_in_seconds(row)) + return results + + +def _execute_command(command): + try: + print(command) + result = subprocess.run(command, capture_output=True, check=True) + except subprocess.CalledProcessError as e: + print(e.stderr.decode("utf-8")) + raise e + return result.stdout.decode("utf-8"), result.stderr.decode("utf-8") + + +class CriterionBenchmark(conbench.runner.Benchmark): + external = True + + def run(self, **kwargs): + src_dir = os.path.join(os.getcwd(), "..") + self._cargo_bench(src_dir) + results = _read_results(src_dir) + for suite in results: + self.conbench.mark_new_batch() + for name, data in results[suite].items(): + yield self._record_result(suite, name, data, kwargs) + + def _cargo_bench(self, src_dir): + os.chdir(src_dir) + _execute_command(["cargo", "bench"]) + + def _record_result(self, suite, name, data, options): + tags = {"suite": suite} + result = {"data": data, "unit": "s"} + context = {"benchmark_language": "Rust"} + github = github_info() + return self.conbench.record( + result, + name, + tags=tags, + context=context, + github=github, + options=options, + ) diff --git a/conbench/benchmarks.json b/conbench/benchmarks.json new file mode 100644 index 000000000000..3e1f33a7dfa9 --- /dev/null +++ b/conbench/benchmarks.json @@ -0,0 +1,8 @@ +[ + { + "command": "arrow-rs", + "flags": { + "language": "Rust" + } + } +] diff --git a/arrow/test/dependency/default-features/Cargo.toml b/conbench/benchmarks.py similarity index 60% rename from arrow/test/dependency/default-features/Cargo.toml rename to conbench/benchmarks.py index 425a891564d9..bc4c1796b85f 100644 --- a/arrow/test/dependency/default-features/Cargo.toml +++ b/conbench/benchmarks.py @@ -15,15 +15,27 @@ # specific language governing permissions and limitations # under the License. -[package] -name = "defeault-features" -description = "Models a user application of arrow that uses default features of arrow" -version = "0.1.0" -edition = "2018" +import conbench.runner -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +import _criterion -[dependencies] -arrow = { path = "../../../../arrow", version = "7.0.0-SNAPSHOT" } -[workspace] +@conbench.runner.register_benchmark +class TestBenchmark(conbench.runner.Benchmark): + name = "test" + + def run(self, **kwargs): + yield self.conbench.benchmark( + self._f(), + self.name, + options=kwargs, + ) + + def _f(self): + return lambda: 1 + 1 + + +@conbench.runner.register_benchmark +class CargoBenchmarks(_criterion.CriterionBenchmark): + name = "arrow-rs" + description = "Run Arrow Rust micro benchmarks." diff --git a/conbench/requirements-test.txt b/conbench/requirements-test.txt new file mode 100644 index 000000000000..5e5647acd2d6 --- /dev/null +++ b/conbench/requirements-test.txt @@ -0,0 +1,3 @@ +black +flake8 +isort diff --git a/conbench/requirements.txt b/conbench/requirements.txt new file mode 100644 index 000000000000..a877c7b44e9b --- /dev/null +++ b/conbench/requirements.txt @@ -0,0 +1 @@ +conbench diff --git a/dev/README.md b/dev/README.md deleted file mode 100644 index f9d2070665d7..000000000000 --- a/dev/README.md +++ /dev/null @@ -1,57 +0,0 @@ - - -# Arrow Developer Scripts - -This directory contains scripts useful to developers when packaging, -testing, or committing to Arrow. - -## Verifying Release Candidates - -We have provided a script to assist with verifying release candidates: - -```shell -bash dev/release/verify-release-candidate.sh 0.7.0 0 -``` - -This works on Linux and macOS. Read the script for information about system -dependencies. - -On Windows, we have a script that verifies C++ and Python (requires Visual -Studio 2015): - -``` -dev/release/verify-release-candidate.bat apache-arrow-0.7.0.tar.gz -``` - -### Verifying the JavaScript release - -For JavaScript-specific releases, use a different verification script: - -```shell -bash dev/release/js-verify-release-candidate.sh 0.7.0 0 -``` - -# Integration testing - -Build the following base image used by multiple tests: - -```shell -docker build -t arrow_integration_xenial_base -f docker_common/Dockerfile.xenial.base . -``` diff --git a/dev/release/README.md b/dev/release/README.md index ef379a573204..912b60dae6b3 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -19,33 +19,15 @@ # Release Process -## Branching +## Overview -Arrow maintains two branches: `active_release` and `master`. +We try to release a new version of Arrow every two weeks. This cadence balances getting new features into arrow without overwhelming downstream projects with too frequent changes. -- All new PRs are created and merged against `master` -- All versions are created from the `active_release` branch -- Once merged to master, changes are "cherry-picked" (via a hopefully soon to be automated process), to the `active_release` branch based on the judgement of the original PR author and maintainers. - -- We do not merge breaking api changes, as defined in [Rust RFC 1105](https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md) to the `active_release` branch. Instead, they are merged to the `master` branch and included in the next major release. - -Please see the [original proposal](https://docs.google.com/document/d/1tMQ67iu8XyGGZuj--h9WQYB9inCk6c2sL_4xMTwENGc/edit?ts=60961758) document the rational of this change. - -## Release Branching - -We aim to release every other week from the `active_release` branch. - -Every other week, a maintainer proposes a minor (e.g. `4.1.0` to `4.2.0`) or patch (e.g `4.1.0` to `4.1.1`) release, depending on changes to the `active_release` in the previous 2 weeks, following the process below. - -If this release is approved by at least three PMC members, that tarball is uploaded to the official apache distribution sites, a new version from that tarball is released to crates.io later in the week. - -The overall Apache Arrow in general does synchronized major releases every three months. The Rust implementation aims to do its major releases in the same time frame. +If any code has been merged to master that has a breaking API change, as defined in [Rust RFC 1105](https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md), the major version number incremented changed (e.g. `9.0.2` to `9.0.2`). Otherwise the new minor version incremented (e.g. `9.0.2` to `7.1.0`). # Release Mechanics -This directory contains the scripts used to manage an Apache Arrow Release. - -# Process Overview +## Process Overview As part of the Apache governance model, official releases consist of signed source tarballs approved by the PMC. @@ -53,12 +35,6 @@ signed source tarballs approved by the PMC. We then use the code in the approved source tarball to release to crates.io, the Rust ecosystem's package manager. -## Branching - -# Release Preparation - -# Change Log - We create a `CHANGELOG.md` so our users know what has been changed between releases. The CHANGELOG is created automatically using @@ -67,47 +43,14 @@ The CHANGELOG is created automatically using This script creates a changelog using github issues and the labels associated with them. -## CHANGELOG for maintenance releases +## Prepare CHANGELOG and version: -At the time of writing, the `update_change_log.sh` script does not work well with branches as it seems to intermix issues that were resolved in master. +Now prepare a PR to update `CHANGELOG.md` and versions on `master` to reflect the planned release. -To generate a bare bones CHANGELOG for maintenance releases, you can use a command similar to the following to get all changes between 5.0.0 and the active_release. - -```shell -git log --pretty=oneline 5.0.0..apache/active_release -``` - -This command will make markdown links that work when rendered outside of github: - -```shell -git log --pretty=oneline 5.1.0..apache/active_release | sed -e 's|\(^[0-9a-f]*\) |* [\1](https://github.com/apache/arrow-rs/commit/\1) |' | sed -e 's|#\([0-9]*\)|[#\1](https://github.com/apache/arrow-rs/pull/\1)|g' -``` - -# Mechanics of creating a release - -## Prepare the release branch and tags - -First, ensure that `active_release` contains the content of the desired release. For minor and patch releases, no additional steps are needed. - -To prepare for _a major release_, change `active release` to point at the latest `master` with commands such as: - -``` -git checkout active_release -git fetch apache -git reset --hard apache/master -git push -f -``` - -### Update CHANGELOG.md + Version - -Now prepare a PR to update `CHANGELOG.md` and versions on `active_release` branch to reflect the planned release. - -See [#298](https://github.com/apache/arrow-rs/pull/298) for an example. - -Here are the commands that could be used to prepare the 5.1.0 release: +See [#1141](https://github.com/apache/arrow-rs/pull/1141) for an example. ```bash -git checkout active_release +git checkout master git pull git checkout -b make-release @@ -118,7 +61,7 @@ CHANGELOG_GITHUB_TOKEN= ./dev/release/update_change_log.sh git commit -a -m 'Create changelog' # update versions -sed -i '' -e 's/5.0.0-SNAPSHOT/5.1.0/g' `find . -name 'Cargo.toml' -or -name '*.md'` +sed -i '' -e 's/14.0.0/16.0.0/g' `find . -name 'Cargo.toml' -or -name '*.md' | grep -v CHANGELOG.md` git commit -a -m 'Update version' ``` @@ -126,9 +69,14 @@ Note that when reviewing the change log, rather than editing the `CHANGELOG.md`, it is preferred to update the issues and their labels (e.g. add `invalid` label to exclude them from release notes) +Merge this PR to `master` prior to the next step. + ## Prepare release candidate tarball -(Note you need to be a committer to run these scripts as they upload to the apache svn distribution servers) +After you have merged the updates to the `CHANGELOG` and version, +create a release candidate using the following steps. Note you need to +be a committer to run these scripts as they upload to the apache `svn` +distribution servers. ### Create git tag for the release: @@ -138,7 +86,7 @@ Using a string such as `4.0.1` as the ``, create and push the tag thusl ```shell git fetch apache -git tag apache/active_release +git tag apache/master # push tag to apache git push apache ``` @@ -198,7 +146,7 @@ The vote will be open for at least 72 hours. For the release to become "official" it needs at least three PMC members to vote +1 on it. -#### Verifying Release Candidates +## Verifying release candidates The `dev/release/verify-release-candidate.sh` is a script in this repository that can assist in the verification process. Run it like: @@ -246,41 +194,3 @@ following commands (cd parquet && cargo publish) (cd parquet_derive && cargo publish) ``` - -# Backporting - -As of the time of writing, backporting to `active_release` done semi-manually. - -_Note_: Since minor releases will be automatically picked up by other CI systems, it is CRITICAL to only cherry pick commits that are API compatible -- that is that do not require any code changes in crates that rely on arrow. API changes are released with the next major version. - -Step 1: Pick the commit to cherry-pick. - -Step 2: Create cherry-pick PR to active_release - -Step 3a: If CI passes, merge cherry-pick PR - -Step 3b: If CI doesn't pass or some other changes are needed, the PR should be reviewed / approved as normal prior to merge - -For example, to backport `b2de5446cc1e45a0559fb39039d0545df1ac0d26` to active_release, you could use the following command - -```shell -git clone git@github.com:apache/arrow-rs.git /tmp/arrow-rs - -CHERRY_PICK_SHA=b2de5446cc1e45a0559fb39039d0545df1ac0d26 ARROW_GITHUB_API_TOKEN=$ARROW_GITHUB_API_TOKEN CHECKOUT_ROOT=/tmp/arrow-rs python3 dev/release/cherry-pick-pr.py -``` - -## Labels - -There are two labels that help keep track of backporting: - -1. [`cherry-picked`](https://github.com/apache/arrow-rs/labels/cherry-picked) for PRs that have been cherry-picked/backported to `active_release` -2. [`release-cherry-pick`](https://github.com/apache/arrow-rs/labels/release-cherry-pick) for the PRs that are the cherry pick to `active_release` - -You can find candidates to cherry pick using [this filter](https://github.com/apache/arrow-rs/pulls?q=is%3Apr+is%3Aclosed+-label%3Arelease-cherry-pick+-label%3Acherry-picked) - -## Rationale for creating PRs on cherry picked PRs: - -1. PRs are a natural place to run the CI tests to make sure there are no logical conflicts -2. PRs offer a place for the original author / committers to comment and say it should/should not be backported. -3. PRs offer a way to make cleanups / fixups and approve (if needed) for non cherry pick PRs -4. There is an additional control / review when the candidate release is created diff --git a/dev/release/cherry-pick-pr.py b/dev/release/cherry-pick-pr.py deleted file mode 100755 index 2886a0dd458f..000000000000 --- a/dev/release/cherry-pick-pr.py +++ /dev/null @@ -1,154 +0,0 @@ -#!/usr/bin/python3 -############################################################################## -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -############################################################################## - -# This script is designed to create a cherry pick PR to a target branch -# -# Usage: python3 cherry_pick_pr.py -# -# To test locally: -# -# git clone git@github.com:apache/arrow-rs.git /tmp/arrow-rs -# -# pip3 install PyGithub -# ARROW_GITHUB_API_TOKEN=<..> -# CHECKOUT_ROOT= -# CHERRY_PICK_SHA= python3 cherry-pick-pr.py -# -import os -import sys -import six -import subprocess - -from pathlib import Path - -TARGET_BRANCH = 'active_release' -TARGET_REPO = 'apache/arrow-rs' - -p = Path(__file__) - -# Use github workspace if specified -repo_root = os.environ.get("CHECKOUT_ROOT") -if repo_root is None: - print("arrow-rs checkout must be supplied in CHECKOUT_ROOT environment") - sys.exit(1) - -print("Using checkout in {}".format(repo_root)) - -token = os.environ.get('ARROW_GITHUB_API_TOKEN', None) -if token is None: - print("GITHUB token must be supplied in ARROW_GITHUB_API_TOKEN environmet") - sys.exit(1) - -new_sha = os.environ.get('CHERRY_PICK_SHA', None) -if new_sha is None: - print("SHA to cherry pick must be supplied in CHERRY_PICK_SHA environment") - sys.exit(1) - - -# from merge_pr.py from arrow repo -def run_cmd(cmd): - if isinstance(cmd, six.string_types): - cmd = cmd.split(' ') - try: - output = subprocess.check_output(cmd) - except subprocess.CalledProcessError as e: - # this avoids hiding the stdout / stderr of failed processes - print('Command failed: %s' % cmd) - print('With output:') - print('--------------') - print(e.output) - print('--------------') - raise e - - if isinstance(output, six.binary_type): - output = output.decode('utf-8') - - return output - - -os.chdir(repo_root) -new_sha_short = run_cmd("git rev-parse --short {}".format(new_sha)).strip() -new_branch = 'cherry_pick_{}'.format(new_sha_short) - - -def make_cherry_pick(): - if os.environ.get('GITHUB_SHA', None) is not None: - print("Running on github runner, setting email/username") - run_cmd(['git', 'config', 'user.email', 'dev@arrow.apache.com']) - run_cmd(['git', 'config', 'user.name', 'Arrow-RS Automation']) - - # - # Create a new branch from active_release - # and cherry pick to there. - # - - print("Creating cherry pick from {} to {}".format( - new_sha_short, new_branch - )) - - # The following tortured dance is required due to how the github - # actions/checkout works (it doesn't pull other branches and pulls - # only one commit back) - - # pull 10 commits back so we can get the proper cherry pick - # (probably only need 2 but 10 must be better, right?) - run_cmd(['git', 'fetch', '--depth', '10', 'origin', 'master']) - run_cmd(['git', 'fetch', 'origin', 'active_release']) - run_cmd(['git', 'checkout', '-b', new_branch]) - run_cmd(['git', 'reset', '--hard', 'origin/active_release']) - run_cmd(['git', 'cherry-pick', new_sha]) - run_cmd(['git', 'push', '-u', 'origin', new_branch]) - - - -def make_cherry_pick_pr(): - from github import Github - g = Github(token) - repo = g.get_repo(TARGET_REPO) - - release_cherry_pick_label = repo.get_label('release-cherry-pick') - cherry_picked_label = repo.get_label('cherry-picked') - - # Default titles - new_title = 'Cherry pick {} to active_release'.format(new_sha) - new_commit_message = 'Automatic cherry-pick of {}\n'.format(new_sha) - - # try and get info from github api - commit = repo.get_commit(new_sha) - for orig_pull in commit.get_pulls(): - new_commit_message += '* Originally appeared in {}: {}\n'.format( - orig_pull.html_url, orig_pull.title) - new_title = 'Cherry pick {} to active_release'.format(orig_pull.title) - orig_pull.add_to_labels(cherry_picked_label) - - pr = repo.create_pull(title=new_title, - body=new_commit_message, - base='refs/heads/active_release', - head='refs/heads/{}'.format(new_branch), - maintainer_can_modify=True, - ) - - pr.add_to_labels(release_cherry_pick_label) - - print('Created PR {}'.format(pr.html_url)) - - -make_cherry_pick() -make_cherry_pick_pr() diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index c976e2c58f86..466f6fa45267 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -11,6 +11,13 @@ parquet_derive/test/dependency/* Cargo.lock filtered_rat.txt rat.txt +conbench/benchmarks.json +conbench/requirements.txt +conbench/requirements-test.txt +conbench/.flake8 +conbench/.isort.cfg # auto-generated arrow-flight/src/arrow.flight.protocol.rs +arrow-flight/src/sql/arrow.flight.protocol.sql.rs .github/* +parquet/src/bin/parquet-fromcsv-help.txt diff --git a/dev/release/release-tarball.sh b/dev/release/release-tarball.sh index 9612921e8bef..352c15a3038b 100755 --- a/dev/release/release-tarball.sh +++ b/dev/release/release-tarball.sh @@ -32,6 +32,9 @@ set -e set -u +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" + if [ "$#" -ne 2 ]; then echo "Usage: $0 " echo "ex. $0 4.1.0 2" @@ -68,5 +71,9 @@ svn ci -m "Apache Arrow Rust ${version}" ${tmp_dir}/release echo "Clean up" rm -rf ${tmp_dir} -echo "Success! The release is available here:" +echo "Success!" +echo "The release is available here:" echo " https://dist.apache.org/repos/dist/release/arrow/${release_version}" + +echo "Clean up old versions from svn" +"${SOURCE_TOP_DIR}"/dev/release/remove-old-releases.sh diff --git a/arrow/test/dependency/no-default-features/Cargo.toml b/dev/release/remove-old-releases.sh old mode 100644 new mode 100755 similarity index 52% rename from arrow/test/dependency/no-default-features/Cargo.toml rename to dev/release/remove-old-releases.sh index 6c7f0652324d..0722197a84b4 --- a/arrow/test/dependency/no-default-features/Cargo.toml +++ b/dev/release/remove-old-releases.sh @@ -1,3 +1,5 @@ +#!/bin/bash +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,16 +16,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# -[package] -name = "no-default-features" -description = "Models a user application of arrow that specifies no-default-features=true" -version = "0.1.0" -edition = "2018" +# This script removes all but the most recent versions of arrow-rs +# from svn +# +# The older versions are in SVN history as well as available on the +# archive page https://archive.apache.org/dist/ +# +# See +# https://infra.apache.org/release-download-pages.html -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +set -e +set -u -[dependencies] -arrow = { path = "../../../../arrow", version = "7.0.0-SNAPSHOT", default-features = false } +svn_base="https://dist.apache.org/repos/dist/release/arrow" -[workspace] +echo "Remove all but the most recent version" +old_releases=$( + svn ls ${svn_base} | \ + grep -E '^arrow-rs-[0-9\.]+' | \ + sort --version-sort --reverse | \ + tail -n +2 +) +for old_release_version in $old_releases; do + echo "Remove old release ${old_release_version}" + svn delete -m "Removing ${old_release_version}" ${svn_base}/${old_release_version} +done diff --git a/dev/release/update_change_log.sh b/dev/release/update_change_log.sh index 53c7f06016fd..316f10c2594b 100755 --- a/dev/release/update_change_log.sh +++ b/dev/release/update_change_log.sh @@ -29,14 +29,48 @@ set -e +SINCE_TAG="15.0.0" +FUTURE_RELEASE="16.0.0" + SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" -pushd ${SOURCE_TOP_DIR} -docker run -it --rm -e CHANGELOG_GITHUB_TOKEN=$CHANGELOG_GITHUB_TOKEN -v "$(pwd)":/usr/local/src/your-app githubchangeloggenerator/github-changelog-generator \ +OUTPUT_PATH="${SOURCE_TOP_DIR}/CHANGELOG.md" + +# remove license header so github-changelog-generator has a clean base to append +sed -i.bak '1,18d' "${OUTPUT_PATH}" + +pushd "${SOURCE_TOP_DIR}" +docker run -it --rm -e CHANGELOG_GITHUB_TOKEN="$CHANGELOG_GITHUB_TOKEN" -v "$(pwd)":/usr/local/src/your-app githubchangeloggenerator/github-changelog-generator \ --user apache \ --project arrow-rs \ - --since-tag 5.0.0 \ - --future-release 6.0.0 + --cache-file=.githubchangeloggenerator.cache \ + --cache-log=.githubchangeloggenerator.cache.log \ + --http-cache \ + --max-issues=300 \ + --since-tag ${SINCE_TAG} \ + --future-release ${FUTURE_RELEASE} + +sed -i.bak "s/\\\n/\n\n/" "${OUTPUT_PATH}" + +# Put license header back on +echo ' +' | cat - "${OUTPUT_PATH}" > "${OUTPUT_PATH}".tmp +mv "${OUTPUT_PATH}".tmp "${OUTPUT_PATH}" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 6501aa0ae430..a5ed04c6f8b8 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -53,6 +53,14 @@ import_gpg_keys() { gpg --import KEYS } +if type shasum >/dev/null 2>&1; then + sha256_verify="shasum -a 256 -c" + sha512_verify="shasum -a 512 -c" +else + sha256_verify="sha256sum -c" + sha512_verify="sha512sum -c" +fi + fetch_archive() { local dist_name=$1 download_rc_file ${dist_name}.tar.gz @@ -60,8 +68,8 @@ fetch_archive() { download_rc_file ${dist_name}.tar.gz.sha256 download_rc_file ${dist_name}.tar.gz.sha512 gpg --verify ${dist_name}.tar.gz.asc ${dist_name}.tar.gz - shasum -a 256 -c ${dist_name}.tar.gz.sha256 - shasum -a 512 -c ${dist_name}.tar.gz.sha512 + ${sha256_verify} ${dist_name}.tar.gz.sha256 + ${sha512_verify} ${dist_name}.tar.gz.sha512 } verify_dir_artifact_signatures() { @@ -75,9 +83,9 @@ verify_dir_artifact_signatures() { pushd $(dirname $artifact) base_artifact=$(basename $artifact) if [ -f $base_artifact.sha256 ]; then - shasum -a 256 -c $base_artifact.sha256 || exit 1 + ${sha256_verify} $base_artifact.sha256 || exit 1 fi - shasum -a 512 -c $base_artifact.sha512 || exit 1 + ${sha512_verify} $base_artifact.sha512 || exit 1 popd done } @@ -139,17 +147,9 @@ test_source_distribution() { cargo publish --dry-run popd - pushd arrow-flight - cargo publish --dry-run - popd - - pushd parquet - cargo publish --dry-run - popd + # Note can't verify parquet/arrow-flight/parquet-derive until arrow is actually published + # as they depend on arrow - pushd parquet_derive - cargo publish --dry-run - popd } TEST_SUCCESS=no diff --git a/format/FlightSql.proto b/format/FlightSql.proto new file mode 100644 index 000000000000..3e85e348bc9c --- /dev/null +++ b/format/FlightSql.proto @@ -0,0 +1,1337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +import "google/protobuf/descriptor.proto"; + +option java_package = "org.apache.arrow.flight.sql.impl"; + +package arrow.flight.protocol.sql; + +/* + * Represents a metadata request. Used in the command member of FlightDescriptor + * for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the metadata request. + * + * The returned Arrow schema will be: + * < + * info_name: uint32 not null, + * value: dense_union< + * string_value: utf8, + * bool_value: bool, + * bigint_value: int64, + * int32_bitmask: int32, + * string_list: list + * int32_to_int32_list_map: map> + * > + * where there is one row per requested piece of metadata information. + */ +message CommandGetSqlInfo { + option (experimental) = true; + + /* + * Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide + * Flight SQL clients with basic, SQL syntax and SQL functions related information. + * More information types can be added in future releases. + * E.g. more SQL syntax support types, scalar functions support, type conversion support etc. + * + * Note that the set of metadata may expand. + * + * Initially, Flight SQL will support the following information types: + * - Server Information - Range [0-500) + * - Syntax Information - Range [500-1000) + * Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). + * Custom options should start at 10,000. + * + * If omitted, then all metadata will be retrieved. + * Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must + * at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. + * If additional metadata is included, the metadata IDs should start from 10,000. + */ + repeated uint32 info = 1; +} + +// Options for CommandGetSqlInfo. +enum SqlInfo { + + // Server Information [0-500): Provides basic information about the Flight SQL Server. + + // Retrieves a UTF-8 string with the name of the Flight SQL Server. + FLIGHT_SQL_SERVER_NAME = 0; + + // Retrieves a UTF-8 string with the native version of the Flight SQL Server. + FLIGHT_SQL_SERVER_VERSION = 1; + + // Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. + FLIGHT_SQL_SERVER_ARROW_VERSION = 2; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server is read only. + * + * Returns: + * - false: if read-write + * - true: if read only + */ + FLIGHT_SQL_SERVER_READ_ONLY = 3; + + + // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of catalogs. + * - true: if it supports CREATE and DROP of catalogs. + */ + SQL_DDL_CATALOG = 500; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of schemas. + * - true: if it supports CREATE and DROP of schemas. + */ + SQL_DDL_SCHEMA = 501; + + /* + * Indicates whether the Flight SQL Server supports CREATE and DROP of tables. + * + * Returns: + * - false: if it doesn't support CREATE and DROP of tables. + * - true: if it supports CREATE and DROP of tables. + */ + SQL_DDL_TABLE = 502; + + /* + * Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of catalog, table, schema and table names. + * + * The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + */ + SQL_IDENTIFIER_CASE = 503; + + // Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. + SQL_IDENTIFIER_QUOTE_CHAR = 504; + + /* + * Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of quoted identifiers. + * + * The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + */ + SQL_QUOTED_IDENTIFIER_CASE = 505; + + /* + * Retrieves a boolean value indicating whether all tables are selectable. + * + * Returns: + * - false: if not all tables are selectable or if none are; + * - true: if all tables are selectable. + */ + SQL_ALL_TABLES_ARE_SELECTABLE = 506; + + /* + * Retrieves the null ordering. + * + * Returns a uint32 ordinal for the null ordering being used, as described in + * `arrow.flight.protocol.sql.SqlNullOrdering`. + */ + SQL_NULL_ORDERING = 507; + + // Retrieves a UTF-8 string list with values of the supported keywords. + SQL_KEYWORDS = 508; + + // Retrieves a UTF-8 string list with values of the supported numeric functions. + SQL_NUMERIC_FUNCTIONS = 509; + + // Retrieves a UTF-8 string list with values of the supported string functions. + SQL_STRING_FUNCTIONS = 510; + + // Retrieves a UTF-8 string list with values of the supported system functions. + SQL_SYSTEM_FUNCTIONS = 511; + + // Retrieves a UTF-8 string list with values of the supported datetime functions. + SQL_DATETIME_FUNCTIONS = 512; + + /* + * Retrieves the UTF-8 string that can be used to escape wildcard characters. + * This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern + * (and therefore use one of the wildcard characters). + * The '_' character represents any single character; the '%' character represents any sequence of zero or more + * characters. + */ + SQL_SEARCH_STRING_ESCAPE = 513; + + /* + * Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names + * (those beyond a-z, A-Z, 0-9 and _). + */ + SQL_EXTRA_NAME_CHARACTERS = 514; + + /* + * Retrieves a boolean value indicating whether column aliasing is supported. + * If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns + * as required. + * + * Returns: + * - false: if column aliasing is unsupported; + * - true: if column aliasing is supported. + */ + SQL_SUPPORTS_COLUMN_ALIASING = 515; + + /* + * Retrieves a boolean value indicating whether concatenations between null and non-null values being + * null are supported. + * + * - Returns: + * - false: if concatenations between null and non-null values being null are unsupported; + * - true: if concatenations between null and non-null values being null are supported. + */ + SQL_NULL_PLUS_NULL_IS_NULL = 516; + + /* + * Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, + * indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on + * SqlSupportsConvert enum. + * The returned map will be: map> + */ + SQL_SUPPORTS_CONVERT = 517; + + /* + * Retrieves a boolean value indicating whether, when table correlation names are supported, + * they are restricted to being different from the names of the tables. + * + * Returns: + * - false: if table correlation names are unsupported; + * - true: if table correlation names are supported. + */ + SQL_SUPPORTS_TABLE_CORRELATION_NAMES = 518; + + /* + * Retrieves a boolean value indicating whether, when table correlation names are supported, + * they are restricted to being different from the names of the tables. + * + * Returns: + * - false: if different table correlation names are unsupported; + * - true: if different table correlation names are supported + */ + SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES = 519; + + /* + * Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. + * + * Returns: + * - false: if expressions in ORDER BY are unsupported; + * - true: if expressions in ORDER BY are supported; + */ + SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY = 520; + + /* + * Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY + * clause is supported. + * + * Returns: + * - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; + * - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. + */ + SQL_SUPPORTS_ORDER_BY_UNRELATED = 521; + + /* + * Retrieves the supported GROUP BY commands; + * + * Returns an int32 bitmask value representing the supported commands. + * The returned bitmask should be parsed in order to retrieve the supported commands. + * + * For instance: + * - return 0 (\b0) => [] (GROUP BY is unsupported); + * - return 1 (\b1) => [SQL_GROUP_BY_UNRELATED]; + * - return 2 (\b10) => [SQL_GROUP_BY_BEYOND_SELECT]; + * - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. + * Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. + */ + SQL_SUPPORTED_GROUP_BY = 522; + + /* + * Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. + * + * Returns: + * - false: if specifying a LIKE escape clause is unsupported; + * - true: if specifying a LIKE escape clause is supported. + */ + SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE = 523; + + /* + * Retrieves a boolean value indicating whether columns may be defined as non-nullable. + * + * Returns: + * - false: if columns cannot be defined as non-nullable; + * - true: if columns may be defined as non-nullable. + */ + SQL_SUPPORTS_NON_NULLABLE_COLUMNS = 524; + + /* + * Retrieves the supported SQL grammar level as per the ODBC specification. + * + * Returns an int32 bitmask value representing the supported SQL grammar level. + * The returned bitmask should be parsed in order to retrieve the supported grammar levels. + * + * For instance: + * - return 0 (\b0) => [] (SQL grammar is unsupported); + * - return 1 (\b1) => [SQL_MINIMUM_GRAMMAR]; + * - return 2 (\b10) => [SQL_CORE_GRAMMAR]; + * - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; + * - return 4 (\b100) => [SQL_EXTENDED_GRAMMAR]; + * - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + * - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + * - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. + * Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. + */ + SQL_SUPPORTED_GRAMMAR = 525; + + /* + * Retrieves the supported ANSI92 SQL grammar level. + * + * Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. + * The returned bitmask should be parsed in order to retrieve the supported commands. + * + * For instance: + * - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); + * - return 1 (\b1) => [ANSI92_ENTRY_SQL]; + * - return 2 (\b10) => [ANSI92_INTERMEDIATE_SQL]; + * - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; + * - return 4 (\b100) => [ANSI92_FULL_SQL]; + * - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; + * - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; + * - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. + * Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. + */ + SQL_ANSI92_SUPPORTED_LEVEL = 526; + + /* + * Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. + * + * Returns: + * - false: if the SQL Integrity Enhancement Facility is supported; + * - true: if the SQL Integrity Enhancement Facility is supported. + */ + SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY = 527; + + /* + * Retrieves the support level for SQL OUTER JOINs. + * + * Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in + * `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. + */ + SQL_OUTER_JOINS_SUPPORT_LEVEL = 528; + + // Retrieves a UTF-8 string with the preferred term for "schema". + SQL_SCHEMA_TERM = 529; + + // Retrieves a UTF-8 string with the preferred term for "procedure". + SQL_PROCEDURE_TERM = 530; + + // Retrieves a UTF-8 string with the preferred term for "catalog". + SQL_CATALOG_TERM = 531; + + /* + * Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. + * + * - false: if a catalog does not appear at the start of a fully qualified table name; + * - true: if a catalog appears at the start of a fully qualified table name. + */ + SQL_CATALOG_AT_START = 532; + + /* + * Retrieves the supported actions for a SQL schema. + * + * Returns an int32 bitmask value representing the supported actions for a SQL schema. + * The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. + * + * For instance: + * - return 0 (\b0) => [] (no supported actions for SQL schema); + * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS]; + * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + * Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + */ + SQL_SCHEMAS_SUPPORTED_ACTIONS = 533; + + /* + * Retrieves the supported actions for a SQL schema. + * + * Returns an int32 bitmask value representing the supported actions for a SQL catalog. + * The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. + * + * For instance: + * - return 0 (\b0) => [] (no supported actions for SQL catalog); + * - return 1 (\b1) => [SQL_ELEMENT_IN_PROCEDURE_CALLS]; + * - return 2 (\b10) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + * - return 4 (\b100) => [SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + * - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + * Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + */ + SQL_CATALOGS_SUPPORTED_ACTIONS = 534; + + /* + * Retrieves the supported SQL positioned commands. + * + * Returns an int32 bitmask value representing the supported SQL positioned commands. + * The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL positioned commands); + * - return 1 (\b1) => [SQL_POSITIONED_DELETE]; + * - return 2 (\b10) => [SQL_POSITIONED_UPDATE]; + * - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. + * Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. + */ + SQL_SUPPORTED_POSITIONED_COMMANDS = 535; + + /* + * Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. + * + * Returns: + * - false: if SELECT FOR UPDATE statements are unsupported; + * - true: if SELECT FOR UPDATE statements are supported. + */ + SQL_SELECT_FOR_UPDATE_SUPPORTED = 536; + + /* + * Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax + * are supported. + * + * Returns: + * - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; + * - true: if stored procedure calls that use the stored procedure escape syntax are supported. + */ + SQL_STORED_PROCEDURES_SUPPORTED = 537; + + /* + * Retrieves the supported SQL subqueries. + * + * Returns an int32 bitmask value representing the supported SQL subqueries. + * The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL subqueries); + * - return 1 (\b1) => [SQL_SUBQUERIES_IN_COMPARISONS]; + * - return 2 (\b10) => [SQL_SUBQUERIES_IN_EXISTS]; + * - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; + * - return 4 (\b100) => [SQL_SUBQUERIES_IN_INS]; + * - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; + * - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; + * - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; + * - return 8 (\b1000) => [SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + * - ... + * Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. + */ + SQL_SUPPORTED_SUBQUERIES = 538; + + /* + * Retrieves a boolean value indicating whether correlated subqueries are supported. + * + * Returns: + * - false: if correlated subqueries are unsupported; + * - true: if correlated subqueries are supported. + */ + SQL_CORRELATED_SUBQUERIES_SUPPORTED = 539; + + /* + * Retrieves the supported SQL UNIONs. + * + * Returns an int32 bitmask value representing the supported SQL UNIONs. + * The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL positioned commands); + * - return 1 (\b1) => [SQL_UNION]; + * - return 2 (\b10) => [SQL_UNION_ALL]; + * - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. + * Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. + */ + SQL_SUPPORTED_UNIONS = 540; + + // Retrieves a uint32 value representing the maximum number of hex characters allowed in an inline binary literal. + SQL_MAX_BINARY_LITERAL_LENGTH = 541; + + // Retrieves a uint32 value representing the maximum number of characters allowed for a character literal. + SQL_MAX_CHAR_LITERAL_LENGTH = 542; + + // Retrieves a uint32 value representing the maximum number of characters allowed for a column name. + SQL_MAX_COLUMN_NAME_LENGTH = 543; + + // Retrieves a uint32 value representing the the maximum number of columns allowed in a GROUP BY clause. + SQL_MAX_COLUMNS_IN_GROUP_BY = 544; + + // Retrieves a uint32 value representing the maximum number of columns allowed in an index. + SQL_MAX_COLUMNS_IN_INDEX = 545; + + // Retrieves a uint32 value representing the maximum number of columns allowed in an ORDER BY clause. + SQL_MAX_COLUMNS_IN_ORDER_BY = 546; + + // Retrieves a uint32 value representing the maximum number of columns allowed in a SELECT list. + SQL_MAX_COLUMNS_IN_SELECT = 547; + + // Retrieves a uint32 value representing the maximum number of columns allowed in a table. + SQL_MAX_COLUMNS_IN_TABLE = 548; + + // Retrieves a uint32 value representing the maximum number of concurrent connections possible. + SQL_MAX_CONNECTIONS = 549; + + // Retrieves a uint32 value the maximum number of characters allowed in a cursor name. + SQL_MAX_CURSOR_NAME_LENGTH = 550; + + /* + * Retrieves a uint32 value representing the maximum number of bytes allowed for an index, + * including all of the parts of the index. + */ + SQL_MAX_INDEX_LENGTH = 551; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a schema name. + SQL_DB_SCHEMA_NAME_LENGTH = 552; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a procedure name. + SQL_MAX_PROCEDURE_NAME_LENGTH = 553; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a catalog name. + SQL_MAX_CATALOG_NAME_LENGTH = 554; + + // Retrieves a uint32 value representing the maximum number of bytes allowed in a single row. + SQL_MAX_ROW_SIZE = 555; + + /* + * Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL + * data types LONGVARCHAR and LONGVARBINARY. + * + * Returns: + * - false: if return value for the JDBC method getMaxRowSize does + * not include the SQL data types LONGVARCHAR and LONGVARBINARY; + * - true: if return value for the JDBC method getMaxRowSize includes + * the SQL data types LONGVARCHAR and LONGVARBINARY. + */ + SQL_MAX_ROW_SIZE_INCLUDES_BLOBS = 556; + + /* + * Retrieves a uint32 value representing the maximum number of characters allowed for an SQL statement; + * a result of 0 (zero) means that there is no limit or the limit is not known. + */ + SQL_MAX_STATEMENT_LENGTH = 557; + + // Retrieves a uint32 value representing the maximum number of active statements that can be open at the same time. + SQL_MAX_STATEMENTS = 558; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a table name. + SQL_MAX_TABLE_NAME_LENGTH = 559; + + // Retrieves a uint32 value representing the maximum number of tables allowed in a SELECT statement. + SQL_MAX_TABLES_IN_SELECT = 560; + + // Retrieves a uint32 value representing the maximum number of characters allowed in a user name. + SQL_MAX_USERNAME_LENGTH = 561; + + /* + * Retrieves this database's default transaction isolation level as described in + * `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + * + * Returns a uint32 ordinal for the SQL transaction isolation level. + */ + SQL_DEFAULT_TRANSACTION_ISOLATION = 562; + + /* + * Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a + * noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + * + * Returns: + * - false: if transactions are unsupported; + * - true: if transactions are supported. + */ + SQL_TRANSACTIONS_SUPPORTED = 563; + + /* + * Retrieves the supported transactions isolation levels. + * + * Returns an int32 bitmask value representing the supported transactions isolation levels. + * The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. + * + * For instance: + * - return 0 (\b0) => [] (no supported SQL transactions isolation levels); + * - return 1 (\b1) => [SQL_TRANSACTION_NONE]; + * - return 2 (\b10) => [SQL_TRANSACTION_READ_UNCOMMITTED]; + * - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; + * - return 4 (\b100) => [SQL_TRANSACTION_REPEATABLE_READ]; + * - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 8 (\b1000) => [SQL_TRANSACTION_REPEATABLE_READ]; + * - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + * - return 16 (\b10000) => [SQL_TRANSACTION_SERIALIZABLE]; + * - ... + * Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + */ + SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS = 564; + + /* + * Retrieves a boolean value indicating whether a data definition statement within a transaction forces + * the transaction to commit. + * + * Returns: + * - false: if a data definition statement within a transaction does not force the transaction to commit; + * - true: if a data definition statement within a transaction forces the transaction to commit. + */ + SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT = 565; + + /* + * Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. + * + * Returns: + * - false: if a data definition statement within a transaction is taken into account; + * - true: a data definition statement within a transaction is ignored. + */ + SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED = 566; + + /* + * Retrieves an int32 bitmask value representing the supported result set types. + * The returned bitmask should be parsed in order to retrieve the supported result set types. + * + * For instance: + * - return 0 (\b0) => [] (no supported result set types); + * - return 1 (\b1) => [SQL_RESULT_SET_TYPE_UNSPECIFIED]; + * - return 2 (\b10) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + * - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + * - return 4 (\b100) => [SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + * - return 8 (\b1000) => [SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE]; + * - ... + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. + */ + SQL_SUPPORTED_RESULT_SET_TYPES = 567; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED = 568; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY = 569; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE = 570; + + /* + * Returns an int32 bitmask value concurrency types supported for + * `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. + * + * For instance: + * - return 0 (\b0) => [] (no supported concurrency types for this result set type) + * - return 1 (\b1) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED] + * - return 2 (\b10) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + * - return 4 (\b100) => [SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + * Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + */ + SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE = 571; + + /* + * Retrieves a boolean value indicating whether this database supports batch updates. + * + * - false: if this database does not support batch updates; + * - true: if this database supports batch updates. + */ + SQL_BATCH_UPDATES_SUPPORTED = 572; + + /* + * Retrieves a boolean value indicating whether this database supports savepoints. + * + * Returns: + * - false: if this database does not support savepoints; + * - true: if this database supports savepoints. + */ + SQL_SAVEPOINTS_SUPPORTED = 573; + + /* + * Retrieves a boolean value indicating whether named parameters are supported in callable statements. + * + * Returns: + * - false: if named parameters in callable statements are unsupported; + * - true: if named parameters in callable statements are supported. + */ + SQL_NAMED_PARAMETERS_SUPPORTED = 574; + + /* + * Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. + * + * Returns: + * - false: if updates made to a LOB are made directly to the LOB; + * - true: if updates made to a LOB are made on a copy. + */ + SQL_LOCATORS_UPDATE_COPY = 575; + + /* + * Retrieves a boolean value indicating whether invoking user-defined or vendor functions + * using the stored procedure escape syntax is supported. + * + * Returns: + * - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; + * - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. + */ + SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = 576; +} + +enum SqlSupportedCaseSensitivity { + SQL_CASE_SENSITIVITY_UNKNOWN = 0; + SQL_CASE_SENSITIVITY_CASE_INSENSITIVE = 1; + SQL_CASE_SENSITIVITY_UPPERCASE = 2; + SQL_CASE_SENSITIVITY_LOWERCASE = 3; +} + +enum SqlNullOrdering { + SQL_NULLS_SORTED_HIGH = 0; + SQL_NULLS_SORTED_LOW = 1; + SQL_NULLS_SORTED_AT_START = 2; + SQL_NULLS_SORTED_AT_END = 3; +} + +enum SupportedSqlGrammar { + SQL_MINIMUM_GRAMMAR = 0; + SQL_CORE_GRAMMAR = 1; + SQL_EXTENDED_GRAMMAR = 2; +} + +enum SupportedAnsi92SqlGrammarLevel { + ANSI92_ENTRY_SQL = 0; + ANSI92_INTERMEDIATE_SQL = 1; + ANSI92_FULL_SQL = 2; +} + +enum SqlOuterJoinsSupportLevel { + SQL_JOINS_UNSUPPORTED = 0; + SQL_LIMITED_OUTER_JOINS = 1; + SQL_FULL_OUTER_JOINS = 2; +} + +enum SqlSupportedGroupBy { + SQL_GROUP_BY_UNRELATED = 0; + SQL_GROUP_BY_BEYOND_SELECT = 1; +} + +enum SqlSupportedElementActions { + SQL_ELEMENT_IN_PROCEDURE_CALLS = 0; + SQL_ELEMENT_IN_INDEX_DEFINITIONS = 1; + SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS = 2; +} + +enum SqlSupportedPositionedCommands { + SQL_POSITIONED_DELETE = 0; + SQL_POSITIONED_UPDATE = 1; +} + +enum SqlSupportedSubqueries { + SQL_SUBQUERIES_IN_COMPARISONS = 0; + SQL_SUBQUERIES_IN_EXISTS = 1; + SQL_SUBQUERIES_IN_INS = 2; + SQL_SUBQUERIES_IN_QUANTIFIEDS = 3; +} + +enum SqlSupportedUnions { + SQL_UNION = 0; + SQL_UNION_ALL = 1; +} + +enum SqlTransactionIsolationLevel { + SQL_TRANSACTION_NONE = 0; + SQL_TRANSACTION_READ_UNCOMMITTED = 1; + SQL_TRANSACTION_READ_COMMITTED = 2; + SQL_TRANSACTION_REPEATABLE_READ = 3; + SQL_TRANSACTION_SERIALIZABLE = 4; +} + +enum SqlSupportedTransactions { + SQL_TRANSACTION_UNSPECIFIED = 0; + SQL_DATA_DEFINITION_TRANSACTIONS = 1; + SQL_DATA_MANIPULATION_TRANSACTIONS = 2; +} + +enum SqlSupportedResultSetType { + SQL_RESULT_SET_TYPE_UNSPECIFIED = 0; + SQL_RESULT_SET_TYPE_FORWARD_ONLY = 1; + SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE = 2; + SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE = 3; +} + +enum SqlSupportedResultSetConcurrency { + SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED = 0; + SQL_RESULT_SET_CONCURRENCY_READ_ONLY = 1; + SQL_RESULT_SET_CONCURRENCY_UPDATABLE = 2; +} + +enum SqlSupportsConvert { + SQL_CONVERT_BIGINT = 0; + SQL_CONVERT_BINARY = 1; + SQL_CONVERT_BIT = 2; + SQL_CONVERT_CHAR = 3; + SQL_CONVERT_DATE = 4; + SQL_CONVERT_DECIMAL = 5; + SQL_CONVERT_FLOAT = 6; + SQL_CONVERT_INTEGER = 7; + SQL_CONVERT_INTERVAL_DAY_TIME = 8; + SQL_CONVERT_INTERVAL_YEAR_MONTH = 9; + SQL_CONVERT_LONGVARBINARY = 10; + SQL_CONVERT_LONGVARCHAR = 11; + SQL_CONVERT_NUMERIC = 12; + SQL_CONVERT_REAL = 13; + SQL_CONVERT_SMALLINT = 14; + SQL_CONVERT_TIME = 15; + SQL_CONVERT_TIMESTAMP = 16; + SQL_CONVERT_TINYINT = 17; + SQL_CONVERT_VARBINARY = 18; + SQL_CONVERT_VARCHAR = 19; +} + +/* + * Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. + * The definition of a catalog depends on vendor/implementation. It is usually the database itself + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8 not null + * > + * The returned data should be ordered by catalog_name. + */ +message CommandGetCatalogs { + option (experimental) = true; +} + +/* + * Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. + * The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8, + * db_schema_name: utf8 not null + * > + * The returned data should be ordered by catalog_name, then db_schema_name. + */ +message CommandGetDbSchemas { + option (experimental) = true; + + /* + * Specifies the Catalog to search for the tables. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies a filter pattern for schemas to search for. + * When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. + * In the pattern string, two special characters can be used to denote matching rules: + * - "%" means to match any substring with 0 or more characters. + * - "_" means to match any one character. + */ + optional string db_schema_filter_pattern = 2; +} + +/* + * Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8, + * db_schema_name: utf8, + * table_name: utf8 not null, + * table_type: utf8 not null, + * [optional] table_schema: bytes not null (schema of the table as described in Schema.fbs::Schema, + * it is serialized as an IPC message.) + * > + * The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. + */ +message CommandGetTables { + option (experimental) = true; + + /* + * Specifies the Catalog to search for the tables. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies a filter pattern for schemas to search for. + * When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. + * In the pattern string, two special characters can be used to denote matching rules: + * - "%" means to match any substring with 0 or more characters. + * - "_" means to match any one character. + */ + optional string db_schema_filter_pattern = 2; + + /* + * Specifies a filter pattern for tables to search for. + * When no table_name_filter_pattern is provided, all tables matching other filters are searched. + * In the pattern string, two special characters can be used to denote matching rules: + * - "%" means to match any substring with 0 or more characters. + * - "_" means to match any one character. + */ + optional string table_name_filter_pattern = 3; + + /* + * Specifies a filter of table types which must match. + * The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + * TABLE, VIEW, and SYSTEM TABLE are commonly supported. + */ + repeated string table_types = 4; + + // Specifies if the Arrow schema should be returned for found tables. + bool include_schema = 5; +} + +/* + * Represents a request to retrieve the list of table types on a Flight SQL enabled backend. + * The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + * TABLE, VIEW, and SYSTEM TABLE are commonly supported. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * table_type: utf8 not null + * > + * The returned data should be ordered by table_type. + */ +message CommandGetTableTypes { + option (experimental) = true; +} + +/* + * Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * catalog_name: utf8, + * db_schema_name: utf8, + * table_name: utf8 not null, + * column_name: utf8 not null, + * key_name: utf8, + * key_sequence: int not null + * > + * The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. + */ +message CommandGetPrimaryKeys { + option (experimental) = true; + + /* + * Specifies the catalog to search for the table. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies the schema to search for the table. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string db_schema = 2; + + // Specifies the table to get the primary keys for. + string table = 3; +} + +enum UpdateDeleteRules { + CASCADE = 0; + RESTRICT = 1; + SET_NULL = 2; + NO_ACTION = 3; + SET_DEFAULT = 4; +} + +/* + * Represents a request to retrieve a description of the foreign key columns that reference the given table's + * primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * pk_catalog_name: utf8, + * pk_db_schema_name: utf8, + * pk_table_name: utf8 not null, + * pk_column_name: utf8 not null, + * fk_catalog_name: utf8, + * fk_db_schema_name: utf8, + * fk_table_name: utf8 not null, + * fk_column_name: utf8 not null, + * key_sequence: int not null, + * fk_key_name: utf8, + * pk_key_name: utf8, + * update_rule: uint1 not null, + * delete_rule: uint1 not null + * > + * The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. + * update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. + */ +message CommandGetExportedKeys { + option (experimental) = true; + + /* + * Specifies the catalog to search for the foreign key table. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies the schema to search for the foreign key table. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string db_schema = 2; + + // Specifies the foreign key table to get the foreign keys for. + string table = 3; +} + +/* + * Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * pk_catalog_name: utf8, + * pk_db_schema_name: utf8, + * pk_table_name: utf8 not null, + * pk_column_name: utf8 not null, + * fk_catalog_name: utf8, + * fk_db_schema_name: utf8, + * fk_table_name: utf8 not null, + * fk_column_name: utf8 not null, + * key_sequence: int not null, + * fk_key_name: utf8, + * pk_key_name: utf8, + * update_rule: uint1 not null, + * delete_rule: uint1 not null + * > + * The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. + * update_rule and delete_rule returns a byte that is equivalent to actions: + * - 0 = CASCADE + * - 1 = RESTRICT + * - 2 = SET NULL + * - 3 = NO ACTION + * - 4 = SET DEFAULT + */ +message CommandGetImportedKeys { + option (experimental) = true; + + /* + * Specifies the catalog to search for the primary key table. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string catalog = 1; + + /* + * Specifies the schema to search for the primary key table. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string db_schema = 2; + + // Specifies the primary key table to get the foreign keys for. + string table = 3; +} + +/* + * Represents a request to retrieve a description of the foreign key columns in the given foreign key table that + * reference the primary key or the columns representing a unique constraint of the parent table (could be the same + * or a different table) on a Flight SQL enabled backend. + * Used in the command member of FlightDescriptor for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the catalog metadata request. + * + * The returned Arrow schema will be: + * < + * pk_catalog_name: utf8, + * pk_db_schema_name: utf8, + * pk_table_name: utf8 not null, + * pk_column_name: utf8 not null, + * fk_catalog_name: utf8, + * fk_db_schema_name: utf8, + * fk_table_name: utf8 not null, + * fk_column_name: utf8 not null, + * key_sequence: int not null, + * fk_key_name: utf8, + * pk_key_name: utf8, + * update_rule: uint1 not null, + * delete_rule: uint1 not null + * > + * The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. + * update_rule and delete_rule returns a byte that is equivalent to actions: + * - 0 = CASCADE + * - 1 = RESTRICT + * - 2 = SET NULL + * - 3 = NO ACTION + * - 4 = SET DEFAULT + */ +message CommandGetCrossReference { + option (experimental) = true; + + /** + * The catalog name where the parent table is. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string pk_catalog = 1; + + /** + * The Schema name where the parent table is. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string pk_db_schema = 2; + + /** + * The parent table name. It cannot be null. + */ + string pk_table = 3; + + /** + * The catalog name where the foreign table is. + * An empty string retrieves those without a catalog. + * If omitted the catalog name should not be used to narrow the search. + */ + optional string fk_catalog = 4; + + /** + * The schema name where the foreign table is. + * An empty string retrieves those without a schema. + * If omitted the schema name should not be used to narrow the search. + */ + optional string fk_db_schema = 5; + + /** + * The foreign table name. It cannot be null. + */ + string fk_table = 6; +} + +// SQL Execution Action Messages + +/* + * Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. + */ +message ActionCreatePreparedStatementRequest { + option (experimental) = true; + + // The valid SQL string to create a prepared statement for. + string query = 1; +} + +/* + * Wrap the result of a "GetPreparedStatement" action. + * + * The resultant PreparedStatement can be closed either: + * - Manually, through the "ClosePreparedStatement" action; + * - Automatically, by a server timeout. + */ +message ActionCreatePreparedStatementResult { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; + + // If a result set generating query was provided, dataset_schema contains the + // schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. + bytes dataset_schema = 2; + + // If the query provided contained parameters, parameter_schema contains the + // schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. + bytes parameter_schema = 3; +} + +/* + * Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. + * Closes server resources associated with the prepared statement handle. + */ +message ActionClosePreparedStatementRequest { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; +} + + +// SQL Execution Messages. + +/* + * Represents a SQL query. Used in the command member of FlightDescriptor + * for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * - GetFlightInfo: execute the query. + */ +message CommandStatementQuery { + option (experimental) = true; + + // The SQL syntax. + string query = 1; +} + +/** + * Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. + * This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. + */ +message TicketStatementQuery { + option (experimental) = true; + + // Unique identifier for the instance of the statement to execute. + bytes statement_handle = 1; +} + +/* + * Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for + * the following RPC calls: + * - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. + * - GetFlightInfo: execute the prepared statement instance. + */ +message CommandPreparedStatementQuery { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; +} + +/* + * Represents a SQL update query. Used in the command member of FlightDescriptor + * for the the RPC call DoPut to cause the server to execute the included SQL update. + */ +message CommandStatementUpdate { + option (experimental) = true; + + // The SQL syntax. + string query = 1; +} + +/* + * Represents a SQL update query. Used in the command member of FlightDescriptor + * for the the RPC call DoPut to cause the server to execute the included + * prepared statement handle as an update. + */ +message CommandPreparedStatementUpdate { + option (experimental) = true; + + // Opaque handle for the prepared statement on the server. + bytes prepared_statement_handle = 1; +} + +/* + * Returned from the RPC call DoPut when a CommandStatementUpdate + * CommandPreparedStatementUpdate was in the request, containing + * results from the update. + */ +message DoPutUpdateResult { + option (experimental) = true; + + // The number of records updated. A return value of -1 represents + // an unknown updated record count. + int64 record_count = 1; +} + +extend google.protobuf.MessageOptions { + bool experimental = 1000; +} diff --git a/format/Schema.fbs b/format/Schema.fbs index 3b00dd4780d6..9da095177c7d 100644 --- a/format/Schema.fbs +++ b/format/Schema.fbs @@ -246,15 +246,24 @@ table Timestamp { timezone: string; } -enum IntervalUnit: short { YEAR_MONTH, DAY_TIME} +enum IntervalUnit: short { YEAR_MONTH, DAY_TIME, MONTH_DAY_NANO} // A "calendar" interval which models types that don't necessarily // have a precise duration without the context of a base timestamp (e.g. // days can differ in length during day light savings time transitions). +// All integers in the types below are stored in the endianness indicated +// by the schema. // YEAR_MONTH - Indicates the number of elapsed whole months, stored as -// 4-byte integers. +// 4-byte signed integers. // DAY_TIME - Indicates the number of elapsed days and milliseconds, // stored as 2 contiguous 32-bit integers (8-bytes in total). Support // of this IntervalUnit is not required for full arrow compatibility. +// MONTH_DAY_NANO - A triple of the number of elapsed months, days, and nanoseconds. +// The values are stored contiguously in 16 byte blocks. Months and +// days are encoded as 32 bit integers and nanoseconds is encoded as a +// 64 bit integer. All integers are signed. Each field is independent +// (e.g. there is no constraint that nanoseconds have the same sign +// as days or that the quantity of nanoseconds represents less +// than a day's worth of time). table Interval { unit: IntervalUnit; } diff --git a/integration-testing/Cargo.toml b/integration-testing/Cargo.toml index 7ee23a910f96..57b5211129ff 100644 --- a/integration-testing/Cargo.toml +++ b/integration-testing/Cargo.toml @@ -18,13 +18,14 @@ [package] name = "arrow-integration-testing" description = "Binaries used in the Arrow integration tests" -version = "7.0.0-SNAPSHOT" +version = "16.0.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] license = "Apache-2.0" -edition = "2018" +edition = "2021" publish = false +rust-version = "1.57" [features] logging = ["tracing-subscriber"] @@ -33,13 +34,13 @@ logging = ["tracing-subscriber"] arrow = { path = "../arrow" } arrow-flight = { path = "../arrow-flight" } async-trait = "0.1.41" -clap = "2.33" +clap = { version = "~3.1", features = ["derive", "env"] } futures = "0.3" hex = "0.4" -prost = "0.8" +prost = "0.10" serde = { version = "1.0", features = ["rc"] } serde_derive = "1.0" serde_json = { version = "1.0", features = ["preserve_order"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } -tonic = "0.5" -tracing-subscriber = { version = "0.2.15", optional = true } +tonic = "0.7" +tracing-subscriber = { version = "0.3.1", optional = true } diff --git a/integration-testing/README.md b/integration-testing/README.md index ff365426b520..e82591e6b139 100644 --- a/integration-testing/README.md +++ b/integration-testing/README.md @@ -19,12 +19,90 @@ # Apache Arrow Rust Integration Testing -See [Integration.rst](../../docs/source/format/Integration.rst) for an overview of integration testing. +See [Integration Testing](https://arrow.apache.org/docs/format/Integration.html) for an overview of integration testing. This crate contains the following binaries, which are invoked by Archery during integration testing with other Arrow implementations. -| Binary | Purpose | -| --------------------------- | ----------------------------------------- | -| arrow-file-to-stream | Converts an Arrow file to an Arrow stream | -| arrow-stream-to-file | Converts an Arrow stream to an Arrow file | -| arrow-json-integration-test | Converts between Arrow and JSON formats | +| Binary | Purpose | +| ------------------------------ | ----------------------------------------- | +| arrow-file-to-stream | Converts an Arrow file to an Arrow stream | +| arrow-stream-to-file | Converts an Arrow stream to an Arrow file | +| arrow-json-integration-test | Converts between Arrow and JSON formats | +| flight-test-integration-server | Flight integration test: Server | +| flight-test-integration-client | Flight integration test: Client | + +# Notes on how to run Rust Integration Test against C/C++ + +The code for running the integration tests is in the [arrow](https://github.com/apache/arrow) repository + +### Check out code: + +```shell +# check out arrow +git clone git@github.com:apache/arrow.git +# link rust source code into arrow +ln -s arrow/rust +``` + +### Install the tools: + +```shell +cd arrow +pip install -e dev/archery[docker] +``` + +### Build the C++ binaries: + +Follow the [C++ Direction](https://github.com/apache/arrow/tree/master/docs/source/developers/cpp) and build the integration test binaries with a command like this: + +``` +# build cpp binaries +cd arrow/cpp +mkdir build +cd build +cmake -DARROW_BUILD_INTEGRATION=ON -DARROW_FLIGHT=ON --preset ninja-debug-minimal .. +ninja +``` + +### Build the Rust binaries + +Then + +``` +# build rust: +cd ../arrow-rs +cargo build --all +``` + +### Run archery + +You can run the Archery tool using a command such as the following: + +```shell +archery integration --with-cpp=true --with-rust=true +``` + +Above command will run producer and consumer for the matrix of enabled Arrow implementations (e.g. C++ and Rust), +and compare the results. + +In C++ Arrow repo's CI, for integration test, it also runs consumer test with golden files specified by `--gold-dirs`: + +```shell +archery integration --with-cpp=true --with-rust=true --gold-dirs=/path/to/arrow/testing/data/arrow-ipc-stream/integration/0.14.1 --gold-dirs=/path/to/arrow/testing/data/arrow-ipc-stream/integration/0.17.1 +``` + +Actually C++ Arrow repo's CI runs with more implementations other than just C++ and Rust. This is the command to reproduce test result of CI: + +```shell +archery integration --run-flight --with-cpp=1 --with-csharp=1 --with-java=1 --with-js=1 --with-go=1 --gold-dirs=/arrow/testing/data/arrow-ipc-stream/integration/0.14.1 --gold-dirs=/arrow/testing/data/arrow-ipc-stream/integration/0.17.1 --gold-dirs=/arrow/testing/data/arrow-ipc-stream/integration/1.0.0-bigendian --gold-dirs=/arrow/testing/data/arrow-ipc-stream/integration/1.0.0-littleendian --gold-dirs=/arrow/testing/data/arrow-ipc-stream/integration/2.0.0-compression --gold-dirs=/arrow/testing/data/arrow-ipc-stream/integration/4.0.0-shareddict +``` + +To debug an individual test scenario, it is also possible to run the binaries directly: + +```shell +# Run cpp server +$ arrow/cpp/build/debug/flight-test-integration-server -port 49153 + +# run rust client (you can see file names if you run archery --debug +$ arrow/rust/target/debug/flight-test-integration-client --host localhost --port=49153 --path /tmp/generated_dictionary_unsigned.json +``` diff --git a/integration-testing/src/bin/arrow-file-to-stream.rs b/integration-testing/src/bin/arrow-file-to-stream.rs index d6bb0428c0f6..e939fe4f0bf7 100644 --- a/integration-testing/src/bin/arrow-file-to-stream.rs +++ b/integration-testing/src/bin/arrow-file-to-stream.rs @@ -15,20 +15,24 @@ // specific language governing permissions and limitations // under the License. -use std::env; -use std::fs::File; -use std::io::{self, BufReader}; - use arrow::error::Result; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::StreamWriter; +use clap::Parser; +use std::fs::File; +use std::io::{self, BufReader}; + +#[derive(Debug, Parser)] +#[clap(author, version, about("Read an arrow file and stream to stdout"), long_about = None)] +struct Args { + file_name: String, +} fn main() -> Result<()> { - let args: Vec = env::args().collect(); - let filename = &args[1]; - let f = File::open(filename)?; + let args = Args::parse(); + let f = File::open(&args.file_name)?; let reader = BufReader::new(f); - let mut reader = FileReader::try_new(reader)?; + let mut reader = FileReader::try_new(reader, None)?; let schema = reader.schema(); let mut writer = StreamWriter::try_new(io::stdout(), &schema)?; diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/integration-testing/src/bin/arrow-json-integration-test.rs index 257802028b20..69b73b19f222 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,52 +15,48 @@ // specific language governing permissions and limitations // under the License. -use std::fs::File; - -use clap::{App, Arg}; - +use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Field}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; use arrow::util::integration_util::*; use arrow_integration_testing::read_json_file; +use clap::Parser; +use std::fs::File; + +#[derive(clap::ArgEnum, Debug, Clone)] +#[clap(rename_all = "SCREAMING_SNAKE_CASE")] +enum Mode { + ArrowToJson, + JsonToArrow, + Validate, +} + +#[derive(Debug, Parser)] +#[clap(author, version, about("rust arrow-json-integration-test"), long_about = None)] +struct Args { + #[clap(short, long)] + integration: bool, + #[clap(short, long, help("Path to ARROW file"))] + arrow: String, + #[clap(short, long, help("Path to JSON file"))] + json: String, + #[clap(arg_enum, short, long, default_value_t = Mode::Validate, help="Mode of integration testing tool")] + mode: Mode, + #[clap(short, long)] + verbose: bool, +} fn main() -> Result<()> { - let matches = App::new("rust arrow-json-integration-test") - .arg(Arg::with_name("integration") - .long("integration")) - .arg(Arg::with_name("arrow") - .long("arrow") - .help("path to ARROW file") - .takes_value(true)) - .arg(Arg::with_name("json") - .long("json") - .help("path to JSON file") - .takes_value(true)) - .arg(Arg::with_name("mode") - .long("mode") - .help("mode of integration testing tool (ARROW_TO_JSON, JSON_TO_ARROW, VALIDATE)") - .takes_value(true) - .default_value("VALIDATE")) - .arg(Arg::with_name("verbose") - .long("verbose") - .help("enable/disable verbose mode")) - .get_matches(); - - let arrow_file = matches - .value_of("arrow") - .expect("must provide path to arrow file"); - let json_file = matches - .value_of("json") - .expect("must provide path to json file"); - let mode = matches.value_of("mode").unwrap(); - let verbose = true; //matches.value_of("verbose").is_some(); - - match mode { - "JSON_TO_ARROW" => json_to_arrow(json_file, arrow_file, verbose), - "ARROW_TO_JSON" => arrow_to_json(arrow_file, json_file, verbose), - "VALIDATE" => validate(arrow_file, json_file, verbose), - _ => panic!("mode {} not supported", mode), + let args = Args::parse(); + let arrow_file = args.arrow; + let json_file = args.json; + let verbose = args.verbose; + match args.mode { + Mode::JsonToArrow => json_to_arrow(&json_file, &arrow_file, verbose), + Mode::ArrowToJson => arrow_to_json(&arrow_file, &json_file, verbose), + Mode::Validate => validate(&arrow_file, &json_file, verbose), } } @@ -89,7 +85,7 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> } let arrow_file = File::open(arrow_name)?; - let reader = FileReader::try_new(arrow_file)?; + let reader = FileReader::try_new(arrow_file, None)?; let mut fields: Vec = vec![]; for f in reader.schema().fields() { @@ -113,6 +109,47 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> Ok(()) } +fn canonicalize_schema(schema: &Schema) -> Schema { + let fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Map(child_field, sorted) => match child_field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + let first_field = fields.get(0).unwrap(); + let key_field = Field::new( + "key", + first_field.data_type().clone(), + first_field.is_nullable(), + ); + let second_field = fields.get(1).unwrap(); + let value_field = Field::new( + "value", + second_field.data_type().clone(), + second_field.is_nullable(), + ); + + let struct_type = DataType::Struct(vec![key_field, value_field]); + let child_field = + Field::new("entries", struct_type, child_field.is_nullable()); + + Field::new( + field.name().as_str(), + DataType::Map(Box::new(child_field), *sorted), + field.is_nullable(), + ) + } + _ => panic!( + "The child field of Map type should be Struct type with 2 fields." + ), + }, + _ => field.clone(), + }) + .collect::>(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { eprintln!("Validating {} and {}", arrow_name, json_name); @@ -123,11 +160,11 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { // open Arrow file let arrow_file = File::open(arrow_name)?; - let mut arrow_reader = FileReader::try_new(arrow_file)?; + let mut arrow_reader = FileReader::try_new(arrow_file, None)?; let arrow_schema = arrow_reader.schema().as_ref().to_owned(); // compare schemas - if json_file.schema != arrow_schema { + if canonicalize_schema(&json_file.schema) != canonicalize_schema(&arrow_schema) { return Err(ArrowError::ComputeError(format!( "Schemas do not match. JSON: {:?}. Arrow: {:?}", json_file.schema, arrow_schema diff --git a/integration-testing/src/bin/arrow-stream-to-file.rs b/integration-testing/src/bin/arrow-stream-to-file.rs index f81d42e6eda2..07ac5c7ddd42 100644 --- a/integration-testing/src/bin/arrow-stream-to-file.rs +++ b/integration-testing/src/bin/arrow-stream-to-file.rs @@ -22,7 +22,7 @@ use arrow::ipc::reader::StreamReader; use arrow::ipc::writer::FileWriter; fn main() -> Result<()> { - let mut arrow_stream_reader = StreamReader::try_new(io::stdin())?; + let mut arrow_stream_reader = StreamReader::try_new(io::stdin(), None)?; let schema = arrow_stream_reader.schema(); let mut writer = FileWriter::try_new(io::stdout(), &schema)?; diff --git a/integration-testing/src/bin/flight-test-integration-client.rs b/integration-testing/src/bin/flight-test-integration-client.rs index 1901553109f9..fa99b424e378 100644 --- a/integration-testing/src/bin/flight-test-integration-client.rs +++ b/integration-testing/src/bin/flight-test-integration-client.rs @@ -16,44 +16,53 @@ // under the License. use arrow_integration_testing::flight_client_scenarios; - -use clap::{App, Arg}; - +use clap::Parser; type Error = Box; type Result = std::result::Result; +#[derive(clap::ArgEnum, Debug, Clone)] +enum Scenario { + Middleware, + #[clap(name = "auth:basic_proto")] + AuthBasicProto, +} + +#[derive(Debug, Parser)] +#[clap(author, version, about("rust flight-test-integration-client"), long_about = None)] +struct Args { + #[clap(long, help = "host of flight server")] + host: String, + #[clap(long, help = "port of flight server")] + port: u16, + #[clap( + short, + long, + help = "path to the descriptor file, only used when scenario is not provided. See https://arrow.apache.org/docs/format/Integration.html#json-test-data-format" + )] + path: Option, + #[clap(long, arg_enum)] + scenario: Option, +} + #[tokio::main] async fn main() -> Result { #[cfg(feature = "logging")] tracing_subscriber::fmt::init(); - let matches = App::new("rust flight-test-integration-client") - .arg(Arg::with_name("host").long("host").takes_value(true)) - .arg(Arg::with_name("port").long("port").takes_value(true)) - .arg(Arg::with_name("path").long("path").takes_value(true)) - .arg( - Arg::with_name("scenario") - .long("scenario") - .takes_value(true), - ) - .get_matches(); - - let host = matches.value_of("host").expect("Host is required"); - let port = matches.value_of("port").expect("Port is required"); + let args = Args::parse(); + let host = args.host; + let port = args.port; - match matches.value_of("scenario") { - Some("middleware") => { - flight_client_scenarios::middleware::run_scenario(host, port).await? + match args.scenario { + Some(Scenario::Middleware) => { + flight_client_scenarios::middleware::run_scenario(&host, port).await? } - Some("auth:basic_proto") => { - flight_client_scenarios::auth_basic_proto::run_scenario(host, port).await? + Some(Scenario::AuthBasicProto) => { + flight_client_scenarios::auth_basic_proto::run_scenario(&host, port).await? } - Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), None => { - let path = matches - .value_of("path") - .expect("Path is required if scenario is not specified"); - flight_client_scenarios::integration_test::run_scenario(host, port, path) + let path = args.path.expect("No path is given"); + flight_client_scenarios::integration_test::run_scenario(&host, port, &path) .await?; } } diff --git a/integration-testing/src/bin/flight-test-integration-server.rs b/integration-testing/src/bin/flight-test-integration-server.rs index b1b280743c3c..6ed22ad81d90 100644 --- a/integration-testing/src/bin/flight-test-integration-server.rs +++ b/integration-testing/src/bin/flight-test-integration-server.rs @@ -15,38 +15,43 @@ // specific language governing permissions and limitations // under the License. -use clap::{App, Arg}; - use arrow_integration_testing::flight_server_scenarios; +use clap::Parser; type Error = Box; type Result = std::result::Result; +#[derive(clap::ArgEnum, Debug, Clone)] +enum Scenario { + Middleware, + #[clap(name = "auth:basic_proto")] + AuthBasicProto, +} + +#[derive(Debug, Parser)] +#[clap(author, version, about("rust flight-test-integration-server"), long_about = None)] +struct Args { + #[clap(long)] + port: u16, + #[clap(long, arg_enum)] + scenario: Option, +} + #[tokio::main] async fn main() -> Result { #[cfg(feature = "logging")] tracing_subscriber::fmt::init(); - let matches = App::new("rust flight-test-integration-server") - .about("Integration testing server for Flight.") - .arg(Arg::with_name("port").long("port").takes_value(true)) - .arg( - Arg::with_name("scenario") - .long("scenario") - .takes_value(true), - ) - .get_matches(); - - let port = matches.value_of("port").unwrap_or("0"); + let args = Args::parse(); + let port = args.port; - match matches.value_of("scenario") { - Some("middleware") => { + match args.scenario { + Some(Scenario::Middleware) => { flight_server_scenarios::middleware::scenario_setup(port).await? } - Some("auth:basic_proto") => { + Some(Scenario::AuthBasicProto) => { flight_server_scenarios::auth_basic_proto::scenario_setup(port).await? } - Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), None => { flight_server_scenarios::integration_test::scenario_setup(port).await?; } diff --git a/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs index 5e8cd4671988..ab398d3d2e7b 100644 --- a/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs +++ b/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs @@ -29,7 +29,7 @@ type Result = std::result::Result; type Client = FlightServiceClient; -pub async fn run_scenario(host: &str, port: &str) -> Result { +pub async fn run_scenario(host: &str, port: u16) -> Result { let url = format!("http://{}:{}", host, port); let mut client = FlightServiceClient::connect(url).await?; diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index bf64451e9f3d..62fe2b85d262 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -16,6 +16,7 @@ // under the License. use crate::{read_json_file, ArrowFile}; +use std::collections::HashMap; use arrow::{ array::ArrayRef, @@ -31,6 +32,7 @@ use arrow_flight::{ use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; +use arrow::datatypes::Schema; use std::sync::Arc; type Error = Box; @@ -38,7 +40,7 @@ type Result = std::result::Result; type Client = FlightServiceClient; -pub async fn run_scenario(host: &str, port: &str, path: &str) -> Result { +pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { let url = format!("http://{}:{}", host, port); let client = FlightServiceClient::connect(url).await?; @@ -60,7 +62,7 @@ pub async fn run_scenario(host: &str, port: &str, path: &str) -> Result { batches.clone(), ) .await?; - verify_data(client, descriptor, schema, &batches).await?; + verify_data(client, descriptor, &batches).await?; Ok(()) } @@ -143,7 +145,6 @@ async fn send_batch( async fn verify_data( mut client: Client, descriptor: FlightDescriptor, - expected_schema: SchemaRef, expected_data: &[RecordBatch], ) -> Result { let resp = client.get_flight_info(Request::new(descriptor)).await?; @@ -163,13 +164,7 @@ async fn verify_data( "No locations returned from Flight server", ); for location in endpoint.location { - consume_flight_location( - location, - ticket.clone(), - expected_data, - expected_schema.clone(), - ) - .await?; + consume_flight_location(location, ticket.clone(), expected_data).await?; } } @@ -180,28 +175,29 @@ async fn consume_flight_location( location: Location, ticket: Ticket, expected_data: &[RecordBatch], - schema: SchemaRef, ) -> Result { let mut location = location; // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs // don't recognize this as valid. - location.uri = location.uri.replace("grpc+tcp://", "grpc://"); + // more details: https://github.com/apache/arrow-rs/issues/1398 + location.uri = location.uri.replace("grpc+tcp://", "http://"); let mut client = FlightServiceClient::connect(location.uri).await?; let resp = client.do_get(ticket).await?; let mut resp = resp.into_inner(); - // We already have the schema from the FlightInfo, but the server sends it again as the - // first FlightData. Ignore this one. - let _schema_again = resp.next().await.unwrap(); + let flight_schema = receive_schema_flight_data(&mut resp) + .await + .unwrap_or_else(|| panic!("Failed to receive flight schema")); + let actual_schema = Arc::new(flight_schema); - let mut dictionaries_by_field = vec![None; schema.fields().len()]; + let mut dictionaries_by_id = HashMap::new(); for (counter, expected_batch) in expected_data.iter().enumerate() { let data = receive_batch_flight_data( &mut resp, - schema.clone(), - &mut dictionaries_by_field, + actual_schema.clone(), + &mut dictionaries_by_id, ) .await .unwrap_or_else(|| { @@ -216,10 +212,10 @@ async fn consume_flight_location( assert_eq!(metadata, data.app_metadata); let actual_batch = - flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_field) + flight_data_to_arrow_batch(&data, actual_schema.clone(), &dictionaries_by_id) .expect("Unable to convert flight data to Arrow batch"); - assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(actual_schema, actual_batch.schema()); assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); let schema = expected_batch.schema(); @@ -243,10 +239,24 @@ async fn consume_flight_location( Ok(()) } +async fn receive_schema_flight_data(resp: &mut Streaming) -> Option { + let data = resp.next().await?.ok()?; + let message = arrow::ipc::root_as_message(&data.data_header[..]) + .expect("Error parsing message"); + + // message header is a Schema, so read it + let ipc_schema: ipc::Schema = message + .header_as_schema() + .expect("Unable to read IPC message as schema"); + let schema = ipc::convert::fb_to_schema(ipc_schema); + + Some(schema) +} + async fn receive_batch_flight_data( resp: &mut Streaming, schema: SchemaRef, - dictionaries_by_field: &mut [Option], + dictionaries_by_id: &mut HashMap, ) -> Option { let mut data = resp.next().await?.ok()?; let mut message = arrow::ipc::root_as_message(&data.data_header[..]) @@ -259,7 +269,8 @@ async fn receive_batch_flight_data( .header_as_dictionary_batch() .expect("Error parsing dictionary"), &schema, - dictionaries_by_field, + dictionaries_by_id, + &message.version(), ) .expect("Error reading dictionary"); diff --git a/integration-testing/src/flight_client_scenarios/middleware.rs b/integration-testing/src/flight_client_scenarios/middleware.rs index cbca879dca51..db8c42cc081c 100644 --- a/integration-testing/src/flight_client_scenarios/middleware.rs +++ b/integration-testing/src/flight_client_scenarios/middleware.rs @@ -24,7 +24,7 @@ use tonic::{Request, Status}; type Error = Box; type Result = std::result::Result; -pub async fn run_scenario(host: &str, port: &str) -> Result { +pub async fn run_scenario(host: &str, port: u16) -> Result { let url = format!("http://{}:{}", host, port); let conn = tonic::transport::Endpoint::new(url)?.connect().await?; let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor); diff --git a/integration-testing/src/flight_server_scenarios.rs b/integration-testing/src/flight_server_scenarios.rs index 9163b6920860..e56252f1dfbf 100644 --- a/integration-testing/src/flight_server_scenarios.rs +++ b/integration-testing/src/flight_server_scenarios.rs @@ -27,7 +27,7 @@ pub mod middleware; type Error = Box; type Result = std::result::Result; -pub async fn listen_on(port: &str) -> Result { +pub async fn listen_on(port: u16) -> Result { let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?; let listener = TcpListener::bind(addr).await?; diff --git a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs index ea7ad3c3385c..68a4a0d3b4ad 100644 --- a/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs +++ b/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -37,7 +37,7 @@ use prost::Message; use crate::{AUTH_PASSWORD, AUTH_USERNAME}; -pub async fn scenario_setup(port: &str) -> Result { +pub async fn scenario_setup(port: u16) -> Result { let service = AuthBasicProtoScenarioImpl { username: AUTH_USERNAME.into(), password: AUTH_PASSWORD.into(), @@ -58,6 +58,7 @@ pub async fn scenario_setup(port: &str) -> Result { pub struct AuthBasicProtoScenarioImpl { username: Arc, password: Arc, + #[allow(dead_code)] peer_identity: Arc>>, } @@ -191,7 +192,8 @@ impl FlightService for AuthBasicProtoScenarioImpl { &self, request: Request>, ) -> Result, Status> { - self.check_auth(request.metadata()).await?; + let metadata = request.metadata(); + self.check_auth(metadata).await?; Err(Status::unimplemented("Not yet implemented")) } @@ -219,7 +221,8 @@ impl FlightService for AuthBasicProtoScenarioImpl { &self, request: Request>, ) -> Result, Status> { - self.check_auth(request.metadata()).await?; + let metadata = request.metadata(); + self.check_auth(metadata).await?; Err(Status::unimplemented("Not yet implemented")) } } diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index ae370b6f8366..7ad3d18eb5ba 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -43,7 +43,7 @@ type TonicStream = Pin + Send + Sync + 'static>>; type Error = Box; type Result = std::result::Result; -pub async fn scenario_setup(port: &str) -> Result { +pub async fn scenario_setup(port: u16) -> Result { let addr = super::listen_on(port).await?; let service = FlightServiceImpl { @@ -284,7 +284,7 @@ async fn record_batch_from_message( message: ipc::Message<'_>, data_body: &[u8], schema_ref: SchemaRef, - dictionaries_by_field: &[Option], + dictionaries_by_id: &HashMap, ) -> Result { let ipc_batch = message.header_as_record_batch().ok_or_else(|| { Status::internal("Could not parse message header as record batch") @@ -294,7 +294,9 @@ async fn record_batch_from_message( data_body, ipc_batch, schema_ref, - dictionaries_by_field, + dictionaries_by_id, + None, + &message.version(), ); arrow_batch_result.map_err(|e| { @@ -306,14 +308,19 @@ async fn dictionary_from_message( message: ipc::Message<'_>, data_body: &[u8], schema_ref: SchemaRef, - dictionaries_by_field: &mut [Option], + dictionaries_by_id: &mut HashMap, ) -> Result<(), Status> { let ipc_batch = message.header_as_dictionary_batch().ok_or_else(|| { Status::internal("Could not parse message header as dictionary batch") })?; - let dictionary_batch_result = - reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_field); + let dictionary_batch_result = reader::read_dictionary( + data_body, + ipc_batch, + &schema_ref, + dictionaries_by_id, + &message.version(), + ); dictionary_batch_result.map_err(|e| { Status::internal(format!("Could not convert to Dictionary: {:?}", e)) }) @@ -330,7 +337,7 @@ async fn save_uploaded_chunks( let mut chunks = vec![]; let mut uploaded_chunks = uploaded_chunks.lock().await; - let mut dictionaries_by_field = vec![None; schema_ref.fields().len()]; + let mut dictionaries_by_id = HashMap::new(); while let Some(Ok(data)) = input_stream.next().await { let message = arrow::ipc::root_as_message(&data.data_header[..]) @@ -349,7 +356,7 @@ async fn save_uploaded_chunks( message, &data.data_body, schema_ref.clone(), - &dictionaries_by_field, + &dictionaries_by_id, ) .await?; @@ -360,7 +367,7 @@ async fn save_uploaded_chunks( message, &data.data_body, schema_ref.clone(), - &mut dictionaries_by_field, + &mut dictionaries_by_id, ) .await?; } diff --git a/integration-testing/src/flight_server_scenarios/middleware.rs b/integration-testing/src/flight_server_scenarios/middleware.rs index 1416acc4088c..5876ac9bfe6d 100644 --- a/integration-testing/src/flight_server_scenarios/middleware.rs +++ b/integration-testing/src/flight_server_scenarios/middleware.rs @@ -31,7 +31,7 @@ type TonicStream = Pin + Send + Sync + 'static>>; type Error = Box; type Result = std::result::Result; -pub async fn scenario_setup(port: &str) -> Result { +pub async fn scenario_setup(port: u16) -> Result { let service = MiddlewareScenarioImpl {}; let svc = FlightServiceServer::new(service); let addr = super::listen_on(port).await?; diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs index f25157f635bc..c7796ece4c73 100644 --- a/integration-testing/src/lib.rs +++ b/integration-testing/src/lib.rs @@ -202,6 +202,39 @@ fn array_from_json( Value::String(s) => { s.parse().expect("Unable to parse string as i64") } + Value::Object(ref map) + if map.contains_key("days") + && map.contains_key("milliseconds") => + { + match field.data_type() { + DataType::Interval(IntervalUnit::DayTime) => { + let days = map.get("days").unwrap(); + let milliseconds = map.get("milliseconds").unwrap(); + + match (days, milliseconds) { + (Value::Number(d), Value::Number(m)) => { + let mut bytes = [0_u8; 8]; + let m = (m.as_i64().unwrap() as i32) + .to_le_bytes(); + let d = (d.as_i64().unwrap() as i32) + .to_le_bytes(); + + let c = [d, m].concat(); + bytes.copy_from_slice(c.as_slice()); + i64::from_le_bytes(bytes) + } + _ => panic!( + "Unable to parse {:?} as interval daytime", + value + ), + } + } + _ => panic!( + "Unable to parse {:?} as interval daytime", + value + ), + } + } _ => panic!("Unable to parse {:?} as number", value), }), _ => b.append_null(), @@ -280,6 +313,49 @@ fn array_from_json( } Ok(Arc::new(b.finish())) } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let mut b = IntervalMonthDayNanoBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { + Value::Object(v) => { + let months = v.get("months").unwrap(); + let days = v.get("days").unwrap(); + let nanoseconds = v.get("nanoseconds").unwrap(); + match (months, days, nanoseconds) { + ( + Value::Number(months), + Value::Number(days), + Value::Number(nanoseconds), + ) => { + let months = months.as_i64().unwrap() as i32; + let days = days.as_i64().unwrap() as i32; + let nanoseconds = nanoseconds.as_i64().unwrap(); + let months_days_ns: i128 = ((nanoseconds as i128) + & 0xFFFFFFFFFFFFFFFF) + << 64 + | ((days as i128) & 0xFFFFFFFF) << 32 + | ((months as i128) & 0xFFFFFFFF); + months_days_ns + } + (_, _, _) => { + panic!("Unable to parse {:?} as MonthDayNano", v) + } + } + } + _ => panic!("Unable to parse {:?} as MonthDayNano", value), + }), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } DataType::Float32 => { let mut b = Float32Builder::new(json_col.count); for (is_valid, value) in json_col @@ -420,7 +496,7 @@ fn array_from_json( .offset(0) .add_buffer(Buffer::from(&offsets.to_byte_slice())) .add_child_data(child_array.data().clone()) - .null_bit_buffer(null_buf) + .null_bit_buffer(Some(null_buf)) .build() .unwrap(); Ok(Arc::new(ListArray::from(list_data))) @@ -448,7 +524,7 @@ fn array_from_json( .offset(0) .add_buffer(Buffer::from(&offsets.to_byte_slice())) .add_child_data(child_array.data().clone()) - .null_bit_buffer(null_buf) + .null_bit_buffer(Some(null_buf)) .build() .unwrap(); Ok(Arc::new(LargeListArray::from(list_data))) @@ -464,7 +540,7 @@ fn array_from_json( let list_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) .add_child_data(child_array.data().clone()) - .null_bit_buffer(null_buf) + .null_bit_buffer(Some(null_buf)) .build() .unwrap(); Ok(Arc::new(FixedSizeListArray::from(list_data))) @@ -474,7 +550,7 @@ fn array_from_json( let null_buf = create_null_buf(&json_col); let mut array_data = ArrayData::builder(field.data_type().clone()) .len(json_col.count) - .null_bit_buffer(null_buf); + .null_bit_buffer(Some(null_buf)); for (field, col) in fields.iter().zip(json_col.children.unwrap()) { let array = array_from_json(field, col, dictionaries)?; @@ -502,7 +578,12 @@ fn array_from_json( .get(&dict_id); match dictionary { Some(dictionary) => dictionary_array_from_json( - field, json_col, key_type, value_type, dictionary, + field, + json_col, + key_type, + value_type, + dictionary, + dictionaries, ), None => Err(ArrowError::JsonError(format!( "Unable to find dictionary for field {:?}", @@ -510,6 +591,81 @@ fn array_from_json( ))), } } + DataType::Decimal(precision, scale) => { + let mut b = DecimalBuilder::new(json_col.count, *precision, *scale); + // C++ interop tests involve incompatible decimal values + unsafe { + b.disable_value_validation(); + } + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap().parse::().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Map(child_field, _) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| v.as_i64().unwrap() as i32) + .collect(); + let array_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data().clone()) + .null_bit_buffer(Some(null_buf)) + .build() + .unwrap(); + + let array = MapArray::from(array_data); + Ok(Arc::new(array)) + } + DataType::Union(fields, field_type_ids, _) => { + let type_ids = if let Some(type_id) = json_col.type_id { + type_id + } else { + return Err(ArrowError::JsonError( + "Cannot find expected type_id in json column".to_string(), + )); + }; + + let offset: Option = json_col.offset.map(|offsets| { + let offsets: Vec = + offsets.iter().map(|v| v.as_i64().unwrap() as i32).collect(); + Buffer::from(&offsets.to_byte_slice()) + }); + + let mut children: Vec<(Field, Arc)> = vec![]; + for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let array = array_from_json(field, col, dictionaries)?; + children.push((field.clone(), array)); + } + + let array = UnionArray::try_new( + field_type_ids, + Buffer::from(&type_ids.to_byte_slice()), + offset, + children, + ) + .unwrap(); + Ok(Arc::new(array)) + } t => Err(ArrowError::JsonError(format!( "data type {:?} not supported", t @@ -523,6 +679,7 @@ fn dictionary_array_from_json( dict_key: &DataType, dict_value: &DataType, dictionary: &ArrowJsonDictionaryBatch, + dictionaries: Option<&HashMap>, ) -> Result { match dict_key { DataType::Int8 @@ -550,15 +707,17 @@ fn dictionary_array_from_json( let keys = array_from_json(&key_field, json_col, None)?; // note: not enough info on nullability of dictionary let value_field = Field::new("value", dict_value.clone(), true); - println!("dictionary value type: {:?}", dict_value); - let values = - array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; + let values = array_from_json( + &value_field, + dictionary.data.columns[0].clone(), + dictionaries, + )?; // convert key and value to dictionary data let dict_data = ArrayData::builder(field.data_type().clone()) .len(keys.len()) .add_buffer(keys.data().buffers()[0].clone()) - .null_bit_buffer(null_buf) + .null_bit_buffer(Some(null_buf)) .add_child_data(values.data().clone()) .build() .unwrap(); diff --git a/parquet-testing b/parquet-testing index 8e7badc6a381..7175a4713397 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 8e7badc6a3817a02e06d17b5d8ab6b6dc356e890 +Subproject commit 7175a471339704c7645af0fe66c68305e2e6759c diff --git a/parquet/CONTRIBUTING.md b/parquet/CONTRIBUTING.md index 834b6af9d4ef..77e9f417e49a 100644 --- a/parquet/CONTRIBUTING.md +++ b/parquet/CONTRIBUTING.md @@ -59,3 +59,8 @@ Run `cargo bench` for benchmarks. To build documentation, run `cargo doc --no-deps`. To compile and view in the browser, run `cargo doc --no-deps --open`. + +## Update Supported Parquet Version + +To update Parquet format to a newer version, check if [parquet-format](https://github.com/sunchao/parquet-format-rs) +version is available. Then simply update version of `parquet-format` crate in Cargo.toml. diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 289513fbed3c..28347bcb7dda 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -17,67 +17,92 @@ [package] name = "parquet" -version = "7.0.0-SNAPSHOT" +version = "16.0.0" license = "Apache-2.0" description = "Apache Parquet implementation in Rust" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] -keywords = [ "arrow", "parquet", "hadoop" ] +keywords = ["arrow", "parquet", "hadoop"] readme = "README.md" build = "build.rs" -edition = "2018" +edition = "2021" +rust-version = "1.57" [dependencies] -# update note: pin `parquet-format` to specific version until it does not break at minor -# version, see ARROW-11187. -parquet-format = "~2.6.1" -byteorder = "1" -thrift = "0.13" -snap = { version = "1.0", optional = true } -brotli = { version = "3.3", optional = true } -flate2 = { version = "1.0", optional = true } -lz4 = { version = "1.23", optional = true } -zstd = { version = "0.9", optional = true } -chrono = "0.4" -num-bigint = "0.4" -arrow = { path = "../arrow", version = "7.0.0-SNAPSHOT", optional = true, default-features = false, features = ["ipc"] } -base64 = { version = "0.13", optional = true } -clap = { version = "2.33.3", optional = true } -serde_json = { version = "1.0", features = ["preserve_order"], optional = true } -rand = "0.8" +parquet-format = { version = "4.0.0", default-features = false } +bytes = { version = "1.1", default-features = false, features = ["std"] } +byteorder = { version = "1", default-features = false } +thrift = { version = "0.13", default-features = false } +snap = { version = "1.0", default-features = false, optional = true } +brotli = { version = "3.3", default-features = false, features = ["std"], optional = true } +flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } +lz4 = { version = "1.23", default-features = false, optional = true } +zstd = { version = "0.11.1", optional = true, default-features = false } +chrono = { version = "0.4", default-features = false, features = ["alloc"] } +num = { version = "0.4", default-features = false } +num-bigint = { version = "0.4", default-features = false } +arrow = { path = "../arrow", version = "16.0.0", optional = true, default-features = false, features = ["ipc"] } +base64 = { version = "0.13", default-features = false, features = ["std"], optional = true } +clap = { version = "~3.1", default-features = false, features = ["std", "derive", "env"], optional = true } +serde_json = { version = "1.0", default-features = false, optional = true } +rand = { version = "0.8", default-features = false } +futures = { version = "0.3", default-features = false, features = ["std" ], optional = true } +tokio = { version = "1.0", optional = true, default-features = false, features = ["macros", "fs", "rt", "io-util"] } [dev-dependencies] -criterion = "0.3" -rand = "0.8" -snap = "1.0" -brotli = "3.3" -flate2 = "1.0" -lz4 = "1.23" -serde_json = { version = "1.0", features = ["preserve_order"] } -arrow = { path = "../arrow", version = "7.0.0-SNAPSHOT", default-features = false, features = ["test_utils"] } +base64 = { version = "0.13", default-features = false, features = ["std"] } +criterion = { version = "0.3", default-features = false } +snap = { version = "1.0", default-features = false } +tempfile = { version = "3.0", default-features = false } +brotli = { version = "3.3", default-features = false, features = [ "std" ] } +flate2 = { version = "1.0", default-features = false, features = [ "rust_backend" ] } +lz4 = { version = "1.23", default-features = false } +zstd = { version = "0.11", default-features = false } +serde_json = { version = "1.0", default-features = false, features = ["preserve_order"] } +arrow = { path = "../arrow", version = "16.0.0", default-features = false, features = ["ipc", "test_utils", "prettyprint"] } + +[package.metadata.docs.rs] +all-features = true [features] default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64"] -cli = ["serde_json", "base64", "clap"] +# Enable arrow reader/writer APIs +arrow = ["dep:arrow", "base64"] +# Enable CLI tools +cli = ["serde_json", "base64", "clap","arrow/csv"] +# Enable internal testing APIs test_common = [] +# Experimental, unstable functionality primarily used for testing +experimental = [] +# Enable async APIs +async = ["futures", "tokio"] -[[ bin ]] +[[bin]] name = "parquet-read" required-features = ["cli"] -[[ bin ]] +[[bin]] name = "parquet-schema" required-features = ["cli"] -[[ bin ]] +[[bin]] name = "parquet-rowcount" required-features = ["cli"] +[[bin]] +name = "parquet-fromcsv" +required-features = ["cli"] + [[bench]] name = "arrow_writer" +required-features = ["arrow"] harness = false [[bench]] -name = "arrow_array_reader" +name = "arrow_reader" +required-features = ["arrow", "test_common", "experimental"] harness = false + +[lib] +bench = false diff --git a/parquet/README.md b/parquet/README.md index c071fb8d69d4..fbb6e3e1b5d5 100644 --- a/parquet/README.md +++ b/parquet/README.md @@ -23,49 +23,11 @@ This crate contains the official Native Rust implementation of [Apache Parquet](https://parquet.apache.org/), which is part of the [Apache Arrow](https://arrow.apache.org/) project. -## Example - -Example usage of reading data: - -```rust -use std::fs::File; -use std::path::Path; -use parquet::file::reader::{FileReader, SerializedFileReader}; - -let file = File::open(&Path::new("/path/to/file")).unwrap(); -let reader = SerializedFileReader::new(file).unwrap(); -let mut iter = reader.get_row_iter(None).unwrap(); -while let Some(record) = iter.next() { - println!("{}", record); -} -``` - -See [crate documentation](https://docs.rs/crate/parquet) on available API. +See [crate documentation](https://docs.rs/parquet/latest/parquet/) for examples and the full API. ## Rust Version Compatbility -This crate is tested with the latest stable version of Rust. We do not currrently test against other, older versions of the Rust compiler. - -## Upgrading from versions prior to 4.0 - -If you are upgrading from version 3.0 or previous of this crate, you -likely need to change your code to use [`ConvertedType`] rather than -[`LogicalType`] to preserve existing behaviour in your code. - -Version 2.4.0 of the Parquet format introduced a `LogicalType` to replace the existing `ConvertedType`. -This crate used `parquet::basic::LogicalType` to map to the `ConvertedType`, but this has been renamed to `parquet::basic::ConvertedType` from version 4.0 of this crate. - -The `ConvertedType` is deprecated in the format, but is still written -to preserve backward compatibility. -It is preferred that `LogicalType` is used, as it supports nanosecond -precision timestamps without using the deprecated `Int96` Parquet type. - -## Supported Parquet Version - -- Parquet-format 2.6.0 - -To update Parquet format to a newer version, check if [parquet-format](https://github.com/sunchao/parquet-format-rs) -version is available. Then simply update version of `parquet-format` crate in Cargo.toml. +This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. ## Features @@ -75,13 +37,15 @@ version is available. Then simply update version of `parquet-format` crate in Ca - [x] Primitive column value readers - [x] Row record reader - [x] Arrow record reader + - [x] Async support (to Arrow) - [x] Statistics support - [x] Write support - [x] Primitive column value writers - [ ] Row record writer - [x] Arrow record writer + - [ ] Async support - [ ] Predicate pushdown -- [x] Parquet format 2.6.0 support +- [x] Parquet format 4.0.0 support ## License diff --git a/parquet/benches/arrow_array_reader.rs b/parquet/benches/arrow_array_reader.rs deleted file mode 100644 index acc5141bcbab..000000000000 --- a/parquet/benches/arrow_array_reader.rs +++ /dev/null @@ -1,762 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use criterion::{criterion_group, criterion_main, Criterion}; -use parquet::util::{DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator}; -use parquet::{ - arrow::array_reader::ArrayReader, - basic::Encoding, - column::page::PageIterator, - data_type::{ByteArrayType, Int32Type}, - schema::types::{ColumnDescPtr, SchemaDescPtr}, -}; -use std::{collections::VecDeque, sync::Arc}; - -fn build_test_schema() -> SchemaDescPtr { - use parquet::schema::{parser::parse_message_type, types::SchemaDescriptor}; - let message_type = " - message test_schema { - REQUIRED INT32 mandatory_int32_leaf; - OPTIONAL INT32 optional_int32_leaf; - REQUIRED BYTE_ARRAY mandatory_string_leaf (UTF8); - OPTIONAL BYTE_ARRAY optional_string_leaf (UTF8); - } - "; - parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap() -} - -// test data params -const NUM_ROW_GROUPS: usize = 1; -const PAGES_PER_GROUP: usize = 2; -const VALUES_PER_PAGE: usize = 10_000; -const BATCH_SIZE: usize = 8192; - -use rand::{rngs::StdRng, Rng, SeedableRng}; - -pub fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - -fn build_plain_encoded_int32_page_iterator( - schema: SchemaDescPtr, - column_desc: ColumnDescPtr, - null_density: f32, -) -> impl PageIterator + Clone { - let max_def_level = column_desc.max_def_level(); - let max_rep_level = column_desc.max_rep_level(); - let rep_levels = vec![0; VALUES_PER_PAGE]; - let mut rng = seedable_rng(); - let mut pages: Vec> = Vec::new(); - let mut int32_value = 0; - for _i in 0..NUM_ROW_GROUPS { - let mut column_chunk_pages = Vec::new(); - for _j in 0..PAGES_PER_GROUP { - // generate page - let mut values = Vec::with_capacity(VALUES_PER_PAGE); - let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); - for _k in 0..VALUES_PER_PAGE { - let def_level = if rng.gen::() < null_density { - max_def_level - 1 - } else { - max_def_level - }; - if def_level == max_def_level { - int32_value += 1; - values.push(int32_value); - } - def_levels.push(def_level); - } - let mut page_builder = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - page_builder.add_rep_levels(max_rep_level, &rep_levels); - page_builder.add_def_levels(max_def_level, &def_levels); - page_builder.add_values::(Encoding::PLAIN, &values); - column_chunk_pages.push(page_builder.consume()); - } - pages.push(column_chunk_pages); - } - - InMemoryPageIterator::new(schema, column_desc, pages) -} - -fn build_dictionary_encoded_int32_page_iterator( - schema: SchemaDescPtr, - column_desc: ColumnDescPtr, - null_density: f32, -) -> impl PageIterator + Clone { - use parquet::encoding::{DictEncoder, Encoder}; - let max_def_level = column_desc.max_def_level(); - let max_rep_level = column_desc.max_rep_level(); - let rep_levels = vec![0; VALUES_PER_PAGE]; - // generate 1% unique values - const NUM_UNIQUE_VALUES: usize = VALUES_PER_PAGE / 100; - let unique_values = (0..NUM_UNIQUE_VALUES) - .map(|x| (x + 1) as i32) - .collect::>(); - let mut rng = seedable_rng(); - let mut pages: Vec> = Vec::new(); - for _i in 0..NUM_ROW_GROUPS { - let mut column_chunk_pages = VecDeque::new(); - let mem_tracker = Arc::new(parquet::memory::MemTracker::new()); - let mut dict_encoder = - DictEncoder::::new(column_desc.clone(), mem_tracker); - // add data pages - for _j in 0..PAGES_PER_GROUP { - // generate page - let mut values = Vec::with_capacity(VALUES_PER_PAGE); - let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); - for _k in 0..VALUES_PER_PAGE { - let def_level = if rng.gen::() < null_density { - max_def_level - 1 - } else { - max_def_level - }; - if def_level == max_def_level { - // select random value from list of unique values - let int32_value = unique_values[rng.gen_range(0..NUM_UNIQUE_VALUES)]; - values.push(int32_value); - } - def_levels.push(def_level); - } - let mut page_builder = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - page_builder.add_rep_levels(max_rep_level, &rep_levels); - page_builder.add_def_levels(max_def_level, &def_levels); - let _ = dict_encoder.put(&values); - let indices = dict_encoder - .write_indices() - .expect("write_indices() should be OK"); - page_builder.add_indices(indices); - column_chunk_pages.push_back(page_builder.consume()); - } - // add dictionary page - let dict = dict_encoder - .write_dict() - .expect("write_dict() should be OK"); - let dict_page = parquet::column::page::Page::DictionaryPage { - buf: dict, - num_values: dict_encoder.num_entries() as u32, - encoding: Encoding::RLE_DICTIONARY, - is_sorted: false, - }; - column_chunk_pages.push_front(dict_page); - pages.push(column_chunk_pages.into()); - } - - InMemoryPageIterator::new(schema, column_desc, pages) -} - -fn build_plain_encoded_string_page_iterator( - schema: SchemaDescPtr, - column_desc: ColumnDescPtr, - null_density: f32, -) -> impl PageIterator + Clone { - let max_def_level = column_desc.max_def_level(); - let max_rep_level = column_desc.max_rep_level(); - let rep_levels = vec![0; VALUES_PER_PAGE]; - let mut rng = seedable_rng(); - let mut pages: Vec> = Vec::new(); - for i in 0..NUM_ROW_GROUPS { - let mut column_chunk_pages = Vec::new(); - for j in 0..PAGES_PER_GROUP { - // generate page - let mut values = Vec::with_capacity(VALUES_PER_PAGE); - let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); - for k in 0..VALUES_PER_PAGE { - let def_level = if rng.gen::() < null_density { - max_def_level - 1 - } else { - max_def_level - }; - if def_level == max_def_level { - let string_value = - format!("Test value {}, row group: {}, page: {}", k, i, j); - values - .push(parquet::data_type::ByteArray::from(string_value.as_str())); - } - def_levels.push(def_level); - } - let mut page_builder = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - page_builder.add_rep_levels(max_rep_level, &rep_levels); - page_builder.add_def_levels(max_def_level, &def_levels); - page_builder.add_values::(Encoding::PLAIN, &values); - column_chunk_pages.push(page_builder.consume()); - } - pages.push(column_chunk_pages); - } - - InMemoryPageIterator::new(schema, column_desc, pages) -} - -fn build_dictionary_encoded_string_page_iterator( - schema: SchemaDescPtr, - column_desc: ColumnDescPtr, - null_density: f32, -) -> impl PageIterator + Clone { - use parquet::encoding::{DictEncoder, Encoder}; - let max_def_level = column_desc.max_def_level(); - let max_rep_level = column_desc.max_rep_level(); - let rep_levels = vec![0; VALUES_PER_PAGE]; - // generate 1% unique values - const NUM_UNIQUE_VALUES: usize = VALUES_PER_PAGE / 100; - let unique_values = (0..NUM_UNIQUE_VALUES) - .map(|x| format!("Dictionary value {}", x)) - .collect::>(); - let mut rng = seedable_rng(); - let mut pages: Vec> = Vec::new(); - for _i in 0..NUM_ROW_GROUPS { - let mut column_chunk_pages = VecDeque::new(); - let mem_tracker = Arc::new(parquet::memory::MemTracker::new()); - let mut dict_encoder = - DictEncoder::::new(column_desc.clone(), mem_tracker); - // add data pages - for _j in 0..PAGES_PER_GROUP { - // generate page - let mut values = Vec::with_capacity(VALUES_PER_PAGE); - let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); - for _k in 0..VALUES_PER_PAGE { - let def_level = if rng.gen::() < null_density { - max_def_level - 1 - } else { - max_def_level - }; - if def_level == max_def_level { - // select random value from list of unique values - let string_value = - unique_values[rng.gen_range(0..NUM_UNIQUE_VALUES)].as_str(); - values.push(parquet::data_type::ByteArray::from(string_value)); - } - def_levels.push(def_level); - } - let mut page_builder = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - page_builder.add_rep_levels(max_rep_level, &rep_levels); - page_builder.add_def_levels(max_def_level, &def_levels); - let _ = dict_encoder.put(&values); - let indices = dict_encoder - .write_indices() - .expect("write_indices() should be OK"); - page_builder.add_indices(indices); - column_chunk_pages.push_back(page_builder.consume()); - } - // add dictionary page - let dict = dict_encoder - .write_dict() - .expect("write_dict() should be OK"); - let dict_page = parquet::column::page::Page::DictionaryPage { - buf: dict, - num_values: dict_encoder.num_entries() as u32, - encoding: Encoding::RLE_DICTIONARY, - is_sorted: false, - }; - column_chunk_pages.push_front(dict_page); - pages.push(column_chunk_pages.into()); - } - - InMemoryPageIterator::new(schema, column_desc, pages) -} - -fn bench_array_reader(mut array_reader: impl ArrayReader) -> usize { - // test procedure: read data in batches of 8192 until no more data - let mut total_count = 0; - loop { - let array = array_reader.next_batch(BATCH_SIZE); - let array_len = array.unwrap().len(); - total_count += array_len; - if array_len < BATCH_SIZE { - break; - } - } - total_count -} - -fn create_int32_arrow_array_reader( - page_iterator: impl PageIterator + 'static, - column_desc: ColumnDescPtr, -) -> impl ArrayReader { - use parquet::arrow::arrow_array_reader::{ArrowArrayReader, PrimitiveArrayConverter}; - let converter = PrimitiveArrayConverter::::new(); - ArrowArrayReader::try_new(page_iterator, column_desc, converter, None).unwrap() -} - -fn create_int32_primitive_array_reader( - page_iterator: impl PageIterator + 'static, - column_desc: ColumnDescPtr, -) -> impl ArrayReader { - use parquet::arrow::array_reader::PrimitiveArrayReader; - PrimitiveArrayReader::::new(Box::new(page_iterator), column_desc, None) - .unwrap() -} - -fn create_string_arrow_array_reader( - page_iterator: impl PageIterator + 'static, - column_desc: ColumnDescPtr, -) -> impl ArrayReader { - use parquet::arrow::arrow_array_reader::{ArrowArrayReader, StringArrayConverter}; - let converter = StringArrayConverter::new(); - ArrowArrayReader::try_new(page_iterator, column_desc, converter, None).unwrap() -} - -fn create_string_complex_array_reader( - page_iterator: impl PageIterator + 'static, - column_desc: ColumnDescPtr, -) -> impl ArrayReader { - use parquet::arrow::array_reader::ComplexObjectArrayReader; - use parquet::arrow::converter::{Utf8ArrayConverter, Utf8Converter}; - let converter = Utf8Converter::new(Utf8ArrayConverter {}); - ComplexObjectArrayReader::::new( - Box::new(page_iterator), - column_desc, - converter, - None, - ) - .unwrap() -} - -fn add_benches(c: &mut Criterion) { - const EXPECTED_VALUE_COUNT: usize = - NUM_ROW_GROUPS * PAGES_PER_GROUP * VALUES_PER_PAGE; - let mut group = c.benchmark_group("arrow_array_reader"); - - let mut count: usize = 0; - - let schema = build_test_schema(); - let mandatory_int32_column_desc = schema.column(0); - let optional_int32_column_desc = schema.column(1); - let mandatory_string_column_desc = schema.column(2); - // println!("mandatory_string_column_desc: {:?}", mandatory_string_column_desc); - let optional_string_column_desc = schema.column(3); - // println!("optional_string_column_desc: {:?}", optional_string_column_desc); - - // primitive / int32 benchmarks - // ============================= - - // int32, plain encoded, no NULLs - let plain_int32_no_null_data = build_plain_encoded_int32_page_iterator( - schema.clone(), - mandatory_int32_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read Int32Array, plain encoded, mandatory, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_int32_primitive_array_reader( - plain_int32_no_null_data.clone(), - mandatory_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read Int32Array, plain encoded, mandatory, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_int32_arrow_array_reader( - plain_int32_no_null_data.clone(), - mandatory_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - let plain_int32_no_null_data = build_plain_encoded_int32_page_iterator( - schema.clone(), - optional_int32_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read Int32Array, plain encoded, optional, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_int32_primitive_array_reader( - plain_int32_no_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read Int32Array, plain encoded, optional, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_int32_arrow_array_reader( - plain_int32_no_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - // int32, plain encoded, half NULLs - let plain_int32_half_null_data = build_plain_encoded_int32_page_iterator( - schema.clone(), - optional_int32_column_desc.clone(), - 0.5, - ); - group.bench_function( - "read Int32Array, plain encoded, optional, half NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_int32_primitive_array_reader( - plain_int32_half_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read Int32Array, plain encoded, optional, half NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_int32_arrow_array_reader( - plain_int32_half_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - // int32, dictionary encoded, no NULLs - let dictionary_int32_no_null_data = build_dictionary_encoded_int32_page_iterator( - schema.clone(), - mandatory_int32_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read Int32Array, dictionary encoded, mandatory, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_int32_primitive_array_reader( - dictionary_int32_no_null_data.clone(), - mandatory_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read Int32Array, dictionary encoded, mandatory, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_int32_arrow_array_reader( - dictionary_int32_no_null_data.clone(), - mandatory_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - let dictionary_int32_no_null_data = build_dictionary_encoded_int32_page_iterator( - schema.clone(), - optional_int32_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read Int32Array, dictionary encoded, optional, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_int32_primitive_array_reader( - dictionary_int32_no_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read Int32Array, dictionary encoded, optional, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_int32_arrow_array_reader( - dictionary_int32_no_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - // int32, dictionary encoded, half NULLs - let dictionary_int32_half_null_data = build_dictionary_encoded_int32_page_iterator( - schema.clone(), - optional_int32_column_desc.clone(), - 0.5, - ); - group.bench_function( - "read Int32Array, dictionary encoded, optional, half NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_int32_primitive_array_reader( - dictionary_int32_half_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read Int32Array, dictionary encoded, optional, half NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_int32_arrow_array_reader( - dictionary_int32_half_null_data.clone(), - optional_int32_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - // string benchmarks - //============================== - - // string, plain encoded, no NULLs - let plain_string_no_null_data = build_plain_encoded_string_page_iterator( - schema.clone(), - mandatory_string_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read StringArray, plain encoded, mandatory, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_string_complex_array_reader( - plain_string_no_null_data.clone(), - mandatory_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read StringArray, plain encoded, mandatory, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_string_arrow_array_reader( - plain_string_no_null_data.clone(), - mandatory_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - let plain_string_no_null_data = build_plain_encoded_string_page_iterator( - schema.clone(), - optional_string_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read StringArray, plain encoded, optional, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_string_complex_array_reader( - plain_string_no_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read StringArray, plain encoded, optional, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_string_arrow_array_reader( - plain_string_no_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - // string, plain encoded, half NULLs - let plain_string_half_null_data = build_plain_encoded_string_page_iterator( - schema.clone(), - optional_string_column_desc.clone(), - 0.5, - ); - group.bench_function( - "read StringArray, plain encoded, optional, half NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_string_complex_array_reader( - plain_string_half_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read StringArray, plain encoded, optional, half NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_string_arrow_array_reader( - plain_string_half_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - // string, dictionary encoded, no NULLs - let dictionary_string_no_null_data = build_dictionary_encoded_string_page_iterator( - schema.clone(), - mandatory_string_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read StringArray, dictionary encoded, mandatory, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_string_complex_array_reader( - dictionary_string_no_null_data.clone(), - mandatory_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read StringArray, dictionary encoded, mandatory, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_string_arrow_array_reader( - dictionary_string_no_null_data.clone(), - mandatory_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - let dictionary_string_no_null_data = build_dictionary_encoded_string_page_iterator( - schema.clone(), - optional_string_column_desc.clone(), - 0.0, - ); - group.bench_function( - "read StringArray, dictionary encoded, optional, no NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_string_complex_array_reader( - dictionary_string_no_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read StringArray, dictionary encoded, optional, no NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_string_arrow_array_reader( - dictionary_string_no_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - // string, dictionary encoded, half NULLs - let dictionary_string_half_null_data = build_dictionary_encoded_string_page_iterator( - schema, - optional_string_column_desc.clone(), - 0.5, - ); - group.bench_function( - "read StringArray, dictionary encoded, optional, half NULLs - old", - |b| { - b.iter(|| { - let array_reader = create_string_complex_array_reader( - dictionary_string_half_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.bench_function( - "read StringArray, dictionary encoded, optional, half NULLs - new", - |b| { - b.iter(|| { - let array_reader = create_string_arrow_array_reader( - dictionary_string_half_null_data.clone(), - optional_string_column_desc.clone(), - ); - count = bench_array_reader(array_reader); - }) - }, - ); - assert_eq!(count, EXPECTED_VALUE_COUNT); - - group.finish(); -} - -criterion_group!(benches, add_benches); -criterion_main!(benches); diff --git a/parquet/benches/arrow_reader.rs b/parquet/benches/arrow_reader.rs new file mode 100644 index 000000000000..647a8dc6f393 --- /dev/null +++ b/parquet/benches/arrow_reader.rs @@ -0,0 +1,697 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::datatypes::DataType; +use criterion::measurement::WallTime; +use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use num::FromPrimitive; +use parquet::basic::Type; +use parquet::util::{DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator}; +use parquet::{ + arrow::array_reader::ArrayReader, + basic::Encoding, + column::page::PageIterator, + data_type::{ByteArrayType, Int32Type, Int64Type}, + schema::types::{ColumnDescPtr, SchemaDescPtr}, +}; +use rand::distributions::uniform::SampleUniform; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use std::{collections::VecDeque, sync::Arc}; + +fn build_test_schema() -> SchemaDescPtr { + use parquet::schema::{parser::parse_message_type, types::SchemaDescriptor}; + let message_type = " + message test_schema { + REQUIRED INT32 mandatory_int32_leaf; + OPTIONAL INT32 optional_int32_leaf; + REQUIRED BYTE_ARRAY mandatory_string_leaf (UTF8); + OPTIONAL BYTE_ARRAY optional_string_leaf (UTF8); + REQUIRED INT64 mandatory_int64_leaf; + OPTIONAL INT64 optional_int64_leaf; + } + "; + parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap() +} + +// test data params +const NUM_ROW_GROUPS: usize = 1; +const PAGES_PER_GROUP: usize = 2; +const VALUES_PER_PAGE: usize = 10_000; +const BATCH_SIZE: usize = 8192; +const EXPECTED_VALUE_COUNT: usize = NUM_ROW_GROUPS * PAGES_PER_GROUP * VALUES_PER_PAGE; + +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn build_encoded_primitive_page_iterator( + schema: SchemaDescPtr, + column_desc: ColumnDescPtr, + null_density: f32, + encoding: Encoding, +) -> impl PageIterator + Clone +where + T: parquet::data_type::DataType, + T::T: SampleUniform + FromPrimitive, +{ + let max_def_level = column_desc.max_def_level(); + let max_rep_level = column_desc.max_rep_level(); + let rep_levels = vec![0; VALUES_PER_PAGE]; + let mut rng = seedable_rng(); + let mut pages: Vec> = Vec::new(); + for _i in 0..NUM_ROW_GROUPS { + let mut column_chunk_pages = Vec::new(); + for _j in 0..PAGES_PER_GROUP { + // generate page + let mut values = Vec::with_capacity(VALUES_PER_PAGE); + let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); + for _k in 0..VALUES_PER_PAGE { + let def_level = if rng.gen::() < null_density { + max_def_level - 1 + } else { + max_def_level + }; + if def_level == max_def_level { + let value = + FromPrimitive::from_usize(rng.gen_range(0..1000)).unwrap(); + values.push(value); + } + def_levels.push(def_level); + } + let mut page_builder = + DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); + page_builder.add_rep_levels(max_rep_level, &rep_levels); + page_builder.add_def_levels(max_def_level, &def_levels); + page_builder.add_values::(encoding, &values); + column_chunk_pages.push(page_builder.consume()); + } + pages.push(column_chunk_pages); + } + + InMemoryPageIterator::new(schema, column_desc, pages) +} + +fn build_dictionary_encoded_primitive_page_iterator( + schema: SchemaDescPtr, + column_desc: ColumnDescPtr, + null_density: f32, +) -> impl PageIterator + Clone +where + T: parquet::data_type::DataType, + T::T: SampleUniform + FromPrimitive + Copy, +{ + use parquet::encoding::{DictEncoder, Encoder}; + let max_def_level = column_desc.max_def_level(); + let max_rep_level = column_desc.max_rep_level(); + let rep_levels = vec![0; VALUES_PER_PAGE]; + // generate 1% unique values + const NUM_UNIQUE_VALUES: usize = VALUES_PER_PAGE / 100; + let unique_values: Vec = (0..NUM_UNIQUE_VALUES) + .map(|x| FromPrimitive::from_usize(x + 1).unwrap()) + .collect::>(); + let mut rng = seedable_rng(); + let mut pages: Vec> = Vec::new(); + for _i in 0..NUM_ROW_GROUPS { + let mut column_chunk_pages = VecDeque::new(); + let mut dict_encoder = DictEncoder::::new(column_desc.clone()); + // add data pages + for _j in 0..PAGES_PER_GROUP { + // generate page + let mut values = Vec::with_capacity(VALUES_PER_PAGE); + let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); + for _k in 0..VALUES_PER_PAGE { + let def_level = if rng.gen::() < null_density { + max_def_level - 1 + } else { + max_def_level + }; + if def_level == max_def_level { + // select random value from list of unique values + let value = unique_values[rng.gen_range(0..NUM_UNIQUE_VALUES)]; + values.push(value); + } + def_levels.push(def_level); + } + let mut page_builder = + DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); + page_builder.add_rep_levels(max_rep_level, &rep_levels); + page_builder.add_def_levels(max_def_level, &def_levels); + let _ = dict_encoder.put(&values); + let indices = dict_encoder + .write_indices() + .expect("write_indices() should be OK"); + page_builder.add_indices(indices); + column_chunk_pages.push_back(page_builder.consume()); + } + // add dictionary page + let dict = dict_encoder + .write_dict() + .expect("write_dict() should be OK"); + let dict_page = parquet::column::page::Page::DictionaryPage { + buf: dict, + num_values: dict_encoder.num_entries() as u32, + encoding: Encoding::RLE_DICTIONARY, + is_sorted: false, + }; + column_chunk_pages.push_front(dict_page); + pages.push(column_chunk_pages.into()); + } + + InMemoryPageIterator::new(schema, column_desc, pages) +} + +fn build_plain_encoded_string_page_iterator( + schema: SchemaDescPtr, + column_desc: ColumnDescPtr, + null_density: f32, +) -> impl PageIterator + Clone { + let max_def_level = column_desc.max_def_level(); + let max_rep_level = column_desc.max_rep_level(); + let rep_levels = vec![0; VALUES_PER_PAGE]; + let mut rng = seedable_rng(); + let mut pages: Vec> = Vec::new(); + for i in 0..NUM_ROW_GROUPS { + let mut column_chunk_pages = Vec::new(); + for j in 0..PAGES_PER_GROUP { + // generate page + let mut values = Vec::with_capacity(VALUES_PER_PAGE); + let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); + for k in 0..VALUES_PER_PAGE { + let def_level = if rng.gen::() < null_density { + max_def_level - 1 + } else { + max_def_level + }; + if def_level == max_def_level { + let string_value = + format!("Test value {}, row group: {}, page: {}", k, i, j); + values + .push(parquet::data_type::ByteArray::from(string_value.as_str())); + } + def_levels.push(def_level); + } + let mut page_builder = + DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); + page_builder.add_rep_levels(max_rep_level, &rep_levels); + page_builder.add_def_levels(max_def_level, &def_levels); + page_builder.add_values::(Encoding::PLAIN, &values); + column_chunk_pages.push(page_builder.consume()); + } + pages.push(column_chunk_pages); + } + + InMemoryPageIterator::new(schema, column_desc, pages) +} + +fn build_dictionary_encoded_string_page_iterator( + schema: SchemaDescPtr, + column_desc: ColumnDescPtr, + null_density: f32, +) -> impl PageIterator + Clone { + use parquet::encoding::{DictEncoder, Encoder}; + let max_def_level = column_desc.max_def_level(); + let max_rep_level = column_desc.max_rep_level(); + let rep_levels = vec![0; VALUES_PER_PAGE]; + // generate 1% unique values + const NUM_UNIQUE_VALUES: usize = VALUES_PER_PAGE / 100; + let unique_values = (0..NUM_UNIQUE_VALUES) + .map(|x| format!("Dictionary value {}", x)) + .collect::>(); + let mut rng = seedable_rng(); + let mut pages: Vec> = Vec::new(); + for _i in 0..NUM_ROW_GROUPS { + let mut column_chunk_pages = VecDeque::new(); + let mut dict_encoder = DictEncoder::::new(column_desc.clone()); + // add data pages + for _j in 0..PAGES_PER_GROUP { + // generate page + let mut values = Vec::with_capacity(VALUES_PER_PAGE); + let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); + for _k in 0..VALUES_PER_PAGE { + let def_level = if rng.gen::() < null_density { + max_def_level - 1 + } else { + max_def_level + }; + if def_level == max_def_level { + // select random value from list of unique values + let string_value = + unique_values[rng.gen_range(0..NUM_UNIQUE_VALUES)].as_str(); + values.push(parquet::data_type::ByteArray::from(string_value)); + } + def_levels.push(def_level); + } + let mut page_builder = + DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); + page_builder.add_rep_levels(max_rep_level, &rep_levels); + page_builder.add_def_levels(max_def_level, &def_levels); + let _ = dict_encoder.put(&values); + let indices = dict_encoder + .write_indices() + .expect("write_indices() should be OK"); + page_builder.add_indices(indices); + column_chunk_pages.push_back(page_builder.consume()); + } + // add dictionary page + let dict = dict_encoder + .write_dict() + .expect("write_dict() should be OK"); + let dict_page = parquet::column::page::Page::DictionaryPage { + buf: dict, + num_values: dict_encoder.num_entries() as u32, + encoding: Encoding::RLE_DICTIONARY, + is_sorted: false, + }; + column_chunk_pages.push_front(dict_page); + pages.push(column_chunk_pages.into()); + } + + InMemoryPageIterator::new(schema, column_desc, pages) +} + +fn bench_array_reader(mut array_reader: Box) -> usize { + // test procedure: read data in batches of 8192 until no more data + let mut total_count = 0; + loop { + let array = array_reader.next_batch(BATCH_SIZE); + let array_len = array.unwrap().len(); + total_count += array_len; + if array_len < BATCH_SIZE { + break; + } + } + total_count +} + +fn create_primitive_array_reader( + page_iterator: impl PageIterator + 'static, + column_desc: ColumnDescPtr, +) -> Box { + use parquet::arrow::array_reader::PrimitiveArrayReader; + match column_desc.physical_type() { + Type::INT32 => { + let reader = PrimitiveArrayReader::::new_with_options( + Box::new(page_iterator), + column_desc, + None, + true, + ) + .unwrap(); + Box::new(reader) + } + Type::INT64 => { + let reader = PrimitiveArrayReader::::new_with_options( + Box::new(page_iterator), + column_desc, + None, + true, + ) + .unwrap(); + Box::new(reader) + } + _ => unreachable!(), + } +} + +fn create_string_byte_array_reader( + page_iterator: impl PageIterator + 'static, + column_desc: ColumnDescPtr, +) -> Box { + use parquet::arrow::array_reader::make_byte_array_reader; + make_byte_array_reader(Box::new(page_iterator), column_desc, None, true).unwrap() +} + +fn create_string_byte_array_dictionary_reader( + page_iterator: impl PageIterator + 'static, + column_desc: ColumnDescPtr, +) -> Box { + use parquet::arrow::array_reader::make_byte_array_dictionary_reader; + let arrow_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + + make_byte_array_dictionary_reader( + Box::new(page_iterator), + column_desc, + Some(arrow_type), + true, + ) + .unwrap() +} + +fn bench_primitive( + group: &mut BenchmarkGroup, + schema: &SchemaDescPtr, + mandatory_column_desc: &ColumnDescPtr, + optional_column_desc: &ColumnDescPtr, +) where + T: parquet::data_type::DataType, + T::T: SampleUniform + FromPrimitive + Copy, +{ + let mut count: usize = 0; + + // plain encoded, no NULLs + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + mandatory_column_desc.clone(), + 0.0, + Encoding::PLAIN, + ); + group.bench_function("plain encoded, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_primitive_array_reader( + data.clone(), + mandatory_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.0, + Encoding::PLAIN, + ); + group.bench_function("plain encoded, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = + create_primitive_array_reader(data.clone(), optional_column_desc.clone()); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // plain encoded, half NULLs + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.5, + Encoding::PLAIN, + ); + group.bench_function("plain encoded, optional, half NULLs", |b| { + b.iter(|| { + let array_reader = + create_primitive_array_reader(data.clone(), optional_column_desc.clone()); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // binary packed, no NULLs + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + mandatory_column_desc.clone(), + 0.0, + Encoding::DELTA_BINARY_PACKED, + ); + group.bench_function("binary packed, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_primitive_array_reader( + data.clone(), + mandatory_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.0, + Encoding::DELTA_BINARY_PACKED, + ); + group.bench_function("binary packed, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = + create_primitive_array_reader(data.clone(), optional_column_desc.clone()); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // binary packed, half NULLs + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.5, + Encoding::DELTA_BINARY_PACKED, + ); + group.bench_function("binary packed, optional, half NULLs", |b| { + b.iter(|| { + let array_reader = + create_primitive_array_reader(data.clone(), optional_column_desc.clone()); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // dictionary encoded, no NULLs + let data = build_dictionary_encoded_primitive_page_iterator::( + schema.clone(), + mandatory_column_desc.clone(), + 0.0, + ); + group.bench_function("dictionary encoded, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_primitive_array_reader( + data.clone(), + mandatory_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + let data = build_dictionary_encoded_primitive_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.0, + ); + group.bench_function("dictionary encoded, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = + create_primitive_array_reader(data.clone(), optional_column_desc.clone()); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // dictionary encoded, half NULLs + let data = build_dictionary_encoded_primitive_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.5, + ); + group.bench_function("dictionary encoded, optional, half NULLs", |b| { + b.iter(|| { + let array_reader = + create_primitive_array_reader(data.clone(), optional_column_desc.clone()); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); +} + +fn add_benches(c: &mut Criterion) { + let mut count: usize = 0; + + let schema = build_test_schema(); + let mandatory_int32_column_desc = schema.column(0); + let optional_int32_column_desc = schema.column(1); + let mandatory_string_column_desc = schema.column(2); + let optional_string_column_desc = schema.column(3); + let mandatory_int64_column_desc = schema.column(4); + let optional_int64_column_desc = schema.column(5); + // primitive / int32 benchmarks + // ============================= + + let mut group = c.benchmark_group("arrow_array_reader/Int32Array"); + bench_primitive::( + &mut group, + &schema, + &mandatory_int32_column_desc, + &optional_int32_column_desc, + ); + group.finish(); + + // primitive / int64 benchmarks + // ============================= + + let mut group = c.benchmark_group("arrow_array_reader/Int64Array"); + bench_primitive::( + &mut group, + &schema, + &mandatory_int64_column_desc, + &optional_int64_column_desc, + ); + group.finish(); + + // string benchmarks + //============================== + + let mut group = c.benchmark_group("arrow_array_reader/StringArray"); + + // string, plain encoded, no NULLs + let plain_string_no_null_data = build_plain_encoded_string_page_iterator( + schema.clone(), + mandatory_string_column_desc.clone(), + 0.0, + ); + group.bench_function("plain encoded, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_reader( + plain_string_no_null_data.clone(), + mandatory_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + let plain_string_no_null_data = build_plain_encoded_string_page_iterator( + schema.clone(), + optional_string_column_desc.clone(), + 0.0, + ); + group.bench_function("plain encoded, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_reader( + plain_string_no_null_data.clone(), + optional_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // string, plain encoded, half NULLs + let plain_string_half_null_data = build_plain_encoded_string_page_iterator( + schema.clone(), + optional_string_column_desc.clone(), + 0.5, + ); + group.bench_function("plain encoded, optional, half NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_reader( + plain_string_half_null_data.clone(), + optional_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // string, dictionary encoded, no NULLs + let dictionary_string_no_null_data = build_dictionary_encoded_string_page_iterator( + schema.clone(), + mandatory_string_column_desc.clone(), + 0.0, + ); + group.bench_function("dictionary encoded, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_reader( + dictionary_string_no_null_data.clone(), + mandatory_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + let dictionary_string_no_null_data = build_dictionary_encoded_string_page_iterator( + schema.clone(), + optional_string_column_desc.clone(), + 0.0, + ); + group.bench_function("dictionary encoded, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_reader( + dictionary_string_no_null_data.clone(), + optional_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // string, dictionary encoded, half NULLs + let dictionary_string_half_null_data = build_dictionary_encoded_string_page_iterator( + schema.clone(), + optional_string_column_desc.clone(), + 0.5, + ); + group.bench_function("dictionary encoded, optional, half NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_reader( + dictionary_string_half_null_data.clone(), + optional_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + group.finish(); + + // string dictionary benchmarks + //============================== + + let mut group = c.benchmark_group("arrow_array_reader/StringDictionary"); + + group.bench_function("dictionary encoded, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_dictionary_reader( + dictionary_string_no_null_data.clone(), + mandatory_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + group.bench_function("dictionary encoded, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_dictionary_reader( + dictionary_string_no_null_data.clone(), + optional_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + group.bench_function("dictionary encoded, optional, half NULLs", |b| { + b.iter(|| { + let array_reader = create_string_byte_array_dictionary_reader( + dictionary_string_half_null_data.clone(), + optional_string_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + group.finish(); +} + +criterion_group!(benches, add_benches); +criterion_main!(benches); diff --git a/parquet/benches/arrow_writer.rs b/parquet/benches/arrow_writer.rs index f1154eb9e394..25ff1ca90dc6 100644 --- a/parquet/benches/arrow_writer.rs +++ b/parquet/benches/arrow_writer.rs @@ -26,9 +26,7 @@ use std::sync::Arc; use arrow::datatypes::*; use arrow::{record_batch::RecordBatch, util::data_gen::*}; -use parquet::{ - arrow::ArrowWriter, errors::Result, file::writer::InMemoryWriteableCursor, -}; +use parquet::{arrow::ArrowWriter, errors::Result}; fn create_primitive_bench_batch( size: usize, @@ -278,8 +276,8 @@ fn _create_nested_bench_batch( #[inline] fn write_batch(batch: &RecordBatch) -> Result<()> { // Write batch to an in-memory writer - let cursor = InMemoryWriteableCursor::default(); - let mut writer = ArrowWriter::try_new(cursor, batch.schema(), None)?; + let buffer = vec![]; + let mut writer = ArrowWriter::try_new(buffer, batch.schema(), None)?; writer.write(batch)?; writer.close()?; diff --git a/parquet/src/arrow/array_reader.rs b/parquet/src/arrow/array_reader.rs deleted file mode 100644 index ae001ed73391..000000000000 --- a/parquet/src/arrow/array_reader.rs +++ /dev/null @@ -1,2750 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::cmp::{max, min}; -use std::collections::{HashMap, HashSet}; -use std::marker::PhantomData; -use std::mem::size_of; -use std::result::Result::Ok; -use std::sync::Arc; -use std::vec::Vec; - -use arrow::array::{ - new_empty_array, Array, ArrayData, ArrayDataBuilder, ArrayRef, BinaryArray, - BinaryBuilder, BooleanArray, BooleanBufferBuilder, BooleanBuilder, DecimalBuilder, - FixedSizeBinaryArray, FixedSizeBinaryBuilder, GenericListArray, Int16BufferBuilder, - Int32Array, Int64Array, MapArray, OffsetSizeTrait, PrimitiveArray, PrimitiveBuilder, - StringArray, StringBuilder, StructArray, -}; -use arrow::buffer::{Buffer, MutableBuffer}; -use arrow::datatypes::{ - ArrowPrimitiveType, BooleanType as ArrowBooleanType, DataType as ArrowType, - Date32Type as ArrowDate32Type, Date64Type as ArrowDate64Type, - DurationMicrosecondType as ArrowDurationMicrosecondType, - DurationMillisecondType as ArrowDurationMillisecondType, - DurationNanosecondType as ArrowDurationNanosecondType, - DurationSecondType as ArrowDurationSecondType, Field, - Float32Type as ArrowFloat32Type, Float64Type as ArrowFloat64Type, - Int16Type as ArrowInt16Type, Int32Type as ArrowInt32Type, - Int64Type as ArrowInt64Type, Int8Type as ArrowInt8Type, IntervalUnit, Schema, - Time32MillisecondType as ArrowTime32MillisecondType, - Time32SecondType as ArrowTime32SecondType, - Time64MicrosecondType as ArrowTime64MicrosecondType, - Time64NanosecondType as ArrowTime64NanosecondType, TimeUnit as ArrowTimeUnit, - TimestampMicrosecondType as ArrowTimestampMicrosecondType, - TimestampMillisecondType as ArrowTimestampMillisecondType, - TimestampNanosecondType as ArrowTimestampNanosecondType, - TimestampSecondType as ArrowTimestampSecondType, ToByteSlice, - UInt16Type as ArrowUInt16Type, UInt32Type as ArrowUInt32Type, - UInt64Type as ArrowUInt64Type, UInt8Type as ArrowUInt8Type, -}; -use arrow::util::bit_util; - -use crate::arrow::converter::{ - BinaryArrayConverter, BinaryConverter, Converter, DecimalArrayConverter, - DecimalConverter, FixedLenBinaryConverter, FixedSizeArrayConverter, - Int96ArrayConverter, Int96Converter, IntervalDayTimeArrayConverter, - IntervalDayTimeConverter, IntervalYearMonthArrayConverter, - IntervalYearMonthConverter, LargeBinaryArrayConverter, LargeBinaryConverter, - LargeUtf8ArrayConverter, LargeUtf8Converter, -}; -use crate::arrow::record_reader::RecordReader; -use crate::arrow::schema::parquet_to_arrow_field; -use crate::basic::{ConvertedType, Repetition, Type as PhysicalType}; -use crate::column::page::PageIterator; -use crate::column::reader::ColumnReaderImpl; -use crate::data_type::{ - BoolType, ByteArrayType, DataType, DoubleType, FixedLenByteArrayType, FloatType, - Int32Type, Int64Type, Int96Type, -}; -use crate::errors::{ParquetError, ParquetError::ArrowError, Result}; -use crate::file::reader::{FilePageIterator, FileReader}; -use crate::schema::types::{ - ColumnDescPtr, ColumnDescriptor, ColumnPath, SchemaDescPtr, Type, TypePtr, -}; -use crate::schema::visitor::TypeVisitor; -use std::any::Any; - -/// Array reader reads parquet data into arrow array. -pub trait ArrayReader { - fn as_any(&self) -> &dyn Any; - - /// Returns the arrow type of this array reader. - fn get_data_type(&self) -> &ArrowType; - - /// Reads at most `batch_size` records into an arrow array and return it. - fn next_batch(&mut self, batch_size: usize) -> Result; - - /// Returns the definition levels of data from last call of `next_batch`. - /// The result is used by parent array reader to calculate its own definition - /// levels and repetition levels, so that its parent can calculate null bitmap. - fn get_def_levels(&self) -> Option<&[i16]>; - - /// Return the repetition levels of data from last call of `next_batch`. - /// The result is used by parent array reader to calculate its own definition - /// levels and repetition levels, so that its parent can calculate null bitmap. - fn get_rep_levels(&self) -> Option<&[i16]>; -} - -/// A NullArrayReader reads Parquet columns stored as null int32s with an Arrow -/// NullArray type. -pub struct NullArrayReader { - data_type: ArrowType, - pages: Box, - def_levels_buffer: Option, - rep_levels_buffer: Option, - column_desc: ColumnDescPtr, - record_reader: RecordReader, - _type_marker: PhantomData, -} - -impl NullArrayReader { - /// Construct null array reader. - pub fn new( - mut pages: Box, - column_desc: ColumnDescPtr, - ) -> Result { - let mut record_reader = RecordReader::::new(column_desc.clone()); - if let Some(page_reader) = pages.next() { - record_reader.set_page_reader(page_reader?)?; - } - - Ok(Self { - data_type: ArrowType::Null, - pages, - def_levels_buffer: None, - rep_levels_buffer: None, - column_desc, - record_reader, - _type_marker: PhantomData, - }) - } -} - -/// Implementation of primitive array reader. -impl ArrayReader for NullArrayReader { - fn as_any(&self) -> &dyn Any { - self - } - - /// Returns data type of primitive array. - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - /// Reads at most `batch_size` records into array. - fn next_batch(&mut self, batch_size: usize) -> Result { - let mut records_read = 0usize; - while records_read < batch_size { - let records_to_read = batch_size - records_read; - - // NB can be 0 if at end of page - let records_read_once = self.record_reader.read_records(records_to_read)?; - records_read += records_read_once; - - // Record reader exhausted - if records_read_once < records_to_read { - if let Some(page_reader) = self.pages.next() { - // Read from new page reader - self.record_reader.set_page_reader(page_reader?)?; - } else { - // Page reader also exhausted - break; - } - } - } - - // convert to arrays - let array = arrow::array::NullArray::new(records_read); - - // save definition and repetition buffers - self.def_levels_buffer = self.record_reader.consume_def_levels()?; - self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; - self.record_reader.reset(); - Ok(Arc::new(array)) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_levels_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_levels_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } -} - -/// Primitive array readers are leaves of array reader tree. They accept page iterator -/// and read them into primitive arrays. -pub struct PrimitiveArrayReader { - data_type: ArrowType, - pages: Box, - def_levels_buffer: Option, - rep_levels_buffer: Option, - column_desc: ColumnDescPtr, - record_reader: RecordReader, - _type_marker: PhantomData, -} - -impl PrimitiveArrayReader { - /// Construct primitive array reader. - pub fn new( - mut pages: Box, - column_desc: ColumnDescPtr, - arrow_type: Option, - ) -> Result { - // Check if Arrow type is specified, else create it from Parquet type - let data_type = match arrow_type { - Some(t) => t, - None => parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(), - }; - - let mut record_reader = RecordReader::::new(column_desc.clone()); - if let Some(page_reader) = pages.next() { - record_reader.set_page_reader(page_reader?)?; - } - - Ok(Self { - data_type, - pages, - def_levels_buffer: None, - rep_levels_buffer: None, - column_desc, - record_reader, - _type_marker: PhantomData, - }) - } -} - -/// Implementation of primitive array reader. -impl ArrayReader for PrimitiveArrayReader { - fn as_any(&self) -> &dyn Any { - self - } - - /// Returns data type of primitive array. - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - /// Reads at most `batch_size` records into array. - fn next_batch(&mut self, batch_size: usize) -> Result { - let mut records_read = 0usize; - while records_read < batch_size { - let records_to_read = batch_size - records_read; - - // NB can be 0 if at end of page - let records_read_once = self.record_reader.read_records(records_to_read)?; - records_read += records_read_once; - - // Record reader exhausted - if records_read_once < records_to_read { - if let Some(page_reader) = self.pages.next() { - // Read from new page reader - self.record_reader.set_page_reader(page_reader?)?; - } else { - // Page reader also exhausted - break; - } - } - } - - let target_type = self.get_data_type().clone(); - let arrow_data_type = match T::get_physical_type() { - PhysicalType::BOOLEAN => ArrowBooleanType::DATA_TYPE, - PhysicalType::INT32 => { - match target_type { - ArrowType::UInt32 => { - // follow C++ implementation and use overflow/reinterpret cast from i32 to u32 which will map - // `i32::MIN..0` to `(i32::MAX as u32)..u32::MAX` - ArrowUInt32Type::DATA_TYPE - } - _ => ArrowInt32Type::DATA_TYPE, - } - } - PhysicalType::INT64 => { - match target_type { - ArrowType::UInt64 => { - // follow C++ implementation and use overflow/reinterpret cast from i64 to u64 which will map - // `i64::MIN..0` to `(i64::MAX as u64)..u64::MAX` - ArrowUInt64Type::DATA_TYPE - } - _ => ArrowInt64Type::DATA_TYPE, - } - } - PhysicalType::FLOAT => ArrowFloat32Type::DATA_TYPE, - PhysicalType::DOUBLE => ArrowFloat64Type::DATA_TYPE, - PhysicalType::INT96 - | PhysicalType::BYTE_ARRAY - | PhysicalType::FIXED_LEN_BYTE_ARRAY => { - unreachable!( - "PrimitiveArrayReaders don't support complex physical types" - ); - } - }; - - // Convert to arrays by using the Parquet phyisical type. - // The physical types are then cast to Arrow types if necessary - - let mut record_data = self.record_reader.consume_record_data()?; - - if T::get_physical_type() == PhysicalType::BOOLEAN { - let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); - - for e in record_data.as_slice() { - boolean_buffer.append(*e > 0); - } - record_data = boolean_buffer.finish(); - } - - let mut array_data = ArrayDataBuilder::new(arrow_data_type) - .len(self.record_reader.num_values()) - .add_buffer(record_data); - - if let Some(b) = self.record_reader.consume_bitmap_buffer()? { - array_data = array_data.null_bit_buffer(b); - } - - let array_data = unsafe { array_data.build_unchecked() }; - let array = match T::get_physical_type() { - PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)) as ArrayRef, - PhysicalType::INT32 => { - Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef - } - PhysicalType::INT64 => { - Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef - } - PhysicalType::FLOAT => { - Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef - } - PhysicalType::DOUBLE => { - Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef - } - PhysicalType::INT96 - | PhysicalType::BYTE_ARRAY - | PhysicalType::FIXED_LEN_BYTE_ARRAY => { - unreachable!( - "PrimitiveArrayReaders don't support complex physical types" - ); - } - }; - - // cast to Arrow type - // We make a strong assumption here that the casts should be infallible. - // If the cast fails because of incompatible datatypes, then there might - // be a bigger problem with how Arrow schemas are converted to Parquet. - // - // As there is not always a 1:1 mapping between Arrow and Parquet, there - // are datatypes which we must convert explicitly. - // These are: - // - date64: we should cast int32 to date32, then date32 to date64. - let array = match target_type { - ArrowType::Date64 => { - // this is cheap as it internally reinterprets the data - let a = arrow::compute::cast(&array, &ArrowType::Date32)?; - arrow::compute::cast(&a, &target_type)? - } - ArrowType::Decimal(p, s) => { - let mut builder = DecimalBuilder::new(array.len(), p, s); - match array.data_type() { - ArrowType::Int32 => { - let values = array.as_any().downcast_ref::().unwrap(); - for maybe_value in values.iter() { - match maybe_value { - Some(value) => builder.append_value(value as i128)?, - None => builder.append_null()?, - } - } - } - ArrowType::Int64 => { - let values = array.as_any().downcast_ref::().unwrap(); - for maybe_value in values.iter() { - match maybe_value { - Some(value) => builder.append_value(value as i128)?, - None => builder.append_null()?, - } - } - } - _ => { - return Err(ArrowError(format!( - "Cannot convert {:?} to decimal", - array.data_type() - ))) - } - } - Arc::new(builder.finish()) as ArrayRef - } - _ => arrow::compute::cast(&array, &target_type)?, - }; - - // save definition and repetition buffers - self.def_levels_buffer = self.record_reader.consume_def_levels()?; - self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; - self.record_reader.reset(); - Ok(array) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_levels_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_levels_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } -} - -/// Primitive array readers are leaves of array reader tree. They accept page iterator -/// and read them into primitive arrays. -pub struct ComplexObjectArrayReader -where - T: DataType, - C: Converter>, ArrayRef> + 'static, -{ - data_type: ArrowType, - pages: Box, - def_levels_buffer: Option>, - rep_levels_buffer: Option>, - column_desc: ColumnDescPtr, - column_reader: Option>, - converter: C, - _parquet_type_marker: PhantomData, - _converter_marker: PhantomData, -} - -impl ArrayReader for ComplexObjectArrayReader -where - T: DataType, - C: Converter>, ArrayRef> + 'static, -{ - fn as_any(&self) -> &dyn Any { - self - } - - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - fn next_batch(&mut self, batch_size: usize) -> Result { - // Try to initialize column reader - if self.column_reader.is_none() { - self.next_column_reader()?; - } - - let mut data_buffer: Vec = Vec::with_capacity(batch_size); - data_buffer.resize_with(batch_size, T::T::default); - - let mut def_levels_buffer = if self.column_desc.max_def_level() > 0 { - let mut buf: Vec = Vec::with_capacity(batch_size); - buf.resize_with(batch_size, || 0); - Some(buf) - } else { - None - }; - - let mut rep_levels_buffer = if self.column_desc.max_rep_level() > 0 { - let mut buf: Vec = Vec::with_capacity(batch_size); - buf.resize_with(batch_size, || 0); - Some(buf) - } else { - None - }; - - let mut num_read = 0; - - while self.column_reader.is_some() && num_read < batch_size { - let num_to_read = batch_size - num_read; - let cur_data_buf = &mut data_buffer[num_read..]; - let cur_def_levels_buf = - def_levels_buffer.as_mut().map(|b| &mut b[num_read..]); - let cur_rep_levels_buf = - rep_levels_buffer.as_mut().map(|b| &mut b[num_read..]); - let (data_read, levels_read) = - self.column_reader.as_mut().unwrap().read_batch( - num_to_read, - cur_def_levels_buf, - cur_rep_levels_buf, - cur_data_buf, - )?; - - // Fill space - if levels_read > data_read { - def_levels_buffer.iter().for_each(|def_levels_buffer| { - let (mut level_pos, mut data_pos) = (levels_read, data_read); - while level_pos > 0 && data_pos > 0 { - if def_levels_buffer[num_read + level_pos - 1] - == self.column_desc.max_def_level() - { - cur_data_buf.swap(level_pos - 1, data_pos - 1); - level_pos -= 1; - data_pos -= 1; - } else { - level_pos -= 1; - } - } - }); - } - - let values_read = max(levels_read, data_read); - num_read += values_read; - // current page exhausted && page iterator exhausted - if values_read < num_to_read && !self.next_column_reader()? { - break; - } - } - - data_buffer.truncate(num_read); - def_levels_buffer - .iter_mut() - .for_each(|buf| buf.truncate(num_read)); - rep_levels_buffer - .iter_mut() - .for_each(|buf| buf.truncate(num_read)); - - self.def_levels_buffer = def_levels_buffer; - self.rep_levels_buffer = rep_levels_buffer; - - let data: Vec> = if self.def_levels_buffer.is_some() { - data_buffer - .into_iter() - .zip(self.def_levels_buffer.as_ref().unwrap().iter()) - .map(|(t, def_level)| { - if *def_level == self.column_desc.max_def_level() { - Some(t) - } else { - None - } - }) - .collect() - } else { - data_buffer.into_iter().map(Some).collect() - }; - - let mut array = self.converter.convert(data)?; - - if let ArrowType::Dictionary(_, _) = self.data_type { - array = arrow::compute::cast(&array, &self.data_type)?; - } - - Ok(array) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_levels_buffer.as_deref() - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_levels_buffer.as_deref() - } -} - -impl ComplexObjectArrayReader -where - T: DataType, - C: Converter>, ArrayRef> + 'static, -{ - pub fn new( - pages: Box, - column_desc: ColumnDescPtr, - converter: C, - arrow_type: Option, - ) -> Result { - let data_type = match arrow_type { - Some(t) => t, - None => parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(), - }; - - Ok(Self { - data_type, - pages, - def_levels_buffer: None, - rep_levels_buffer: None, - column_desc, - column_reader: None, - converter, - _parquet_type_marker: PhantomData, - _converter_marker: PhantomData, - }) - } - - fn next_column_reader(&mut self) -> Result { - Ok(match self.pages.next() { - Some(page) => { - self.column_reader = - Some(ColumnReaderImpl::::new(self.column_desc.clone(), page?)); - true - } - None => false, - }) - } -} - -/// Implementation of list array reader. -pub struct ListArrayReader { - item_reader: Box, - data_type: ArrowType, - item_type: ArrowType, - list_def_level: i16, - list_rep_level: i16, - list_empty_def_level: i16, - list_null_def_level: i16, - def_level_buffer: Option, - rep_level_buffer: Option, - _marker: PhantomData, -} - -impl ListArrayReader { - /// Construct list array reader. - pub fn new( - item_reader: Box, - data_type: ArrowType, - item_type: ArrowType, - def_level: i16, - rep_level: i16, - list_null_def_level: i16, - list_empty_def_level: i16, - ) -> Self { - Self { - item_reader, - data_type, - item_type, - list_def_level: def_level, - list_rep_level: rep_level, - list_null_def_level, - list_empty_def_level, - def_level_buffer: None, - rep_level_buffer: None, - _marker: PhantomData, - } - } -} - -macro_rules! remove_primitive_array_indices { - ($arr: expr, $item_type:ty, $indices:expr) => {{ - let array_data = match $arr.as_any().downcast_ref::>() { - Some(a) => a, - _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), - }; - let mut builder = PrimitiveBuilder::<$item_type>::new($arr.len()); - for i in 0..array_data.len() { - if !$indices.contains(&i) { - if array_data.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array_data.value(i))?; - } - } - } - Ok(Arc::new(builder.finish())) - }}; -} - -macro_rules! remove_array_indices_custom_builder { - ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr) => {{ - let array_data = match $arr.as_any().downcast_ref::<$array_type>() { - Some(a) => a, - _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), - }; - let mut builder = $item_builder::new(array_data.len()); - - for i in 0..array_data.len() { - if !$indices.contains(&i) { - if array_data.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array_data.value(i))?; - } - } - } - Ok(Arc::new(builder.finish())) - }}; -} - -macro_rules! remove_fixed_size_binary_array_indices { - ($arr: expr, $array_type:ty, $item_builder:ident, $indices:expr, $len:expr) => {{ - let array_data = match $arr.as_any().downcast_ref::<$array_type>() { - Some(a) => a, - _ => return Err(ParquetError::General(format!("Error generating next batch for ListArray: {:?} cannot be downcast to PrimitiveArray", $arr))), - }; - let mut builder = FixedSizeBinaryBuilder::new(array_data.len(), $len); - for i in 0..array_data.len() { - if !$indices.contains(&i) { - if array_data.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array_data.value(i))?; - } - } - } - Ok(Arc::new(builder.finish())) - }}; -} - -fn remove_indices( - arr: ArrayRef, - item_type: ArrowType, - indices: Vec, -) -> Result { - match item_type { - ArrowType::UInt8 => remove_primitive_array_indices!(arr, ArrowUInt8Type, indices), - ArrowType::UInt16 => { - remove_primitive_array_indices!(arr, ArrowUInt16Type, indices) - } - ArrowType::UInt32 => { - remove_primitive_array_indices!(arr, ArrowUInt32Type, indices) - } - ArrowType::UInt64 => { - remove_primitive_array_indices!(arr, ArrowUInt64Type, indices) - } - ArrowType::Int8 => remove_primitive_array_indices!(arr, ArrowInt8Type, indices), - ArrowType::Int16 => remove_primitive_array_indices!(arr, ArrowInt16Type, indices), - ArrowType::Int32 => remove_primitive_array_indices!(arr, ArrowInt32Type, indices), - ArrowType::Int64 => remove_primitive_array_indices!(arr, ArrowInt64Type, indices), - ArrowType::Float32 => { - remove_primitive_array_indices!(arr, ArrowFloat32Type, indices) - } - ArrowType::Float64 => { - remove_primitive_array_indices!(arr, ArrowFloat64Type, indices) - } - ArrowType::Boolean => { - remove_array_indices_custom_builder!( - arr, - BooleanArray, - BooleanBuilder, - indices - ) - } - ArrowType::Date32 => { - remove_primitive_array_indices!(arr, ArrowDate32Type, indices) - } - ArrowType::Date64 => { - remove_primitive_array_indices!(arr, ArrowDate64Type, indices) - } - ArrowType::Time32(ArrowTimeUnit::Second) => { - remove_primitive_array_indices!(arr, ArrowTime32SecondType, indices) - } - ArrowType::Time32(ArrowTimeUnit::Millisecond) => { - remove_primitive_array_indices!(arr, ArrowTime32MillisecondType, indices) - } - ArrowType::Time64(ArrowTimeUnit::Microsecond) => { - remove_primitive_array_indices!(arr, ArrowTime64MicrosecondType, indices) - } - ArrowType::Time64(ArrowTimeUnit::Nanosecond) => { - remove_primitive_array_indices!(arr, ArrowTime64NanosecondType, indices) - } - ArrowType::Duration(ArrowTimeUnit::Second) => { - remove_primitive_array_indices!(arr, ArrowDurationSecondType, indices) - } - ArrowType::Duration(ArrowTimeUnit::Millisecond) => { - remove_primitive_array_indices!(arr, ArrowDurationMillisecondType, indices) - } - ArrowType::Duration(ArrowTimeUnit::Microsecond) => { - remove_primitive_array_indices!(arr, ArrowDurationMicrosecondType, indices) - } - ArrowType::Duration(ArrowTimeUnit::Nanosecond) => { - remove_primitive_array_indices!(arr, ArrowDurationNanosecondType, indices) - } - ArrowType::Timestamp(ArrowTimeUnit::Second, _) => { - remove_primitive_array_indices!(arr, ArrowTimestampSecondType, indices) - } - ArrowType::Timestamp(ArrowTimeUnit::Millisecond, _) => { - remove_primitive_array_indices!(arr, ArrowTimestampMillisecondType, indices) - } - ArrowType::Timestamp(ArrowTimeUnit::Microsecond, _) => { - remove_primitive_array_indices!(arr, ArrowTimestampMicrosecondType, indices) - } - ArrowType::Timestamp(ArrowTimeUnit::Nanosecond, _) => { - remove_primitive_array_indices!(arr, ArrowTimestampNanosecondType, indices) - } - ArrowType::Utf8 => { - remove_array_indices_custom_builder!(arr, StringArray, StringBuilder, indices) - } - ArrowType::Binary => { - remove_array_indices_custom_builder!(arr, BinaryArray, BinaryBuilder, indices) - } - ArrowType::FixedSizeBinary(size) => remove_fixed_size_binary_array_indices!( - arr, - FixedSizeBinaryArray, - FixedSizeBinaryBuilder, - indices, - size - ), - _ => Err(ParquetError::General(format!( - "ListArray of type List({:?}) is not supported by array_reader", - item_type - ))), - } -} - -/// Implementation of ListArrayReader. Nested lists and lists of structs are not yet supported. -impl ArrayReader for ListArrayReader { - fn as_any(&self) -> &dyn Any { - self - } - - /// Returns data type. - /// This must be a List. - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - fn next_batch(&mut self, batch_size: usize) -> Result { - let next_batch_array = self.item_reader.next_batch(batch_size)?; - let item_type = self.item_reader.get_data_type().clone(); - - if next_batch_array.len() == 0 { - return Ok(new_empty_array(&self.data_type)); - } - let def_levels = self - .item_reader - .get_def_levels() - .ok_or_else(|| ArrowError("item_reader def levels are None.".to_string()))?; - let rep_levels = self - .item_reader - .get_rep_levels() - .ok_or_else(|| ArrowError("item_reader rep levels are None.".to_string()))?; - - if !((def_levels.len() == rep_levels.len()) - && (rep_levels.len() == next_batch_array.len())) - { - return Err(ArrowError( - "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), - )); - } - - // List definitions can be encoded as 4 values: - // - n + 0: the list slot is null - // - n + 1: the list slot is not null, but is empty (i.e. []) - // - n + 2: the list slot is not null, but its child is empty (i.e. [ null ]) - // - n + 3: the list slot is not null, and its child is not empty - // Where n is the max definition level of the list's parent. - // If a Parquet schema's only leaf is the list, then n = 0. - - // If the list index is at empty definition, the child slot is null - let null_list_indices: Vec = def_levels - .iter() - .enumerate() - .filter_map(|(index, def)| { - if *def <= self.list_empty_def_level { - Some(index) - } else { - None - } - }) - .collect(); - let batch_values = match null_list_indices.len() { - 0 => next_batch_array.clone(), - _ => remove_indices(next_batch_array.clone(), item_type, null_list_indices)?, - }; - - // first item in each list has rep_level = 0, subsequent items have rep_level = 1 - let mut offsets: Vec = Vec::new(); - let mut cur_offset = OffsetSize::zero(); - def_levels.iter().zip(rep_levels).for_each(|(d, r)| { - if *r == 0 || d == &self.list_empty_def_level { - offsets.push(cur_offset); - } - if d > &self.list_empty_def_level { - cur_offset += OffsetSize::one(); - } - }); - offsets.push(cur_offset); - - let num_bytes = bit_util::ceil(offsets.len(), 8); - // TODO: A useful optimization is to use the null count to fill with - // 0 or null, to reduce individual bits set in a loop. - // To favour dense data, set every slot to true, then unset - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let null_slice = null_buf.as_slice_mut(); - let mut list_index = 0; - for i in 0..rep_levels.len() { - // If the level is lower than empty, then the slot is null. - // When a list is non-nullable, its empty level = null level, - // so this automatically factors that in. - if rep_levels[i] == 0 && def_levels[i] < self.list_empty_def_level { - bit_util::unset_bit(null_slice, list_index); - } - if rep_levels[i] == 0 { - list_index += 1; - } - } - let value_offsets = Buffer::from(&offsets.to_byte_slice()); - - let list_data = ArrayData::builder(self.get_data_type().clone()) - .len(offsets.len() - 1) - .add_buffer(value_offsets) - .add_child_data(batch_values.data().clone()) - .null_bit_buffer(null_buf.into()) - .offset(next_batch_array.offset()); - - let list_data = unsafe { list_data.build_unchecked() }; - - let result_array = GenericListArray::::from(list_data); - Ok(Arc::new(result_array)) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_level_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_level_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } -} - -/// Implementation of a map array reader. -pub struct MapArrayReader { - key_reader: Box, - value_reader: Box, - data_type: ArrowType, - map_def_level: i16, - map_rep_level: i16, - def_level_buffer: Option, - rep_level_buffer: Option, -} - -impl MapArrayReader { - pub fn new( - key_reader: Box, - value_reader: Box, - data_type: ArrowType, - def_level: i16, - rep_level: i16, - ) -> Self { - Self { - key_reader, - value_reader, - data_type, - map_def_level: rep_level, - map_rep_level: def_level, - def_level_buffer: None, - rep_level_buffer: None, - } - } -} - -impl ArrayReader for MapArrayReader { - fn as_any(&self) -> &dyn Any { - self - } - - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - fn next_batch(&mut self, batch_size: usize) -> Result { - let key_array = self.key_reader.next_batch(batch_size)?; - let value_array = self.value_reader.next_batch(batch_size)?; - - // Check that key and value have the same lengths - let key_length = key_array.len(); - if key_length != value_array.len() { - return Err(general_err!( - "Map key and value should have the same lengths." - )); - } - - let def_levels = self - .key_reader - .get_def_levels() - .ok_or_else(|| ArrowError("item_reader def levels are None.".to_string()))?; - let rep_levels = self - .key_reader - .get_rep_levels() - .ok_or_else(|| ArrowError("item_reader rep levels are None.".to_string()))?; - - if !((def_levels.len() == rep_levels.len()) && (rep_levels.len() == key_length)) { - return Err(ArrowError( - "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), - )); - } - - let entry_data_type = if let ArrowType::Map(field, _) = &self.data_type { - field.data_type().clone() - } else { - return Err(ArrowError("Expected a map arrow type".to_string())); - }; - - let entry_data = ArrayDataBuilder::new(entry_data_type) - .len(key_length) - .add_child_data(key_array.data().clone()) - .add_child_data(value_array.data().clone()); - let entry_data = unsafe { entry_data.build_unchecked() }; - - let entry_len = rep_levels.iter().filter(|level| **level == 0).count(); - - // first item in each list has rep_level = 0, subsequent items have rep_level = 1 - let mut offsets: Vec = Vec::new(); - let mut cur_offset = 0; - def_levels.iter().zip(rep_levels).for_each(|(d, r)| { - if *r == 0 || d == &self.map_def_level { - offsets.push(cur_offset); - } - if d > &self.map_def_level { - cur_offset += 1; - } - }); - offsets.push(cur_offset); - - let num_bytes = bit_util::ceil(offsets.len(), 8); - // TODO: A useful optimization is to use the null count to fill with - // 0 or null, to reduce individual bits set in a loop. - // To favour dense data, set every slot to true, then unset - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let null_slice = null_buf.as_slice_mut(); - let mut list_index = 0; - for i in 0..rep_levels.len() { - // If the level is lower than empty, then the slot is null. - // When a list is non-nullable, its empty level = null level, - // so this automatically factors that in. - if rep_levels[i] == 0 && def_levels[i] < self.map_def_level { - // should be empty list - bit_util::unset_bit(null_slice, list_index); - } - if rep_levels[i] == 0 { - list_index += 1; - } - } - let value_offsets = Buffer::from(&offsets.to_byte_slice()); - - // Now we can build array data - let array_data = ArrayDataBuilder::new(self.data_type.clone()) - .len(entry_len) - .add_buffer(value_offsets) - .null_bit_buffer(null_buf.into()) - .add_child_data(entry_data); - - let array_data = unsafe { array_data.build_unchecked() }; - - Ok(Arc::new(MapArray::from(array_data))) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_level_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_level_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } -} - -/// Implementation of struct array reader. -pub struct StructArrayReader { - children: Vec>, - data_type: ArrowType, - struct_def_level: i16, - struct_rep_level: i16, - def_level_buffer: Option, - rep_level_buffer: Option, -} - -impl StructArrayReader { - /// Construct struct array reader. - pub fn new( - data_type: ArrowType, - children: Vec>, - def_level: i16, - rep_level: i16, - ) -> Self { - Self { - data_type, - children, - struct_def_level: def_level, - struct_rep_level: rep_level, - def_level_buffer: None, - rep_level_buffer: None, - } - } -} - -impl ArrayReader for StructArrayReader { - fn as_any(&self) -> &dyn Any { - self - } - - /// Returns data type. - /// This must be a struct. - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - /// Read `batch_size` struct records. - /// - /// Definition levels of struct array is calculated as following: - /// ```ignore - /// def_levels[i] = min(child1_def_levels[i], child2_def_levels[i], ..., - /// childn_def_levels[i]); - /// ``` - /// - /// Repetition levels of struct array is calculated as following: - /// ```ignore - /// rep_levels[i] = child1_rep_levels[i]; - /// ``` - /// - /// The null bitmap of struct array is calculated from def_levels: - /// ```ignore - /// null_bitmap[i] = (def_levels[i] >= self.def_level); - /// ``` - fn next_batch(&mut self, batch_size: usize) -> Result { - if self.children.is_empty() { - self.def_level_buffer = None; - self.rep_level_buffer = None; - return Ok(Arc::new(StructArray::from(Vec::new()))); - } - - let children_array = self - .children - .iter_mut() - .map(|reader| reader.next_batch(batch_size)) - .try_fold( - Vec::new(), - |mut result, child_array| -> Result> { - result.push(child_array?); - Ok(result) - }, - )?; - - // check that array child data has same size - let children_array_len = - children_array.first().map(|arr| arr.len()).ok_or_else(|| { - general_err!("Struct array reader should have at least one child!") - })?; - - let all_children_len_eq = children_array - .iter() - .all(|arr| arr.len() == children_array_len); - if !all_children_len_eq { - return Err(general_err!("Not all children array length are the same!")); - } - - // calculate struct def level data - let buffer_size = children_array_len * size_of::(); - let mut def_level_data_buffer = MutableBuffer::new(buffer_size); - def_level_data_buffer.resize(buffer_size, 0); - - let def_level_data = def_level_data_buffer.typed_data_mut(); - - def_level_data - .iter_mut() - .for_each(|v| *v = self.struct_def_level); - - for child in &self.children { - if let Some(current_child_def_levels) = child.get_def_levels() { - if current_child_def_levels.len() != children_array_len { - return Err(general_err!("Child array length are not equal!")); - } else { - for i in 0..children_array_len { - def_level_data[i] = - min(def_level_data[i], current_child_def_levels[i]); - } - } - } - } - - // calculate bitmap for current array - let mut bitmap_builder = BooleanBufferBuilder::new(children_array_len); - for def_level in def_level_data { - let not_null = *def_level >= self.struct_def_level; - bitmap_builder.append(not_null); - } - - // Now we can build array data - let array_data = ArrayDataBuilder::new(self.data_type.clone()) - .len(children_array_len) - .null_bit_buffer(bitmap_builder.finish()) - .child_data( - children_array - .iter() - .map(|x| x.data().clone()) - .collect::>(), - ); - let array_data = unsafe { array_data.build_unchecked() }; - - // calculate struct rep level data, since struct doesn't add to repetition - // levels, here we just need to keep repetition levels of first array - // TODO: Verify that all children array reader has same repetition levels - let rep_level_data = self - .children - .first() - .ok_or_else(|| { - general_err!("Struct array reader should have at least one child!") - })? - .get_rep_levels() - .map(|data| -> Result { - let mut buffer = Int16BufferBuilder::new(children_array_len); - buffer.append_slice(data); - Ok(buffer.finish()) - }) - .transpose()?; - - self.def_level_buffer = Some(def_level_data_buffer.into()); - self.rep_level_buffer = rep_level_data; - Ok(Arc::new(StructArray::from(array_data))) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_level_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_level_buffer - .as_ref() - .map(|buf| unsafe { buf.typed_data() }) - } -} - -/// Create array reader from parquet schema, column indices, and parquet file reader. -pub fn build_array_reader( - parquet_schema: SchemaDescPtr, - arrow_schema: Schema, - column_indices: T, - file_reader: Arc, -) -> Result> -where - T: IntoIterator, -{ - let mut leaves = HashMap::<*const Type, usize>::new(); - - let mut filtered_root_names = HashSet::::new(); - - for c in column_indices { - let column = parquet_schema.column(c).self_type() as *const Type; - - leaves.insert(column, c); - - let root = parquet_schema.get_column_root_ptr(c); - filtered_root_names.insert(root.name().to_string()); - } - - if leaves.is_empty() { - return Err(general_err!("Can't build array reader without columns!")); - } - - // Only pass root fields that take part in the projection - // to avoid traversal of columns that are not read. - // TODO: also prune unread parts of the tree in child structures - let filtered_root_fields = parquet_schema - .root_schema() - .get_fields() - .iter() - .filter(|field| filtered_root_names.contains(field.name())) - .cloned() - .collect::>(); - - let proj = Type::GroupType { - basic_info: parquet_schema.root_schema().get_basic_info().clone(), - fields: filtered_root_fields, - }; - - ArrayReaderBuilder::new( - Arc::new(proj), - Arc::new(arrow_schema), - Arc::new(leaves), - file_reader, - ) - .build_array_reader() -} - -/// Used to build array reader. -struct ArrayReaderBuilder { - root_schema: TypePtr, - arrow_schema: Arc, - // Key: columns that need to be included in final array builder - // Value: column index in schema - columns_included: Arc>, - file_reader: Arc, -} - -/// Used in type visitor. -#[derive(Clone)] -struct ArrayReaderBuilderContext { - def_level: i16, - rep_level: i16, - path: ColumnPath, -} - -impl Default for ArrayReaderBuilderContext { - fn default() -> Self { - Self { - def_level: 0i16, - rep_level: 0i16, - path: ColumnPath::new(Vec::new()), - } - } -} - -/// Create array reader by visiting schema. -impl<'a> TypeVisitor>, &'a ArrayReaderBuilderContext> - for ArrayReaderBuilder -{ - /// Build array reader for primitive type. - fn visit_primitive( - &mut self, - cur_type: TypePtr, - context: &'a ArrayReaderBuilderContext, - ) -> Result>> { - if self.is_included(cur_type.as_ref()) { - let mut new_context = context.clone(); - new_context.path.append(vec![cur_type.name().to_string()]); - - match cur_type.get_basic_info().repetition() { - Repetition::REPEATED => { - new_context.def_level += 1; - new_context.rep_level += 1; - } - Repetition::OPTIONAL => { - new_context.def_level += 1; - } - _ => (), - } - - let reader = - self.build_for_primitive_type_inner(cur_type.clone(), &new_context)?; - - if cur_type.get_basic_info().repetition() == Repetition::REPEATED { - Err(ArrowError( - "Reading repeated field is not supported yet!".to_string(), - )) - } else { - Ok(Some(reader)) - } - } else { - Ok(None) - } - } - - /// Build array reader for struct type. - fn visit_struct( - &mut self, - cur_type: Arc, - context: &'a ArrayReaderBuilderContext, - ) -> Result>> { - let mut new_context = context.clone(); - new_context.path.append(vec![cur_type.name().to_string()]); - - if cur_type.get_basic_info().has_repetition() { - match cur_type.get_basic_info().repetition() { - Repetition::REPEATED => { - new_context.def_level += 1; - new_context.rep_level += 1; - } - Repetition::OPTIONAL => { - new_context.def_level += 1; - } - _ => (), - } - } - - if let Some(reader) = self.build_for_struct_type_inner(&cur_type, &new_context)? { - if cur_type.get_basic_info().has_repetition() - && cur_type.get_basic_info().repetition() == Repetition::REPEATED - { - Err(ArrowError( - "Reading repeated field is not supported yet!".to_string(), - )) - } else { - Ok(Some(reader)) - } - } else { - Ok(None) - } - } - - /// Build array reader for map type. - fn visit_map( - &mut self, - map_type: Arc, - context: &'a ArrayReaderBuilderContext, - ) -> Result>> { - // Add map type to context - let mut new_context = context.clone(); - new_context.path.append(vec![map_type.name().to_string()]); - if let Repetition::OPTIONAL = map_type.get_basic_info().repetition() { - new_context.def_level += 1; - } - - // Add map entry (key_value) to context - let map_key_value = map_type.get_fields().first().ok_or_else(|| { - ArrowError("Map field must have a key_value entry".to_string()) - })?; - new_context - .path - .append(vec![map_key_value.name().to_string()]); - new_context.rep_level += 1; - - // Get key and value, and create context for each - let map_key = map_key_value - .get_fields() - .first() - .ok_or_else(|| ArrowError("Map entry must have a key".to_string()))?; - let map_value = map_key_value - .get_fields() - .get(1) - .ok_or_else(|| ArrowError("Map entry must have a value".to_string()))?; - - let key_reader = { - let mut key_context = new_context.clone(); - key_context.def_level += 1; - key_context.path.append(vec![map_key.name().to_string()]); - self.dispatch(map_key.clone(), &key_context)?.unwrap() - }; - let value_reader = { - let mut value_context = new_context.clone(); - if let Repetition::OPTIONAL = map_value.get_basic_info().repetition() { - value_context.def_level += 1; - } - self.dispatch(map_value.clone(), &value_context)?.unwrap() - }; - - let arrow_type = self - .arrow_schema - .field_with_name(map_type.name()) - .ok() - .map(|f| f.data_type().to_owned()) - .unwrap_or_else(|| { - ArrowType::Map( - Box::new(Field::new( - map_key_value.name(), - ArrowType::Struct(vec![ - Field::new( - map_key.name(), - key_reader.get_data_type().clone(), - false, - ), - Field::new( - map_value.name(), - value_reader.get_data_type().clone(), - map_value.is_optional(), - ), - ]), - map_type.is_optional(), - )), - false, - ) - }); - - let key_array_reader: Box = Box::new(MapArrayReader::new( - key_reader, - value_reader, - arrow_type, - new_context.def_level, - new_context.rep_level, - )); - - Ok(Some(key_array_reader)) - } - - /// Build array reader for list type. - fn visit_list_with_item( - &mut self, - list_type: Arc, - item_type: Arc, - context: &'a ArrayReaderBuilderContext, - ) -> Result>> { - let mut list_child = &list_type - .get_fields() - .first() - .ok_or_else(|| ArrowError("List field must have a child.".to_string()))? - .clone(); - let mut new_context = context.clone(); - - new_context.path.append(vec![list_type.name().to_string()]); - // We need to know at what definition a list or its child is null - let list_null_def = new_context.def_level; - let mut list_empty_def = new_context.def_level; - - // If the list's root is nullable - if let Repetition::OPTIONAL = list_type.get_basic_info().repetition() { - new_context.def_level += 1; - // current level is nullable, increment to get level for empty list slot - list_empty_def += 1; - } - - match list_child.get_basic_info().repetition() { - Repetition::REPEATED => { - new_context.def_level += 1; - new_context.rep_level += 1; - } - Repetition::OPTIONAL => { - new_context.def_level += 1; - } - _ => (), - } - - let item_reader = self - .dispatch(item_type.clone(), &new_context) - .unwrap() - .unwrap(); - - let item_reader_type = item_reader.get_data_type().clone(); - - match item_reader_type { - ArrowType::List(_) - | ArrowType::FixedSizeList(_, _) - | ArrowType::Struct(_) - | ArrowType::Dictionary(_, _) => Err(ArrowError(format!( - "reading List({:?}) into arrow not supported yet", - item_type - ))), - _ => { - // a list is a group type with a single child. The list child's - // name comes from the child's field name. - // if the child's name is "list" and it has a child, then use this child - if list_child.name() == "list" && !list_child.get_fields().is_empty() { - list_child = list_child.get_fields().first().unwrap(); - } - let arrow_type = self - .arrow_schema - .field_with_name(list_type.name()) - .ok() - .map(|f| f.data_type().to_owned()) - .unwrap_or_else(|| { - ArrowType::List(Box::new(Field::new( - list_child.name(), - item_reader_type.clone(), - list_child.is_optional(), - ))) - }); - - let list_array_reader: Box = match arrow_type { - ArrowType::List(_) => Box::new(ListArrayReader::::new( - item_reader, - arrow_type, - item_reader_type, - new_context.def_level, - new_context.rep_level, - list_null_def, - list_empty_def, - )), - ArrowType::LargeList(_) => Box::new(ListArrayReader::::new( - item_reader, - arrow_type, - item_reader_type, - new_context.def_level, - new_context.rep_level, - list_null_def, - list_empty_def, - )), - - _ => { - return Err(ArrowError(format!( - "creating ListArrayReader with type {:?} should be unreachable", - arrow_type - ))) - } - }; - - Ok(Some(list_array_reader)) - } - } - } -} - -impl<'a> ArrayReaderBuilder { - /// Construct array reader builder. - fn new( - root_schema: TypePtr, - arrow_schema: Arc, - columns_included: Arc>, - file_reader: Arc, - ) -> Self { - Self { - root_schema, - arrow_schema, - columns_included, - file_reader, - } - } - - /// Main entry point. - fn build_array_reader(&mut self) -> Result> { - let context = ArrayReaderBuilderContext::default(); - - self.visit_struct(self.root_schema.clone(), &context) - .and_then(|reader_opt| { - reader_opt.ok_or_else(|| general_err!("Failed to build array reader!")) - }) - } - - // Utility functions - - /// Check whether one column in included in this array reader builder. - fn is_included(&self, t: &Type) -> bool { - self.columns_included.contains_key(&(t as *const Type)) - } - - /// Creates primitive array reader for each primitive type. - fn build_for_primitive_type_inner( - &self, - cur_type: TypePtr, - context: &'a ArrayReaderBuilderContext, - ) -> Result> { - let column_desc = Arc::new(ColumnDescriptor::new( - cur_type.clone(), - context.def_level, - context.rep_level, - context.path.clone(), - )); - let page_iterator = Box::new(FilePageIterator::new( - self.columns_included[&(cur_type.as_ref() as *const Type)], - self.file_reader.clone(), - )?); - - let arrow_type: Option = self - .get_arrow_field(&cur_type, context) - .map(|f| f.data_type().clone()); - - match cur_type.get_physical_type() { - PhysicalType::BOOLEAN => Ok(Box::new(PrimitiveArrayReader::::new( - page_iterator, - column_desc, - arrow_type, - )?)), - PhysicalType::INT32 => { - if let Some(ArrowType::Null) = arrow_type { - Ok(Box::new(NullArrayReader::::new( - page_iterator, - column_desc, - )?)) - } else { - Ok(Box::new(PrimitiveArrayReader::::new( - page_iterator, - column_desc, - arrow_type, - )?)) - } - } - PhysicalType::INT64 => Ok(Box::new(PrimitiveArrayReader::::new( - page_iterator, - column_desc, - arrow_type, - )?)), - PhysicalType::INT96 => { - // get the optional timezone information from arrow type - let timezone = arrow_type - .as_ref() - .map(|data_type| { - if let ArrowType::Timestamp(_, tz) = data_type { - tz.clone() - } else { - None - } - }) - .flatten(); - let converter = Int96Converter::new(Int96ArrayConverter { timezone }); - Ok(Box::new(ComplexObjectArrayReader::< - Int96Type, - Int96Converter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - PhysicalType::FLOAT => Ok(Box::new(PrimitiveArrayReader::::new( - page_iterator, - column_desc, - arrow_type, - )?)), - PhysicalType::DOUBLE => { - Ok(Box::new(PrimitiveArrayReader::::new( - page_iterator, - column_desc, - arrow_type, - )?)) - } - PhysicalType::BYTE_ARRAY => { - if cur_type.get_basic_info().converted_type() == ConvertedType::UTF8 { - if let Some(ArrowType::LargeUtf8) = arrow_type { - let converter = - LargeUtf8Converter::new(LargeUtf8ArrayConverter {}); - Ok(Box::new(ComplexObjectArrayReader::< - ByteArrayType, - LargeUtf8Converter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } else { - use crate::arrow::arrow_array_reader::{ - ArrowArrayReader, StringArrayConverter, - }; - let converter = StringArrayConverter::new(); - Ok(Box::new(ArrowArrayReader::try_new( - *page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - } else if let Some(ArrowType::LargeBinary) = arrow_type { - let converter = - LargeBinaryConverter::new(LargeBinaryArrayConverter {}); - Ok(Box::new(ComplexObjectArrayReader::< - ByteArrayType, - LargeBinaryConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } else { - let converter = BinaryConverter::new(BinaryArrayConverter {}); - Ok(Box::new(ComplexObjectArrayReader::< - ByteArrayType, - BinaryConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - } - PhysicalType::FIXED_LEN_BYTE_ARRAY - if cur_type.get_basic_info().converted_type() - == ConvertedType::DECIMAL => - { - let converter = DecimalConverter::new(DecimalArrayConverter::new( - cur_type.get_precision(), - cur_type.get_scale(), - )); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - DecimalConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - PhysicalType::FIXED_LEN_BYTE_ARRAY => { - let byte_width = match *cur_type { - Type::PrimitiveType { - ref type_length, .. - } => *type_length, - _ => { - return Err(ArrowError( - "Expected a physical type, not a group type".to_string(), - )) - } - }; - if cur_type.get_basic_info().converted_type() == ConvertedType::INTERVAL { - if byte_width != 12 { - return Err(ArrowError(format!( - "Parquet interval type should have length of 12, found {}", - byte_width - ))); - } - match arrow_type { - Some(ArrowType::Interval(IntervalUnit::DayTime)) => { - let converter = IntervalDayTimeConverter::new( - IntervalDayTimeArrayConverter {}, - ); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - IntervalDayTimeConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - Some(ArrowType::Interval(IntervalUnit::YearMonth)) => { - let converter = IntervalYearMonthConverter::new( - IntervalYearMonthArrayConverter {}, - ); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - IntervalYearMonthConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - Some(t) => Err(ArrowError(format!( - "Cannot write a Parquet interval to {:?}", - t - ))), - None => { - // we do not support an interval not matched to an Arrow type, - // because we risk data loss as we won't know which of the 12 bytes - // are or should be populated - Err(ArrowError( - "Cannot write a Parquet interval with no Arrow type specified. - There is a risk of data loss as Arrow either supports YearMonth or - DayTime precision. Without the Arrow type, we cannot infer the type. - ".to_string() - )) - } - } - } else { - let converter = FixedLenBinaryConverter::new( - FixedSizeArrayConverter::new(byte_width), - ); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - FixedLenBinaryConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - } - } - } - - /// Constructs struct array reader without considering repetition. - fn build_for_struct_type_inner( - &mut self, - cur_type: &Type, - context: &'a ArrayReaderBuilderContext, - ) -> Result>> { - let mut fields = Vec::with_capacity(cur_type.get_fields().len()); - let mut children_reader = Vec::with_capacity(cur_type.get_fields().len()); - - for child in cur_type.get_fields() { - let mut struct_context = context.clone(); - if let Some(child_reader) = self.dispatch(child.clone(), context)? { - // TODO: this results in calling get_arrow_field twice, it could be reused - // from child_reader above, by making child_reader carry its `Field` - struct_context.path.append(vec![child.name().to_string()]); - let field = match self.get_arrow_field(child, &struct_context) { - Some(f) => f.clone(), - _ => Field::new( - child.name(), - child_reader.get_data_type().clone(), - child.is_optional(), - ), - }; - fields.push(field); - children_reader.push(child_reader); - } - } - - if !fields.is_empty() { - let arrow_type = ArrowType::Struct(fields); - Ok(Some(Box::new(StructArrayReader::new( - arrow_type, - children_reader, - context.def_level, - context.rep_level, - )))) - } else { - Ok(None) - } - } - - fn get_arrow_field( - &self, - cur_type: &Type, - context: &'a ArrayReaderBuilderContext, - ) -> Option<&Field> { - let parts: Vec<&str> = context - .path - .parts() - .iter() - .map(|x| -> &str { x }) - .collect::>(); - - // If the parts length is one it'll have the top level "schema" type. If - // it's two then it'll be a top-level type that we can get from the arrow - // schema directly. - if parts.len() <= 2 { - self.arrow_schema.field_with_name(cur_type.name()).ok() - } else { - // If it's greater than two then we need to traverse the type path - // until we find the actual field we're looking for. - let mut field: Option<&Field> = None; - - for (i, part) in parts.iter().enumerate().skip(1) { - if i == 1 { - field = self.arrow_schema.field_with_name(part).ok(); - } else if let Some(f) = field { - if let ArrowType::Struct(fields) = f.data_type() { - field = fields.iter().find(|f| f.name() == part) - } else { - field = None - } - } else { - field = None - } - } - field - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::arrow::converter::{Utf8ArrayConverter, Utf8Converter}; - use crate::arrow::schema::parquet_to_arrow_schema; - use crate::basic::{Encoding, Type as PhysicalType}; - use crate::column::page::{Page, PageReader}; - use crate::data_type::{ByteArray, DataType, Int32Type, Int64Type}; - use crate::errors::Result; - use crate::file::reader::{FileReader, SerializedFileReader}; - use crate::schema::parser::parse_message_type; - use crate::schema::types::{ColumnDescPtr, SchemaDescriptor}; - use crate::util::test_common::page_util::{ - DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator, - }; - use crate::util::test_common::{get_test_file, make_pages}; - use arrow::array::{ - Array, ArrayRef, LargeListArray, ListArray, PrimitiveArray, StringArray, - StructArray, - }; - use arrow::datatypes::{ - ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field, - Int32Type as ArrowInt32, Int64Type as ArrowInt64, - Time32MillisecondType as ArrowTime32MillisecondArray, - Time64MicrosecondType as ArrowTime64MicrosecondArray, - TimestampMicrosecondType as ArrowTimestampMicrosecondType, - TimestampMillisecondType as ArrowTimestampMillisecondType, - }; - use rand::distributions::uniform::SampleUniform; - use rand::{thread_rng, Rng}; - use std::any::Any; - use std::collections::VecDeque; - use std::sync::Arc; - - fn make_column_chunks( - column_desc: ColumnDescPtr, - encoding: Encoding, - num_levels: usize, - min_value: T::T, - max_value: T::T, - def_levels: &mut Vec, - rep_levels: &mut Vec, - values: &mut Vec, - page_lists: &mut Vec>, - use_v2: bool, - num_chunks: usize, - ) where - T::T: PartialOrd + SampleUniform + Copy, - { - for _i in 0..num_chunks { - let mut pages = VecDeque::new(); - let mut data = Vec::new(); - let mut page_def_levels = Vec::new(); - let mut page_rep_levels = Vec::new(); - - make_pages::( - column_desc.clone(), - encoding, - 1, - num_levels, - min_value, - max_value, - &mut page_def_levels, - &mut page_rep_levels, - &mut data, - &mut pages, - use_v2, - ); - - def_levels.append(&mut page_def_levels); - rep_levels.append(&mut page_rep_levels); - values.append(&mut data); - page_lists.push(Vec::from(pages)); - } - } - - #[test] - fn test_primitive_array_reader_empty_pages() { - // Construct column schema - let message_type = " - message test_schema { - REQUIRED INT32 leaf; - } - "; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let column_desc = schema.column(0); - let page_iterator = EmptyPageIterator::new(schema); - - let mut array_reader = PrimitiveArrayReader::::new( - Box::new(page_iterator), - column_desc, - None, - ) - .unwrap(); - - // expect no values to be read - let array = array_reader.next_batch(50).unwrap(); - assert!(array.is_empty()); - } - - #[test] - fn test_primitive_array_reader_data() { - // Construct column schema - let message_type = " - message test_schema { - REQUIRED INT32 leaf; - } - "; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let column_desc = schema.column(0); - - // Construct page iterator - { - let mut data = Vec::new(); - let mut page_lists = Vec::new(); - make_column_chunks::( - column_desc.clone(), - Encoding::PLAIN, - 100, - 1, - 200, - &mut Vec::new(), - &mut Vec::new(), - &mut data, - &mut page_lists, - true, - 2, - ); - let page_iterator = - InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); - - let mut array_reader = PrimitiveArrayReader::::new( - Box::new(page_iterator), - column_desc, - None, - ) - .unwrap(); - - // Read first 50 values, which are all from the first column chunk - let array = array_reader.next_batch(50).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::::from(data[0..50].to_vec()), - array - ); - - // Read next 100 values, the first 50 ones are from the first column chunk, - // and the last 50 ones are from the second column chunk - let array = array_reader.next_batch(100).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::::from(data[50..150].to_vec()), - array - ); - - // Try to read 100 values, however there are only 50 values - let array = array_reader.next_batch(100).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::::from(data[150..200].to_vec()), - array - ); - } - } - - macro_rules! test_primitive_array_reader_one_type { - ($arrow_parquet_type:ty, $physical_type:expr, $converted_type_str:expr, $result_arrow_type:ty, $result_arrow_cast_type:ty, $result_primitive_type:ty) => {{ - let message_type = format!( - " - message test_schema {{ - REQUIRED {:?} leaf ({}); - }} - ", - $physical_type, $converted_type_str - ); - let schema = parse_message_type(&message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let column_desc = schema.column(0); - - // Construct page iterator - { - let mut data = Vec::new(); - let mut page_lists = Vec::new(); - make_column_chunks::<$arrow_parquet_type>( - column_desc.clone(), - Encoding::PLAIN, - 100, - 1, - 200, - &mut Vec::new(), - &mut Vec::new(), - &mut data, - &mut page_lists, - true, - 2, - ); - let page_iterator = InMemoryPageIterator::new( - schema.clone(), - column_desc.clone(), - page_lists, - ); - let mut array_reader = PrimitiveArrayReader::<$arrow_parquet_type>::new( - Box::new(page_iterator), - column_desc.clone(), - None, - ) - .expect("Unable to get array reader"); - - let array = array_reader - .next_batch(50) - .expect("Unable to get batch from reader"); - - let result_data_type = <$result_arrow_type>::DATA_TYPE; - let array = array - .as_any() - .downcast_ref::>() - .expect( - format!( - "Unable to downcast {:?} to {:?}", - array.data_type(), - result_data_type - ) - .as_str(), - ); - - // create expected array as primitive, and cast to result type - let expected = PrimitiveArray::<$result_arrow_cast_type>::from( - data[0..50] - .iter() - .map(|x| *x as $result_primitive_type) - .collect::>(), - ); - let expected = Arc::new(expected) as ArrayRef; - let expected = arrow::compute::cast(&expected, &result_data_type) - .expect("Unable to cast expected array"); - assert_eq!(expected.data_type(), &result_data_type); - let expected = expected - .as_any() - .downcast_ref::>() - .expect( - format!( - "Unable to downcast expected {:?} to {:?}", - expected.data_type(), - result_data_type - ) - .as_str(), - ); - assert_eq!(expected, array); - } - }}; - } - - #[test] - fn test_primitive_array_reader_temporal_types() { - test_primitive_array_reader_one_type!( - Int32Type, - PhysicalType::INT32, - "DATE", - ArrowDate32, - ArrowInt32, - i32 - ); - test_primitive_array_reader_one_type!( - Int32Type, - PhysicalType::INT32, - "TIME_MILLIS", - ArrowTime32MillisecondArray, - ArrowInt32, - i32 - ); - test_primitive_array_reader_one_type!( - Int64Type, - PhysicalType::INT64, - "TIME_MICROS", - ArrowTime64MicrosecondArray, - ArrowInt64, - i64 - ); - test_primitive_array_reader_one_type!( - Int64Type, - PhysicalType::INT64, - "TIMESTAMP_MILLIS", - ArrowTimestampMillisecondType, - ArrowInt64, - i64 - ); - test_primitive_array_reader_one_type!( - Int64Type, - PhysicalType::INT64, - "TIMESTAMP_MICROS", - ArrowTimestampMicrosecondType, - ArrowInt64, - i64 - ); - } - - #[test] - fn test_primitive_array_reader_def_and_rep_levels() { - // Construct column schema - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL INT32 leaf; - } - } - "; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let column_desc = schema.column(0); - - // Construct page iterator - { - let mut def_levels = Vec::new(); - let mut rep_levels = Vec::new(); - let mut page_lists = Vec::new(); - make_column_chunks::( - column_desc.clone(), - Encoding::PLAIN, - 100, - 1, - 200, - &mut def_levels, - &mut rep_levels, - &mut Vec::new(), - &mut page_lists, - true, - 2, - ); - - let page_iterator = - InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); - - let mut array_reader = PrimitiveArrayReader::::new( - Box::new(page_iterator), - column_desc, - None, - ) - .unwrap(); - - let mut accu_len: usize = 0; - - // Read first 50 values, which are all from the first column chunk - let array = array_reader.next_batch(50).unwrap(); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Read next 100 values, the first 50 ones are from the first column chunk, - // and the last 50 ones are from the second column chunk - let array = array_reader.next_batch(100).unwrap(); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Try to read 100 values, however there are only 50 values - let array = array_reader.next_batch(100).unwrap(); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - } - } - - #[test] - fn test_complex_array_reader_no_pages() { - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL BYTE_ARRAY leaf (UTF8); - } - } - "; - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - let column_desc = schema.column(0); - let pages: Vec> = Vec::new(); - let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); - - let converter = Utf8Converter::new(Utf8ArrayConverter {}); - let mut array_reader = - ComplexObjectArrayReader::::new( - Box::new(page_iterator), - column_desc, - converter, - None, - ) - .unwrap(); - - let values_per_page = 100; // this value is arbitrary in this test - the result should always be an array of 0 length - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), 0); - } - - #[test] - fn test_complex_array_reader_def_and_rep_levels() { - // Construct column schema - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL BYTE_ARRAY leaf (UTF8); - } - } - "; - let num_pages = 2; - let values_per_page = 100; - let str_base = "Hello World"; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let max_def_level = schema.column(0).max_def_level(); - let max_rep_level = schema.column(0).max_rep_level(); - - assert_eq!(max_def_level, 2); - assert_eq!(max_rep_level, 1); - - let mut rng = thread_rng(); - let column_desc = schema.column(0); - let mut pages: Vec> = Vec::new(); - - let mut rep_levels = Vec::with_capacity(num_pages * values_per_page); - let mut def_levels = Vec::with_capacity(num_pages * values_per_page); - let mut all_values = Vec::with_capacity(num_pages * values_per_page); - - for i in 0..num_pages { - let mut values = Vec::with_capacity(values_per_page); - - for _ in 0..values_per_page { - let def_level = rng.gen_range(0..max_def_level + 1); - let rep_level = rng.gen_range(0..max_rep_level + 1); - if def_level == max_def_level { - let len = rng.gen_range(1..str_base.len()); - let slice = &str_base[..len]; - values.push(ByteArray::from(slice)); - all_values.push(Some(slice.to_string())); - } else { - all_values.push(None) - } - rep_levels.push(rep_level); - def_levels.push(def_level) - } - - let range = i * values_per_page..(i + 1) * values_per_page; - let mut pb = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - - pb.add_rep_levels(max_rep_level, &rep_levels.as_slice()[range.clone()]); - pb.add_def_levels(max_def_level, &def_levels.as_slice()[range]); - pb.add_values::(Encoding::PLAIN, values.as_slice()); - - let data_page = pb.consume(); - pages.push(vec![data_page]); - } - - let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); - - let converter = Utf8Converter::new(Utf8ArrayConverter {}); - let mut array_reader = - ComplexObjectArrayReader::::new( - Box::new(page_iterator), - column_desc, - converter, - None, - ) - .unwrap(); - - let mut accu_len: usize = 0; - - let array = array_reader.next_batch(values_per_page / 2).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, - // and the last values_per_page/2 ones are from the second column chunk - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - let strings = array.as_any().downcast_ref::().unwrap(); - for i in 0..array.len() { - if array.is_valid(i) { - assert_eq!( - all_values[i + accu_len].as_ref().unwrap().as_str(), - strings.value(i) - ) - } else { - assert_eq!(all_values[i + accu_len], None) - } - } - accu_len += array.len(); - - // Try to read values_per_page values, however there are only values_per_page/2 values - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - } - - /// Array reader for test. - struct InMemoryArrayReader { - data_type: ArrowType, - array: ArrayRef, - def_levels: Option>, - rep_levels: Option>, - } - - impl InMemoryArrayReader { - pub fn new( - data_type: ArrowType, - array: ArrayRef, - def_levels: Option>, - rep_levels: Option>, - ) -> Self { - Self { - data_type, - array, - def_levels, - rep_levels, - } - } - } - - impl ArrayReader for InMemoryArrayReader { - fn as_any(&self) -> &dyn Any { - self - } - - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - fn next_batch(&mut self, _batch_size: usize) -> Result { - Ok(self.array.clone()) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_levels.as_deref() - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_levels.as_deref() - } - } - - /// Iterator for testing reading empty columns - struct EmptyPageIterator { - schema: SchemaDescPtr, - } - - impl EmptyPageIterator { - fn new(schema: SchemaDescPtr) -> Self { - EmptyPageIterator { schema } - } - } - - impl Iterator for EmptyPageIterator { - type Item = Result>; - - fn next(&mut self) -> Option { - None - } - } - - impl PageIterator for EmptyPageIterator { - fn schema(&mut self) -> Result { - Ok(self.schema.clone()) - } - - fn column_schema(&mut self) -> Result { - Ok(self.schema.column(0)) - } - } - - #[test] - fn test_struct_array_reader() { - let array_1 = Arc::new(PrimitiveArray::::from(vec![1, 2, 3, 4, 5])); - let array_reader_1 = InMemoryArrayReader::new( - ArrowType::Int32, - array_1.clone(), - Some(vec![0, 1, 2, 3, 1]), - Some(vec![1, 1, 1, 1, 1]), - ); - - let array_2 = Arc::new(PrimitiveArray::::from(vec![5, 4, 3, 2, 1])); - let array_reader_2 = InMemoryArrayReader::new( - ArrowType::Int32, - array_2.clone(), - Some(vec![0, 1, 3, 1, 2]), - Some(vec![1, 1, 1, 1, 1]), - ); - - let struct_type = ArrowType::Struct(vec![ - Field::new("f1", array_1.data_type().clone(), true), - Field::new("f2", array_2.data_type().clone(), true), - ]); - - let mut struct_array_reader = StructArrayReader::new( - struct_type, - vec![Box::new(array_reader_1), Box::new(array_reader_2)], - 1, - 1, - ); - - let struct_array = struct_array_reader.next_batch(5).unwrap(); - let struct_array = struct_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(5, struct_array.len()); - assert_eq!( - vec![true, false, false, false, false], - (0..5) - .map(|idx| struct_array.data_ref().is_null(idx)) - .collect::>() - ); - assert_eq!( - Some(vec![0, 1, 1, 1, 1].as_slice()), - struct_array_reader.get_def_levels() - ); - assert_eq!( - Some(vec![1, 1, 1, 1, 1].as_slice()), - struct_array_reader.get_rep_levels() - ); - } - - #[test] - fn test_create_array_reader() { - let file = get_test_file("nulls.snappy.parquet"); - let file_reader = Arc::new(SerializedFileReader::new(file).unwrap()); - - let file_metadata = file_reader.metadata().file_metadata(); - let arrow_schema = parquet_to_arrow_schema( - file_metadata.schema_descr(), - file_metadata.key_value_metadata(), - ) - .unwrap(); - - let array_reader = build_array_reader( - file_reader.metadata().file_metadata().schema_descr_ptr(), - arrow_schema, - vec![0usize].into_iter(), - file_reader, - ) - .unwrap(); - - // Create arrow types - let arrow_type = ArrowType::Struct(vec![Field::new( - "b_struct", - ArrowType::Struct(vec![Field::new("b_c_int", ArrowType::Int32, true)]), - true, - )]); - - assert_eq!(array_reader.get_data_type(), &arrow_type); - } - - #[test] - fn test_list_array_reader() { - // [[1, null, 2], null, [3, 4]] - let array = Arc::new(PrimitiveArray::::from(vec![ - Some(1), - None, - Some(2), - None, - Some(3), - Some(4), - ])); - let item_array_reader = InMemoryArrayReader::new( - ArrowType::Int32, - array, - Some(vec![3, 2, 3, 0, 3, 3]), - Some(vec![0, 1, 1, 0, 0, 1]), - ); - - let mut list_array_reader = ListArrayReader::::new( - Box::new(item_array_reader), - ArrowType::List(Box::new(Field::new("item", ArrowType::Int32, true))), - ArrowType::Int32, - 1, - 1, - 0, - 1, - ); - - let next_batch = list_array_reader.next_batch(1024).unwrap(); - let list_array = next_batch.as_any().downcast_ref::().unwrap(); - - assert_eq!(3, list_array.len()); - // This passes as I expect - assert_eq!(1, list_array.null_count()); - - assert_eq!( - list_array - .value(0) - .as_any() - .downcast_ref::>() - .unwrap(), - &PrimitiveArray::::from(vec![Some(1), None, Some(2)]) - ); - - assert!(list_array.is_null(1)); - - assert_eq!( - list_array - .value(2) - .as_any() - .downcast_ref::>() - .unwrap(), - &PrimitiveArray::::from(vec![Some(3), Some(4)]) - ); - } - - #[test] - fn test_large_list_array_reader() { - // [[1, null, 2], null, [3, 4]] - let array = Arc::new(PrimitiveArray::::from(vec![ - Some(1), - None, - Some(2), - None, - Some(3), - Some(4), - ])); - let item_array_reader = InMemoryArrayReader::new( - ArrowType::Int32, - array, - Some(vec![3, 2, 3, 0, 3, 3]), - Some(vec![0, 1, 1, 0, 0, 1]), - ); - - let mut list_array_reader = ListArrayReader::::new( - Box::new(item_array_reader), - ArrowType::LargeList(Box::new(Field::new("item", ArrowType::Int32, true))), - ArrowType::Int32, - 1, - 1, - 0, - 1, - ); - - let next_batch = list_array_reader.next_batch(1024).unwrap(); - let list_array = next_batch - .as_any() - .downcast_ref::() - .unwrap(); - - assert_eq!(3, list_array.len()); - - assert_eq!( - list_array - .value(0) - .as_any() - .downcast_ref::>() - .unwrap(), - &PrimitiveArray::::from(vec![Some(1), None, Some(2)]) - ); - - assert!(list_array.is_null(1)); - - assert_eq!( - list_array - .value(2) - .as_any() - .downcast_ref::>() - .unwrap(), - &PrimitiveArray::::from(vec![Some(3), Some(4)]) - ); - } -} diff --git a/parquet/src/arrow/array_reader/builder.rs b/parquet/src/arrow/array_reader/builder.rs new file mode 100644 index 000000000000..e8c22f95aa0a --- /dev/null +++ b/parquet/src/arrow/array_reader/builder.rs @@ -0,0 +1,368 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, IntervalUnit, SchemaRef}; + +use crate::arrow::array_reader::empty_array::make_empty_array_reader; +use crate::arrow::array_reader::{ + make_byte_array_dictionary_reader, make_byte_array_reader, ArrayReader, + ComplexObjectArrayReader, ListArrayReader, MapArrayReader, NullArrayReader, + PrimitiveArrayReader, RowGroupCollection, StructArrayReader, +}; +use crate::arrow::buffer::converter::{ + DecimalArrayConverter, DecimalConverter, FixedLenBinaryConverter, + FixedSizeArrayConverter, Int96ArrayConverter, Int96Converter, + IntervalDayTimeArrayConverter, IntervalDayTimeConverter, + IntervalYearMonthArrayConverter, IntervalYearMonthConverter, +}; +use crate::arrow::schema::{convert_schema, ParquetField, ParquetFieldType}; +use crate::arrow::ProjectionMask; +use crate::basic::Type as PhysicalType; +use crate::data_type::{ + BoolType, DoubleType, FixedLenByteArrayType, FloatType, Int32Type, Int64Type, + Int96Type, +}; +use crate::errors::Result; +use crate::schema::types::{ColumnDescriptor, ColumnPath, SchemaDescPtr, Type}; + +/// Create array reader from parquet schema, projection mask, and parquet file reader. +pub fn build_array_reader( + parquet_schema: SchemaDescPtr, + arrow_schema: SchemaRef, + mask: ProjectionMask, + row_groups: Box, +) -> Result> { + let field = + convert_schema(parquet_schema.as_ref(), mask, Some(arrow_schema.as_ref()))?; + + match &field { + Some(field) => build_reader(field, row_groups.as_ref()), + None => Ok(make_empty_array_reader(row_groups.num_rows())), + } +} + +fn build_reader( + field: &ParquetField, + row_groups: &dyn RowGroupCollection, +) -> Result> { + match field.field_type { + ParquetFieldType::Primitive { .. } => build_primitive_reader(field, row_groups), + ParquetFieldType::Group { .. } => match &field.arrow_type { + DataType::Map(_, _) => build_map_reader(field, row_groups), + DataType::Struct(_) => build_struct_reader(field, row_groups), + DataType::List(_) => build_list_reader(field, false, row_groups), + DataType::LargeList(_) => build_list_reader(field, true, row_groups), + d => unimplemented!("reading group type {} not implemented", d), + }, + } +} + +/// Build array reader for map type. +fn build_map_reader( + field: &ParquetField, + row_groups: &dyn RowGroupCollection, +) -> Result> { + let children = field.children().unwrap(); + assert_eq!(children.len(), 2); + + let key_reader = build_reader(&children[0], row_groups)?; + let value_reader = build_reader(&children[1], row_groups)?; + + Ok(Box::new(MapArrayReader::new( + key_reader, + value_reader, + field.arrow_type.clone(), + field.def_level, + field.rep_level, + ))) +} + +/// Build array reader for list type. +fn build_list_reader( + field: &ParquetField, + is_large: bool, + row_groups: &dyn RowGroupCollection, +) -> Result> { + let children = field.children().unwrap(); + assert_eq!(children.len(), 1); + + let data_type = field.arrow_type.clone(); + let item_reader = build_reader(&children[0], row_groups)?; + let item_type = item_reader.get_data_type().clone(); + + match is_large { + false => Ok(Box::new(ListArrayReader::::new( + item_reader, + data_type, + item_type, + field.def_level, + field.rep_level, + field.nullable, + )) as _), + true => Ok(Box::new(ListArrayReader::::new( + item_reader, + data_type, + item_type, + field.def_level, + field.rep_level, + field.nullable, + )) as _), + } +} + +/// Creates primitive array reader for each primitive type. +fn build_primitive_reader( + field: &ParquetField, + row_groups: &dyn RowGroupCollection, +) -> Result> { + let (col_idx, primitive_type, type_len) = match &field.field_type { + ParquetFieldType::Primitive { + col_idx, + primitive_type, + } => match primitive_type.as_ref() { + Type::PrimitiveType { type_length, .. } => { + (*col_idx, primitive_type.clone(), *type_length) + } + Type::GroupType { .. } => unreachable!(), + }, + _ => unreachable!(), + }; + + let physical_type = primitive_type.get_physical_type(); + + // We don't track the column path in ParquetField as it adds a potential source + // of bugs when the arrow mapping converts more than one level in the parquet + // schema into a single arrow field. + // + // None of the readers actually use this field, but it is required for this type, + // so just stick a placeholder in + let column_desc = Arc::new(ColumnDescriptor::new( + primitive_type, + field.def_level, + field.rep_level, + ColumnPath::new(vec![]), + )); + + let page_iterator = row_groups.column_chunks(col_idx)?; + let null_mask_only = field.def_level == 1 && field.nullable; + let arrow_type = Some(field.arrow_type.clone()); + + match physical_type { + PhysicalType::BOOLEAN => Ok(Box::new( + PrimitiveArrayReader::::new_with_options( + page_iterator, + column_desc, + arrow_type, + null_mask_only, + )?, + )), + PhysicalType::INT32 => { + if let Some(DataType::Null) = arrow_type { + Ok(Box::new(NullArrayReader::::new( + page_iterator, + column_desc, + )?)) + } else { + Ok(Box::new( + PrimitiveArrayReader::::new_with_options( + page_iterator, + column_desc, + arrow_type, + null_mask_only, + )?, + )) + } + } + PhysicalType::INT64 => Ok(Box::new( + PrimitiveArrayReader::::new_with_options( + page_iterator, + column_desc, + arrow_type, + null_mask_only, + )?, + )), + PhysicalType::INT96 => { + // get the optional timezone information from arrow type + let timezone = arrow_type.as_ref().and_then(|data_type| { + if let DataType::Timestamp(_, tz) = data_type { + tz.clone() + } else { + None + } + }); + let converter = Int96Converter::new(Int96ArrayConverter { timezone }); + Ok(Box::new(ComplexObjectArrayReader::< + Int96Type, + Int96Converter, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } + PhysicalType::FLOAT => Ok(Box::new( + PrimitiveArrayReader::::new_with_options( + page_iterator, + column_desc, + arrow_type, + null_mask_only, + )?, + )), + PhysicalType::DOUBLE => Ok(Box::new( + PrimitiveArrayReader::::new_with_options( + page_iterator, + column_desc, + arrow_type, + null_mask_only, + )?, + )), + PhysicalType::BYTE_ARRAY => match arrow_type { + Some(DataType::Dictionary(_, _)) => make_byte_array_dictionary_reader( + page_iterator, + column_desc, + arrow_type, + null_mask_only, + ), + _ => make_byte_array_reader( + page_iterator, + column_desc, + arrow_type, + null_mask_only, + ), + }, + PhysicalType::FIXED_LEN_BYTE_ARRAY => match field.arrow_type { + DataType::Decimal(precision, scale) => { + let converter = DecimalConverter::new(DecimalArrayConverter::new( + precision as i32, + scale as i32, + )); + Ok(Box::new(ComplexObjectArrayReader::< + FixedLenByteArrayType, + DecimalConverter, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } + DataType::Interval(IntervalUnit::DayTime) => { + let converter = + IntervalDayTimeConverter::new(IntervalDayTimeArrayConverter {}); + Ok(Box::new(ComplexObjectArrayReader::< + FixedLenByteArrayType, + _, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let converter = + IntervalYearMonthConverter::new(IntervalYearMonthArrayConverter {}); + Ok(Box::new(ComplexObjectArrayReader::< + FixedLenByteArrayType, + _, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } + _ => { + let converter = + FixedLenBinaryConverter::new(FixedSizeArrayConverter::new(type_len)); + Ok(Box::new(ComplexObjectArrayReader::< + FixedLenByteArrayType, + FixedLenBinaryConverter, + >::new( + page_iterator, + column_desc, + converter, + arrow_type, + )?)) + } + }, + } +} + +fn build_struct_reader( + field: &ParquetField, + row_groups: &dyn RowGroupCollection, +) -> Result> { + let children = field.children().unwrap(); + let children_reader = children + .iter() + .map(|child| build_reader(child, row_groups)) + .collect::>>()?; + + Ok(Box::new(StructArrayReader::new( + field.arrow_type.clone(), + children_reader, + field.def_level, + field.rep_level, + field.nullable, + )) as _) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::parquet_to_arrow_schema; + use crate::file::reader::{FileReader, SerializedFileReader}; + use crate::util::test_common::get_test_file; + use arrow::datatypes::Field; + use std::sync::Arc; + + #[test] + fn test_create_array_reader() { + let file = get_test_file("nulls.snappy.parquet"); + let file_reader: Arc = + Arc::new(SerializedFileReader::new(file).unwrap()); + + let file_metadata = file_reader.metadata().file_metadata(); + let mask = ProjectionMask::leaves(file_metadata.schema_descr(), [0]); + let arrow_schema = parquet_to_arrow_schema( + file_metadata.schema_descr(), + file_metadata.key_value_metadata(), + ) + .unwrap(); + + let array_reader = build_array_reader( + file_reader.metadata().file_metadata().schema_descr_ptr(), + Arc::new(arrow_schema), + mask, + Box::new(file_reader), + ) + .unwrap(); + + // Create arrow types + let arrow_type = DataType::Struct(vec![Field::new( + "b_struct", + DataType::Struct(vec![Field::new("b_c_int", DataType::Int32, true)]), + true, + )]); + + assert_eq!(array_reader.get_data_type(), &arrow_type); + } +} diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs new file mode 100644 index 000000000000..9e0f83fa9450 --- /dev/null +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -0,0 +1,665 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::array_reader::{read_records, ArrayReader}; +use crate::arrow::buffer::offset_buffer::OffsetBuffer; +use crate::arrow::record_reader::buffer::ScalarValue; +use crate::arrow::record_reader::GenericRecordReader; +use crate::arrow::schema::parquet_to_arrow_field; +use crate::basic::{ConvertedType, Encoding}; +use crate::column::page::PageIterator; +use crate::column::reader::decoder::ColumnValueDecoder; +use crate::data_type::Int32Type; +use crate::encodings::{ + decoding::{Decoder, DeltaBitPackDecoder}, + rle::RleDecoder, +}; +use crate::errors::{ParquetError, Result}; +use crate::schema::types::ColumnDescPtr; +use crate::util::memory::ByteBufferPtr; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::buffer::Buffer; +use arrow::datatypes::DataType as ArrowType; +use std::any::Any; +use std::ops::Range; + +/// Returns an [`ArrayReader`] that decodes the provided byte array column +pub fn make_byte_array_reader( + pages: Box, + column_desc: ColumnDescPtr, + arrow_type: Option, + null_mask_only: bool, +) -> Result> { + // Check if Arrow type is specified, else create it from Parquet type + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; + + match data_type { + ArrowType::Binary | ArrowType::Utf8 => { + let reader = + GenericRecordReader::new_with_options(column_desc, null_mask_only); + Ok(Box::new(ByteArrayReader::::new( + pages, data_type, reader, + ))) + } + ArrowType::LargeUtf8 | ArrowType::LargeBinary => { + let reader = + GenericRecordReader::new_with_options(column_desc, null_mask_only); + Ok(Box::new(ByteArrayReader::::new( + pages, data_type, reader, + ))) + } + _ => Err(general_err!( + "invalid data type for byte array reader - {}", + data_type + )), + } +} + +/// An [`ArrayReader`] for variable length byte arrays +struct ByteArrayReader { + data_type: ArrowType, + pages: Box, + def_levels_buffer: Option, + rep_levels_buffer: Option, + record_reader: GenericRecordReader, ByteArrayColumnValueDecoder>, +} + +impl ByteArrayReader { + fn new( + pages: Box, + data_type: ArrowType, + record_reader: GenericRecordReader< + OffsetBuffer, + ByteArrayColumnValueDecoder, + >, + ) -> Self { + Self { + data_type, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + record_reader, + } + } +} + +impl ArrayReader for ByteArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + let buffer = self.record_reader.consume_record_data()?; + let null_buffer = self.record_reader.consume_bitmap_buffer()?; + self.def_levels_buffer = self.record_reader.consume_def_levels()?; + self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; + self.record_reader.reset(); + + Ok(buffer.into_array(null_buffer, self.data_type.clone())) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } +} + +/// A [`ColumnValueDecoder`] for variable length byte arrays +struct ByteArrayColumnValueDecoder { + dict: Option>, + decoder: Option, + validate_utf8: bool, +} + +impl ColumnValueDecoder + for ByteArrayColumnValueDecoder +{ + type Slice = OffsetBuffer; + + fn new(desc: &ColumnDescPtr) -> Self { + let validate_utf8 = desc.converted_type() == ConvertedType::UTF8; + Self { + dict: None, + decoder: None, + validate_utf8, + } + } + + fn set_dict( + &mut self, + buf: ByteBufferPtr, + num_values: u32, + encoding: Encoding, + _is_sorted: bool, + ) -> Result<()> { + if !matches!( + encoding, + Encoding::PLAIN | Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY + ) { + return Err(nyi_err!( + "Invalid/Unsupported encoding type for dictionary: {}", + encoding + )); + } + + let mut buffer = OffsetBuffer::default(); + let mut decoder = ByteArrayDecoderPlain::new( + buf, + num_values as usize, + Some(num_values as usize), + self.validate_utf8, + ); + decoder.read(&mut buffer, usize::MAX)?; + self.dict = Some(buffer); + Ok(()) + } + + fn set_data( + &mut self, + encoding: Encoding, + data: ByteBufferPtr, + num_levels: usize, + num_values: Option, + ) -> Result<()> { + self.decoder = Some(ByteArrayDecoder::new( + encoding, + data, + num_levels, + num_values, + self.validate_utf8, + )?); + Ok(()) + } + + fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result { + let decoder = self + .decoder + .as_mut() + .ok_or_else(|| general_err!("no decoder set"))?; + + decoder.read(out, range.end - range.start, self.dict.as_ref()) + } +} + +/// A generic decoder from uncompressed parquet value data to [`OffsetBuffer`] +pub enum ByteArrayDecoder { + Plain(ByteArrayDecoderPlain), + Dictionary(ByteArrayDecoderDictionary), + DeltaLength(ByteArrayDecoderDeltaLength), + DeltaByteArray(ByteArrayDecoderDelta), +} + +impl ByteArrayDecoder { + pub fn new( + encoding: Encoding, + data: ByteBufferPtr, + num_levels: usize, + num_values: Option, + validate_utf8: bool, + ) -> Result { + let decoder = match encoding { + Encoding::PLAIN => ByteArrayDecoder::Plain(ByteArrayDecoderPlain::new( + data, + num_levels, + num_values, + validate_utf8, + )), + Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => { + ByteArrayDecoder::Dictionary(ByteArrayDecoderDictionary::new( + data, num_levels, num_values, + )) + } + Encoding::DELTA_LENGTH_BYTE_ARRAY => ByteArrayDecoder::DeltaLength( + ByteArrayDecoderDeltaLength::new(data, validate_utf8)?, + ), + Encoding::DELTA_BYTE_ARRAY => ByteArrayDecoder::DeltaByteArray( + ByteArrayDecoderDelta::new(data, validate_utf8)?, + ), + _ => { + return Err(general_err!( + "unsupported encoding for byte array: {}", + encoding + )) + } + }; + + Ok(decoder) + } + + /// Read up to `len` values to `out` with the optional dictionary + pub fn read( + &mut self, + out: &mut OffsetBuffer, + len: usize, + dict: Option<&OffsetBuffer>, + ) -> Result { + match self { + ByteArrayDecoder::Plain(d) => d.read(out, len), + ByteArrayDecoder::Dictionary(d) => { + let dict = dict + .ok_or_else(|| general_err!("missing dictionary page for column"))?; + + d.read(out, dict, len) + } + ByteArrayDecoder::DeltaLength(d) => d.read(out, len), + ByteArrayDecoder::DeltaByteArray(d) => d.read(out, len), + } + } +} + +/// Decoder from [`Encoding::PLAIN`] data to [`OffsetBuffer`] +pub struct ByteArrayDecoderPlain { + buf: ByteBufferPtr, + offset: usize, + validate_utf8: bool, + + /// This is a maximum as the null count is not always known, e.g. value data from + /// a v1 data page + max_remaining_values: usize, +} + +impl ByteArrayDecoderPlain { + pub fn new( + buf: ByteBufferPtr, + num_levels: usize, + num_values: Option, + validate_utf8: bool, + ) -> Self { + Self { + buf, + validate_utf8, + offset: 0, + max_remaining_values: num_values.unwrap_or(num_levels), + } + } + + pub fn read( + &mut self, + output: &mut OffsetBuffer, + len: usize, + ) -> Result { + let initial_values_length = output.values.len(); + + let to_read = len.min(self.max_remaining_values); + output.offsets.reserve(to_read); + + let remaining_bytes = self.buf.len() - self.offset; + if remaining_bytes == 0 { + return Ok(0); + } + + let estimated_bytes = remaining_bytes + .checked_mul(to_read) + .map(|x| x / self.max_remaining_values) + .unwrap_or_default(); + + output.values.reserve(estimated_bytes); + + let mut read = 0; + + let buf = self.buf.as_ref(); + while self.offset < self.buf.len() && read != to_read { + if self.offset + 4 > buf.len() { + return Err(ParquetError::EOF("eof decoding byte array".into())); + } + let len_bytes: [u8; 4] = + buf[self.offset..self.offset + 4].try_into().unwrap(); + let len = u32::from_le_bytes(len_bytes); + + let start_offset = self.offset + 4; + let end_offset = start_offset + len as usize; + if end_offset > buf.len() { + return Err(ParquetError::EOF("eof decoding byte array".into())); + } + + output.try_push(&buf[start_offset..end_offset], self.validate_utf8)?; + + self.offset = end_offset; + read += 1; + } + self.max_remaining_values -= to_read; + + if self.validate_utf8 { + output.check_valid_utf8(initial_values_length)?; + } + Ok(to_read) + } +} + +/// Decoder from [`Encoding::DELTA_LENGTH_BYTE_ARRAY`] data to [`OffsetBuffer`] +pub struct ByteArrayDecoderDeltaLength { + lengths: Vec, + data: ByteBufferPtr, + length_offset: usize, + data_offset: usize, + validate_utf8: bool, +} + +impl ByteArrayDecoderDeltaLength { + fn new(data: ByteBufferPtr, validate_utf8: bool) -> Result { + let mut len_decoder = DeltaBitPackDecoder::::new(); + len_decoder.set_data(data.all(), 0)?; + let values = len_decoder.values_left(); + + let mut lengths = vec![0; values]; + len_decoder.get(&mut lengths)?; + + Ok(Self { + lengths, + data, + validate_utf8, + length_offset: 0, + data_offset: len_decoder.get_offset(), + }) + } + + fn read( + &mut self, + output: &mut OffsetBuffer, + len: usize, + ) -> Result { + let initial_values_length = output.values.len(); + + let to_read = len.min(self.lengths.len() - self.length_offset); + output.offsets.reserve(to_read); + + let src_lengths = &self.lengths[self.length_offset..self.length_offset + to_read]; + + let total_bytes: usize = src_lengths.iter().map(|x| *x as usize).sum(); + output.values.reserve(total_bytes); + + if self.data_offset + total_bytes > self.data.len() { + return Err(ParquetError::EOF( + "Insufficient delta length byte array bytes".to_string(), + )); + } + + let mut start_offset = self.data_offset; + for length in src_lengths { + let end_offset = start_offset + *length as usize; + output.try_push( + &self.data.as_ref()[start_offset..end_offset], + self.validate_utf8, + )?; + start_offset = end_offset; + } + + self.data_offset = start_offset; + self.length_offset += to_read; + + if self.validate_utf8 { + output.check_valid_utf8(initial_values_length)?; + } + Ok(to_read) + } +} + +/// Decoder from [`Encoding::DELTA_BYTE_ARRAY`] to [`OffsetBuffer`] +pub struct ByteArrayDecoderDelta { + prefix_lengths: Vec, + suffix_lengths: Vec, + data: ByteBufferPtr, + length_offset: usize, + data_offset: usize, + last_value: Vec, + validate_utf8: bool, +} + +impl ByteArrayDecoderDelta { + fn new(data: ByteBufferPtr, validate_utf8: bool) -> Result { + let mut prefix = DeltaBitPackDecoder::::new(); + prefix.set_data(data.all(), 0)?; + + let num_prefix = prefix.values_left(); + let mut prefix_lengths = vec![0; num_prefix]; + assert_eq!(prefix.get(&mut prefix_lengths)?, num_prefix); + + let mut suffix = DeltaBitPackDecoder::::new(); + suffix.set_data(data.start_from(prefix.get_offset()), 0)?; + + let num_suffix = suffix.values_left(); + let mut suffix_lengths = vec![0; num_suffix]; + assert_eq!(suffix.get(&mut suffix_lengths)?, num_suffix); + + if num_prefix != num_suffix { + return Err(general_err!(format!( + "inconsistent DELTA_BYTE_ARRAY lengths, prefixes: {}, suffixes: {}", + num_prefix, num_suffix + ))); + } + + Ok(Self { + prefix_lengths, + suffix_lengths, + data, + length_offset: 0, + data_offset: prefix.get_offset() + suffix.get_offset(), + last_value: vec![], + validate_utf8, + }) + } + + fn read( + &mut self, + output: &mut OffsetBuffer, + len: usize, + ) -> Result { + let initial_values_length = output.values.len(); + assert_eq!(self.prefix_lengths.len(), self.suffix_lengths.len()); + + let to_read = len.min(self.prefix_lengths.len() - self.length_offset); + + output.offsets.reserve(to_read); + + let length_range = self.length_offset..self.length_offset + to_read; + let iter = self.prefix_lengths[length_range.clone()] + .iter() + .zip(&self.suffix_lengths[length_range]); + + let data = self.data.as_ref(); + + for (prefix_length, suffix_length) in iter { + let prefix_length = *prefix_length as usize; + let suffix_length = *suffix_length as usize; + + if self.data_offset + suffix_length > self.data.len() { + return Err(ParquetError::EOF("eof decoding byte array".into())); + } + + self.last_value.truncate(prefix_length); + self.last_value.extend_from_slice( + &data[self.data_offset..self.data_offset + suffix_length], + ); + output.try_push(&self.last_value, self.validate_utf8)?; + + self.data_offset += suffix_length; + } + + self.length_offset += to_read; + + if self.validate_utf8 { + output.check_valid_utf8(initial_values_length)?; + } + Ok(to_read) + } +} + +/// Decoder from [`Encoding::RLE_DICTIONARY`] to [`OffsetBuffer`] +pub struct ByteArrayDecoderDictionary { + decoder: RleDecoder, + + index_buf: Box<[i32; 1024]>, + index_buf_len: usize, + index_offset: usize, + + /// This is a maximum as the null count is not always known, e.g. value data from + /// a v1 data page + max_remaining_values: usize, +} + +impl ByteArrayDecoderDictionary { + fn new(data: ByteBufferPtr, num_levels: usize, num_values: Option) -> Self { + let bit_width = data[0]; + let mut decoder = RleDecoder::new(bit_width); + decoder.set_data(data.start_from(1)); + + Self { + decoder, + index_buf: Box::new([0; 1024]), + index_buf_len: 0, + index_offset: 0, + max_remaining_values: num_values.unwrap_or(num_levels), + } + } + + fn read( + &mut self, + output: &mut OffsetBuffer, + dict: &OffsetBuffer, + len: usize, + ) -> Result { + if dict.is_empty() { + return Ok(0); // All data must be NULL + } + + let mut values_read = 0; + + while values_read != len && self.max_remaining_values != 0 { + if self.index_offset == self.index_buf_len { + let read = self.decoder.get_batch(self.index_buf.as_mut())?; + if read == 0 { + break; + } + self.index_buf_len = read; + self.index_offset = 0; + } + + let to_read = (len - values_read) + .min(self.index_buf_len - self.index_offset) + .min(self.max_remaining_values); + + output.extend_from_dictionary( + &self.index_buf[self.index_offset..self.index_offset + to_read], + dict.offsets.as_slice(), + dict.values.as_slice(), + )?; + + self.index_offset += to_read; + values_read += to_read; + self.max_remaining_values -= to_read; + } + Ok(values_read) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::array_reader::test_util::{byte_array_all_encodings, utf8_column}; + use crate::arrow::record_reader::buffer::ValuesBuffer; + use arrow::array::{Array, StringArray}; + + #[test] + fn test_byte_array_decoder() { + let (pages, encoded_dictionary) = + byte_array_all_encodings(vec!["hello", "world", "a", "b"]); + + let column_desc = utf8_column(); + let mut decoder = ByteArrayColumnValueDecoder::new(&column_desc); + + decoder + .set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + for (encoding, page) in pages { + let mut output = OffsetBuffer::::default(); + decoder.set_data(encoding, page, 4, Some(4)).unwrap(); + + assert_eq!(decoder.read(&mut output, 0..1).unwrap(), 1); + + assert_eq!(output.values.as_slice(), "hello".as_bytes()); + assert_eq!(output.offsets.as_slice(), &[0, 5]); + + assert_eq!(decoder.read(&mut output, 1..2).unwrap(), 1); + assert_eq!(output.values.as_slice(), "helloworld".as_bytes()); + assert_eq!(output.offsets.as_slice(), &[0, 5, 10]); + + assert_eq!(decoder.read(&mut output, 2..4).unwrap(), 2); + assert_eq!(output.values.as_slice(), "helloworldab".as_bytes()); + assert_eq!(output.offsets.as_slice(), &[0, 5, 10, 11, 12]); + + assert_eq!(decoder.read(&mut output, 4..8).unwrap(), 0); + + let valid = vec![false, false, true, true, false, true, true, false, false]; + let valid_buffer = Buffer::from_iter(valid.iter().cloned()); + + output.pad_nulls(0, 4, valid.len(), valid_buffer.as_slice()); + let array = output.into_array(Some(valid_buffer), ArrowType::Utf8); + let strings = array.as_any().downcast_ref::().unwrap(); + + assert_eq!( + strings.iter().collect::>(), + vec![ + None, + None, + Some("hello"), + Some("world"), + None, + Some("a"), + Some("b"), + None, + None, + ] + ); + } + } + + #[test] + fn test_byte_array_decoder_nulls() { + let (pages, encoded_dictionary) = byte_array_all_encodings(Vec::<&str>::new()); + + let column_desc = utf8_column(); + let mut decoder = ByteArrayColumnValueDecoder::new(&column_desc); + + decoder + .set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + for (encoding, page) in pages { + let mut output = OffsetBuffer::::default(); + decoder.set_data(encoding, page, 4, None).unwrap(); + assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 0); + } + } +} diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs new file mode 100644 index 000000000000..0cd67206f000 --- /dev/null +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -0,0 +1,551 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::marker::PhantomData; +use std::ops::Range; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow::buffer::Buffer; +use arrow::datatypes::{ArrowNativeType, DataType as ArrowType}; + +use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, ByteArrayDecoderPlain}; +use crate::arrow::array_reader::{read_records, ArrayReader}; +use crate::arrow::buffer::{ + dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer, +}; +use crate::arrow::record_reader::buffer::{BufferQueue, ScalarValue}; +use crate::arrow::record_reader::GenericRecordReader; +use crate::arrow::schema::parquet_to_arrow_field; +use crate::basic::{ConvertedType, Encoding}; +use crate::column::page::PageIterator; +use crate::column::reader::decoder::ColumnValueDecoder; +use crate::encodings::rle::RleDecoder; +use crate::errors::{ParquetError, Result}; +use crate::schema::types::ColumnDescPtr; +use crate::util::bit_util::FromBytes; +use crate::util::memory::ByteBufferPtr; + +/// A macro to reduce verbosity of [`make_byte_array_dictionary_reader`] +macro_rules! make_reader { + ( + ($pages:expr, $column_desc:expr, $data_type:expr, $null_mask_only:expr) => match ($k:expr, $v:expr) { + $(($key_arrow:pat, $value_arrow:pat) => ($key_type:ty, $value_type:ty),)+ + } + ) => { + match (($k, $v)) { + $( + ($key_arrow, $value_arrow) => { + let reader = GenericRecordReader::new_with_options( + $column_desc, + $null_mask_only, + ); + Ok(Box::new(ByteArrayDictionaryReader::<$key_type, $value_type>::new( + $pages, $data_type, reader, + ))) + } + )+ + _ => Err(general_err!( + "unsupported data type for byte array dictionary reader - {}", + $data_type + )), + } + } +} + +/// Returns an [`ArrayReader`] that decodes the provided byte array column +/// +/// This will attempt to preserve any dictionary encoding present in the parquet data +/// +/// It will be unable to preserve the dictionary encoding if: +/// +/// * A single read spans across multiple column chunks +/// * A column chunk contains non-dictionary encoded pages +/// +/// It is therefore recommended that if `pages` contains data from multiple column chunks, +/// that the read batch size used is a divisor of the row group size +/// +pub fn make_byte_array_dictionary_reader( + pages: Box, + column_desc: ColumnDescPtr, + arrow_type: Option, + null_mask_only: bool, +) -> Result> { + // Check if Arrow type is specified, else create it from Parquet type + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; + + match &data_type { + ArrowType::Dictionary(key_type, value_type) => { + make_reader! { + (pages, column_desc, data_type, null_mask_only) => match (key_type.as_ref(), value_type.as_ref()) { + (ArrowType::UInt8, ArrowType::Binary | ArrowType::Utf8) => (u8, i32), + (ArrowType::UInt8, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u8, i64), + (ArrowType::Int8, ArrowType::Binary | ArrowType::Utf8) => (i8, i32), + (ArrowType::Int8, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i8, i64), + (ArrowType::UInt16, ArrowType::Binary | ArrowType::Utf8) => (u16, i32), + (ArrowType::UInt16, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u16, i64), + (ArrowType::Int16, ArrowType::Binary | ArrowType::Utf8) => (i16, i32), + (ArrowType::Int16, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i16, i64), + (ArrowType::UInt32, ArrowType::Binary | ArrowType::Utf8) => (u32, i32), + (ArrowType::UInt32, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u32, i64), + (ArrowType::Int32, ArrowType::Binary | ArrowType::Utf8) => (i32, i32), + (ArrowType::Int32, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i32, i64), + (ArrowType::UInt64, ArrowType::Binary | ArrowType::Utf8) => (u64, i32), + (ArrowType::UInt64, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u64, i64), + (ArrowType::Int64, ArrowType::Binary | ArrowType::Utf8) => (i64, i32), + (ArrowType::Int64, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (i64, i64), + } + } + } + _ => Err(general_err!( + "invalid non-dictionary data type for byte array dictionary reader - {}", + data_type + )), + } +} + +/// An [`ArrayReader`] for dictionary encoded variable length byte arrays +/// +/// Will attempt to preserve any dictionary encoding present in the parquet data +struct ByteArrayDictionaryReader { + data_type: ArrowType, + pages: Box, + def_levels_buffer: Option, + rep_levels_buffer: Option, + record_reader: GenericRecordReader, DictionaryDecoder>, +} + +impl ByteArrayDictionaryReader +where + K: FromBytes + ScalarValue + Ord + ArrowNativeType, + V: ScalarValue + OffsetSizeTrait, +{ + fn new( + pages: Box, + data_type: ArrowType, + record_reader: GenericRecordReader< + DictionaryBuffer, + DictionaryDecoder, + >, + ) -> Self { + Self { + data_type, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + record_reader, + } + } +} + +impl ArrayReader for ByteArrayDictionaryReader +where + K: FromBytes + ScalarValue + Ord + ArrowNativeType, + V: ScalarValue + OffsetSizeTrait, +{ + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + let buffer = self.record_reader.consume_record_data()?; + let null_buffer = self.record_reader.consume_bitmap_buffer()?; + let array = buffer.into_array(null_buffer, &self.data_type)?; + + self.def_levels_buffer = self.record_reader.consume_def_levels()?; + self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; + self.record_reader.reset(); + + Ok(array) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } +} + +/// If the data is dictionary encoded decode the key data directly, so that the dictionary +/// encoding can be preserved. Otherwise fallback to decoding using [`ByteArrayDecoder`] +/// and compute a fresh dictionary in [`ByteArrayDictionaryReader::next_batch`] +enum MaybeDictionaryDecoder { + Dict { + decoder: RleDecoder, + /// This is a maximum as the null count is not always known, e.g. value data from + /// a v1 data page + max_remaining_values: usize, + }, + Fallback(ByteArrayDecoder), +} + +/// A [`ColumnValueDecoder`] for dictionary encoded variable length byte arrays +struct DictionaryDecoder { + /// The current dictionary + dict: Option, + + /// Dictionary decoder + decoder: Option, + + validate_utf8: bool, + + value_type: ArrowType, + + phantom: PhantomData<(K, V)>, +} + +impl ColumnValueDecoder for DictionaryDecoder +where + K: FromBytes + ScalarValue + Ord + ArrowNativeType, + V: ScalarValue + OffsetSizeTrait, +{ + type Slice = DictionaryBuffer; + + fn new(col: &ColumnDescPtr) -> Self { + let validate_utf8 = col.converted_type() == ConvertedType::UTF8; + + let value_type = match (V::IS_LARGE, col.converted_type() == ConvertedType::UTF8) + { + (true, true) => ArrowType::LargeUtf8, + (true, false) => ArrowType::LargeBinary, + (false, true) => ArrowType::Utf8, + (false, false) => ArrowType::Binary, + }; + + Self { + dict: None, + decoder: None, + validate_utf8, + value_type, + phantom: Default::default(), + } + } + + fn set_dict( + &mut self, + buf: ByteBufferPtr, + num_values: u32, + encoding: Encoding, + _is_sorted: bool, + ) -> Result<()> { + if !matches!( + encoding, + Encoding::PLAIN | Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY + ) { + return Err(nyi_err!( + "Invalid/Unsupported encoding type for dictionary: {}", + encoding + )); + } + + if K::from_usize(num_values as usize).is_none() { + return Err(general_err!("dictionary too large for index type")); + } + + let len = num_values as usize; + let mut buffer = OffsetBuffer::::default(); + let mut decoder = + ByteArrayDecoderPlain::new(buf, len, Some(len), self.validate_utf8); + decoder.read(&mut buffer, usize::MAX)?; + + let array = buffer.into_array(None, self.value_type.clone()); + self.dict = Some(Arc::new(array)); + Ok(()) + } + + fn set_data( + &mut self, + encoding: Encoding, + data: ByteBufferPtr, + num_levels: usize, + num_values: Option, + ) -> Result<()> { + let decoder = match encoding { + Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => { + let bit_width = data[0]; + let mut decoder = RleDecoder::new(bit_width); + decoder.set_data(data.start_from(1)); + MaybeDictionaryDecoder::Dict { + decoder, + max_remaining_values: num_values.unwrap_or(num_levels), + } + } + _ => MaybeDictionaryDecoder::Fallback(ByteArrayDecoder::new( + encoding, + data, + num_levels, + num_values, + self.validate_utf8, + )?), + }; + + self.decoder = Some(decoder); + Ok(()) + } + + fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result { + match self.decoder.as_mut().expect("decoder set") { + MaybeDictionaryDecoder::Fallback(decoder) => { + decoder.read(out.spill_values()?, range.end - range.start, None) + } + MaybeDictionaryDecoder::Dict { + decoder, + max_remaining_values, + } => { + let len = (range.end - range.start).min(*max_remaining_values); + + let dict = self + .dict + .as_ref() + .ok_or_else(|| general_err!("missing dictionary page for column"))?; + + assert_eq!(dict.data_type(), &self.value_type); + + if dict.is_empty() { + return Ok(0); // All data must be NULL + } + + match out.as_keys(dict) { + Some(keys) => { + // Happy path - can just copy keys + // Keys will be validated on conversion to arrow + let keys_slice = keys.spare_capacity_mut(range.start + len); + let len = decoder.get_batch(&mut keys_slice[range.start..])?; + Ok(len) + } + None => { + // Sad path - need to recompute dictionary + // + // This either means we crossed into a new column chunk whilst + // reading this batch, or encountered non-dictionary encoded data + let values = out.spill_values()?; + let mut keys = vec![K::default(); len]; + let len = decoder.get_batch(&mut keys)?; + + assert_eq!(dict.data_type(), &self.value_type); + + let dict_buffers = dict.data().buffers(); + let dict_offsets = dict_buffers[0].typed_data::(); + let dict_values = dict_buffers[1].as_slice(); + + values.extend_from_dictionary( + &keys[..len], + dict_offsets, + dict_values, + )?; + + Ok(len) + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::compute::cast; + + use crate::arrow::array_reader::test_util::{ + byte_array_all_encodings, encode_dictionary, utf8_column, + }; + use crate::arrow::record_reader::buffer::ValuesBuffer; + use crate::data_type::ByteArray; + + use super::*; + + fn utf8_dictionary() -> ArrowType { + ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8)) + } + + #[test] + fn test_dictionary_preservation() { + let data_type = utf8_dictionary(); + + let data: Vec<_> = vec!["0", "1", "0", "1", "2", "1", "2"] + .into_iter() + .map(ByteArray::from) + .collect(); + let (dict, encoded) = encode_dictionary(&data); + + let column_desc = utf8_column(); + let mut decoder = DictionaryDecoder::::new(&column_desc); + + decoder + .set_dict(dict, 3, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + decoder + .set_data(Encoding::RLE_DICTIONARY, encoded, 14, Some(data.len())) + .unwrap(); + + let mut output = DictionaryBuffer::::default(); + assert_eq!(decoder.read(&mut output, 0..3).unwrap(), 3); + + let mut valid = vec![false, false, true, true, false, true]; + let valid_buffer = Buffer::from_iter(valid.iter().cloned()); + output.pad_nulls(0, 3, valid.len(), valid_buffer.as_slice()); + + assert!(matches!(output, DictionaryBuffer::Dict { .. })); + + assert_eq!(decoder.read(&mut output, 0..4).unwrap(), 4); + + valid.extend_from_slice(&[false, false, true, true, false, true, true, false]); + let valid_buffer = Buffer::from_iter(valid.iter().cloned()); + output.pad_nulls(6, 4, 8, valid_buffer.as_slice()); + + assert!(matches!(output, DictionaryBuffer::Dict { .. })); + + let array = output.into_array(Some(valid_buffer), &data_type).unwrap(); + assert_eq!(array.data_type(), &data_type); + + let array = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.len(), 14); + + assert_eq!( + strings.iter().collect::>(), + vec![ + None, + None, + Some("0"), + Some("1"), + None, + Some("0"), + None, + None, + Some("1"), + Some("2"), + None, + Some("1"), + Some("2"), + None + ] + ) + } + + #[test] + fn test_dictionary_fallback() { + let data_type = utf8_dictionary(); + let data = vec!["hello", "world", "a", "b"]; + + let (pages, encoded_dictionary) = byte_array_all_encodings(data.clone()); + let num_encodings = pages.len(); + + let column_desc = utf8_column(); + let mut decoder = DictionaryDecoder::::new(&column_desc); + + decoder + .set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + // Read all pages into single buffer + let mut output = DictionaryBuffer::::default(); + + for (encoding, page) in pages { + decoder.set_data(encoding, page, 4, Some(4)).unwrap(); + assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 4); + } + let array = output.into_array(None, &data_type).unwrap(); + assert_eq!(array.data_type(), &data_type); + + let array = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.len(), data.len() * num_encodings); + + // Should have a copy of `data` for each encoding + for i in 0..num_encodings { + assert_eq!( + strings + .iter() + .skip(i * data.len()) + .take(data.len()) + .map(|x| x.unwrap()) + .collect::>(), + data + ) + } + } + + #[test] + fn test_too_large_dictionary() { + let data: Vec<_> = (0..128) + .map(|x| ByteArray::from(x.to_string().as_str())) + .collect(); + let (dictionary, _) = encode_dictionary(&data); + + let column_desc = utf8_column(); + + let mut decoder = DictionaryDecoder::::new(&column_desc); + let err = decoder + .set_dict(dictionary.clone(), 128, Encoding::RLE_DICTIONARY, false) + .unwrap_err() + .to_string(); + + assert!(err.contains("dictionary too large for index type")); + + let mut decoder = DictionaryDecoder::::new(&column_desc); + decoder + .set_dict(dictionary, 128, Encoding::RLE_DICTIONARY, false) + .unwrap(); + } + + #[test] + fn test_nulls() { + let data_type = utf8_dictionary(); + let (pages, encoded_dictionary) = byte_array_all_encodings(Vec::<&str>::new()); + + let column_desc = utf8_column(); + let mut decoder = DictionaryDecoder::new(&column_desc); + + decoder + .set_dict(encoded_dictionary, 4, Encoding::PLAIN_DICTIONARY, false) + .unwrap(); + + for (encoding, page) in pages { + let mut output = DictionaryBuffer::::default(); + decoder.set_data(encoding, page, 8, None).unwrap(); + assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 0); + + output.pad_nulls(0, 0, 8, &[0]); + let array = output + .into_array(Some(Buffer::from(&[0])), &data_type) + .unwrap(); + + assert_eq!(array.len(), 8); + assert_eq!(array.null_count(), 8); + } + } +} diff --git a/parquet/src/arrow/array_reader/empty_array.rs b/parquet/src/arrow/array_reader/empty_array.rs new file mode 100644 index 000000000000..54b77becba04 --- /dev/null +++ b/parquet/src/arrow/array_reader/empty_array.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::array_reader::ArrayReader; +use crate::errors::Result; +use arrow::array::{ArrayDataBuilder, ArrayRef, StructArray}; +use arrow::datatypes::DataType as ArrowType; +use std::any::Any; +use std::sync::Arc; + +/// Returns an [`ArrayReader`] that yields [`StructArray`] with no columns +/// but with row counts that correspond to the amount of data in the file +/// +/// This is useful for when projection eliminates all columns within a collection +pub fn make_empty_array_reader(row_count: usize) -> Box { + Box::new(EmptyArrayReader::new(row_count)) +} + +struct EmptyArrayReader { + data_type: ArrowType, + remaining_rows: usize, +} + +impl EmptyArrayReader { + pub fn new(row_count: usize) -> Self { + Self { + data_type: ArrowType::Struct(vec![]), + remaining_rows: row_count, + } + } +} + +impl ArrayReader for EmptyArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + let len = self.remaining_rows.min(batch_size); + self.remaining_rows -= len; + + let data = ArrayDataBuilder::new(self.data_type.clone()) + .len(len) + .build() + .unwrap(); + + Ok(Arc::new(StructArray::from(data))) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + None + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + None + } +} diff --git a/parquet/src/arrow/array_reader/list_array.rs b/parquet/src/arrow/array_reader/list_array.rs new file mode 100644 index 000000000000..ab51cd87de1e --- /dev/null +++ b/parquet/src/arrow/array_reader/list_array.rs @@ -0,0 +1,612 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::array_reader::ArrayReader; +use crate::errors::ParquetError; +use crate::errors::Result; +use arrow::array::{ + new_empty_array, Array, ArrayData, ArrayRef, BooleanBufferBuilder, GenericListArray, + MutableArrayData, OffsetSizeTrait, +}; +use arrow::buffer::Buffer; +use arrow::datatypes::DataType as ArrowType; +use arrow::datatypes::ToByteSlice; +use std::any::Any; +use std::cmp::Ordering; +use std::marker::PhantomData; +use std::sync::Arc; + +/// Implementation of list array reader. +pub struct ListArrayReader { + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + /// The definition level at which this list is not null + def_level: i16, + /// The repetition level that corresponds to a new value in this array + rep_level: i16, + /// If this list is nullable + nullable: bool, + _marker: PhantomData, +} + +impl ListArrayReader { + /// Construct list array reader. + pub fn new( + item_reader: Box, + data_type: ArrowType, + item_type: ArrowType, + def_level: i16, + rep_level: i16, + nullable: bool, + ) -> Self { + Self { + item_reader, + data_type, + item_type, + def_level, + rep_level, + nullable, + _marker: PhantomData, + } + } +} + +/// Implementation of ListArrayReader. Nested lists and lists of structs are not yet supported. +impl ArrayReader for ListArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type. + /// This must be a List. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + let next_batch_array = self.item_reader.next_batch(batch_size)?; + + if next_batch_array.len() == 0 { + return Ok(new_empty_array(&self.data_type)); + } + + let def_levels = self + .item_reader + .get_def_levels() + .ok_or_else(|| general_err!("item_reader def levels are None."))?; + + let rep_levels = self + .item_reader + .get_rep_levels() + .ok_or_else(|| general_err!("item_reader rep levels are None."))?; + + if OffsetSize::from_usize(next_batch_array.len()).is_none() { + return Err(general_err!( + "offset of {} would overflow list array", + next_batch_array.len() + )); + } + + if !rep_levels.is_empty() && rep_levels[0] != 0 { + // This implies either the source data was invalid, or the leaf column + // reader did not correctly delimit semantic records + return Err(general_err!("first repetition level of batch must be 0")); + } + + // A non-nullable list has a single definition level indicating if the list is empty + // + // A nullable list has two definition levels associated with it: + // + // The first identifies if the list is null + // The second identifies if the list is empty + // + // The child data returned above is padded with a value for each not-fully defined level. + // Therefore null and empty lists will correspond to a value in the child array. + // + // Whilst nulls may have a non-zero slice in the offsets array, empty lists must + // be of zero length. As a result we MUST filter out values corresponding to empty + // lists, and for consistency we do the same for nulls. + + // The output offsets for the computed ListArray + let mut list_offsets: Vec = + Vec::with_capacity(next_batch_array.len()); + + // The validity mask of the computed ListArray if nullable + let mut validity = self + .nullable + .then(|| BooleanBufferBuilder::new(next_batch_array.len())); + + // The offset into the filtered child data of the current level being considered + let mut cur_offset = 0; + + // Identifies the start of a run of values to copy from the source child data + let mut filter_start = None; + + // The number of child values skipped due to empty lists or nulls + let mut skipped = 0; + + // Builder used to construct the filtered child data, skipping empty lists and nulls + let mut child_data_builder = MutableArrayData::new( + vec![next_batch_array.data()], + false, + next_batch_array.len(), + ); + + def_levels.iter().zip(rep_levels).try_for_each(|(d, r)| { + match r.cmp(&self.rep_level) { + Ordering::Greater => { + // Repetition level greater than current => already handled by inner array + if *d < self.def_level { + return Err(general_err!( + "Encountered repetition level too large for definition level" + )); + } + } + Ordering::Equal => { + // New value in the current list + cur_offset += 1; + } + Ordering::Less => { + // Create new array slice + // Already checked that this cannot overflow + list_offsets.push(OffsetSize::from_usize(cur_offset).unwrap()); + + if *d >= self.def_level { + // Fully defined value + + // Record current offset if it is None + filter_start.get_or_insert(cur_offset + skipped); + + cur_offset += 1; + + if let Some(validity) = validity.as_mut() { + validity.append(true) + } + } else { + // Flush the current slice of child values if any + if let Some(start) = filter_start.take() { + child_data_builder.extend(0, start, cur_offset + skipped); + } + + if let Some(validity) = validity.as_mut() { + // Valid if empty list + validity.append(*d + 1 == self.def_level) + } + + skipped += 1; + } + } + } + Ok(()) + })?; + + list_offsets.push(OffsetSize::from_usize(cur_offset).unwrap()); + + let child_data = if skipped == 0 { + // No filtered values - can reuse original array + next_batch_array.data().clone() + } else { + // One or more filtered values - must build new array + if let Some(start) = filter_start.take() { + child_data_builder.extend(0, start, cur_offset + skipped) + } + + child_data_builder.freeze() + }; + + if cur_offset != child_data.len() { + return Err(general_err!("Failed to reconstruct list from level data")); + } + + let value_offsets = Buffer::from(&list_offsets.to_byte_slice()); + + let mut data_builder = ArrayData::builder(self.get_data_type().clone()) + .len(list_offsets.len() - 1) + .add_buffer(value_offsets) + .add_child_data(child_data); + + if let Some(mut builder) = validity { + assert_eq!(builder.len(), list_offsets.len() - 1); + data_builder = data_builder.null_bit_buffer(Some(builder.finish())) + } + + let list_data = unsafe { data_builder.build_unchecked() }; + + let result_array = GenericListArray::::from(list_data); + Ok(Arc::new(result_array)) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.item_reader.get_def_levels() + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.item_reader.get_rep_levels() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::array_reader::build_array_reader; + use crate::arrow::array_reader::list_array::ListArrayReader; + use crate::arrow::array_reader::test_util::InMemoryArrayReader; + use crate::arrow::{parquet_to_arrow_schema, ArrowWriter, ProjectionMask}; + use crate::file::properties::WriterProperties; + use crate::file::reader::{FileReader, SerializedFileReader}; + use crate::schema::parser::parse_message_type; + use crate::schema::types::SchemaDescriptor; + use arrow::array::{Array, ArrayDataBuilder, PrimitiveArray}; + use arrow::datatypes::{Field, Int32Type as ArrowInt32, Int32Type}; + use std::sync::Arc; + + fn list_type( + data_type: ArrowType, + item_nullable: bool, + ) -> ArrowType { + let field = Box::new(Field::new("item", data_type, item_nullable)); + match OffsetSize::IS_LARGE { + true => ArrowType::LargeList(field), + false => ArrowType::List(field), + } + } + + fn downcast( + array: &ArrayRef, + ) -> &'_ GenericListArray { + array + .as_any() + .downcast_ref::>() + .unwrap() + } + + fn to_offsets(values: Vec) -> Buffer { + Buffer::from_iter( + values + .into_iter() + .map(|x| OffsetSize::from_usize(x).unwrap()), + ) + } + + fn test_nested_list() { + // 3 lists, with first and third nullable + // [ + // [ + // [[1, null], null, [4], []], + // [], + // [[7]], + // [[]], + // [[1, 2, 3], [4, null, 6], null] + // ], + // null, + // [], + // [[[11]]] + // ] + + let l3_item_type = ArrowType::Int32; + let l3_type = list_type::(l3_item_type.clone(), true); + + let l2_item_type = l3_type.clone(); + let l2_type = list_type::(l2_item_type.clone(), true); + + let l1_item_type = l2_type.clone(); + let l1_type = list_type::(l1_item_type.clone(), false); + + let leaf = PrimitiveArray::::from_iter(vec![ + Some(1), + None, + Some(4), + Some(7), + Some(1), + Some(2), + Some(3), + Some(4), + None, + Some(6), + Some(11), + ]); + + // [[1, null], null, [4], [], [7], [], [1, 2, 3], [4, null, 6], null, [11]] + let offsets = to_offsets::(vec![0, 2, 2, 3, 3, 4, 4, 7, 10, 10, 11]); + let l3 = ArrayDataBuilder::new(l3_type.clone()) + .len(10) + .add_buffer(offsets) + .add_child_data(leaf.data().clone()) + .null_bit_buffer(Some(Buffer::from([0b11111101, 0b00000010]))) + .build() + .unwrap(); + + // [[[1, null], null, [4], []], [], [[7]], [[]], [[1, 2, 3], [4, null, 6], null], [[11]]] + let offsets = to_offsets::(vec![0, 4, 4, 5, 6, 9, 10]); + let l2 = ArrayDataBuilder::new(l2_type.clone()) + .len(6) + .add_buffer(offsets) + .add_child_data(l3) + .build() + .unwrap(); + + let offsets = to_offsets::(vec![0, 5, 5, 5, 6]); + let l1 = ArrayDataBuilder::new(l1_type.clone()) + .len(4) + .add_buffer(offsets) + .add_child_data(l2) + .null_bit_buffer(Some(Buffer::from([0b00001101]))) + .build() + .unwrap(); + + let expected = GenericListArray::::from(l1); + + let values = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + None, + Some(4), + None, + None, + Some(7), + None, + Some(1), + Some(2), + Some(3), + Some(4), + None, + Some(6), + None, + None, + None, + Some(11), + ])); + + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + values, + Some(vec![6, 5, 3, 6, 4, 2, 6, 4, 6, 6, 6, 6, 5, 6, 3, 0, 1, 6]), + Some(vec![0, 3, 2, 2, 2, 1, 1, 1, 1, 3, 3, 2, 3, 3, 2, 0, 0, 0]), + ); + + let l3 = ListArrayReader::::new( + Box::new(item_array_reader), + l3_type, + l3_item_type, + 5, + 3, + true, + ); + + let l2 = ListArrayReader::::new( + Box::new(l3), + l2_type, + l2_item_type, + 3, + 2, + false, + ); + + let mut l1 = ListArrayReader::::new( + Box::new(l2), + l1_type, + l1_item_type, + 2, + 1, + true, + ); + + let expected_1 = expected.slice(0, 2); + let expected_2 = expected.slice(2, 2); + + let actual = l1.next_batch(2).unwrap(); + assert_eq!(expected_1.as_ref(), actual.as_ref()); + + let actual = l1.next_batch(1024).unwrap(); + assert_eq!(expected_2.as_ref(), actual.as_ref()); + } + + fn test_required_list() { + // [[1, null, 2], [], [3, 4], [], [], [null, 1]] + let expected = + GenericListArray::::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(2)]), + Some(vec![]), + Some(vec![Some(3), Some(4)]), + Some(vec![]), + Some(vec![]), + Some(vec![None, Some(1)]), + ]); + + let array = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + None, + None, + None, + Some(1), + ])); + + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + array, + Some(vec![2, 1, 2, 0, 2, 2, 0, 0, 1, 2]), + Some(vec![0, 1, 1, 0, 0, 1, 0, 0, 0, 1]), + ); + + let mut list_array_reader = ListArrayReader::::new( + Box::new(item_array_reader), + list_type::(ArrowType::Int32, true), + ArrowType::Int32, + 1, + 1, + false, + ); + + let actual = list_array_reader.next_batch(1024).unwrap(); + let actual = downcast::(&actual); + + assert_eq!(&expected, actual) + } + + fn test_nullable_list() { + // [[1, null, 2], null, [], [3, 4], [], [], null, [], [null, 1]] + let expected = + GenericListArray::::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(2)]), + None, + Some(vec![]), + Some(vec![Some(3), Some(4)]), + Some(vec![]), + Some(vec![]), + None, + Some(vec![]), + Some(vec![None, Some(1)]), + ]); + + let array = Arc::new(PrimitiveArray::::from(vec![ + Some(1), + None, + Some(2), + None, + None, + Some(3), + Some(4), + None, + None, + None, + None, + None, + Some(1), + ])); + + let item_array_reader = InMemoryArrayReader::new( + ArrowType::Int32, + array, + Some(vec![3, 2, 3, 0, 1, 3, 3, 1, 1, 0, 1, 2, 3]), + Some(vec![0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), + ); + + let mut list_array_reader = ListArrayReader::::new( + Box::new(item_array_reader), + list_type::(ArrowType::Int32, true), + ArrowType::Int32, + 2, + 1, + true, + ); + + let actual = list_array_reader.next_batch(1024).unwrap(); + let actual = downcast::(&actual); + + assert_eq!(&expected, actual) + } + + fn test_list_array() { + test_nullable_list::(); + test_required_list::(); + test_nested_list::(); + } + + #[test] + fn test_list_array_reader() { + test_list_array::(); + } + + #[test] + fn test_large_list_array_reader() { + test_list_array::() + } + + #[test] + fn test_nested_lists() { + // Construct column schema + let message_type = " + message table { + REPEATED group table_info { + REQUIRED BYTE_ARRAY name; + REPEATED group cols { + REQUIRED BYTE_ARRAY name; + REQUIRED INT32 type; + OPTIONAL INT32 length; + } + REPEATED group tags { + REQUIRED BYTE_ARRAY name; + REQUIRED INT32 type; + OPTIONAL INT32 length; + } + } + } + "; + + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + + let arrow_schema = parquet_to_arrow_schema(schema.as_ref(), None).unwrap(); + + let file = tempfile::tempfile().unwrap(); + let props = WriterProperties::builder() + .set_max_row_group_size(200) + .build(); + + let writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + Arc::new(arrow_schema), + Some(props), + ) + .unwrap(); + writer.close().unwrap(); + + let file_reader: Arc = + Arc::new(SerializedFileReader::new(file).unwrap()); + + let file_metadata = file_reader.metadata().file_metadata(); + let arrow_schema = parquet_to_arrow_schema( + file_metadata.schema_descr(), + file_metadata.key_value_metadata(), + ) + .unwrap(); + + let schema = file_metadata.schema_descr_ptr(); + let mask = ProjectionMask::leaves(&schema, vec![0]); + + let mut array_reader = build_array_reader( + schema, + Arc::new(arrow_schema), + mask, + Box::new(file_reader), + ) + .unwrap(); + + let batch = array_reader.next_batch(100).unwrap(); + assert_eq!(batch.data_type(), array_reader.get_data_type()); + assert_eq!( + batch.data_type(), + &ArrowType::Struct(vec![Field::new( + "table_info", + ArrowType::List(Box::new(Field::new( + "table_info", + ArrowType::Struct(vec![Field::new("name", ArrowType::Binary, false)]), + false + ))), + false + )]) + ); + assert_eq!(batch.len(), 0); + } +} diff --git a/parquet/src/arrow/array_reader/map_array.rs b/parquet/src/arrow/array_reader/map_array.rs new file mode 100644 index 000000000000..efeafe201497 --- /dev/null +++ b/parquet/src/arrow/array_reader/map_array.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::array_reader::ArrayReader; +use crate::errors::ParquetError::ArrowError; +use crate::errors::{ParquetError, Result}; +use arrow::array::{ArrayDataBuilder, ArrayRef, MapArray}; +use arrow::buffer::{Buffer, MutableBuffer}; +use arrow::datatypes::DataType as ArrowType; +use arrow::datatypes::ToByteSlice; +use arrow::util::bit_util; +use std::any::Any; +use std::sync::Arc; + +/// Implementation of a map array reader. +pub struct MapArrayReader { + key_reader: Box, + value_reader: Box, + data_type: ArrowType, + map_def_level: i16, + map_rep_level: i16, +} + +impl MapArrayReader { + pub fn new( + key_reader: Box, + value_reader: Box, + data_type: ArrowType, + def_level: i16, + rep_level: i16, + ) -> Self { + Self { + key_reader, + value_reader, + data_type, + map_def_level: rep_level, + map_rep_level: def_level, + } + } +} + +impl ArrayReader for MapArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + let key_array = self.key_reader.next_batch(batch_size)?; + let value_array = self.value_reader.next_batch(batch_size)?; + + // Check that key and value have the same lengths + let key_length = key_array.len(); + if key_length != value_array.len() { + return Err(general_err!( + "Map key and value should have the same lengths." + )); + } + + let def_levels = self + .key_reader + .get_def_levels() + .ok_or_else(|| ArrowError("item_reader def levels are None.".to_string()))?; + let rep_levels = self + .key_reader + .get_rep_levels() + .ok_or_else(|| ArrowError("item_reader rep levels are None.".to_string()))?; + + if !((def_levels.len() == rep_levels.len()) && (rep_levels.len() == key_length)) { + return Err(ArrowError( + "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), + )); + } + + let entry_data_type = if let ArrowType::Map(field, _) = &self.data_type { + field.data_type().clone() + } else { + return Err(ArrowError("Expected a map arrow type".to_string())); + }; + + let entry_data = ArrayDataBuilder::new(entry_data_type) + .len(key_length) + .add_child_data(key_array.data().clone()) + .add_child_data(value_array.data().clone()); + let entry_data = unsafe { entry_data.build_unchecked() }; + + let entry_len = rep_levels.iter().filter(|level| **level == 0).count(); + + // first item in each list has rep_level = 0, subsequent items have rep_level = 1 + let mut offsets: Vec = Vec::new(); + let mut cur_offset = 0; + def_levels.iter().zip(rep_levels).for_each(|(d, r)| { + if *r == 0 || d == &self.map_def_level { + offsets.push(cur_offset); + } + if d > &self.map_def_level { + cur_offset += 1; + } + }); + offsets.push(cur_offset); + + let num_bytes = bit_util::ceil(offsets.len(), 8); + // TODO: A useful optimization is to use the null count to fill with + // 0 or null, to reduce individual bits set in a loop. + // To favour dense data, set every slot to true, then unset + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); + let null_slice = null_buf.as_slice_mut(); + let mut list_index = 0; + for i in 0..rep_levels.len() { + // If the level is lower than empty, then the slot is null. + // When a list is non-nullable, its empty level = null level, + // so this automatically factors that in. + if rep_levels[i] == 0 && def_levels[i] < self.map_def_level { + // should be empty list + bit_util::unset_bit(null_slice, list_index); + } + if rep_levels[i] == 0 { + list_index += 1; + } + } + let value_offsets = Buffer::from(&offsets.to_byte_slice()); + + // Now we can build array data + let array_data = ArrayDataBuilder::new(self.data_type.clone()) + .len(entry_len) + .add_buffer(value_offsets) + .null_bit_buffer(Some(null_buf.into())) + .add_child_data(entry_data); + + let array_data = unsafe { array_data.build_unchecked() }; + + Ok(Arc::new(MapArray::from(array_data))) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + // Children definition levels should describe the same parent structure, + // so return key_reader only + self.key_reader.get_def_levels() + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + // Children repetition levels should describe the same parent structure, + // so return key_reader only + self.key_reader.get_rep_levels() + } +} + +#[cfg(test)] +mod tests { + //TODO: Add unit tests (#1561) +} diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs new file mode 100644 index 000000000000..6207b377d137 --- /dev/null +++ b/parquet/src/arrow/array_reader/mod.rs @@ -0,0 +1,1610 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic for reading into arrow arrays + +use std::any::Any; +use std::cmp::max; +use std::marker::PhantomData; +use std::result::Result::Ok; +use std::sync::Arc; +use std::vec::Vec; + +use arrow::array::{ + Array, ArrayData, ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, + DecimalArray, Int32Array, Int64Array, PrimitiveArray, StructArray, +}; +use arrow::buffer::Buffer; +use arrow::datatypes::{ + ArrowPrimitiveType, BooleanType as ArrowBooleanType, DataType as ArrowType, + Float32Type as ArrowFloat32Type, Float64Type as ArrowFloat64Type, + Int32Type as ArrowInt32Type, Int64Type as ArrowInt64Type, + UInt32Type as ArrowUInt32Type, UInt64Type as ArrowUInt64Type, +}; + +use crate::arrow::buffer::converter::Converter; +use crate::arrow::record_reader::buffer::{ScalarValue, ValuesBuffer}; +use crate::arrow::record_reader::{GenericRecordReader, RecordReader}; +use crate::arrow::schema::parquet_to_arrow_field; +use crate::basic::Type as PhysicalType; +use crate::column::page::PageIterator; +use crate::column::reader::decoder::ColumnValueDecoder; +use crate::column::reader::ColumnReaderImpl; +use crate::data_type::DataType; +use crate::errors::{ParquetError, ParquetError::ArrowError, Result}; +use crate::file::reader::{FilePageIterator, FileReader}; +use crate::schema::types::{ColumnDescPtr, SchemaDescPtr}; + +mod builder; +mod byte_array; +mod byte_array_dictionary; +mod empty_array; +mod list_array; +mod map_array; + +#[cfg(test)] +mod test_util; + +pub use builder::build_array_reader; +pub use byte_array::make_byte_array_reader; +pub use byte_array_dictionary::make_byte_array_dictionary_reader; +pub use list_array::ListArrayReader; +pub use map_array::MapArrayReader; + +/// Array reader reads parquet data into arrow array. +pub trait ArrayReader: Send { + fn as_any(&self) -> &dyn Any; + + /// Returns the arrow type of this array reader. + fn get_data_type(&self) -> &ArrowType; + + /// Reads at most `batch_size` records into an arrow array and return it. + fn next_batch(&mut self, batch_size: usize) -> Result; + + /// If this array has a non-zero definition level, i.e. has a nullable parent + /// array, returns the definition levels of data from the last call of `next_batch` + /// + /// Otherwise returns None + /// + /// This is used by parent [`ArrayReader`] to compute their null bitmaps + fn get_def_levels(&self) -> Option<&[i16]>; + + /// If this array has a non-zero repetition level, i.e. has a repeated parent + /// array, returns the repetition levels of data from the last call of `next_batch` + /// + /// Otherwise returns None + /// + /// This is used by parent [`ArrayReader`] to compute their array offsets + fn get_rep_levels(&self) -> Option<&[i16]>; +} + +/// A collection of row groups +pub trait RowGroupCollection { + /// Get schema of parquet file. + fn schema(&self) -> Result; + + /// Get the numer of rows in this collection + fn num_rows(&self) -> usize; + + /// Returns an iterator over the column chunks for particular column + fn column_chunks(&self, i: usize) -> Result>; +} + +impl RowGroupCollection for Arc { + fn schema(&self) -> Result { + Ok(self.metadata().file_metadata().schema_descr_ptr()) + } + + fn num_rows(&self) -> usize { + self.metadata().file_metadata().num_rows() as usize + } + + fn column_chunks(&self, column_index: usize) -> Result> { + let iterator = FilePageIterator::new(column_index, Arc::clone(self))?; + Ok(Box::new(iterator)) + } +} + +/// Uses `record_reader` to read up to `batch_size` records from `pages` +/// +/// Returns the number of records read, which can be less than batch_size if +/// pages is exhausted. +fn read_records( + record_reader: &mut GenericRecordReader, + pages: &mut dyn PageIterator, + batch_size: usize, +) -> Result +where + V: ValuesBuffer + Default, + CV: ColumnValueDecoder, +{ + let mut records_read = 0usize; + while records_read < batch_size { + let records_to_read = batch_size - records_read; + + let records_read_once = record_reader.read_records(records_to_read)?; + records_read += records_read_once; + + // Record reader exhausted + if records_read_once < records_to_read { + if let Some(page_reader) = pages.next() { + // Read from new page reader (i.e. column chunk) + record_reader.set_page_reader(page_reader?)?; + } else { + // Page reader also exhausted + break; + } + } + } + Ok(records_read) +} + +/// A NullArrayReader reads Parquet columns stored as null int32s with an Arrow +/// NullArray type. +pub struct NullArrayReader +where + T: DataType, + T::T: ScalarValue, +{ + data_type: ArrowType, + pages: Box, + def_levels_buffer: Option, + rep_levels_buffer: Option, + column_desc: ColumnDescPtr, + record_reader: RecordReader, + _type_marker: PhantomData, +} + +impl NullArrayReader +where + T: DataType, + T::T: ScalarValue, +{ + /// Construct null array reader. + pub fn new(pages: Box, column_desc: ColumnDescPtr) -> Result { + let record_reader = RecordReader::::new(column_desc.clone()); + + Ok(Self { + data_type: ArrowType::Null, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + column_desc, + record_reader, + _type_marker: PhantomData, + }) + } +} + +/// Implementation of primitive array reader. +impl ArrayReader for NullArrayReader +where + T: DataType, + T::T: ScalarValue, +{ + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type of primitive array. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + /// Reads at most `batch_size` records into array. + fn next_batch(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + + // convert to arrays + let array = arrow::array::NullArray::new(self.record_reader.num_values()); + + // save definition and repetition buffers + self.def_levels_buffer = self.record_reader.consume_def_levels()?; + self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; + + // Must consume bitmap buffer + self.record_reader.consume_bitmap_buffer()?; + + self.record_reader.reset(); + Ok(Arc::new(array)) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } +} + +/// Primitive array readers are leaves of array reader tree. They accept page iterator +/// and read them into primitive arrays. +pub struct PrimitiveArrayReader +where + T: DataType, + T::T: ScalarValue, +{ + data_type: ArrowType, + pages: Box, + def_levels_buffer: Option, + rep_levels_buffer: Option, + column_desc: ColumnDescPtr, + record_reader: RecordReader, +} + +impl PrimitiveArrayReader +where + T: DataType, + T::T: ScalarValue, +{ + /// Construct primitive array reader. + pub fn new( + pages: Box, + column_desc: ColumnDescPtr, + arrow_type: Option, + ) -> Result { + Self::new_with_options(pages, column_desc, arrow_type, false) + } + + /// Construct primitive array reader with ability to only compute null mask and not + /// buffer level data + pub fn new_with_options( + pages: Box, + column_desc: ColumnDescPtr, + arrow_type: Option, + null_mask_only: bool, + ) -> Result { + // Check if Arrow type is specified, else create it from Parquet type + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; + + let record_reader = + RecordReader::::new_with_options(column_desc.clone(), null_mask_only); + + Ok(Self { + data_type, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + column_desc, + record_reader, + }) + } +} + +/// Implementation of primitive array reader. +impl ArrayReader for PrimitiveArrayReader +where + T: DataType, + T::T: ScalarValue, +{ + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type of primitive array. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + /// Reads at most `batch_size` records into array. + fn next_batch(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + + let target_type = self.get_data_type().clone(); + let arrow_data_type = match T::get_physical_type() { + PhysicalType::BOOLEAN => ArrowBooleanType::DATA_TYPE, + PhysicalType::INT32 => { + match target_type { + ArrowType::UInt32 => { + // follow C++ implementation and use overflow/reinterpret cast from i32 to u32 which will map + // `i32::MIN..0` to `(i32::MAX as u32)..u32::MAX` + ArrowUInt32Type::DATA_TYPE + } + _ => ArrowInt32Type::DATA_TYPE, + } + } + PhysicalType::INT64 => { + match target_type { + ArrowType::UInt64 => { + // follow C++ implementation and use overflow/reinterpret cast from i64 to u64 which will map + // `i64::MIN..0` to `(i64::MAX as u64)..u64::MAX` + ArrowUInt64Type::DATA_TYPE + } + _ => ArrowInt64Type::DATA_TYPE, + } + } + PhysicalType::FLOAT => ArrowFloat32Type::DATA_TYPE, + PhysicalType::DOUBLE => ArrowFloat64Type::DATA_TYPE, + PhysicalType::INT96 + | PhysicalType::BYTE_ARRAY + | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + unreachable!( + "PrimitiveArrayReaders don't support complex physical types" + ); + } + }; + + // Convert to arrays by using the Parquet physical type. + // The physical types are then cast to Arrow types if necessary + + let mut record_data = self.record_reader.consume_record_data()?; + + if T::get_physical_type() == PhysicalType::BOOLEAN { + let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); + + for e in record_data.as_slice() { + boolean_buffer.append(*e > 0); + } + record_data = boolean_buffer.finish(); + } + + let array_data = ArrayDataBuilder::new(arrow_data_type) + .len(self.record_reader.num_values()) + .add_buffer(record_data) + .null_bit_buffer(self.record_reader.consume_bitmap_buffer()?); + + let array_data = unsafe { array_data.build_unchecked() }; + let array = match T::get_physical_type() { + PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)) as ArrayRef, + PhysicalType::INT32 => { + Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef + } + PhysicalType::INT64 => { + Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef + } + PhysicalType::FLOAT => { + Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef + } + PhysicalType::DOUBLE => { + Arc::new(PrimitiveArray::::from(array_data)) as ArrayRef + } + PhysicalType::INT96 + | PhysicalType::BYTE_ARRAY + | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + unreachable!( + "PrimitiveArrayReaders don't support complex physical types" + ); + } + }; + + // cast to Arrow type + // We make a strong assumption here that the casts should be infallible. + // If the cast fails because of incompatible datatypes, then there might + // be a bigger problem with how Arrow schemas are converted to Parquet. + // + // As there is not always a 1:1 mapping between Arrow and Parquet, there + // are datatypes which we must convert explicitly. + // These are: + // - date64: we should cast int32 to date32, then date32 to date64. + let array = match target_type { + ArrowType::Date64 => { + // this is cheap as it internally reinterprets the data + let a = arrow::compute::cast(&array, &ArrowType::Date32)?; + arrow::compute::cast(&a, &target_type)? + } + ArrowType::Decimal(p, s) => { + let array = match array.data_type() { + ArrowType::Int32 => array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|v| v.into())) + .collect::(), + + ArrowType::Int64 => array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|v| v.into())) + .collect::(), + _ => { + return Err(ArrowError(format!( + "Cannot convert {:?} to decimal", + array.data_type() + ))) + } + } + .with_precision_and_scale(p, s)?; + + Arc::new(array) as ArrayRef + } + _ => arrow::compute::cast(&array, &target_type)?, + }; + + // save definition and repetition buffers + self.def_levels_buffer = self.record_reader.consume_def_levels()?; + self.rep_levels_buffer = self.record_reader.consume_rep_levels()?; + self.record_reader.reset(); + Ok(array) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer + .as_ref() + .map(|buf| buf.typed_data()) + } +} + +/// Primitive array readers are leaves of array reader tree. They accept page iterator +/// and read them into primitive arrays. +pub struct ComplexObjectArrayReader +where + T: DataType, + C: Converter>, ArrayRef> + 'static, +{ + data_type: ArrowType, + pages: Box, + def_levels_buffer: Option>, + rep_levels_buffer: Option>, + column_desc: ColumnDescPtr, + column_reader: Option>, + converter: C, + _parquet_type_marker: PhantomData, + _converter_marker: PhantomData, +} + +impl ArrayReader for ComplexObjectArrayReader +where + T: DataType, + C: Converter>, ArrayRef> + Send + 'static, +{ + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + // Try to initialize column reader + if self.column_reader.is_none() { + self.next_column_reader()?; + } + + let mut data_buffer: Vec = Vec::with_capacity(batch_size); + data_buffer.resize_with(batch_size, T::T::default); + + let mut def_levels_buffer = if self.column_desc.max_def_level() > 0 { + let mut buf: Vec = Vec::with_capacity(batch_size); + buf.resize_with(batch_size, || 0); + Some(buf) + } else { + None + }; + + let mut rep_levels_buffer = if self.column_desc.max_rep_level() > 0 { + let mut buf: Vec = Vec::with_capacity(batch_size); + buf.resize_with(batch_size, || 0); + Some(buf) + } else { + None + }; + + let mut num_read = 0; + + while self.column_reader.is_some() && num_read < batch_size { + let num_to_read = batch_size - num_read; + let cur_data_buf = &mut data_buffer[num_read..]; + let cur_def_levels_buf = + def_levels_buffer.as_mut().map(|b| &mut b[num_read..]); + let cur_rep_levels_buf = + rep_levels_buffer.as_mut().map(|b| &mut b[num_read..]); + let (data_read, levels_read) = + self.column_reader.as_mut().unwrap().read_batch( + num_to_read, + cur_def_levels_buf, + cur_rep_levels_buf, + cur_data_buf, + )?; + + // Fill space + if levels_read > data_read { + def_levels_buffer.iter().for_each(|def_levels_buffer| { + let (mut level_pos, mut data_pos) = (levels_read, data_read); + while level_pos > 0 && data_pos > 0 { + if def_levels_buffer[num_read + level_pos - 1] + == self.column_desc.max_def_level() + { + cur_data_buf.swap(level_pos - 1, data_pos - 1); + level_pos -= 1; + data_pos -= 1; + } else { + level_pos -= 1; + } + } + }); + } + + let values_read = max(levels_read, data_read); + num_read += values_read; + // current page exhausted && page iterator exhausted + if values_read < num_to_read && !self.next_column_reader()? { + break; + } + } + + data_buffer.truncate(num_read); + def_levels_buffer + .iter_mut() + .for_each(|buf| buf.truncate(num_read)); + rep_levels_buffer + .iter_mut() + .for_each(|buf| buf.truncate(num_read)); + + self.def_levels_buffer = def_levels_buffer; + self.rep_levels_buffer = rep_levels_buffer; + + let data: Vec> = if self.def_levels_buffer.is_some() { + data_buffer + .into_iter() + .zip(self.def_levels_buffer.as_ref().unwrap().iter()) + .map(|(t, def_level)| { + if *def_level == self.column_desc.max_def_level() { + Some(t) + } else { + None + } + }) + .collect() + } else { + data_buffer.into_iter().map(Some).collect() + }; + + let mut array = self.converter.convert(data)?; + + if let ArrowType::Dictionary(_, _) = self.data_type { + array = arrow::compute::cast(&array, &self.data_type)?; + } + + Ok(array) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer.as_deref() + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer.as_deref() + } +} + +impl ComplexObjectArrayReader +where + T: DataType, + C: Converter>, ArrayRef> + 'static, +{ + pub fn new( + pages: Box, + column_desc: ColumnDescPtr, + converter: C, + arrow_type: Option, + ) -> Result { + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; + + Ok(Self { + data_type, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + column_desc, + column_reader: None, + converter, + _parquet_type_marker: PhantomData, + _converter_marker: PhantomData, + }) + } + + fn next_column_reader(&mut self) -> Result { + Ok(match self.pages.next() { + Some(page) => { + self.column_reader = + Some(ColumnReaderImpl::::new(self.column_desc.clone(), page?)); + true + } + None => false, + }) + } +} + +/// Implementation of struct array reader. +pub struct StructArrayReader { + children: Vec>, + data_type: ArrowType, + struct_def_level: i16, + struct_rep_level: i16, + nullable: bool, +} + +impl StructArrayReader { + /// Construct struct array reader. + pub fn new( + data_type: ArrowType, + children: Vec>, + def_level: i16, + rep_level: i16, + nullable: bool, + ) -> Self { + Self { + data_type, + children, + struct_def_level: def_level, + struct_rep_level: rep_level, + nullable, + } + } +} + +impl ArrayReader for StructArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns data type. + /// This must be a struct. + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + /// Read `batch_size` struct records. + /// + /// Definition levels of struct array is calculated as following: + /// ```ignore + /// def_levels[i] = min(child1_def_levels[i], child2_def_levels[i], ..., + /// childn_def_levels[i]); + /// ``` + /// + /// Repetition levels of struct array is calculated as following: + /// ```ignore + /// rep_levels[i] = child1_rep_levels[i]; + /// ``` + /// + /// The null bitmap of struct array is calculated from def_levels: + /// ```ignore + /// null_bitmap[i] = (def_levels[i] >= self.def_level); + /// ``` + fn next_batch(&mut self, batch_size: usize) -> Result { + if self.children.is_empty() { + return Ok(Arc::new(StructArray::from(Vec::new()))); + } + + let children_array = self + .children + .iter_mut() + .map(|reader| reader.next_batch(batch_size)) + .collect::>>()?; + + // check that array child data has same size + let children_array_len = + children_array.first().map(|arr| arr.len()).ok_or_else(|| { + general_err!("Struct array reader should have at least one child!") + })?; + + let all_children_len_eq = children_array + .iter() + .all(|arr| arr.len() == children_array_len); + if !all_children_len_eq { + return Err(general_err!("Not all children array length are the same!")); + } + + // Now we can build array data + let mut array_data_builder = ArrayDataBuilder::new(self.data_type.clone()) + .len(children_array_len) + .child_data( + children_array + .iter() + .map(|x| x.data().clone()) + .collect::>(), + ); + + if self.nullable { + // calculate struct def level data + + // children should have consistent view of parent, only need to inspect first child + let def_levels = self.children[0] + .get_def_levels() + .expect("child with nullable parents must have definition level"); + + // calculate bitmap for current array + let mut bitmap_builder = BooleanBufferBuilder::new(children_array_len); + + match self.children[0].get_rep_levels() { + Some(rep_levels) => { + // Sanity check + assert_eq!(rep_levels.len(), def_levels.len()); + + for (rep_level, def_level) in rep_levels.iter().zip(def_levels) { + if rep_level > &self.struct_rep_level { + // Already handled by inner list - SKIP + continue; + } + bitmap_builder.append(*def_level >= self.struct_def_level) + } + } + None => { + for def_level in def_levels { + bitmap_builder.append(*def_level >= self.struct_def_level) + } + } + } + + if bitmap_builder.len() != children_array_len { + return Err(general_err!("Failed to decode level data for struct array")); + } + + array_data_builder = + array_data_builder.null_bit_buffer(Some(bitmap_builder.finish())); + } + + let array_data = unsafe { array_data_builder.build_unchecked() }; + Ok(Arc::new(StructArray::from(array_data))) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + // Children definition levels should describe the same + // parent structure, so return first child's + self.children.first().and_then(|l| l.get_def_levels()) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + // Children definition levels should describe the same + // parent structure, so return first child's + self.children.first().and_then(|l| l.get_rep_levels()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::VecDeque; + use std::sync::Arc; + + use rand::distributions::uniform::SampleUniform; + use rand::{thread_rng, Rng}; + + use crate::arrow::array_reader::test_util::InMemoryArrayReader; + use arrow::array::{ + Array, ArrayRef, ListArray, PrimitiveArray, StringArray, StructArray, + }; + use arrow::datatypes::{ + ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field, + Int32Type as ArrowInt32, Int64Type as ArrowInt64, + Time32MillisecondType as ArrowTime32MillisecondArray, + Time64MicrosecondType as ArrowTime64MicrosecondArray, + TimestampMicrosecondType as ArrowTimestampMicrosecondType, + TimestampMillisecondType as ArrowTimestampMillisecondType, + }; + + use crate::arrow::buffer::converter::{Utf8ArrayConverter, Utf8Converter}; + use crate::basic::{Encoding, Type as PhysicalType}; + use crate::column::page::Page; + use crate::data_type::{ByteArray, ByteArrayType, DataType, Int32Type, Int64Type}; + use crate::schema::parser::parse_message_type; + use crate::schema::types::{ColumnDescPtr, SchemaDescriptor}; + use crate::util::test_common::make_pages; + use crate::util::test_common::page_util::{ + DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator, + }; + + use super::*; + + fn make_column_chunks( + column_desc: ColumnDescPtr, + encoding: Encoding, + num_levels: usize, + min_value: T::T, + max_value: T::T, + def_levels: &mut Vec, + rep_levels: &mut Vec, + values: &mut Vec, + page_lists: &mut Vec>, + use_v2: bool, + num_chunks: usize, + ) where + T::T: PartialOrd + SampleUniform + Copy, + { + for _i in 0..num_chunks { + let mut pages = VecDeque::new(); + let mut data = Vec::new(); + let mut page_def_levels = Vec::new(); + let mut page_rep_levels = Vec::new(); + + make_pages::( + column_desc.clone(), + encoding, + 1, + num_levels, + min_value, + max_value, + &mut page_def_levels, + &mut page_rep_levels, + &mut data, + &mut pages, + use_v2, + ); + + def_levels.append(&mut page_def_levels); + rep_levels.append(&mut page_rep_levels); + values.append(&mut data); + page_lists.push(Vec::from(pages)); + } + } + + #[test] + fn test_primitive_array_reader_empty_pages() { + // Construct column schema + let message_type = " + message test_schema { + REQUIRED INT32 leaf; + } + "; + + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + + let column_desc = schema.column(0); + let page_iterator = test_util::EmptyPageIterator::new(schema); + + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); + + // expect no values to be read + let array = array_reader.next_batch(50).unwrap(); + assert!(array.is_empty()); + } + + #[test] + fn test_primitive_array_reader_data() { + // Construct column schema + let message_type = " + message test_schema { + REQUIRED INT32 leaf; + } + "; + + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + + let column_desc = schema.column(0); + + // Construct page iterator + { + let mut data = Vec::new(); + let mut page_lists = Vec::new(); + make_column_chunks::( + column_desc.clone(), + Encoding::PLAIN, + 100, + 1, + 200, + &mut Vec::new(), + &mut Vec::new(), + &mut data, + &mut page_lists, + true, + 2, + ); + let page_iterator = + InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); + + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); + + // Read first 50 values, which are all from the first column chunk + let array = array_reader.next_batch(50).unwrap(); + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!( + &PrimitiveArray::::from(data[0..50].to_vec()), + array + ); + + // Read next 100 values, the first 50 ones are from the first column chunk, + // and the last 50 ones are from the second column chunk + let array = array_reader.next_batch(100).unwrap(); + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!( + &PrimitiveArray::::from(data[50..150].to_vec()), + array + ); + + // Try to read 100 values, however there are only 50 values + let array = array_reader.next_batch(100).unwrap(); + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!( + &PrimitiveArray::::from(data[150..200].to_vec()), + array + ); + } + } + + macro_rules! test_primitive_array_reader_one_type { + ($arrow_parquet_type:ty, $physical_type:expr, $converted_type_str:expr, $result_arrow_type:ty, $result_arrow_cast_type:ty, $result_primitive_type:ty) => {{ + let message_type = format!( + " + message test_schema {{ + REQUIRED {:?} leaf ({}); + }} + ", + $physical_type, $converted_type_str + ); + let schema = parse_message_type(&message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + + let column_desc = schema.column(0); + + // Construct page iterator + { + let mut data = Vec::new(); + let mut page_lists = Vec::new(); + make_column_chunks::<$arrow_parquet_type>( + column_desc.clone(), + Encoding::PLAIN, + 100, + 1, + 200, + &mut Vec::new(), + &mut Vec::new(), + &mut data, + &mut page_lists, + true, + 2, + ); + let page_iterator = InMemoryPageIterator::new( + schema.clone(), + column_desc.clone(), + page_lists, + ); + let mut array_reader = PrimitiveArrayReader::<$arrow_parquet_type>::new( + Box::new(page_iterator), + column_desc.clone(), + None, + ) + .expect("Unable to get array reader"); + + let array = array_reader + .next_batch(50) + .expect("Unable to get batch from reader"); + + let result_data_type = <$result_arrow_type>::DATA_TYPE; + let array = array + .as_any() + .downcast_ref::>() + .expect( + format!( + "Unable to downcast {:?} to {:?}", + array.data_type(), + result_data_type + ) + .as_str(), + ); + + // create expected array as primitive, and cast to result type + let expected = PrimitiveArray::<$result_arrow_cast_type>::from( + data[0..50] + .iter() + .map(|x| *x as $result_primitive_type) + .collect::>(), + ); + let expected = Arc::new(expected) as ArrayRef; + let expected = arrow::compute::cast(&expected, &result_data_type) + .expect("Unable to cast expected array"); + assert_eq!(expected.data_type(), &result_data_type); + let expected = expected + .as_any() + .downcast_ref::>() + .expect( + format!( + "Unable to downcast expected {:?} to {:?}", + expected.data_type(), + result_data_type + ) + .as_str(), + ); + assert_eq!(expected, array); + } + }}; + } + + #[test] + fn test_primitive_array_reader_temporal_types() { + test_primitive_array_reader_one_type!( + Int32Type, + PhysicalType::INT32, + "DATE", + ArrowDate32, + ArrowInt32, + i32 + ); + test_primitive_array_reader_one_type!( + Int32Type, + PhysicalType::INT32, + "TIME_MILLIS", + ArrowTime32MillisecondArray, + ArrowInt32, + i32 + ); + test_primitive_array_reader_one_type!( + Int64Type, + PhysicalType::INT64, + "TIME_MICROS", + ArrowTime64MicrosecondArray, + ArrowInt64, + i64 + ); + test_primitive_array_reader_one_type!( + Int64Type, + PhysicalType::INT64, + "TIMESTAMP_MILLIS", + ArrowTimestampMillisecondType, + ArrowInt64, + i64 + ); + test_primitive_array_reader_one_type!( + Int64Type, + PhysicalType::INT64, + "TIMESTAMP_MICROS", + ArrowTimestampMicrosecondType, + ArrowInt64, + i64 + ); + } + + #[test] + fn test_primitive_array_reader_def_and_rep_levels() { + // Construct column schema + let message_type = " + message test_schema { + REPEATED Group test_mid { + OPTIONAL INT32 leaf; + } + } + "; + + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + + let column_desc = schema.column(0); + + // Construct page iterator + { + let mut def_levels = Vec::new(); + let mut rep_levels = Vec::new(); + let mut page_lists = Vec::new(); + make_column_chunks::( + column_desc.clone(), + Encoding::PLAIN, + 100, + 1, + 200, + &mut def_levels, + &mut rep_levels, + &mut Vec::new(), + &mut page_lists, + true, + 2, + ); + + let page_iterator = + InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); + + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); + + let mut accu_len: usize = 0; + + // Read first 50 values, which are all from the first column chunk + let array = array_reader.next_batch(50).unwrap(); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + accu_len += array.len(); + + // Read next 100 values, the first 50 ones are from the first column chunk, + // and the last 50 ones are from the second column chunk + let array = array_reader.next_batch(100).unwrap(); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + accu_len += array.len(); + + // Try to read 100 values, however there are only 50 values + let array = array_reader.next_batch(100).unwrap(); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + } + } + + #[test] + fn test_complex_array_reader_no_pages() { + let message_type = " + message test_schema { + REPEATED Group test_mid { + OPTIONAL BYTE_ARRAY leaf (UTF8); + } + } + "; + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + let column_desc = schema.column(0); + let pages: Vec> = Vec::new(); + let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); + + let converter = Utf8Converter::new(Utf8ArrayConverter {}); + let mut array_reader = + ComplexObjectArrayReader::::new( + Box::new(page_iterator), + column_desc, + converter, + None, + ) + .unwrap(); + + let values_per_page = 100; // this value is arbitrary in this test - the result should always be an array of 0 length + let array = array_reader.next_batch(values_per_page).unwrap(); + assert_eq!(array.len(), 0); + } + + #[test] + fn test_complex_array_reader_def_and_rep_levels() { + // Construct column schema + let message_type = " + message test_schema { + REPEATED Group test_mid { + OPTIONAL BYTE_ARRAY leaf (UTF8); + } + } + "; + let num_pages = 2; + let values_per_page = 100; + let str_base = "Hello World"; + + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + + let max_def_level = schema.column(0).max_def_level(); + let max_rep_level = schema.column(0).max_rep_level(); + + assert_eq!(max_def_level, 2); + assert_eq!(max_rep_level, 1); + + let mut rng = thread_rng(); + let column_desc = schema.column(0); + let mut pages: Vec> = Vec::new(); + + let mut rep_levels = Vec::with_capacity(num_pages * values_per_page); + let mut def_levels = Vec::with_capacity(num_pages * values_per_page); + let mut all_values = Vec::with_capacity(num_pages * values_per_page); + + for i in 0..num_pages { + let mut values = Vec::with_capacity(values_per_page); + + for _ in 0..values_per_page { + let def_level = rng.gen_range(0..max_def_level + 1); + let rep_level = rng.gen_range(0..max_rep_level + 1); + if def_level == max_def_level { + let len = rng.gen_range(1..str_base.len()); + let slice = &str_base[..len]; + values.push(ByteArray::from(slice)); + all_values.push(Some(slice.to_string())); + } else { + all_values.push(None) + } + rep_levels.push(rep_level); + def_levels.push(def_level) + } + + let range = i * values_per_page..(i + 1) * values_per_page; + let mut pb = + DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); + + pb.add_rep_levels(max_rep_level, &rep_levels.as_slice()[range.clone()]); + pb.add_def_levels(max_def_level, &def_levels.as_slice()[range]); + pb.add_values::(Encoding::PLAIN, values.as_slice()); + + let data_page = pb.consume(); + pages.push(vec![data_page]); + } + + let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); + + let converter = Utf8Converter::new(Utf8ArrayConverter {}); + let mut array_reader = + ComplexObjectArrayReader::::new( + Box::new(page_iterator), + column_desc, + converter, + None, + ) + .unwrap(); + + let mut accu_len: usize = 0; + + let array = array_reader.next_batch(values_per_page / 2).unwrap(); + assert_eq!(array.len(), values_per_page / 2); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + accu_len += array.len(); + + // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, + // and the last values_per_page/2 ones are from the second column chunk + let array = array_reader.next_batch(values_per_page).unwrap(); + assert_eq!(array.len(), values_per_page); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + let strings = array.as_any().downcast_ref::().unwrap(); + for i in 0..array.len() { + if array.is_valid(i) { + assert_eq!( + all_values[i + accu_len].as_ref().unwrap().as_str(), + strings.value(i) + ) + } else { + assert_eq!(all_values[i + accu_len], None) + } + } + accu_len += array.len(); + + // Try to read values_per_page values, however there are only values_per_page/2 values + let array = array_reader.next_batch(values_per_page).unwrap(); + assert_eq!(array.len(), values_per_page / 2); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + } + + #[test] + fn test_complex_array_reader_dict_enc_string() { + use crate::encodings::encoding::{DictEncoder, Encoder}; + // Construct column schema + let message_type = " + message test_schema { + REPEATED Group test_mid { + OPTIONAL BYTE_ARRAY leaf (UTF8); + } + } + "; + let num_pages = 2; + let values_per_page = 100; + let str_base = "Hello World"; + + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + let column_desc = schema.column(0); + let max_def_level = column_desc.max_def_level(); + let max_rep_level = column_desc.max_rep_level(); + + assert_eq!(max_def_level, 2); + assert_eq!(max_rep_level, 1); + + let mut rng = thread_rng(); + let mut pages: Vec> = Vec::new(); + + let mut rep_levels = Vec::with_capacity(num_pages * values_per_page); + let mut def_levels = Vec::with_capacity(num_pages * values_per_page); + let mut all_values = Vec::with_capacity(num_pages * values_per_page); + + for i in 0..num_pages { + let mut dict_encoder = DictEncoder::::new(column_desc.clone()); + // add data page + let mut values = Vec::with_capacity(values_per_page); + + for _ in 0..values_per_page { + let def_level = rng.gen_range(0..max_def_level + 1); + let rep_level = rng.gen_range(0..max_rep_level + 1); + if def_level == max_def_level { + let len = rng.gen_range(1..str_base.len()); + let slice = &str_base[..len]; + values.push(ByteArray::from(slice)); + all_values.push(Some(slice.to_string())); + } else { + all_values.push(None) + } + rep_levels.push(rep_level); + def_levels.push(def_level) + } + + let range = i * values_per_page..(i + 1) * values_per_page; + let mut pb = + DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); + pb.add_rep_levels(max_rep_level, &rep_levels.as_slice()[range.clone()]); + pb.add_def_levels(max_def_level, &def_levels.as_slice()[range]); + let _ = dict_encoder.put(&values); + let indices = dict_encoder + .write_indices() + .expect("write_indices() should be OK"); + pb.add_indices(indices); + let data_page = pb.consume(); + // for each page log num_values vs actual values in page + // println!("page num_values: {}, values.len(): {}", data_page.num_values(), values.len()); + // add dictionary page + let dict = dict_encoder + .write_dict() + .expect("write_dict() should be OK"); + let dict_page = Page::DictionaryPage { + buf: dict, + num_values: dict_encoder.num_entries() as u32, + encoding: Encoding::RLE_DICTIONARY, + is_sorted: false, + }; + pages.push(vec![dict_page, data_page]); + } + + let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); + let converter = Utf8Converter::new(Utf8ArrayConverter {}); + let mut array_reader = + ComplexObjectArrayReader::::new( + Box::new(page_iterator), + column_desc, + converter, + None, + ) + .unwrap(); + + let mut accu_len: usize = 0; + + // println!("---------- reading a batch of {} values ----------", values_per_page / 2); + let array = array_reader.next_batch(values_per_page / 2).unwrap(); + assert_eq!(array.len(), values_per_page / 2); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + accu_len += array.len(); + + // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, + // and the last values_per_page/2 ones are from the second column chunk + // println!("---------- reading a batch of {} values ----------", values_per_page); + let array = array_reader.next_batch(values_per_page).unwrap(); + assert_eq!(array.len(), values_per_page); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + let strings = array.as_any().downcast_ref::().unwrap(); + for i in 0..array.len() { + if array.is_valid(i) { + assert_eq!( + all_values[i + accu_len].as_ref().unwrap().as_str(), + strings.value(i) + ) + } else { + assert_eq!(all_values[i + accu_len], None) + } + } + accu_len += array.len(); + + // Try to read values_per_page values, however there are only values_per_page/2 values + // println!("---------- reading a batch of {} values ----------", values_per_page); + let array = array_reader.next_batch(values_per_page).unwrap(); + assert_eq!(array.len(), values_per_page / 2); + assert_eq!( + Some(&def_levels[accu_len..(accu_len + array.len())]), + array_reader.get_def_levels() + ); + assert_eq!( + Some(&rep_levels[accu_len..(accu_len + array.len())]), + array_reader.get_rep_levels() + ); + } + + #[test] + fn test_struct_array_reader() { + let array_1 = Arc::new(PrimitiveArray::::from(vec![1, 2, 3, 4, 5])); + let array_reader_1 = InMemoryArrayReader::new( + ArrowType::Int32, + array_1.clone(), + Some(vec![0, 1, 2, 3, 1]), + Some(vec![0, 1, 1, 1, 1]), + ); + + let array_2 = Arc::new(PrimitiveArray::::from(vec![5, 4, 3, 2, 1])); + let array_reader_2 = InMemoryArrayReader::new( + ArrowType::Int32, + array_2.clone(), + Some(vec![0, 1, 3, 1, 2]), + Some(vec![0, 1, 1, 1, 1]), + ); + + let struct_type = ArrowType::Struct(vec![ + Field::new("f1", array_1.data_type().clone(), true), + Field::new("f2", array_2.data_type().clone(), true), + ]); + + let mut struct_array_reader = StructArrayReader::new( + struct_type, + vec![Box::new(array_reader_1), Box::new(array_reader_2)], + 1, + 1, + true, + ); + + let struct_array = struct_array_reader.next_batch(5).unwrap(); + let struct_array = struct_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(5, struct_array.len()); + assert_eq!( + vec![true, false, false, false, false], + (0..5) + .map(|idx| struct_array.data_ref().is_null(idx)) + .collect::>() + ); + assert_eq!( + Some(vec![0, 1, 2, 3, 1].as_slice()), + struct_array_reader.get_def_levels() + ); + assert_eq!( + Some(vec![0, 1, 1, 1, 1].as_slice()), + struct_array_reader.get_rep_levels() + ); + } + + #[test] + fn test_struct_array_reader_list() { + use arrow::datatypes::Int32Type; + // [ + // {foo: [1, 2, null], + // {foo: []}, + // {foo: null}, + // null, + // ] + + let expected_l = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), None]), + Some(vec![]), + None, + None, + ])); + + let validity = Buffer::from([0b00000111]); + let struct_fields = vec![( + Field::new("foo", expected_l.data_type().clone(), true), + expected_l.clone() as ArrayRef, + )]; + let expected = StructArray::from((struct_fields, validity)); + + let array = Arc::new(Int32Array::from_iter(vec![ + Some(1), + Some(2), + None, + None, + None, + None, + ])); + let reader = InMemoryArrayReader::new( + ArrowType::Int32, + array, + Some(vec![4, 4, 3, 2, 1, 0]), + Some(vec![0, 1, 1, 0, 0, 0]), + ); + + let list_reader = ListArrayReader::::new( + Box::new(reader), + expected_l.data_type().clone(), + ArrowType::Int32, + 3, + 1, + true, + ); + + let mut struct_reader = StructArrayReader::new( + expected.data_type().clone(), + vec![Box::new(list_reader)], + 1, + 0, + true, + ); + + let actual = struct_reader.next_batch(1024).unwrap(); + let actual = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(actual, &expected) + } +} diff --git a/parquet/src/arrow/array_reader/test_util.rs b/parquet/src/arrow/array_reader/test_util.rs new file mode 100644 index 000000000000..0c044eb2df63 --- /dev/null +++ b/parquet/src/arrow/array_reader/test_util.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::DataType as ArrowType; +use std::any::Any; +use std::sync::Arc; + +use crate::arrow::array_reader::ArrayReader; +use crate::basic::{ConvertedType, Encoding, Type as PhysicalType}; +use crate::column::page::{PageIterator, PageReader}; +use crate::data_type::{ByteArray, ByteArrayType}; +use crate::encodings::encoding::{get_encoder, DictEncoder, Encoder}; +use crate::errors::Result; +use crate::schema::types::{ + ColumnDescPtr, ColumnDescriptor, ColumnPath, SchemaDescPtr, Type, +}; +use crate::util::memory::ByteBufferPtr; + +/// Returns a descriptor for a UTF-8 column +pub fn utf8_column() -> ColumnDescPtr { + let t = Type::primitive_type_builder("col", PhysicalType::BYTE_ARRAY) + .with_converted_type(ConvertedType::UTF8) + .build() + .unwrap(); + + Arc::new(ColumnDescriptor::new( + Arc::new(t), + 1, + 0, + ColumnPath::new(vec![]), + )) +} + +/// Encode `data` with the provided `encoding` +pub fn encode_byte_array(encoding: Encoding, data: &[ByteArray]) -> ByteBufferPtr { + let descriptor = utf8_column(); + let mut encoder = get_encoder::(descriptor, encoding).unwrap(); + + encoder.put(data).unwrap(); + encoder.flush_buffer().unwrap() +} + +/// Returns the encoded dictionary and value data +pub fn encode_dictionary(data: &[ByteArray]) -> (ByteBufferPtr, ByteBufferPtr) { + let mut dict_encoder = DictEncoder::::new(utf8_column()); + + dict_encoder.put(data).unwrap(); + let encoded_rle = dict_encoder.flush_buffer().unwrap(); + let encoded_dictionary = dict_encoder.write_dict().unwrap(); + + (encoded_dictionary, encoded_rle) +} + +/// Encodes `data` in all the possible encodings +/// +/// Returns an array of data with its associated encoding, along with an encoded dictionary +pub fn byte_array_all_encodings( + data: Vec>, +) -> (Vec<(Encoding, ByteBufferPtr)>, ByteBufferPtr) { + let data: Vec<_> = data.into_iter().map(Into::into).collect(); + let (encoded_dictionary, encoded_rle) = encode_dictionary(&data); + + // A column chunk with all the encodings! + let pages = vec![ + (Encoding::PLAIN, encode_byte_array(Encoding::PLAIN, &data)), + ( + Encoding::DELTA_BYTE_ARRAY, + encode_byte_array(Encoding::DELTA_BYTE_ARRAY, &data), + ), + ( + Encoding::DELTA_LENGTH_BYTE_ARRAY, + encode_byte_array(Encoding::DELTA_LENGTH_BYTE_ARRAY, &data), + ), + (Encoding::PLAIN_DICTIONARY, encoded_rle.clone()), + (Encoding::RLE_DICTIONARY, encoded_rle), + ]; + + (pages, encoded_dictionary) +} + +/// Array reader for test. +pub struct InMemoryArrayReader { + data_type: ArrowType, + array: ArrayRef, + def_levels: Option>, + rep_levels: Option>, + last_idx: usize, + cur_idx: usize, +} + +impl InMemoryArrayReader { + pub fn new( + data_type: ArrowType, + array: ArrayRef, + def_levels: Option>, + rep_levels: Option>, + ) -> Self { + assert!(def_levels + .as_ref() + .map(|d| d.len() == array.len()) + .unwrap_or(true)); + + assert!(rep_levels + .as_ref() + .map(|r| r.len() == array.len()) + .unwrap_or(true)); + + Self { + data_type, + array, + def_levels, + rep_levels, + cur_idx: 0, + last_idx: 0, + } + } +} + +impl ArrayReader for InMemoryArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn next_batch(&mut self, batch_size: usize) -> Result { + assert_ne!(batch_size, 0); + // This replicates the logical normally performed by + // RecordReader to delimit semantic records + let read = match &self.rep_levels { + Some(rep_levels) => { + let rep_levels = &rep_levels[self.cur_idx..]; + let mut levels_read = 0; + let mut records_read = 0; + while levels_read < rep_levels.len() && records_read < batch_size { + if rep_levels[levels_read] == 0 { + records_read += 1; // Start of new record + } + levels_read += 1; + } + + // Find end of current record + while levels_read < rep_levels.len() && rep_levels[levels_read] != 0 { + levels_read += 1 + } + levels_read + } + None => batch_size.min(self.array.len() - self.cur_idx), + }; + + self.last_idx = self.cur_idx; + self.cur_idx += read; + Ok(self.array.slice(self.last_idx, read)) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels + .as_ref() + .map(|l| &l[self.last_idx..self.cur_idx]) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels + .as_ref() + .map(|l| &l[self.last_idx..self.cur_idx]) + } +} + +/// Iterator for testing reading empty columns +pub struct EmptyPageIterator { + schema: SchemaDescPtr, +} + +impl EmptyPageIterator { + pub fn new(schema: SchemaDescPtr) -> Self { + EmptyPageIterator { schema } + } +} + +impl Iterator for EmptyPageIterator { + type Item = Result>; + + fn next(&mut self) -> Option { + None + } +} + +impl PageIterator for EmptyPageIterator { + fn schema(&mut self) -> Result { + Ok(self.schema.clone()) + } + + fn column_schema(&mut self) -> Result { + Ok(self.schema.column(0)) + } +} diff --git a/parquet/src/arrow/arrow_array_reader.rs b/parquet/src/arrow/arrow_array_reader.rs deleted file mode 100644 index 3f2acf4568d7..000000000000 --- a/parquet/src/arrow/arrow_array_reader.rs +++ /dev/null @@ -1,1726 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::array_reader::ArrayReader; -use crate::arrow::schema::parquet_to_arrow_field; -use crate::basic::Encoding; -use crate::data_type::{ByteArray, ByteArrayType}; -use crate::decoding::{Decoder, DeltaByteArrayDecoder}; -use crate::errors::{ParquetError, Result}; -use crate::{ - column::page::{Page, PageIterator}, - memory::ByteBufferPtr, - schema::types::{ColumnDescPtr, ColumnDescriptor}, -}; -use arrow::{ - array::{ArrayRef, Int16Array}, - buffer::MutableBuffer, - datatypes::{DataType as ArrowType, ToByteSlice}, -}; -use std::{any::Any, collections::VecDeque, marker::PhantomData}; -use std::{cell::RefCell, rc::Rc}; - -struct UnzipIter { - shared_state: Rc>, - select_item_buffer: fn(&mut State) -> &mut VecDeque, - consume_source_item: fn(source_item: Source, state: &mut State) -> Target, -} - -impl UnzipIter { - fn new( - shared_state: Rc>, - item_buffer_selector: fn(&mut State) -> &mut VecDeque, - source_item_consumer: fn(source_item: Source, state: &mut State) -> Target, - ) -> Self { - Self { - shared_state, - select_item_buffer: item_buffer_selector, - consume_source_item: source_item_consumer, - } - } -} - -trait UnzipIterState { - type SourceIter: Iterator; - fn source_iter(&mut self) -> &mut Self::SourceIter; -} - -impl> Iterator - for UnzipIter -{ - type Item = Target; - - fn next(&mut self) -> Option { - let mut inner = self.shared_state.borrow_mut(); - // try to get one from the stored data - (self.select_item_buffer)(&mut *inner) - .pop_front() - .or_else(|| - // nothing stored, we need a new element. - inner.source_iter().next().map(|s| { - (self.consume_source_item)(s, &mut inner) - })) - } -} - -struct PageBufferUnzipIterState { - iter: It, - value_iter_buffer: VecDeque, - def_level_iter_buffer: VecDeque, - rep_level_iter_buffer: VecDeque, -} - -impl> UnzipIterState<(V, L, L)> - for PageBufferUnzipIterState -{ - type SourceIter = It; - - #[inline] - fn source_iter(&mut self) -> &mut Self::SourceIter { - &mut self.iter - } -} - -type ValueUnzipIter = - UnzipIter<(V, L, L), V, PageBufferUnzipIterState>; -type LevelUnzipIter = - UnzipIter<(V, L, L), L, PageBufferUnzipIterState>; -type PageUnzipResult = ( - ValueUnzipIter, - LevelUnzipIter, - LevelUnzipIter, -); - -fn unzip_iter>(it: It) -> PageUnzipResult { - let shared_data = Rc::new(RefCell::new(PageBufferUnzipIterState { - iter: it, - value_iter_buffer: VecDeque::new(), - def_level_iter_buffer: VecDeque::new(), - rep_level_iter_buffer: VecDeque::new(), - })); - - let value_iter = UnzipIter::new( - shared_data.clone(), - |state| &mut state.value_iter_buffer, - |(v, d, r), state| { - state.def_level_iter_buffer.push_back(d); - state.rep_level_iter_buffer.push_back(r); - v - }, - ); - - let def_level_iter = UnzipIter::new( - shared_data.clone(), - |state| &mut state.def_level_iter_buffer, - |(v, d, r), state| { - state.value_iter_buffer.push_back(v); - state.rep_level_iter_buffer.push_back(r); - d - }, - ); - - let rep_level_iter = UnzipIter::new( - shared_data, - |state| &mut state.rep_level_iter_buffer, - |(v, d, r), state| { - state.value_iter_buffer.push_back(v); - state.def_level_iter_buffer.push_back(d); - r - }, - ); - - (value_iter, def_level_iter, rep_level_iter) -} - -pub trait ArrayConverter { - fn convert_value_bytes( - &self, - value_decoder: &mut impl ValueDecoder, - num_values: usize, - ) -> Result; -} - -pub struct ArrowArrayReader<'a, C: ArrayConverter + 'a> { - column_desc: ColumnDescPtr, - data_type: ArrowType, - def_level_decoder: Box, - rep_level_decoder: Box, - value_decoder: Box, - last_def_levels: Option, - last_rep_levels: Option, - array_converter: C, -} - -pub(crate) struct ColumnChunkContext { - dictionary_values: Option>, -} - -impl ColumnChunkContext { - fn new() -> Self { - Self { - dictionary_values: None, - } - } - - fn set_dictionary(&mut self, dictionary_values: Vec) { - self.dictionary_values = Some(dictionary_values); - } -} - -type PageDecoderTuple = ( - Box, - Box, - Box, -); - -impl<'a, C: ArrayConverter + 'a> ArrowArrayReader<'a, C> { - pub fn try_new( - column_chunk_iterator: P, - column_desc: ColumnDescPtr, - array_converter: C, - arrow_type: Option, - ) -> Result { - let data_type = match arrow_type { - Some(t) => t, - None => parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(), - }; - type PageIteratorItem = Result<(Page, Rc>)>; - let page_iter = column_chunk_iterator - // build iterator of pages across column chunks - .flat_map(|x| -> Box> { - // attach column chunk context - let context = Rc::new(RefCell::new(ColumnChunkContext::new())); - match x { - Ok(page_reader) => Box::new( - page_reader.map(move |pr| pr.map(|p| (p, context.clone()))), - ), - // errors from reading column chunks / row groups are propagated to page level - Err(e) => Box::new(std::iter::once(Err(e))), - } - }); - // capture a clone of column_desc in closure so that it can outlive current function - let map_page_fn_factory = |column_desc: ColumnDescPtr| { - move |x: Result<(Page, Rc>)>| { - x.and_then(|(page, context)| { - Self::map_page(page, context, column_desc.as_ref()) - }) - } - }; - let map_page_fn = map_page_fn_factory(column_desc.clone()); - // map page iterator into tuple of buffer iterators for (values, def levels, rep levels) - // errors from lower levels are surfaced through the value decoder iterator - let decoder_iter = page_iter.map(map_page_fn).map(|x| match x { - Ok(iter_tuple) => iter_tuple, - // errors from reading pages are propagated to decoder iterator level - Err(e) => Self::map_page_error(e), - }); - // split tuple iterator into separate iterators for (values, def levels, rep levels) - let (value_iter, def_level_iter, rep_level_iter) = unzip_iter(decoder_iter); - - Ok(Self { - column_desc, - data_type, - def_level_decoder: Box::new(CompositeValueDecoder::new(def_level_iter)), - rep_level_decoder: Box::new(CompositeValueDecoder::new(rep_level_iter)), - value_decoder: Box::new(CompositeValueDecoder::new(value_iter)), - last_def_levels: None, - last_rep_levels: None, - array_converter, - }) - } - - #[inline] - fn def_levels_available(column_desc: &ColumnDescriptor) -> bool { - column_desc.max_def_level() > 0 - } - - #[inline] - fn rep_levels_available(column_desc: &ColumnDescriptor) -> bool { - column_desc.max_rep_level() > 0 - } - - fn map_page_error(err: ParquetError) -> PageDecoderTuple { - ( - Box::new(::once(Err(err.clone()))), - Box::new(::once(Err(err.clone()))), - Box::new(::once(Err(err))), - ) - } - - // Split Result into Result<(Iterator, Iterator, Iterator)> - // this method could fail, e.g. if the page encoding is not supported - fn map_page( - page: Page, - column_chunk_context: Rc>, - column_desc: &ColumnDescriptor, - ) -> Result { - use crate::encodings::levels::LevelDecoder; - match page { - Page::DictionaryPage { - buf, - num_values, - encoding, - .. - } => { - let mut column_chunk_context = column_chunk_context.borrow_mut(); - if column_chunk_context.dictionary_values.is_some() { - return Err(general_err!( - "Column chunk cannot have more than one dictionary" - )); - } - // create plain decoder for dictionary values - let mut dict_decoder = Self::get_dictionary_page_decoder( - buf, - num_values as usize, - encoding, - column_desc, - )?; - // decode and cache dictionary values - let dictionary_values = dict_decoder.read_dictionary_values()?; - column_chunk_context.set_dictionary(dictionary_values); - - // a dictionary page doesn't return any values - Ok(( - Box::new(::empty()), - Box::new(::empty()), - Box::new(::empty()), - )) - } - Page::DataPage { - buf, - num_values, - encoding, - def_level_encoding, - rep_level_encoding, - statistics: _, - } => { - let mut buffer_ptr = buf; - // create rep level decoder iterator - let rep_level_iter: Box = - if Self::rep_levels_available(column_desc) { - let mut rep_decoder = LevelDecoder::v1( - rep_level_encoding, - column_desc.max_rep_level(), - ); - let rep_level_byte_len = - rep_decoder.set_data(num_values as usize, buffer_ptr.all()); - // advance buffer pointer - buffer_ptr = buffer_ptr.start_from(rep_level_byte_len); - Box::new(LevelValueDecoder::new(rep_decoder)) - } else { - Box::new(::once(Err(ParquetError::General( - "rep levels are not available".to_string(), - )))) - }; - // create def level decoder iterator - let def_level_iter: Box = - if Self::def_levels_available(column_desc) { - let mut def_decoder = LevelDecoder::v1( - def_level_encoding, - column_desc.max_def_level(), - ); - let def_levels_byte_len = - def_decoder.set_data(num_values as usize, buffer_ptr.all()); - // advance buffer pointer - buffer_ptr = buffer_ptr.start_from(def_levels_byte_len); - Box::new(LevelValueDecoder::new(def_decoder)) - } else { - Box::new(::once(Err(ParquetError::General( - "def levels are not available".to_string(), - )))) - }; - // create value decoder iterator - let value_iter = Self::get_value_decoder( - buffer_ptr, - num_values as usize, - encoding, - column_desc, - column_chunk_context, - )?; - Ok((value_iter, def_level_iter, rep_level_iter)) - } - Page::DataPageV2 { - buf, - num_values, - encoding, - num_nulls: _, - num_rows: _, - def_levels_byte_len, - rep_levels_byte_len, - is_compressed: _, - statistics: _, - } => { - let mut offset = 0; - // create rep level decoder iterator - let rep_level_iter: Box = - if Self::rep_levels_available(column_desc) { - let rep_levels_byte_len = rep_levels_byte_len as usize; - let mut rep_decoder = - LevelDecoder::v2(column_desc.max_rep_level()); - rep_decoder.set_data_range( - num_values as usize, - &buf, - offset, - rep_levels_byte_len, - ); - offset += rep_levels_byte_len; - Box::new(LevelValueDecoder::new(rep_decoder)) - } else { - Box::new(::once(Err(ParquetError::General( - "rep levels are not available".to_string(), - )))) - }; - // create def level decoder iterator - let def_level_iter: Box = - if Self::def_levels_available(column_desc) { - let def_levels_byte_len = def_levels_byte_len as usize; - let mut def_decoder = - LevelDecoder::v2(column_desc.max_def_level()); - def_decoder.set_data_range( - num_values as usize, - &buf, - offset, - def_levels_byte_len, - ); - offset += def_levels_byte_len; - Box::new(LevelValueDecoder::new(def_decoder)) - } else { - Box::new(::once(Err(ParquetError::General( - "def levels are not available".to_string(), - )))) - }; - - // create value decoder iterator - let values_buffer = buf.start_from(offset); - let value_iter = Self::get_value_decoder( - values_buffer, - num_values as usize, - encoding, - column_desc, - column_chunk_context, - )?; - Ok((value_iter, def_level_iter, rep_level_iter)) - } - } - } - - fn get_dictionary_page_decoder( - values_buffer: ByteBufferPtr, - num_values: usize, - mut encoding: Encoding, - column_desc: &ColumnDescriptor, - ) -> Result> { - if encoding == Encoding::PLAIN || encoding == Encoding::PLAIN_DICTIONARY { - encoding = Encoding::RLE_DICTIONARY - } - - if encoding == Encoding::RLE_DICTIONARY { - Ok( - Self::get_plain_value_decoder(values_buffer, num_values, column_desc) - .into_dictionary_decoder(), - ) - } else { - Err(nyi_err!( - "Invalid/Unsupported encoding type for dictionary: {}", - encoding - )) - } - } - - fn get_value_decoder( - values_buffer: ByteBufferPtr, - num_values: usize, - mut encoding: Encoding, - column_desc: &ColumnDescriptor, - column_chunk_context: Rc>, - ) -> Result> { - if encoding == Encoding::PLAIN_DICTIONARY { - encoding = Encoding::RLE_DICTIONARY; - } - - match encoding { - Encoding::PLAIN => { - Ok( - Self::get_plain_value_decoder(values_buffer, num_values, column_desc) - .into_value_decoder(), - ) - } - Encoding::RLE_DICTIONARY => { - if column_chunk_context.borrow().dictionary_values.is_some() { - let value_bit_len = Self::get_column_physical_bit_len(column_desc); - let dictionary_decoder: Box = if value_bit_len == 0 - { - Box::new(VariableLenDictionaryDecoder::new( - column_chunk_context, - values_buffer, - num_values, - )) - } else { - Box::new(FixedLenDictionaryDecoder::new( - column_chunk_context, - values_buffer, - num_values, - value_bit_len, - )) - }; - Ok(dictionary_decoder) - } else { - Err(general_err!("Dictionary values have not been initialized.")) - } - } - // Encoding::RLE => Box::new(RleValueDecoder::new()), - // Encoding::DELTA_BINARY_PACKED => Box::new(DeltaBitPackDecoder::new()), - // Encoding::DELTA_LENGTH_BYTE_ARRAY => Box::new(DeltaLengthByteArrayDecoder::new()), - Encoding::DELTA_BYTE_ARRAY => Ok(Box::new(DeltaByteArrayValueDecoder::new( - values_buffer, - num_values, - )?)), - e => return Err(nyi_err!("Encoding {} is not supported", e)), - } - } - - fn get_column_physical_bit_len(column_desc: &ColumnDescriptor) -> usize { - use crate::basic::Type as PhysicalType; - // parquet only supports a limited number of physical types - // later converters cast to a more specific arrow / logical type if necessary - match column_desc.physical_type() { - PhysicalType::BOOLEAN => 1, - PhysicalType::INT32 | PhysicalType::FLOAT => 32, - PhysicalType::INT64 | PhysicalType::DOUBLE => 64, - PhysicalType::INT96 => 96, - PhysicalType::BYTE_ARRAY => 0, - PhysicalType::FIXED_LEN_BYTE_ARRAY => column_desc.type_length() as usize * 8, - } - } - - fn get_plain_value_decoder( - values_buffer: ByteBufferPtr, - num_values: usize, - column_desc: &ColumnDescriptor, - ) -> Box { - let value_bit_len = Self::get_column_physical_bit_len(column_desc); - if value_bit_len == 0 { - Box::new(VariableLenPlainDecoder::new(values_buffer, num_values)) - } else { - Box::new(FixedLenPlainDecoder::new( - values_buffer, - num_values, - value_bit_len, - )) - } - } - - fn build_level_array( - level_decoder: &mut impl ValueDecoder, - batch_size: usize, - ) -> Result { - use arrow::datatypes::Int16Type; - let level_converter = PrimitiveArrayConverter::::new(); - let array_data = - level_converter.convert_value_bytes(level_decoder, batch_size)?; - Ok(Int16Array::from(array_data)) - } -} - -impl ArrayReader for ArrowArrayReader<'static, C> { - fn as_any(&self) -> &dyn Any { - self - } - - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - fn next_batch(&mut self, batch_size: usize) -> Result { - if Self::rep_levels_available(&self.column_desc) { - // read rep levels if available - let rep_level_array = - Self::build_level_array(&mut self.rep_level_decoder, batch_size)?; - self.last_rep_levels = Some(rep_level_array); - } - - // check if def levels are available - let (values_to_read, null_bitmap_array) = - if !Self::def_levels_available(&self.column_desc) { - // if no def levels - just read (up to) batch_size values - (batch_size, None) - } else { - // if def levels are available - they determine how many values will be read - // decode def levels, return first error if any - let def_level_array = - Self::build_level_array(&mut self.def_level_decoder, batch_size)?; - let def_level_count = def_level_array.len(); - // use eq_scalar to efficiently build null bitmap array from def levels - let null_bitmap_array = arrow::compute::eq_scalar( - &def_level_array, - self.column_desc.max_def_level(), - )?; - self.last_def_levels = Some(def_level_array); - // efficiently calculate values to read - let values_to_read = null_bitmap_array - .values() - .count_set_bits_offset(0, def_level_count); - let maybe_null_bitmap = if values_to_read != null_bitmap_array.len() { - Some(null_bitmap_array) - } else { - // shortcut if no NULLs - None - }; - (values_to_read, maybe_null_bitmap) - }; - - // read a batch of values - // converter only creates a no-null / all value array data - let mut value_array_data = self - .array_converter - .convert_value_bytes(&mut self.value_decoder, values_to_read)?; - - if let Some(null_bitmap_array) = null_bitmap_array { - // Only if def levels are available - insert null values efficiently using MutableArrayData. - // This will require value bytes to be copied again, but converter requirements are reduced. - // With a small number of NULLs, this will only be a few copies of large byte sequences. - let actual_batch_size = null_bitmap_array.len(); - // use_nulls is false, because null_bitmap_array is already calculated and re-used - let mut mutable = arrow::array::MutableArrayData::new( - vec![&value_array_data], - false, - actual_batch_size, - ); - // SlicesIterator slices only the true values, NULLs are inserted to fill any gaps - arrow::compute::SlicesIterator::new(&null_bitmap_array).for_each( - |(start, end)| { - // the gap needs to be filled with NULLs - if start > mutable.len() { - let nulls_to_add = start - mutable.len(); - mutable.extend_nulls(nulls_to_add); - } - // fill values, adjust start and end with NULL count so far - let nulls_added = mutable.null_count(); - mutable.extend(0, start - nulls_added, end - nulls_added); - }, - ); - // any remaining part is NULLs - if mutable.len() < actual_batch_size { - let nulls_to_add = actual_batch_size - mutable.len(); - mutable.extend_nulls(nulls_to_add); - } - - value_array_data = unsafe { - mutable - .into_builder() - .null_bit_buffer(null_bitmap_array.values().clone()) - .build_unchecked() - }; - } - let mut array = arrow::array::make_array(value_array_data); - if array.data_type() != &self.data_type { - // cast array to self.data_type if necessary - array = arrow::compute::cast(&array, &self.data_type)? - } - Ok(array) - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.last_def_levels.as_ref().map(|x| x.values()) - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.last_rep_levels.as_ref().map(|x| x.values()) - } -} - -use crate::encodings::rle::RleDecoder; - -pub trait ValueDecoder { - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result; -} - -trait DictionaryValueDecoder { - fn read_dictionary_values(&mut self) -> Result>; -} - -trait PlainValueDecoder: ValueDecoder + DictionaryValueDecoder { - fn into_value_decoder(self: Box) -> Box; - fn into_dictionary_decoder(self: Box) -> Box; -} - -impl PlainValueDecoder for T -where - T: ValueDecoder + DictionaryValueDecoder + 'static, -{ - fn into_value_decoder(self: Box) -> Box { - self - } - - fn into_dictionary_decoder(self: Box) -> Box { - self - } -} - -impl dyn ValueDecoder { - fn empty() -> impl ValueDecoder { - SingleValueDecoder::new(Ok(0)) - } - - fn once(value: Result) -> impl ValueDecoder { - SingleValueDecoder::new(value) - } -} - -impl ValueDecoder for Box { - #[inline] - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - self.as_mut().read_value_bytes(num_values, read_bytes) - } -} - -struct SingleValueDecoder { - value: Result, -} - -impl SingleValueDecoder { - fn new(value: Result) -> Self { - Self { value } - } -} - -impl ValueDecoder for SingleValueDecoder { - fn read_value_bytes( - &mut self, - _num_values: usize, - _read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - self.value.clone() - } -} - -struct CompositeValueDecoder>> { - current_decoder: Option>, - decoder_iter: I, -} - -impl>> CompositeValueDecoder { - fn new(mut decoder_iter: I) -> Self { - let current_decoder = decoder_iter.next(); - Self { - current_decoder, - decoder_iter, - } - } -} - -impl>> ValueDecoder - for CompositeValueDecoder -{ - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - let mut values_to_read = num_values; - while values_to_read > 0 { - let value_decoder = match self.current_decoder.as_mut() { - Some(d) => d, - // no more decoders - None => break, - }; - while values_to_read > 0 { - let values_read = - value_decoder.read_value_bytes(values_to_read, read_bytes)?; - if values_read > 0 { - values_to_read -= values_read; - } else { - // no more values in current decoder - self.current_decoder = self.decoder_iter.next(); - break; - } - } - } - - Ok(num_values - values_to_read) - } -} - -struct LevelValueDecoder { - level_decoder: crate::encodings::levels::LevelDecoder, - level_value_buffer: Vec, -} - -impl LevelValueDecoder { - fn new(level_decoder: crate::encodings::levels::LevelDecoder) -> Self { - Self { - level_decoder, - level_value_buffer: vec![0i16; 2048], - } - } -} - -impl ValueDecoder for LevelValueDecoder { - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - let value_size = std::mem::size_of::(); - let mut total_values_read = 0; - while total_values_read < num_values { - let values_to_read = std::cmp::min( - num_values - total_values_read, - self.level_value_buffer.len(), - ); - let values_read = match self - .level_decoder - .get(&mut self.level_value_buffer[..values_to_read]) - { - Ok(values_read) => values_read, - Err(e) => return Err(e), - }; - if values_read > 0 { - let level_value_bytes = - &self.level_value_buffer.to_byte_slice()[..values_read * value_size]; - read_bytes(level_value_bytes, values_read); - total_values_read += values_read; - } else { - break; - } - } - Ok(total_values_read) - } -} - -pub(crate) struct FixedLenPlainDecoder { - data: ByteBufferPtr, - num_values: usize, - value_bit_len: usize, -} - -impl FixedLenPlainDecoder { - pub(crate) fn new( - data: ByteBufferPtr, - num_values: usize, - value_bit_len: usize, - ) -> Self { - Self { - data, - num_values, - value_bit_len, - } - } -} - -impl DictionaryValueDecoder for FixedLenPlainDecoder { - fn read_dictionary_values(&mut self) -> Result> { - let value_byte_len = self.value_bit_len / 8; - let available_values = self.data.len() / value_byte_len; - let values_to_read = std::cmp::min(available_values, self.num_values); - let byte_len = values_to_read * value_byte_len; - let values = vec![self.data.range(0, byte_len)]; - self.num_values = 0; - self.data.set_range(self.data.start(), 0); - Ok(values) - } -} - -impl ValueDecoder for FixedLenPlainDecoder { - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - let available_values = self.data.len() * 8 / self.value_bit_len; - if available_values > 0 { - let values_to_read = std::cmp::min(available_values, num_values); - let byte_len = values_to_read * self.value_bit_len / 8; - read_bytes(&self.data.data()[..byte_len], values_to_read); - self.data - .set_range(self.data.start() + byte_len, self.data.len() - byte_len); - Ok(values_to_read) - } else { - Ok(0) - } - } -} - -pub(crate) struct VariableLenPlainDecoder { - data: ByteBufferPtr, - num_values: usize, - position: usize, -} - -impl VariableLenPlainDecoder { - pub(crate) fn new(data: ByteBufferPtr, num_values: usize) -> Self { - Self { - data, - num_values, - position: 0, - } - } -} - -impl DictionaryValueDecoder for VariableLenPlainDecoder { - fn read_dictionary_values(&mut self) -> Result> { - const LEN_SIZE: usize = std::mem::size_of::(); - let data = self.data.data(); - let data_len = data.len(); - let values_to_read = self.num_values; - let mut values = Vec::with_capacity(values_to_read); - let mut values_read = 0; - while self.position < data_len && values_read < values_to_read { - let len: usize = - read_num_bytes!(u32, LEN_SIZE, data[self.position..]) as usize; - self.position += LEN_SIZE; - if data_len < self.position + len { - return Err(eof_err!("Not enough bytes to decode")); - } - values.push(self.data.range(self.position, len)); - self.position += len; - values_read += 1; - } - self.num_values -= values_read; - Ok(values) - } -} - -impl ValueDecoder for VariableLenPlainDecoder { - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - const LEN_SIZE: usize = std::mem::size_of::(); - let data = self.data.data(); - let data_len = data.len(); - let values_to_read = std::cmp::min(self.num_values, num_values); - let mut values_read = 0; - while self.position < data_len && values_read < values_to_read { - let len: usize = - read_num_bytes!(u32, LEN_SIZE, data[self.position..]) as usize; - self.position += LEN_SIZE; - if data_len < self.position + len { - return Err(eof_err!("Not enough bytes to decode")); - } - read_bytes(&data[self.position..][..len], 1); - self.position += len; - values_read += 1; - } - self.num_values -= values_read; - Ok(values_read) - } -} - -pub(crate) struct FixedLenDictionaryDecoder { - context_ref: Rc>, - key_data_bufer: ByteBufferPtr, - num_values: usize, - rle_decoder: RleDecoder, - value_byte_len: usize, - keys_buffer: Vec, -} - -impl FixedLenDictionaryDecoder { - pub(crate) fn new( - column_chunk_context: Rc>, - key_data_bufer: ByteBufferPtr, - num_values: usize, - value_bit_len: usize, - ) -> Self { - assert!( - value_bit_len % 8 == 0, - "value_bit_size must be a multiple of 8" - ); - // First byte in `data` is bit width - let bit_width = key_data_bufer.data()[0]; - let mut rle_decoder = RleDecoder::new(bit_width); - rle_decoder.set_data(key_data_bufer.start_from(1)); - - Self { - context_ref: column_chunk_context, - key_data_bufer, - num_values, - rle_decoder, - value_byte_len: value_bit_len / 8, - keys_buffer: vec![0; 2048], - } - } -} - -impl ValueDecoder for FixedLenDictionaryDecoder { - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - if self.num_values == 0 { - return Ok(0); - } - let context = self.context_ref.borrow(); - let values = context.dictionary_values.as_ref().unwrap(); - let input_value_bytes = values[0].data(); - // read no more than available values or requested values - let values_to_read = std::cmp::min(self.num_values, num_values); - let mut values_read = 0; - while values_read < values_to_read { - // read values in batches of up to self.keys_buffer.len() - let keys_to_read = - std::cmp::min(values_to_read - values_read, self.keys_buffer.len()); - let keys_read = match self - .rle_decoder - .get_batch(&mut self.keys_buffer[..keys_to_read]) - { - Ok(keys_read) => keys_read, - Err(e) => return Err(e), - }; - if keys_read == 0 { - self.num_values = 0; - return Ok(values_read); - } - for i in 0..keys_read { - let key = self.keys_buffer[i] as usize; - read_bytes( - &input_value_bytes[key * self.value_byte_len..] - [..self.value_byte_len], - 1, - ); - } - values_read += keys_read; - } - self.num_values -= values_read; - Ok(values_read) - } -} - -pub(crate) struct VariableLenDictionaryDecoder { - context_ref: Rc>, - key_data_bufer: ByteBufferPtr, - num_values: usize, - rle_decoder: RleDecoder, - keys_buffer: Vec, -} - -impl VariableLenDictionaryDecoder { - pub(crate) fn new( - column_chunk_context: Rc>, - key_data_bufer: ByteBufferPtr, - num_values: usize, - ) -> Self { - // First byte in `data` is bit width - let bit_width = key_data_bufer.data()[0]; - let mut rle_decoder = RleDecoder::new(bit_width); - rle_decoder.set_data(key_data_bufer.start_from(1)); - - Self { - context_ref: column_chunk_context, - key_data_bufer, - num_values, - rle_decoder, - keys_buffer: vec![0; 2048], - } - } -} - -impl ValueDecoder for VariableLenDictionaryDecoder { - fn read_value_bytes( - &mut self, - num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - if self.num_values == 0 { - return Ok(0); - } - let context = self.context_ref.borrow(); - let values = context.dictionary_values.as_ref().unwrap(); - let values_to_read = std::cmp::min(self.num_values, num_values); - let mut values_read = 0; - while values_read < values_to_read { - // read values in batches of up to self.keys_buffer.len() - let keys_to_read = - std::cmp::min(values_to_read - values_read, self.keys_buffer.len()); - let keys_read = match self - .rle_decoder - .get_batch(&mut self.keys_buffer[..keys_to_read]) - { - Ok(keys_read) => keys_read, - Err(e) => return Err(e), - }; - if keys_read == 0 { - self.num_values = 0; - return Ok(values_read); - } - for i in 0..keys_read { - let key = self.keys_buffer[i] as usize; - read_bytes(values[key].data(), 1); - } - values_read += keys_read; - } - self.num_values -= values_read; - Ok(values_read) - } -} - -pub(crate) struct DeltaByteArrayValueDecoder { - decoder: DeltaByteArrayDecoder, -} - -impl DeltaByteArrayValueDecoder { - pub fn new(data: ByteBufferPtr, num_values: usize) -> Result { - let mut decoder = DeltaByteArrayDecoder::new(); - decoder.set_data(data, num_values)?; - Ok(Self { decoder }) - } -} - -impl ValueDecoder for DeltaByteArrayValueDecoder { - fn read_value_bytes( - &mut self, - mut num_values: usize, - read_bytes: &mut dyn FnMut(&[u8], usize), - ) -> Result { - num_values = std::cmp::min(num_values, self.decoder.values_left()); - let mut values_read = 0; - let mut buf = [ByteArray::new()]; - while values_read < num_values { - let num_read = self.decoder.get(&mut buf)?; - debug_assert_eq!(num_read, 1); - - read_bytes(buf[0].data(), 1); - - values_read += 1; - } - Ok(values_read) - } -} - -use arrow::datatypes::ArrowPrimitiveType; - -pub struct PrimitiveArrayConverter { - _phantom_data: PhantomData, -} - -impl PrimitiveArrayConverter { - pub fn new() -> Self { - Self { - _phantom_data: PhantomData, - } - } -} - -impl ArrayConverter for PrimitiveArrayConverter { - fn convert_value_bytes( - &self, - value_decoder: &mut impl ValueDecoder, - num_values: usize, - ) -> Result { - let value_size = T::get_byte_width(); - let values_byte_capacity = num_values * value_size; - let mut values_buffer = MutableBuffer::new(values_byte_capacity); - - value_decoder.read_value_bytes(num_values, &mut |value_bytes, _| { - values_buffer.extend_from_slice(value_bytes); - })?; - - // calculate actual data_len, which may be different from the iterator's upper bound - let value_count = values_buffer.len() / value_size; - let array_data = arrow::array::ArrayData::builder(T::DATA_TYPE) - .len(value_count) - .add_buffer(values_buffer.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Ok(array_data) - } -} - -pub struct StringArrayConverter {} - -impl StringArrayConverter { - pub fn new() -> Self { - Self {} - } -} - -impl ArrayConverter for StringArrayConverter { - fn convert_value_bytes( - &self, - value_decoder: &mut impl ValueDecoder, - num_values: usize, - ) -> Result { - use arrow::datatypes::ArrowNativeType; - let offset_size = std::mem::size_of::(); - let mut offsets_buffer = MutableBuffer::new((num_values + 1) * offset_size); - // allocate initial capacity of 1 byte for each item - let values_byte_capacity = num_values; - let mut values_buffer = MutableBuffer::new(values_byte_capacity); - - let mut length_so_far = i32::default(); - offsets_buffer.push(length_so_far); - - value_decoder.read_value_bytes(num_values, &mut |value_bytes, values_read| { - debug_assert_eq!( - values_read, 1, - "offset length value buffers can only contain bytes for a single value" - ); - length_so_far += - ::from_usize(value_bytes.len()).unwrap(); - // this should be safe because a ValueDecoder should not read more than num_values - unsafe { - offsets_buffer.push_unchecked(length_so_far); - } - values_buffer.extend_from_slice(value_bytes); - })?; - // calculate actual data_len, which may be different from the iterator's upper bound - let data_len = (offsets_buffer.len() / offset_size) - 1; - let array_data = arrow::array::ArrayData::builder(ArrowType::Utf8) - .len(data_len) - .add_buffer(offsets_buffer.into()) - .add_buffer(values_buffer.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Ok(array_data) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::arrow::{ArrowReader, ParquetFileArrowReader}; - use crate::basic::ConvertedType; - use crate::column::page::Page; - use crate::column::writer::ColumnWriter; - use crate::data_type::ByteArray; - use crate::data_type::ByteArrayType; - use crate::file::properties::WriterProperties; - use crate::file::reader::SerializedFileReader; - use crate::file::serialized_reader::SliceableCursor; - use crate::file::writer::{FileWriter, SerializedFileWriter, TryClone}; - use crate::schema::parser::parse_message_type; - use crate::schema::types::SchemaDescriptor; - use crate::util::test_common::page_util::{ - DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator, - }; - use crate::{ - basic::Encoding, column::page::PageReader, schema::types::SchemaDescPtr, - }; - use arrow::array::{PrimitiveArray, StringArray}; - use arrow::datatypes::Int32Type as ArrowInt32; - use rand::{distributions::uniform::SampleUniform, thread_rng, Rng}; - use std::io::{Cursor, Seek, SeekFrom, Write}; - use std::sync::{Arc, Mutex}; - - /// Iterator for testing reading empty columns - struct EmptyPageIterator { - schema: SchemaDescPtr, - } - - impl EmptyPageIterator { - fn new(schema: SchemaDescPtr) -> Self { - EmptyPageIterator { schema } - } - } - - impl Iterator for EmptyPageIterator { - type Item = Result>; - - fn next(&mut self) -> Option { - None - } - } - - impl PageIterator for EmptyPageIterator { - fn schema(&mut self) -> Result { - Ok(self.schema.clone()) - } - - fn column_schema(&mut self) -> Result { - Ok(self.schema.column(0)) - } - } - - #[test] - fn test_array_reader_empty_pages() { - // Construct column schema - let message_type = " - message test_schema { - REQUIRED INT32 leaf; - } - "; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let column_desc = schema.column(0); - let page_iterator = EmptyPageIterator::new(schema); - - let converter = PrimitiveArrayConverter::::new(); - let mut array_reader = - ArrowArrayReader::try_new(page_iterator, column_desc, converter, None) - .unwrap(); - - // expect no values to be read - let array = array_reader.next_batch(50).unwrap(); - assert!(array.is_empty()); - } - - fn make_column_chunks( - column_desc: ColumnDescPtr, - encoding: Encoding, - num_levels: usize, - min_value: T::T, - max_value: T::T, - def_levels: &mut Vec, - rep_levels: &mut Vec, - values: &mut Vec, - page_lists: &mut Vec>, - use_v2: bool, - num_chunks: usize, - ) where - T::T: PartialOrd + SampleUniform + Copy, - { - for _i in 0..num_chunks { - let mut pages = VecDeque::new(); - let mut data = Vec::new(); - let mut page_def_levels = Vec::new(); - let mut page_rep_levels = Vec::new(); - - crate::util::test_common::make_pages::( - column_desc.clone(), - encoding, - 1, - num_levels, - min_value, - max_value, - &mut page_def_levels, - &mut page_rep_levels, - &mut data, - &mut pages, - use_v2, - ); - - def_levels.append(&mut page_def_levels); - rep_levels.append(&mut page_rep_levels); - values.append(&mut data); - page_lists.push(Vec::from(pages)); - } - } - - #[test] - fn test_primitive_array_reader_data() { - // Construct column schema - let message_type = " - message test_schema { - REQUIRED INT32 leaf; - } - "; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let column_desc = schema.column(0); - - // Construct page iterator - { - let mut data = Vec::new(); - let mut page_lists = Vec::new(); - make_column_chunks::( - column_desc.clone(), - Encoding::PLAIN, - 100, - 1, - 200, - &mut Vec::new(), - &mut Vec::new(), - &mut data, - &mut page_lists, - true, - 2, - ); - let page_iterator = - InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); - - let converter = PrimitiveArrayConverter::::new(); - let mut array_reader = - ArrowArrayReader::try_new(page_iterator, column_desc, converter, None) - .unwrap(); - - // Read first 50 values, which are all from the first column chunk - let array = array_reader.next_batch(50).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::::from(data[0..50].to_vec()), - array - ); - - // Read next 100 values, the first 50 ones are from the first column chunk, - // and the last 50 ones are from the second column chunk - let array = array_reader.next_batch(100).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::::from(data[50..150].to_vec()), - array - ); - - // Try to read 100 values, however there are only 50 values - let array = array_reader.next_batch(100).unwrap(); - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - - assert_eq!( - &PrimitiveArray::::from(data[150..200].to_vec()), - array - ); - } - } - - #[test] - fn test_primitive_array_reader_def_and_rep_levels() { - // Construct column schema - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL INT32 leaf; - } - } - "; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let column_desc = schema.column(0); - - // Construct page iterator - { - let mut def_levels = Vec::new(); - let mut rep_levels = Vec::new(); - let mut page_lists = Vec::new(); - make_column_chunks::( - column_desc.clone(), - Encoding::PLAIN, - 100, - 1, - 200, - &mut def_levels, - &mut rep_levels, - &mut Vec::new(), - &mut page_lists, - true, - 2, - ); - - let page_iterator = - InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); - - let converter = PrimitiveArrayConverter::::new(); - let mut array_reader = - ArrowArrayReader::try_new(page_iterator, column_desc, converter, None) - .unwrap(); - - let mut accu_len: usize = 0; - - // Read first 50 values, which are all from the first column chunk - let array = array_reader.next_batch(50).unwrap(); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Read next 100 values, the first 50 ones are from the first column chunk, - // and the last 50 ones are from the second column chunk - let array = array_reader.next_batch(100).unwrap(); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Try to read 100 values, however there are only 50 values - let array = array_reader.next_batch(100).unwrap(); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - - assert_eq!(accu_len + array.len(), 200); - } - } - - #[test] - fn test_arrow_array_reader_string() { - // Construct column schema - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL BYTE_ARRAY leaf (UTF8); - } - } - "; - let num_pages = 2; - let values_per_page = 100; - let str_base = "Hello World"; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - let column_desc = schema.column(0); - let max_def_level = column_desc.max_def_level(); - let max_rep_level = column_desc.max_rep_level(); - - assert_eq!(max_def_level, 2); - assert_eq!(max_rep_level, 1); - - let mut rng = thread_rng(); - let mut pages: Vec> = Vec::new(); - - let mut rep_levels = Vec::with_capacity(num_pages * values_per_page); - let mut def_levels = Vec::with_capacity(num_pages * values_per_page); - let mut all_values = Vec::with_capacity(num_pages * values_per_page); - - for i in 0..num_pages { - let mut values = Vec::with_capacity(values_per_page); - - for _ in 0..values_per_page { - let def_level = rng.gen_range(0..max_def_level + 1); - let rep_level = rng.gen_range(0..max_rep_level + 1); - if def_level == max_def_level { - let len = rng.gen_range(1..str_base.len()); - let slice = &str_base[..len]; - values.push(ByteArray::from(slice)); - all_values.push(Some(slice.to_string())); - } else { - all_values.push(None) - } - rep_levels.push(rep_level); - def_levels.push(def_level) - } - - let range = i * values_per_page..(i + 1) * values_per_page; - let mut pb = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - - pb.add_rep_levels(max_rep_level, &rep_levels.as_slice()[range.clone()]); - pb.add_def_levels(max_def_level, &def_levels.as_slice()[range]); - pb.add_values::(Encoding::PLAIN, values.as_slice()); - - let data_page = pb.consume(); - pages.push(vec![data_page]); - } - - let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); - let converter = StringArrayConverter::new(); - let mut array_reader = - ArrowArrayReader::try_new(page_iterator, column_desc, converter, None) - .unwrap(); - - let mut accu_len: usize = 0; - - let array = array_reader.next_batch(values_per_page / 2).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, - // and the last values_per_page/2 ones are from the second column chunk - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - let strings = array.as_any().downcast_ref::().unwrap(); - for i in 0..array.len() { - if array.is_valid(i) { - assert_eq!( - all_values[i + accu_len].as_ref().unwrap().as_str(), - strings.value(i) - ) - } else { - assert_eq!(all_values[i + accu_len], None) - } - } - accu_len += array.len(); - - // Try to read values_per_page values, however there are only values_per_page/2 values - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - } - - /// Allows to write parquet into memory. Intended only for use in tests. - #[derive(Clone)] - struct VecWriter { - data: Arc>>>, - } - - impl VecWriter { - pub fn new() -> VecWriter { - VecWriter { - data: Arc::new(Mutex::new(Cursor::new(Vec::new()))), - } - } - - pub fn consume(self) -> Vec { - Arc::try_unwrap(self.data) - .unwrap() - .into_inner() - .unwrap() - .into_inner() - } - } - - impl TryClone for VecWriter { - fn try_clone(&self) -> std::io::Result { - Ok(self.clone()) - } - } - - impl Seek for VecWriter { - fn seek(&mut self, pos: SeekFrom) -> std::io::Result { - self.data.lock().unwrap().seek(pos) - } - - fn stream_position(&mut self) -> std::io::Result { - self.data.lock().unwrap().stream_position() - } - } - - impl Write for VecWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.data.lock().unwrap().write(buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.data.lock().unwrap().flush() - } - } - - #[test] - fn test_string_delta_byte_array() { - use crate::basic; - use crate::schema::types::Type; - - let data = VecWriter::new(); - let schema = Arc::new( - Type::group_type_builder("string_test") - .with_fields(&mut vec![Arc::new( - Type::primitive_type_builder("c", basic::Type::BYTE_ARRAY) - .with_converted_type(ConvertedType::UTF8) - .build() - .unwrap(), - )]) - .build() - .unwrap(), - ); - // Disable dictionary and use the fallback encoding. - let p = Arc::new( - WriterProperties::builder() - .set_dictionary_enabled(false) - .set_encoding(Encoding::DELTA_BYTE_ARRAY) - .build(), - ); - // Write a few strings. - let mut w = SerializedFileWriter::new(data.clone(), schema, p).unwrap(); - let mut rg = w.next_row_group().unwrap(); - let mut c = rg.next_column().unwrap().unwrap(); - match &mut c { - ColumnWriter::ByteArrayColumnWriter(c) => { - c.write_batch( - &[ByteArray::from("foo"), ByteArray::from("bar")], - Some(&[0, 1, 0, 0, 1, 0]), - Some(&[0, 0, 0, 0, 0, 0]), - ) - .unwrap(); - } - _ => panic!("unexpected column"), - }; - rg.close_column(c).unwrap(); - w.close_row_group(rg).unwrap(); - w.close().unwrap(); - std::mem::drop(w); - - // Check we can read them back. - let r = SerializedFileReader::new(SliceableCursor::new(Arc::new(data.consume()))) - .unwrap(); - let mut r = ParquetFileArrowReader::new(Arc::new(r)); - - let batch = r - .get_record_reader_by_columns([0], 1024) - .unwrap() - .next() - .unwrap() - .unwrap(); - assert_eq!(batch.columns().len(), 1); - - let strings = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!( - strings.into_iter().collect::>(), - vec![None, Some("foo"), None, None, Some("bar"), None] - ); - } -} diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs index 761c5a6781bf..89406cd616a4 100644 --- a/parquet/src/arrow/arrow_reader.rs +++ b/parquet/src/arrow/arrow_reader.rs @@ -15,21 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Contains reader which reads parquet data into arrow array. +//! Contains reader which reads parquet data into arrow [`RecordBatch`] -use crate::arrow::array_reader::{build_array_reader, ArrayReader, StructArrayReader}; -use crate::arrow::schema::parquet_to_arrow_schema; -use crate::arrow::schema::{ - parquet_to_arrow_schema_by_columns, parquet_to_arrow_schema_by_root_columns, -}; -use crate::errors::{ParquetError, Result}; -use crate::file::metadata::ParquetMetaData; -use crate::file::reader::FileReader; +use std::sync::Arc; + +use arrow::array::Array; use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; use arrow::{array::StructArray, error::ArrowError}; -use std::sync::Arc; + +use crate::arrow::array_reader::{build_array_reader, ArrayReader}; +use crate::arrow::schema::parquet_to_arrow_schema; +use crate::arrow::schema::parquet_to_arrow_schema_by_columns; +use crate::arrow::ProjectionMask; +use crate::errors::Result; +use crate::file::metadata::{KeyValue, ParquetMetaData}; +use crate::file::reader::{ChunkReader, FileReader, SerializedFileReader}; +use crate::schema::types::SchemaDescriptor; /// Arrow reader api. /// With this api, user can get arrow schema from parquet file, and read parquet data @@ -41,15 +44,8 @@ pub trait ArrowReader { fn get_schema(&mut self) -> Result; /// Read parquet schema and convert it into arrow schema. - /// This schema only includes columns identified by `column_indices`. - /// To select leaf columns (i.e. `a.b.c` instead of `a`), set `leaf_columns = true` - fn get_schema_by_columns( - &mut self, - column_indices: T, - leaf_columns: bool, - ) -> Result - where - T: IntoIterator; + /// This schema only includes columns identified by `mask`. + fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> Result; /// Returns record batch reader from whole parquet file. /// @@ -61,23 +57,48 @@ pub trait ArrowReader { fn get_record_reader(&mut self, batch_size: usize) -> Result; /// Returns record batch reader whose record batch contains columns identified by - /// `column_indices`. + /// `mask`. /// /// # Arguments /// - /// `column_indices`: The columns that should be included in record batches. + /// `mask`: The columns that should be included in record batches. /// `batch_size`: Please refer to `get_record_reader`. - fn get_record_reader_by_columns( + fn get_record_reader_by_columns( &mut self, - column_indices: T, + mask: ProjectionMask, batch_size: usize, - ) -> Result - where - T: IntoIterator; + ) -> Result; +} + +#[derive(Debug, Clone, Default)] +pub struct ArrowReaderOptions { + skip_arrow_metadata: bool, +} + +impl ArrowReaderOptions { + /// Create a new [`ArrowReaderOptions`] with the default settings + fn new() -> Self { + Self::default() + } + + /// Parquet files generated by some writers may contain embedded arrow + /// schema and metadata. This may not be correct or compatible with your system. + /// + /// For example:[ARROW-16184](https://issues.apache.org/jira/browse/ARROW-16184) + /// + + /// Set `skip_arrow_metadata` to true, to skip decoding this + pub fn with_skip_arrow_metadata(self, skip_arrow_metadata: bool) -> Self { + Self { + skip_arrow_metadata, + } + } } pub struct ParquetFileArrowReader { file_reader: Arc, + + options: ArrowReaderOptions, } impl ArrowReader for ParquetFileArrowReader { @@ -85,66 +106,38 @@ impl ArrowReader for ParquetFileArrowReader { fn get_schema(&mut self) -> Result { let file_metadata = self.file_reader.metadata().file_metadata(); - parquet_to_arrow_schema( - file_metadata.schema_descr(), - file_metadata.key_value_metadata(), - ) + parquet_to_arrow_schema(file_metadata.schema_descr(), self.get_kv_metadata()) } - fn get_schema_by_columns( - &mut self, - column_indices: T, - leaf_columns: bool, - ) -> Result - where - T: IntoIterator, - { + fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> Result { let file_metadata = self.file_reader.metadata().file_metadata(); - if leaf_columns { - parquet_to_arrow_schema_by_columns( - file_metadata.schema_descr(), - column_indices, - file_metadata.key_value_metadata(), - ) - } else { - parquet_to_arrow_schema_by_root_columns( - file_metadata.schema_descr(), - column_indices, - file_metadata.key_value_metadata(), - ) - } + parquet_to_arrow_schema_by_columns( + file_metadata.schema_descr(), + mask, + self.get_kv_metadata(), + ) } fn get_record_reader( &mut self, batch_size: usize, ) -> Result { - let column_indices = 0..self - .file_reader - .metadata() - .file_metadata() - .schema_descr() - .num_columns(); - - self.get_record_reader_by_columns(column_indices, batch_size) + self.get_record_reader_by_columns(ProjectionMask::all(), batch_size) } - fn get_record_reader_by_columns( + fn get_record_reader_by_columns( &mut self, - column_indices: T, + mask: ProjectionMask, batch_size: usize, - ) -> Result - where - T: IntoIterator, - { + ) -> Result { let array_reader = build_array_reader( self.file_reader .metadata() .file_metadata() .schema_descr_ptr(), - self.get_schema()?, - column_indices, - self.file_reader.clone(), + Arc::new(self.get_schema()?), + mask, + Box::new(self.file_reader.clone()), )?; ParquetRecordBatchReader::try_new(batch_size, array_reader) @@ -152,14 +145,77 @@ impl ArrowReader for ParquetFileArrowReader { } impl ParquetFileArrowReader { + /// Create a new [`ParquetFileArrowReader`] with the provided [`ChunkReader`] + /// + /// ```no_run + /// # use std::fs::File; + /// # use bytes::Bytes; + /// # use parquet::arrow::ParquetFileArrowReader; + /// + /// let file = File::open("file.parquet").unwrap(); + /// let reader = ParquetFileArrowReader::try_new(file).unwrap(); + /// + /// let bytes = Bytes::from(vec![]); + /// let reader = ParquetFileArrowReader::try_new(bytes).unwrap(); + /// ``` + pub fn try_new(chunk_reader: R) -> Result { + Self::try_new_with_options(chunk_reader, Default::default()) + } + + /// Create a new [`ParquetFileArrowReader`] with the provided [`ChunkReader`] + /// and [`ArrowReaderOptions`] + pub fn try_new_with_options( + chunk_reader: R, + options: ArrowReaderOptions, + ) -> Result { + let file_reader = Arc::new(SerializedFileReader::new(chunk_reader)?); + Ok(Self::new_with_options(file_reader, options)) + } + + /// Create a new [`ParquetFileArrowReader`] with the provided [`Arc`] pub fn new(file_reader: Arc) -> Self { - Self { file_reader } + Self::new_with_options(file_reader, Default::default()) } - // Expose the reader metadata + /// Create a new [`ParquetFileArrowReader`] with the provided [`Arc`] + /// and [`ArrowReaderOptions`] + pub fn new_with_options( + file_reader: Arc, + options: ArrowReaderOptions, + ) -> Self { + Self { + file_reader, + options, + } + } + + /// Expose the reader metadata + #[deprecated = "use metadata() instead"] pub fn get_metadata(&mut self) -> ParquetMetaData { self.file_reader.metadata().clone() } + + /// Returns the parquet metadata + pub fn metadata(&self) -> &ParquetMetaData { + self.file_reader.metadata() + } + + /// Returns the parquet schema + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.file_reader.metadata().file_metadata().schema_descr() + } + + /// Returns the key value metadata, returns `None` if [`ArrowReaderOptions::skip_arrow_metadata`] + fn get_kv_metadata(&self) -> Option<&Vec> { + if self.options.skip_arrow_metadata { + return None; + } + + self.file_reader + .metadata() + .file_metadata() + .key_value_metadata() + } } pub struct ParquetRecordBatchReader { @@ -181,20 +237,10 @@ impl Iterator for ParquetRecordBatchReader { "Struct array reader should return struct array".to_string(), ) }); + match struct_array { Err(err) => Some(Err(err)), - Ok(e) => { - match RecordBatch::try_new(self.schema.clone(), e.columns_ref()) { - Err(err) => Some(Err(err)), - Ok(record_batch) => { - if record_batch.num_rows() > 0 { - Some(Ok(record_batch)) - } else { - None - } - } - } - } + Ok(e) => (e.len() > 0).then(|| Ok(RecordBatch::from(e))), } } } @@ -212,12 +258,6 @@ impl ParquetRecordBatchReader { batch_size: usize, array_reader: Box, ) -> Result { - // Check that array reader is struct array reader - array_reader - .as_any() - .downcast_ref::() - .ok_or_else(|| general_err!("The input must be struct array reader!"))?; - let schema = match array_reader.get_data_type() { ArrowType::Struct(ref fields) => Schema::new(fields.clone()), _ => unreachable!("Struct array reader's data type is not struct!"), @@ -233,33 +273,45 @@ impl ParquetRecordBatchReader { #[cfg(test)] mod tests { - use crate::arrow::arrow_reader::{ArrowReader, ParquetFileArrowReader}; - use crate::arrow::converter::{ - Converter, FixedSizeArrayConverter, FromConverter, IntervalDayTimeArrayConverter, - Utf8ArrayConverter, + use bytes::Bytes; + use std::cmp::min; + use std::convert::TryFrom; + use std::fs::File; + use std::io::Seek; + use std::path::PathBuf; + use std::sync::Arc; + + use rand::{thread_rng, RngCore}; + use serde_json::json; + use serde_json::Value::{Array as JArray, Null as JNull, Object as JObject}; + use tempfile::tempfile; + + use arrow::array::*; + use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; + use arrow::error::Result as ArrowResult; + use arrow::record_batch::{RecordBatch, RecordBatchReader}; + + use crate::arrow::arrow_reader::{ + ArrowReader, ArrowReaderOptions, ParquetFileArrowReader, }; - use crate::column::writer::get_typed_column_writer_mut; + use crate::arrow::buffer::converter::{ + BinaryArrayConverter, Converter, FixedSizeArrayConverter, FromConverter, + IntervalDayTimeArrayConverter, LargeUtf8ArrayConverter, Utf8ArrayConverter, + }; + use crate::arrow::schema::add_encoded_arrow_schema_to_metadata; + use crate::arrow::{ArrowWriter, ProjectionMask}; + use crate::basic::{ConvertedType, Encoding, Repetition, Type as PhysicalType}; use crate::data_type::{ BoolType, ByteArray, ByteArrayType, DataType, FixedLenByteArray, - FixedLenByteArrayType, Int32Type, + FixedLenByteArrayType, Int32Type, Int64Type, }; use crate::errors::Result; - use crate::file::properties::WriterProperties; + use crate::file::properties::{WriterProperties, WriterVersion}; use crate::file::reader::{FileReader, SerializedFileReader}; - use crate::file::writer::{FileWriter, SerializedFileWriter}; + use crate::file::writer::SerializedFileWriter; use crate::schema::parser::parse_message_type; - use crate::schema::types::TypePtr; - use crate::util::test_common::{get_temp_filename, RandGen}; - use arrow::array::*; - use arrow::record_batch::RecordBatchReader; - use rand::RngCore; - use serde_json::json; - use serde_json::Value::{Array as JArray, Null as JNull, Object as JObject}; - use std::cmp::min; - use std::convert::TryFrom; - use std::fs::File; - use std::path::{Path, PathBuf}; - use std::sync::Arc; + use crate::schema::types::{Type, TypePtr}; + use crate::util::test_common::RandGen; #[test] fn test_arrow_reader_all_columns() { @@ -300,12 +352,14 @@ mod tests { let parquet_file_reader = get_test_reader("parquet/generated_simple_numerics/blogs.parquet"); - let max_len = parquet_file_reader.metadata().file_metadata().num_rows() as usize; + let file_metadata = parquet_file_reader.metadata().file_metadata(); + let max_len = file_metadata.num_rows() as usize; + let mask = ProjectionMask::leaves(file_metadata.schema_descr(), [2]); let mut arrow_reader = ParquetFileArrowReader::new(parquet_file_reader); let mut record_batch_reader = arrow_reader - .get_record_reader_by_columns(vec![2], 60) + .get_record_reader_by_columns(mask, 60) .expect("Failed to read into array!"); // Verify that the schema was correctly parsed @@ -317,20 +371,77 @@ mod tests { } #[test] - fn test_bool_single_column_reader_test() { - let message_type = " - message test_schema { - REQUIRED BOOLEAN leaf; - } + fn test_null_column_reader_test() { + let mut file = tempfile::tempfile().unwrap(); + + let schema = " + message message { + OPTIONAL INT32 int32; + } "; + let schema = Arc::new(parse_message_type(schema).unwrap()); + + let def_levels = vec![vec![0, 0, 0], vec![0, 0, 0, 0]]; + generate_single_column_file_with_data::( + &[vec![], vec![]], + Some(&def_levels), + file.try_clone().unwrap(), // Cannot use &mut File (#1163) + schema, + Some(Field::new("int32", ArrowDataType::Null, true)), + &Default::default(), + ) + .unwrap(); - let converter = FromConverter::new(); - run_single_column_reader_tests::< - BoolType, - BooleanArray, - FromConverter>, BooleanArray>, - BoolType, - >(2, message_type, &converter); + file.rewind().unwrap(); + + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); + let record_reader = arrow_reader.get_record_reader(2).unwrap(); + + let batches = record_reader.collect::>>().unwrap(); + + assert_eq!(batches.len(), 4); + for batch in &batches[0..3] { + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.column(0).null_count(), 2); + } + + assert_eq!(batches[3].num_rows(), 1); + assert_eq!(batches[3].num_columns(), 1); + assert_eq!(batches[3].column(0).null_count(), 1); + } + + #[test] + fn test_primitive_single_column_reader_test() { + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + &FromConverter::new(), + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], + ); + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + &FromConverter::new(), + &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_BINARY_PACKED, + ], + ); + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + &FromConverter::new(), + &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_BINARY_PACKED, + ], + ); } struct RandFixedLenGen {} @@ -345,36 +456,36 @@ mod tests { #[test] fn test_fixed_length_binary_column_reader() { - let message_type = " - message test_schema { - REQUIRED FIXED_LEN_BYTE_ARRAY (20) leaf; - } - "; - let converter = FixedSizeArrayConverter::new(20); run_single_column_reader_tests::< FixedLenByteArrayType, FixedSizeBinaryArray, FixedSizeArrayConverter, RandFixedLenGen, - >(20, message_type, &converter); + >( + 20, + ConvertedType::NONE, + None, + &converter, + &[Encoding::PLAIN, Encoding::RLE_DICTIONARY], + ); } #[test] fn test_interval_day_time_column_reader() { - let message_type = " - message test_schema { - REQUIRED FIXED_LEN_BYTE_ARRAY (12) leaf (INTERVAL); - } - "; - let converter = IntervalDayTimeArrayConverter {}; run_single_column_reader_tests::< FixedLenByteArrayType, IntervalDayTimeArray, IntervalDayTimeArrayConverter, RandFixedLenGen, - >(12, message_type, &converter); + >( + 12, + ConvertedType::INTERVAL, + None, + &converter, + &[Encoding::PLAIN, Encoding::RLE_DICTIONARY], + ); } struct RandUtf8Gen {} @@ -387,19 +498,124 @@ mod tests { #[test] fn test_utf8_single_column_reader_test() { - let message_type = " - message test_schema { - REQUIRED BINARY leaf (UTF8); - } - "; + let encodings = &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_LENGTH_BYTE_ARRAY, + Encoding::DELTA_BYTE_ARRAY, + ]; + + let converter = BinaryArrayConverter {}; + run_single_column_reader_tests::< + ByteArrayType, + BinaryArray, + BinaryArrayConverter, + RandUtf8Gen, + >(2, ConvertedType::NONE, None, &converter, encodings); + + let utf8_converter = Utf8ArrayConverter {}; + run_single_column_reader_tests::< + ByteArrayType, + StringArray, + Utf8ArrayConverter, + RandUtf8Gen, + >(2, ConvertedType::UTF8, None, &utf8_converter, encodings); - let converter = Utf8ArrayConverter {}; run_single_column_reader_tests::< ByteArrayType, StringArray, Utf8ArrayConverter, RandUtf8Gen, - >(2, message_type, &converter); + >( + 2, + ConvertedType::UTF8, + Some(ArrowDataType::Utf8), + &utf8_converter, + encodings, + ); + + let large_utf8_converter = LargeUtf8ArrayConverter {}; + run_single_column_reader_tests::< + ByteArrayType, + LargeStringArray, + LargeUtf8ArrayConverter, + RandUtf8Gen, + >( + 2, + ConvertedType::UTF8, + Some(ArrowDataType::LargeUtf8), + &large_utf8_converter, + encodings, + ); + + let small_key_types = [ArrowDataType::Int8, ArrowDataType::UInt8]; + for key in &small_key_types { + for encoding in encodings { + let mut opts = TestOptions::new(2, 20, 15).with_null_percent(50); + opts.encoding = *encoding; + + // Cannot run full test suite as keys overflow, run small test instead + single_column_reader_test::< + ByteArrayType, + StringArray, + Utf8ArrayConverter, + RandUtf8Gen, + >( + opts, + 2, + ConvertedType::UTF8, + Some(ArrowDataType::Dictionary( + Box::new(key.clone()), + Box::new(ArrowDataType::Utf8), + )), + &utf8_converter, + ); + } + } + + let key_types = [ + ArrowDataType::Int16, + ArrowDataType::UInt16, + ArrowDataType::Int32, + ArrowDataType::UInt32, + ArrowDataType::Int64, + ArrowDataType::UInt64, + ]; + + for key in &key_types { + run_single_column_reader_tests::< + ByteArrayType, + StringArray, + Utf8ArrayConverter, + RandUtf8Gen, + >( + 2, + ConvertedType::UTF8, + Some(ArrowDataType::Dictionary( + Box::new(key.clone()), + Box::new(ArrowDataType::Utf8), + )), + &utf8_converter, + encodings, + ); + + // https://github.com/apache/arrow-rs/issues/1179 + // run_single_column_reader_tests::< + // ByteArrayType, + // LargeStringArray, + // LargeUtf8ArrayConverter, + // RandUtf8Gen, + // >( + // 2, + // ConvertedType::UTF8, + // Some(ArrowDataType::Dictionary( + // Box::new(key.clone()), + // Box::new(ArrowDataType::LargeUtf8), + // )), + // &large_utf8_converter, + // encodings + // ); + } } #[test] @@ -409,9 +625,8 @@ mod tests { let file_variants = vec![("fixed_length", 25), ("int32", 4), ("int64", 10)]; for (prefix, target_precision) in file_variants { let path = format!("{}/{}_decimal.parquet", testdata, prefix); - let parquet_reader = - SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader)); + let file = File::open(&path).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); let mut record_reader = arrow_reader.get_record_reader(32).unwrap(); @@ -429,25 +644,102 @@ mod tests { assert_eq!(col.scale(), 2); for (i, v) in expected.enumerate() { - assert_eq!(col.value(i), v * 100_i128); + assert_eq!(col.value(i).as_i128(), v * 100_i128); } } } /// Parameters for single_column_reader_test - #[derive(Debug)] + #[derive(Debug, Clone)] struct TestOptions { /// Number of row group to write to parquet (row group size = /// num_row_groups / num_rows) num_row_groups: usize, - /// Total number of rows + /// Total number of rows per row group num_rows: usize, /// Size of batches to read back record_batch_size: usize, - /// Total number of batches to attempt to read. - /// `record_batch_size` * `num_iterations` should be greater - /// than `num_rows` to ensure the data can be read back completely - num_iterations: usize, + /// Percentage of nulls in column or None if required + null_percent: Option, + /// Set write batch size + /// + /// This is the number of rows that are written at once to a page and + /// therefore acts as a bound on the page granularity of a row group + write_batch_size: usize, + /// Maximum size of page in bytes + max_data_page_size: usize, + /// Maximum size of dictionary page in bytes + max_dict_page_size: usize, + /// Writer version + writer_version: WriterVersion, + /// Encoding + encoding: Encoding, + } + + impl Default for TestOptions { + fn default() -> Self { + Self { + num_row_groups: 2, + num_rows: 100, + record_batch_size: 15, + null_percent: None, + write_batch_size: 64, + max_data_page_size: 1024 * 1024, + max_dict_page_size: 1024 * 1024, + writer_version: WriterVersion::PARQUET_1_0, + encoding: Encoding::PLAIN, + } + } + } + + impl TestOptions { + fn new(num_row_groups: usize, num_rows: usize, record_batch_size: usize) -> Self { + Self { + num_row_groups, + num_rows, + record_batch_size, + ..Default::default() + } + } + + fn with_null_percent(self, null_percent: usize) -> Self { + Self { + null_percent: Some(null_percent), + ..self + } + } + + fn with_max_data_page_size(self, max_data_page_size: usize) -> Self { + Self { + max_data_page_size, + ..self + } + } + + fn with_max_dict_page_size(self, max_dict_page_size: usize) -> Self { + Self { + max_dict_page_size, + ..self + } + } + + fn writer_props(&self) -> WriterProperties { + let builder = WriterProperties::builder() + .set_data_pagesize_limit(self.max_data_page_size) + .set_write_batch_size(self.write_batch_size) + .set_writer_version(self.writer_version); + + let builder = match self.encoding { + Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => builder + .set_dictionary_enabled(true) + .set_dictionary_pagesize_limit(self.max_dict_page_size), + _ => builder + .set_dictionary_enabled(false) + .set_encoding(self.encoding), + }; + + builder.build() + } } /// Create a parquet file and then read it using @@ -458,53 +750,58 @@ mod tests { /// value generator fn run_single_column_reader_tests( rand_max: i32, - message_type: &str, + converted_type: ConvertedType, + arrow_type: Option, converter: &C, + encodings: &[Encoding], ) where T: DataType, G: RandGen, - A: PartialEq + Array + 'static, + A: Array + 'static, C: Converter>, A> + 'static, { let all_options = vec![ // choose record_batch_batch (15) so batches cross row // group boundaries (50 rows in 2 row groups) cases. - TestOptions { - num_row_groups: 2, - num_rows: 100, - record_batch_size: 15, - num_iterations: 50, - }, + TestOptions::new(2, 100, 15), // choose record_batch_batch (5) so batches sometime fall // on row group boundaries and (25 rows in 3 row groups // --> row groups of 10, 10, and 5). Tests buffer // refilling edge cases. - TestOptions { - num_row_groups: 3, - num_rows: 25, - record_batch_size: 5, - num_iterations: 50, - }, + TestOptions::new(3, 25, 5), // Choose record_batch_size (25) so all batches fall // exactly on row group boundary (25). Tests buffer // refilling edge cases. - TestOptions { - num_row_groups: 4, - num_rows: 100, - record_batch_size: 25, - num_iterations: 50, - }, + TestOptions::new(4, 100, 25), + // Set maximum page size so row groups have multiple pages + TestOptions::new(3, 256, 73).with_max_data_page_size(128), + // Set small dictionary page size to test dictionary fallback + TestOptions::new(3, 256, 57).with_max_dict_page_size(128), + // Test optional but with no nulls + TestOptions::new(2, 256, 127).with_null_percent(0), + // Test optional with nulls + TestOptions::new(2, 256, 93).with_null_percent(25), ]; all_options.into_iter().for_each(|opts| { - // Print out options to facilitate debugging failures on CI - println!("Running with Test Options: {:?}", opts); - single_column_reader_test::( - opts, - rand_max, - message_type, - converter, - ) + for writer_version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] + { + for encoding in encodings { + let opts = TestOptions { + writer_version, + encoding: *encoding, + ..opts + }; + + single_column_reader_test::( + opts, + rand_max, + converted_type, + arrow_type.clone(), + converter, + ) + } + } }); } @@ -514,93 +811,181 @@ mod tests { fn single_column_reader_test( opts: TestOptions, rand_max: i32, - message_type: &str, + converted_type: ConvertedType, + arrow_type: Option, converter: &C, ) where T: DataType, G: RandGen, - A: PartialEq + Array + 'static, + A: Array + 'static, C: Converter>, A> + 'static, { + // Print out options to facilitate debugging failures on CI + println!( + "Running single_column_reader_test ConvertedType::{}/ArrowType::{:?} with Options: {:?}", + converted_type, arrow_type, opts + ); + + let (repetition, def_levels) = match opts.null_percent.as_ref() { + Some(null_percent) => { + let mut rng = thread_rng(); + + let def_levels: Vec> = (0..opts.num_row_groups) + .map(|_| { + std::iter::from_fn(|| { + Some((rng.next_u32() as usize % 100 >= *null_percent) as i16) + }) + .take(opts.num_rows) + .collect() + }) + .collect(); + (Repetition::OPTIONAL, Some(def_levels)) + } + None => (Repetition::REQUIRED, None), + }; + let values: Vec> = (0..opts.num_row_groups) - .map(|_| G::gen_vec(rand_max, opts.num_rows)) + .map(|idx| { + let null_count = match def_levels.as_ref() { + Some(d) => d[idx].iter().filter(|x| **x == 0).count(), + None => 0, + }; + G::gen_vec(rand_max, opts.num_rows - null_count) + }) .collect(); - let path = get_temp_filename(); - - let schema = parse_message_type(message_type).map(Arc::new).unwrap(); + let len = match T::get_physical_type() { + crate::basic::Type::FIXED_LEN_BYTE_ARRAY => rand_max, + crate::basic::Type::INT96 => 12, + _ => -1, + }; - generate_single_column_file_with_data::(&values, path.as_path(), schema) - .unwrap(); + let mut fields = vec![Arc::new( + Type::primitive_type_builder("leaf", T::get_physical_type()) + .with_repetition(repetition) + .with_converted_type(converted_type) + .with_length(len) + .build() + .unwrap(), + )]; + + let schema = Arc::new( + Type::group_type_builder("test_schema") + .with_fields(&mut fields) + .build() + .unwrap(), + ); + + let arrow_field = arrow_type + .clone() + .map(|t| arrow::datatypes::Field::new("leaf", t, false)); + + let mut file = tempfile::tempfile().unwrap(); + + generate_single_column_file_with_data::( + &values, + def_levels.as_ref(), + file.try_clone().unwrap(), // Cannot use &mut File (#1163) + schema, + arrow_field, + &opts, + ) + .unwrap(); - let parquet_reader = - SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader)); + file.rewind().unwrap(); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); let mut record_reader = arrow_reader .get_record_reader(opts.record_batch_size) .unwrap(); - let expected_data: Vec> = values - .iter() - .flat_map(|v| v.iter()) - .map(|b| Some(b.clone())) - .collect(); + let expected_data: Vec> = match def_levels { + Some(levels) => { + let mut values_iter = values.iter().flatten(); + levels + .iter() + .flatten() + .map(|d| match d { + 1 => Some(values_iter.next().cloned().unwrap()), + 0 => None, + _ => unreachable!(), + }) + .collect() + } + None => values.iter().flatten().map(|b| Some(b.clone())).collect(), + }; - for i in 0..opts.num_iterations { - let start = i * opts.record_batch_size; + assert_eq!(expected_data.len(), opts.num_rows * opts.num_row_groups); - let batch = record_reader.next(); - if start < expected_data.len() { - let end = min(start + opts.record_batch_size, expected_data.len()); - assert!(batch.is_some()); + let mut total_read = 0; + loop { + let maybe_batch = record_reader.next(); + if total_read < expected_data.len() { + let end = min(total_read + opts.record_batch_size, expected_data.len()); + let batch = maybe_batch.unwrap().unwrap(); + assert_eq!(end - total_read, batch.num_rows()); let mut data = vec![]; - data.extend_from_slice(&expected_data[start..end]); - - assert_eq!( - &converter.convert(data).unwrap(), - batch - .unwrap() - .unwrap() - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - ); + data.extend_from_slice(&expected_data[total_read..end]); + + let a = converter.convert(data).unwrap(); + let mut b = Arc::clone(batch.column(0)); + + if let Some(arrow_type) = arrow_type.as_ref() { + assert_eq!(b.data_type(), arrow_type); + if let ArrowDataType::Dictionary(_, v) = arrow_type { + assert_eq!(a.data_type(), v.as_ref()); + b = arrow::compute::cast(&b, v.as_ref()).unwrap() + } + } + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.data(), b.data(), "{:#?} vs {:#?}", a.data(), b.data()); + + total_read = end; } else { - assert!(batch.is_none()); + assert!(maybe_batch.is_none()); + break; } } } fn generate_single_column_file_with_data( values: &[Vec], - path: &Path, + def_levels: Option<&Vec>>, + file: File, schema: TypePtr, + field: Option, + opts: &TestOptions, ) -> Result { - let file = File::create(path)?; - let writer_props = Arc::new(WriterProperties::builder().build()); + let mut writer_props = opts.writer_props(); + if let Some(field) = field { + let arrow_schema = arrow::datatypes::Schema::new(vec![field]); + add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut writer_props); + } - let mut writer = SerializedFileWriter::new(file, schema, writer_props)?; + let mut writer = SerializedFileWriter::new(file, schema, Arc::new(writer_props))?; - for v in values { + for (idx, v) in values.iter().enumerate() { + let def_levels = def_levels.map(|d| d[idx].as_slice()); let mut row_group_writer = writer.next_row_group()?; - let mut column_writer = row_group_writer - .next_column()? - .expect("Column writer is none!"); + { + let mut column_writer = row_group_writer + .next_column()? + .expect("Column writer is none!"); - get_typed_column_writer_mut::(&mut column_writer) - .write_batch(v, None, None)?; + column_writer + .typed::() + .write_batch(v, def_levels, None)?; - row_group_writer.close_column(column_writer)?; - writer.close_row_group(row_group_writer)? + column_writer.close()?; + } + row_group_writer.close()?; } writer.close() } - fn get_test_reader(file_name: &str) -> Arc { + fn get_test_reader(file_name: &str) -> Arc> { let file = get_test_file(file_name); let reader = @@ -657,9 +1042,8 @@ mod tests { // (see: ARROW-11452) let testdata = arrow::util::test_util::parquet_test_data(); let path = format!("{}/nested_structs.rust.parquet", testdata); - let parquet_file_reader = - SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_file_reader)); + let file = File::open(&path).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); let record_batch_reader = arrow_reader .get_record_reader(60) .expect("Failed to read into array!"); @@ -667,15 +1051,49 @@ mod tests { for batch in record_batch_reader { batch.unwrap(); } + + let mask = ProjectionMask::leaves(arrow_reader.parquet_schema(), [3, 8, 10]); + let projected_reader = arrow_reader + .get_record_reader_by_columns(mask.clone(), 60) + .unwrap(); + let projected_schema = arrow_reader.get_schema_by_columns(mask).unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new( + "roll_num", + ArrowDataType::Struct(vec![Field::new( + "count", + ArrowDataType::UInt64, + false, + )]), + false, + ), + Field::new( + "PC_CUR", + ArrowDataType::Struct(vec![ + Field::new("mean", ArrowDataType::Int64, false), + Field::new("sum", ArrowDataType::Int64, false), + ]), + false, + ), + ]); + + // Tests for #1652 and #1654 + assert_eq!(projected_reader.schema().as_ref(), &projected_schema); + assert_eq!(expected_schema, projected_schema); + + for batch in projected_reader { + let batch = batch.unwrap(); + assert_eq!(batch.schema().as_ref(), &projected_schema); + } } #[test] fn test_read_maps() { let testdata = arrow::util::test_util::parquet_test_data(); let path = format!("{}/nested_maps.snappy.parquet", testdata); - let parquet_file_reader = - SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_file_reader)); + let file = File::open(&path).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); let record_batch_reader = arrow_reader .get_record_reader(60) .expect("Failed to read into array!"); @@ -684,4 +1102,333 @@ mod tests { batch.unwrap(); } } + + #[test] + fn test_nested_nullability() { + let message_type = "message nested { + OPTIONAL Group group { + REQUIRED INT32 leaf; + } + }"; + + let file = tempfile::tempfile().unwrap(); + let schema = Arc::new(parse_message_type(message_type).unwrap()); + + { + // Write using low-level parquet API (#1167) + let writer_props = Arc::new(WriterProperties::builder().build()); + let mut writer = SerializedFileWriter::new( + file.try_clone().unwrap(), + schema, + writer_props, + ) + .unwrap(); + + { + let mut row_group_writer = writer.next_row_group().unwrap(); + let mut column_writer = row_group_writer.next_column().unwrap().unwrap(); + + column_writer + .typed::() + .write_batch(&[34, 76], Some(&[0, 1, 0, 1]), None) + .unwrap(); + + column_writer.close().unwrap(); + row_group_writer.close().unwrap(); + } + + writer.close().unwrap(); + } + + let mut reader = ParquetFileArrowReader::try_new(file).unwrap(); + let mask = ProjectionMask::leaves(reader.parquet_schema(), [0]); + + let reader = reader.get_record_reader_by_columns(mask, 1024).unwrap(); + + let expected_schema = Schema::new(vec![Field::new( + "group", + ArrowDataType::Struct(vec![Field::new("leaf", ArrowDataType::Int32, false)]), + true, + )]); + + let batch = reader.into_iter().next().unwrap().unwrap(); + assert_eq!(batch.schema().as_ref(), &expected_schema); + assert_eq!(batch.num_rows(), 4); + assert_eq!(batch.column(0).data().null_count(), 2); + } + + #[test] + fn test_invalid_utf8() { + // a parquet file with 1 column with invalid utf8 + let data = vec![ + 80, 65, 82, 49, 21, 6, 21, 22, 21, 22, 92, 21, 2, 21, 0, 21, 2, 21, 0, 21, 4, + 21, 0, 18, 28, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, + 108, 111, 0, 0, 0, 3, 1, 5, 0, 0, 0, 104, 101, 255, 108, 111, 38, 110, 28, + 21, 12, 25, 37, 6, 0, 25, 24, 2, 99, 49, 21, 0, 22, 2, 22, 102, 22, 102, 38, + 8, 60, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, 108, 111, + 0, 0, 0, 21, 4, 25, 44, 72, 4, 114, 111, 111, 116, 21, 2, 0, 21, 12, 37, 2, + 24, 2, 99, 49, 37, 0, 76, 28, 0, 0, 0, 22, 2, 25, 28, 25, 28, 38, 110, 28, + 21, 12, 25, 37, 6, 0, 25, 24, 2, 99, 49, 21, 0, 22, 2, 22, 102, 22, 102, 38, + 8, 60, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, 108, 111, + 0, 0, 0, 22, 102, 22, 2, 0, 40, 44, 65, 114, 114, 111, 119, 50, 32, 45, 32, + 78, 97, 116, 105, 118, 101, 32, 82, 117, 115, 116, 32, 105, 109, 112, 108, + 101, 109, 101, 110, 116, 97, 116, 105, 111, 110, 32, 111, 102, 32, 65, 114, + 114, 111, 119, 0, 130, 0, 0, 0, 80, 65, 82, 49, + ]; + + let file = Bytes::from(data); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); + let mut record_batch_reader = arrow_reader + .get_record_reader_by_columns(ProjectionMask::all(), 10) + .unwrap(); + + let error = record_batch_reader.next().unwrap().unwrap_err(); + + assert!( + error.to_string().contains("invalid utf-8 sequence"), + "{}", + error + ); + } + + #[test] + fn test_dictionary_preservation() { + let mut fields = vec![Arc::new( + Type::primitive_type_builder("leaf", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::OPTIONAL) + .with_converted_type(ConvertedType::UTF8) + .build() + .unwrap(), + )]; + + let schema = Arc::new( + Type::group_type_builder("test_schema") + .with_fields(&mut fields) + .build() + .unwrap(), + ); + + let dict_type = ArrowDataType::Dictionary( + Box::new(ArrowDataType::Int32), + Box::new(ArrowDataType::Utf8), + ); + + let arrow_field = Field::new("leaf", dict_type, true); + + let mut file = tempfile::tempfile().unwrap(); + + let values = vec![ + vec![ + ByteArray::from("hello"), + ByteArray::from("a"), + ByteArray::from("b"), + ByteArray::from("d"), + ], + vec![ + ByteArray::from("c"), + ByteArray::from("a"), + ByteArray::from("b"), + ], + ]; + + let def_levels = vec![ + vec![1, 0, 0, 1, 0, 0, 1, 1], + vec![0, 0, 1, 1, 0, 0, 1, 0, 0], + ]; + + let opts = TestOptions { + encoding: Encoding::RLE_DICTIONARY, + ..Default::default() + }; + + generate_single_column_file_with_data::( + &values, + Some(&def_levels), + file.try_clone().unwrap(), // Cannot use &mut File (#1163) + schema, + Some(arrow_field), + &opts, + ) + .unwrap(); + + file.rewind().unwrap(); + + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); + + let record_reader = arrow_reader.get_record_reader(3).unwrap(); + + let batches = record_reader + .collect::>>() + .unwrap(); + + assert_eq!(batches.len(), 6); + assert!(batches.iter().all(|x| x.num_columns() == 1)); + + let row_counts = batches + .iter() + .map(|x| (x.num_rows(), x.column(0).null_count())) + .collect::>(); + + assert_eq!( + row_counts, + vec![(3, 2), (3, 2), (3, 1), (3, 1), (3, 2), (2, 2)] + ); + + let get_dict = + |batch: &RecordBatch| batch.column(0).data().child_data()[0].clone(); + + // First and second batch in same row group -> same dictionary + assert_eq!(get_dict(&batches[0]), get_dict(&batches[1])); + // Third batch spans row group -> computed dictionary + assert_ne!(get_dict(&batches[1]), get_dict(&batches[2])); + assert_ne!(get_dict(&batches[2]), get_dict(&batches[3])); + // Fourth, fifth and sixth from same row group -> same dictionary + assert_eq!(get_dict(&batches[3]), get_dict(&batches[4])); + assert_eq!(get_dict(&batches[4]), get_dict(&batches[5])); + } + + #[test] + fn test_read_null_list() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/null_list.parquet", testdata); + let file = File::open(&path).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); + let mut record_batch_reader = arrow_reader + .get_record_reader(60) + .expect("Failed to read into array!"); + + let batch = record_batch_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.column(0).len(), 1); + + let list = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(list.len(), 1); + assert!(list.is_valid(0)); + + let val = list.value(0); + assert_eq!(val.len(), 0); + } + + #[test] + fn test_null_schema_inference() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/null_list.parquet", testdata); + let reader = + Arc::new(SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap()); + + let arrow_field = Field::new( + "emptylist", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Null, true))), + true, + ); + + let options = ArrowReaderOptions::default().with_skip_arrow_metadata(true); + let mut arrow_reader = ParquetFileArrowReader::new_with_options(reader, options); + let schema = arrow_reader.get_schema().unwrap(); + assert_eq!(schema.fields().len(), 1); + assert_eq!(schema.field(0), &arrow_field); + } + + #[test] + fn test_skip_metadata() { + let col = Arc::new(TimestampNanosecondArray::from_iter_values(vec![0, 1, 2])); + let field = Field::new("col", col.data_type().clone(), true); + + let schema_without_metadata = Arc::new(Schema::new(vec![field.clone()])); + + let metadata = [("key".to_string(), "value".to_string())] + .into_iter() + .collect(); + + let schema_with_metadata = + Arc::new(Schema::new(vec![field.with_metadata(Some(metadata))])); + + assert_ne!(schema_with_metadata, schema_without_metadata); + + let batch = + RecordBatch::try_new(schema_with_metadata.clone(), vec![col as ArrayRef]) + .unwrap(); + + let file = |version: WriterVersion| { + let props = WriterProperties::builder() + .set_writer_version(version) + .build(); + + let file = tempfile().unwrap(); + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + batch.schema(), + Some(props), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + file + }; + + let v1_reader = Arc::new( + SerializedFileReader::new(file(WriterVersion::PARQUET_1_0)).unwrap(), + ); + let v2_reader = Arc::new( + SerializedFileReader::new(file(WriterVersion::PARQUET_2_0)).unwrap(), + ); + + let mut arrow_reader = ParquetFileArrowReader::new(v1_reader.clone()); + assert_eq!( + &arrow_reader.get_schema().unwrap(), + schema_with_metadata.as_ref() + ); + + let options = ArrowReaderOptions::new().with_skip_arrow_metadata(true); + let mut arrow_reader = + ParquetFileArrowReader::new_with_options(v1_reader, options); + assert_eq!( + &arrow_reader.get_schema().unwrap(), + schema_without_metadata.as_ref() + ); + + let mut arrow_reader = ParquetFileArrowReader::new(v2_reader.clone()); + assert_eq!( + &arrow_reader.get_schema().unwrap(), + schema_with_metadata.as_ref() + ); + + let options = ArrowReaderOptions::new().with_skip_arrow_metadata(true); + let mut arrow_reader = + ParquetFileArrowReader::new_with_options(v2_reader, options); + assert_eq!( + &arrow_reader.get_schema().unwrap(), + schema_without_metadata.as_ref() + ); + } + + #[test] + fn test_empty_projection() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_plain.parquet", testdata); + let file = File::open(&path).unwrap(); + + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); + let file_metadata = arrow_reader.metadata().file_metadata(); + let expected_rows = file_metadata.num_rows() as usize; + let schema = file_metadata.schema_descr_ptr(); + + let mask = ProjectionMask::leaves(&schema, []); + let batch_reader = arrow_reader.get_record_reader_by_columns(mask, 2).unwrap(); + + let mut total_rows = 0; + for maybe_batch in batch_reader { + let batch = maybe_batch.unwrap(); + total_rows += batch.num_rows(); + assert_eq!(batch.num_columns(), 0); + assert!(batch.num_rows() <= 2); + } + + assert_eq!(total_rows, expected_rows); + } } diff --git a/parquet/src/arrow/arrow_writer/levels.rs b/parquet/src/arrow/arrow_writer/levels.rs new file mode 100644 index 000000000000..073754262e29 --- /dev/null +++ b/parquet/src/arrow/arrow_writer/levels.rs @@ -0,0 +1,1403 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parquet definition and repetition levels +//! +//! Contains the algorithm for computing definition and repetition levels. +//! The algorithm works by tracking the slots of an array that should +//! ultimately be populated when writing to Parquet. +//! Parquet achieves nesting through definition levels and repetition levels \[1\]. +//! Definition levels specify how many optional fields in the part for the column +//! are defined. +//! Repetition levels specify at what repeated field (list) in the path a column +//! is defined. +//! +//! In a nested data structure such as `a.b.c`, one can see levels as defining +//! whether a record is defined at `a`, `a.b`, or `a.b.c`. +//! Optional fields are nullable fields, thus if all 3 fields +//! are nullable, the maximum definition could be = 3 if there are no lists. +//! +//! The algorithm in this module computes the necessary information to enable +//! the writer to keep track of which columns are at which levels, and to extract +//! the correct values at the correct slots from Arrow arrays. +//! +//! It works by walking a record batch's arrays, keeping track of what values +//! are non-null, their positions and computing what their levels are. +//! +//! \[1\] [parquet-format#nested-encoding](https://github.com/apache/parquet-format#nested-encoding) + +use crate::errors::{ParquetError, Result}; +use arrow::array::{ + make_array, Array, ArrayData, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, + StructArray, +}; +use arrow::datatypes::{DataType, Field}; +use std::ops::Range; + +/// Performs a depth-first scan of the children of `array`, constructing [`LevelInfo`] +/// for each leaf column encountered +pub(crate) fn calculate_array_levels( + array: &ArrayRef, + field: &Field, +) -> Result> { + let mut builder = LevelInfoBuilder::try_new(field, Default::default())?; + builder.write(array, 0..array.len()); + Ok(builder.finish()) +} + +/// Returns true if the DataType can be represented as a primitive parquet column, +/// i.e. a leaf array with no children +fn is_leaf(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::LargeBinary + | DataType::Decimal(_, _) + | DataType::FixedSizeBinary(_) + ) +} + +/// The definition and repetition level of an array within a potentially nested hierarchy +#[derive(Debug, Default, Clone, Copy)] +struct LevelContext { + /// The current repetition level + rep_level: i16, + /// The current definition level + def_level: i16, +} + +/// A helper to construct [`LevelInfo`] from a potentially nested [`Field`] +enum LevelInfoBuilder { + /// A primitive, leaf array + Primitive(LevelInfo), + /// A list array, contains the [`LevelInfoBuilder`] of the child and + /// the [`LevelContext`] of this list + List(Box, LevelContext), + /// A list array, contains the [`LevelInfoBuilder`] of its children and + /// the [`LevelContext`] of this struct array + Struct(Vec, LevelContext), +} + +impl LevelInfoBuilder { + /// Create a new [`LevelInfoBuilder`] for the given [`Field`] and parent [`LevelContext`] + fn try_new(field: &Field, parent_ctx: LevelContext) -> Result { + match field.data_type() { + d if is_leaf(d) => Ok(Self::Primitive(LevelInfo::new( + parent_ctx, + field.is_nullable(), + ))), + DataType::Dictionary(_, v) if is_leaf(v.as_ref()) => Ok(Self::Primitive( + LevelInfo::new(parent_ctx, field.is_nullable()), + )), + DataType::Struct(children) => { + let def_level = match field.is_nullable() { + true => parent_ctx.def_level + 1, + false => parent_ctx.def_level, + }; + + let ctx = LevelContext { + rep_level: parent_ctx.rep_level, + def_level, + }; + + let children = children + .iter() + .map(|f| Self::try_new(f, ctx)) + .collect::>()?; + + Ok(Self::Struct(children, ctx)) + } + DataType::List(child) + | DataType::LargeList(child) + | DataType::Map(child, _) => { + let def_level = match field.is_nullable() { + true => parent_ctx.def_level + 2, + false => parent_ctx.def_level + 1, + }; + + let ctx = LevelContext { + rep_level: parent_ctx.rep_level + 1, + def_level, + }; + + let child = Self::try_new(child.as_ref(), ctx)?; + Ok(Self::List(Box::new(child), ctx)) + } + d => Err(nyi_err!("Datatype {} is not yet supported", d)), + } + } + + /// Finish this [`LevelInfoBuilder`] returning the [`LevelInfo`] for the leaf columns + /// as enumerated by a depth-first search + fn finish(self) -> Vec { + match self { + LevelInfoBuilder::Primitive(v) => vec![v], + LevelInfoBuilder::List(v, _) => v.finish(), + LevelInfoBuilder::Struct(v, _) => { + v.into_iter().flat_map(|l| l.finish()).collect() + } + } + } + + /// Given an `array`, write the level data for the elements in `range` + fn write(&mut self, array: &ArrayRef, range: Range) { + match array.data_type() { + d if is_leaf(d) => self.write_leaf(array, range), + DataType::Dictionary(_, v) if is_leaf(v.as_ref()) => { + self.write_leaf(array, range) + } + DataType::Struct(_) => { + let array = array.as_any().downcast_ref::().unwrap(); + self.write_struct(array, range) + } + DataType::List(_) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + self.write_list(array.value_offsets(), array.data(), range) + } + DataType::LargeList(_) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + self.write_list(array.value_offsets(), array.data(), range) + } + DataType::Map(_, _) => { + let array = array.as_any().downcast_ref::().unwrap(); + // A Map is just as ListArray with a StructArray child, we therefore + // treat it as such to avoid code duplication + self.write_list(array.value_offsets(), array.data(), range) + } + _ => unreachable!(), + } + } + + /// Write `range` elements from ListArray `array` + /// + /// Note: MapArrays are ListArray under the hood and so are dispatched to this method + fn write_list( + &mut self, + offsets: &[O], + list_data: &ArrayData, + range: Range, + ) { + let (child, ctx) = match self { + Self::List(child, ctx) => (child, ctx), + _ => unreachable!(), + }; + + let offsets = &offsets[range.start..range.end + 1]; + let child_array = make_array(list_data.child_data()[0].clone()); + + let write_non_null_slice = + |child: &mut LevelInfoBuilder, start_idx: usize, end_idx: usize| { + child.write(&child_array, start_idx..end_idx); + child.visit_leaves(|leaf| { + let rep_levels = leaf.rep_levels.as_mut().unwrap(); + let mut rev = rep_levels.iter_mut().rev(); + let mut remaining = end_idx - start_idx; + + loop { + let next = rev.next().unwrap(); + if *next > ctx.rep_level { + // Nested element - ignore + continue; + } + + remaining -= 1; + if remaining == 0 { + *next = ctx.rep_level - 1; + break; + } + } + }) + }; + + let write_empty_slice = |child: &mut LevelInfoBuilder| { + child.visit_leaves(|leaf| { + let rep_levels = leaf.rep_levels.as_mut().unwrap(); + rep_levels.push(ctx.rep_level - 1); + let def_levels = leaf.def_levels.as_mut().unwrap(); + def_levels.push(ctx.def_level - 1); + }) + }; + + let write_null_slice = |child: &mut LevelInfoBuilder| { + child.visit_leaves(|leaf| { + let rep_levels = leaf.rep_levels.as_mut().unwrap(); + rep_levels.push(ctx.rep_level - 1); + let def_levels = leaf.def_levels.as_mut().unwrap(); + def_levels.push(ctx.def_level - 2); + }) + }; + + match list_data.null_bitmap() { + Some(nulls) => { + let null_offset = list_data.offset() + range.start; + // TODO: Faster bitmask iteration (#1757) + for (idx, w) in offsets.windows(2).enumerate() { + let is_valid = nulls.is_set(idx + null_offset); + let start_idx = w[0].to_usize().unwrap(); + let end_idx = w[1].to_usize().unwrap(); + if !is_valid { + write_null_slice(child) + } else if start_idx == end_idx { + write_empty_slice(child) + } else { + write_non_null_slice(child, start_idx, end_idx) + } + } + } + None => { + for w in offsets.windows(2) { + let start_idx = w[0].to_usize().unwrap(); + let end_idx = w[1].to_usize().unwrap(); + if start_idx == end_idx { + write_empty_slice(child) + } else { + write_non_null_slice(child, start_idx, end_idx) + } + } + } + } + } + + /// Write `range` elements from StructArray `array` + fn write_struct(&mut self, array: &StructArray, range: Range) { + let (children, ctx) = match self { + Self::Struct(children, ctx) => (children, ctx), + _ => unreachable!(), + }; + + let write_null = |children: &mut [LevelInfoBuilder], range: Range| { + for child in children { + child.visit_leaves(|info| { + let len = range.end - range.start; + + let def_levels = info.def_levels.as_mut().unwrap(); + def_levels.extend(std::iter::repeat(ctx.def_level - 1).take(len)); + + if let Some(rep_levels) = info.rep_levels.as_mut() { + rep_levels.extend(std::iter::repeat(ctx.rep_level).take(len)); + } + }) + } + }; + + let write_non_null = |children: &mut [LevelInfoBuilder], range: Range| { + for (child_array, child) in array.columns().into_iter().zip(children) { + child.write(child_array, range.clone()) + } + }; + + match array.data().null_bitmap() { + Some(validity) => { + let null_offset = array.data().offset(); + let mut last_non_null_idx = None; + let mut last_null_idx = None; + + // TODO: Faster bitmask iteration (#1757) + for i in range.clone() { + match validity.is_set(i + null_offset) { + true => { + if let Some(last_idx) = last_null_idx.take() { + write_null(children, last_idx..i) + } + last_non_null_idx.get_or_insert(i); + } + false => { + if let Some(last_idx) = last_non_null_idx.take() { + write_non_null(children, last_idx..i) + } + last_null_idx.get_or_insert(i); + } + } + } + + if let Some(last_idx) = last_null_idx.take() { + write_null(children, last_idx..range.end) + } + + if let Some(last_idx) = last_non_null_idx.take() { + write_non_null(children, last_idx..range.end) + } + } + None => write_non_null(children, range), + } + } + + /// Write a primitive array, as defined by [`is_leaf`] + fn write_leaf(&mut self, array: &ArrayRef, range: Range) { + let info = match self { + Self::Primitive(info) => info, + _ => unreachable!(), + }; + + let len = range.end - range.start; + + match &mut info.def_levels { + Some(def_levels) => { + def_levels.reserve(len); + info.non_null_indices.reserve(len); + + match array.data().null_bitmap() { + Some(nulls) => { + let nulls_offset = array.data().offset(); + // TODO: Faster bitmask iteration (#1757) + for i in range { + match nulls.is_set(i + nulls_offset) { + true => { + def_levels.push(info.max_def_level); + info.non_null_indices.push(i) + } + false => def_levels.push(info.max_def_level - 1), + } + } + } + None => { + let iter = std::iter::repeat(info.max_def_level).take(len); + def_levels.extend(iter); + info.non_null_indices.extend(range); + } + } + } + None => info.non_null_indices.extend(range), + } + + if let Some(rep_levels) = &mut info.rep_levels { + rep_levels.extend(std::iter::repeat(info.max_rep_level).take(len)) + } + } + + /// Visits all children of this node in depth first order + fn visit_leaves(&mut self, visit: impl Fn(&mut LevelInfo) + Copy) { + match self { + LevelInfoBuilder::Primitive(info) => visit(info), + LevelInfoBuilder::List(c, _) => c.visit_leaves(visit), + LevelInfoBuilder::Struct(children, _) => { + for c in children { + c.visit_leaves(visit) + } + } + } + } +} +/// The data necessary to write a primitive Arrow array to parquet, taking into account +/// any non-primitive parents it may have in the arrow representation +#[derive(Debug, Eq, PartialEq, Clone)] +pub(crate) struct LevelInfo { + /// Array's definition levels + /// + /// Present if `max_def_level != 0` + def_levels: Option>, + + /// Array's optional repetition levels + /// + /// Present if `max_rep_level != 0` + rep_levels: Option>, + + /// The corresponding array identifying non-null slices of data + /// from the primitive array + non_null_indices: Vec, + + /// The maximum definition level for this leaf column + max_def_level: i16, + + /// The maximum repetition for this leaf column + max_rep_level: i16, +} + +impl LevelInfo { + fn new(ctx: LevelContext, is_nullable: bool) -> Self { + let max_rep_level = ctx.rep_level; + let max_def_level = match is_nullable { + true => ctx.def_level + 1, + false => ctx.def_level, + }; + + Self { + def_levels: (max_def_level != 0).then(Vec::new), + rep_levels: (max_rep_level != 0).then(Vec::new), + non_null_indices: vec![], + max_def_level, + max_rep_level, + } + } + + pub fn def_levels(&self) -> Option<&[i16]> { + self.def_levels.as_deref() + } + + pub fn rep_levels(&self) -> Option<&[i16]> { + self.rep_levels.as_deref() + } + + pub fn non_null_indices(&self) -> &[usize] { + &self.non_null_indices + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow::array::*; + use arrow::buffer::Buffer; + use arrow::datatypes::{Int32Type, Schema, ToByteSlice}; + use arrow::record_batch::RecordBatch; + use arrow::util::pretty::pretty_format_columns; + + #[test] + fn test_calculate_array_levels_twitter_example() { + // based on the example at https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet.html + // [[a, b, c], [d, e, f, g]], [[h], [i,j]] + + let leaf_type = Field::new("item", DataType::Int32, false); + let inner_type = DataType::List(Box::new(leaf_type)); + let inner_field = Field::new("l2", inner_type.clone(), false); + let outer_type = DataType::List(Box::new(inner_field)); + let outer_field = Field::new("l1", outer_type.clone(), false); + + let primitives = Int32Array::from_iter(0..10); + + // Cannot use from_iter_primitive as always infers nullable + let offsets = Buffer::from_iter([0_i32, 3, 7, 8, 10]); + let inner_list = ArrayDataBuilder::new(inner_type) + .len(4) + .add_buffer(offsets) + .add_child_data(primitives.data().clone()) + .build() + .unwrap(); + + let offsets = Buffer::from_iter([0_i32, 2, 4]); + let outer_list = ArrayDataBuilder::new(outer_type) + .len(2) + .add_buffer(offsets) + .add_child_data(inner_list) + .build() + .unwrap(); + let outer_list = make_array(outer_list); + + let levels = calculate_array_levels(&outer_list, &outer_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected = LevelInfo { + def_levels: Some(vec![2; 10]), + rep_levels: Some(vec![0, 2, 2, 1, 2, 2, 2, 0, 1, 2]), + non_null_indices: vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + max_def_level: 2, + max_rep_level: 2, + }; + assert_eq!(&levels[0], &expected); + } + + #[test] + fn test_calculate_one_level_1() { + // This test calculates the levels for a non-null primitive array + let array = Arc::new(Int32Array::from_iter(0..10)) as ArrayRef; + let field = Field::new("item", DataType::Int32, false); + + let levels = calculate_array_levels(&array, &field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: None, + rep_levels: None, + non_null_indices: (0..10).collect(), + max_def_level: 0, + max_rep_level: 0, + }; + assert_eq!(&levels[0], &expected_levels); + } + + #[test] + fn test_calculate_one_level_2() { + // This test calculates the levels for a nullable primitive array + let array = Arc::new(Int32Array::from_iter([ + Some(0), + None, + Some(0), + Some(0), + None, + ])) as ArrayRef; + let field = Field::new("item", DataType::Int32, true); + + let levels = calculate_array_levels(&array, &field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![1, 0, 1, 1, 0]), + rep_levels: None, + non_null_indices: vec![0, 2, 3], + max_def_level: 1, + max_rep_level: 0, + }; + assert_eq!(&levels[0], &expected_levels); + } + + #[test] + fn test_calculate_array_levels_1() { + let leaf_field = Field::new("item", DataType::Int32, false); + let list_type = DataType::List(Box::new(leaf_field)); + + // if all array values are defined (e.g. batch>) + // [[0], [1], [2], [3], [4]] + + let leaf_array = Int32Array::from_iter(0..5); + // Cannot use from_iter_primitive as always infers nullable + let offsets = Buffer::from_iter(0_i32..6); + let list = ArrayDataBuilder::new(list_type.clone()) + .len(5) + .add_buffer(offsets) + .add_child_data(leaf_array.data().clone()) + .build() + .unwrap(); + let list = make_array(list); + + let list_field = Field::new("list", list_type.clone(), false); + let levels = calculate_array_levels(&list, &list_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![1; 5]), + rep_levels: Some(vec![0; 5]), + non_null_indices: (0..5).collect(), + max_def_level: 1, + max_rep_level: 1, + }; + assert_eq!(&levels[0], &expected_levels); + + // array: [[0, 0], NULL, [2, 2], [3, 3, 3, 3], [4, 4, 4]] + // all values are defined as we do not have nulls on the root (batch) + // repetition: + // 0: 0, 1 + // 1: 0 + // 2: 0, 1 + // 3: 0, 1, 1, 1 + // 4: 0, 1, 1 + let leaf_array = Int32Array::from_iter([0, 0, 2, 2, 3, 3, 3, 3, 4, 4, 4]); + let offsets = Buffer::from_iter([0_i32, 2, 2, 4, 8, 11]); + let list = ArrayDataBuilder::new(list_type.clone()) + .len(5) + .add_buffer(offsets) + .add_child_data(leaf_array.data().clone()) + .null_bit_buffer(Some(Buffer::from([0b00011101]))) + .build() + .unwrap(); + let list = make_array(list); + + let list_field = Field::new("list", list_type, true); + let levels = calculate_array_levels(&list, &list_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2]), + rep_levels: Some(vec![0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1]), + non_null_indices: (0..11).collect(), + max_def_level: 2, + max_rep_level: 1, + }; + assert_eq!(&levels[0], &expected_levels); + } + + #[test] + fn test_calculate_array_levels_2() { + // If some values are null + // + // This emulates an array in the form: > + // with values: + // - 0: [0, 1], but is null because of the struct + // - 1: [] + // - 2: [2, 3], but is null because of the struct + // - 3: [4, 5, 6, 7] + // - 4: [8, 9, 10] + // + // If the first values of a list are null due to a parent, we have to still account for them + // while indexing, because they would affect the way the child is indexed + // i.e. in the above example, we have to know that [0, 1] has to be skipped + let leaf = Int32Array::from_iter(0..11); + let leaf_field = Field::new("leaf", DataType::Int32, false); + + let list_type = DataType::List(Box::new(leaf_field)); + let list = ArrayData::builder(list_type.clone()) + .len(5) + .add_child_data(leaf.data().clone()) + .add_buffer(Buffer::from_iter([0_i32, 2, 2, 4, 8, 11])) + .build() + .unwrap(); + + let list = make_array(list); + let list_field = Field::new("list", list_type, true); + + let struct_array = + StructArray::from((vec![(list_field, list)], Buffer::from([0b00011010]))); + let array = Arc::new(struct_array) as ArrayRef; + + let struct_field = Field::new("struct", array.data_type().clone(), true); + + let levels = calculate_array_levels(&array, &struct_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![0, 2, 0, 3, 3, 3, 3, 3, 3, 3]), + rep_levels: Some(vec![0, 0, 0, 0, 1, 1, 1, 0, 1, 1]), + non_null_indices: (4..11).collect(), + max_def_level: 3, + max_rep_level: 1, + }; + + assert_eq!(&levels[0], &expected_levels); + + // nested lists + + // 0: [[100, 101], [102, 103]] + // 1: [] + // 2: [[104, 105], [106, 107]] + // 3: [[108, 109], [110, 111], [112, 113], [114, 115]] + // 4: [[116, 117], [118, 119], [120, 121]] + + let leaf = Int32Array::from_iter(100..122); + let leaf_field = Field::new("leaf", DataType::Int32, true); + + let l1_type = DataType::List(Box::new(leaf_field)); + let offsets = Buffer::from_iter([0_i32, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22]); + let l1 = ArrayData::builder(l1_type.clone()) + .len(11) + .add_child_data(leaf.data().clone()) + .add_buffer(offsets) + .build() + .unwrap(); + + let l1_field = Field::new("l1", l1_type, true); + let l2_type = DataType::List(Box::new(l1_field)); + let l2 = ArrayData::builder(l2_type) + .len(5) + .add_child_data(l1) + .add_buffer(Buffer::from_iter([0, 2, 2, 4, 8, 11])) + .build() + .unwrap(); + + let l2 = make_array(l2); + let l2_field = Field::new("l2", l2.data_type().clone(), true); + + let levels = calculate_array_levels(&l2, &l2_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![ + 5, 5, 5, 5, 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + ]), + rep_levels: Some(vec![ + 0, 2, 1, 2, 0, 0, 2, 1, 2, 0, 2, 1, 2, 1, 2, 1, 2, 0, 2, 1, 2, 1, 2, + ]), + non_null_indices: (0..22).collect(), + max_def_level: 5, + max_rep_level: 2, + }; + + assert_eq!(&levels[0], &expected_levels); + } + + #[test] + fn test_calculate_array_levels_nested_list() { + let leaf_field = Field::new("leaf", DataType::Int32, false); + let list_type = DataType::List(Box::new(leaf_field)); + + // if all array values are defined (e.g. batch>) + // The array at this level looks like: + // 0: [a] + // 1: [a] + // 2: [a] + // 3: [a] + + let leaf = Int32Array::from_iter([0; 4]); + let list = ArrayData::builder(list_type.clone()) + .len(4) + .add_buffer(Buffer::from_iter(0_i32..5)) + .add_child_data(leaf.data().clone()) + .build() + .unwrap(); + let list = make_array(list); + + let list_field = Field::new("list", list_type.clone(), false); + let levels = calculate_array_levels(&list, &list_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![1; 4]), + rep_levels: Some(vec![0; 4]), + non_null_indices: (0..4).collect(), + max_def_level: 1, + max_rep_level: 1, + }; + assert_eq!(&levels[0], &expected_levels); + + // 0: null + // 1: [1, 2, 3] + // 2: [4, 5] + // 3: [6, 7] + let leaf = Int32Array::from_iter(0..8); + let list = ArrayData::builder(list_type.clone()) + .len(4) + .add_buffer(Buffer::from_iter([0_i32, 0, 3, 5, 7])) + .null_bit_buffer(Some(Buffer::from([0b00001110]))) + .add_child_data(leaf.data().clone()) + .build() + .unwrap(); + let list = make_array(list); + let list_field = Field::new("list", list_type, true); + + let struct_array = StructArray::from(vec![(list_field, list)]); + let array = Arc::new(struct_array) as ArrayRef; + + let struct_field = Field::new("struct", array.data_type().clone(), true); + let levels = calculate_array_levels(&array, &struct_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![1, 3, 3, 3, 3, 3, 3, 3]), + rep_levels: Some(vec![0, 0, 1, 1, 0, 1, 0, 1]), + non_null_indices: (0..7).collect(), + max_def_level: 3, + max_rep_level: 1, + }; + assert_eq!(&levels[0], &expected_levels); + + // nested lists + // In a JSON syntax with the schema: >>>, this translates into: + // 0: {"struct": null } + // 1: {"struct": [ [201], [202, 203], [] ]} + // 2: {"struct": [ [204, 205, 206], [207, 208, 209, 210] ]} + // 3: {"struct": [ [], [211, 212, 213, 214, 215] ]} + + let leaf = Int32Array::from_iter(201..216); + let leaf_field = Field::new("leaf", DataType::Int32, false); + let list_1_type = DataType::List(Box::new(leaf_field)); + let list_1 = ArrayData::builder(list_1_type.clone()) + .len(7) + .add_buffer(Buffer::from_iter([0_i32, 1, 3, 3, 6, 10, 10, 15])) + .add_child_data(leaf.data().clone()) + .build() + .unwrap(); + + let list_1_field = Field::new("l1", list_1_type, true); + let list_2_type = DataType::List(Box::new(list_1_field)); + let list_2 = ArrayData::builder(list_2_type.clone()) + .len(4) + .add_buffer(Buffer::from_iter([0_i32, 0, 3, 5, 7])) + .null_bit_buffer(Some(Buffer::from([0b00001110]))) + .add_child_data(list_1) + .build() + .unwrap(); + + let list_2 = make_array(list_2); + let list_2_field = Field::new("list_2", list_2_type, true); + + let struct_array = + StructArray::from((vec![(list_2_field, list_2)], Buffer::from([0b00001111]))); + let struct_field = Field::new("struct", struct_array.data_type().clone(), true); + + let array = Arc::new(struct_array) as ArrayRef; + let levels = calculate_array_levels(&array, &struct_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![1, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5, 4, 5, 5, 5, 5, 5]), + rep_levels: Some(vec![0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 2, 0, 1, 2, 2, 2, 2]), + non_null_indices: (0..15).collect(), + max_def_level: 5, + max_rep_level: 2, + }; + assert_eq!(&levels[0], &expected_levels); + } + + #[test] + fn test_calculate_nested_struct_levels() { + // tests a > + // array: + // - {a: {b: {c: 1}}} + // - {a: {b: {c: null}}} + // - {a: {b: {c: 3}}} + // - {a: {b: null}} + // - {a: null}} + // - {a: {b: {c: 6}}} + + let c = Int32Array::from_iter([Some(1), None, Some(3), None, Some(5), Some(6)]); + let c_field = Field::new("c", DataType::Int32, true); + let b = StructArray::from(( + (vec![(c_field, Arc::new(c) as ArrayRef)]), + Buffer::from([0b00110111]), + )); + + let b_field = Field::new("b", b.data_type().clone(), true); + let a = StructArray::from(( + (vec![(b_field, Arc::new(b) as ArrayRef)]), + Buffer::from([0b00101111]), + )); + + let a_field = Field::new("a", a.data_type().clone(), true); + let a_array = Arc::new(a) as ArrayRef; + + let levels = calculate_array_levels(&a_array, &a_field).unwrap(); + assert_eq!(levels.len(), 1); + + let expected_levels = LevelInfo { + def_levels: Some(vec![3, 2, 3, 1, 0, 3]), + rep_levels: None, + non_null_indices: vec![0, 2, 5], + max_def_level: 3, + max_rep_level: 0, + }; + assert_eq!(&levels[0], &expected_levels); + } + + #[test] + fn list_single_column() { + // this tests the level generation from the arrow_writer equivalent test + + let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let a_value_offsets = arrow::buffer::Buffer::from_iter([0_i32, 1, 3, 3, 6, 10]); + let a_list_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let a_list_data = ArrayData::builder(a_list_type.clone()) + .len(5) + .add_buffer(a_value_offsets) + .null_bit_buffer(Some(Buffer::from(vec![0b00011011]))) + .add_child_data(a_values.data().clone()) + .build() + .unwrap(); + + assert_eq!(a_list_data.null_count(), 1); + + let a = ListArray::from(a_list_data); + let values = Arc::new(a) as _; + + let item_field = Field::new("item", a_list_type, true); + let mut builder = + LevelInfoBuilder::try_new(&item_field, Default::default()).unwrap(); + builder.write(&values, 2..4); + let levels = builder.finish(); + + assert_eq!(levels.len(), 1); + + let list_level = levels.get(0).unwrap(); + + let expected_level = LevelInfo { + def_levels: Some(vec![0, 3, 3, 3]), + rep_levels: Some(vec![0, 0, 1, 1]), + non_null_indices: vec![3, 4, 5], + max_def_level: 3, + max_rep_level: 1, + }; + assert_eq!(list_level, &expected_level); + } + + #[test] + fn mixed_struct_list() { + // this tests the level generation from the equivalent arrow_writer_complex test + + // define schema + let struct_field_d = Field::new("d", DataType::Float64, true); + let struct_field_f = Field::new("f", DataType::Float32, true); + let struct_field_g = Field::new( + "g", + DataType::List(Box::new(Field::new("items", DataType::Int16, false))), + false, + ); + let struct_field_e = Field::new( + "e", + DataType::Struct(vec![struct_field_f.clone(), struct_field_g.clone()]), + true, + ); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + Field::new( + "c", + DataType::Struct(vec![struct_field_d.clone(), struct_field_e.clone()]), + true, // https://github.com/apache/arrow-rs/issues/245 + ), + ]); + + // create some data + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); + let d = Float64Array::from(vec![None, None, None, Some(1.0), None]); + let f = Float32Array::from(vec![Some(0.0), None, Some(333.3), None, Some(5.25)]); + + let g_value = Int16Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + // Construct a buffer for value offsets, for the nested array: + // [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]] + let g_value_offsets = + arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); + + // Construct a list array from the above two + let g_list_data = ArrayData::builder(struct_field_g.data_type().clone()) + .len(5) + .add_buffer(g_value_offsets) + .add_child_data(g_value.data().clone()) + .build() + .unwrap(); + let g = ListArray::from(g_list_data); + + let e = StructArray::from(vec![ + (struct_field_f, Arc::new(f) as ArrayRef), + (struct_field_g, Arc::new(g) as ArrayRef), + ]); + + let c = StructArray::from(vec![ + (struct_field_d, Arc::new(d) as ArrayRef), + (struct_field_e, Arc::new(e) as ArrayRef), + ]); + + // build a record batch + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a), Arc::new(b), Arc::new(c)], + ) + .unwrap(); + + ////////////////////////////////////////////// + // calculate the list's level + let mut levels = vec![]; + batch + .columns() + .iter() + .zip(batch.schema().fields()) + .for_each(|(array, field)| { + let mut array_levels = calculate_array_levels(array, field).unwrap(); + levels.append(&mut array_levels); + }); + assert_eq!(levels.len(), 5); + + // test "a" levels + let list_level = levels.get(0).unwrap(); + + let expected_level = LevelInfo { + def_levels: None, + rep_levels: None, + non_null_indices: vec![0, 1, 2, 3, 4], + max_def_level: 0, + max_rep_level: 0, + }; + assert_eq!(list_level, &expected_level); + + // test "b" levels + let list_level = levels.get(1).unwrap(); + + let expected_level = LevelInfo { + def_levels: Some(vec![1, 0, 0, 1, 1]), + rep_levels: None, + non_null_indices: vec![0, 3, 4], + max_def_level: 1, + max_rep_level: 0, + }; + assert_eq!(list_level, &expected_level); + + // test "d" levels + let list_level = levels.get(2).unwrap(); + + let expected_level = LevelInfo { + def_levels: Some(vec![1, 1, 1, 2, 1]), + rep_levels: None, + non_null_indices: vec![3], + max_def_level: 2, + max_rep_level: 0, + }; + assert_eq!(list_level, &expected_level); + + // test "f" levels + let list_level = levels.get(3).unwrap(); + + let expected_level = LevelInfo { + def_levels: Some(vec![3, 2, 3, 2, 3]), + rep_levels: None, + non_null_indices: vec![0, 2, 4], + max_def_level: 3, + max_rep_level: 0, + }; + assert_eq!(list_level, &expected_level); + } + + #[test] + fn test_null_vs_nonnull_struct() { + // define schema + let offset_field = Field::new("offset", DataType::Int32, true); + let schema = Schema::new(vec![Field::new( + "some_nested_object", + DataType::Struct(vec![offset_field.clone()]), + false, + )]); + + // create some data + let offset = Int32Array::from(vec![1, 2, 3, 4, 5]); + + let some_nested_object = + StructArray::from(vec![(offset_field, Arc::new(offset) as ArrayRef)]); + + // build a record batch + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(some_nested_object)]) + .unwrap(); + + let struct_null_level = + calculate_array_levels(batch.column(0), batch.schema().field(0)); + + // create second batch + // define schema + let offset_field = Field::new("offset", DataType::Int32, true); + let schema = Schema::new(vec![Field::new( + "some_nested_object", + DataType::Struct(vec![offset_field.clone()]), + true, + )]); + + // create some data + let offset = Int32Array::from(vec![1, 2, 3, 4, 5]); + + let some_nested_object = + StructArray::from(vec![(offset_field, Arc::new(offset) as ArrayRef)]); + + // build a record batch + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(some_nested_object)]) + .unwrap(); + + let struct_non_null_level = + calculate_array_levels(batch.column(0), batch.schema().field(0)); + + // The 2 levels should not be the same + if struct_non_null_level == struct_null_level { + panic!("Levels should not be equal, to reflect the difference in struct nullness"); + } + } + + #[test] + fn test_map_array() { + // Note: we are using the JSON Arrow reader for brevity + let json_content = r#" + {"stocks":{"long": "$AAA", "short": "$BBB"}} + {"stocks":{"long": null, "long": "$CCC", "short": null}} + {"stocks":{"hedged": "$YYY", "long": null, "short": "$D"}} + "#; + let entries_struct_type = DataType::Struct(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), + ]); + let stocks_field = Field::new( + "stocks", + DataType::Map( + Box::new(Field::new("entries", entries_struct_type, false)), + false, + ), + // not nullable, so the keys have max level = 1 + false, + ); + let schema = Arc::new(Schema::new(vec![stocks_field])); + let builder = arrow::json::ReaderBuilder::new() + .with_schema(schema) + .with_batch_size(64); + let mut reader = builder.build(std::io::Cursor::new(json_content)).unwrap(); + + let batch = reader.next().unwrap().unwrap(); + + // calculate the map's level + let mut levels = vec![]; + batch + .columns() + .iter() + .zip(batch.schema().fields()) + .for_each(|(array, field)| { + let mut array_levels = calculate_array_levels(array, field).unwrap(); + levels.append(&mut array_levels); + }); + assert_eq!(levels.len(), 2); + + // test key levels + let list_level = levels.get(0).unwrap(); + + let expected_level = LevelInfo { + def_levels: Some(vec![1; 7]), + rep_levels: Some(vec![0, 1, 0, 1, 0, 1, 1]), + non_null_indices: vec![0, 1, 2, 3, 4, 5, 6], + max_def_level: 1, + max_rep_level: 1, + }; + assert_eq!(list_level, &expected_level); + + // test values levels + let list_level = levels.get(1).unwrap(); + + let expected_level = LevelInfo { + def_levels: Some(vec![2, 2, 2, 1, 2, 1, 2]), + rep_levels: Some(vec![0, 1, 0, 1, 0, 1, 1]), + non_null_indices: vec![0, 1, 2, 4, 6], + max_def_level: 2, + max_rep_level: 1, + }; + assert_eq!(list_level, &expected_level); + } + + #[test] + fn test_list_of_struct() { + // define schema + let int_field = Field::new("a", DataType::Int32, true); + let item_field = + Field::new("item", DataType::Struct(vec![int_field.clone()]), true); + let list_field = Field::new("list", DataType::List(Box::new(item_field)), true); + + let int_builder = Int32Builder::new(10); + let struct_builder = + StructBuilder::new(vec![int_field], vec![Box::new(int_builder)]); + let mut list_builder = ListBuilder::new(struct_builder); + + // [{a: 1}], [], null, [null, null], [{a: null}], [{a: 2}] + // + // [{a: 1}] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_value(1) + .unwrap(); + values.append(true).unwrap(); + list_builder.append(true).unwrap(); + + // [] + list_builder.append(true).unwrap(); + + // null + list_builder.append(false).unwrap(); + + // [null, null] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_null() + .unwrap(); + values.append(false).unwrap(); + values + .field_builder::(0) + .unwrap() + .append_null() + .unwrap(); + values.append(false).unwrap(); + list_builder.append(true).unwrap(); + + // [{a: null}] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_null() + .unwrap(); + values.append(true).unwrap(); + list_builder.append(true).unwrap(); + + // [{a: 2}] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_value(2) + .unwrap(); + values.append(true).unwrap(); + list_builder.append(true).unwrap(); + + let array = Arc::new(list_builder.finish()); + + let values_len = array.data().child_data()[0].len(); + assert_eq!(values_len, 5); + + let schema = Arc::new(Schema::new(vec![list_field])); + + let rb = RecordBatch::try_new(schema, vec![array]).unwrap(); + + let levels = calculate_array_levels(rb.column(0), rb.schema().field(0)).unwrap(); + let list_level = &levels[0]; + + let expected_level = LevelInfo { + def_levels: Some(vec![4, 1, 0, 2, 2, 3, 4]), + rep_levels: Some(vec![0, 0, 0, 0, 1, 0, 0]), + non_null_indices: vec![0, 4], + max_def_level: 4, + max_rep_level: 1, + }; + + assert_eq!(list_level, &expected_level); + } + + #[test] + fn test_struct_mask_list() { + // Test the null mask of a struct array masking out non-empty slices of a child ListArray + let inner = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![None]), + Some(vec![]), + Some(vec![Some(3), None]), // Masked by struct array + Some(vec![Some(4), Some(5)]), + None, // Masked by struct array + None, + ]); + + // This test assumes that nulls don't take up space + assert_eq!(inner.data().child_data()[0].len(), 7); + + let field = Field::new("list", inner.data_type().clone(), true); + let array = Arc::new(inner) as ArrayRef; + let nulls = Buffer::from([0b01010111]); + let struct_a = StructArray::from((vec![(field, array)], nulls)); + + let field = Field::new("struct", struct_a.data_type().clone(), true); + let array = Arc::new(struct_a) as ArrayRef; + let levels = calculate_array_levels(&array, &field).unwrap(); + + assert_eq!(levels.len(), 1); + + let expected_level = LevelInfo { + def_levels: Some(vec![4, 4, 3, 2, 0, 4, 4, 0, 1]), + rep_levels: Some(vec![0, 1, 0, 0, 0, 0, 1, 0, 0]), + non_null_indices: vec![0, 1, 5, 6], + max_def_level: 4, + max_rep_level: 1, + }; + + assert_eq!(&levels[0], &expected_level); + } + + #[test] + fn test_list_mask_struct() { + // Test the null mask of a struct array and the null mask of a list array + // masking out non-null elements of their children + + let a1 = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![None]), // Masked by list array + Some(vec![]), // Masked by list array + Some(vec![Some(3), None]), + Some(vec![Some(4), Some(5), None, Some(6)]), // Masked by struct array + None, + None, + ])) as ArrayRef; + + let a2 = Arc::new(Int32Array::from_iter(vec![ + Some(1), // Masked by list array + Some(2), // Masked by list array + None, + Some(4), // Masked by struct array + Some(5), + None, + ])) as ArrayRef; + + let field_a1 = Field::new("list", a1.data_type().clone(), true); + let field_a2 = Field::new("integers", a2.data_type().clone(), true); + + let nulls = Buffer::from([0b00110111]); + let struct_a = Arc::new( + StructArray::try_from((vec![(field_a1, a1), (field_a2, a2)], nulls)).unwrap(), + ) as ArrayRef; + + let offsets = Buffer::from_iter([0_i32, 0, 2, 2, 3, 5, 5]); + let nulls = Buffer::from([0b00111100]); + + let list_type = DataType::List(Box::new(Field::new( + "struct", + struct_a.data_type().clone(), + true, + ))); + + let data = ArrayDataBuilder::new(list_type.clone()) + .len(6) + .null_bit_buffer(Some(nulls)) + .add_buffer(offsets) + .add_child_data(struct_a.data().clone()) + .build() + .unwrap(); + + let list = make_array(data); + let list_field = Field::new("col", list_type, true); + + let expected = vec![ + r#"+-------------------------------------+"#, + r#"| col |"#, + r#"+-------------------------------------+"#, + r#"| |"#, + r#"| |"#, + r#"| [] |"#, + r#"| [{"list": [3, ], "integers": null}] |"#, + r#"| [, {"list": null, "integers": 5}] |"#, + r#"| [] |"#, + r#"+-------------------------------------+"#, + ] + .join("\n"); + + let pretty = pretty_format_columns(list_field.name(), &[list.clone()]).unwrap(); + assert_eq!(pretty.to_string(), expected); + + let levels = calculate_array_levels(&list, &list_field).unwrap(); + + assert_eq!(levels.len(), 2); + + let expected_level = LevelInfo { + def_levels: Some(vec![0, 0, 1, 6, 5, 2, 3, 1]), + rep_levels: Some(vec![0, 0, 0, 0, 2, 0, 1, 0]), + non_null_indices: vec![1], + max_def_level: 6, + max_rep_level: 2, + }; + + assert_eq!(&levels[0], &expected_level); + + let expected_level = LevelInfo { + def_levels: Some(vec![0, 0, 1, 3, 2, 4, 1]), + rep_levels: Some(vec![0, 0, 0, 0, 0, 1, 0]), + non_null_indices: vec![4], + max_def_level: 4, + max_rep_level: 1, + }; + + assert_eq!(&levels[1], &expected_level); + } +} diff --git a/parquet/src/arrow/arrow_writer.rs b/parquet/src/arrow/arrow_writer/mod.rs similarity index 62% rename from parquet/src/arrow/arrow_writer.rs rename to parquet/src/arrow/arrow_writer/mod.rs index 8600eb0b5101..83f1bc70b525 100644 --- a/parquet/src/arrow/arrow_writer.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -17,41 +17,57 @@ //! Contains writer which writes arrow data into parquet data. +use std::collections::VecDeque; +use std::io::Write; use std::sync::Arc; use arrow::array as arrow_array; +use arrow::array::ArrayRef; use arrow::datatypes::{DataType as ArrowDataType, IntervalUnit, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_array::Array; -use super::levels::LevelInfo; use super::schema::{ - add_encoded_arrow_schema_to_metadata, decimal_length_from_precision, + add_encoded_arrow_schema_to_metadata, arrow_to_parquet_schema, + decimal_length_from_precision, }; use crate::column::writer::ColumnWriter; use crate::errors::{ParquetError, Result}; +use crate::file::metadata::RowGroupMetaDataPtr; use crate::file::properties::WriterProperties; -use crate::{ - data_type::*, - file::writer::{FileWriter, ParquetWriter, RowGroupWriter, SerializedFileWriter}, -}; +use crate::file::writer::{SerializedColumnWriter, SerializedRowGroupWriter}; +use crate::{data_type::*, file::writer::SerializedFileWriter}; +use levels::{calculate_array_levels, LevelInfo}; + +mod levels; /// Arrow writer /// -/// Writes Arrow `RecordBatch`es to a Parquet writer -pub struct ArrowWriter { +/// Writes Arrow `RecordBatch`es to a Parquet writer, buffering up `RecordBatch` in order +/// to produce row groups with `max_row_group_size` rows. Any remaining rows will be +/// flushed on close, leading the final row group in the output file to potentially +/// contain fewer than `max_row_group_size` rows +pub struct ArrowWriter { /// Underlying Parquet writer writer: SerializedFileWriter, + + /// For each column, maintain an ordered queue of arrays to write + buffer: Vec>, + + /// The total number of rows currently buffered + buffered_rows: usize, + /// A copy of the Arrow schema. /// /// The schema is used to verify that each record batch written has the correct schema arrow_schema: SchemaRef, + /// The length of arrays to write to each row group max_row_group_size: usize, } -impl ArrowWriter { +impl ArrowWriter { /// Try to create a new Arrow writer /// /// The writer will fail if: @@ -62,33 +78,35 @@ impl ArrowWriter { arrow_schema: SchemaRef, props: Option, ) -> Result { - let schema = crate::arrow::arrow_to_parquet_schema(&arrow_schema)?; + let schema = arrow_to_parquet_schema(&arrow_schema)?; // add serialized arrow schema let mut props = props.unwrap_or_else(|| WriterProperties::builder().build()); add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut props); let max_row_group_size = props.max_row_group_size(); - let file_writer = SerializedFileWriter::new( - writer.try_clone()?, - schema.root_schema_ptr(), - Arc::new(props), - )?; + let file_writer = + SerializedFileWriter::new(writer, schema.root_schema_ptr(), Arc::new(props))?; Ok(Self { writer: file_writer, + buffer: vec![Default::default(); arrow_schema.fields().len()], + buffered_rows: 0, arrow_schema, max_row_group_size, }) } - /// Write a RecordBatch to writer + /// Returns metadata for any flushed row groups + pub fn flushed_row_groups(&self) -> &[RowGroupMetaDataPtr] { + self.writer.flushed_row_groups() + } + + /// Enqueues the provided `RecordBatch` to be written /// - /// The writer will slice the `batch` into `max_row_group_size`, - /// but if a batch has left-over rows less than the row group size, - /// the last row group will have fewer records. - /// This is currently a limitation because we close the row group - /// instead of keeping it open for the next batch. + /// If following this there are more than `max_row_group_size` rows buffered, + /// this will flush out one or more row groups with `max_row_group_size` rows, + /// and drop any fully written `RecordBatch` pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { // validate batch schema against writer's supplied schema if self.arrow_schema != batch.schema() { @@ -96,58 +114,122 @@ impl ArrowWriter { "Record batch schema does not match writer schema".to_string(), )); } - // Track the number of rows being written in the batch. - // We currently do not have a way of slicing nested arrays, thus we - // track this manually. - let num_rows = batch.num_rows(); - let batches = (num_rows + self.max_row_group_size - 1) / self.max_row_group_size; - let min_batch = num_rows.min(self.max_row_group_size); - for batch_index in 0..batches { - // Determine the offset and length of arrays - let offset = batch_index * min_batch; - let length = (num_rows - offset).min(self.max_row_group_size); - - // Compute the definition and repetition levels of the batch - let batch_level = LevelInfo::new(offset, length); - let mut row_group_writer = self.writer.next_row_group()?; - for (array, field) in batch.columns().iter().zip(batch.schema().fields()) { - let mut levels = batch_level.calculate_array_levels(array, field); - // Reverse levels as we pop() them when writing arrays - levels.reverse(); - write_leaves(&mut row_group_writer, array, &mut levels)?; + + for (buffer, column) in self.buffer.iter_mut().zip(batch.columns()) { + buffer.push_back(column.clone()) + } + + self.buffered_rows += batch.num_rows(); + self.flush_completed()?; + + Ok(()) + } + + /// Flushes buffered data until there are less than `max_row_group_size` rows buffered + fn flush_completed(&mut self) -> Result<()> { + while self.buffered_rows >= self.max_row_group_size { + self.flush_rows(self.max_row_group_size)?; + } + Ok(()) + } + + /// Flushes all buffered rows into a new row group + pub fn flush(&mut self) -> Result<()> { + self.flush_rows(self.buffered_rows) + } + + /// Flushes `num_rows` from the buffer into a new row group + fn flush_rows(&mut self, num_rows: usize) -> Result<()> { + if num_rows == 0 { + return Ok(()); + } + + assert!( + num_rows <= self.buffered_rows, + "cannot flush {} rows only have {}", + num_rows, + self.buffered_rows + ); + + assert!( + num_rows <= self.max_row_group_size, + "cannot flush {} rows would exceed max row group size of {}", + num_rows, + self.max_row_group_size + ); + + let mut row_group_writer = self.writer.next_row_group()?; + + for (col_buffer, field) in self.buffer.iter_mut().zip(self.arrow_schema.fields()) + { + // Collect the number of arrays to append + let mut remaining = num_rows; + let mut arrays = Vec::with_capacity(col_buffer.len()); + while remaining != 0 { + match col_buffer.pop_front() { + Some(next) if next.len() > remaining => { + col_buffer + .push_front(next.slice(remaining, next.len() - remaining)); + arrays.push(next.slice(0, remaining)); + remaining = 0; + } + Some(next) => { + remaining -= next.len(); + arrays.push(next); + } + _ => break, + } } - self.writer.close_row_group(row_group_writer)?; + let mut levels = arrays + .iter() + .map(|array| { + let mut levels = calculate_array_levels(array, field)?; + // Reverse levels as we pop() them when writing arrays + levels.reverse(); + Ok(levels) + }) + .collect::>>()?; + + write_leaves(&mut row_group_writer, &arrays, &mut levels)?; } + row_group_writer.close()?; + self.buffered_rows -= num_rows; + Ok(()) } /// Close and finalize the underlying Parquet writer - pub fn close(&mut self) -> Result { + pub fn close(mut self) -> Result { + self.flush()?; self.writer.close() } } /// Convenience method to get the next ColumnWriter from the RowGroupWriter #[inline] -#[allow(clippy::borrowed_box)] -fn get_col_writer( - row_group_writer: &mut Box, -) -> Result { +fn get_col_writer<'a, W: Write>( + row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>, +) -> Result> { let col_writer = row_group_writer .next_column()? .expect("Unable to get column writer"); Ok(col_writer) } -#[allow(clippy::borrowed_box)] -fn write_leaves( - mut row_group_writer: &mut Box, - array: &arrow_array::ArrayRef, - mut levels: &mut Vec, +fn write_leaves( + row_group_writer: &mut SerializedRowGroupWriter<'_, W>, + arrays: &[ArrayRef], + levels: &mut [Vec], ) -> Result<()> { - match array.data_type() { + assert_eq!(arrays.len(), levels.len()); + assert!(!arrays.is_empty()); + + let data_type = arrays.first().unwrap().data_type().clone(); + assert!(arrays.iter().all(|a| a.data_type() == &data_type)); + + match &data_type { ArrowDataType::Null | ArrowDataType::Boolean | ArrowDataType::Int8 @@ -173,62 +255,88 @@ fn write_leaves( | ArrowDataType::LargeUtf8 | ArrowDataType::Decimal(_, _) | ArrowDataType::FixedSizeBinary(_) => { - let mut col_writer = get_col_writer(&mut row_group_writer)?; - write_leaf( - &mut col_writer, - array, - levels.pop().expect("Levels exhausted"), - )?; - row_group_writer.close_column(col_writer)?; + let mut col_writer = get_col_writer(row_group_writer)?; + for (array, levels) in arrays.iter().zip(levels.iter_mut()) { + write_leaf( + col_writer.untyped(), + array, + levels.pop().expect("Levels exhausted"), + )?; + } + col_writer.close()?; Ok(()) } ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { - // write the child list - let data = array.data(); - let child_array = arrow_array::make_array(data.child_data()[0].clone()); - write_leaves(&mut row_group_writer, &child_array, &mut levels)?; + let arrays: Vec<_> = arrays.iter().map(|array|{ + // write the child list + let data = array.data(); + arrow_array::make_array(data.child_data()[0].clone()) + }).collect(); + + write_leaves(row_group_writer, &arrays, levels)?; Ok(()) } - ArrowDataType::Struct(_) => { - let struct_array: &arrow_array::StructArray = array - .as_any() - .downcast_ref::() - .expect("Unable to get struct array"); - for field in struct_array.columns() { - write_leaves(&mut row_group_writer, field, &mut levels)?; + ArrowDataType::Struct(fields) => { + // Groups child arrays by field + let mut field_arrays = vec![Vec::with_capacity(arrays.len()); fields.len()]; + + for array in arrays { + let struct_array: &arrow_array::StructArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get struct array"); + + assert_eq!(struct_array.columns().len(), fields.len()); + + for (child_array, field) in field_arrays.iter_mut().zip(struct_array.columns()) { + child_array.push(field.clone()) + } } + + for field in field_arrays { + write_leaves(row_group_writer, &field, levels)?; + } + Ok(()) } ArrowDataType::Map(_, _) => { - let map_array: &arrow_array::MapArray = array - .as_any() - .downcast_ref::() - .expect("Unable to get map array"); - write_leaves(&mut row_group_writer, &map_array.keys(), &mut levels)?; - write_leaves(&mut row_group_writer, &map_array.values(), &mut levels)?; + let mut keys = Vec::with_capacity(arrays.len()); + let mut values = Vec::with_capacity(arrays.len()); + for array in arrays { + let map_array: &arrow_array::MapArray = array + .as_any() + .downcast_ref::() + .expect("Unable to get map array"); + keys.push(map_array.keys()); + values.push(map_array.values()); + } + + write_leaves(row_group_writer, &keys, levels)?; + write_leaves(row_group_writer, &values, levels)?; Ok(()) } ArrowDataType::Dictionary(_, value_type) => { - // cast dictionary to a primitive - let array = arrow::compute::cast(array, value_type)?; - - let mut col_writer = get_col_writer(&mut row_group_writer)?; - write_leaf( - &mut col_writer, - &array, - levels.pop().expect("Levels exhausted"), - )?; - row_group_writer.close_column(col_writer)?; + let mut col_writer = get_col_writer(row_group_writer)?; + for (array, levels) in arrays.iter().zip(levels.iter_mut()) { + // cast dictionary to a primitive + let array = arrow::compute::cast(array, value_type)?; + write_leaf( + col_writer.untyped(), + &array, + levels.pop().expect("Levels exhausted"), + )?; + } + col_writer.close()?; Ok(()) } ArrowDataType::Float16 => Err(ParquetError::ArrowError( "Float16 arrays not supported".to_string(), )), - ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_) => { + ArrowDataType::FixedSizeList(_, _) | ArrowDataType::Union(_, _, _) => { Err(ParquetError::NYI( format!( "Attempting to write an Arrow type {:?} to parquet that is not yet implemented", - array.data_type() + data_type ) )) } @@ -236,30 +344,27 @@ fn write_leaves( } fn write_leaf( - writer: &mut ColumnWriter, - column: &arrow_array::ArrayRef, + writer: &mut ColumnWriter<'_>, + column: &ArrayRef, levels: LevelInfo, ) -> Result { - let indices = levels.filter_array_indices(); - // Slice array according to computed offset and length - let column = column.slice(levels.offset, levels.length); + let indices = levels.non_null_indices(); let written = match writer { ColumnWriter::Int32ColumnWriter(ref mut typed) => { let values = match column.data_type() { ArrowDataType::Date64 => { // If the column is a Date64, we cast it to a Date32, and then interpret that as Int32 let array = if let ArrowDataType::Date64 = column.data_type() { - let array = - arrow::compute::cast(&column, &ArrowDataType::Date32)?; + let array = arrow::compute::cast(column, &ArrowDataType::Date32)?; arrow::compute::cast(&array, &ArrowDataType::Int32)? } else { - arrow::compute::cast(&column, &ArrowDataType::Int32)? + arrow::compute::cast(column, &ArrowDataType::Int32)? }; let array = array .as_any() .downcast_ref::() .expect("Unable to get int32 array"); - get_numeric_array_slice::(array, &indices) + get_numeric_array_slice::(array, indices) } ArrowDataType::UInt32 => { // follow C++ implementation and use overflow/reinterpret cast from u32 to i32 which will map @@ -272,21 +377,21 @@ fn write_leaf( array, |x| x as i32, ); - get_numeric_array_slice::(&array, &indices) + get_numeric_array_slice::(&array, indices) } _ => { - let array = arrow::compute::cast(&column, &ArrowDataType::Int32)?; + let array = arrow::compute::cast(column, &ArrowDataType::Int32)?; let array = array .as_any() .downcast_ref::() .expect("Unable to get i32 array"); - get_numeric_array_slice::(array, &indices) + get_numeric_array_slice::(array, indices) } }; typed.write_batch( values.as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + levels.def_levels(), + levels.rep_levels(), )? } ColumnWriter::BoolColumnWriter(ref mut typed) => { @@ -295,9 +400,9 @@ fn write_leaf( .downcast_ref::() .expect("Unable to get boolean array"); typed.write_batch( - get_bool_array_slice(array, &indices).as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + get_bool_array_slice(array, indices).as_slice(), + levels.def_levels(), + levels.rep_levels(), )? } ColumnWriter::Int64ColumnWriter(ref mut typed) => { @@ -307,7 +412,7 @@ fn write_leaf( .as_any() .downcast_ref::() .expect("Unable to get i64 array"); - get_numeric_array_slice::(array, &indices) + get_numeric_array_slice::(array, indices) } ArrowDataType::UInt64 => { // follow C++ implementation and use overflow/reinterpret cast from u64 to i64 which will map @@ -320,21 +425,21 @@ fn write_leaf( array, |x| x as i64, ); - get_numeric_array_slice::(&array, &indices) + get_numeric_array_slice::(&array, indices) } _ => { - let array = arrow::compute::cast(&column, &ArrowDataType::Int64)?; + let array = arrow::compute::cast(column, &ArrowDataType::Int64)?; let array = array .as_any() .downcast_ref::() .expect("Unable to get i64 array"); - get_numeric_array_slice::(array, &indices) + get_numeric_array_slice::(array, indices) } }; typed.write_batch( values.as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + levels.def_levels(), + levels.rep_levels(), )? } ColumnWriter::Int96ColumnWriter(ref mut _typed) => { @@ -346,9 +451,9 @@ fn write_leaf( .downcast_ref::() .expect("Unable to get Float32 array"); typed.write_batch( - get_numeric_array_slice::(array, &indices).as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + get_numeric_array_slice::(array, indices).as_slice(), + levels.def_levels(), + levels.rep_levels(), )? } ColumnWriter::DoubleColumnWriter(ref mut typed) => { @@ -357,9 +462,9 @@ fn write_leaf( .downcast_ref::() .expect("Unable to get Float64 array"); typed.write_batch( - get_numeric_array_slice::(array, &indices).as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + get_numeric_array_slice::(array, indices).as_slice(), + levels.def_levels(), + levels.rep_levels(), )? } ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match column.data_type() { @@ -370,8 +475,8 @@ fn write_leaf( .expect("Unable to get BinaryArray array"); typed.write_batch( get_binary_array(array).as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + levels.def_levels(), + levels.rep_levels(), )? } ArrowDataType::Utf8 => { @@ -381,8 +486,8 @@ fn write_leaf( .expect("Unable to get LargeBinaryArray array"); typed.write_batch( get_string_array(array).as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + levels.def_levels(), + levels.rep_levels(), )? } ArrowDataType::LargeBinary => { @@ -392,8 +497,8 @@ fn write_leaf( .expect("Unable to get LargeBinaryArray array"); typed.write_batch( get_large_binary_array(array).as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + levels.def_levels(), + levels.rep_levels(), )? } ArrowDataType::LargeUtf8 => { @@ -403,8 +508,8 @@ fn write_leaf( .expect("Unable to get LargeUtf8 array"); typed.write_batch( get_large_string_array(array).as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + levels.def_levels(), + levels.rep_levels(), )? } _ => unreachable!("Currently unreachable because data type not supported"), @@ -417,14 +522,22 @@ fn write_leaf( .as_any() .downcast_ref::() .unwrap(); - get_interval_ym_array_slice(array, &indices) + get_interval_ym_array_slice(array, indices) } IntervalUnit::DayTime => { let array = column .as_any() .downcast_ref::() .unwrap(); - get_interval_dt_array_slice(array, &indices) + get_interval_dt_array_slice(array, indices) + } + _ => { + return Err(ParquetError::NYI( + format!( + "Attempting to write an Arrow interval type {:?} to parquet that is not yet implemented", + interval_unit + ) + )); } }, ArrowDataType::FixedSizeBinary(_) => { @@ -432,14 +545,14 @@ fn write_leaf( .as_any() .downcast_ref::() .unwrap(); - get_fsb_array_slice(array, &indices) + get_fsb_array_slice(array, indices) } ArrowDataType::Decimal(_, _) => { let array = column .as_any() .downcast_ref::() .unwrap(); - get_decimal_array_slice(array, &indices) + get_decimal_array_slice(array, indices) } _ => { return Err(ParquetError::NYI( @@ -450,8 +563,8 @@ fn write_leaf( }; typed.write_batch( bytes.as_slice(), - Some(levels.definition.as_slice()), - levels.repetition.as_deref(), + levels.def_levels(), + levels.rep_levels(), )? } }; @@ -462,8 +575,8 @@ macro_rules! def_get_binary_array_fn { ($name:ident, $ty:ty) => { fn $name(array: &$ty) -> Vec { let mut byte_array = ByteArray::new(); - let ptr = crate::memory::ByteBufferPtr::new( - unsafe { array.value_data().typed_data::() }.to_vec(), + let ptr = crate::util::memory::ByteBufferPtr::new( + array.value_data().as_slice().to_vec(), ); byte_array.set_data(ptr); array @@ -484,6 +597,7 @@ macro_rules! def_get_binary_array_fn { }; } +// TODO: These methods don't handle non null indices correctly (#1753) def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray); def_get_binary_array_fn!(get_string_array, arrow_array::StringArray); def_get_binary_array_fn!(get_large_binary_array, arrow_array::LargeBinaryArray); @@ -559,7 +673,7 @@ fn get_decimal_array_slice( let mut values = Vec::with_capacity(indices.len()); let size = decimal_length_from_precision(array.precision()); for i in indices { - let as_be_bytes = array.value(*i).to_be_bytes(); + let as_be_bytes = array.value(*i).as_i128().to_be_bytes(); let resized_value = as_be_bytes[(16 - size)..].to_vec(); values.push(FixedLenByteArray::from(ByteArray::from(resized_value))); } @@ -582,21 +696,23 @@ fn get_fsb_array_slice( mod tests { use super::*; + use bytes::Bytes; use std::fs::File; use std::sync::Arc; use arrow::datatypes::ToByteSlice; use arrow::datatypes::{DataType, Field, Schema, UInt32Type, UInt8Type}; + use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; + use arrow::util::pretty::pretty_format_batches; use arrow::{array::*, buffer::Buffer}; use crate::arrow::{ArrowReader, ParquetFileArrowReader}; + use crate::file::metadata::ParquetMetaData; use crate::file::{ reader::{FileReader, SerializedFileReader}, statistics::Statistics, - writer::InMemoryWriteableCursor, }; - use crate::util::test_common::get_temp_file; #[test] fn arrow_writer() { @@ -615,7 +731,7 @@ mod tests { RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) .unwrap(); - roundtrip("test_arrow_write.parquet", batch, Some(SMALL_SIZE / 2)); + roundtrip(batch, Some(SMALL_SIZE / 2)); } #[test] @@ -634,19 +750,16 @@ mod tests { let expected_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap(); - let cursor = InMemoryWriteableCursor::default(); + let mut buffer = vec![]; { - let mut writer = ArrowWriter::try_new(cursor.clone(), schema, None).unwrap(); + let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap(); writer.write(&expected_batch).unwrap(); writer.close().unwrap(); } - let buffer = cursor.into_inner().unwrap(); - - let cursor = crate::file::serialized_reader::SliceableCursor::new(buffer); - let reader = SerializedFileReader::new(cursor).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(reader)); + let cursor = Bytes::from(buffer); + let mut arrow_reader = ParquetFileArrowReader::try_new(cursor).unwrap(); let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); let actual_batch = record_batch_reader @@ -676,11 +789,7 @@ mod tests { // build a record batch let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); - roundtrip( - "test_arrow_writer_non_null.parquet", - batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(batch, Some(SMALL_SIZE / 2)); } #[test] @@ -709,7 +818,7 @@ mod tests { .len(5) .add_buffer(a_value_offsets) .add_child_data(a_values.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00011011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00011011]))) .build() .unwrap(); let a = ListArray::from(a_list_data); @@ -721,7 +830,7 @@ mod tests { // This test fails if the max row group size is less than the batch's length // see https://github.com/apache/arrow-rs/issues/518 - roundtrip("test_arrow_writer_list.parquet", batch, None); + roundtrip(batch, None); } #[test] @@ -761,7 +870,7 @@ mod tests { // see https://github.com/apache/arrow-rs/issues/518 assert_eq!(batch.column(0).data().null_count(), 0); - roundtrip("test_arrow_writer_list_non_null.parquet", batch, None); + roundtrip(batch, None); } #[test] @@ -790,11 +899,7 @@ mod tests { ) .unwrap(); - roundtrip( - "test_arrow_writer_binary.parquet", - batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(batch, Some(SMALL_SIZE / 2)); } #[test] @@ -802,22 +907,18 @@ mod tests { let decimal_field = Field::new("a", DataType::Decimal(5, 2), false); let schema = Schema::new(vec![decimal_field]); - let mut dec_builder = DecimalBuilder::new(4, 5, 2); - dec_builder.append_value(10_000).unwrap(); - dec_builder.append_value(50_000).unwrap(); - dec_builder.append_value(0).unwrap(); - dec_builder.append_value(-100).unwrap(); + let decimal_values = vec![10_000, 50_000, 0, -100] + .into_iter() + .map(Some) + .collect::() + .with_precision_and_scale(5, 2) + .unwrap(); - let decimal_values = dec_builder.finish(); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(decimal_values)]) .unwrap(); - roundtrip( - "test_arrow_writer_decimal.parquet", - batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(batch, Some(SMALL_SIZE / 2)); } #[test] @@ -880,7 +981,7 @@ mod tests { .len(5) .add_buffer(g_value_offsets) .add_child_data(g_value.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00011011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00011011]))) .build() .unwrap(); let h = ListArray::from(h_list_data); @@ -903,17 +1004,8 @@ mod tests { ) .unwrap(); - roundtrip( - "test_arrow_writer_complex.parquet", - batch.clone(), - Some(SMALL_SIZE / 2), - ); - - roundtrip( - "test_arrow_writer_complex_small_batch.parquet", - batch, - Some(SMALL_SIZE / 3), - ); + roundtrip(batch.clone(), Some(SMALL_SIZE / 2)); + roundtrip(batch, Some(SMALL_SIZE / 3)); } #[test] @@ -951,11 +1043,7 @@ mod tests { RecordBatch::try_new(Arc::new(schema), vec![Arc::new(some_nested_object)]) .unwrap(); - roundtrip( - "test_arrow_writer_complex_mixed.parquet", - batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(batch, Some(SMALL_SIZE / 2)); } #[test] @@ -985,7 +1073,7 @@ mod tests { let mut reader = builder.build(std::io::Cursor::new(json_content)).unwrap(); let batch = reader.next().unwrap().unwrap(); - roundtrip("test_arrow_writer_map.parquet", batch, None); + roundtrip(batch, None); } #[test] @@ -1000,14 +1088,14 @@ mod tests { let c = Int32Array::from(vec![Some(1), None, Some(3), None, None, Some(6)]); let b_data = ArrayDataBuilder::new(field_b.data_type().clone()) .len(6) - .null_bit_buffer(Buffer::from(vec![0b00100111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00100111]))) .add_child_data(c.data().clone()) .build() .unwrap(); let b = StructArray::from(b_data); let a_data = ArrayDataBuilder::new(field_a.data_type().clone()) .len(6) - .null_bit_buffer(Buffer::from(vec![0b00101111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00101111]))) .add_child_data(b.data().clone()) .build() .unwrap(); @@ -1019,11 +1107,7 @@ mod tests { // build a racord batch let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); - roundtrip( - "test_arrow_writer_2_level_struct.parquet", - batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(batch, Some(SMALL_SIZE / 2)); } #[test] @@ -1055,11 +1139,7 @@ mod tests { // build a racord batch let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); - roundtrip( - "test_arrow_writer_2_level_struct_non_null.parquet", - batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(batch, Some(SMALL_SIZE / 2)); } #[test] @@ -1074,7 +1154,7 @@ mod tests { let c = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); let b_data = ArrayDataBuilder::new(field_b.data_type().clone()) .len(6) - .null_bit_buffer(Buffer::from(vec![0b00100111])) + .null_bit_buffer(Some(Buffer::from(vec![0b00100111]))) .add_child_data(c.data().clone()) .build() .unwrap(); @@ -1093,21 +1173,13 @@ mod tests { // build a racord batch let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); - roundtrip( - "test_arrow_writer_2_level_struct_mixed_null.parquet", - batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(batch, Some(SMALL_SIZE / 2)); } const SMALL_SIZE: usize = 7; - fn roundtrip( - filename: &str, - expected_batch: RecordBatch, - max_row_group_size: Option, - ) -> File { - let file = get_temp_file(filename, &[]); + fn roundtrip(expected_batch: RecordBatch, max_row_group_size: Option) -> File { + let file = tempfile::tempfile().unwrap(); let mut writer = ArrowWriter::try_new( file.try_clone().unwrap(), @@ -1122,8 +1194,8 @@ mod tests { writer.write(&expected_batch).unwrap(); writer.close().unwrap(); - let reader = SerializedFileReader::new(file.try_clone().unwrap()).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(reader)); + let mut arrow_reader = + ParquetFileArrowReader::try_new(file.try_clone().unwrap()).unwrap(); let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); let actual_batch = record_batch_reader @@ -1145,7 +1217,6 @@ mod tests { } fn one_column_roundtrip( - filename: &str, values: ArrayRef, nullable: bool, max_row_group_size: Option, @@ -1158,20 +1229,20 @@ mod tests { let expected_batch = RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); - roundtrip(filename, expected_batch, max_row_group_size) + roundtrip(expected_batch, max_row_group_size) } - fn values_required(iter: I, filename: &str) + fn values_required(iter: I) where A: From> + Array + 'static, I: IntoIterator, { let raw_values: Vec<_> = iter.into_iter().collect(); let values = Arc::new(A::from(raw_values)); - one_column_roundtrip(filename, values, false, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, false, Some(SMALL_SIZE / 2)); } - fn values_optional(iter: I, filename: &str) + fn values_optional(iter: I) where A: From>> + Array + 'static, I: IntoIterator, @@ -1182,32 +1253,27 @@ mod tests { .map(|(i, v)| if i % 2 == 0 { None } else { Some(v) }) .collect(); let optional_values = Arc::new(A::from(optional_raw_values)); - one_column_roundtrip(filename, optional_values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(optional_values, true, Some(SMALL_SIZE / 2)); } - fn required_and_optional(iter: I, filename: &str) + fn required_and_optional(iter: I) where A: From> + From>> + Array + 'static, I: IntoIterator + Clone, { - values_required::(iter.clone(), filename); - values_optional::(iter, filename); + values_required::(iter.clone()); + values_optional::(iter); } #[test] fn all_null_primitive_single_column() { let values = Arc::new(Int32Array::from(vec![None; SMALL_SIZE])); - one_column_roundtrip( - "all_null_primitive_single_column", - values, - true, - Some(SMALL_SIZE / 2), - ); + one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); } #[test] fn null_single_column() { let values = Arc::new(NullArray::new(SMALL_SIZE)); - one_column_roundtrip("null_single_column", values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); // null arrays are always nullable, a test with non-nullable nulls fails } @@ -1215,7 +1281,6 @@ mod tests { fn bool_single_column() { required_and_optional::( [true, false].iter().cycle().copied().take(SMALL_SIZE), - "bool_single_column", ); } @@ -1233,7 +1298,7 @@ mod tests { Schema::new(vec![Field::new("col", values.data_type().clone(), true)]); let expected_batch = RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); - let file = get_temp_file("bool_large_single_column", &[]); + let file = tempfile::tempfile().unwrap(); let mut writer = ArrowWriter::try_new( file.try_clone().unwrap(), @@ -1247,67 +1312,52 @@ mod tests { #[test] fn i8_single_column() { - required_and_optional::(0..SMALL_SIZE as i8, "i8_single_column"); + required_and_optional::(0..SMALL_SIZE as i8); } #[test] fn i16_single_column() { - required_and_optional::(0..SMALL_SIZE as i16, "i16_single_column"); + required_and_optional::(0..SMALL_SIZE as i16); } #[test] fn i32_single_column() { - required_and_optional::(0..SMALL_SIZE as i32, "i32_single_column"); + required_and_optional::(0..SMALL_SIZE as i32); } #[test] fn i64_single_column() { - required_and_optional::(0..SMALL_SIZE as i64, "i64_single_column"); + required_and_optional::(0..SMALL_SIZE as i64); } #[test] fn u8_single_column() { - required_and_optional::(0..SMALL_SIZE as u8, "u8_single_column"); + required_and_optional::(0..SMALL_SIZE as u8); } #[test] fn u16_single_column() { - required_and_optional::( - 0..SMALL_SIZE as u16, - "u16_single_column", - ); + required_and_optional::(0..SMALL_SIZE as u16); } #[test] fn u32_single_column() { - required_and_optional::( - 0..SMALL_SIZE as u32, - "u32_single_column", - ); + required_and_optional::(0..SMALL_SIZE as u32); } #[test] fn u64_single_column() { - required_and_optional::( - 0..SMALL_SIZE as u64, - "u64_single_column", - ); + required_and_optional::(0..SMALL_SIZE as u64); } #[test] fn f32_single_column() { - required_and_optional::( - (0..SMALL_SIZE).map(|i| i as f32), - "f32_single_column", - ); + required_and_optional::((0..SMALL_SIZE).map(|i| i as f32)); } #[test] fn f64_single_column() { - required_and_optional::( - (0..SMALL_SIZE).map(|i| i as f64), - "f64_single_column", - ); + required_and_optional::((0..SMALL_SIZE).map(|i| i as f64)); } // The timestamp array types don't implement From> because they need the timezone @@ -1319,7 +1369,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampSecondArray::from_vec(raw_values, None)); - one_column_roundtrip("timestamp_second_single_column", values, false, Some(3)); + one_column_roundtrip(values, false, Some(3)); } #[test] @@ -1327,12 +1377,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampMillisecondArray::from_vec(raw_values, None)); - one_column_roundtrip( - "timestamp_millisecond_single_column", - values, - false, - Some(SMALL_SIZE / 2 + 1), - ); + one_column_roundtrip(values, false, Some(SMALL_SIZE / 2 + 1)); } #[test] @@ -1340,12 +1385,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampMicrosecondArray::from_vec(raw_values, None)); - one_column_roundtrip( - "timestamp_microsecond_single_column", - values, - false, - Some(SMALL_SIZE / 2 + 2), - ); + one_column_roundtrip(values, false, Some(SMALL_SIZE / 2 + 2)); } #[test] @@ -1353,20 +1393,12 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampNanosecondArray::from_vec(raw_values, None)); - one_column_roundtrip( - "timestamp_nanosecond_single_column", - values, - false, - Some(SMALL_SIZE / 2), - ); + one_column_roundtrip(values, false, Some(SMALL_SIZE / 2)); } #[test] fn date32_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i32, - "date32_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i32); } #[test] @@ -1374,92 +1406,69 @@ mod tests { // Date64 must be a multiple of 86400000, see ARROW-10925 required_and_optional::( (0..(SMALL_SIZE as i64 * 86400000)).step_by(86400000), - "date64_single_column", ); } #[test] fn time32_second_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i32, - "time32_second_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i32); } #[test] fn time32_millisecond_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i32, - "time32_millisecond_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i32); } #[test] fn time64_microsecond_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i64, - "time64_microsecond_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i64); } #[test] fn time64_nanosecond_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i64, - "time64_nanosecond_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i64); } #[test] #[should_panic(expected = "Converting Duration to parquet not supported")] fn duration_second_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i64, - "duration_second_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i64); } #[test] #[should_panic(expected = "Converting Duration to parquet not supported")] fn duration_millisecond_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i64, - "duration_millisecond_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i64); } #[test] #[should_panic(expected = "Converting Duration to parquet not supported")] fn duration_microsecond_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i64, - "duration_microsecond_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i64); } #[test] #[should_panic(expected = "Converting Duration to parquet not supported")] fn duration_nanosecond_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i64, - "duration_nanosecond_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i64); } #[test] fn interval_year_month_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i32, - "interval_year_month_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i32); } #[test] fn interval_day_time_single_column() { - required_and_optional::( - 0..SMALL_SIZE as i64, - "interval_day_time_single_column", - ); + required_and_optional::(0..SMALL_SIZE as i64); + } + + #[test] + #[should_panic( + expected = "Attempting to write an Arrow interval type MonthDayNano to parquet that is not yet implemented" + )] + fn interval_month_day_nano_single_column() { + required_and_optional::(0..SMALL_SIZE as i128); } #[test] @@ -1469,7 +1478,7 @@ mod tests { let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); // BinaryArrays can't be built from Vec>, so only call `values_required` - values_required::(many_vecs_iter, "binary_single_column"); + values_required::(many_vecs_iter); } #[test] @@ -1479,10 +1488,7 @@ mod tests { let many_vecs_iter = many_vecs.iter().map(|v| v.as_slice()); // LargeBinaryArrays can't be built from Vec>, so only call `values_required` - values_required::( - many_vecs_iter, - "large_binary_single_column", - ); + values_required::(many_vecs_iter); } #[test] @@ -1494,12 +1500,7 @@ mod tests { builder.append_value(b"1112").unwrap(); let array = Arc::new(builder.finish()); - one_column_roundtrip( - "fixed_size_binary_single_column", - array, - true, - Some(SMALL_SIZE / 2), - ); + one_column_roundtrip(array, true, Some(SMALL_SIZE / 2)); } #[test] @@ -1507,7 +1508,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); let raw_strs = raw_values.iter().map(|s| s.as_str()); - required_and_optional::(raw_strs, "string_single_column"); + required_and_optional::(raw_strs); } #[test] @@ -1515,10 +1516,44 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE).map(|i| i.to_string()).collect(); let raw_strs = raw_values.iter().map(|s| s.as_str()); - required_and_optional::( - raw_strs, - "large_string_single_column", - ); + required_and_optional::(raw_strs); + } + + #[test] + fn null_list_single_column() { + let null_field = Field::new("item", DataType::Null, true); + let list_field = + Field::new("emptylist", DataType::List(Box::new(null_field)), true); + + let schema = Schema::new(vec![list_field]); + + // Build [[], null, [null, null]] + let a_values = NullArray::new(2); + let a_value_offsets = arrow::buffer::Buffer::from(&[0, 0, 0, 2].to_byte_slice()); + let a_list_data = ArrayData::builder(DataType::List(Box::new(Field::new( + "item", + DataType::Null, + true, + )))) + .len(3) + .add_buffer(a_value_offsets) + .null_bit_buffer(Some(Buffer::from(vec![0b00000101]))) + .add_child_data(a_values.data().clone()) + .build() + .unwrap(); + + let a = ListArray::from(a_list_data); + + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + assert!(a.is_valid(2)); + + assert_eq!(a.value(0).len(), 0); + assert_eq!(a.value(2).len(), 2); + assert_eq!(a.value(2).null_count(), 2); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); + roundtrip(batch, None); } #[test] @@ -1533,7 +1568,7 @@ mod tests { )))) .len(5) .add_buffer(a_value_offsets) - .null_bit_buffer(Buffer::from(vec![0b00011011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00011011]))) .add_child_data(a_values.data().clone()) .build() .unwrap(); @@ -1543,7 +1578,7 @@ mod tests { let a = ListArray::from(a_list_data); let values = Arc::new(a); - one_column_roundtrip("list_single_column", values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); } #[test] @@ -1559,7 +1594,7 @@ mod tests { .len(5) .add_buffer(a_value_offsets) .add_child_data(a_values.data().clone()) - .null_bit_buffer(Buffer::from(vec![0b00011011])) + .null_bit_buffer(Some(Buffer::from(vec![0b00011011]))) .build() .unwrap(); @@ -1569,12 +1604,26 @@ mod tests { let a = LargeListArray::from(a_list_data); let values = Arc::new(a); - one_column_roundtrip( - "large_list_single_column", - values, - true, - Some(SMALL_SIZE / 2), - ); + one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); + } + + #[test] + fn list_nested_nulls() { + use arrow::datatypes::Int32Type; + let data = vec![ + Some(vec![Some(1)]), + Some(vec![Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5), None]), + Some(vec![None]), + Some(vec![Some(6), Some(7)]), + ]; + + let list = ListArray::from_iter_primitive::(data.clone()); + one_column_roundtrip(Arc::new(list), true, Some(SMALL_SIZE / 2)); + + let list = LargeListArray::from_iter_primitive::(data); + one_column_roundtrip(Arc::new(list), true, Some(SMALL_SIZE / 2)); } #[test] @@ -1584,7 +1633,7 @@ mod tests { let s = StructArray::from(vec![(struct_field_a, Arc::new(a_values) as ArrayRef)]); let values = Arc::new(s); - one_column_roundtrip("struct_single_column", values, false, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, false, Some(SMALL_SIZE / 2)); } #[test] @@ -1607,11 +1656,7 @@ mod tests { // build a record batch let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); - roundtrip( - "test_arrow_writer_string_dictionary.parquet", - expected_batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(expected_batch, Some(SMALL_SIZE / 2)); } #[test] @@ -1638,11 +1683,7 @@ mod tests { // build a record batch let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); - roundtrip( - "test_arrow_writer_primitive_dictionary.parquet", - expected_batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(expected_batch, Some(SMALL_SIZE / 2)); } #[test] @@ -1665,11 +1706,7 @@ mod tests { // build a record batch let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); - roundtrip( - "test_arrow_writer_string_dictionary_unsigned_index.parquet", - expected_batch, - Some(SMALL_SIZE / 2), - ); + roundtrip(expected_batch, Some(SMALL_SIZE / 2)); } #[test] @@ -1684,7 +1721,7 @@ mod tests { u32::MAX - 1, u32::MAX, ])); - let file = one_column_roundtrip("u32_min_max_single_column", values, false, None); + let file = one_column_roundtrip(values, false, None); // check statistics are valid let reader = SerializedFileReader::new(file).unwrap(); @@ -1715,7 +1752,7 @@ mod tests { u64::MAX - 1, u64::MAX, ])); - let file = one_column_roundtrip("u64_min_max_single_column", values, false, None); + let file = one_column_roundtrip(values, false, None); // check statistics are valid let reader = SerializedFileReader::new(file).unwrap(); @@ -1738,7 +1775,7 @@ mod tests { fn statistics_null_counts_only_nulls() { // check that null-count statistics for "only NULL"-columns are correct let values = Arc::new(UInt64Array::from(vec![None, None])); - let file = one_column_roundtrip("null_counts", values, true, None); + let file = one_column_roundtrip(values, true, None); // check statistics are valid let reader = SerializedFileReader::new(file).unwrap(); @@ -1750,4 +1787,302 @@ mod tests { let stats = column.statistics().unwrap(); assert_eq!(stats.null_count(), 2); } + + #[test] + fn test_list_of_struct_roundtrip() { + // define schema + let int_field = Field::new("a", DataType::Int32, true); + let int_field2 = Field::new("b", DataType::Int32, true); + + let int_builder = Int32Builder::new(10); + let int_builder2 = Int32Builder::new(10); + + let struct_builder = StructBuilder::new( + vec![int_field, int_field2], + vec![Box::new(int_builder), Box::new(int_builder2)], + ); + let mut list_builder = ListBuilder::new(struct_builder); + + // Construct the following array + // [{a: 1, b: 2}], [], null, [null, null], [{a: null, b: 3}], [{a: 2, b: null}] + + // [{a: 1, b: 2}] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_value(1) + .unwrap(); + values + .field_builder::(1) + .unwrap() + .append_value(2) + .unwrap(); + values.append(true).unwrap(); + list_builder.append(true).unwrap(); + + // [] + list_builder.append(true).unwrap(); + + // null + list_builder.append(false).unwrap(); + + // [null, null] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_null() + .unwrap(); + values + .field_builder::(1) + .unwrap() + .append_null() + .unwrap(); + values.append(false).unwrap(); + values + .field_builder::(0) + .unwrap() + .append_null() + .unwrap(); + values + .field_builder::(1) + .unwrap() + .append_null() + .unwrap(); + values.append(false).unwrap(); + list_builder.append(true).unwrap(); + + // [{a: null, b: 3}] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_null() + .unwrap(); + values + .field_builder::(1) + .unwrap() + .append_value(3) + .unwrap(); + values.append(true).unwrap(); + list_builder.append(true).unwrap(); + + // [{a: 2, b: null}] + let values = list_builder.values(); + values + .field_builder::(0) + .unwrap() + .append_value(2) + .unwrap(); + values + .field_builder::(1) + .unwrap() + .append_null() + .unwrap(); + values.append(true).unwrap(); + list_builder.append(true).unwrap(); + + let array = Arc::new(list_builder.finish()); + + one_column_roundtrip(array, true, Some(10)); + } + + fn row_group_sizes(metadata: &ParquetMetaData) -> Vec { + metadata.row_groups().iter().map(|x| x.num_rows()).collect() + } + + #[test] + fn test_aggregates_records() { + let arrays = [ + Int32Array::from((0..100).collect::>()), + Int32Array::from((0..50).collect::>()), + Int32Array::from((200..500).collect::>()), + ]; + + let schema = Arc::new(Schema::new(vec![Field::new( + "int", + ArrowDataType::Int32, + false, + )])); + + let file = tempfile::tempfile().unwrap(); + + let props = WriterProperties::builder() + .set_max_row_group_size(200) + .build(); + + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), schema.clone(), Some(props)) + .unwrap(); + + for array in arrays { + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + writer.write(&batch).unwrap(); + } + + writer.close().unwrap(); + + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); + assert_eq!(&row_group_sizes(arrow_reader.metadata()), &[200, 200, 50]); + + let batches = arrow_reader + .get_record_reader(100) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(batches.len(), 5); + assert!(batches.iter().all(|x| x.num_columns() == 1)); + + let batch_sizes: Vec<_> = batches.iter().map(|x| x.num_rows()).collect(); + + assert_eq!(&batch_sizes, &[100, 100, 100, 100, 50]); + + let values: Vec<_> = batches + .iter() + .flat_map(|x| { + x.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + .iter() + .cloned() + }) + .collect(); + + let expected_values: Vec<_> = + [0..100, 0..50, 200..500].into_iter().flatten().collect(); + assert_eq!(&values, &expected_values) + } + + #[test] + fn complex_aggregate() { + // Tests aggregating nested data + let field_a = Field::new("leaf_a", DataType::Int32, false); + let field_b = Field::new("leaf_b", DataType::Int32, true); + let struct_a = Field::new( + "struct_a", + DataType::Struct(vec![field_a.clone(), field_b.clone()]), + true, + ); + + let list_a = Field::new("list", DataType::List(Box::new(struct_a)), true); + let struct_b = + Field::new("struct_b", DataType::Struct(vec![list_a.clone()]), false); + + let schema = Arc::new(Schema::new(vec![struct_b])); + + // create nested data + let field_a_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + let field_b_array = + Int32Array::from_iter(vec![Some(1), None, Some(2), None, None, Some(6)]); + + let struct_a_array = StructArray::from(vec![ + (field_a.clone(), Arc::new(field_a_array) as ArrayRef), + (field_b.clone(), Arc::new(field_b_array) as ArrayRef), + ]); + + let list_data = ArrayDataBuilder::new(list_a.data_type().clone()) + .len(5) + .add_buffer(Buffer::from_iter(vec![ + 0_i32, 1_i32, 1_i32, 3_i32, 3_i32, 5_i32, + ])) + .null_bit_buffer(Some(Buffer::from_iter(vec![ + true, false, true, false, true, + ]))) + .child_data(vec![struct_a_array.data().clone()]) + .build() + .unwrap(); + + let list_a_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + let struct_b_array = StructArray::from(vec![(list_a.clone(), list_a_array)]); + + let batch1 = RecordBatch::try_from_iter(vec![( + "struct_b", + Arc::new(struct_b_array) as ArrayRef, + )]) + .unwrap(); + + let field_a_array = Int32Array::from(vec![6, 7, 8, 9, 10]); + let field_b_array = Int32Array::from_iter(vec![None, None, None, Some(1), None]); + + let struct_a_array = StructArray::from(vec![ + (field_a, Arc::new(field_a_array) as ArrayRef), + (field_b, Arc::new(field_b_array) as ArrayRef), + ]); + + let list_data = ArrayDataBuilder::new(list_a.data_type().clone()) + .len(2) + .add_buffer(Buffer::from_iter(vec![0_i32, 4_i32, 5_i32])) + .child_data(vec![struct_a_array.data().clone()]) + .build() + .unwrap(); + + let list_a_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + let struct_b_array = StructArray::from(vec![(list_a, list_a_array)]); + + let batch2 = RecordBatch::try_from_iter(vec![( + "struct_b", + Arc::new(struct_b_array) as ArrayRef, + )]) + .unwrap(); + + let batches = &[batch1, batch2]; + + // Verify data is as expected + + let expected = r#" + +-------------------------------------------------------------------------------------------------------------------------------------+ + | struct_b | + +-------------------------------------------------------------------------------------------------------------------------------------+ + | {"list": [{"leaf_a": 1, "leaf_b": 1}]} | + | {"list": null} | + | {"list": [{"leaf_a": 2, "leaf_b": null}, {"leaf_a": 3, "leaf_b": 2}]} | + | {"list": null} | + | {"list": [{"leaf_a": 4, "leaf_b": null}, {"leaf_a": 5, "leaf_b": null}]} | + | {"list": [{"leaf_a": 6, "leaf_b": null}, {"leaf_a": 7, "leaf_b": null}, {"leaf_a": 8, "leaf_b": null}, {"leaf_a": 9, "leaf_b": 1}]} | + | {"list": [{"leaf_a": 10, "leaf_b": null}]} | + +-------------------------------------------------------------------------------------------------------------------------------------+ + "#.trim().split('\n').map(|x| x.trim()).collect::>().join("\n"); + + let actual = pretty_format_batches(batches).unwrap().to_string(); + assert_eq!(actual, expected); + + // Write data + let file = tempfile::tempfile().unwrap(); + let props = WriterProperties::builder() + .set_max_row_group_size(6) + .build(); + + let mut writer = + ArrowWriter::try_new(file.try_clone().unwrap(), schema, Some(props)).unwrap(); + + for batch in batches { + writer.write(batch).unwrap(); + } + writer.close().unwrap(); + + // Read Data + // Should have written entire first batch and first row of second to the first row group + // leaving a single row in the second row group + + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); + assert_eq!(&row_group_sizes(arrow_reader.metadata()), &[6, 1]); + + let batches = arrow_reader + .get_record_reader(2) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(batches.len(), 4); + let batch_counts: Vec<_> = batches.iter().map(|x| x.num_rows()).collect(); + assert_eq!(&batch_counts, &[2, 2, 2, 1]); + + let actual = pretty_format_batches(&batches).unwrap().to_string(); + assert_eq!(actual, expected); + } } diff --git a/parquet/src/arrow/async_reader.rs b/parquet/src/arrow/async_reader.rs new file mode 100644 index 000000000000..3f14114e3c60 --- /dev/null +++ b/parquet/src/arrow/async_reader.rs @@ -0,0 +1,665 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Provides `async` API for reading parquet files as +//! [`RecordBatch`]es +//! +//! ``` +//! # #[tokio::main(flavor="current_thread")] +//! # async fn main() { +//! # +//! use arrow::record_batch::RecordBatch; +//! use arrow::util::pretty::pretty_format_batches; +//! use futures::TryStreamExt; +//! use tokio::fs::File; +//! +//! use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +//! +//! # fn assert_batches_eq(batches: &[RecordBatch], expected_lines: &[&str]) { +//! # let formatted = pretty_format_batches(batches).unwrap().to_string(); +//! # let actual_lines: Vec<_> = formatted.trim().lines().collect(); +//! # assert_eq!( +//! # &actual_lines, expected_lines, +//! # "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", +//! # expected_lines, actual_lines +//! # ); +//! # } +//! +//! let testdata = arrow::util::test_util::parquet_test_data(); +//! let path = format!("{}/alltypes_plain.parquet", testdata); +//! let file = File::open(path).await.unwrap(); +//! +//! let builder = ParquetRecordBatchStreamBuilder::new(file) +//! .await +//! .unwrap() +//! .with_batch_size(3); +//! +//! let file_metadata = builder.metadata().file_metadata(); +//! let mask = ProjectionMask::roots(file_metadata.schema_descr(), [1, 2, 6]); +//! +//! let stream = builder.with_projection(mask).build().unwrap(); +//! let results = stream.try_collect::>().await.unwrap(); +//! assert_eq!(results.len(), 3); +//! +//! assert_batches_eq( +//! &results, +//! &[ +//! "+----------+-------------+-----------+", +//! "| bool_col | tinyint_col | float_col |", +//! "+----------+-------------+-----------+", +//! "| true | 0 | 0 |", +//! "| false | 1 | 1.1 |", +//! "| true | 0 | 0 |", +//! "| false | 1 | 1.1 |", +//! "| true | 0 | 0 |", +//! "| false | 1 | 1.1 |", +//! "| true | 0 | 0 |", +//! "| false | 1 | 1.1 |", +//! "+----------+-------------+-----------+", +//! ], +//! ); +//! # } +//! ``` + +use std::collections::VecDeque; +use std::fmt::Formatter; +use std::io::{Cursor, SeekFrom}; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use bytes::{Buf, Bytes}; +use futures::future::{BoxFuture, FutureExt}; +use futures::stream::Stream; +use parquet_format::PageType; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; + +use crate::arrow::array_reader::{build_array_reader, RowGroupCollection}; +use crate::arrow::arrow_reader::ParquetRecordBatchReader; +use crate::arrow::schema::parquet_to_arrow_schema; +use crate::arrow::ProjectionMask; +use crate::basic::Compression; +use crate::column::page::{Page, PageIterator, PageReader}; +use crate::compression::{create_codec, Codec}; +use crate::errors::{ParquetError, Result}; +use crate::file::footer::{decode_footer, decode_metadata}; +use crate::file::metadata::ParquetMetaData; +use crate::file::reader::SerializedPageReader; +use crate::file::serialized_reader::{decode_page, read_page_header}; +use crate::file::FOOTER_SIZE; +use crate::schema::types::{ColumnDescPtr, SchemaDescPtr, SchemaDescriptor}; + +/// The asynchronous interface used by [`ParquetRecordBatchStream`] to read parquet files +pub trait AsyncFileReader { + /// Retrieve the bytes in `range` + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, Result>; + + /// Provides asynchronous access to the [`ParquetMetaData`] of a parquet file, + /// allowing fine-grained control over how metadata is sourced, in particular allowing + /// for caching, pre-fetching, catalog metadata, etc... + fn get_metadata(&mut self) -> BoxFuture<'_, Result>>; +} + +impl AsyncFileReader for T { + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, Result> { + async move { + self.seek(SeekFrom::Start(range.start as u64)).await?; + + let to_read = range.end - range.start; + let mut buffer = Vec::with_capacity(to_read); + let read = self.take(to_read as u64).read_to_end(&mut buffer).await?; + if read != to_read { + eof_err!("expected to read {} bytes, got {}", to_read, read); + } + + Ok(buffer.into()) + } + .boxed() + } + + fn get_metadata(&mut self) -> BoxFuture<'_, Result>> { + const FOOTER_SIZE_I64: i64 = FOOTER_SIZE as i64; + async move { + self.seek(SeekFrom::End(-FOOTER_SIZE_I64)).await?; + + let mut buf = [0_u8; FOOTER_SIZE]; + self.read_exact(&mut buf).await?; + + let metadata_len = decode_footer(&buf)?; + self.seek(SeekFrom::End(-FOOTER_SIZE_I64 - metadata_len as i64)) + .await?; + + let mut buf = Vec::with_capacity(metadata_len); + self.read_to_end(&mut buf).await?; + + Ok(Arc::new(decode_metadata(&buf)?)) + } + .boxed() + } +} + +/// A builder used to construct a [`ParquetRecordBatchStream`] for a parquet file +/// +/// In particular, this handles reading the parquet file metadata, allowing consumers +/// to use this information to select what specific columns, row groups, etc... +/// they wish to be read by the resulting stream +/// +pub struct ParquetRecordBatchStreamBuilder { + input: T, + + metadata: Arc, + + schema: SchemaRef, + + batch_size: usize, + + row_groups: Option>, + + projection: ProjectionMask, +} + +impl ParquetRecordBatchStreamBuilder { + /// Create a new [`ParquetRecordBatchStreamBuilder`] with the provided parquet file + pub async fn new(mut input: T) -> Result { + let metadata = input.get_metadata().await?; + + let schema = Arc::new(parquet_to_arrow_schema( + metadata.file_metadata().schema_descr(), + metadata.file_metadata().key_value_metadata(), + )?); + + Ok(Self { + input, + metadata, + schema, + batch_size: 1024, + row_groups: None, + projection: ProjectionMask::all(), + }) + } + + /// Returns a reference to the [`ParquetMetaData`] for this parquet file + pub fn metadata(&self) -> &Arc { + &self.metadata + } + + /// Returns the parquet [`SchemaDescriptor`] for this parquet file + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.metadata.file_metadata().schema_descr() + } + + /// Returns the arrow [`SchemaRef`] for this parquet file + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Set the size of [`RecordBatch`] to produce + pub fn with_batch_size(self, batch_size: usize) -> Self { + Self { batch_size, ..self } + } + + /// Only read data from the provided row group indexes + pub fn with_row_groups(self, row_groups: Vec) -> Self { + Self { + row_groups: Some(row_groups), + ..self + } + } + + /// Only read data from the provided column indexes + pub fn with_projection(self, mask: ProjectionMask) -> Self { + Self { + projection: mask, + ..self + } + } + + /// Build a new [`ParquetRecordBatchStream`] + pub fn build(self) -> Result> { + let num_row_groups = self.metadata.row_groups().len(); + + let row_groups = match self.row_groups { + Some(row_groups) => { + if let Some(col) = row_groups.iter().find(|x| **x >= num_row_groups) { + return Err(general_err!( + "row group {} out of bounds 0..{}", + col, + num_row_groups + )); + } + row_groups.into() + } + None => (0..self.metadata.row_groups().len()).collect(), + }; + + Ok(ParquetRecordBatchStream { + row_groups, + projection: self.projection, + batch_size: self.batch_size, + metadata: self.metadata, + schema: self.schema, + input: Some(self.input), + state: StreamState::Init, + }) + } +} + +enum StreamState { + /// At the start of a new row group, or the end of the parquet stream + Init, + /// Decoding a batch + Decoding(ParquetRecordBatchReader), + /// Reading data from input + Reading(BoxFuture<'static, Result<(T, InMemoryRowGroup)>>), + /// Error + Error, +} + +impl std::fmt::Debug for StreamState { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + StreamState::Init => write!(f, "StreamState::Init"), + StreamState::Decoding(_) => write!(f, "StreamState::Decoding"), + StreamState::Reading(_) => write!(f, "StreamState::Reading"), + StreamState::Error => write!(f, "StreamState::Error"), + } + } +} + +/// An asynchronous [`Stream`] of [`RecordBatch`] for a parquet file +pub struct ParquetRecordBatchStream { + metadata: Arc, + + schema: SchemaRef, + + batch_size: usize, + + projection: ProjectionMask, + + row_groups: VecDeque, + + /// This is an option so it can be moved into a future + input: Option, + + state: StreamState, +} + +impl std::fmt::Debug for ParquetRecordBatchStream { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ParquetRecordBatchStream") + .field("metadata", &self.metadata) + .field("schema", &self.schema) + .field("batch_size", &self.batch_size) + .field("projection", &self.projection) + .field("state", &self.state) + .finish() + } +} + +impl ParquetRecordBatchStream { + /// Returns the [`SchemaRef`] for this parquet file + pub fn schema(&self) -> &SchemaRef { + &self.schema + } +} + +impl Stream for ParquetRecordBatchStream +where + T: AsyncFileReader + Unpin + Send + 'static, +{ + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match &mut self.state { + StreamState::Decoding(batch_reader) => match batch_reader.next() { + Some(Ok(batch)) => return Poll::Ready(Some(Ok(batch))), + Some(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(ParquetError::ArrowError( + e.to_string(), + )))); + } + None => self.state = StreamState::Init, + }, + StreamState::Init => { + let row_group_idx = match self.row_groups.pop_front() { + Some(idx) => idx, + None => return Poll::Ready(None), + }; + + let metadata = self.metadata.clone(); + let mut input = match self.input.take() { + Some(input) => input, + None => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(general_err!( + "input stream lost" + )))); + } + }; + + let projection = self.projection.clone(); + self.state = StreamState::Reading( + async move { + let row_group_metadata = metadata.row_group(row_group_idx); + let mut column_chunks = + vec![None; row_group_metadata.columns().len()]; + + // TODO: Combine consecutive ranges + for (idx, chunk) in column_chunks.iter_mut().enumerate() { + if !projection.leaf_included(idx) { + continue; + } + + let column = row_group_metadata.column(idx); + let (start, length) = column.byte_range(); + + let data = input + .get_bytes(start as usize..(start + length) as usize) + .await?; + + *chunk = Some(InMemoryColumnChunk { + num_values: column.num_values(), + compression: column.compression(), + physical_type: column.column_type(), + data, + }); + } + + Ok(( + input, + InMemoryRowGroup { + schema: metadata.file_metadata().schema_descr_ptr(), + row_count: row_group_metadata.num_rows() as usize, + column_chunks, + }, + )) + } + .boxed(), + ) + } + StreamState::Reading(f) => { + let result = futures::ready!(f.poll_unpin(cx)); + self.state = StreamState::Init; + + let row_group: Box = match result { + Ok((input, row_group)) => { + self.input = Some(input); + Box::new(row_group) + } + Err(e) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(e))); + } + }; + + let parquet_schema = self.metadata.file_metadata().schema_descr_ptr(); + + let array_reader = build_array_reader( + parquet_schema, + self.schema.clone(), + self.projection.clone(), + row_group, + )?; + + let batch_reader = + ParquetRecordBatchReader::try_new(self.batch_size, array_reader) + .expect("reader"); + + self.state = StreamState::Decoding(batch_reader) + } + StreamState::Error => return Poll::Pending, + } + } + } +} + +/// An in-memory collection of column chunks +struct InMemoryRowGroup { + schema: SchemaDescPtr, + column_chunks: Vec>, + row_count: usize, +} + +impl RowGroupCollection for InMemoryRowGroup { + fn schema(&self) -> Result { + Ok(self.schema.clone()) + } + + fn num_rows(&self) -> usize { + self.row_count + } + + fn column_chunks(&self, i: usize) -> Result> { + let page_reader = self.column_chunks[i].as_ref().unwrap().pages(); + + Ok(Box::new(ColumnChunkIterator { + schema: self.schema.clone(), + column_schema: self.schema.columns()[i].clone(), + reader: Some(page_reader), + })) + } +} + +/// Data for a single column chunk +#[derive(Clone)] +struct InMemoryColumnChunk { + num_values: i64, + compression: Compression, + physical_type: crate::basic::Type, + data: Bytes, +} + +impl InMemoryColumnChunk { + fn pages(&self) -> Result> { + let page_reader = SerializedPageReader::new( + self.data.clone().reader(), + self.num_values, + self.compression, + self.physical_type, + )?; + + Ok(Box::new(page_reader)) + } +} + +// A serialized implementation for Parquet [`PageReader`]. +struct InMemoryColumnChunkReader { + chunk: InMemoryColumnChunk, + decompressor: Option>, + offset: usize, + seen_num_values: i64, +} + +impl InMemoryColumnChunkReader { + /// Creates a new serialized page reader from file source. + pub fn new(chunk: InMemoryColumnChunk) -> Result { + let decompressor = create_codec(chunk.compression)?; + let result = Self { + chunk, + decompressor, + offset: 0, + seen_num_values: 0, + }; + Ok(result) + } +} + +impl Iterator for InMemoryColumnChunkReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.get_next_page().transpose() + } +} + +impl PageReader for InMemoryColumnChunkReader { + fn get_next_page(&mut self) -> Result> { + while self.seen_num_values < self.chunk.num_values { + let mut cursor = Cursor::new(&self.chunk.data.as_ref()[self.offset..]); + let page_header = read_page_header(&mut cursor)?; + let compressed_size = page_header.compressed_page_size as usize; + + self.offset += cursor.position() as usize; + let start_offset = self.offset; + let end_offset = self.offset + compressed_size; + self.offset = end_offset; + + let buffer = self.chunk.data.slice(start_offset..end_offset); + + let result = match page_header.type_ { + PageType::DataPage | PageType::DataPageV2 => { + let decoded = decode_page( + page_header, + buffer.into(), + self.chunk.physical_type, + self.decompressor.as_mut(), + )?; + self.seen_num_values += decoded.num_values() as i64; + decoded + } + PageType::DictionaryPage => decode_page( + page_header, + buffer.into(), + self.chunk.physical_type, + self.decompressor.as_mut(), + )?, + _ => { + // For unknown page type (e.g., INDEX_PAGE), skip and read next. + continue; + } + }; + + return Ok(Some(result)); + } + + // We are at the end of this column chunk and no more page left. Return None. + Ok(None) + } +} + +/// Implements [`PageIterator`] for a single column chunk, yielding a single [`PageReader`] +struct ColumnChunkIterator { + schema: SchemaDescPtr, + column_schema: ColumnDescPtr, + reader: Option>>, +} + +impl Iterator for ColumnChunkIterator { + type Item = Result>; + + fn next(&mut self) -> Option { + self.reader.take() + } +} + +impl PageIterator for ColumnChunkIterator { + fn schema(&mut self) -> Result { + Ok(self.schema.clone()) + } + + fn column_schema(&mut self) -> Result { + Ok(self.column_schema.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::{ArrowReader, ParquetFileArrowReader}; + use arrow::error::Result as ArrowResult; + use futures::TryStreamExt; + use std::sync::Mutex; + + struct TestReader { + data: Bytes, + metadata: Arc, + requests: Arc>>>, + } + + impl AsyncFileReader for TestReader { + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, Result> { + self.requests.lock().unwrap().push(range.clone()); + futures::future::ready(Ok(self.data.slice(range))).boxed() + } + + fn get_metadata(&mut self) -> BoxFuture<'_, Result>> { + futures::future::ready(Ok(self.metadata.clone())).boxed() + } + } + + #[tokio::test] + async fn test_async_reader() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_plain.parquet", testdata); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = crate::file::footer::parse_metadata(&data).unwrap(); + let metadata = Arc::new(metadata); + + assert_eq!(metadata.num_row_groups(), 1); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let requests = async_reader.requests.clone(); + let builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + + let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]); + let stream = builder + .with_projection(mask.clone()) + .with_batch_size(1024) + .build() + .unwrap(); + + let async_batches: Vec<_> = stream.try_collect().await.unwrap(); + + let mut sync_reader = ParquetFileArrowReader::try_new(data).unwrap(); + let sync_batches = sync_reader + .get_record_reader_by_columns(mask, 1024) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(async_batches, sync_batches); + + let requests = requests.lock().unwrap(); + let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range(); + let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range(); + + assert_eq!( + &requests[..], + &[ + offset_1 as usize..(offset_1 + length_1) as usize, + offset_2 as usize..(offset_2 + length_2) as usize + ] + ); + } +} diff --git a/parquet/src/arrow/buffer/bit_util.rs b/parquet/src/arrow/buffer/bit_util.rs new file mode 100644 index 000000000000..192ab4b72163 --- /dev/null +++ b/parquet/src/arrow/buffer/bit_util.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::bit_chunk_iterator::UnalignedBitChunk; +use std::ops::Range; + +/// Counts the number of set bits in the provided range +pub fn count_set_bits(bytes: &[u8], range: Range) -> usize { + let unaligned = UnalignedBitChunk::new(bytes, range.start, range.end - range.start); + unaligned.count_ones() +} + +/// Iterates through the set bit positions in `bytes` in reverse order +pub fn iter_set_bits_rev(bytes: &[u8]) -> impl Iterator + '_ { + let bit_length = bytes.len() * 8; + let unaligned = UnalignedBitChunk::new(bytes, 0, bit_length); + let mut chunk_end_idx = + bit_length + unaligned.lead_padding() + unaligned.trailing_padding(); + + let iter = unaligned + .prefix() + .into_iter() + .chain(unaligned.chunks().iter().cloned()) + .chain(unaligned.suffix().into_iter()); + + iter.rev().flat_map(move |mut chunk| { + let chunk_idx = chunk_end_idx - 64; + chunk_end_idx = chunk_idx; + std::iter::from_fn(move || { + if chunk != 0 { + let bit_pos = 63 - chunk.leading_zeros(); + chunk ^= 1 << bit_pos; + return Some(chunk_idx + (bit_pos as usize)); + } + None + }) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::BooleanBufferBuilder; + use rand::prelude::*; + + #[test] + fn test_bit_fns() { + let mut rng = thread_rng(); + let mask_length = rng.gen_range(1..1024); + let bools: Vec<_> = std::iter::from_fn(|| Some(rng.next_u32() & 1 == 0)) + .take(mask_length) + .collect(); + + let mut nulls = BooleanBufferBuilder::new(mask_length); + bools.iter().for_each(|b| nulls.append(*b)); + + let actual: Vec<_> = iter_set_bits_rev(nulls.as_slice()).collect(); + let expected: Vec<_> = bools + .iter() + .enumerate() + .rev() + .filter_map(|(x, y)| y.then(|| x)) + .collect(); + assert_eq!(actual, expected); + + assert_eq!(iter_set_bits_rev(&[]).count(), 0); + assert_eq!(count_set_bits(&[], 0..0), 0); + assert_eq!(count_set_bits(&[0xFF], 1..1), 0); + + for _ in 0..20 { + let start = rng.gen_range(0..bools.len()); + let end = rng.gen_range(start..bools.len()); + + let actual = count_set_bits(nulls.as_slice(), start..end); + let expected = bools[start..end].iter().filter(|x| **x).count(); + + assert_eq!(actual, expected); + } + } +} diff --git a/parquet/src/arrow/converter.rs b/parquet/src/arrow/buffer/converter.rs similarity index 66% rename from parquet/src/arrow/converter.rs rename to parquet/src/arrow/buffer/converter.rs index 1672be9c0462..51e1d8290ee3 100644 --- a/parquet/src/arrow/converter.rs +++ b/parquet/src/arrow/buffer/converter.rs @@ -15,30 +15,20 @@ // specific language governing permissions and limitations // under the License. -use crate::data_type::{ByteArray, DataType, FixedLenByteArray, Int96}; -// TODO: clean up imports (best done when there are few moving parts) +use crate::data_type::{ByteArray, FixedLenByteArray, Int96}; use arrow::array::{ - Array, ArrayRef, BinaryBuilder, DecimalBuilder, FixedSizeBinaryBuilder, - IntervalDayTimeArray, IntervalDayTimeBuilder, IntervalYearMonthArray, - IntervalYearMonthBuilder, LargeBinaryBuilder, LargeStringBuilder, PrimitiveBuilder, - PrimitiveDictionaryBuilder, StringBuilder, StringDictionaryBuilder, + Array, ArrayRef, BinaryArray, BinaryBuilder, DecimalArray, FixedSizeBinaryArray, + FixedSizeBinaryBuilder, IntervalDayTimeArray, IntervalDayTimeBuilder, + IntervalYearMonthArray, IntervalYearMonthBuilder, LargeBinaryArray, + LargeBinaryBuilder, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, + TimestampNanosecondArray, }; -use arrow::compute::cast; use std::convert::{From, TryInto}; use std::sync::Arc; use crate::errors::Result; -use arrow::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}; - -use arrow::array::{ - BinaryArray, DecimalArray, DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, - LargeStringArray, PrimitiveArray, StringArray, TimestampNanosecondArray, -}; use std::marker::PhantomData; -use crate::data_type::Int32Type as ParquetInt32Type; -use arrow::datatypes::Int32Type; - /// A converter is used to consume record reader's content and convert it to arrow /// primitive array. pub trait Converter { @@ -100,21 +90,13 @@ impl DecimalArrayConverter { impl Converter>, DecimalArray> for DecimalArrayConverter { fn convert(&self, source: Vec>) -> Result { - let mut builder = DecimalBuilder::new( - source.len(), - self.precision as usize, - self.scale as usize, - ); - for v in source { - match v { - Some(array) => { - builder.append_value(Self::from_bytes_to_i128(array.data())) - } - None => builder.append_null(), - }? - } + let array = source + .into_iter() + .map(|array| array.map(|array| Self::from_bytes_to_i128(array.data()))) + .collect::() + .with_precision_and_scale(self.precision as usize, self.scale as usize)?; - Ok(builder.finish()) + Ok(array) } } /// An Arrow Interval converter, which reads the first 4 bytes of a Parquet interval, @@ -257,92 +239,6 @@ impl Converter>, LargeBinaryArray> for LargeBinaryArrayCon } } -pub struct StringDictionaryArrayConverter {} - -impl Converter>, DictionaryArray> - for StringDictionaryArrayConverter -{ - fn convert(&self, source: Vec>) -> Result> { - let data_size = source - .iter() - .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) - .sum(); - - let keys_builder = PrimitiveBuilder::::new(source.len()); - let values_builder = StringBuilder::with_capacity(source.len(), data_size); - - let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); - for v in source { - match v { - Some(array) => { - let _ = builder.append(array.as_utf8()?)?; - } - None => builder.append_null()?, - } - } - - Ok(builder.finish()) - } -} - -pub struct DictionaryArrayConverter -{ - _dict_value_source_marker: PhantomData, - _dict_value_target_marker: PhantomData, - _parquet_marker: PhantomData, -} - -impl - DictionaryArrayConverter -{ - pub fn new() -> Self { - Self { - _dict_value_source_marker: PhantomData, - _dict_value_target_marker: PhantomData, - _parquet_marker: PhantomData, - } - } -} - -impl - Converter::T>>, DictionaryArray> - for DictionaryArrayConverter -where - K: ArrowPrimitiveType, - DictValueSourceType: ArrowPrimitiveType, - DictValueTargetType: ArrowPrimitiveType, - ParquetType: DataType, - PrimitiveArray: From::T>>>, -{ - fn convert( - &self, - source: Vec::T>>, - ) -> Result> { - let keys_builder = PrimitiveBuilder::::new(source.len()); - let values_builder = PrimitiveBuilder::::new(source.len()); - - let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); - - let source_array: Arc = - Arc::new(PrimitiveArray::::from(source)); - let target_array = cast(&source_array, &DictValueTargetType::DATA_TYPE)?; - let target = target_array - .as_any() - .downcast_ref::>() - .unwrap(); - - for i in 0..target.len() { - if target.is_null(i) { - builder.append_null()?; - } else { - let _ = builder.append(target.value(i))?; - } - } - - Ok(builder.finish()) - } -} - pub type Utf8Converter = ArrayRefConverter>, StringArray, Utf8ArrayConverter>; pub type LargeUtf8Converter = @@ -354,21 +250,6 @@ pub type LargeBinaryConverter = ArrayRefConverter< LargeBinaryArray, LargeBinaryArrayConverter, >; -pub type StringDictionaryConverter = ArrayRefConverter< - Vec>, - DictionaryArray, - StringDictionaryArrayConverter, ->; -pub type DictionaryConverter = ArrayRefConverter< - Vec::T>>, - DictionaryArray, - DictionaryArrayConverter, ->; -pub type PrimitiveDictionaryConverter = ArrayRefConverter< - Vec::T>>, - DictionaryArray, - DictionaryArrayConverter, ->; pub type Int96Converter = ArrayRefConverter>, TimestampNanosecondArray, Int96ArrayConverter>; diff --git a/parquet/src/arrow/buffer/dictionary_buffer.rs b/parquet/src/arrow/buffer/dictionary_buffer.rs new file mode 100644 index 000000000000..ffa3a4843c50 --- /dev/null +++ b/parquet/src/arrow/buffer/dictionary_buffer.rs @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::buffer::offset_buffer::OffsetBuffer; +use crate::arrow::record_reader::buffer::{ + BufferQueue, ScalarBuffer, ScalarValue, ValuesBuffer, +}; +use crate::column::reader::decoder::ValuesBufferSlice; +use crate::errors::{ParquetError, Result}; +use arrow::array::{make_array, ArrayDataBuilder, ArrayRef, OffsetSizeTrait}; +use arrow::buffer::Buffer; +use arrow::datatypes::{ArrowNativeType, DataType as ArrowType}; +use std::sync::Arc; + +/// An array of variable length byte arrays that are potentially dictionary encoded +/// and can be converted into a corresponding [`ArrayRef`] +pub enum DictionaryBuffer { + Dict { + keys: ScalarBuffer, + values: ArrayRef, + }, + Values { + values: OffsetBuffer, + }, +} + +impl Default for DictionaryBuffer { + fn default() -> Self { + Self::Values { + values: Default::default(), + } + } +} + +impl + DictionaryBuffer +{ + pub fn len(&self) -> usize { + match self { + Self::Dict { keys, .. } => keys.len(), + Self::Values { values } => values.len(), + } + } + + /// Returns a mutable reference to a keys array + /// + /// Returns None if the dictionary needs to be recomputed + /// + /// # Panic + /// + /// Panics if the dictionary is too large for `K` + pub fn as_keys(&mut self, dictionary: &ArrayRef) -> Option<&mut ScalarBuffer> { + assert!(K::from_usize(dictionary.len()).is_some()); + + match self { + Self::Dict { keys, values } => { + // Need to discard fat pointer for equality check + // - https://stackoverflow.com/a/67114787 + // - https://github.com/rust-lang/rust/issues/46139 + let values_ptr = values.as_ref() as *const _ as *const (); + let dict_ptr = dictionary.as_ref() as *const _ as *const (); + if values_ptr == dict_ptr { + Some(keys) + } else if keys.is_empty() { + *values = Arc::clone(dictionary); + Some(keys) + } else { + None + } + } + Self::Values { values } if values.is_empty() => { + *self = Self::Dict { + keys: Default::default(), + values: Arc::clone(dictionary), + }; + match self { + Self::Dict { keys, .. } => Some(keys), + _ => unreachable!(), + } + } + _ => None, + } + } + + /// Returns a mutable reference to a values array + /// + /// If this is currently dictionary encoded, this will convert from the + /// dictionary encoded representation + pub fn spill_values(&mut self) -> Result<&mut OffsetBuffer> { + match self { + Self::Values { values } => Ok(values), + Self::Dict { keys, values } => { + let mut spilled = OffsetBuffer::default(); + let dict_buffers = values.data().buffers(); + let dict_offsets = dict_buffers[0].typed_data::(); + let dict_values = dict_buffers[1].as_slice(); + + if values.is_empty() { + // If dictionary is empty, zero pad offsets + spilled.offsets.resize(keys.len() + 1); + } else { + // Note: at this point null positions will have arbitrary dictionary keys + // and this will hydrate them to the corresponding byte array. This is + // likely sub-optimal, as we would prefer zero length null "slots", but + // spilling is already a degenerate case and so it is unclear if this is + // worth optimising for, e.g. by keeping a null mask around + spilled.extend_from_dictionary( + keys.as_slice(), + dict_offsets, + dict_values, + )?; + } + + *self = Self::Values { values: spilled }; + match self { + Self::Values { values } => Ok(values), + _ => unreachable!(), + } + } + } + } + + /// Converts this into an [`ArrayRef`] with the provided `data_type` and `null_buffer` + pub fn into_array( + self, + null_buffer: Option, + data_type: &ArrowType, + ) -> Result { + assert!(matches!(data_type, ArrowType::Dictionary(_, _))); + + match self { + Self::Dict { keys, values } => { + // Validate keys unless dictionary is empty + if !values.is_empty() { + let min = K::from_usize(0).unwrap(); + let max = K::from_usize(values.len()).unwrap(); + + // It may be possible to use SIMD here + if keys.as_slice().iter().any(|x| *x < min || *x >= max) { + return Err(general_err!( + "dictionary key beyond bounds of dictionary: 0..{}", + values.len() + )); + } + } + + let builder = ArrayDataBuilder::new(data_type.clone()) + .len(keys.len()) + .add_buffer(keys.into()) + .add_child_data(values.data().clone()) + .null_bit_buffer(null_buffer); + + let data = match cfg!(debug_assertions) { + true => builder.build().unwrap(), + false => unsafe { builder.build_unchecked() }, + }; + + Ok(make_array(data)) + } + Self::Values { values } => { + let value_type = match data_type { + ArrowType::Dictionary(_, v) => v.as_ref().clone(), + _ => unreachable!(), + }; + + // This will compute a new dictionary + let array = arrow::compute::cast( + &values.into_array(null_buffer, value_type), + data_type, + ) + .expect("cast should be infallible"); + + Ok(array) + } + } + } +} + +impl ValuesBufferSlice for DictionaryBuffer { + fn capacity(&self) -> usize { + usize::MAX + } +} + +impl ValuesBuffer + for DictionaryBuffer +{ + fn pad_nulls( + &mut self, + read_offset: usize, + values_read: usize, + levels_read: usize, + valid_mask: &[u8], + ) { + match self { + Self::Dict { keys, .. } => { + keys.resize(read_offset + levels_read); + keys.pad_nulls(read_offset, values_read, levels_read, valid_mask) + } + Self::Values { values, .. } => { + values.pad_nulls(read_offset, values_read, levels_read, valid_mask) + } + } + } +} + +impl BufferQueue + for DictionaryBuffer +{ + type Output = Self; + type Slice = Self; + + fn split_off(&mut self, len: usize) -> Self::Output { + match self { + Self::Dict { keys, values } => Self::Dict { + keys: keys.take(len), + values: values.clone(), + }, + Self::Values { values } => Self::Values { + values: values.split_off(len), + }, + } + } + + fn spare_capacity_mut(&mut self, _batch_size: usize) -> &mut Self::Slice { + self + } + + fn set_len(&mut self, len: usize) { + match self { + Self::Dict { keys, .. } => keys.set_len(len), + Self::Values { values } => values.set_len(len), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, StringArray}; + use arrow::compute::cast; + + #[test] + fn test_dictionary_buffer() { + let dict_type = + ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8)); + + let d1: ArrayRef = + Arc::new(StringArray::from(vec!["hello", "world", "", "a", "b"])); + + let mut buffer = DictionaryBuffer::::default(); + + // Read some data preserving the dictionary + let values = &[1, 0, 3, 2, 4]; + buffer.as_keys(&d1).unwrap().extend_from_slice(values); + + let mut valid = vec![false, false, true, true, false, true, true, true]; + let valid_buffer = Buffer::from_iter(valid.iter().cloned()); + buffer.pad_nulls(0, values.len(), valid.len(), valid_buffer.as_slice()); + + // Split off some data + + let split = buffer.split_off(4); + let null_buffer = Buffer::from_iter(valid.drain(0..4)); + let array = split.into_array(Some(null_buffer), &dict_type).unwrap(); + assert_eq!(array.data_type(), &dict_type); + + let strings = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = strings.as_any().downcast_ref::().unwrap(); + assert_eq!( + strings.iter().collect::>(), + vec![None, None, Some("world"), Some("hello")] + ); + + // Read some data not preserving the dictionary + + let values = buffer.spill_values().unwrap(); + let read_offset = values.len(); + values.try_push("bingo".as_bytes(), false).unwrap(); + values.try_push("bongo".as_bytes(), false).unwrap(); + + valid.extend_from_slice(&[false, false, true, false, true]); + let null_buffer = Buffer::from_iter(valid.iter().cloned()); + buffer.pad_nulls(read_offset, 2, 5, null_buffer.as_slice()); + + assert_eq!(buffer.len(), 9); + let split = buffer.split_off(9); + + let array = split.into_array(Some(null_buffer), &dict_type).unwrap(); + assert_eq!(array.data_type(), &dict_type); + + let strings = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = strings.as_any().downcast_ref::().unwrap(); + assert_eq!( + strings.iter().collect::>(), + vec![ + None, + Some("a"), + Some(""), + Some("b"), + None, + None, + Some("bingo"), + None, + Some("bongo") + ] + ); + + // Can recreate with new dictionary as values is empty + assert!(matches!(&buffer, DictionaryBuffer::Values { .. })); + assert_eq!(buffer.len(), 0); + let d2 = Arc::new(StringArray::from(vec!["bingo", ""])) as ArrayRef; + buffer + .as_keys(&d2) + .unwrap() + .extend_from_slice(&[0, 1, 0, 1]); + + let array = buffer.split_off(4).into_array(None, &dict_type).unwrap(); + assert_eq!(array.data_type(), &dict_type); + + let strings = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = strings.as_any().downcast_ref::().unwrap(); + assert_eq!( + strings.iter().collect::>(), + vec![Some("bingo"), Some(""), Some("bingo"), Some("")] + ); + + // Can recreate with new dictionary as keys empty + assert!(matches!(&buffer, DictionaryBuffer::Dict { .. })); + assert_eq!(buffer.len(), 0); + let d3 = Arc::new(StringArray::from(vec!["bongo"])) as ArrayRef; + buffer.as_keys(&d3).unwrap().extend_from_slice(&[0, 0]); + + // Cannot change dictionary as keys not empty + let d4 = Arc::new(StringArray::from(vec!["bananas"])) as ArrayRef; + assert!(buffer.as_keys(&d4).is_none()); + } + + #[test] + fn test_validates_keys() { + let dict_type = + ArrowType::Dictionary(Box::new(ArrowType::Int32), Box::new(ArrowType::Utf8)); + + let mut buffer = DictionaryBuffer::::default(); + let d = Arc::new(StringArray::from(vec!["", "f"])) as ArrayRef; + buffer.as_keys(&d).unwrap().extend_from_slice(&[0, 2, 0]); + + let err = buffer.into_array(None, &dict_type).unwrap_err().to_string(); + assert!( + err.contains("dictionary key beyond bounds of dictionary: 0..2"), + "{}", + err + ); + + let mut buffer = DictionaryBuffer::::default(); + let d = Arc::new(StringArray::from(vec![""])) as ArrayRef; + buffer.as_keys(&d).unwrap().extend_from_slice(&[0, 1, 0]); + + let err = buffer.spill_values().unwrap_err().to_string(); + assert!( + err.contains("dictionary key beyond bounds of dictionary: 0..1"), + "{}", + err + ); + } +} diff --git a/arrow/src/arch/mod.rs b/parquet/src/arrow/buffer/mod.rs similarity index 79% rename from arrow/src/arch/mod.rs rename to parquet/src/arrow/buffer/mod.rs index 56d8f4c0e2cf..5ee89aa1a782 100644 --- a/arrow/src/arch/mod.rs +++ b/parquet/src/arrow/buffer/mod.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -/// -/// Arch module contains architecture specific code. -/// Be aware that not all machines have these specific operations available. -#[cfg(all(target_arch = "x86_64", feature = "avx512"))] -pub(crate) mod avx512; +//! Logic for reading data into arrow buffers + +pub mod bit_util; +pub mod converter; +pub mod dictionary_buffer; +pub mod offset_buffer; diff --git a/parquet/src/arrow/buffer/offset_buffer.rs b/parquet/src/arrow/buffer/offset_buffer.rs new file mode 100644 index 000000000000..2d73e3f146b6 --- /dev/null +++ b/parquet/src/arrow/buffer/offset_buffer.rs @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::buffer::bit_util::iter_set_bits_rev; +use crate::arrow::record_reader::buffer::{ + BufferQueue, ScalarBuffer, ScalarValue, ValuesBuffer, +}; +use crate::column::reader::decoder::ValuesBufferSlice; +use crate::errors::{ParquetError, Result}; +use arrow::array::{make_array, ArrayDataBuilder, ArrayRef, OffsetSizeTrait}; +use arrow::buffer::Buffer; +use arrow::datatypes::{ArrowNativeType, DataType as ArrowType}; + +/// A buffer of variable-sized byte arrays that can be converted into +/// a corresponding [`ArrayRef`] +#[derive(Debug)] +pub struct OffsetBuffer { + pub offsets: ScalarBuffer, + pub values: ScalarBuffer, +} + +impl Default for OffsetBuffer { + fn default() -> Self { + let mut offsets = ScalarBuffer::new(); + offsets.resize(1); + Self { + offsets, + values: ScalarBuffer::new(), + } + } +} + +impl OffsetBuffer { + /// Returns the number of byte arrays in this buffer + pub fn len(&self) -> usize { + self.offsets.len() - 1 + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// If `validate_utf8` this verifies that the first character of `data` is + /// the start of a UTF-8 codepoint + /// + /// Note: This does not verify that the entirety of `data` is valid + /// UTF-8. This should be done by calling [`Self::check_valid_utf8`] after + /// all data has been written + pub fn try_push(&mut self, data: &[u8], validate_utf8: bool) -> Result<()> { + if validate_utf8 { + if let Some(&b) = data.first() { + // A valid code-point iff it does not start with 0b10xxxxxx + // Bit-magic taken from `std::str::is_char_boundary` + if (b as i8) < -0x40 { + return Err(ParquetError::General( + "encountered non UTF-8 data".to_string(), + )); + } + } + } + + self.values.extend_from_slice(data); + + let index_offset = I::from_usize(self.values.len()) + .ok_or_else(|| general_err!("index overflow decoding byte array"))?; + + self.offsets.push(index_offset); + Ok(()) + } + + /// Extends this buffer with a list of keys + /// + /// For each value `key` in `keys` this will insert + /// `&dict_values[dict_offsets[key]..dict_offsets[key+1]]` + /// + /// Note: This will validate offsets are valid + pub fn extend_from_dictionary( + &mut self, + keys: &[K], + dict_offsets: &[V], + dict_values: &[u8], + ) -> Result<()> { + for key in keys { + let index = key.to_usize().unwrap(); + if index + 1 >= dict_offsets.len() { + return Err(general_err!( + "dictionary key beyond bounds of dictionary: 0..{}", + dict_offsets.len().saturating_sub(1) + )); + } + let start_offset = dict_offsets[index].to_usize().unwrap(); + let end_offset = dict_offsets[index + 1].to_usize().unwrap(); + + // Dictionary values are verified when decoding dictionary page + self.try_push(&dict_values[start_offset..end_offset], false)?; + } + Ok(()) + } + + /// Validates that `&self.values[start_offset..]` is a valid UTF-8 sequence + /// + /// This MUST be combined with validating that the offsets start on a character + /// boundary, otherwise it would be possible for the values array to be a valid UTF-8 + /// sequence, but not the individual string slices it contains + /// + /// [`Self::try_push`] can perform this validation check on insertion + pub fn check_valid_utf8(&self, start_offset: usize) -> Result<()> { + match std::str::from_utf8(&self.values.as_slice()[start_offset..]) { + Ok(_) => Ok(()), + Err(e) => Err(general_err!("encountered non UTF-8 data: {}", e)), + } + } + + /// Converts this into an [`ArrayRef`] with the provided `data_type` and `null_buffer` + pub fn into_array( + self, + null_buffer: Option, + data_type: ArrowType, + ) -> ArrayRef { + let array_data_builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .add_buffer(self.offsets.into()) + .add_buffer(self.values.into()) + .null_bit_buffer(null_buffer); + + let data = match cfg!(debug_assertions) { + true => array_data_builder.build().unwrap(), + false => unsafe { array_data_builder.build_unchecked() }, + }; + + make_array(data) + } +} + +impl BufferQueue for OffsetBuffer { + type Output = Self; + type Slice = Self; + + fn split_off(&mut self, len: usize) -> Self::Output { + assert!(self.offsets.len() > len, "{} > {}", self.offsets.len(), len); + let remaining_offsets = self.offsets.len() - len - 1; + let offsets = self.offsets.as_slice(); + + let end_offset = offsets[len]; + + let mut new_offsets = ScalarBuffer::new(); + new_offsets.reserve(remaining_offsets + 1); + for v in &offsets[len..] { + new_offsets.push(*v - end_offset) + } + + self.offsets.resize(len + 1); + + Self { + offsets: std::mem::replace(&mut self.offsets, new_offsets), + values: self.values.take(end_offset.to_usize().unwrap()), + } + } + + fn spare_capacity_mut(&mut self, _batch_size: usize) -> &mut Self::Slice { + self + } + + fn set_len(&mut self, len: usize) { + assert_eq!(self.offsets.len(), len + 1); + } +} + +impl ValuesBuffer for OffsetBuffer { + fn pad_nulls( + &mut self, + read_offset: usize, + values_read: usize, + levels_read: usize, + valid_mask: &[u8], + ) { + assert_eq!(self.offsets.len(), read_offset + values_read + 1); + self.offsets.resize(read_offset + levels_read + 1); + + let offsets = self.offsets.as_slice_mut(); + + let mut last_pos = read_offset + levels_read + 1; + let mut last_start_offset = I::from_usize(self.values.len()).unwrap(); + + let values_range = read_offset..read_offset + values_read; + for (value_pos, level_pos) in values_range + .clone() + .rev() + .zip(iter_set_bits_rev(valid_mask)) + { + assert!(level_pos >= value_pos); + assert!(level_pos < last_pos); + + let end_offset = offsets[value_pos + 1]; + let start_offset = offsets[value_pos]; + + // Fill in any nulls + for x in &mut offsets[level_pos + 1..last_pos] { + *x = end_offset; + } + + if level_pos == value_pos { + return; + } + + offsets[level_pos] = start_offset; + last_pos = level_pos; + last_start_offset = start_offset; + } + + // Pad leading nulls up to `last_offset` + for x in &mut offsets[values_range.start + 1..last_pos] { + *x = last_start_offset + } + } +} + +impl ValuesBufferSlice for OffsetBuffer { + fn capacity(&self) -> usize { + usize::MAX + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, LargeStringArray, StringArray}; + + #[test] + fn test_offset_buffer_empty() { + let buffer = OffsetBuffer::::default(); + let array = buffer.into_array(None, ArrowType::Utf8); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.len(), 0); + } + + #[test] + fn test_offset_buffer_append() { + let mut buffer = OffsetBuffer::::default(); + buffer.try_push("hello".as_bytes(), true).unwrap(); + buffer.try_push("bar".as_bytes(), true).unwrap(); + buffer + .extend_from_dictionary(&[1, 3, 0, 2], &[0, 2, 4, 5, 6], "abcdef".as_bytes()) + .unwrap(); + + let array = buffer.into_array(None, ArrowType::LargeUtf8); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!( + strings.iter().map(|x| x.unwrap()).collect::>(), + vec!["hello", "bar", "cd", "f", "ab", "e"] + ) + } + + #[test] + fn test_offset_buffer_split() { + let mut buffer = OffsetBuffer::::default(); + for v in ["hello", "world", "cupcakes", "a", "b", "c"] { + buffer.try_push(v.as_bytes(), false).unwrap() + } + let split = buffer.split_off(3); + + let array = split.into_array(None, ArrowType::Utf8); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!( + strings.iter().map(|x| x.unwrap()).collect::>(), + vec!["hello", "world", "cupcakes"] + ); + + buffer.try_push("test".as_bytes(), false).unwrap(); + let array = buffer.into_array(None, ArrowType::Utf8); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!( + strings.iter().map(|x| x.unwrap()).collect::>(), + vec!["a", "b", "c", "test"] + ); + } + + #[test] + fn test_offset_buffer_pad_nulls() { + let mut buffer = OffsetBuffer::::default(); + let values = ["a", "b", "c", "def", "gh"]; + for v in &values { + buffer.try_push(v.as_bytes(), false).unwrap() + } + + let valid = vec![ + true, false, false, true, false, true, false, true, true, false, false, + ]; + let valid_mask = Buffer::from_iter(valid.iter().cloned()); + + // Both trailing and leading nulls + buffer.pad_nulls(1, values.len() - 1, valid.len() - 1, valid_mask.as_slice()); + + let array = buffer.into_array(Some(valid_mask), ArrowType::Utf8); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!( + strings.iter().collect::>(), + vec![ + Some("a"), + None, + None, + Some("b"), + None, + Some("c"), + None, + Some("def"), + Some("gh"), + None, + None + ] + ); + } + + #[test] + fn test_utf8_validation() { + let valid_2_byte_utf8 = &[0b11001000, 0b10001000]; + std::str::from_utf8(valid_2_byte_utf8).unwrap(); + let valid_3_byte_utf8 = &[0b11101000, 0b10001000, 0b10001000]; + std::str::from_utf8(valid_3_byte_utf8).unwrap(); + let valid_4_byte_utf8 = &[0b11110010, 0b10101000, 0b10101001, 0b10100101]; + std::str::from_utf8(valid_4_byte_utf8).unwrap(); + + let mut buffer = OffsetBuffer::::default(); + buffer.try_push(valid_2_byte_utf8, true).unwrap(); + buffer.try_push(valid_3_byte_utf8, true).unwrap(); + buffer.try_push(valid_4_byte_utf8, true).unwrap(); + + // Cannot append string starting with incomplete codepoint + buffer.try_push(&valid_2_byte_utf8[1..], true).unwrap_err(); + buffer.try_push(&valid_3_byte_utf8[1..], true).unwrap_err(); + buffer.try_push(&valid_3_byte_utf8[2..], true).unwrap_err(); + buffer.try_push(&valid_4_byte_utf8[1..], true).unwrap_err(); + buffer.try_push(&valid_4_byte_utf8[2..], true).unwrap_err(); + buffer.try_push(&valid_4_byte_utf8[3..], true).unwrap_err(); + + // Can append data containing an incomplete codepoint + buffer.try_push(&[0b01111111, 0b10111111], true).unwrap(); + + assert_eq!(buffer.len(), 4); + assert_eq!(buffer.values.len(), 11); + + buffer.try_push(valid_3_byte_utf8, true).unwrap(); + + // Should fail due to incomplete codepoint + buffer.check_valid_utf8(0).unwrap_err(); + + // After broken codepoint -> success + buffer.check_valid_utf8(11).unwrap(); + + // Fails if run from middle of codepoint + buffer.check_valid_utf8(12).unwrap_err(); + } + + #[test] + fn test_pad_nulls_empty() { + let mut buffer = OffsetBuffer::::default(); + let valid_mask = Buffer::from_iter(std::iter::repeat(false).take(9)); + buffer.pad_nulls(0, 0, 9, valid_mask.as_slice()); + + let array = buffer.into_array(Some(valid_mask), ArrowType::Utf8); + let strings = array.as_any().downcast_ref::().unwrap(); + + assert_eq!(strings.len(), 9); + assert!(strings.iter().all(|x| x.is_none())) + } +} diff --git a/parquet/src/arrow/levels.rs b/parquet/src/arrow/levels.rs deleted file mode 100644 index c9b6052aeb87..000000000000 --- a/parquet/src/arrow/levels.rs +++ /dev/null @@ -1,1678 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Parquet definition and repetition levels -//! -//! Contains the algorithm for computing definition and repetition levels. -//! The algorithm works by tracking the slots of an array that should -//! ultimately be populated when writing to Parquet. -//! Parquet achieves nesting through definition levels and repetition levels \[1\]. -//! Definition levels specify how many optional fields in the part for the column -//! are defined. -//! Repetition levels specify at what repeated field (list) in the path a column -//! is defined. -//! -//! In a nested data structure such as `a.b.c`, one can see levels as defining -//! whether a record is defined at `a`, `a.b`, or `a.b.c`. -//! Optional fields are nullable fields, thus if all 3 fields -//! are nullable, the maximum definition could be = 3 if there are no lists. -//! -//! The algorithm in this module computes the necessary information to enable -//! the writer to keep track of which columns are at which levels, and to extract -//! the correct values at the correct slots from Arrow arrays. -//! -//! It works by walking a record batch's arrays, keeping track of what values -//! are non-null, their positions and computing what their levels are. -//! -//! \[1\] [parquet-format#nested-encoding](https://github.com/apache/parquet-format#nested-encoding) - -use arrow::array::{make_array, ArrayRef, MapArray, StructArray}; -use arrow::datatypes::{DataType, Field}; - -/// Keeps track of the level information per array that is needed to write an Arrow array to Parquet. -/// -/// When a nested schema is traversed, intermediate [LevelInfo] structs are created to track -/// the state of parent arrays. When a primitive Arrow array is encountered, a final [LevelInfo] -/// is created, and this is what is used to index into the array when writing data to Parquet. -#[derive(Debug, Eq, PartialEq, Clone)] -pub(crate) struct LevelInfo { - /// Array's definition levels - pub definition: Vec, - /// Array's optional repetition levels - pub repetition: Option>, - /// Array's offsets, 64-bit is used to accommodate large offset arrays - pub array_offsets: Vec, - // TODO: Convert to an Arrow Buffer after ARROW-10766 is merged. - /// Array's logical validity mask, whcih gets unpacked for list children. - /// If the parent of an array is null, all children are logically treated as - /// null. This mask keeps track of that. - /// - pub array_mask: Vec, - /// The maximum definition at this level, 0 at the record batch - pub max_definition: i16, - /// The type of array represented by this level info - pub level_type: LevelType, - /// The offset of the current level's array - pub offset: usize, - /// The length of the current level's array - pub length: usize, -} - -/// LevelType defines the type of level, and whether it is nullable or not -#[derive(Debug, Eq, PartialEq, Clone, Copy)] -pub(crate) enum LevelType { - Root, - List(bool), - Struct(bool), - Primitive(bool), -} - -impl LevelType { - #[inline] - const fn level_increment(&self) -> i16 { - match self { - LevelType::Root => 0, - // List repetition adds a constant 1 - LevelType::List(is_nullable) => 1 + *is_nullable as i16, - LevelType::Struct(is_nullable) | LevelType::Primitive(is_nullable) => { - *is_nullable as i16 - } - } - } -} - -impl LevelInfo { - /// Create a new [LevelInfo] by filling `length` slots, and setting an initial offset. - /// - /// This is a convenience function to populate the starting point of the traversal. - pub(crate) fn new(offset: usize, length: usize) -> Self { - Self { - // a batch has no definition level yet - definition: vec![0; length], - // a batch has no repetition as it is not a list - repetition: None, - // a batch has sequential offsets, should be num_rows + 1 - array_offsets: (0..=(length as i64)).collect(), - // all values at a batch-level are non-null - array_mask: vec![true; length], - max_definition: 0, - level_type: LevelType::Root, - offset, - length, - } - } - - /// Compute nested levels of the Arrow array, recursing into lists and structs. - /// - /// Returns a list of `LevelInfo`, where each level is for nested primitive arrays. - /// - /// The parent struct's nullness is tracked, as it determines whether the child - /// max_definition should be incremented. - /// The 'is_parent_struct' variable asks "is this field's parent a struct?". - /// * If we are starting at a [RecordBatch], this is `false`. - /// * If we are calculating a list's child, this is `false`. - /// * If we are calculating a struct (i.e. `field.data_type90 == Struct`), - /// this depends on whether the struct is a child of a struct. - /// * If we are calculating a field inside a [StructArray], this is 'true'. - pub(crate) fn calculate_array_levels( - &self, - array: &ArrayRef, - field: &Field, - ) -> Vec { - let (array_offsets, array_mask) = - Self::get_array_offsets_and_masks(array, self.offset, self.length); - match array.data_type() { - DataType::Null => vec![Self { - definition: self.definition.clone(), - repetition: self.repetition.clone(), - array_offsets, - array_mask, - max_definition: self.max_definition.max(1), - // Null type is always nullable - level_type: LevelType::Primitive(true), - offset: self.offset, - length: self.length, - }], - DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Timestamp(_, _) - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) - | DataType::Duration(_) - | DataType::Interval(_) - | DataType::Binary - | DataType::LargeBinary - | DataType::Decimal(_, _) - | DataType::FixedSizeBinary(_) => { - // we return a vector of 1 value to represent the primitive - vec![self.calculate_child_levels( - array_offsets, - array_mask, - LevelType::Primitive(field.is_nullable()), - )] - } - DataType::List(list_field) | DataType::LargeList(list_field) => { - let child_offset = array_offsets[0] as usize; - let child_len = *array_offsets.last().unwrap() as usize; - // Calculate the list level - let list_level = self.calculate_child_levels( - array_offsets, - array_mask, - LevelType::List(field.is_nullable()), - ); - - // Construct the child array of the list, and get its offset + mask - let array_data = array.data(); - let child_data = array_data.child_data().get(0).unwrap(); - let child_array = make_array(child_data.clone()); - let (child_offsets, child_mask) = Self::get_array_offsets_and_masks( - &child_array, - child_offset, - child_len - child_offset, - ); - - match child_array.data_type() { - // TODO: The behaviour of a > is untested - DataType::Null => vec![list_level], - DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Timestamp(_, _) - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) - | DataType::Duration(_) - | DataType::Interval(_) - | DataType::Binary - | DataType::LargeBinary - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Dictionary(_, _) - | DataType::Decimal(_, _) - | DataType::FixedSizeBinary(_) => { - vec![list_level.calculate_child_levels( - child_offsets, - child_mask, - LevelType::Primitive(list_field.is_nullable()), - )] - } - DataType::List(_) - | DataType::LargeList(_) - | DataType::Struct(_) - | DataType::Map(_, _) => { - list_level.calculate_array_levels(&child_array, list_field) - } - DataType::FixedSizeList(_, _) => unimplemented!(), - DataType::Union(_) => unimplemented!(), - } - } - DataType::Map(map_field, _) => { - // Calculate the map level - let map_level = self.calculate_child_levels( - array_offsets, - array_mask, - // A map is treated like a list as it has repetition - LevelType::List(field.is_nullable()), - ); - - let map_array = array.as_any().downcast_ref::().unwrap(); - - let key_array = map_array.keys(); - let value_array = map_array.values(); - - if let DataType::Struct(fields) = map_field.data_type() { - let key_field = &fields[0]; - let value_field = &fields[1]; - - let mut map_levels = vec![]; - - // Get key levels - let mut key_levels = - map_level.calculate_array_levels(&key_array, key_field); - map_levels.append(&mut key_levels); - - let mut value_levels = - map_level.calculate_array_levels(&value_array, value_field); - map_levels.append(&mut value_levels); - - map_levels - } else { - panic!( - "Map field should be a struct, found {:?}", - map_field.data_type() - ); - } - } - DataType::FixedSizeList(_, _) => unimplemented!(), - DataType::Struct(struct_fields) => { - let struct_array: &StructArray = array - .as_any() - .downcast_ref::() - .expect("Unable to get struct array"); - let struct_level = self.calculate_child_levels( - array_offsets, - array_mask, - LevelType::Struct(field.is_nullable()), - ); - let mut struct_levels = vec![]; - struct_array - .columns() - .into_iter() - .zip(struct_fields) - .for_each(|(child_array, child_field)| { - let mut levels = - struct_level.calculate_array_levels(child_array, child_field); - struct_levels.append(&mut levels); - }); - struct_levels - } - DataType::Union(_) => unimplemented!(), - DataType::Dictionary(_, _) => { - // Need to check for these cases not implemented in C++: - // - "Writing DictionaryArray with nested dictionary type not yet supported" - // - "Writing DictionaryArray with null encoded in dictionary type not yet supported" - // vec![self.get_primitive_def_levels(array, field, array_mask)] - vec![self.calculate_child_levels( - array_offsets, - array_mask, - LevelType::Primitive(field.is_nullable()), - )] - } - } - } - - /// Calculate child/leaf array levels. - /// - /// The algorithm works by incrementing definitions of array values based on whether: - /// - a value is optional or required (is_nullable) - /// - a list value is repeated + optional or required (is_list) - /// - /// A record batch always starts at a populated definition = level 0. - /// When a batch only has a primitive, i.e. `>, column `a` - /// can only have a maximum level of 1 if it is not null. - /// If it is not null, we increment by 1, such that the null slots will = level 1. - /// The above applies to types that have no repetition (anything not a list or map). - /// - /// If a batch has lists, then we increment by up to 2 levels: - /// - 1 level for the list (repeated) - /// - 1 level if the list itself is nullable (optional) - /// - /// A list's child then gets incremented using the above rules. - /// - /// *Exceptions* - /// - /// There are 2 exceptions from the above rules: - /// - /// 1. When at the root of the schema: We always increment the - /// level regardless of whether the child is nullable or not. If we do not do - /// this, we could have a non-nullable array having a definition of 0. - /// - /// 2. List parent, non-list child: We always increment the level in this case, - /// regardless of whether the child is nullable or not. - /// - /// *Examples* - /// - /// A batch with only a primitive that's non-nullable. ``: - /// * We don't increment the definition level as the array is not optional. - /// * This would leave us with a definition of 0, so the first exception applies. - /// * The definition level becomes 1. - /// - /// A batch with only a primitive that's nullable. ``: - /// * The definition level becomes 1, as we increment it once. - /// - /// A batch with a single non-nullable list (both list and child not null): - /// * We calculate the level twice, for the list, and for the child. - /// * At the list, the level becomes 1, where 0 indicates that the list is - /// empty, and 1 says it's not (determined through offsets). - /// * At the primitive level, the second exception applies. The level becomes 2. - fn calculate_child_levels( - &self, - // we use 64-bit offsets to also accommodate large arrays - array_offsets: Vec, - array_mask: Vec, - level_type: LevelType, - ) -> Self { - let min_len = *(array_offsets.last().unwrap()) as usize; - let mut definition = Vec::with_capacity(min_len); - let mut repetition = Vec::with_capacity(min_len); - let mut merged_array_mask = Vec::with_capacity(min_len); - - let max_definition = match (self.level_type, level_type) { - // Handle the illegal cases - (_, LevelType::Root) => { - unreachable!("Cannot have a root as a child") - } - (LevelType::Primitive(_), _) => { - unreachable!("Cannot have a primitive parent for any type") - } - // The general case - (_, _) => self.max_definition + level_type.level_increment(), - }; - - match (self.level_type, level_type) { - (LevelType::List(_), LevelType::List(is_nullable)) => { - // Parent is a list or descendant of a list, and child is a list - let reps = self.repetition.clone().unwrap(); - - // List is null, and not empty - let l1 = max_definition - is_nullable as i16; - // List is not null, but is empty - let l2 = max_definition - 1; - // List is not null, and not empty - let l3 = max_definition; - - let mut nulls_seen = 0; - - self.array_offsets.windows(2).for_each(|w| { - let start = w[0] as usize; - let end = w[1] as usize; - let parent_len = end - start; - - if parent_len == 0 { - // If the parent length is 0, there won't be a slot for the child - let index = start + nulls_seen - self.offset; - definition.push(self.definition[index]); - repetition.push(0); - merged_array_mask.push(self.array_mask[index]); - nulls_seen += 1; - } else { - (start..end).for_each(|parent_index| { - let index = parent_index + nulls_seen - self.offset; - let parent_index = parent_index - self.offset; - - // parent is either defined at this level, or earlier - let parent_def = self.definition[index]; - let parent_rep = reps[index]; - let parent_mask = self.array_mask[index]; - - // valid parent, index into children - let child_start = array_offsets[parent_index] as usize; - let child_end = array_offsets[parent_index + 1] as usize; - let child_len = child_end - child_start; - let child_mask = array_mask[parent_index]; - let merged_mask = parent_mask && child_mask; - - if child_len == 0 { - // Empty slot, i.e. {"parent": {"child": [] } } - // Nullness takes priority over emptiness - definition.push(if child_mask { l2 } else { l1 }); - repetition.push(parent_rep); - merged_array_mask.push(merged_mask); - } else { - (child_start..child_end).for_each(|child_index| { - let rep = match ( - parent_index == start, - child_index == child_start, - ) { - (true, true) => parent_rep, - (true, false) => parent_rep + 2, - (false, true) => parent_rep, - (false, false) => parent_rep + 1, - }; - - definition.push(if !parent_mask { - parent_def - } else if child_mask { - l3 - } else { - l1 - }); - repetition.push(rep); - merged_array_mask.push(merged_mask); - }); - } - }); - } - }); - - debug_assert_eq!(definition.len(), merged_array_mask.len()); - - let offset = *array_offsets.first().unwrap() as usize; - let length = *array_offsets.last().unwrap() as usize - offset; - - Self { - definition, - repetition: Some(repetition), - array_offsets, - array_mask: merged_array_mask, - max_definition, - level_type, - offset: offset + self.offset, - length, - } - } - (LevelType::List(_), _) => { - // List and primitive (or struct). - // The list can have more values than the primitive, indicating that there - // are slots where the list is empty. We use a counter to track this behaviour. - let mut nulls_seen = 0; - - // let child_max_definition = list_max_definition + is_nullable as i16; - // child values are a function of parent list offsets - let reps = self.repetition.as_deref().unwrap(); - self.array_offsets.windows(2).for_each(|w| { - let start = w[0] as usize; - let end = w[1] as usize; - let parent_len = end - start; - - if parent_len == 0 { - let index = start + nulls_seen - self.offset; - definition.push(self.definition[index]); - repetition.push(reps[index]); - merged_array_mask.push(self.array_mask[index]); - nulls_seen += 1; - } else { - // iterate through the array, adjusting child definitions for nulls - (start..end).for_each(|child_index| { - let index = child_index + nulls_seen - self.offset; - let child_mask = array_mask[child_index - self.offset]; - let parent_mask = self.array_mask[index]; - let parent_def = self.definition[index]; - - if !parent_mask || parent_def < self.max_definition { - definition.push(parent_def); - repetition.push(reps[index]); - merged_array_mask.push(parent_mask); - } else { - definition.push(max_definition - !child_mask as i16); - repetition.push(reps[index]); - merged_array_mask.push(child_mask); - } - }); - } - }); - - debug_assert_eq!(definition.len(), merged_array_mask.len()); - - let offset = *array_offsets.first().unwrap() as usize; - let length = *array_offsets.last().unwrap() as usize - offset; - - Self { - definition, - repetition: Some(repetition), - array_offsets: self.array_offsets.clone(), - array_mask: merged_array_mask, - max_definition, - level_type, - offset: offset + self.offset, - length, - } - } - (_, LevelType::List(is_nullable)) => { - // Encountering a list for the first time. - // Calculate the 2 list hierarchy definitions in advance - - // List is null, and not empty - let l1 = max_definition - 1 - is_nullable as i16; - // List is not null, but is empty - let l2 = max_definition - 1; - // List is not null, and not empty - let l3 = max_definition; - - self.definition - .iter() - .enumerate() - .for_each(|(parent_index, def)| { - let child_from = array_offsets[parent_index]; - let child_to = array_offsets[parent_index + 1]; - let child_len = child_to - child_from; - let child_mask = array_mask[parent_index]; - let parent_mask = self.array_mask[parent_index]; - - match (parent_mask, child_len) { - (true, 0) => { - // Empty slot, i.e. {"parent": {"child": [] } } - // Nullness takes priority over emptiness - definition.push(if child_mask { l2 } else { l1 }); - repetition.push(0); - merged_array_mask.push(child_mask); - } - (false, 0) => { - // Inherit the parent definition as parent was null - definition.push(*def); - repetition.push(0); - merged_array_mask.push(child_mask); - } - (true, _) => { - (child_from..child_to).for_each(|child_index| { - // l1 and l3 make sense as list is not empty, - // but we reflect that it's either null or not - definition.push(if child_mask { l3 } else { l1 }); - // Mark the first child slot as 0, and the next as 1 - repetition.push(if child_index == child_from { - 0 - } else { - 1 - }); - merged_array_mask.push(child_mask); - }); - } - (false, _) => { - (child_from..child_to).for_each(|child_index| { - // Inherit the parent definition as parent was null - definition.push(*def); - // mark the first child slot as 0, and the next as 1 - repetition.push(if child_index == child_from { - 0 - } else { - 1 - }); - merged_array_mask.push(false); - }); - } - } - }); - - debug_assert_eq!(definition.len(), merged_array_mask.len()); - - let offset = *array_offsets.first().unwrap() as usize; - let length = *array_offsets.last().unwrap() as usize - offset; - - Self { - definition, - repetition: Some(repetition), - array_offsets, - array_mask: merged_array_mask, - max_definition, - level_type, - offset, - length, - } - } - (_, _) => { - self.definition - .iter() - .zip(array_mask.into_iter().zip(&self.array_mask)) - .for_each(|(current_def, (child_mask, parent_mask))| { - merged_array_mask.push(*parent_mask && child_mask); - match (parent_mask, child_mask) { - (true, true) => { - definition.push(max_definition); - } - (true, false) => { - // The child is only legally null if its array is nullable. - // Thus parent's max_definition is lower - definition.push(if *current_def <= self.max_definition { - *current_def - } else { - self.max_definition - }); - } - // if the parent was false, retain its definitions - (false, _) => { - definition.push(*current_def); - } - } - }); - - debug_assert_eq!(definition.len(), merged_array_mask.len()); - - Self { - definition, - repetition: self.repetition.clone(), // it's None - array_offsets, - array_mask: merged_array_mask, - max_definition, - level_type, - // Inherit parent offset and length - offset: self.offset, - length: self.length, - } - } - } - } - - /// Get the offsets of an array as 64-bit values, and validity masks as booleans - /// - Primitive, binary and struct arrays' offsets will be a sequence, masks obtained - /// from validity bitmap - /// - List array offsets will be the value offsets, masks are computed from offsets - fn get_array_offsets_and_masks( - array: &ArrayRef, - offset: usize, - len: usize, - ) -> (Vec, Vec) { - match array.data_type() { - DataType::Null - | DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float16 - | DataType::Float32 - | DataType::Float64 - | DataType::Timestamp(_, _) - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Time64(_) - | DataType::Duration(_) - | DataType::Interval(_) - | DataType::Binary - | DataType::LargeBinary - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Struct(_) - | DataType::Dictionary(_, _) - | DataType::Decimal(_, _) => { - let array_mask = match array.data().null_buffer() { - Some(buf) => get_bool_array_slice(buf, array.offset() + offset, len), - None => vec![true; len], - }; - ((0..=(len as i64)).collect(), array_mask) - } - DataType::List(_) | DataType::Map(_, _) => { - let data = array.data(); - let offsets = unsafe { data.buffers()[0].typed_data::() }; - let offsets = offsets - .to_vec() - .into_iter() - .skip(offset) - .take(len + 1) - .map(|v| v as i64) - .collect::>(); - let array_mask = match array.data().null_buffer() { - Some(buf) => get_bool_array_slice(buf, array.offset() + offset, len), - None => vec![true; len], - }; - (offsets, array_mask) - } - DataType::LargeList(_) => { - let offsets = unsafe { array.data().buffers()[0].typed_data::() } - .iter() - .skip(offset) - .take(len + 1) - .copied() - .collect(); - let array_mask = match array.data().null_buffer() { - Some(buf) => get_bool_array_slice(buf, array.offset() + offset, len), - None => vec![true; len], - }; - (offsets, array_mask) - } - DataType::FixedSizeBinary(value_len) => { - let array_mask = match array.data().null_buffer() { - Some(buf) => get_bool_array_slice(buf, array.offset() + offset, len), - None => vec![true; len], - }; - let value_len = *value_len as i64; - ( - (0..=(len as i64)).map(|v| v * value_len).collect(), - array_mask, - ) - } - DataType::FixedSizeList(_, _) | DataType::Union(_) => { - unimplemented!("Getting offsets not yet implemented") - } - } - } - - /// Given a level's information, calculate the offsets required to index an array correctly. - pub(crate) fn filter_array_indices(&self) -> Vec { - // happy path if not dealing with lists - let is_nullable = match self.level_type { - LevelType::Primitive(is_nullable) => is_nullable, - _ => panic!( - "Cannot filter indices on a non-primitive array, found {:?}", - self.level_type - ), - }; - if self.repetition.is_none() { - return self - .definition - .iter() - .enumerate() - .filter_map(|(i, def)| { - if *def == self.max_definition { - Some(i) - } else { - None - } - }) - .collect(); - } - let mut filtered = vec![]; - // remove slots that are false from definition_mask - let mut index = 0; - self.definition.iter().for_each(|def| { - if *def == self.max_definition { - filtered.push(index); - } - if *def >= self.max_definition - is_nullable as i16 { - index += 1; - } - }); - filtered - } -} - -/// Convert an Arrow buffer to a boolean array slice -/// TODO: this was created for buffers, so might not work for bool array, might be slow too -#[inline] -fn get_bool_array_slice( - buffer: &arrow::buffer::Buffer, - offset: usize, - len: usize, -) -> Vec { - let data = buffer.as_slice(); - (offset..(len + offset)) - .map(|i| arrow::util::bit_util::get_bit(data, i)) - .collect() -} - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use arrow::array::*; - use arrow::buffer::Buffer; - use arrow::datatypes::{Schema, ToByteSlice}; - use arrow::record_batch::RecordBatch; - - #[test] - fn test_calculate_array_levels_twitter_example() { - // based on the example at https://blog.twitter.com/engineering/en_us/a/2013/dremel-made-simple-with-parquet.html - // [[a, b, c], [d, e, f, g]], [[h], [i,j]] - let parent_levels = LevelInfo { - definition: vec![0, 0], - repetition: None, - array_offsets: vec![0, 1, 2], // 2 records, root offsets always sequential - array_mask: vec![true, true], // both lists defined - max_definition: 0, - level_type: LevelType::Root, - offset: 0, - length: 2, - }; - // offset into array, each level1 has 2 values - let array_offsets = vec![0, 2, 4]; - let array_mask = vec![true, true]; - - // calculate level1 levels - let levels = parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask, - LevelType::List(false), - ); - // - let expected_levels = LevelInfo { - definition: vec![1, 1, 1, 1], - repetition: Some(vec![0, 1, 0, 1]), - array_offsets, - array_mask: vec![true, true, true, true], - max_definition: 1, - level_type: LevelType::List(false), - offset: 0, - length: 4, - }; - // the separate asserts make it easier to see what's failing - assert_eq!(&levels.definition, &expected_levels.definition); - assert_eq!(&levels.repetition, &expected_levels.repetition); - assert_eq!(&levels.array_mask, &expected_levels.array_mask); - assert_eq!(&levels.array_offsets, &expected_levels.array_offsets); - assert_eq!(&levels.max_definition, &expected_levels.max_definition); - assert_eq!(&levels.level_type, &expected_levels.level_type); - // this assert is to help if there are more variables added to the struct - assert_eq!(&levels, &expected_levels); - - // level2 - let parent_levels = levels; - let array_offsets = vec![0, 3, 7, 8, 10]; - let array_mask = vec![true, true, true, true]; - let levels = parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask, - LevelType::List(false), - ); - let expected_levels = LevelInfo { - definition: vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2], - repetition: Some(vec![0, 2, 2, 1, 2, 2, 2, 0, 1, 2]), - array_offsets, - array_mask: vec![true; 10], - max_definition: 2, - level_type: LevelType::List(false), - offset: 0, - length: 10, - }; - assert_eq!(&levels.definition, &expected_levels.definition); - assert_eq!(&levels.repetition, &expected_levels.repetition); - assert_eq!(&levels.array_mask, &expected_levels.array_mask); - assert_eq!(&levels.max_definition, &expected_levels.max_definition); - assert_eq!(&levels.array_offsets, &expected_levels.array_offsets); - assert_eq!(&levels.level_type, &expected_levels.level_type); - assert_eq!(&levels, &expected_levels); - } - - #[test] - fn test_calculate_one_level_1() { - // This test calculates the levels for a non-null primitive array - let parent_levels = LevelInfo { - definition: vec![0; 10], - repetition: None, - array_offsets: (0..=10).collect(), - array_mask: vec![true; 10], - max_definition: 0, - level_type: LevelType::Root, - offset: 0, - length: 10, - }; - let array_offsets: Vec = (0..=10).collect(); - let array_mask = vec![true; 10]; - - let levels = parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask.clone(), - LevelType::Primitive(false), - ); - let expected_levels = LevelInfo { - // As it is non-null, definitions can be omitted - definition: vec![0; 10], - repetition: None, - array_offsets, - array_mask, - max_definition: 0, - level_type: LevelType::Primitive(false), - offset: 0, - length: 10, - }; - assert_eq!(&levels, &expected_levels); - } - - #[test] - fn test_calculate_one_level_2() { - // This test calculates the levels for a non-null primitive array - let parent_levels = LevelInfo { - definition: vec![0; 5], - repetition: None, - array_offsets: (0..=5).collect(), - array_mask: vec![true, true, true, true, true], - max_definition: 0, - level_type: LevelType::Root, - offset: 0, - length: 5, - }; - let array_offsets: Vec = (0..=5).collect(); - let array_mask = vec![true, false, true, true, false]; - - let levels = parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask.clone(), - LevelType::Primitive(true), - ); - let expected_levels = LevelInfo { - definition: vec![1, 0, 1, 1, 0], - repetition: None, - array_offsets, - array_mask, - max_definition: 1, - level_type: LevelType::Primitive(true), - offset: 0, - length: 5, - }; - assert_eq!(&levels, &expected_levels); - } - - #[test] - fn test_calculate_array_levels_1() { - // if all array values are defined (e.g. batch>) - // [[0], [1], [2], [3], [4]] - let parent_levels = LevelInfo { - definition: vec![0; 5], - repetition: None, - array_offsets: vec![0, 1, 2, 3, 4, 5], - array_mask: vec![true, true, true, true, true], - max_definition: 0, - level_type: LevelType::Root, - offset: 0, - length: 5, - }; - let array_offsets = vec![0, 2, 2, 4, 8, 11]; - let array_mask = vec![true, false, true, true, true]; - - let levels = parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask, - LevelType::List(true), - ); - // array: [[0, 0], _1_, [2, 2], [3, 3, 3, 3], [4, 4, 4]] - // all values are defined as we do not have nulls on the root (batch) - // repetition: - // 0: 0, 1 - // 1: - // 2: 0, 1 - // 3: 0, 1, 1, 1 - // 4: 0, 1, 1 - let expected_levels = LevelInfo { - // The levels are normally 2 because we: - // - Calculate the level at the list - // - Calculate the level at the list's child - // We do not do this in these tests, thus the levels are 1 less. - definition: vec![2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2], - repetition: Some(vec![0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1]), - array_offsets, - array_mask: vec![ - true, true, false, true, true, true, true, true, true, true, true, true, - ], - max_definition: 2, - level_type: LevelType::List(true), - offset: 0, - length: 11, // the child has 11 slots - }; - assert_eq!(&levels.definition, &expected_levels.definition); - assert_eq!(&levels.repetition, &expected_levels.repetition); - assert_eq!(&levels.array_offsets, &expected_levels.array_offsets); - assert_eq!(&levels.max_definition, &expected_levels.max_definition); - assert_eq!(&levels.level_type, &expected_levels.level_type); - assert_eq!(&levels, &expected_levels); - } - - #[test] - fn test_calculate_array_levels_2() { - // If some values are null - // - // This emulates an array in the form: > - // with values: - // - 0: [0, 1], but is null because of the struct - // - 1: [] - // - 2: [2, 3], but is null because of the struct - // - 3: [4, 5, 6, 7] - // - 4: [8, 9, 10] - // - // If the first values of a list are null due to a parent, we have to still account for them - // while indexing, because they would affect the way the child is indexed - // i.e. in the above example, we have to know that [0, 1] has to be skipped - let parent_levels = LevelInfo { - definition: vec![0, 1, 0, 1, 1], - repetition: None, - array_offsets: vec![0, 1, 2, 3, 4, 5], - array_mask: vec![false, true, false, true, true], - max_definition: 1, - level_type: LevelType::Struct(true), - offset: 0, - length: 5, - }; - let array_offsets = vec![0, 2, 2, 4, 8, 11]; - let array_mask = vec![true, false, true, true, true]; - - let levels = parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask, - LevelType::List(true), - ); - let expected_levels = LevelInfo { - // 0 1 [2] are 0 (not defined at level 1) - // [2] is 1, but has 0 slots so is not populated (defined at level 1 only) - // 2 3 [4] are 0 - // 4 5 6 7 [8] are 1 (defined at level 1 only) - // 8 9 10 [11] are 2 (defined at both levels) - definition: vec![0, 0, 1, 0, 0, 3, 3, 3, 3, 3, 3, 3], - repetition: Some(vec![0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1]), - array_offsets, - array_mask: vec![ - false, false, false, false, false, true, true, true, true, true, true, - true, - ], - max_definition: 3, - level_type: LevelType::List(true), - offset: 0, - length: 11, - }; - assert_eq!(&levels.definition, &expected_levels.definition); - assert_eq!(&levels.repetition, &expected_levels.repetition); - assert_eq!(&levels.array_offsets, &expected_levels.array_offsets); - assert_eq!(&levels.max_definition, &expected_levels.max_definition); - assert_eq!(&levels.level_type, &expected_levels.level_type); - assert_eq!(&levels, &expected_levels); - - // nested lists (using previous test) - let nested_parent_levels = levels; - let array_offsets = vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22]; - let array_mask = vec![ - true, true, true, true, true, true, true, true, true, true, true, - ]; - let levels = nested_parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask, - LevelType::List(true), - ); - let expected_levels = LevelInfo { - // (def: 0) 0 1 [2] are 0 (take parent) - // (def: 0) 2 3 [4] are 0 (take parent) - // (def: 0) 4 5 [6] are 0 (take parent) - // (def: 0) 6 7 [8] are 0 (take parent) - // (def: 1) 8 9 [10] are 1 (take parent) - // (def: 1) 10 11 [12] are 1 (take parent) - // (def: 1) 12 23 [14] are 1 (take parent) - // (def: 1) 14 15 [16] are 1 (take parent) - // (def: 2) 16 17 [18] are 2 (defined at all levels) - // (def: 2) 18 19 [20] are 2 (defined at all levels) - // (def: 2) 20 21 [22] are 2 (defined at all levels) - // - // 0 1 [2] are 0 (not defined at level 1) - // [2] is 1, but has 0 slots so is not populated (defined at level 1 only) - // 2 3 [4] are 0 - // 4 5 6 7 [8] are 1 (defined at level 1 only) - // 8 9 10 [11] are 2 (defined at both levels) - // - // 0: [[100, 101], [102, 103]] - // 1: [] - // 2: [[104, 105], [106, 107]] - // 3: [[108, 109], [110, 111], [112, 113], [114, 115]] - // 4: [[116, 117], [118, 119], [120, 121]] - definition: vec![ - 0, 0, 0, 0, 1, 0, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - ], - repetition: Some(vec![ - 0, 2, 1, 2, 0, 0, 2, 1, 2, 0, 2, 1, 2, 1, 2, 1, 2, 0, 2, 1, 2, 1, 2, - ]), - array_offsets, - array_mask: vec![ - false, false, false, false, false, false, false, false, false, true, - true, true, true, true, true, true, true, true, true, true, true, true, - true, - ], - max_definition: 5, - level_type: LevelType::List(true), - offset: 0, - length: 22, - }; - assert_eq!(&levels.definition, &expected_levels.definition); - assert_eq!(&levels.repetition, &expected_levels.repetition); - assert_eq!(&levels.array_offsets, &expected_levels.array_offsets); - assert_eq!(&levels.array_mask, &expected_levels.array_mask); - assert_eq!(&levels.max_definition, &expected_levels.max_definition); - assert_eq!(&levels.level_type, &expected_levels.level_type); - assert_eq!(&levels, &expected_levels); - } - - #[test] - fn test_calculate_array_levels_nested_list() { - // if all array values are defined (e.g. batch>) - // The array at this level looks like: - // 0: [a] - // 1: [a] - // 2: [a] - // 3: [a] - let parent_levels = LevelInfo { - definition: vec![1, 1, 1, 1], - repetition: None, - array_offsets: vec![0, 1, 2, 3, 4], - array_mask: vec![true, true, true, true], - max_definition: 1, - level_type: LevelType::Struct(true), - offset: 0, - length: 4, - }; - // 0: null ([], but mask is false, so it's not just an empty list) - // 1: [1, 2, 3] - // 2: [4, 5] - // 3: [6, 7] - let array_offsets = vec![0, 1, 4, 6, 8]; - let array_mask = vec![false, true, true, true]; - - let levels = parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask, - LevelType::List(true), - ); - // 0: [null], level 1 is defined, but not 2 - // 1: [1, 2, 3] - // 2: [4, 5] - // 3: [6, 7] - let expected_levels = LevelInfo { - definition: vec![1, 3, 3, 3, 3, 3, 3, 3], - repetition: Some(vec![0, 0, 1, 1, 0, 1, 0, 1]), - array_offsets, - array_mask: vec![false, true, true, true, true, true, true, true], - max_definition: 3, - level_type: LevelType::List(true), - offset: 0, - length: 8, - }; - assert_eq!(&levels.definition, &expected_levels.definition); - assert_eq!(&levels.repetition, &expected_levels.repetition); - assert_eq!(&levels.array_offsets, &expected_levels.array_offsets); - assert_eq!(&levels.max_definition, &expected_levels.max_definition); - assert_eq!(&levels.level_type, &expected_levels.level_type); - assert_eq!(&levels, &expected_levels); - - // nested lists (using previous test) - let nested_parent_levels = levels; - // 0: [null] (was a populated null slot at the parent) - // 1: [201] - // 2: [202, 203] - // 3: null ([]) - // 4: [204, 205, 206] - // 5: [207, 208, 209, 210] - // 6: [] (tests a non-null empty list slot) - // 7: [211, 212, 213, 214, 215] - let array_offsets = vec![0, 1, 2, 4, 4, 7, 11, 11, 16]; - // logically, the fist slot of the mask is false - let array_mask = vec![true, true, true, false, true, true, true, true]; - let levels = nested_parent_levels.calculate_child_levels( - array_offsets.clone(), - array_mask, - LevelType::List(true), - ); - // We have 7 array values, and at least 15 primitives (from array_offsets) - // 0: (-)[null], parent was null, no value populated here - // 1: (0)[201], (1)[202, 203], (2)[[null]] - // 2: (3)[204, 205, 206], (4)[207, 208, 209, 210] - // 3: (5)[[]], (6)[211, 212, 213, 214, 215] - // - // In a JSON syntax with the schema: >>>, this translates into: - // 0: {"struct": [ null ]} - // 1: {"struct": [ [201], [202, 203], [] ]} - // 2: {"struct": [ [204, 205, 206], [207, 208, 209, 210] ]} - // 3: {"struct": [ [], [211, 212, 213, 214, 215] ]} - let expected_levels = LevelInfo { - definition: vec![1, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5, 4, 5, 5, 5, 5, 5], - repetition: Some(vec![0, 0, 1, 2, 1, 0, 2, 2, 1, 2, 2, 2, 0, 1, 2, 2, 2, 2]), - array_mask: vec![ - false, true, true, true, false, true, true, true, true, true, true, true, - true, true, true, true, true, true, - ], - array_offsets, - max_definition: 5, - level_type: LevelType::List(true), - offset: 0, - length: 16, - }; - assert_eq!(&levels.definition, &expected_levels.definition); - assert_eq!(&levels.repetition, &expected_levels.repetition); - assert_eq!(&levels.array_offsets, &expected_levels.array_offsets); - assert_eq!(&levels.array_mask, &expected_levels.array_mask); - assert_eq!(&levels.max_definition, &expected_levels.max_definition); - assert_eq!(&levels.level_type, &expected_levels.level_type); - assert_eq!(&levels, &expected_levels); - } - - #[test] - fn test_calculate_nested_struct_levels() { - // tests a > - // array: - // - {a: {b: {c: 1}}} - // - {a: {b: {c: null}}} - // - {a: {b: {c: 3}}} - // - {a: {b: null}} - // - {a: null}} - // - {a: {b: {c: 6}}} - let a_levels = LevelInfo { - definition: vec![1, 1, 1, 1, 0, 1], - repetition: None, - array_offsets: (0..=6).collect(), - array_mask: vec![true, true, true, true, false, true], - max_definition: 1, - level_type: LevelType::Struct(true), - offset: 0, - length: 6, - }; - // b's offset and mask - let b_offsets: Vec = (0..=6).collect(); - let b_mask = vec![true, true, true, false, false, true]; - // b's expected levels - let b_expected_levels = LevelInfo { - definition: vec![2, 2, 2, 1, 0, 2], - repetition: None, - array_offsets: (0..=6).collect(), - array_mask: vec![true, true, true, false, false, true], - max_definition: 2, - level_type: LevelType::Struct(true), - offset: 0, - length: 6, - }; - let b_levels = a_levels.calculate_child_levels( - b_offsets.clone(), - b_mask, - LevelType::Struct(true), - ); - assert_eq!(&b_expected_levels, &b_levels); - - // c's offset and mask - let c_offsets = b_offsets; - let c_mask = vec![true, false, true, false, false, true]; - // c's expected levels - let c_expected_levels = LevelInfo { - definition: vec![3, 2, 3, 1, 0, 3], - repetition: None, - array_offsets: c_offsets.clone(), - array_mask: vec![true, false, true, false, false, true], - max_definition: 3, - level_type: LevelType::Struct(true), - offset: 0, - length: 6, - }; - let c_levels = - b_levels.calculate_child_levels(c_offsets, c_mask, LevelType::Struct(true)); - assert_eq!(&c_expected_levels, &c_levels); - } - - #[test] - fn list_single_column() { - // this tests the level generation from the arrow_writer equivalent test - - let a_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); - let a_value_offsets = - arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); - let a_list_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, true))); - let a_list_data = ArrayData::builder(a_list_type.clone()) - .len(5) - .add_buffer(a_value_offsets) - .null_bit_buffer(Buffer::from(vec![0b00011011])) - .add_child_data(a_values.data().clone()) - .build() - .unwrap(); - - assert_eq!(a_list_data.null_count(), 1); - - let a = ListArray::from(a_list_data); - let values = Arc::new(a); - - let schema = Schema::new(vec![Field::new("item", a_list_type, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); - - let expected_batch_level = LevelInfo { - definition: vec![0; 2], - repetition: None, - array_offsets: (0..=2).collect(), - array_mask: vec![true, true], - max_definition: 0, - level_type: LevelType::Root, - offset: 2, - length: 2, - }; - - let batch_level = LevelInfo::new(2, 2); - assert_eq!(&batch_level, &expected_batch_level); - - // calculate the list's level - let mut levels = vec![]; - batch - .columns() - .iter() - .zip(batch.schema().fields()) - .for_each(|(array, field)| { - let mut array_levels = batch_level.calculate_array_levels(array, field); - levels.append(&mut array_levels); - }); - assert_eq!(levels.len(), 1); - - let list_level = levels.get(0).unwrap(); - - let expected_level = LevelInfo { - definition: vec![0, 3, 3, 3], - repetition: Some(vec![0, 0, 1, 1]), - array_offsets: vec![3, 3, 6], - array_mask: vec![false, true, true, true], - max_definition: 3, - level_type: LevelType::Primitive(true), - offset: 3, - length: 3, - }; - assert_eq!(&list_level.definition, &expected_level.definition); - assert_eq!(&list_level.repetition, &expected_level.repetition); - assert_eq!(&list_level.array_offsets, &expected_level.array_offsets); - assert_eq!(&list_level.array_mask, &expected_level.array_mask); - assert_eq!(&list_level.max_definition, &expected_level.max_definition); - assert_eq!(&list_level.level_type, &expected_level.level_type); - assert_eq!(list_level, &expected_level); - } - - #[test] - fn mixed_struct_list() { - // this tests the level generation from the equivalent arrow_writer_complex test - - // define schema - let struct_field_d = Field::new("d", DataType::Float64, true); - let struct_field_f = Field::new("f", DataType::Float32, true); - let struct_field_g = Field::new( - "g", - DataType::List(Box::new(Field::new("items", DataType::Int16, false))), - false, - ); - let struct_field_e = Field::new( - "e", - DataType::Struct(vec![struct_field_f.clone(), struct_field_g.clone()]), - true, - ); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, true), - Field::new( - "c", - DataType::Struct(vec![struct_field_d.clone(), struct_field_e.clone()]), - true, // https://github.com/apache/arrow-rs/issues/245 - ), - ]); - - // create some data - let a = Int32Array::from(vec![1, 2, 3, 4, 5]); - let b = Int32Array::from(vec![Some(1), None, None, Some(4), Some(5)]); - let d = Float64Array::from(vec![None, None, None, Some(1.0), None]); - let f = Float32Array::from(vec![Some(0.0), None, Some(333.3), None, Some(5.25)]); - - let g_value = Int16Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); - - // Construct a buffer for value offsets, for the nested array: - // [[1], [2, 3], null, [4, 5, 6], [7, 8, 9, 10]] - let g_value_offsets = - arrow::buffer::Buffer::from(&[0, 1, 3, 3, 6, 10].to_byte_slice()); - - // Construct a list array from the above two - let g_list_data = ArrayData::builder(struct_field_g.data_type().clone()) - .len(5) - .add_buffer(g_value_offsets) - .add_child_data(g_value.data().clone()) - .build() - .unwrap(); - let g = ListArray::from(g_list_data); - - let e = StructArray::from(vec![ - (struct_field_f, Arc::new(f) as ArrayRef), - (struct_field_g, Arc::new(g) as ArrayRef), - ]); - - let c = StructArray::from(vec![ - (struct_field_d, Arc::new(d) as ArrayRef), - (struct_field_e, Arc::new(e) as ArrayRef), - ]); - - // build a record batch - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(a), Arc::new(b), Arc::new(c)], - ) - .unwrap(); - - ////////////////////////////////////////////// - let expected_batch_level = LevelInfo { - definition: vec![0; 5], - repetition: None, - array_offsets: (0..=5).collect(), - array_mask: vec![true, true, true, true, true], - max_definition: 0, - level_type: LevelType::Root, - offset: 0, - length: 5, - }; - - let batch_level = LevelInfo::new(0, 5); - assert_eq!(&batch_level, &expected_batch_level); - - // calculate the list's level - let mut levels = vec![]; - batch - .columns() - .iter() - .zip(batch.schema().fields()) - .for_each(|(array, field)| { - let mut array_levels = batch_level.calculate_array_levels(array, field); - levels.append(&mut array_levels); - }); - assert_eq!(levels.len(), 5); - - // test "a" levels - let list_level = levels.get(0).unwrap(); - - let expected_level = LevelInfo { - definition: vec![0, 0, 0, 0, 0], - repetition: None, - array_offsets: vec![0, 1, 2, 3, 4, 5], - array_mask: vec![true, true, true, true, true], - max_definition: 0, - level_type: LevelType::Primitive(false), - offset: 0, - length: 5, - }; - assert_eq!(list_level, &expected_level); - - // test "b" levels - let list_level = levels.get(1).unwrap(); - - let expected_level = LevelInfo { - definition: vec![1, 0, 0, 1, 1], - repetition: None, - array_offsets: vec![0, 1, 2, 3, 4, 5], - array_mask: vec![true, false, false, true, true], - max_definition: 1, - level_type: LevelType::Primitive(true), - offset: 0, - length: 5, - }; - assert_eq!(list_level, &expected_level); - - // test "d" levels - let list_level = levels.get(2).unwrap(); - - let expected_level = LevelInfo { - definition: vec![1, 1, 1, 2, 1], - repetition: None, - array_offsets: vec![0, 1, 2, 3, 4, 5], - array_mask: vec![false, false, false, true, false], - max_definition: 2, - level_type: LevelType::Primitive(true), - offset: 0, - length: 5, - }; - assert_eq!(list_level, &expected_level); - - // test "f" levels - let list_level = levels.get(3).unwrap(); - - let expected_level = LevelInfo { - definition: vec![3, 2, 3, 2, 3], - repetition: None, - array_offsets: vec![0, 1, 2, 3, 4, 5], - array_mask: vec![true, false, true, false, true], - max_definition: 3, - level_type: LevelType::Primitive(true), - offset: 0, - length: 5, - }; - assert_eq!(list_level, &expected_level); - } - - #[test] - fn test_filter_array_indices() { - let level = LevelInfo { - definition: vec![3, 3, 3, 1, 3, 3, 3], - repetition: Some(vec![0, 1, 1, 0, 0, 1, 1]), - array_offsets: vec![0, 3, 3, 6], - array_mask: vec![true, true, true, false, true, true, true], - max_definition: 3, - level_type: LevelType::Primitive(true), - offset: 0, - length: 6, - }; - - let expected = vec![0, 1, 2, 3, 4, 5]; - let filter = level.filter_array_indices(); - assert_eq!(expected, filter); - } - - #[test] - fn test_null_vs_nonnull_struct() { - // define schema - let offset_field = Field::new("offset", DataType::Int32, true); - let schema = Schema::new(vec![Field::new( - "some_nested_object", - DataType::Struct(vec![offset_field.clone()]), - false, - )]); - - // create some data - let offset = Int32Array::from(vec![1, 2, 3, 4, 5]); - - let some_nested_object = - StructArray::from(vec![(offset_field, Arc::new(offset) as ArrayRef)]); - - // build a record batch - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(some_nested_object)]) - .unwrap(); - - let batch_level = LevelInfo::new(0, batch.num_rows()); - let struct_null_level = - batch_level.calculate_array_levels(batch.column(0), batch.schema().field(0)); - - // create second batch - // define schema - let offset_field = Field::new("offset", DataType::Int32, true); - let schema = Schema::new(vec![Field::new( - "some_nested_object", - DataType::Struct(vec![offset_field.clone()]), - true, - )]); - - // create some data - let offset = Int32Array::from(vec![1, 2, 3, 4, 5]); - - let some_nested_object = - StructArray::from(vec![(offset_field, Arc::new(offset) as ArrayRef)]); - - // build a record batch - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(some_nested_object)]) - .unwrap(); - - let batch_level = LevelInfo::new(0, batch.num_rows()); - let struct_non_null_level = - batch_level.calculate_array_levels(batch.column(0), batch.schema().field(0)); - - // The 2 levels should not be the same - if struct_non_null_level == struct_null_level { - panic!("Levels should not be equal, to reflect the difference in struct nullness"); - } - } - - #[test] - fn test_map_array() { - // Note: we are using the JSON Arrow reader for brevity - let json_content = r#" - {"stocks":{"long": "$AAA", "short": "$BBB"}} - {"stocks":{"long": null, "long": "$CCC", "short": null}} - {"stocks":{"hedged": "$YYY", "long": null, "short": "$D"}} - "#; - let entries_struct_type = DataType::Struct(vec![ - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Utf8, true), - ]); - let stocks_field = Field::new( - "stocks", - DataType::Map( - Box::new(Field::new("entries", entries_struct_type, false)), - false, - ), - // not nullable, so the keys have max level = 1 - false, - ); - let schema = Arc::new(Schema::new(vec![stocks_field])); - let builder = arrow::json::ReaderBuilder::new() - .with_schema(schema) - .with_batch_size(64); - let mut reader = builder.build(std::io::Cursor::new(json_content)).unwrap(); - - let batch = reader.next().unwrap().unwrap(); - - let expected_batch_level = LevelInfo { - definition: vec![0; 3], - repetition: None, - array_offsets: (0..=3).collect(), - array_mask: vec![true, true, true], - max_definition: 0, - level_type: LevelType::Root, - offset: 0, - length: 3, - }; - - let batch_level = LevelInfo::new(0, 3); - assert_eq!(&batch_level, &expected_batch_level); - - // calculate the map's level - let mut levels = vec![]; - batch - .columns() - .iter() - .zip(batch.schema().fields()) - .for_each(|(array, field)| { - let mut array_levels = batch_level.calculate_array_levels(array, field); - levels.append(&mut array_levels); - }); - assert_eq!(levels.len(), 2); - - // test key levels - let list_level = levels.get(0).unwrap(); - - let expected_level = LevelInfo { - definition: vec![1; 7], - repetition: Some(vec![0, 1, 0, 1, 0, 1, 1]), - array_offsets: vec![0, 2, 4, 7], - array_mask: vec![true; 7], - max_definition: 1, - level_type: LevelType::Primitive(false), - offset: 0, - length: 7, - }; - assert_eq!(list_level, &expected_level); - - // test values levels - let list_level = levels.get(1).unwrap(); - - let expected_level = LevelInfo { - definition: vec![2, 2, 2, 1, 2, 1, 2], - repetition: Some(vec![0, 1, 0, 1, 0, 1, 1]), - array_offsets: vec![0, 2, 4, 7], - array_mask: vec![true, true, true, false, true, false, true], - max_definition: 2, - level_type: LevelType::Primitive(true), - offset: 0, - length: 7, - }; - assert_eq!(list_level, &expected_level); - } -} diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs index 227bbdcd27a0..3aee7cf42cbc 100644 --- a/parquet/src/arrow/mod.rs +++ b/parquet/src/arrow/mod.rs @@ -15,16 +15,17 @@ // specific language governing permissions and limitations // under the License. +//! Provides API for reading/writing Arrow +//! [RecordBatch](arrow::record_batch::RecordBatch)es and +//! [Array](arrow::array::Array)s to/from Parquet Files. +//! //! [Apache Arrow](http://arrow.apache.org/) is a cross-language development platform for //! in-memory data. //! -//! This mod provides API for converting between arrow and parquet. -//! //!# Example of writing Arrow record batch to Parquet file //! //!```rust -//! use arrow::array::Int32Array; -//! use arrow::datatypes::{DataType, Field, Schema}; +//! use arrow::array::{Int32Array, ArrayRef}; //! use arrow::record_batch::RecordBatch; //! use parquet::arrow::arrow_writer::ArrowWriter; //! use parquet::file::properties::WriterProperties; @@ -32,25 +33,21 @@ //! use std::sync::Arc; //! let ids = Int32Array::from(vec![1, 2, 3, 4]); //! let vals = Int32Array::from(vec![5, 6, 7, 8]); -//! let schema = Arc::new(Schema::new(vec![ -//! Field::new("id", DataType::Int32, false), -//! Field::new("val", DataType::Int32, false), -//! ])); +//! let batch = RecordBatch::try_from_iter(vec![ +//! ("id", Arc::new(ids) as ArrayRef), +//! ("val", Arc::new(vals) as ArrayRef), +//! ]).unwrap(); //! //! let file = File::create("data.parquet").unwrap(); //! -//! let batch = -//! RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(ids), Arc::new(vals)]).unwrap(); -//! let batches = vec![batch]; -//! //! // Default writer properties //! let props = WriterProperties::builder().build(); //! -//! let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props)).unwrap(); +//! let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props)).unwrap(); +//! +//! writer.write(&batch).expect("Writing batch"); //! -//! for batch in batches { -//! writer.write(&batch).expect("Writing batch"); -//! } +//! // writer must be closed to write footer //! writer.close().unwrap(); //! ``` //! @@ -70,8 +67,8 @@ //! //! ```rust //! use arrow::record_batch::RecordBatchReader; -//! use parquet::file::reader::SerializedFileReader; -//! use parquet::arrow::{ParquetFileArrowReader, ArrowReader}; +//! use parquet::file::reader::{FileReader, SerializedFileReader}; +//! use parquet::arrow::{ParquetFileArrowReader, ArrowReader, ProjectionMask}; //! use std::sync::Arc; //! use std::fs::File; //! @@ -99,14 +96,18 @@ //! # writer.close().unwrap(); //! //! let file = File::open("data.parquet").unwrap(); -//! let file_reader = SerializedFileReader::new(file).unwrap(); -//! let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(file_reader)); +//! +//! let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); +//! let mask = ProjectionMask::leaves(arrow_reader.parquet_schema(), [0]); //! //! println!("Converted arrow schema is: {}", arrow_reader.get_schema().unwrap()); //! println!("Arrow schema after projection is: {}", -//! arrow_reader.get_schema_by_columns(vec![0], true).unwrap()); +//! arrow_reader.get_schema_by_columns(mask.clone()).unwrap()); +//! +//! let mut unprojected = arrow_reader.get_record_reader(2048).unwrap(); +//! println!("Unprojected reader schema: {}", unprojected.schema()); //! -//! let mut record_batch_reader = arrow_reader.get_record_reader(2048).unwrap(); +//! let mut record_batch_reader = arrow_reader.get_record_reader_by_columns(mask, 2048).unwrap(); //! //! for maybe_record_batch in record_batch_reader { //! let record_batch = maybe_record_batch.unwrap(); @@ -118,22 +119,114 @@ //!} //! ``` -pub mod array_reader; -pub mod arrow_array_reader; +experimental_mod!(array_reader); pub mod arrow_reader; pub mod arrow_writer; -pub mod converter; -pub(in crate::arrow) mod levels; -pub(in crate::arrow) mod record_reader; -pub mod schema; +mod buffer; + +#[cfg(feature = "async")] +pub mod async_reader; + +mod record_reader; +experimental_mod!(schema); pub use self::arrow_reader::ArrowReader; pub use self::arrow_reader::ParquetFileArrowReader; pub use self::arrow_writer::ArrowWriter; +#[cfg(feature = "async")] +pub use self::async_reader::ParquetRecordBatchStreamBuilder; +use crate::schema::types::SchemaDescriptor; + pub use self::schema::{ arrow_to_parquet_schema, parquet_to_arrow_schema, parquet_to_arrow_schema_by_columns, - parquet_to_arrow_schema_by_root_columns, }; /// Schema metadata key used to store serialized Arrow IPC schema pub const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; + +/// A [`ProjectionMask`] identifies a set of columns within a potentially nested schema to project +/// +/// In particular, a [`ProjectionMask`] can be constructed from a list of leaf column indices +/// or root column indices where: +/// +/// * Root columns are the direct children of the root schema, enumerated in order +/// * Leaf columns are the child-less leaves of the schema as enumerated by a depth-first search +/// +/// For example, the schema +/// +/// ```ignore +/// message schema { +/// REQUIRED boolean leaf_1; +/// REQUIRED GROUP group { +/// OPTIONAL int32 leaf_2; +/// OPTIONAL int64 leaf_3; +/// } +/// } +/// ``` +/// +/// Has roots `["leaf_1", "group"]` and leaves `["leaf_1", "leaf_2", "leaf_3"]` +/// +/// For non-nested schemas, i.e. those containing only primitive columns, the root +/// and leaves are the same +/// +#[derive(Debug, Clone)] +pub struct ProjectionMask { + /// If present a leaf column should be included if the value at + /// the corresponding index is true + /// + /// If `None`, include all columns + mask: Option>, +} + +impl ProjectionMask { + /// Create a [`ProjectionMask`] which selects all columns + pub fn all() -> Self { + Self { mask: None } + } + + /// Create a [`ProjectionMask`] which selects only the specified leaf columns + /// + /// Note: repeated or out of order indices will not impact the final mask + /// + /// i.e. `[0, 1, 2]` will construct the same mask as `[1, 0, 0, 2]` + pub fn leaves( + schema: &SchemaDescriptor, + indices: impl IntoIterator, + ) -> Self { + let mut mask = vec![false; schema.num_columns()]; + for leaf_idx in indices { + mask[leaf_idx] = true; + } + Self { mask: Some(mask) } + } + + /// Create a [`ProjectionMask`] which selects only the specified root columns + /// + /// Note: repeated or out of order indices will not impact the final mask + /// + /// i.e. `[0, 1, 2]` will construct the same mask as `[1, 0, 0, 2]` + pub fn roots( + schema: &SchemaDescriptor, + indices: impl IntoIterator, + ) -> Self { + let num_root_columns = schema.root_schema().get_fields().len(); + let mut root_mask = vec![false; num_root_columns]; + for root_idx in indices { + root_mask[root_idx] = true; + } + + let mask = (0..schema.num_columns()) + .map(|leaf_idx| { + let root_idx = schema.get_column_root_idx(leaf_idx); + root_mask[root_idx] + }) + .collect(); + + Self { mask: Some(mask) } + } + + /// Returns true if the leaf column `leaf_idx` is included by the mask + pub fn leaf_included(&self, leaf_idx: usize) -> bool { + self.mask.as_ref().map(|m| m[leaf_idx]).unwrap_or(true) + } +} diff --git a/parquet/src/arrow/record_reader/buffer.rs b/parquet/src/arrow/record_reader/buffer.rs new file mode 100644 index 000000000000..7101eaa9ccc9 --- /dev/null +++ b/parquet/src/arrow/record_reader/buffer.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::marker::PhantomData; + +use crate::arrow::buffer::bit_util::iter_set_bits_rev; +use arrow::buffer::{Buffer, MutableBuffer}; +use arrow::datatypes::ArrowNativeType; + +/// A buffer that supports writing new data to the end, and removing data from the front +/// +/// Used by [RecordReader](`super::RecordReader`) to buffer up values before returning a +/// potentially smaller number of values, corresponding to a whole number of semantic records +pub trait BufferQueue: Sized { + type Output: Sized; + + type Slice: ?Sized; + + /// Split out the first `len` items + /// + /// # Panics + /// + /// Implementations must panic if `len` is beyond the length of [`BufferQueue`] + /// + fn split_off(&mut self, len: usize) -> Self::Output; + + /// Returns a [`Self::Slice`] with at least `batch_size` capacity that can be used + /// to append data to the end of this [`BufferQueue`] + /// + /// NB: writes to the returned slice will not update the length of [`BufferQueue`] + /// instead a subsequent call should be made to [`BufferQueue::set_len`] + fn spare_capacity_mut(&mut self, batch_size: usize) -> &mut Self::Slice; + + /// Sets the length of the [`BufferQueue`]. + /// + /// Intended to be used in combination with [`BufferQueue::spare_capacity_mut`] + /// + /// # Panics + /// + /// Implementations must panic if `len` is beyond the initialized length + /// + /// Implementations may panic if `set_len` is called with less than what has been written + /// + /// This distinction is to allow for implementations that return a default initialized + /// [BufferQueue::Slice`] which doesn't track capacity and length separately + /// + /// For example, [`BufferQueue`] returns a default-initialized `&mut [T]`, and does not + /// track how much of this slice is actually written to by the caller. This is still + /// safe as the slice is default-initialized. + /// + fn set_len(&mut self, len: usize); +} + +/// A marker trait for [scalar] types +/// +/// This means that a `[Self::default()]` of length `len` can be safely created from a +/// zero-initialized `[u8]` with length `len * std::mem::size_of::()` and +/// alignment of `std::mem::size_of::()` +/// +/// [scalar]: https://doc.rust-lang.org/book/ch03-02-data-types.html#scalar-types +/// +pub trait ScalarValue: Copy {} +impl ScalarValue for bool {} +impl ScalarValue for u8 {} +impl ScalarValue for i8 {} +impl ScalarValue for u16 {} +impl ScalarValue for i16 {} +impl ScalarValue for u32 {} +impl ScalarValue for i32 {} +impl ScalarValue for u64 {} +impl ScalarValue for i64 {} +impl ScalarValue for f32 {} +impl ScalarValue for f64 {} + +/// A typed buffer similar to [`Vec`] but using [`MutableBuffer`] for storage +#[derive(Debug)] +pub struct ScalarBuffer { + buffer: MutableBuffer, + + /// Length in elements of size T + len: usize, + + /// Placeholder to allow `T` as an invariant generic parameter + /// without making it !Send + _phantom: PhantomData T>, +} + +impl Default for ScalarBuffer { + fn default() -> Self { + Self::new() + } +} + +impl ScalarBuffer { + pub fn new() -> Self { + Self { + buffer: MutableBuffer::new(0), + len: 0, + _phantom: Default::default(), + } + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn reserve(&mut self, additional: usize) { + self.buffer.reserve(additional * std::mem::size_of::()); + } + + pub fn resize(&mut self, len: usize) { + self.buffer.resize(len * std::mem::size_of::(), 0); + self.len = len; + } + + #[inline] + pub fn as_slice(&self) -> &[T] { + let (prefix, buf, suffix) = unsafe { self.buffer.as_slice().align_to::() }; + assert!(prefix.is_empty() && suffix.is_empty()); + buf + } + + #[inline] + pub fn as_slice_mut(&mut self) -> &mut [T] { + let (prefix, buf, suffix) = + unsafe { self.buffer.as_slice_mut().align_to_mut::() }; + assert!(prefix.is_empty() && suffix.is_empty()); + buf + } + + pub fn take(&mut self, len: usize) -> Self { + assert!(len <= self.len); + + let num_bytes = len * std::mem::size_of::(); + let remaining_bytes = self.buffer.len() - num_bytes; + // TODO: Optimize to reduce the copy + // create an empty buffer, as it will be resized below + let mut remaining = MutableBuffer::new(0); + remaining.resize(remaining_bytes, 0); + + let new_records = remaining.as_slice_mut(); + + new_records[0..remaining_bytes] + .copy_from_slice(&self.buffer.as_slice()[num_bytes..]); + + self.buffer.resize(num_bytes, 0); + self.len -= len; + + Self { + buffer: std::mem::replace(&mut self.buffer, remaining), + len, + _phantom: Default::default(), + } + } +} + +impl ScalarBuffer { + pub fn push(&mut self, v: T) { + self.buffer.push(v); + self.len += 1; + } + + pub fn extend_from_slice(&mut self, v: &[T]) { + self.buffer.extend_from_slice(v); + self.len += v.len(); + } +} + +impl From> for Buffer { + fn from(t: ScalarBuffer) -> Self { + t.buffer.into() + } +} + +impl BufferQueue for ScalarBuffer { + type Output = Buffer; + + type Slice = [T]; + + fn split_off(&mut self, len: usize) -> Self::Output { + self.take(len).into() + } + + fn spare_capacity_mut(&mut self, batch_size: usize) -> &mut Self::Slice { + self.buffer + .resize((self.len + batch_size) * std::mem::size_of::(), 0); + + let range = self.len..self.len + batch_size; + &mut self.as_slice_mut()[range] + } + + fn set_len(&mut self, len: usize) { + self.len = len; + + let new_bytes = self.len * std::mem::size_of::(); + assert!(new_bytes <= self.buffer.len()); + self.buffer.resize(new_bytes, 0); + } +} + +/// A [`BufferQueue`] capable of storing column values +pub trait ValuesBuffer: BufferQueue { + /// + /// If a column contains nulls, more level data may be read than value data, as null + /// values are not encoded. Therefore, first the levels data is read, the null count + /// determined, and then the corresponding number of values read to a [`ValuesBuffer`]. + /// + /// It is then necessary to move this values data into positions that correspond to + /// the non-null level positions. This is what this method does. + /// + /// It is provided with: + /// + /// - `read_offset` - the offset in [`ValuesBuffer`] to start null padding from + /// - `values_read` - the number of values read + /// - `levels_read` - the number of levels read + /// - `valid_mask` - a packed mask of valid levels + /// + fn pad_nulls( + &mut self, + read_offset: usize, + values_read: usize, + levels_read: usize, + valid_mask: &[u8], + ); +} + +impl ValuesBuffer for ScalarBuffer { + fn pad_nulls( + &mut self, + read_offset: usize, + values_read: usize, + levels_read: usize, + valid_mask: &[u8], + ) { + let slice = self.as_slice_mut(); + assert!(slice.len() >= read_offset + levels_read); + + let values_range = read_offset..read_offset + values_read; + for (value_pos, level_pos) in + values_range.rev().zip(iter_set_bits_rev(valid_mask)) + { + debug_assert!(level_pos >= value_pos); + if level_pos <= value_pos { + break; + } + slice[level_pos] = slice[value_pos]; + } + } +} diff --git a/parquet/src/arrow/record_reader/definition_levels.rs b/parquet/src/arrow/record_reader/definition_levels.rs new file mode 100644 index 000000000000..9cca25c8ae5c --- /dev/null +++ b/parquet/src/arrow/record_reader/definition_levels.rs @@ -0,0 +1,417 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ops::Range; + +use arrow::array::BooleanBufferBuilder; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; + +use crate::arrow::buffer::bit_util::count_set_bits; +use crate::arrow::record_reader::buffer::BufferQueue; +use crate::basic::Encoding; +use crate::column::reader::decoder::{ + ColumnLevelDecoder, ColumnLevelDecoderImpl, LevelsBufferSlice, +}; +use crate::errors::{ParquetError, Result}; +use crate::schema::types::ColumnDescPtr; +use crate::util::memory::ByteBufferPtr; + +use super::{buffer::ScalarBuffer, MIN_BATCH_SIZE}; + +enum BufferInner { + /// Compute levels and null mask + Full { + levels: ScalarBuffer, + nulls: BooleanBufferBuilder, + max_level: i16, + }, + /// Only compute null bitmask - requires max level to be 1 + /// + /// This is an optimisation for the common case of a nullable scalar column, as decoding + /// the definition level data is only required when decoding nested structures + /// + Mask { nulls: BooleanBufferBuilder }, +} + +pub struct DefinitionLevelBuffer { + inner: BufferInner, + + /// The length of this buffer + /// + /// Note: `buffer` and `builder` may contain more elements + len: usize, +} + +impl DefinitionLevelBuffer { + pub fn new(desc: &ColumnDescPtr, null_mask_only: bool) -> Self { + let inner = match null_mask_only { + true => { + assert_eq!( + desc.max_def_level(), + 1, + "max definition level must be 1 to only compute null bitmask" + ); + + assert_eq!( + desc.max_rep_level(), + 0, + "max repetition level must be 0 to only compute null bitmask" + ); + + BufferInner::Mask { + nulls: BooleanBufferBuilder::new(0), + } + } + false => BufferInner::Full { + levels: ScalarBuffer::new(), + nulls: BooleanBufferBuilder::new(0), + max_level: desc.max_def_level(), + }, + }; + + Self { inner, len: 0 } + } + + pub fn split_levels(&mut self, len: usize) -> Option { + match &mut self.inner { + BufferInner::Full { levels, .. } => { + let out = levels.split_off(len); + self.len = levels.len(); + Some(out) + } + BufferInner::Mask { .. } => None, + } + } + + pub fn set_len(&mut self, len: usize) { + assert_eq!(self.nulls().len(), len); + self.len = len; + } + + /// Split `len` levels out of `self` + pub fn split_bitmask(&mut self, len: usize) -> Bitmap { + let old_builder = match &mut self.inner { + BufferInner::Full { nulls, .. } => nulls, + BufferInner::Mask { nulls } => nulls, + }; + + // Compute the number of values left behind + let num_left_values = old_builder.len() - len; + let mut new_builder = + BooleanBufferBuilder::new(MIN_BATCH_SIZE.max(num_left_values)); + + // Copy across remaining values + new_builder.append_packed_range(len..old_builder.len(), old_builder.as_slice()); + + // Truncate buffer + old_builder.resize(len); + + // Swap into self + self.len = new_builder.len(); + Bitmap::from(std::mem::replace(old_builder, new_builder).finish()) + } + + pub fn nulls(&self) -> &BooleanBufferBuilder { + match &self.inner { + BufferInner::Full { nulls, .. } => nulls, + BufferInner::Mask { nulls } => nulls, + } + } +} + +impl LevelsBufferSlice for DefinitionLevelBuffer { + fn capacity(&self) -> usize { + usize::MAX + } + + fn count_nulls(&self, range: Range, _max_level: i16) -> usize { + let total_count = range.end - range.start; + let range = range.start + self.len..range.end + self.len; + total_count - count_set_bits(self.nulls().as_slice(), range) + } +} + +pub struct DefinitionLevelDecoder { + max_level: i16, + encoding: Encoding, + data: Option, + column_decoder: Option, + packed_decoder: Option, +} + +impl ColumnLevelDecoder for DefinitionLevelDecoder { + type Slice = DefinitionLevelBuffer; + + fn new(max_level: i16, encoding: Encoding, data: ByteBufferPtr) -> Self { + Self { + max_level, + encoding, + data: Some(data), + column_decoder: None, + packed_decoder: None, + } + } + + fn read( + &mut self, + writer: &mut Self::Slice, + range: Range, + ) -> crate::errors::Result { + match &mut writer.inner { + BufferInner::Full { + levels, + nulls, + max_level, + } => { + assert_eq!(self.max_level, *max_level); + assert_eq!(range.start + writer.len, nulls.len()); + + let decoder = match self.data.take() { + Some(data) => self.column_decoder.insert( + ColumnLevelDecoderImpl::new(self.max_level, self.encoding, data), + ), + None => self + .column_decoder + .as_mut() + .expect("consistent null_mask_only"), + }; + + levels.resize(range.end + writer.len); + + let slice = &mut levels.as_slice_mut()[writer.len..]; + let levels_read = decoder.read(slice, range.clone())?; + + nulls.reserve(levels_read); + for i in &slice[range.start..range.start + levels_read] { + nulls.append(i == max_level) + } + + Ok(levels_read) + } + BufferInner::Mask { nulls } => { + assert_eq!(self.max_level, 1); + assert_eq!(range.start + writer.len, nulls.len()); + + let decoder = match self.data.take() { + Some(data) => self + .packed_decoder + .insert(PackedDecoder::new(self.encoding, data)), + None => self + .packed_decoder + .as_mut() + .expect("consistent null_mask_only"), + }; + + decoder.read(nulls, range.end - range.start) + } + } + } +} + +/// An optimized decoder for decoding [RLE] and [BIT_PACKED] data with a bit width of 1 +/// directly into a bitmask +/// +/// This is significantly faster than decoding the data into `[i16]` and then computing +/// a bitmask from this, as not only can it skip this buffer allocation and construction, +/// but it can exploit properties of the encoded data to reduce work further +/// +/// In particular: +/// +/// * Packed runs are already bitmask encoded and can simply be appended +/// * Runs of 1 or 0 bits can be efficiently appended with byte (or larger) operations +/// +/// [RLE]: https://github.com/apache/parquet-format/blob/master/Encodings.md#run-length-encoding--bit-packing-hybrid-rle--3 +/// [BIT_PACKED]: https://github.com/apache/parquet-format/blob/master/Encodings.md#bit-packed-deprecated-bit_packed--4 +struct PackedDecoder { + data: ByteBufferPtr, + data_offset: usize, + rle_left: usize, + rle_value: bool, + packed_count: usize, + packed_offset: usize, +} + +impl PackedDecoder { + fn next_rle_block(&mut self) -> Result<()> { + let indicator_value = self.decode_header()?; + if indicator_value & 1 == 1 { + let len = (indicator_value >> 1) as usize; + self.packed_count = len * 8; + self.packed_offset = 0; + } else { + self.rle_left = (indicator_value >> 1) as usize; + let byte = *self.data.as_ref().get(self.data_offset).ok_or_else(|| { + ParquetError::EOF( + "unexpected end of file whilst decoding definition levels rle value" + .into(), + ) + })?; + + self.data_offset += 1; + self.rle_value = byte != 0; + } + Ok(()) + } + + /// Decodes a VLQ encoded little endian integer and returns it + fn decode_header(&mut self) -> Result { + let mut offset = 0; + let mut v: i64 = 0; + while offset < 10 { + let byte = *self + .data + .as_ref() + .get(self.data_offset + offset) + .ok_or_else(|| { + ParquetError::EOF( + "unexpected end of file whilst decoding definition levels rle header" + .into(), + ) + })?; + + v |= ((byte & 0x7F) as i64) << (offset * 7); + offset += 1; + if byte & 0x80 == 0 { + self.data_offset += offset; + return Ok(v); + } + } + Err(general_err!("too many bytes for VLQ")) + } +} + +impl PackedDecoder { + fn new(encoding: Encoding, data: ByteBufferPtr) -> Self { + match encoding { + Encoding::RLE => Self { + data, + data_offset: 0, + rle_left: 0, + rle_value: false, + packed_count: 0, + packed_offset: 0, + }, + Encoding::BIT_PACKED => Self { + data_offset: 0, + rle_left: 0, + rle_value: false, + packed_count: data.len() * 8, + packed_offset: 0, + data, + }, + _ => unreachable!("invalid level encoding: {}", encoding), + } + } + + fn read(&mut self, buffer: &mut BooleanBufferBuilder, len: usize) -> Result { + let mut read = 0; + while read != len { + if self.rle_left != 0 { + let to_read = self.rle_left.min(len - read); + buffer.append_n(to_read, self.rle_value); + self.rle_left -= to_read; + read += to_read; + } else if self.packed_count != self.packed_offset { + let to_read = (self.packed_count - self.packed_offset).min(len - read); + let offset = self.data_offset * 8 + self.packed_offset; + buffer.append_packed_range(offset..offset + to_read, self.data.as_ref()); + self.packed_offset += to_read; + read += to_read; + + if self.packed_offset == self.packed_count { + self.data_offset += self.packed_count / 8; + } + } else if self.data_offset == self.data.len() { + break; + } else { + self.next_rle_block()? + } + } + Ok(read) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use crate::basic::Type as PhysicalType; + use crate::encodings::rle::RleEncoder; + use crate::schema::types::{ColumnDescriptor, ColumnPath, Type}; + use rand::{thread_rng, Rng}; + + #[test] + fn test_packed_decoder() { + let mut rng = thread_rng(); + let len: usize = rng.gen_range(512..1024); + + let mut expected = BooleanBufferBuilder::new(len); + let mut encoder = RleEncoder::new(1, 1024); + for _ in 0..len { + let bool = rng.gen_bool(0.8); + assert!(encoder.put(bool as u64).unwrap()); + expected.append(bool); + } + assert_eq!(expected.len(), len); + + let encoded = encoder.consume().unwrap(); + let mut decoder = PackedDecoder::new(Encoding::RLE, ByteBufferPtr::new(encoded)); + + // Decode data in random length intervals + let mut decoded = BooleanBufferBuilder::new(len); + loop { + let remaining = len - decoded.len(); + if remaining == 0 { + break; + } + + let to_read = rng.gen_range(1..=remaining); + decoder.read(&mut decoded, to_read).unwrap(); + } + + assert_eq!(decoded.len(), len); + assert_eq!(decoded.as_slice(), expected.as_slice()); + } + + #[test] + fn test_split_off() { + let t = Type::primitive_type_builder("col", PhysicalType::INT32) + .build() + .unwrap(); + + let descriptor = Arc::new(ColumnDescriptor::new( + Arc::new(t), + 1, + 0, + ColumnPath::new(vec![]), + )); + + let mut buffer = DefinitionLevelBuffer::new(&descriptor, true); + match &mut buffer.inner { + BufferInner::Mask { nulls } => nulls.append_n(100, false), + _ => unreachable!(), + }; + + let bitmap = buffer.split_bitmask(19); + + // Should have split off 19 records leaving, 81 behind + assert_eq!(bitmap.bit_len(), 3 * 8); // Note: bitmask only tracks bytes not bits + assert_eq!(buffer.nulls().len(), 81); + } +} diff --git a/parquet/src/arrow/record_reader.rs b/parquet/src/arrow/record_reader/mod.rs similarity index 60% rename from parquet/src/arrow/record_reader.rs rename to parquet/src/arrow/record_reader/mod.rs index 4dd7da910fd0..023a538a2741 100644 --- a/parquet/src/arrow/record_reader.rs +++ b/parquet/src/arrow/record_reader/mod.rs @@ -16,75 +16,113 @@ // under the License. use std::cmp::{max, min}; -use std::mem::{replace, size_of}; -use crate::column::{page::PageReader, reader::ColumnReaderImpl}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; + +use crate::arrow::record_reader::{ + buffer::{BufferQueue, ScalarBuffer, ValuesBuffer}, + definition_levels::{DefinitionLevelBuffer, DefinitionLevelDecoder}, +}; +use crate::column::{ + page::PageReader, + reader::{ + decoder::{ColumnLevelDecoderImpl, ColumnValueDecoder, ColumnValueDecoderImpl}, + GenericColumnReader, + }, +}; use crate::data_type::DataType; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; -use arrow::array::BooleanBufferBuilder; -use arrow::bitmap::Bitmap; -use arrow::buffer::{Buffer, MutableBuffer}; + +pub(crate) mod buffer; +mod definition_levels; const MIN_BATCH_SIZE: usize = 1024; /// A `RecordReader` is a stateful column reader that delimits semantic records. -pub struct RecordReader { +pub type RecordReader = + GenericRecordReader::T>, ColumnValueDecoderImpl>; + +#[doc(hidden)] +/// A generic stateful column reader that delimits semantic records +/// +/// This type is hidden from the docs, and relies on private traits with no +/// public implementations. As such this type signature may be changed without +/// breaking downstream users as it can only be constructed through type aliases +pub struct GenericRecordReader { column_desc: ColumnDescPtr, - records: MutableBuffer, - def_levels: Option, - rep_levels: Option, - null_bitmap: Option, - column_reader: Option>, + records: V, + def_levels: Option, + rep_levels: Option>, + column_reader: + Option>, /// Number of records accumulated in records num_records: usize, /// Number of values `num_records` contains. num_values: usize, - values_seen: usize, /// Starts from 1, number of values have been written to buffer values_written: usize, - in_middle_of_record: bool, } -impl RecordReader { - pub fn new(column_schema: ColumnDescPtr) -> Self { - let (def_levels, null_map) = if column_schema.max_def_level() > 0 { - ( - Some(MutableBuffer::new(MIN_BATCH_SIZE)), - Some(BooleanBufferBuilder::new(0)), - ) - } else { - (None, None) - }; +impl GenericRecordReader +where + V: ValuesBuffer + Default, + CV: ColumnValueDecoder, +{ + /// Create a new [`GenericRecordReader`] + pub fn new(desc: ColumnDescPtr) -> Self { + Self::new_with_options(desc, false) + } - let rep_levels = if column_schema.max_rep_level() > 0 { - Some(MutableBuffer::new(MIN_BATCH_SIZE)) - } else { - None - }; + /// Create a new [`GenericRecordReader`] with the ability to only generate the bitmask + /// + /// If `null_mask_only` is true only the null bitmask will be generated and + /// [`Self::consume_def_levels`] and [`Self::consume_rep_levels`] will always return `None` + /// + /// It is insufficient to solely check that that the max definition level is 1 as we + /// need there to be no nullable parent array that will required decoded definition levels + /// + /// In particular consider the case of: + /// + /// ```ignore + /// message nested { + /// OPTIONAL Group group { + /// REQUIRED INT32 leaf; + /// } + /// } + /// ``` + /// + /// The maximum definition level of leaf is 1, however, we still need to decode the + /// definition levels so that the parent group can be constructed correctly + /// + pub(crate) fn new_with_options(desc: ColumnDescPtr, null_mask_only: bool) -> Self { + let def_levels = (desc.max_def_level() > 0) + .then(|| DefinitionLevelBuffer::new(&desc, null_mask_only)); + + let rep_levels = (desc.max_rep_level() > 0).then(ScalarBuffer::new); Self { - records: MutableBuffer::new(MIN_BATCH_SIZE), + records: Default::default(), def_levels, rep_levels, - null_bitmap: null_map, column_reader: None, - column_desc: column_schema, + column_desc: desc, num_records: 0, num_values: 0, - values_seen: 0, values_written: 0, - in_middle_of_record: false, } } /// Set the current page reader. pub fn set_page_reader(&mut self, page_reader: Box) -> Result<()> { - self.column_reader = - Some(ColumnReaderImpl::new(self.column_desc.clone(), page_reader)); + self.column_reader = Some(GenericColumnReader::new( + self.column_desc.clone(), + page_reader, + )); Ok(()) } @@ -107,25 +145,52 @@ impl RecordReader { loop { // Try to find some records from buffers that has been read into memory // but not counted as seen records. - records_read += self.split_records(num_records - records_read)?; - - // Since page reader contains complete records, so if we reached end of a - // page reader, we should reach the end of a record - if end_of_column - && self.values_seen >= self.values_written - && self.in_middle_of_record - { - self.num_records += 1; - self.num_values = self.values_seen; - self.in_middle_of_record = false; - records_read += 1; + let (record_count, value_count) = + self.count_records(num_records - records_read); + + self.num_records += record_count; + self.num_values += value_count; + records_read += record_count; + + if records_read == num_records { + break; } - if (records_read >= num_records) || end_of_column { + if end_of_column { + // Since page reader contains complete records, if we reached end of a + // page reader, we should reach the end of a record + if self.rep_levels.is_some() { + self.num_records += 1; + self.num_values = self.values_written; + records_read += 1; + } break; } - let batch_size = max(num_records - records_read, MIN_BATCH_SIZE); + // If repetition levels present, we don't know how much more to read + // in order to read the requested number of records, therefore read at least + // MIN_BATCH_SIZE, otherwise read **exactly** what was requested. This helps + // to avoid a degenerate case where the buffers are never fully drained. + // + // Consider the scenario where the user is requesting batches of MIN_BATCH_SIZE. + // + // When transitioning across a row group boundary, this will read some remainder + // from the row group `r`, before reading MIN_BATCH_SIZE from the next row group, + // leaving `MIN_BATCH_SIZE + r` in the buffer. + // + // The client will then only split off the `MIN_BATCH_SIZE` they actually wanted, + // leaving behind `r`. This will continue indefinitely. + // + // Aside from wasting cycles splitting and shuffling buffers unnecessarily, this + // prevents dictionary preservation from functioning correctly as the buffer + // will never be emptied, allowing a new dictionary to be registered. + // + // This degenerate case can still occur for repeated fields, but + // it is avoided for the more common case of a non-repeated field + let batch_size = match &self.rep_levels { + Some(_) => max(num_records - records_read, MIN_BATCH_SIZE), + None => num_records - records_read, + }; // Try to more value from parquet pages let values_read = self.read_one_batch(batch_size)?; @@ -154,108 +219,31 @@ impl RecordReader { /// definition level values that have already been read into memory but not counted /// as record values, e.g. those from `self.num_values` to `self.values_written`. pub fn consume_def_levels(&mut self) -> Result> { - let new_buffer = if let Some(ref mut def_levels_buf) = &mut self.def_levels { - let num_left_values = self.values_written - self.num_values; - // create an empty buffer, as it will be resized below - let mut new_buffer = MutableBuffer::new(0); - let num_bytes = num_left_values * size_of::(); - let new_len = self.num_values * size_of::(); - - new_buffer.resize(num_bytes, 0); - - let new_def_levels = new_buffer.as_slice_mut(); - let left_def_levels = &def_levels_buf.as_slice_mut()[new_len..]; - - new_def_levels[0..num_bytes].copy_from_slice(&left_def_levels[0..num_bytes]); - - def_levels_buf.resize(new_len, 0); - Some(new_buffer) - } else { - None - }; - - Ok(replace(&mut self.def_levels, new_buffer).map(|x| x.into())) + Ok(match self.def_levels.as_mut() { + Some(x) => x.split_levels(self.num_values), + None => None, + }) } /// Return repetition level data. /// The side effect is similar to `consume_def_levels`. pub fn consume_rep_levels(&mut self) -> Result> { - // TODO: Optimize to reduce the copy - let new_buffer = if let Some(ref mut rep_levels_buf) = &mut self.rep_levels { - let num_left_values = self.values_written - self.num_values; - // create an empty buffer, as it will be resized below - let mut new_buffer = MutableBuffer::new(0); - let num_bytes = num_left_values * size_of::(); - let new_len = self.num_values * size_of::(); - - new_buffer.resize(num_bytes, 0); - - let new_rep_levels = new_buffer.as_slice_mut(); - let left_rep_levels = &rep_levels_buf.as_slice_mut()[new_len..]; - - new_rep_levels[0..num_bytes].copy_from_slice(&left_rep_levels[0..num_bytes]); - - rep_levels_buf.resize(new_len, 0); - - Some(new_buffer) - } else { - None - }; - - Ok(replace(&mut self.rep_levels, new_buffer).map(|x| x.into())) + Ok(match self.rep_levels.as_mut() { + Some(x) => Some(x.split_off(self.num_values)), + None => None, + }) } /// Returns currently stored buffer data. /// The side effect is similar to `consume_def_levels`. - pub fn consume_record_data(&mut self) -> Result { - // TODO: Optimize to reduce the copy - let num_left_values = self.values_written - self.num_values; - // create an empty buffer, as it will be resized below - let mut new_buffer = MutableBuffer::new(0); - let num_bytes = num_left_values * T::get_type_size(); - let new_len = self.num_values * T::get_type_size(); - - new_buffer.resize(num_bytes, 0); - - let new_records = new_buffer.as_slice_mut(); - let left_records = &mut self.records.as_slice_mut()[new_len..]; - - new_records[0..num_bytes].copy_from_slice(&left_records[0..num_bytes]); - - self.records.resize(new_len, 0); - - Ok(replace(&mut self.records, new_buffer).into()) + pub fn consume_record_data(&mut self) -> Result { + Ok(self.records.split_off(self.num_values)) } /// Returns currently stored null bitmap data. /// The side effect is similar to `consume_def_levels`. pub fn consume_bitmap_buffer(&mut self) -> Result> { - // TODO: Optimize to reduce the copy - if self.column_desc.max_def_level() > 0 { - assert!(self.null_bitmap.is_some()); - let num_left_values = self.values_written - self.num_values; - let new_bitmap_builder = Some(BooleanBufferBuilder::new(max( - MIN_BATCH_SIZE, - num_left_values, - ))); - - let old_bitmap = replace(&mut self.null_bitmap, new_bitmap_builder) - .map(|mut builder| builder.finish()) - .unwrap(); - - let old_bitmap = Bitmap::from(old_bitmap); - - for i in self.num_values..self.values_written { - self.null_bitmap - .as_mut() - .unwrap() - .append(old_bitmap.is_set(i)); - } - - Ok(Some(old_bitmap.into_buffer())) - } else { - Ok(None) - } + Ok(self.consume_bitmap()?.map(|b| b.into_buffer())) } /// Reset state of record reader. @@ -265,49 +253,26 @@ impl RecordReader { self.values_written -= self.num_values; self.num_records = 0; self.num_values = 0; - self.values_seen = 0; - self.in_middle_of_record = false; } /// Returns bitmap data. pub fn consume_bitmap(&mut self) -> Result> { - self.consume_bitmap_buffer() - .map(|buffer| buffer.map(Bitmap::from)) + Ok(self + .def_levels + .as_mut() + .map(|levels| levels.split_bitmask(self.num_values))) } /// Try to read one batch of data. fn read_one_batch(&mut self, batch_size: usize) -> Result { - // Reserve spaces - self.records - .resize(self.records.len() + batch_size * T::get_type_size(), 0); - if let Some(ref mut buf) = self.rep_levels { - buf.resize(buf.len() + batch_size * size_of::(), 0); - } - if let Some(ref mut buf) = self.def_levels { - buf.resize(buf.len() + batch_size * size_of::(), 0); - } - - let values_written = self.values_written; - - // Convert mutable buffer spaces to mutable slices - let (prefix, values, suffix) = - unsafe { self.records.as_slice_mut().align_to_mut::() }; - assert!(prefix.is_empty() && suffix.is_empty()); - let values = &mut values[values_written..]; + let rep_levels = self + .rep_levels + .as_mut() + .map(|levels| levels.spare_capacity_mut(batch_size)); - let def_levels = self.def_levels.as_mut().map(|buf| { - let (prefix, def_levels, suffix) = - unsafe { buf.as_slice_mut().align_to_mut::() }; - assert!(prefix.is_empty() && suffix.is_empty()); - &mut def_levels[values_written..] - }); + let def_levels = self.def_levels.as_mut(); - let rep_levels = self.rep_levels.as_mut().map(|buf| { - let (prefix, rep_levels, suffix) = - unsafe { buf.as_slice_mut().align_to_mut::() }; - assert!(prefix.is_empty() && suffix.is_empty()); - &mut rep_levels[values_written..] - }); + let values = self.records.spare_capacity_mut(batch_size); let (values_read, levels_read) = self .column_reader @@ -315,98 +280,61 @@ impl RecordReader { .unwrap() .read_batch(batch_size, def_levels, rep_levels, values)?; - // get new references for the def levels. - let def_levels = self.def_levels.as_ref().map(|buf| { - let (prefix, def_levels, suffix) = - unsafe { buf.as_slice().align_to::() }; - assert!(prefix.is_empty() && suffix.is_empty()); - &def_levels[values_written..] - }); - - let max_def_level = self.column_desc.max_def_level(); - if values_read < levels_read { - let def_levels = def_levels.ok_or_else(|| { + let def_levels = self.def_levels.as_ref().ok_or_else(|| { general_err!( "Definition levels should exist when data is less than levels!" ) })?; - // Fill spaces in column data with default values - let mut values_pos = values_read; - let mut level_pos = levels_read; - - while level_pos > values_pos { - if def_levels[level_pos - 1] == max_def_level { - // This values is not empty - // We use swap rather than assign here because T::T doesn't - // implement Copy - values.swap(level_pos - 1, values_pos - 1); - values_pos -= 1; - } else { - values[level_pos - 1] = T::T::default(); - } - - level_pos -= 1; - } - } - - // Fill in bitmap data - if let Some(null_buffer) = self.null_bitmap.as_mut() { - let def_levels = def_levels.ok_or_else(|| { - general_err!( - "Definition levels should exist when data is less than levels!" - ) - })?; - (0..levels_read) - .for_each(|idx| null_buffer.append(def_levels[idx] == max_def_level)); + self.records.pad_nulls( + self.values_written, + values_read, + levels_read, + def_levels.nulls().as_slice(), + ); } - let values_read = max(values_read, levels_read); + let values_read = max(levels_read, values_read); self.set_values_written(self.values_written + values_read)?; Ok(values_read) } - /// Split values into records according repetition definition and returns number of - /// records read. - #[allow(clippy::unnecessary_wraps)] - fn split_records(&mut self, records_to_read: usize) -> Result { - let rep_levels = self.rep_levels.as_ref().map(|buf| { - let (prefix, rep_levels, suffix) = - unsafe { buf.as_slice().align_to::() }; - assert!(prefix.is_empty() && suffix.is_empty()); - rep_levels - }); - - match rep_levels { + /// Inspects the buffered repetition levels in the range `self.num_values..self.values_written` + /// and returns the number of "complete" records along with the corresponding number of values + /// + /// A "complete" record is one where the buffer contains a subsequent repetition level of 0 + fn count_records(&self, records_to_read: usize) -> (usize, usize) { + match self.rep_levels.as_ref() { Some(buf) => { + let buf = buf.as_slice(); + let mut records_read = 0; + let mut end_of_last_record = self.num_values; - while (self.values_seen < self.values_written) - && (records_read < records_to_read) + for (current, item) in buf + .iter() + .enumerate() + .take(self.values_written) + .skip(self.num_values) { - if buf[self.values_seen] == 0 { - if self.in_middle_of_record { - records_read += 1; - self.num_records += 1; - self.num_values = self.values_seen; + if *item == 0 && current != self.num_values { + records_read += 1; + end_of_last_record = current; + + if records_read == records_to_read { + break; } - self.in_middle_of_record = true; } - self.values_seen += 1; } - Ok(records_read) + (records_read, end_of_last_record - self.num_values) } None => { let records_read = - min(records_to_read, self.values_written - self.values_seen); - self.num_records += records_read; - self.num_values += records_read; - self.values_seen += records_read; - self.in_middle_of_record = false; + min(records_to_read, self.values_written - self.num_values); - Ok(records_read) + (records_read, records_read) } } } @@ -414,17 +342,14 @@ impl RecordReader { #[allow(clippy::unnecessary_wraps)] fn set_values_written(&mut self, new_values_written: usize) -> Result<()> { self.values_written = new_values_written; - self.records - .resize(self.values_written * T::get_type_size(), 0); - - let new_levels_len = self.values_written * size_of::(); + self.records.set_len(self.values_written); if let Some(ref mut buf) = self.rep_levels { - buf.resize(new_levels_len, 0) + buf.set_len(self.values_written) }; if let Some(ref mut buf) = self.def_levels { - buf.resize(new_levels_len, 0) + buf.set_len(self.values_written) }; Ok(()) @@ -433,7 +358,12 @@ impl RecordReader { #[cfg(test)] mod tests { - use super::RecordReader; + use std::sync::Arc; + + use arrow::array::{Int16BufferBuilder, Int32BufferBuilder}; + use arrow::bitmap::Bitmap; + use arrow::buffer::Buffer; + use crate::basic::Encoding; use crate::column::page::Page; use crate::column::page::PageReader; @@ -442,12 +372,11 @@ mod tests { use crate::schema::parser::parse_message_type; use crate::schema::types::SchemaDescriptor; use crate::util::test_common::page_util::{DataPageBuilder, DataPageBuilderImpl}; - use arrow::array::{BooleanBufferBuilder, Int16BufferBuilder, Int32BufferBuilder}; - use arrow::bitmap::Bitmap; - use std::sync::Arc; + + use super::RecordReader; struct TestPageReader { - pages: Box>, + pages: Box + Send>, } impl TestPageReader { @@ -624,15 +553,6 @@ mod tests { assert_eq!(7, record_reader.num_values()); } - // Verify result record data - let mut bb = Int32BufferBuilder::new(7); - bb.append_slice(&[0, 7, 0, 6, 3, 0, 8]); - let expected_buffer = bb.finish(); - assert_eq!( - expected_buffer, - record_reader.consume_record_data().unwrap() - ); - // Verify result def levels let mut bb = Int16BufferBuilder::new(7); bb.append_slice(&[1i16, 2i16, 0i16, 2i16, 2i16, 0i16, 2i16]); @@ -643,13 +563,28 @@ mod tests { ); // Verify bitmap - let mut bb = BooleanBufferBuilder::new(7); - bb.append_slice(&[false, true, false, true, true, false, true]); - let expected_bitmap = Bitmap::from(bb.finish()); + let expected_valid = &[false, true, false, true, true, false, true]; + let expected_buffer = Buffer::from_iter(expected_valid.iter().cloned()); + let expected_bitmap = Bitmap::from(expected_buffer); assert_eq!( Some(expected_bitmap), record_reader.consume_bitmap().unwrap() ); + + // Verify result record data + let actual = record_reader.consume_record_data().unwrap(); + let actual_values = actual.typed_data::(); + + let expected = &[0, 7, 0, 6, 3, 0, 8]; + assert_eq!(actual_values.len(), expected.len()); + + // Only validate valid values are equal + let iter = expected_valid.iter().zip(actual_values).zip(expected); + for ((valid, actual), expected) in iter { + if *valid { + assert_eq!(actual, expected) + } + } } #[test] @@ -732,15 +667,6 @@ mod tests { assert_eq!(9, record_reader.num_values()); } - // Verify result record data - let mut bb = Int32BufferBuilder::new(9); - bb.append_slice(&[4, 0, 0, 7, 6, 3, 2, 8, 9]); - let expected_buffer = bb.finish(); - assert_eq!( - expected_buffer, - record_reader.consume_record_data().unwrap() - ); - // Verify result def levels let mut bb = Int16BufferBuilder::new(9); bb.append_slice(&[2i16, 0i16, 1i16, 2i16, 2i16, 2i16, 2i16, 2i16, 2i16]); @@ -751,13 +677,27 @@ mod tests { ); // Verify bitmap - let mut bb = BooleanBufferBuilder::new(9); - bb.append_slice(&[true, false, false, true, true, true, true, true, true]); - let expected_bitmap = Bitmap::from(bb.finish()); + let expected_valid = &[true, false, false, true, true, true, true, true, true]; + let expected_buffer = Buffer::from_iter(expected_valid.iter().cloned()); + let expected_bitmap = Bitmap::from(expected_buffer); assert_eq!( Some(expected_bitmap), record_reader.consume_bitmap().unwrap() ); + + // Verify result record data + let actual = record_reader.consume_record_data().unwrap(); + let actual_values = actual.typed_data::(); + let expected = &[4, 0, 0, 7, 6, 3, 2, 8, 9]; + assert_eq!(actual_values.len(), expected.len()); + + // Only validate valid values are equal + let iter = expected_valid.iter().zip(actual_values).zip(expected); + for ((valid, actual), expected) in iter { + if *valid { + assert_eq!(actual, expected) + } + } } #[test] diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs index 5fe94cef94db..f3d0a3d9b36b 100644 --- a/parquet/src/arrow/schema.rs +++ b/parquet/src/arrow/schema.rs @@ -23,157 +23,67 @@ //! //! The interfaces for converting arrow schema to parquet schema is coming. -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::Arc; -use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use arrow::ipc::writer; -use crate::errors::{ParquetError::ArrowError, Result}; +use crate::basic::{ + ConvertedType, LogicalType, Repetition, TimeUnit as ParquetTimeUnit, + Type as PhysicalType, +}; +use crate::errors::{ParquetError, Result}; use crate::file::{metadata::KeyValue, properties::WriterProperties}; use crate::schema::types::{ColumnDescriptor, SchemaDescriptor, Type, TypePtr}; -use crate::{ - basic::{ - ConvertedType, DecimalType, IntType, LogicalType, Repetition, TimeType, - TimeUnit as ParquetTimeUnit, TimestampType, Type as PhysicalType, - }, - errors::ParquetError, -}; + +mod complex; +mod primitive; + +use crate::arrow::ProjectionMask; +pub(crate) use complex::{convert_schema, ParquetField, ParquetFieldType}; /// Convert Parquet schema to Arrow schema including optional metadata. /// Attempts to decode any existing Arrow schema metadata, falling back /// to converting the Parquet schema column-wise pub fn parquet_to_arrow_schema( parquet_schema: &SchemaDescriptor, - key_value_metadata: &Option>, + key_value_metadata: Option<&Vec>, ) -> Result { - let mut metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); - metadata - .remove(super::ARROW_SCHEMA_META_KEY) - .map(|encoded| get_arrow_schema_from_metadata(&encoded)) - .unwrap_or(parquet_to_arrow_schema_by_columns( - parquet_schema, - 0..parquet_schema.columns().len(), - key_value_metadata, - )) -} - -/// Convert parquet schema to arrow schema including optional metadata, -/// only preserving some root columns. -/// This is useful if we have columns `a.b`, `a.c.e` and `a.d`, -/// and want `a` with all its child fields -pub fn parquet_to_arrow_schema_by_root_columns( - parquet_schema: &SchemaDescriptor, - column_indices: T, - key_value_metadata: &Option>, -) -> Result -where - T: IntoIterator, -{ - // Reconstruct the index ranges of the parent columns - // An Arrow struct gets represented by 1+ columns based on how many child fields the - // struct has. This means that getting fields 1 and 2 might return the struct twice, - // if field 1 is the struct having say 3 fields, and field 2 is a primitive. - // - // The below gets the parent columns, and counts the number of child fields in each parent, - // such that we would end up with: - // - field 1 - columns: [0, 1, 2] - // - field 2 - columns: [3] - let mut parent_columns = vec![]; - let mut curr_name = ""; - let mut prev_name = ""; - let mut indices = vec![]; - (0..(parquet_schema.num_columns())).for_each(|i| { - let p_type = parquet_schema.get_column_root(i); - curr_name = p_type.get_basic_info().name(); - if prev_name.is_empty() { - // first index - indices.push(i); - prev_name = curr_name; - } else if curr_name != prev_name { - prev_name = curr_name; - parent_columns.push((curr_name.to_string(), indices.clone())); - indices = vec![i]; - } else { - indices.push(i); - } - }); - // push the last column if indices has values - if !indices.is_empty() { - parent_columns.push((curr_name.to_string(), indices)); - } - - // gather the required leaf columns - let leaf_columns = column_indices - .into_iter() - .flat_map(|i| parent_columns[i].1.clone()); - - parquet_to_arrow_schema_by_columns(parquet_schema, leaf_columns, key_value_metadata) + parquet_to_arrow_schema_by_columns( + parquet_schema, + ProjectionMask::all(), + key_value_metadata, + ) } /// Convert parquet schema to arrow schema including optional metadata, /// only preserving some leaf columns. -pub fn parquet_to_arrow_schema_by_columns( +pub fn parquet_to_arrow_schema_by_columns( parquet_schema: &SchemaDescriptor, - column_indices: T, - key_value_metadata: &Option>, -) -> Result -where - T: IntoIterator, -{ + mask: ProjectionMask, + key_value_metadata: Option<&Vec>, +) -> Result { let mut metadata = parse_key_value_metadata(key_value_metadata).unwrap_or_default(); - let arrow_schema_metadata = metadata + let maybe_schema = metadata .remove(super::ARROW_SCHEMA_META_KEY) - .map(|encoded| get_arrow_schema_from_metadata(&encoded)) - .map_or(Ok(None), |v| v.map(Some))?; + .map(|value| get_arrow_schema_from_metadata(&value)) + .transpose()?; - // add the Arrow metadata to the Parquet metadata - if let Some(arrow_schema) = &arrow_schema_metadata { + // Add the Arrow metadata to the Parquet metadata skipping keys that collide + if let Some(arrow_schema) = &maybe_schema { arrow_schema.metadata().iter().for_each(|(k, v)| { - metadata.insert(k.clone(), v.clone()); + metadata.entry(k.clone()).or_insert(v.clone()); }); } - let mut base_nodes = Vec::new(); - let mut base_nodes_set = HashSet::new(); - let mut leaves = HashSet::new(); - - enum FieldType<'a> { - Parquet(&'a Type), - Arrow(Field), - } - - for c in column_indices { - let column = parquet_schema.column(c); - let name = column.name(); - - if let Some(field) = arrow_schema_metadata - .as_ref() - .and_then(|schema| schema.field_with_name(name).ok().cloned()) - { - base_nodes.push(FieldType::Arrow(field)); - } else { - let column = column.self_type() as *const Type; - let root = parquet_schema.get_column_root(c); - let root_raw_ptr = root as *const Type; - - leaves.insert(column); - if !base_nodes_set.contains(&root_raw_ptr) { - base_nodes.push(FieldType::Parquet(root)); - base_nodes_set.insert(root_raw_ptr); - } - } + match convert_schema(parquet_schema, mask, maybe_schema.as_ref())? { + Some(field) => match field.arrow_type { + DataType::Struct(fields) => Ok(Schema::new_with_metadata(fields, metadata)), + _ => unreachable!(), + }, + None => Ok(Schema::new_with_metadata(vec![], metadata)), } - - base_nodes - .into_iter() - .map(|t| match t { - FieldType::Parquet(t) => ParquetTypeConverter::new(t, &leaves).to_field(), - FieldType::Arrow(f) => Ok(Some(f)), - }) - .collect::>>>() - .map(|result| result.into_iter().flatten().collect::>()) - .map(|fields| Schema::new_with_metadata(fields, metadata)) } /// Try to convert Arrow schema metadata into a schema @@ -190,24 +100,24 @@ fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Result { Ok(message) => message .header_as_schema() .map(arrow::ipc::convert::fb_to_schema) - .ok_or(ArrowError("the message is not Arrow Schema".to_string())), + .ok_or(arrow_err!("the message is not Arrow Schema")), Err(err) => { // The flatbuffers implementation returns an error on verification error. - Err(ArrowError(format!( + Err(arrow_err!( "Unable to get root as message stored in {}: {:?}", super::ARROW_SCHEMA_META_KEY, err - ))) + )) } } } Err(err) => { // The C++ implementation returns an error if the schema can't be parsed. - Err(ArrowError(format!( + Err(arrow_err!( "Unable to decode the encoded schema stored in {}, {:?}", super::ARROW_SCHEMA_META_KEY, err - ))) + )) } } } @@ -274,7 +184,7 @@ pub fn arrow_to_parquet_schema(schema: &Schema) -> Result { } fn parse_key_value_metadata( - key_value_metadata: &Option>, + key_value_metadata: Option<&Vec>, ) -> Option> { match key_value_metadata { Some(key_values) => { @@ -299,14 +209,13 @@ fn parse_key_value_metadata( /// Convert parquet column schema to arrow field. pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result { - let schema = parquet_column.self_type(); - - let mut leaves = HashSet::new(); - leaves.insert(parquet_column.self_type() as *const Type); + let field = complex::convert_type(&parquet_column.self_type_ptr())?; - ParquetTypeConverter::new(schema, &leaves) - .to_field() - .map(|opt| opt.unwrap()) + Ok(Field::new( + parquet_column.name(), + field.arrow_type, + field.nullable, + )) } pub fn decimal_length_from_precision(precision: usize) -> usize { @@ -324,24 +233,24 @@ fn arrow_to_parquet_type(field: &Field) -> Result { // create type from field match field.data_type() { DataType::Null => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::UNKNOWN(Default::default()))) + .with_logical_type(Some(LogicalType::Unknown)) .with_repetition(repetition) .build(), DataType::Boolean => Type::primitive_type_builder(name, PhysicalType::BOOLEAN) .with_repetition(repetition) .build(), DataType::Int8 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 8, is_signed: true, - }))) + })) .with_repetition(repetition) .build(), DataType::Int16 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 16, is_signed: true, - }))) + })) .with_repetition(repetition) .build(), DataType::Int32 => Type::primitive_type_builder(name, PhysicalType::INT32) @@ -351,85 +260,105 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(repetition) .build(), DataType::UInt8 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 8, is_signed: false, - }))) + })) .with_repetition(repetition) .build(), DataType::UInt16 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 16, is_signed: false, - }))) + })) .with_repetition(repetition) .build(), DataType::UInt32 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 32, is_signed: false, - }))) + })) .with_repetition(repetition) .build(), DataType::UInt64 => Type::primitive_type_builder(name, PhysicalType::INT64) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 64, is_signed: false, - }))) + })) .with_repetition(repetition) .build(), - DataType::Float16 => Err(ArrowError("Float16 arrays not supported".to_string())), + DataType::Float16 => Err(arrow_err!("Float16 arrays not supported")), DataType::Float32 => Type::primitive_type_builder(name, PhysicalType::FLOAT) .with_repetition(repetition) .build(), DataType::Float64 => Type::primitive_type_builder(name, PhysicalType::DOUBLE) .with_repetition(repetition) .build(), - DataType::Timestamp(time_unit, zone) => Type::primitive_type_builder( - name, - PhysicalType::INT64, - ) - .with_logical_type(Some(LogicalType::TIMESTAMP(TimestampType { - is_adjusted_to_u_t_c: matches!(zone, Some(z) if !z.as_str().is_empty()), - unit: match time_unit { - TimeUnit::Second => ParquetTimeUnit::MILLIS(Default::default()), - TimeUnit::Millisecond => ParquetTimeUnit::MILLIS(Default::default()), - TimeUnit::Microsecond => ParquetTimeUnit::MICROS(Default::default()), - TimeUnit::Nanosecond => ParquetTimeUnit::NANOS(Default::default()), - }, - }))) - .with_repetition(repetition) - .build(), + DataType::Timestamp(TimeUnit::Second, _) => { + // Cannot represent seconds in LogicalType + Type::primitive_type_builder(name, PhysicalType::INT64) + .with_repetition(repetition) + .build() + } + DataType::Timestamp(time_unit, _) => { + Type::primitive_type_builder(name, PhysicalType::INT64) + .with_logical_type(Some(LogicalType::Timestamp { + is_adjusted_to_u_t_c: false, + unit: match time_unit { + TimeUnit::Second => unreachable!(), + TimeUnit::Millisecond => { + ParquetTimeUnit::MILLIS(Default::default()) + } + TimeUnit::Microsecond => { + ParquetTimeUnit::MICROS(Default::default()) + } + TimeUnit::Nanosecond => { + ParquetTimeUnit::NANOS(Default::default()) + } + }, + })) + .with_repetition(repetition) + .build() + } DataType::Date32 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::DATE(Default::default()))) + .with_logical_type(Some(LogicalType::Date)) .with_repetition(repetition) .build(), - // date64 is cast to date32 + // date64 is cast to date32 (#1666) DataType::Date64 => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::DATE(Default::default()))) + .with_logical_type(Some(LogicalType::Date)) .with_repetition(repetition) .build(), - DataType::Time32(_) => Type::primitive_type_builder(name, PhysicalType::INT32) - .with_logical_type(Some(LogicalType::TIME(TimeType { + DataType::Time32(TimeUnit::Second) => { + // Cannot represent seconds in LogicalType + Type::primitive_type_builder(name, PhysicalType::INT32) + .with_repetition(repetition) + .build() + } + DataType::Time32(unit) => Type::primitive_type_builder(name, PhysicalType::INT32) + .with_logical_type(Some(LogicalType::Time { is_adjusted_to_u_t_c: false, - unit: ParquetTimeUnit::MILLIS(Default::default()), - }))) + unit: match unit { + TimeUnit::Millisecond => ParquetTimeUnit::MILLIS(Default::default()), + u => unreachable!("Invalid unit for Time32: {:?}", u), + }, + })) .with_repetition(repetition) .build(), DataType::Time64(unit) => Type::primitive_type_builder(name, PhysicalType::INT64) - .with_logical_type(Some(LogicalType::TIME(TimeType { + .with_logical_type(Some(LogicalType::Time { is_adjusted_to_u_t_c: false, unit: match unit { TimeUnit::Microsecond => ParquetTimeUnit::MICROS(Default::default()), TimeUnit::Nanosecond => ParquetTimeUnit::NANOS(Default::default()), u => unreachable!("Invalid unit for Time64: {:?}", u), }, - }))) + })) .with_repetition(repetition) .build(), - DataType::Duration(_) => Err(ArrowError( - "Converting Duration to parquet not supported".to_string(), - )), + DataType::Duration(_) => { + Err(arrow_err!("Converting Duration to parquet not supported",)) + } DataType::Interval(_) => { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_converted_type(ConvertedType::INTERVAL) @@ -465,17 +394,17 @@ fn arrow_to_parquet_type(field: &Field) -> Result { Type::primitive_type_builder(name, PhysicalType::FIXED_LEN_BYTE_ARRAY) .with_repetition(repetition) .with_length(decimal_length_from_precision(*precision) as i32) - .with_logical_type(Some(LogicalType::DECIMAL(DecimalType { + .with_logical_type(Some(LogicalType::Decimal { scale: *scale as i32, precision: *precision as i32, - }))) + })) .with_precision(*precision as i32) .with_scale(*scale as i32) .build() } DataType::Utf8 | DataType::LargeUtf8 => { Type::primitive_type_builder(name, PhysicalType::BYTE_ARRAY) - .with_logical_type(Some(LogicalType::STRING(Default::default()))) + .with_logical_type(Some(LogicalType::String)) .with_repetition(repetition) .build() } @@ -487,15 +416,15 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(Repetition::REPEATED) .build()?, )]) - .with_logical_type(Some(LogicalType::LIST(Default::default()))) + .with_logical_type(Some(LogicalType::List)) .with_repetition(repetition) .build() } DataType::Struct(fields) => { if fields.is_empty() { - return Err(ArrowError( - "Parquet does not support writing empty structs".to_string(), - )); + return Err( + arrow_err!("Parquet does not support writing empty structs",), + ); } // recursively convert children to types/nodes let fields: Result> = fields @@ -527,16 +456,16 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_repetition(Repetition::REPEATED) .build()?, )]) - .with_logical_type(Some(LogicalType::MAP(Default::default()))) + .with_logical_type(Some(LogicalType::Map)) .with_repetition(repetition) .build() } else { - Err(ArrowError( - "DataType::Map should contain a struct field child".to_string(), + Err(arrow_err!( + "DataType::Map should contain a struct field child", )) } } - DataType::Union(_) => unimplemented!("See ARROW-8817."), + DataType::Union(_, _, _) => unimplemented!("See ARROW-8817."), DataType::Dictionary(_, ref value) => { // Dictionary encoding not handled at the schema level let dict_field = Field::new(name, *value.clone(), field.is_nullable()); @@ -544,507 +473,19 @@ fn arrow_to_parquet_type(field: &Field) -> Result { } } } -/// This struct is used to group methods and data structures used to convert parquet -/// schema together. -struct ParquetTypeConverter<'a> { - schema: &'a Type, - /// This is the columns that need to be converted to arrow schema. - columns_to_convert: &'a HashSet<*const Type>, -} - -impl<'a> ParquetTypeConverter<'a> { - fn new(schema: &'a Type, columns_to_convert: &'a HashSet<*const Type>) -> Self { - Self { - schema, - columns_to_convert, - } - } - - fn clone_with_schema(&self, other: &'a Type) -> Self { - Self { - schema: other, - columns_to_convert: self.columns_to_convert, - } - } -} - -impl ParquetTypeConverter<'_> { - // Public interfaces. - - /// Converts parquet schema to arrow data type. - /// - /// This function discards schema name. - /// - /// If this schema is a primitive type and not included in the leaves, the result is - /// Ok(None). - /// - /// If this schema is a group type and none of its children is reserved in the - /// conversion, the result is Ok(None). - fn to_data_type(&self) -> Result> { - match self.schema { - Type::PrimitiveType { .. } => self.to_primitive_type(), - Type::GroupType { .. } => self.to_group_type(), - } - } - - /// Converts parquet schema to arrow field. - /// - /// This method is roughly the same as - /// [`to_data_type`](`ParquetTypeConverter::to_data_type`), except it reserves schema - /// name. - fn to_field(&self) -> Result> { - self.to_data_type().map(|opt| { - opt.map(|dt| Field::new(self.schema.name(), dt, self.is_nullable())) - }) - } - - // Utility functions. - - /// Checks whether this schema is nullable. - fn is_nullable(&self) -> bool { - let basic_info = self.schema.get_basic_info(); - if basic_info.has_repetition() { - match basic_info.repetition() { - Repetition::OPTIONAL => true, - Repetition::REPEATED => true, - Repetition::REQUIRED => false, - } - } else { - false - } - } - - fn is_repeated(&self) -> bool { - let basic_info = self.schema.get_basic_info(); - - basic_info.has_repetition() && basic_info.repetition() == Repetition::REPEATED - } - - fn is_self_included(&self) -> bool { - self.columns_to_convert - .contains(&(self.schema as *const Type)) - } - - // Functions for primitive types. - - /// Entry point for converting parquet primitive type to arrow type. - /// - /// This function takes care of repetition. - fn to_primitive_type(&self) -> Result> { - if self.is_self_included() { - self.to_primitive_type_inner().map(|dt| { - if self.is_repeated() { - Some(DataType::List(Box::new(Field::new( - self.schema.name(), - dt, - self.is_nullable(), - )))) - } else { - Some(dt) - } - }) - } else { - Ok(None) - } - } - - /// Converting parquet primitive type to arrow data type. - fn to_primitive_type_inner(&self) -> Result { - match self.schema.get_physical_type() { - PhysicalType::BOOLEAN => Ok(DataType::Boolean), - PhysicalType::INT32 => self.from_int32(), - PhysicalType::INT64 => self.from_int64(), - PhysicalType::INT96 => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), - PhysicalType::FLOAT => Ok(DataType::Float32), - PhysicalType::DOUBLE => Ok(DataType::Float64), - PhysicalType::BYTE_ARRAY => self.from_byte_array(), - PhysicalType::FIXED_LEN_BYTE_ARRAY => self.from_fixed_len_byte_array(), - } - } - - fn from_int32(&self) -> Result { - match ( - self.schema.get_basic_info().logical_type(), - self.schema.get_basic_info().converted_type(), - ) { - (None, ConvertedType::NONE) => Ok(DataType::Int32), - (Some(LogicalType::INTEGER(t)), _) => match (t.bit_width, t.is_signed) { - (8, true) => Ok(DataType::Int8), - (16, true) => Ok(DataType::Int16), - (32, true) => Ok(DataType::Int32), - (8, false) => Ok(DataType::UInt8), - (16, false) => Ok(DataType::UInt16), - (32, false) => Ok(DataType::UInt32), - _ => Err(ArrowError(format!( - "Cannot create INT32 physical type from {:?}", - t - ))), - }, - (Some(LogicalType::DECIMAL(_)), _) => Ok(self.to_decimal()), - (Some(LogicalType::DATE(_)), _) => Ok(DataType::Date32), - (Some(LogicalType::TIME(t)), _) => match t.unit { - ParquetTimeUnit::MILLIS(_) => Ok(DataType::Time32(TimeUnit::Millisecond)), - _ => Err(ArrowError(format!( - "Cannot create INT32 physical type from {:?}", - t.unit - ))), - }, - (None, ConvertedType::UINT_8) => Ok(DataType::UInt8), - (None, ConvertedType::UINT_16) => Ok(DataType::UInt16), - (None, ConvertedType::UINT_32) => Ok(DataType::UInt32), - (None, ConvertedType::INT_8) => Ok(DataType::Int8), - (None, ConvertedType::INT_16) => Ok(DataType::Int16), - (None, ConvertedType::INT_32) => Ok(DataType::Int32), - (None, ConvertedType::DATE) => Ok(DataType::Date32), - (None, ConvertedType::TIME_MILLIS) => { - Ok(DataType::Time32(TimeUnit::Millisecond)) - } - (None, ConvertedType::DECIMAL) => Ok(self.to_decimal()), - (logical, converted) => Err(ArrowError(format!( - "Unable to convert parquet INT32 logical type {:?} or converted type {}", - logical, converted - ))), - } - } - - fn from_int64(&self) -> Result { - match ( - self.schema.get_basic_info().logical_type(), - self.schema.get_basic_info().converted_type(), - ) { - (None, ConvertedType::NONE) => Ok(DataType::Int64), - (Some(LogicalType::INTEGER(t)), _) if t.bit_width == 64 => { - match t.is_signed { - true => Ok(DataType::Int64), - false => Ok(DataType::UInt64), - } - } - (Some(LogicalType::TIME(t)), _) => match t.unit { - ParquetTimeUnit::MILLIS(_) => Err(ArrowError( - "Cannot create INT64 from MILLIS time unit".to_string(), - )), - ParquetTimeUnit::MICROS(_) => Ok(DataType::Time64(TimeUnit::Microsecond)), - ParquetTimeUnit::NANOS(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)), - }, - (Some(LogicalType::TIMESTAMP(t)), _) => Ok(DataType::Timestamp( - match t.unit { - ParquetTimeUnit::MILLIS(_) => TimeUnit::Millisecond, - ParquetTimeUnit::MICROS(_) => TimeUnit::Microsecond, - ParquetTimeUnit::NANOS(_) => TimeUnit::Nanosecond, - }, - if t.is_adjusted_to_u_t_c { - Some("UTC".to_string()) - } else { - None - }, - )), - (None, ConvertedType::INT_64) => Ok(DataType::Int64), - (None, ConvertedType::UINT_64) => Ok(DataType::UInt64), - (None, ConvertedType::TIME_MICROS) => { - Ok(DataType::Time64(TimeUnit::Microsecond)) - } - (None, ConvertedType::TIMESTAMP_MILLIS) => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - (None, ConvertedType::TIMESTAMP_MICROS) => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - (Some(LogicalType::DECIMAL(_)), _) => Ok(self.to_decimal()), - (None, ConvertedType::DECIMAL) => Ok(self.to_decimal()), - (logical, converted) => Err(ArrowError(format!( - "Unable to convert parquet INT64 logical type {:?} or converted type {}", - logical, converted - ))), - } - } - - fn from_fixed_len_byte_array(&self) -> Result { - match ( - self.schema.get_basic_info().logical_type(), - self.schema.get_basic_info().converted_type(), - ) { - (Some(LogicalType::DECIMAL(_)), _) => Ok(self.to_decimal()), - (None, ConvertedType::DECIMAL) => Ok(self.to_decimal()), - (None, ConvertedType::INTERVAL) => { - // There is currently no reliable way of determining which IntervalUnit - // to return. Thus without the original Arrow schema, the results - // would be incorrect if all 12 bytes of the interval are populated - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - _ => { - let byte_width = match self.schema { - Type::PrimitiveType { - ref type_length, .. - } => *type_length, - _ => { - return Err(ArrowError( - "Expected a physical type, not a group type".to_string(), - )) - } - }; - - Ok(DataType::FixedSizeBinary(byte_width)) - } - } - } - - fn to_decimal(&self) -> DataType { - assert!(self.schema.is_primitive()); - DataType::Decimal( - self.schema.get_precision() as usize, - self.schema.get_scale() as usize, - ) - } - - fn from_byte_array(&self) -> Result { - match (self.schema.get_basic_info().logical_type(), self.schema.get_basic_info().converted_type()) { - (Some(LogicalType::STRING(_)), _) => Ok(DataType::Utf8), - (Some(LogicalType::JSON(_)), _) => Ok(DataType::Binary), - (Some(LogicalType::BSON(_)), _) => Ok(DataType::Binary), - (Some(LogicalType::ENUM(_)), _) => Ok(DataType::Binary), - (None, ConvertedType::NONE) => Ok(DataType::Binary), - (None, ConvertedType::JSON) => Ok(DataType::Binary), - (None, ConvertedType::BSON) => Ok(DataType::Binary), - (None, ConvertedType::ENUM) => Ok(DataType::Binary), - (None, ConvertedType::UTF8) => Ok(DataType::Utf8), - (logical, converted) => Err(ArrowError(format!( - "Unable to convert parquet BYTE_ARRAY logical type {:?} or converted type {}", - logical, converted - ))), - } - } - - // Functions for group types. - - /// Entry point for converting parquet group type. - /// - /// This function takes care of logical type and repetition. - fn to_group_type(&self) -> Result> { - match ( - self.schema.get_basic_info().logical_type(), - self.schema.get_basic_info().converted_type(), - ) { - (Some(LogicalType::LIST(_)), _) | (_, ConvertedType::LIST) => self.to_list(), - (Some(LogicalType::MAP(_)), _) - | (_, ConvertedType::MAP) - | (_, ConvertedType::MAP_KEY_VALUE) => self.to_map(), - (_, _) => { - if self.is_repeated() { - self.to_struct().map(|opt| { - opt.map(|dt| { - DataType::List(Box::new(Field::new( - self.schema.name(), - dt, - self.is_nullable(), - ))) - }) - }) - } else { - self.to_struct() - } - } - } - } - - /// Converts a parquet group type to arrow struct. - fn to_struct(&self) -> Result> { - match self.schema { - Type::PrimitiveType { .. } => Err(ParquetError::General(format!( - "{:?} is a struct type, and can't be processed as primitive.", - self.schema - ))), - Type::GroupType { - basic_info: _, - fields, - } => fields - .iter() - .map(|field_ptr| self.clone_with_schema(field_ptr).to_field()) - .collect::>>>() - .map(|result| result.into_iter().flatten().collect::>()) - .map(|fields| { - if fields.is_empty() { - None - } else { - Some(DataType::Struct(fields)) - } - }), - } - } - - /// Converts a parquet list to arrow list. - /// - /// To fully understand this algorithm, please refer to - /// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). - fn to_list(&self) -> Result> { - match self.schema { - Type::PrimitiveType { .. } => Err(ParquetError::General(format!( - "{:?} is a list type and can't be processed as primitive.", - self.schema - ))), - Type::GroupType { - basic_info: _, - fields, - } if fields.len() == 1 => { - let list_item = fields.first().unwrap(); - let item_converter = self.clone_with_schema(list_item); - - let item_type = match list_item.as_ref() { - Type::PrimitiveType { .. } => { - if item_converter.is_repeated() { - item_converter.to_primitive_type_inner().map(Some) - } else { - Err(ArrowError( - "Primitive element type of list must be repeated." - .to_string(), - )) - } - } - Type::GroupType { - basic_info: _, - fields, - } => { - if fields.len() > 1 { - item_converter.to_struct() - } else if fields.len() == 1 - && list_item.name() != "array" - && list_item.name() != format!("{}_tuple", self.schema.name()) - { - let nested_item = fields.first().unwrap(); - let nested_item_converter = - self.clone_with_schema(nested_item); - - nested_item_converter.to_data_type() - } else { - item_converter.to_struct() - } - } - }; - - // Check that the name of the list child is "list", in which case we - // get the child nullability and name (normally "element") from the nested - // group type. - // Without this step, the child incorrectly inherits the parent's optionality - let (list_item_name, item_is_optional) = match &item_converter.schema { - Type::GroupType { basic_info, fields } - if basic_info.name() == "list" && fields.len() == 1 => - { - let field = fields.first().unwrap(); - (field.name(), field.is_optional()) - } - _ => (list_item.name(), list_item.is_optional()), - }; - - item_type.map(|opt| { - opt.map(|dt| { - DataType::List(Box::new(Field::new( - list_item_name, - dt, - item_is_optional, - ))) - }) - }) - } - _ => Err(ArrowError( - "Group element type of list can only contain one field.".to_string(), - )), - } - } - - /// Converts a parquet map to arrow map. - /// - /// To fully understand this algorithm, please refer to - /// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). - fn to_map(&self) -> Result> { - match self.schema { - Type::PrimitiveType { .. } => Err(ParquetError::General(format!( - "{:?} is a map type and can't be processed as primitive.", - self.schema - ))), - Type::GroupType { - basic_info: _, - fields, - } if fields.len() == 1 => { - let key_item = fields.first().unwrap(); - - let (key_type, value_type) = match key_item.as_ref() { - Type::PrimitiveType { .. } => { - return Err(ArrowError( - "A map can only have a group child type (key_values)." - .to_string(), - )) - } - Type::GroupType { - basic_info: _, - fields, - } => { - if fields.len() != 2 { - return Err(ArrowError(format!("Map type should have 2 fields, a key and value. Found {} fields", fields.len()))); - } else { - let nested_key = fields.first().unwrap(); - let nested_key_converter = self.clone_with_schema(nested_key); - - let nested_value = fields.last().unwrap(); - let nested_value_converter = - self.clone_with_schema(nested_value); - - ( - nested_key_converter.to_data_type()?.map(|d| { - Field::new( - nested_key.name(), - d, - nested_key.is_optional(), - ) - }), - nested_value_converter.to_data_type()?.map(|d| { - Field::new( - nested_value.name(), - d, - nested_value.is_optional(), - ) - }), - ) - } - } - }; - - match (key_type, value_type) { - (Some(key), Some(value)) => Ok(Some(DataType::Map( - Box::new(Field::new( - key_item.name(), - DataType::Struct(vec![key, value]), - false, - )), - false, // There is no information to tell if keys are sorted - ))), - (None, None) => Ok(None), - (None, Some(_)) => Err(ArrowError( - "Could not convert the map key to a valid datatype".to_string(), - )), - (Some(_), None) => Err(ArrowError( - "Could not convert the map value to a valid datatype".to_string(), - )), - } - } - _ => Err(ArrowError( - "Group element type of map can only contain one field.".to_string(), - )), - } - } -} #[cfg(test)] mod tests { use super::*; - use std::{collections::HashMap, convert::TryFrom, sync::Arc}; + use std::{collections::HashMap, sync::Arc}; use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; - use crate::file::{metadata::KeyValue, reader::SerializedFileReader}; + use crate::file::metadata::KeyValue; use crate::{ arrow::{ArrowReader, ArrowWriter, ParquetFileArrowReader}, schema::{parser::parse_message_type, types::SchemaDescriptor}, - util::test_common::get_temp_file, }; #[test] @@ -1068,7 +509,7 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = vec![ Field::new("boolean", DataType::Boolean, false), @@ -1100,7 +541,7 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = vec![ Field::new("binary", DataType::Binary, false), @@ -1122,7 +563,7 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let arrow_fields = vec![ Field::new("boolean", DataType::Boolean, false), @@ -1132,8 +573,8 @@ mod tests { let converted_arrow_schema = parquet_to_arrow_schema_by_columns( &parquet_schema, - vec![0usize, 1usize], - &None, + ProjectionMask::all(), + None, ) .unwrap(); assert_eq!(&arrow_fields, converted_arrow_schema.fields()); @@ -1253,7 +694,7 @@ mod tests { { arrow_fields.push(Field::new( "my_list", - DataType::List(Box::new(Field::new("element", DataType::Utf8, true))), + DataType::List(Box::new(Field::new("str", DataType::Utf8, false))), true, )); } @@ -1265,7 +706,7 @@ mod tests { { arrow_fields.push(Field::new( "my_list", - DataType::List(Box::new(Field::new("element", DataType::Int32, true))), + DataType::List(Box::new(Field::new("element", DataType::Int32, false))), true, )); } @@ -1284,7 +725,7 @@ mod tests { ]); arrow_fields.push(Field::new( "my_list", - DataType::List(Box::new(Field::new("element", arrow_struct, true))), + DataType::List(Box::new(Field::new("element", arrow_struct, false))), true, )); } @@ -1301,7 +742,7 @@ mod tests { DataType::Struct(vec![Field::new("str", DataType::Utf8, false)]); arrow_fields.push(Field::new( "my_list", - DataType::List(Box::new(Field::new("array", arrow_struct, true))), + DataType::List(Box::new(Field::new("array", arrow_struct, false))), true, )); } @@ -1318,7 +759,11 @@ mod tests { DataType::Struct(vec![Field::new("str", DataType::Utf8, false)]); arrow_fields.push(Field::new( "my_list", - DataType::List(Box::new(Field::new("my_list_tuple", arrow_struct, true))), + DataType::List(Box::new(Field::new( + "my_list_tuple", + arrow_struct, + false, + ))), true, )); } @@ -1328,8 +773,8 @@ mod tests { { arrow_fields.push(Field::new( "name", - DataType::List(Box::new(Field::new("name", DataType::Int32, true))), - true, + DataType::List(Box::new(Field::new("name", DataType::Int32, false))), + false, )); } @@ -1337,12 +782,12 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); for i in 0..arrow_fields.len() { - assert_eq!(arrow_fields[i], converted_fields[i]); + assert_eq!(arrow_fields[i], converted_fields[i], "{}", i); } } @@ -1416,7 +861,7 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1495,7 +940,7 @@ mod tests { Field::new("str", DataType::Utf8, false), Field::new("num", DataType::Int32, false), ]), - false, + false, // (#1697) )), false, ), @@ -1520,7 +965,7 @@ mod tests { Field::new("key", DataType::Utf8, false), Field::new("value", DataType::Int32, true), ]), - false, + false, // (#1697) )), false, ), @@ -1532,7 +977,7 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1570,7 +1015,7 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1619,9 +1064,9 @@ mod tests { // required int64 leaf5; let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + let mask = ProjectionMask::leaves(&parquet_schema, [3, 0, 4, 4]); let converted_arrow_schema = - parquet_to_arrow_schema_by_columns(&parquet_schema, vec![0, 3, 4], &None) - .unwrap(); + parquet_to_arrow_schema_by_columns(&parquet_schema, mask, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1634,15 +1079,15 @@ mod tests { fn test_nested_schema_partial_ordering() { let mut arrow_fields = Vec::new(); { + let group1_fields = vec![Field::new("leaf1", DataType::Int64, false)]; + let group1 = Field::new("group1", DataType::Struct(group1_fields), false); + arrow_fields.push(group1); + let group2_fields = vec![Field::new("leaf4", DataType::Int64, false)]; let group2 = Field::new("group2", DataType::Struct(group2_fields), false); arrow_fields.push(group2); arrow_fields.push(Field::new("leaf5", DataType::Int64, false)); - - let group1_fields = vec![Field::new("leaf1", DataType::Int64, false)]; - let group1 = Field::new("group1", DataType::Struct(group1_fields), false); - arrow_fields.push(group1); } let message_type = " @@ -1670,9 +1115,9 @@ mod tests { // required int64 leaf5; let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + let mask = ProjectionMask::leaves(&parquet_schema, [3, 0, 4]); let converted_arrow_schema = - parquet_to_arrow_schema_by_columns(&parquet_schema, vec![3, 4, 0], &None) - .unwrap(); + parquet_to_arrow_schema_by_columns(&parquet_schema, mask, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1692,9 +1137,9 @@ mod tests { DataType::List(Box::new(Field::new( "innerGroup", DataType::Struct(vec![Field::new("leaf3", DataType::Int32, true)]), - true, + false, ))), - true, + false, ); let outer_group_list = Field::new( @@ -1705,9 +1150,9 @@ mod tests { Field::new("leaf2", DataType::Int32, true), inner_group_list, ]), - true, + false, ))), - true, + false, ); arrow_fields.push(outer_group_list); } @@ -1727,7 +1172,7 @@ mod tests { let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &None).unwrap(); + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); let converted_fields = converted_arrow_schema.fields(); assert_eq!(arrow_fields.len(), converted_fields.len()); @@ -1782,8 +1227,8 @@ mod tests { Field::new("string", DataType::Utf8, true), Field::new( "bools", - DataType::List(Box::new(Field::new("bools", DataType::Boolean, true))), - true, + DataType::List(Box::new(Field::new("bools", DataType::Boolean, false))), + false, ), Field::new("date", DataType::Date32, true), Field::new("time_milli", DataType::Time32(TimeUnit::Millisecond), true), @@ -1956,16 +1401,17 @@ mod tests { "; let parquet_group_type = parse_message_type(message_type).unwrap(); - let mut key_value_metadata: Vec = Vec::new(); - key_value_metadata.push(KeyValue::new("foo".to_owned(), Some("bar".to_owned()))); - key_value_metadata.push(KeyValue::new("baz".to_owned(), None)); + let key_value_metadata = vec![ + KeyValue::new("foo".to_owned(), Some("bar".to_owned())), + KeyValue::new("baz".to_owned(), None), + ]; let mut expected_metadata: HashMap = HashMap::new(); expected_metadata.insert("foo".to_owned(), "bar".to_owned()); let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); let converted_arrow_schema = - parquet_to_arrow_schema(&parquet_schema, &Some(key_value_metadata)).unwrap(); + parquet_to_arrow_schema(&parquet_schema, Some(&key_value_metadata)).unwrap(); assert_eq!(converted_arrow_schema.metadata(), &expected_metadata); } @@ -2090,7 +1536,7 @@ mod tests { true, ), ]), - false, + false, // #1697 )), false, // fails to roundtrip keys_sorted ), @@ -2113,19 +1559,42 @@ mod tests { true, ), ]), - false, + false, // #1697 )), false, // fails to roundtrip keys_sorted ), true, ), + Field::new( + "c41", + DataType::Map( + Box::new(Field::new( + "my_entries", + DataType::Struct(vec![ + Field::new("my_key", DataType::Utf8, false), + Field::new( + "my_value", + DataType::List(Box::new(Field::new( + "item", + DataType::Utf8, + true, + ))), + true, + ), + ]), + false, + )), + false, // fails to roundtrip keys_sorted + ), + false, + ), ], metadata, ); // write to an empty parquet file so that schema is serialized - let file = get_temp_file("test_arrow_schema_roundtrip.parquet", &[]); - let mut writer = ArrowWriter::try_new( + let file = tempfile::tempfile().unwrap(); + let writer = ArrowWriter::try_new( file.try_clone().unwrap(), Arc::new(schema.clone()), None, @@ -2133,21 +1602,19 @@ mod tests { writer.close()?; // read file back - let parquet_reader = SerializedFileReader::try_from(file)?; - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader)); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); let read_schema = arrow_reader.get_schema()?; assert_eq!(schema, read_schema); // read all fields by columns let partial_read_schema = - arrow_reader.get_schema_by_columns(0..(schema.fields().len()), false)?; + arrow_reader.get_schema_by_columns(ProjectionMask::all())?; assert_eq!(schema, partial_read_schema); Ok(()) } #[test] - #[ignore = "Roundtrip of lists currently fails because we don't check their types correctly in the Arrow schema"] fn test_arrow_schema_roundtrip_lists() -> Result<()> { let metadata: HashMap = [("Key".to_string(), "Value".to_string())] @@ -2195,8 +1662,8 @@ mod tests { ); // write to an empty parquet file so that schema is serialized - let file = get_temp_file("test_arrow_schema_roundtrip_lists.parquet", &[]); - let mut writer = ArrowWriter::try_new( + let file = tempfile::tempfile().unwrap(); + let writer = ArrowWriter::try_new( file.try_clone().unwrap(), Arc::new(schema.clone()), None, @@ -2204,14 +1671,13 @@ mod tests { writer.close()?; // read file back - let parquet_reader = SerializedFileReader::try_from(file)?; - let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader)); + let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); let read_schema = arrow_reader.get_schema()?; assert_eq!(schema, read_schema); // read all fields by columns let partial_read_schema = - arrow_reader.get_schema_by_columns(0..(schema.fields().len()), false)?; + arrow_reader.get_schema_by_columns(ProjectionMask::all())?; assert_eq!(schema, partial_read_schema); Ok(()) diff --git a/parquet/src/arrow/schema/complex.rs b/parquet/src/arrow/schema/complex.rs new file mode 100644 index 000000000000..d63ab5606b03 --- /dev/null +++ b/parquet/src/arrow/schema/complex.rs @@ -0,0 +1,586 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::schema::primitive::convert_primitive; +use crate::arrow::ProjectionMask; +use crate::basic::{ConvertedType, Repetition}; +use crate::errors::ParquetError; +use crate::errors::Result; +use crate::schema::types::{SchemaDescriptor, Type, TypePtr}; +use arrow::datatypes::{DataType, Field, Schema}; + +fn get_repetition(t: &Type) -> Repetition { + let info = t.get_basic_info(); + match info.has_repetition() { + true => info.repetition(), + false => Repetition::REQUIRED, + } +} + +/// Representation of a parquet file, in terms of arrow schema elements +pub struct ParquetField { + /// The level which represents an insertion into the current list + /// i.e. guaranteed to be > 0 for a list type + pub rep_level: i16, + /// The level at which this field is fully defined, + /// i.e. guaranteed to be > 0 for a nullable type + pub def_level: i16, + /// Whether this field is nullable + pub nullable: bool, + /// The arrow type of the column data + /// + /// Note: In certain cases the data stored in parquet may have been coerced + /// to a different type and will require conversion on read (e.g. Date64 and Interval) + pub arrow_type: DataType, + /// The type of this field + pub field_type: ParquetFieldType, +} + +impl ParquetField { + /// Converts `self` into an arrow list, with its current type as the field type + /// + /// This is used to convert repeated columns, into their arrow representation + fn into_list(self, name: &str) -> Self { + ParquetField { + rep_level: self.rep_level, + def_level: self.def_level, + nullable: false, + arrow_type: DataType::List(Box::new(Field::new( + name, + self.arrow_type.clone(), + false, + ))), + field_type: ParquetFieldType::Group { + children: vec![self], + }, + } + } + + /// Returns a list of [`ParquetField`] children if this is a group type + pub fn children(&self) -> Option<&[Self]> { + match &self.field_type { + ParquetFieldType::Primitive { .. } => None, + ParquetFieldType::Group { children } => Some(children), + } + } +} + +pub enum ParquetFieldType { + Primitive { + /// The index of the column in parquet + col_idx: usize, + /// The type of the column in parquet + primitive_type: TypePtr, + }, + Group { + children: Vec, + }, +} + +/// Encodes the context of the parent of the field currently under consideration +struct VisitorContext { + rep_level: i16, + def_level: i16, + /// An optional [`DataType`] sourced from the embedded arrow schema + data_type: Option, +} + +impl VisitorContext { + /// Compute the resulting definition level, repetition level and nullability + /// for a child field with the given [`Repetition`] + fn levels(&self, repetition: Repetition) -> (i16, i16, bool) { + match repetition { + Repetition::OPTIONAL => (self.def_level + 1, self.rep_level, true), + Repetition::REQUIRED => (self.def_level, self.rep_level, false), + Repetition::REPEATED => (self.def_level + 1, self.rep_level + 1, false), + } + } +} + +/// Walks the parquet schema in a depth-first fashion in order to map it to arrow data structures +/// +/// See [Logical Types] for more information on the conversion algorithm +/// +/// [Logical Types]: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md +struct Visitor { + /// The column index of the next leaf column + next_col_idx: usize, + + /// Mask of columns to include + mask: ProjectionMask, +} + +impl Visitor { + fn visit_primitive( + &mut self, + primitive_type: &TypePtr, + context: VisitorContext, + ) -> Result> { + let col_idx = self.next_col_idx; + self.next_col_idx += 1; + + if !self.mask.leaf_included(col_idx) { + return Ok(None); + } + + let repetition = get_repetition(primitive_type); + let (def_level, rep_level, nullable) = context.levels(repetition); + + let arrow_type = convert_primitive(primitive_type, context.data_type)?; + + let primitive_field = ParquetField { + rep_level, + def_level, + nullable, + arrow_type, + field_type: ParquetFieldType::Primitive { + primitive_type: primitive_type.clone(), + col_idx, + }, + }; + + Ok(Some(match repetition { + Repetition::REPEATED => primitive_field.into_list(primitive_type.name()), + _ => primitive_field, + })) + } + + fn visit_struct( + &mut self, + struct_type: &TypePtr, + context: VisitorContext, + ) -> Result> { + // The root type will not have a repetition level + let repetition = get_repetition(struct_type); + let (def_level, rep_level, nullable) = context.levels(repetition); + + let parquet_fields = struct_type.get_fields(); + + // Extract the arrow fields + let arrow_fields = match &context.data_type { + Some(DataType::Struct(fields)) => { + if fields.len() != parquet_fields.len() { + return Err(arrow_err!( + "incompatible arrow schema, expected {} struct fields got {}", + parquet_fields.len(), + fields.len() + )); + } + Some(fields) + } + Some(d) => { + return Err(arrow_err!( + "incompatible arrow schema, expected struct got {}", + d + )) + } + None => None, + }; + + let mut child_fields = Vec::with_capacity(parquet_fields.len()); + let mut children = Vec::with_capacity(parquet_fields.len()); + + // Perform a DFS of children + for (idx, parquet_field) in parquet_fields.iter().enumerate() { + let data_type = match arrow_fields { + Some(fields) => { + let field = &fields[idx]; + if field.name() != parquet_field.name() { + return Err(arrow_err!( + "incompatible arrow schema, expected field named {} got {}", + parquet_field.name(), + field.name() + )); + } + Some(field.data_type().clone()) + } + None => None, + }; + + let arrow_field = arrow_fields.map(|x| &x[idx]); + let child_ctx = VisitorContext { + rep_level, + def_level, + data_type, + }; + + if let Some(child) = self.dispatch(parquet_field, child_ctx)? { + // The child type returned may be different from what is encoded in the arrow + // schema in the event of a mismatch or a projection + child_fields.push(convert_field(parquet_field, &child, arrow_field)); + children.push(child); + } + } + + if children.is_empty() { + return Ok(None); + } + + let struct_field = ParquetField { + rep_level, + def_level, + nullable, + arrow_type: DataType::Struct(child_fields), + field_type: ParquetFieldType::Group { children }, + }; + + Ok(Some(match repetition { + Repetition::REPEATED => struct_field.into_list(struct_type.name()), + _ => struct_field, + })) + } + + fn visit_map( + &mut self, + map_type: &TypePtr, + context: VisitorContext, + ) -> Result> { + let rep_level = context.rep_level + 1; + let (def_level, nullable) = match get_repetition(map_type) { + Repetition::REQUIRED => (context.def_level + 1, false), + Repetition::OPTIONAL => (context.def_level + 2, true), + Repetition::REPEATED => return Err(arrow_err!("Map cannot be repeated")), + }; + + if map_type.get_fields().len() != 1 { + return Err(arrow_err!( + "Map field must have exactly one key_value child, found {}", + map_type.get_fields().len() + )); + } + + // Add map entry (key_value) to context + let map_key_value = &map_type.get_fields()[0]; + if map_key_value.get_basic_info().repetition() != Repetition::REPEATED { + return Err(arrow_err!("Child of map field must be repeated")); + } + + if map_key_value.get_fields().len() != 2 { + // According to the specification the values are optional (#1642) + return Err(arrow_err!( + "Child of map field must have two children, found {}", + map_key_value.get_fields().len() + )); + } + + // Get key and value, and create context for each + let map_key = &map_key_value.get_fields()[0]; + let map_value = &map_key_value.get_fields()[1]; + + if map_key.get_basic_info().repetition() != Repetition::REQUIRED { + return Err(arrow_err!("Map keys must be required")); + } + + if map_value.get_basic_info().repetition() == Repetition::REPEATED { + return Err(arrow_err!("Map values cannot be repeated")); + } + + // Extract the arrow fields + let (arrow_map, arrow_key, arrow_value, sorted) = match &context.data_type { + Some(DataType::Map(field, sorted)) => match field.data_type() { + DataType::Struct(fields) => { + if fields.len() != 2 { + return Err(arrow_err!( + "Map data type should contain struct with two children, got {}", + fields.len() + )); + } + + (Some(field), Some(&fields[0]), Some(&fields[1]), *sorted) + } + d => { + return Err(arrow_err!( + "Map data type should contain struct got {}", + d + )); + } + }, + Some(d) => { + return Err(arrow_err!( + "incompatible arrow schema, expected map got {}", + d + )) + } + None => (None, None, None, false), + }; + + let maybe_key = { + let context = VisitorContext { + rep_level, + def_level, + data_type: arrow_key.map(|x| x.data_type().clone()), + }; + + self.dispatch(map_key, context)? + }; + + let maybe_value = { + let context = VisitorContext { + rep_level, + def_level, + data_type: arrow_value.map(|x| x.data_type().clone()), + }; + + self.dispatch(map_value, context)? + }; + + // Need both columns to be projected + match (maybe_key, maybe_value) { + (Some(key), Some(value)) => { + let key_field = convert_field(map_key, &key, arrow_key); + let value_field = convert_field(map_value, &value, arrow_value); + + let map_field = Field::new( + map_key_value.name(), + DataType::Struct(vec![key_field, value_field]), + false, // The inner map field is always non-nullable (#1697) + ) + .with_metadata(arrow_map.and_then(|f| f.metadata().cloned())); + + Ok(Some(ParquetField { + rep_level, + def_level, + nullable, + arrow_type: DataType::Map(Box::new(map_field), sorted), + field_type: ParquetFieldType::Group { + children: vec![key, value], + }, + })) + } + _ => Ok(None), + } + } + + fn visit_list( + &mut self, + list_type: &TypePtr, + context: VisitorContext, + ) -> Result> { + if list_type.is_primitive() { + return Err(arrow_err!( + "{:?} is a list type and can't be processed as primitive.", + list_type + )); + } + + let fields = list_type.get_fields(); + if fields.len() != 1 { + return Err(arrow_err!( + "list type must have a single child, found {}", + fields.len() + )); + } + + let repeated_field = &fields[0]; + if get_repetition(repeated_field) != Repetition::REPEATED { + return Err(arrow_err!("List child must be repeated")); + } + + // If the list is nullable + let (def_level, nullable) = match list_type.get_basic_info().repetition() { + Repetition::REQUIRED => (context.def_level, false), + Repetition::OPTIONAL => (context.def_level + 1, true), + Repetition::REPEATED => { + return Err(arrow_err!("List type cannot be repeated")) + } + }; + + let arrow_field = match &context.data_type { + Some(DataType::List(f)) => Some(f.as_ref()), + Some(DataType::LargeList(f)) => Some(f.as_ref()), + Some(DataType::FixedSizeList(f, _)) => Some(f.as_ref()), + Some(d) => { + return Err(arrow_err!( + "incompatible arrow schema, expected list got {}", + d + )) + } + None => None, + }; + + if repeated_field.is_primitive() { + // If the repeated field is not a group, then its type is the element type and elements are required. + // + // required/optional group my_list (LIST) { + // repeated int32 element; + // } + // + let context = VisitorContext { + rep_level: context.rep_level, + def_level, + data_type: arrow_field.map(|f| f.data_type().clone()), + }; + + return match self.visit_primitive(repeated_field, context) { + Ok(Some(mut field)) => { + // visit_primitive will infer a non-nullable list, update if necessary + field.nullable = nullable; + Ok(Some(field)) + } + r => r, + }; + } + + let items = repeated_field.get_fields(); + if items.len() != 1 + || repeated_field.name() == "array" + || repeated_field.name() == format!("{}_tuple", list_type.name()) + { + // If the repeated field is a group with multiple fields, then its type is the element type and elements are required. + // + // If the repeated field is a group with one field and is named either array or uses the LIST-annotated group's name + // with _tuple appended then the repeated type is the element type and elements are required. + let context = VisitorContext { + rep_level: context.rep_level, + def_level, + data_type: arrow_field.map(|f| f.data_type().clone()), + }; + + return match self.visit_struct(repeated_field, context) { + Ok(Some(mut field)) => { + field.nullable = nullable; + Ok(Some(field)) + } + r => r, + }; + } + + // Regular list handling logic + let item_type = &items[0]; + let rep_level = context.rep_level + 1; + let def_level = def_level + 1; + + let new_context = VisitorContext { + def_level, + rep_level, + data_type: arrow_field.map(|f| f.data_type().clone()), + }; + + match self.dispatch(item_type, new_context) { + Ok(Some(item)) => { + let item_field = Box::new(convert_field(item_type, &item, arrow_field)); + + // Use arrow type as hint for index size + let arrow_type = match context.data_type { + Some(DataType::LargeList(_)) => DataType::LargeList(item_field), + Some(DataType::FixedSizeList(_, len)) => { + DataType::FixedSizeList(item_field, len) + } + _ => DataType::List(item_field), + }; + + Ok(Some(ParquetField { + rep_level, + def_level, + nullable, + arrow_type, + field_type: ParquetFieldType::Group { + children: vec![item], + }, + })) + } + r => r, + } + } + + fn dispatch( + &mut self, + cur_type: &TypePtr, + context: VisitorContext, + ) -> Result> { + if cur_type.is_primitive() { + self.visit_primitive(cur_type, context) + } else { + match cur_type.get_basic_info().converted_type() { + ConvertedType::LIST => self.visit_list(cur_type, context), + ConvertedType::MAP | ConvertedType::MAP_KEY_VALUE => { + self.visit_map(cur_type, context) + } + _ => self.visit_struct(cur_type, context), + } + } + } +} + +/// Computes the [`Field`] for a child column +/// +/// The resulting [`Field`] will have the type dictated by `field`, a name +/// dictated by the `parquet_type`, and any metadata from `arrow_hint` +fn convert_field( + parquet_type: &Type, + field: &ParquetField, + arrow_hint: Option<&Field>, +) -> Field { + let name = parquet_type.name(); + let data_type = field.arrow_type.clone(); + let nullable = field.nullable; + + match arrow_hint { + Some(hint) => { + // If the inferred type is a dictionary, preserve dictionary metadata + let field = match (&data_type, hint.dict_id(), hint.dict_is_ordered()) { + (DataType::Dictionary(_, _), Some(id), Some(ordered)) => { + Field::new_dict(name, data_type, nullable, id, ordered) + } + _ => Field::new(name, data_type, nullable), + }; + + field.with_metadata(hint.metadata().cloned()) + } + None => Field::new(name, data_type, nullable), + } +} + +/// Computes the [`ParquetField`] for the provided [`SchemaDescriptor`] with `leaf_columns` listing +/// the indexes of leaf columns to project, and `embedded_arrow_schema` the optional +/// [`Schema`] embedded in the parquet metadata +/// +/// Note: This does not support out of order column projection +pub fn convert_schema( + schema: &SchemaDescriptor, + mask: ProjectionMask, + embedded_arrow_schema: Option<&Schema>, +) -> Result> { + let mut visitor = Visitor { + next_col_idx: 0, + mask, + }; + + let context = VisitorContext { + rep_level: 0, + def_level: 0, + data_type: embedded_arrow_schema.map(|s| DataType::Struct(s.fields().clone())), + }; + + visitor.dispatch(&schema.root_schema_ptr(), context) +} + +/// Computes the [`ParquetField`] for the provided `parquet_type` +pub fn convert_type(parquet_type: &TypePtr) -> Result { + let mut visitor = Visitor { + next_col_idx: 0, + mask: ProjectionMask::all(), + }; + + let context = VisitorContext { + rep_level: 0, + def_level: 0, + data_type: None, + }; + + Ok(visitor.dispatch(parquet_type, context)?.unwrap()) +} diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs new file mode 100644 index 000000000000..0816b6b2f8e8 --- /dev/null +++ b/parquet/src/arrow/schema/primitive.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::basic::{ + ConvertedType, LogicalType, TimeUnit as ParquetTimeUnit, Type as PhysicalType, +}; +use crate::errors::{ParquetError, Result}; +use crate::schema::types::{BasicTypeInfo, Type}; +use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; + +/// Converts [`Type`] to [`DataType`] with an optional `arrow_type_hint` +/// provided by the arrow schema +/// +/// Note: the values embedded in the schema are advisory, +pub fn convert_primitive( + parquet_type: &Type, + arrow_type_hint: Option, +) -> Result { + let physical_type = from_parquet(parquet_type)?; + Ok(match arrow_type_hint { + Some(hint) => apply_hint(physical_type, hint), + None => physical_type, + }) +} + +/// Uses an type hint from the embedded arrow schema to aid in faithfully +/// reproducing the data as it was written into parquet +fn apply_hint(parquet: DataType, hint: DataType) -> DataType { + match (&parquet, &hint) { + // Not all time units can be represented as LogicalType / ConvertedType + (DataType::Int32 | DataType::Int64, DataType::Timestamp(_, _)) => hint, + (DataType::Int32, DataType::Time32(_)) => hint, + (DataType::Int64, DataType::Time64(_)) => hint, + + // Date64 doesn't have a corresponding LogicalType / ConvertedType + (DataType::Int64, DataType::Date64) => hint, + + // Coerce Date32 back to Date64 (#1666) + (DataType::Date32, DataType::Date64) => hint, + + // Determine timezone + (DataType::Timestamp(p, None), DataType::Timestamp(h, Some(_))) if p == h => hint, + + // Determine offset size + (DataType::Utf8, DataType::LargeUtf8) => hint, + (DataType::Binary, DataType::LargeBinary) => hint, + + // Determine interval time unit (#1666) + (DataType::Interval(_), DataType::Interval(_)) => hint, + + // Potentially preserve dictionary encoding + (_, DataType::Dictionary(_, value)) => { + // Apply hint to inner type + let hinted = apply_hint(parquet, value.as_ref().clone()); + + // If matches dictionary value - preserve dictionary + // otherwise use hinted inner type + match &hinted == value.as_ref() { + true => hint, + false => hinted, + } + } + _ => parquet, + } +} + +fn from_parquet(parquet_type: &Type) -> Result { + match parquet_type { + Type::PrimitiveType { + physical_type, + basic_info, + type_length, + scale, + precision, + .. + } => match physical_type { + PhysicalType::BOOLEAN => Ok(DataType::Boolean), + PhysicalType::INT32 => from_int32(basic_info, *scale, *precision), + PhysicalType::INT64 => from_int64(basic_info, *scale, *precision), + PhysicalType::INT96 => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), + PhysicalType::FLOAT => Ok(DataType::Float32), + PhysicalType::DOUBLE => Ok(DataType::Float64), + PhysicalType::BYTE_ARRAY => from_byte_array(basic_info), + PhysicalType::FIXED_LEN_BYTE_ARRAY => { + from_fixed_len_byte_array(basic_info, *scale, *precision, *type_length) + } + }, + Type::GroupType { .. } => unreachable!(), + } +} + +fn decimal_type(scale: i32, precision: i32) -> Result { + let scale = scale + .try_into() + .map_err(|_| arrow_err!("scale cannot be negative: {}", scale))?; + + let precision = precision + .try_into() + .map_err(|_| arrow_err!("precision cannot be negative: {}", precision))?; + + Ok(DataType::Decimal(precision, scale)) +} + +fn from_int32(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result { + match (info.logical_type(), info.converted_type()) { + (None, ConvertedType::NONE) => Ok(DataType::Int32), + ( + Some( + ref t @ LogicalType::Integer { + bit_width, + is_signed, + }, + ), + _, + ) => match (bit_width, is_signed) { + (8, true) => Ok(DataType::Int8), + (16, true) => Ok(DataType::Int16), + (32, true) => Ok(DataType::Int32), + (8, false) => Ok(DataType::UInt8), + (16, false) => Ok(DataType::UInt16), + (32, false) => Ok(DataType::UInt32), + _ => Err(arrow_err!("Cannot create INT32 physical type from {:?}", t)), + }, + (Some(LogicalType::Decimal { scale, precision }), _) => { + decimal_type(scale, precision) + } + (Some(LogicalType::Date), _) => Ok(DataType::Date32), + (Some(LogicalType::Time { unit, .. }), _) => match unit { + ParquetTimeUnit::MILLIS(_) => Ok(DataType::Time32(TimeUnit::Millisecond)), + _ => Err(arrow_err!( + "Cannot create INT32 physical type from {:?}", + unit + )), + }, + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#unknown-always-null + (Some(LogicalType::Unknown), _) => Ok(DataType::Null), + (None, ConvertedType::UINT_8) => Ok(DataType::UInt8), + (None, ConvertedType::UINT_16) => Ok(DataType::UInt16), + (None, ConvertedType::UINT_32) => Ok(DataType::UInt32), + (None, ConvertedType::INT_8) => Ok(DataType::Int8), + (None, ConvertedType::INT_16) => Ok(DataType::Int16), + (None, ConvertedType::INT_32) => Ok(DataType::Int32), + (None, ConvertedType::DATE) => Ok(DataType::Date32), + (None, ConvertedType::TIME_MILLIS) => Ok(DataType::Time32(TimeUnit::Millisecond)), + (None, ConvertedType::DECIMAL) => decimal_type(scale, precision), + (logical, converted) => Err(arrow_err!( + "Unable to convert parquet INT32 logical type {:?} or converted type {}", + logical, + converted + )), + } +} + +fn from_int64(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result { + match (info.logical_type(), info.converted_type()) { + (None, ConvertedType::NONE) => Ok(DataType::Int64), + ( + Some(LogicalType::Integer { + bit_width, + is_signed, + }), + _, + ) if bit_width == 64 => match is_signed { + true => Ok(DataType::Int64), + false => Ok(DataType::UInt64), + }, + (Some(LogicalType::Time { unit, .. }), _) => match unit { + ParquetTimeUnit::MILLIS(_) => { + Err(arrow_err!("Cannot create INT64 from MILLIS time unit",)) + } + ParquetTimeUnit::MICROS(_) => Ok(DataType::Time64(TimeUnit::Microsecond)), + ParquetTimeUnit::NANOS(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)), + }, + ( + Some(LogicalType::Timestamp { + is_adjusted_to_u_t_c, + unit, + }), + _, + ) => Ok(DataType::Timestamp( + match unit { + ParquetTimeUnit::MILLIS(_) => TimeUnit::Millisecond, + ParquetTimeUnit::MICROS(_) => TimeUnit::Microsecond, + ParquetTimeUnit::NANOS(_) => TimeUnit::Nanosecond, + }, + if is_adjusted_to_u_t_c { + Some("UTC".to_string()) + } else { + None + }, + )), + (None, ConvertedType::INT_64) => Ok(DataType::Int64), + (None, ConvertedType::UINT_64) => Ok(DataType::UInt64), + (None, ConvertedType::TIME_MICROS) => Ok(DataType::Time64(TimeUnit::Microsecond)), + (None, ConvertedType::TIMESTAMP_MILLIS) => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + } + (None, ConvertedType::TIMESTAMP_MICROS) => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + } + (Some(LogicalType::Decimal { scale, precision }), _) => { + decimal_type(scale, precision) + } + (None, ConvertedType::DECIMAL) => decimal_type(scale, precision), + (logical, converted) => Err(arrow_err!( + "Unable to convert parquet INT64 logical type {:?} or converted type {}", + logical, + converted + )), + } +} + +fn from_byte_array(info: &BasicTypeInfo) -> Result { + match (info.logical_type(), info.converted_type()) { + (Some(LogicalType::String), _) => Ok(DataType::Utf8), + (Some(LogicalType::Json), _) => Ok(DataType::Binary), + (Some(LogicalType::Bson), _) => Ok(DataType::Binary), + (Some(LogicalType::Enum), _) => Ok(DataType::Binary), + (None, ConvertedType::NONE) => Ok(DataType::Binary), + (None, ConvertedType::JSON) => Ok(DataType::Binary), + (None, ConvertedType::BSON) => Ok(DataType::Binary), + (None, ConvertedType::ENUM) => Ok(DataType::Binary), + (None, ConvertedType::UTF8) => Ok(DataType::Utf8), + (logical, converted) => Err(arrow_err!( + "Unable to convert parquet BYTE_ARRAY logical type {:?} or converted type {}", + logical, + converted + )), + } +} + +fn from_fixed_len_byte_array( + info: &BasicTypeInfo, + scale: i32, + precision: i32, + type_length: i32, +) -> Result { + // TODO: This should check the type length for the decimal and interval types + match (info.logical_type(), info.converted_type()) { + (Some(LogicalType::Decimal { scale, precision }), _) => { + decimal_type(scale, precision) + } + (None, ConvertedType::DECIMAL) => decimal_type(scale, precision), + (None, ConvertedType::INTERVAL) => { + // There is currently no reliable way of determining which IntervalUnit + // to return. Thus without the original Arrow schema, the results + // would be incorrect if all 12 bytes of the interval are populated + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + _ => Ok(DataType::FixedSizeBinary(type_length)), + } +} diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index 631257e0ed1d..59a0fe07b7de 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -41,7 +41,7 @@ pub use parquet_format::{ /// control the on disk storage format. /// For example INT16 is not included as a type since a good encoding of INT32 /// would handle this. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Type { BOOLEAN, INT32, @@ -165,19 +165,31 @@ pub enum ConvertedType { /// [`ConvertedType`]. Please see the README.md for more details. #[derive(Debug, Clone, PartialEq)] pub enum LogicalType { - STRING(StringType), - MAP(MapType), - LIST(ListType), - ENUM(EnumType), - DECIMAL(DecimalType), - DATE(DateType), - TIME(TimeType), - TIMESTAMP(TimestampType), - INTEGER(IntType), - UNKNOWN(NullType), - JSON(JsonType), - BSON(BsonType), - UUID(UUIDType), + String, + Map, + List, + Enum, + Decimal { + scale: i32, + precision: i32, + }, + Date, + Time { + is_adjusted_to_u_t_c: bool, + unit: TimeUnit, + }, + Timestamp { + is_adjusted_to_u_t_c: bool, + unit: TimeUnit, + }, + Integer { + bit_width: i8, + is_signed: bool, + }, + Unknown, + Json, + Bson, + Uuid, } // ---------------------------------------------------------------------- @@ -250,6 +262,15 @@ pub enum Encoding { /// /// The ids are encoded using the RLE encoding. RLE_DICTIONARY, + + /// Encoding for floating-point data. + /// + /// K byte-streams are created where K is the size in bytes of the data type. + /// The individual bytes of an FP value are scattered to the corresponding stream and + /// the streams are concatenated. + /// This itself does not reduce the size of the data but can lead to better compression + /// afterwards. + BYTE_STREAM_SPLIT, } // ---------------------------------------------------------------------- @@ -326,21 +347,21 @@ impl ColumnOrder { // TODO: Should this take converted and logical type, for compatibility? match logical_type { Some(logical) => match logical { - LogicalType::STRING(_) - | LogicalType::ENUM(_) - | LogicalType::JSON(_) - | LogicalType::BSON(_) => SortOrder::UNSIGNED, - LogicalType::INTEGER(t) => match t.is_signed { + LogicalType::String + | LogicalType::Enum + | LogicalType::Json + | LogicalType::Bson => SortOrder::UNSIGNED, + LogicalType::Integer { is_signed, .. } => match is_signed { true => SortOrder::SIGNED, false => SortOrder::UNSIGNED, }, - LogicalType::MAP(_) | LogicalType::LIST(_) => SortOrder::UNDEFINED, - LogicalType::DECIMAL(_) => SortOrder::SIGNED, - LogicalType::DATE(_) => SortOrder::SIGNED, - LogicalType::TIME(_) => SortOrder::SIGNED, - LogicalType::TIMESTAMP(_) => SortOrder::SIGNED, - LogicalType::UNKNOWN(_) => SortOrder::UNDEFINED, - LogicalType::UUID(_) => SortOrder::UNSIGNED, + LogicalType::Map | LogicalType::List => SortOrder::UNDEFINED, + LogicalType::Decimal { .. } => SortOrder::SIGNED, + LogicalType::Date => SortOrder::SIGNED, + LogicalType::Time { .. } => SortOrder::SIGNED, + LogicalType::Timestamp { .. } => SortOrder::SIGNED, + LogicalType::Unknown => SortOrder::UNDEFINED, + LogicalType::Uuid => SortOrder::UNSIGNED, }, // Fall back to converted type None => Self::get_converted_sort_order(converted_type, physical_type), @@ -577,19 +598,31 @@ impl convert::From for Option { impl convert::From for LogicalType { fn from(value: parquet::LogicalType) -> Self { match value { - parquet::LogicalType::STRING(t) => LogicalType::STRING(t), - parquet::LogicalType::MAP(t) => LogicalType::MAP(t), - parquet::LogicalType::LIST(t) => LogicalType::LIST(t), - parquet::LogicalType::ENUM(t) => LogicalType::ENUM(t), - parquet::LogicalType::DECIMAL(t) => LogicalType::DECIMAL(t), - parquet::LogicalType::DATE(t) => LogicalType::DATE(t), - parquet::LogicalType::TIME(t) => LogicalType::TIME(t), - parquet::LogicalType::TIMESTAMP(t) => LogicalType::TIMESTAMP(t), - parquet::LogicalType::INTEGER(t) => LogicalType::INTEGER(t), - parquet::LogicalType::UNKNOWN(t) => LogicalType::UNKNOWN(t), - parquet::LogicalType::JSON(t) => LogicalType::JSON(t), - parquet::LogicalType::BSON(t) => LogicalType::BSON(t), - parquet::LogicalType::UUID(t) => LogicalType::UUID(t), + parquet::LogicalType::STRING(_) => LogicalType::String, + parquet::LogicalType::MAP(_) => LogicalType::Map, + parquet::LogicalType::LIST(_) => LogicalType::List, + parquet::LogicalType::ENUM(_) => LogicalType::Enum, + parquet::LogicalType::DECIMAL(t) => LogicalType::Decimal { + scale: t.scale, + precision: t.precision, + }, + parquet::LogicalType::DATE(_) => LogicalType::Date, + parquet::LogicalType::TIME(t) => LogicalType::Time { + is_adjusted_to_u_t_c: t.is_adjusted_to_u_t_c, + unit: t.unit, + }, + parquet::LogicalType::TIMESTAMP(t) => LogicalType::Timestamp { + is_adjusted_to_u_t_c: t.is_adjusted_to_u_t_c, + unit: t.unit, + }, + parquet::LogicalType::INTEGER(t) => LogicalType::Integer { + bit_width: t.bit_width, + is_signed: t.is_signed, + }, + parquet::LogicalType::UNKNOWN(_) => LogicalType::Unknown, + parquet::LogicalType::JSON(_) => LogicalType::Json, + parquet::LogicalType::BSON(_) => LogicalType::Bson, + parquet::LogicalType::UUID(_) => LogicalType::Uuid, } } } @@ -597,19 +630,39 @@ impl convert::From for LogicalType { impl convert::From for parquet::LogicalType { fn from(value: LogicalType) -> Self { match value { - LogicalType::STRING(t) => parquet::LogicalType::STRING(t), - LogicalType::MAP(t) => parquet::LogicalType::MAP(t), - LogicalType::LIST(t) => parquet::LogicalType::LIST(t), - LogicalType::ENUM(t) => parquet::LogicalType::ENUM(t), - LogicalType::DECIMAL(t) => parquet::LogicalType::DECIMAL(t), - LogicalType::DATE(t) => parquet::LogicalType::DATE(t), - LogicalType::TIME(t) => parquet::LogicalType::TIME(t), - LogicalType::TIMESTAMP(t) => parquet::LogicalType::TIMESTAMP(t), - LogicalType::INTEGER(t) => parquet::LogicalType::INTEGER(t), - LogicalType::UNKNOWN(t) => parquet::LogicalType::UNKNOWN(t), - LogicalType::JSON(t) => parquet::LogicalType::JSON(t), - LogicalType::BSON(t) => parquet::LogicalType::BSON(t), - LogicalType::UUID(t) => parquet::LogicalType::UUID(t), + LogicalType::String => parquet::LogicalType::STRING(Default::default()), + LogicalType::Map => parquet::LogicalType::MAP(Default::default()), + LogicalType::List => parquet::LogicalType::LIST(Default::default()), + LogicalType::Enum => parquet::LogicalType::ENUM(Default::default()), + LogicalType::Decimal { scale, precision } => { + parquet::LogicalType::DECIMAL(DecimalType { scale, precision }) + } + LogicalType::Date => parquet::LogicalType::DATE(Default::default()), + LogicalType::Time { + is_adjusted_to_u_t_c, + unit, + } => parquet::LogicalType::TIME(TimeType { + is_adjusted_to_u_t_c, + unit, + }), + LogicalType::Timestamp { + is_adjusted_to_u_t_c, + unit, + } => parquet::LogicalType::TIMESTAMP(TimestampType { + is_adjusted_to_u_t_c, + unit, + }), + LogicalType::Integer { + bit_width, + is_signed, + } => parquet::LogicalType::INTEGER(IntType { + bit_width, + is_signed, + }), + LogicalType::Unknown => parquet::LogicalType::UNKNOWN(Default::default()), + LogicalType::Json => parquet::LogicalType::JSON(Default::default()), + LogicalType::Bson => parquet::LogicalType::BSON(Default::default()), + LogicalType::Uuid => parquet::LogicalType::UUID(Default::default()), } } } @@ -627,23 +680,26 @@ impl From> for ConvertedType { fn from(value: Option) -> Self { match value { Some(value) => match value { - LogicalType::STRING(_) => ConvertedType::UTF8, - LogicalType::MAP(_) => ConvertedType::MAP, - LogicalType::LIST(_) => ConvertedType::LIST, - LogicalType::ENUM(_) => ConvertedType::ENUM, - LogicalType::DECIMAL(_) => ConvertedType::DECIMAL, - LogicalType::DATE(_) => ConvertedType::DATE, - LogicalType::TIME(t) => match t.unit { + LogicalType::String => ConvertedType::UTF8, + LogicalType::Map => ConvertedType::MAP, + LogicalType::List => ConvertedType::LIST, + LogicalType::Enum => ConvertedType::ENUM, + LogicalType::Decimal { .. } => ConvertedType::DECIMAL, + LogicalType::Date => ConvertedType::DATE, + LogicalType::Time { unit, .. } => match unit { TimeUnit::MILLIS(_) => ConvertedType::TIME_MILLIS, TimeUnit::MICROS(_) => ConvertedType::TIME_MICROS, TimeUnit::NANOS(_) => ConvertedType::NONE, }, - LogicalType::TIMESTAMP(t) => match t.unit { + LogicalType::Timestamp { unit, .. } => match unit { TimeUnit::MILLIS(_) => ConvertedType::TIMESTAMP_MILLIS, TimeUnit::MICROS(_) => ConvertedType::TIMESTAMP_MICROS, TimeUnit::NANOS(_) => ConvertedType::NONE, }, - LogicalType::INTEGER(t) => match (t.bit_width, t.is_signed) { + LogicalType::Integer { + bit_width, + is_signed, + } => match (bit_width, is_signed) { (8, true) => ConvertedType::INT_8, (16, true) => ConvertedType::INT_16, (32, true) => ConvertedType::INT_32, @@ -654,10 +710,10 @@ impl From> for ConvertedType { (64, false) => ConvertedType::UINT_64, t => panic!("Integer type {:?} is not supported", t), }, - LogicalType::UNKNOWN(_) => ConvertedType::NONE, - LogicalType::JSON(_) => ConvertedType::JSON, - LogicalType::BSON(_) => ConvertedType::BSON, - LogicalType::UUID(_) => ConvertedType::NONE, + LogicalType::Unknown => ConvertedType::NONE, + LogicalType::Json => ConvertedType::JSON, + LogicalType::Bson => ConvertedType::BSON, + LogicalType::Uuid => ConvertedType::NONE, }, None => ConvertedType::NONE, } @@ -701,6 +757,7 @@ impl convert::From for Encoding { parquet::Encoding::DeltaLengthByteArray => Encoding::DELTA_LENGTH_BYTE_ARRAY, parquet::Encoding::DeltaByteArray => Encoding::DELTA_BYTE_ARRAY, parquet::Encoding::RleDictionary => Encoding::RLE_DICTIONARY, + parquet::Encoding::ByteStreamSplit => Encoding::BYTE_STREAM_SPLIT, } } } @@ -716,6 +773,7 @@ impl convert::From for parquet::Encoding { Encoding::DELTA_LENGTH_BYTE_ARRAY => parquet::Encoding::DeltaLengthByteArray, Encoding::DELTA_BYTE_ARRAY => parquet::Encoding::DeltaByteArray, Encoding::RLE_DICTIONARY => parquet::Encoding::RleDictionary, + Encoding::BYTE_STREAM_SPLIT => parquet::Encoding::ByteStreamSplit, } } } @@ -849,31 +907,31 @@ impl str::FromStr for LogicalType { fn from_str(s: &str) -> result::Result { match s { // The type is a placeholder that gets updated elsewhere - "INTEGER" => Ok(LogicalType::INTEGER(IntType { + "INTEGER" => Ok(LogicalType::Integer { bit_width: 8, is_signed: false, - })), - "MAP" => Ok(LogicalType::MAP(MapType {})), - "LIST" => Ok(LogicalType::LIST(ListType {})), - "ENUM" => Ok(LogicalType::ENUM(EnumType {})), - "DECIMAL" => Ok(LogicalType::DECIMAL(DecimalType { + }), + "MAP" => Ok(LogicalType::Map), + "LIST" => Ok(LogicalType::List), + "ENUM" => Ok(LogicalType::Enum), + "DECIMAL" => Ok(LogicalType::Decimal { precision: -1, scale: -1, - })), - "DATE" => Ok(LogicalType::DATE(DateType {})), - "TIME" => Ok(LogicalType::TIME(TimeType { + }), + "DATE" => Ok(LogicalType::Date), + "TIME" => Ok(LogicalType::Time { is_adjusted_to_u_t_c: false, unit: TimeUnit::MILLIS(parquet::MilliSeconds {}), - })), - "TIMESTAMP" => Ok(LogicalType::TIMESTAMP(TimestampType { + }), + "TIMESTAMP" => Ok(LogicalType::Timestamp { is_adjusted_to_u_t_c: false, unit: TimeUnit::MILLIS(parquet::MilliSeconds {}), - })), - "STRING" => Ok(LogicalType::STRING(StringType {})), - "JSON" => Ok(LogicalType::JSON(JsonType {})), - "BSON" => Ok(LogicalType::BSON(BsonType {})), - "UUID" => Ok(LogicalType::UUID(UUIDType {})), - "UNKNOWN" => Ok(LogicalType::UNKNOWN(NullType {})), + }), + "STRING" => Ok(LogicalType::String), + "JSON" => Ok(LogicalType::Json), + "BSON" => Ok(LogicalType::Bson), + "UUID" => Ok(LogicalType::Uuid), + "UNKNOWN" => Ok(LogicalType::Unknown), "INTERVAL" => Err(general_err!("Interval logical type not yet supported")), other => Err(general_err!("Invalid logical type {}", other)), } @@ -1358,144 +1416,144 @@ mod tests { let logical_none: Option = None; assert_eq!(ConvertedType::from(logical_none), ConvertedType::NONE); assert_eq!( - ConvertedType::from(Some(LogicalType::DECIMAL(DecimalType { + ConvertedType::from(Some(LogicalType::Decimal { precision: 20, scale: 5 - }))), + })), ConvertedType::DECIMAL ); assert_eq!( - ConvertedType::from(Some(LogicalType::BSON(Default::default()))), + ConvertedType::from(Some(LogicalType::Bson)), ConvertedType::BSON ); assert_eq!( - ConvertedType::from(Some(LogicalType::JSON(Default::default()))), + ConvertedType::from(Some(LogicalType::Json)), ConvertedType::JSON ); assert_eq!( - ConvertedType::from(Some(LogicalType::STRING(Default::default()))), + ConvertedType::from(Some(LogicalType::String)), ConvertedType::UTF8 ); assert_eq!( - ConvertedType::from(Some(LogicalType::DATE(Default::default()))), + ConvertedType::from(Some(LogicalType::Date)), ConvertedType::DATE ); assert_eq!( - ConvertedType::from(Some(LogicalType::TIME(TimeType { + ConvertedType::from(Some(LogicalType::Time { unit: TimeUnit::MILLIS(Default::default()), is_adjusted_to_u_t_c: true, - }))), + })), ConvertedType::TIME_MILLIS ); assert_eq!( - ConvertedType::from(Some(LogicalType::TIME(TimeType { + ConvertedType::from(Some(LogicalType::Time { unit: TimeUnit::MICROS(Default::default()), is_adjusted_to_u_t_c: true, - }))), + })), ConvertedType::TIME_MICROS ); assert_eq!( - ConvertedType::from(Some(LogicalType::TIME(TimeType { + ConvertedType::from(Some(LogicalType::Time { unit: TimeUnit::NANOS(Default::default()), is_adjusted_to_u_t_c: false, - }))), + })), ConvertedType::NONE ); assert_eq!( - ConvertedType::from(Some(LogicalType::TIMESTAMP(TimestampType { + ConvertedType::from(Some(LogicalType::Timestamp { unit: TimeUnit::MILLIS(Default::default()), is_adjusted_to_u_t_c: true, - }))), + })), ConvertedType::TIMESTAMP_MILLIS ); assert_eq!( - ConvertedType::from(Some(LogicalType::TIMESTAMP(TimestampType { + ConvertedType::from(Some(LogicalType::Timestamp { unit: TimeUnit::MICROS(Default::default()), is_adjusted_to_u_t_c: false, - }))), + })), ConvertedType::TIMESTAMP_MICROS ); assert_eq!( - ConvertedType::from(Some(LogicalType::TIMESTAMP(TimestampType { + ConvertedType::from(Some(LogicalType::Timestamp { unit: TimeUnit::NANOS(Default::default()), is_adjusted_to_u_t_c: false, - }))), + })), ConvertedType::NONE ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 8, is_signed: false - }))), + })), ConvertedType::UINT_8 ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 8, is_signed: true - }))), + })), ConvertedType::INT_8 ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 16, is_signed: false - }))), + })), ConvertedType::UINT_16 ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 16, is_signed: true - }))), + })), ConvertedType::INT_16 ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 32, is_signed: false - }))), + })), ConvertedType::UINT_32 ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 32, is_signed: true - }))), + })), ConvertedType::INT_32 ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 64, is_signed: false - }))), + })), ConvertedType::UINT_64 ); assert_eq!( - ConvertedType::from(Some(LogicalType::INTEGER(IntType { + ConvertedType::from(Some(LogicalType::Integer { bit_width: 64, is_signed: true - }))), + })), ConvertedType::INT_64 ); assert_eq!( - ConvertedType::from(Some(LogicalType::LIST(Default::default()))), + ConvertedType::from(Some(LogicalType::List)), ConvertedType::LIST ); assert_eq!( - ConvertedType::from(Some(LogicalType::MAP(Default::default()))), + ConvertedType::from(Some(LogicalType::Map)), ConvertedType::MAP ); assert_eq!( - ConvertedType::from(Some(LogicalType::UUID(Default::default()))), + ConvertedType::from(Some(LogicalType::Uuid)), ConvertedType::NONE ); assert_eq!( - ConvertedType::from(Some(LogicalType::ENUM(Default::default()))), + ConvertedType::from(Some(LogicalType::Enum)), ConvertedType::ENUM ); assert_eq!( - ConvertedType::from(Some(LogicalType::UNKNOWN(Default::default()))), + ConvertedType::from(Some(LogicalType::Unknown)), ConvertedType::NONE ); } @@ -1776,85 +1834,82 @@ mod tests { // Unsigned comparison (physical type does not matter) let unsigned = vec![ - LogicalType::STRING(Default::default()), - LogicalType::JSON(Default::default()), - LogicalType::BSON(Default::default()), - LogicalType::ENUM(Default::default()), - LogicalType::UUID(Default::default()), - LogicalType::INTEGER(IntType { + LogicalType::String, + LogicalType::Json, + LogicalType::Bson, + LogicalType::Enum, + LogicalType::Uuid, + LogicalType::Integer { bit_width: 8, is_signed: false, - }), - LogicalType::INTEGER(IntType { + }, + LogicalType::Integer { bit_width: 16, is_signed: false, - }), - LogicalType::INTEGER(IntType { + }, + LogicalType::Integer { bit_width: 32, is_signed: false, - }), - LogicalType::INTEGER(IntType { + }, + LogicalType::Integer { bit_width: 64, is_signed: false, - }), + }, ]; check_sort_order(unsigned, SortOrder::UNSIGNED); // Signed comparison (physical type does not matter) let signed = vec![ - LogicalType::INTEGER(IntType { + LogicalType::Integer { bit_width: 8, is_signed: true, - }), - LogicalType::INTEGER(IntType { + }, + LogicalType::Integer { bit_width: 8, is_signed: true, - }), - LogicalType::INTEGER(IntType { + }, + LogicalType::Integer { bit_width: 8, is_signed: true, - }), - LogicalType::INTEGER(IntType { + }, + LogicalType::Integer { bit_width: 8, is_signed: true, - }), - LogicalType::DECIMAL(DecimalType { + }, + LogicalType::Decimal { scale: 20, precision: 4, - }), - LogicalType::DATE(Default::default()), - LogicalType::TIME(TimeType { + }, + LogicalType::Date, + LogicalType::Time { is_adjusted_to_u_t_c: false, unit: TimeUnit::MILLIS(Default::default()), - }), - LogicalType::TIME(TimeType { + }, + LogicalType::Time { is_adjusted_to_u_t_c: false, unit: TimeUnit::MICROS(Default::default()), - }), - LogicalType::TIME(TimeType { + }, + LogicalType::Time { is_adjusted_to_u_t_c: true, unit: TimeUnit::NANOS(Default::default()), - }), - LogicalType::TIMESTAMP(TimestampType { + }, + LogicalType::Timestamp { is_adjusted_to_u_t_c: false, unit: TimeUnit::MILLIS(Default::default()), - }), - LogicalType::TIMESTAMP(TimestampType { + }, + LogicalType::Timestamp { is_adjusted_to_u_t_c: false, unit: TimeUnit::MICROS(Default::default()), - }), - LogicalType::TIMESTAMP(TimestampType { + }, + LogicalType::Timestamp { is_adjusted_to_u_t_c: true, unit: TimeUnit::NANOS(Default::default()), - }), + }, ]; check_sort_order(signed, SortOrder::SIGNED); // Undefined comparison - let undefined = vec![ - LogicalType::LIST(Default::default()), - LogicalType::MAP(Default::default()), - ]; + let undefined = vec![LogicalType::List, LogicalType::Map]; check_sort_order(undefined, SortOrder::UNDEFINED); } diff --git a/parquet/src/bin/parquet-fromcsv-help.txt b/parquet/src/bin/parquet-fromcsv-help.txt new file mode 100644 index 000000000000..f4fe704ab267 --- /dev/null +++ b/parquet/src/bin/parquet-fromcsv-help.txt @@ -0,0 +1,66 @@ +Apache Arrow +Binary to convert csv to Parquet + +USAGE: + parquet [OPTIONS] --schema --input-file --output-file + +OPTIONS: + -b, --batch-size + batch size + + [env: PARQUET_FROM_CSV_BATCHSIZE=] + [default: 1000] + + -c, --parquet-compression + compression mode + + [default: SNAPPY] + + -d, --delimiter + field delimiter + + default value: when input_format==CSV: ',' when input_format==TSV: 'TAB' + + -D, --double-quote + double quote + + -e, --escape-char + escape charactor + + -f, --input-format + input file format + + [default: csv] + [possible values: csv, tsv] + + -h, --has-header + has header + + --help + Print help information + + -i, --input-file + input CSV file + + -m, --max-row-group-size + max row group size + + -o, --output-file + output Parquet file + + -q, --quote-char + quate charactor + + -r, --record-terminator + record terminator + + [possible values: lf, crlf, cr] + + -s, --schema + message schema for output Parquet + + -V, --version + Print version information + + -w, --writer-version + writer version diff --git a/parquet/src/bin/parquet-fromcsv.rs b/parquet/src/bin/parquet-fromcsv.rs new file mode 100644 index 000000000000..aa1d50563cd9 --- /dev/null +++ b/parquet/src/bin/parquet-fromcsv.rs @@ -0,0 +1,636 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Binary file to converts csv to Parquet file +//! +//! # Install +//! +//! `parquet-fromcsv` can be installed using `cargo`: +//! +//! ```text +//! cargo install parquet --features=cli +//! ``` +//! +//! After this `parquet-fromcsv` shoud be available: +//! +//! ```text +//! parquet-fromcsv --schema message_schema_for_parquet.txt input.csv output.parquet +//! ``` +//! +//! The binary can also be built from the source code and run as follows: +//! +//! ```text +//! cargo run --features=cli --bin parquet-fromcsv --schema message_schema_for_parquet.txt \ +//! \ input.csv output.parquet +//! ``` +//! +//! # Options +//! +//! ```text +#![doc = include_str!("./parquet-fromcsv-help.txt")] // Update for this file : Run test test_command_help +//! ``` +//! +//! ## Parquet file options +//! +//! - `-b`, `--batch-size` : Batch size for Parquet +//! - `-c`, `--parquet-compression` : Compression option for Parquet, default is SNAPPY +//! - `-s`, `--schema` : Path to message schema for generated Parquet file +//! - `-o`, `--output-file` : Path to output Parquet file +//! - `-w`, `--writer-version` : Writer version +//! - `-m`, `--max-row-group-size` : Max row group size +//! +//! ## Input file options +//! +//! - `-i`, `--input-file` : Path to input CSV file +//! - `-f`, `--input-format` : Dialect for input file, `csv` or `tsv`. +//! - `-d`, `--delimiter : Field delimitor for CSV file, default depends `--input-format` +//! - `-e`, `--escape` : Escape charactor for input file +//! - `-h`, `--has-header` : Input has header +//! - `-r`, `--record-terminator` : Record terminator charactor for input. default is CRLF +//! - `-q`, `--quote-char` : Input quoting charactor +//! + +use std::{ + fmt::Display, + fs::{read_to_string, File}, + path::{Path, PathBuf}, + sync::Arc, +}; + +use arrow::{csv::ReaderBuilder, datatypes::Schema, error::ArrowError}; +use clap::{ArgEnum, Parser}; +use parquet::{ + arrow::{parquet_to_arrow_schema, ArrowWriter}, + basic::Compression, + errors::ParquetError, + file::properties::{WriterProperties, WriterVersion}, + schema::{parser::parse_message_type, types::SchemaDescriptor}, +}; + +#[derive(Debug)] +enum ParquetFromCsvError { + CommandLineParseError(clap::Error), + IoError(std::io::Error), + ArrowError(ArrowError), + ParquetError(ParquetError), + WithContext(String, Box), +} + +impl From for ParquetFromCsvError { + fn from(e: std::io::Error) -> Self { + Self::IoError(e) + } +} + +impl From for ParquetFromCsvError { + fn from(e: ArrowError) -> Self { + Self::ArrowError(e) + } +} + +impl From for ParquetFromCsvError { + fn from(e: ParquetError) -> Self { + Self::ParquetError(e) + } +} + +impl From for ParquetFromCsvError { + fn from(e: clap::Error) -> Self { + Self::CommandLineParseError(e) + } +} + +impl ParquetFromCsvError { + pub fn with_context>( + inner_error: E, + context: &str, + ) -> ParquetFromCsvError { + let inner = inner_error.into(); + ParquetFromCsvError::WithContext(context.to_string(), Box::new(inner)) + } +} + +impl Display for ParquetFromCsvError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ParquetFromCsvError::CommandLineParseError(e) => write!(f, "{}", e), + ParquetFromCsvError::IoError(e) => write!(f, "{}", e), + ParquetFromCsvError::ArrowError(e) => write!(f, "{}", e), + ParquetFromCsvError::ParquetError(e) => write!(f, "{}", e), + ParquetFromCsvError::WithContext(c, e) => { + writeln!(f, "{}", e)?; + write!(f, "context: {}", c) + } + } + } +} + +#[derive(Debug, Parser)] +#[clap(author, version, about("Binary to convert csv to Parquet"), long_about=None)] +struct Args { + /// Path to a text file containing a parquet schema definition + #[clap(short, long, help("message schema for output Parquet"))] + schema: PathBuf, + /// input CSV file path + #[clap(short, long, help("input CSV file"))] + input_file: PathBuf, + /// output Parquet file path + #[clap(short, long, help("output Parquet file"))] + output_file: PathBuf, + /// input file format + #[clap( + arg_enum, + short('f'), + long, + help("input file format"), + default_value_t=CsvDialect::Csv + )] + input_format: CsvDialect, + /// batch size + #[clap( + short, + long, + help("batch size"), + default_value_t = 1000, + env = "PARQUET_FROM_CSV_BATCHSIZE" + )] + batch_size: usize, + /// has header line + #[clap(short, long, help("has header"))] + has_header: bool, + /// field delimiter + /// + /// default value: + /// when input_format==CSV: ',' + /// when input_format==TSV: 'TAB' + #[clap(short, long, help("field delimiter"))] + delimiter: Option, + #[clap(arg_enum, short, long, help("record terminator"))] + record_terminator: Option, + #[clap(short, long, help("escape charactor"))] + escape_char: Option, + #[clap(short, long, help("quate charactor"))] + quote_char: Option, + #[clap(short('D'), long, help("double quote"))] + double_quote: Option, + #[clap(short('c'), long, help("compression mode"), default_value_t=Compression::SNAPPY)] + #[clap(parse(try_from_str =compression_from_str))] + parquet_compression: Compression, + + #[clap(short, long, help("writer version"))] + #[clap(parse(try_from_str =writer_version_from_str))] + writer_version: Option, + #[clap(short, long, help("max row group size"))] + max_row_group_size: Option, +} + +fn compression_from_str(cmp: &str) -> Result { + match cmp.to_uppercase().as_str() { + "UNCOMPRESSED" => Ok(Compression::UNCOMPRESSED), + "SNAPPY" => Ok(Compression::SNAPPY), + "GZIP" => Ok(Compression::GZIP), + "LZO" => Ok(Compression::LZO), + "BROTLI" => Ok(Compression::BROTLI), + "LZ4" => Ok(Compression::LZ4), + "ZSTD" => Ok(Compression::ZSTD), + v => Err( + format!("Unknown compression {0} : possible values UNCOMPRESSED, SNAPPY, GZIP, LZO, BROTLI, LZ4, ZSTD ",v) + ) + } +} + +fn writer_version_from_str(cmp: &str) -> Result { + match cmp.to_uppercase().as_str() { + "1" => Ok(WriterVersion::PARQUET_1_0), + "2" => Ok(WriterVersion::PARQUET_2_0), + v => Err(format!( + "Unknown writer version {0} : possible values 1, 2", + v + )), + } +} + +impl Args { + fn schema_path(&self) -> &Path { + self.schema.as_path() + } + fn get_delimiter(&self) -> u8 { + match self.delimiter { + Some(ch) => ch as u8, + None => match self.input_format { + CsvDialect::Csv => b',', + CsvDialect::Tsv => b'\t', + }, + } + } + fn get_terminator(&self) -> Option { + match self.record_terminator { + Some(RecordTerminator::LF) => Some(0x0a), + Some(RecordTerminator::CR) => Some(0x0d), + Some(RecordTerminator::Crlf) => None, + None => match self.input_format { + CsvDialect::Csv => None, + CsvDialect::Tsv => Some(0x0a), + }, + } + } + fn get_escape(&self) -> Option { + self.escape_char.map(|ch| ch as u8) + } + fn get_quote(&self) -> Option { + if self.quote_char.is_none() { + match self.input_format { + CsvDialect::Csv => Some(b'\"'), + CsvDialect::Tsv => None, + } + } else { + self.quote_char.map(|c| c as u8) + } + } +} + +#[derive(Debug, Clone, Copy, ArgEnum, PartialEq)] +enum CsvDialect { + Csv, + Tsv, +} + +#[derive(Debug, Clone, Copy, ArgEnum, PartialEq)] +enum RecordTerminator { + LF, + Crlf, + CR, +} + +fn configure_writer_properties(args: &Args) -> WriterProperties { + let mut properties_builder = + WriterProperties::builder().set_compression(args.parquet_compression); + if let Some(writer_version) = args.writer_version { + properties_builder = properties_builder.set_writer_version(writer_version); + } + if let Some(max_row_group_size) = args.max_row_group_size { + properties_builder = + properties_builder.set_max_row_group_size(max_row_group_size); + } + properties_builder.build() +} + +fn configure_reader_builder(args: &Args, arrow_schema: Arc) -> ReaderBuilder { + fn configure_reader ReaderBuilder>( + builder: ReaderBuilder, + value: Option, + fun: F, + ) -> ReaderBuilder { + if let Some(val) = value { + fun(builder, val) + } else { + builder + } + } + + let mut builder = ReaderBuilder::new() + .with_schema(arrow_schema) + .with_batch_size(args.batch_size) + .has_header(args.has_header) + .with_delimiter(args.get_delimiter()); + + builder = configure_reader( + builder, + args.get_terminator(), + ReaderBuilder::with_terminator, + ); + builder = configure_reader(builder, args.get_escape(), ReaderBuilder::with_escape); + builder = configure_reader(builder, args.get_quote(), ReaderBuilder::with_quote); + + builder +} + +fn arrow_schema_from_string(schema: &str) -> Result, ParquetFromCsvError> { + let schema = Arc::new(parse_message_type(schema)?); + let desc = SchemaDescriptor::new(schema); + let arrow_schema = Arc::new(parquet_to_arrow_schema(&desc, None)?); + Ok(arrow_schema) +} + +fn convert_csv_to_parquet(args: &Args) -> Result<(), ParquetFromCsvError> { + let schema = read_to_string(args.schema_path()).map_err(|e| { + ParquetFromCsvError::with_context( + e, + &format!("Failed to open schema file {:#?}", args.schema_path()), + ) + })?; + let arrow_schema = arrow_schema_from_string(&schema)?; + + // create output parquet writer + let parquet_file = File::create(&args.output_file).map_err(|e| { + ParquetFromCsvError::with_context( + e, + &format!("Failed to create output file {:#?}", &args.output_file), + ) + })?; + + let writer_properties = Some(configure_writer_properties(args)); + let mut arrow_writer = + ArrowWriter::try_new(parquet_file, arrow_schema.clone(), writer_properties) + .map_err(|e| { + ParquetFromCsvError::with_context(e, "Failed to create ArrowWriter") + })?; + + // open input file + let input_file = File::open(&args.input_file).map_err(|e| { + ParquetFromCsvError::with_context( + e, + &format!("Failed to open input file {:#?}", &args.input_file), + ) + })?; + // create input csv reader + let builder = configure_reader_builder(args, arrow_schema); + let reader = builder.build(input_file)?; + for batch_result in reader { + let batch = batch_result.map_err(|e| { + ParquetFromCsvError::with_context(e, "Failed to read RecordBatch from CSV") + })?; + arrow_writer.write(&batch).map_err(|e| { + ParquetFromCsvError::with_context(e, "Failed to write RecordBatch to parquet") + })?; + } + arrow_writer + .close() + .map_err(|e| ParquetFromCsvError::with_context(e, "Failed to close parquet"))?; + Ok(()) +} + +fn main() -> Result<(), ParquetFromCsvError> { + let args = Args::parse(); + convert_csv_to_parquet(&args) +} + +#[cfg(test)] +mod tests { + use std::{ + io::{Seek, SeekFrom, Write}, + path::{Path, PathBuf}, + }; + + use super::*; + use arrow::datatypes::{DataType, Field}; + use clap::{CommandFactory, Parser}; + use tempfile::NamedTempFile; + + #[test] + fn test_command_help() { + let mut cmd = Args::command(); + let dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); + let mut path_buf = PathBuf::from(dir); + path_buf.push("src"); + path_buf.push("bin"); + path_buf.push("parquet-fromcsv-help.txt"); + let expected = std::fs::read_to_string(path_buf).unwrap(); + let mut buffer_vec = Vec::new(); + let mut buffer = std::io::Cursor::new(&mut buffer_vec); + cmd.write_long_help(&mut buffer).unwrap(); + // Remove Parquet version string from the help text + let mut actual = String::from_utf8(buffer_vec).unwrap(); + let pos = actual.find('\n').unwrap() + 1; + actual = actual[pos..].to_string(); + assert_eq!( + expected, actual, + "help text not match. please update to \n---\n{}\n---\n", + actual + ) + } + + fn parse_args(mut extra_args: Vec<&str>) -> Result { + let mut args = vec![ + "test", + "--schema", + "test.schema", + "--input-file", + "infile.csv", + "--output-file", + "out.parquet", + ]; + args.append(&mut extra_args); + let args = Args::try_parse_from(args.iter())?; + Ok(args) + } + + #[test] + fn test_parse_arg_minimum() -> Result<(), ParquetFromCsvError> { + let args = parse_args(vec![])?; + + assert_eq!(args.schema, PathBuf::from(Path::new("test.schema"))); + assert_eq!(args.input_file, PathBuf::from(Path::new("infile.csv"))); + assert_eq!(args.output_file, PathBuf::from(Path::new("out.parquet"))); + // test default values + assert_eq!(args.input_format, CsvDialect::Csv); + assert_eq!(args.batch_size, 1000); + assert_eq!(args.has_header, false); + assert_eq!(args.delimiter, None); + assert_eq!(args.get_delimiter(), b','); + assert_eq!(args.record_terminator, None); + assert_eq!(args.get_terminator(), None); // CRLF + assert_eq!(args.quote_char, None); + assert_eq!(args.get_quote(), Some(b'\"')); + assert_eq!(args.double_quote, None); + assert_eq!(args.parquet_compression, Compression::SNAPPY); + Ok(()) + } + + #[test] + fn test_parse_arg_format_variants() -> Result<(), ParquetFromCsvError> { + let args = parse_args(vec!["--input-format", "csv"])?; + assert_eq!(args.input_format, CsvDialect::Csv); + assert_eq!(args.get_delimiter(), b','); + assert_eq!(args.get_terminator(), None); // CRLF + assert_eq!(args.get_quote(), Some(b'\"')); + assert_eq!(args.get_escape(), None); + let args = parse_args(vec!["--input-format", "tsv"])?; + assert_eq!(args.input_format, CsvDialect::Tsv); + assert_eq!(args.get_delimiter(), b'\t'); + assert_eq!(args.get_terminator(), Some(b'\x0a')); // LF + assert_eq!(args.get_quote(), None); // quote none + assert_eq!(args.get_escape(), None); + + let args = parse_args(vec!["--input-format", "csv", "--escape-char", "\\"])?; + assert_eq!(args.input_format, CsvDialect::Csv); + assert_eq!(args.get_delimiter(), b','); + assert_eq!(args.get_terminator(), None); // CRLF + assert_eq!(args.get_quote(), Some(b'\"')); + assert_eq!(args.get_escape(), Some(b'\\')); + + let args = parse_args(vec!["--input-format", "tsv", "--delimiter", ":"])?; + assert_eq!(args.input_format, CsvDialect::Tsv); + assert_eq!(args.get_delimiter(), b':'); + assert_eq!(args.get_terminator(), Some(b'\x0a')); // LF + assert_eq!(args.get_quote(), None); // quote none + assert_eq!(args.get_escape(), None); + + Ok(()) + } + + #[test] + #[should_panic] + fn test_parse_arg_format_error() { + parse_args(vec!["--input-format", "excel"]).unwrap(); + } + + #[test] + fn test_parse_arg_compression_format() { + let args = parse_args(vec!["--parquet-compression", "uncompressed"]).unwrap(); + assert_eq!(args.parquet_compression, Compression::UNCOMPRESSED); + let args = parse_args(vec!["--parquet-compression", "snappy"]).unwrap(); + assert_eq!(args.parquet_compression, Compression::SNAPPY); + let args = parse_args(vec!["--parquet-compression", "gzip"]).unwrap(); + assert_eq!(args.parquet_compression, Compression::GZIP); + let args = parse_args(vec!["--parquet-compression", "lzo"]).unwrap(); + assert_eq!(args.parquet_compression, Compression::LZO); + let args = parse_args(vec!["--parquet-compression", "lz4"]).unwrap(); + assert_eq!(args.parquet_compression, Compression::LZ4); + let args = parse_args(vec!["--parquet-compression", "brotli"]).unwrap(); + assert_eq!(args.parquet_compression, Compression::BROTLI); + let args = parse_args(vec!["--parquet-compression", "zstd"]).unwrap(); + assert_eq!(args.parquet_compression, Compression::ZSTD); + } + + #[test] + fn test_parse_arg_compression_format_fail() { + match parse_args(vec!["--parquet-compression", "zip"]) { + Ok(_) => panic!("unexpected success"), + Err(e) => assert_eq!( + format!("{}", e), + "error: Invalid value \"zip\" for '--parquet-compression ': Unknown compression ZIP : possible values UNCOMPRESSED, SNAPPY, GZIP, LZO, BROTLI, LZ4, ZSTD \n\nFor more information try --help\n"), + } + } + + fn assert_debug_text(debug_text: &str, name: &str, value: &str) { + let pattern = format!(" {}: {}", name, value); + assert!( + debug_text.contains(&pattern), + "\"{}\" not contains \"{}\"", + debug_text, + pattern + ) + } + + #[test] + fn test_configure_reader_builder() { + let args = Args { + schema: PathBuf::from(Path::new("schema.arvo")), + input_file: PathBuf::from(Path::new("test.csv")), + output_file: PathBuf::from(Path::new("out.parquet")), + batch_size: 1000, + input_format: CsvDialect::Csv, + has_header: false, + delimiter: None, + record_terminator: None, + escape_char: None, + quote_char: None, + double_quote: None, + parquet_compression: Compression::SNAPPY, + writer_version: None, + max_row_group_size: None, + }; + let arrow_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Utf8, false), + Field::new("field2", DataType::Utf8, false), + Field::new("field3", DataType::Utf8, false), + Field::new("field4", DataType::Utf8, false), + Field::new("field5", DataType::Utf8, false), + ])); + + let reader_builder = configure_reader_builder(&args, arrow_schema.clone()); + let builder_debug = format!("{:?}", reader_builder); + assert_debug_text(&builder_debug, "has_header", "false"); + assert_debug_text(&builder_debug, "delimiter", "Some(44)"); + assert_debug_text(&builder_debug, "quote", "Some(34)"); + assert_debug_text(&builder_debug, "terminator", "None"); + assert_debug_text(&builder_debug, "batch_size", "1000"); + assert_debug_text(&builder_debug, "escape", "None"); + + let args = Args { + schema: PathBuf::from(Path::new("schema.arvo")), + input_file: PathBuf::from(Path::new("test.csv")), + output_file: PathBuf::from(Path::new("out.parquet")), + batch_size: 2000, + input_format: CsvDialect::Tsv, + has_header: true, + delimiter: None, + record_terminator: None, + escape_char: Some('\\'), + quote_char: None, + double_quote: None, + parquet_compression: Compression::SNAPPY, + writer_version: None, + max_row_group_size: None, + }; + let arrow_schema = Arc::new(Schema::new(vec![ + Field::new("field1", DataType::Utf8, false), + Field::new("field2", DataType::Utf8, false), + Field::new("field3", DataType::Utf8, false), + Field::new("field4", DataType::Utf8, false), + Field::new("field5", DataType::Utf8, false), + ])); + let reader_builder = configure_reader_builder(&args, arrow_schema.clone()); + let builder_debug = format!("{:?}", reader_builder); + assert_debug_text(&builder_debug, "has_header", "true"); + assert_debug_text(&builder_debug, "delimiter", "Some(9)"); + assert_debug_text(&builder_debug, "quote", "None"); + assert_debug_text(&builder_debug, "terminator", "Some(10)"); + assert_debug_text(&builder_debug, "batch_size", "2000"); + assert_debug_text(&builder_debug, "escape", "Some(92)"); + } + + #[test] + fn test_convert_csv_to_parquet() { + let schema = NamedTempFile::new().unwrap(); + let schema_text = r"message schema { + optional int32 id; + optional binary name (STRING); + }"; + schema.as_file().write_all(schema_text.as_bytes()).unwrap(); + + let mut input_file = NamedTempFile::new().unwrap(); + { + let csv = input_file.as_file_mut(); + for index in 1..2000 { + write!(csv, "{},\"name_{}\"\r\n", index, index).unwrap(); + } + csv.flush().unwrap(); + csv.seek(SeekFrom::Start(0)).unwrap(); + } + let output_parquet = NamedTempFile::new().unwrap(); + + let args = Args { + schema: PathBuf::from(schema.path()), + input_file: PathBuf::from(input_file.path()), + output_file: PathBuf::from(output_parquet.path()), + batch_size: 1000, + input_format: CsvDialect::Csv, + has_header: false, + delimiter: None, + record_terminator: None, + escape_char: None, + quote_char: None, + double_quote: None, + parquet_compression: Compression::SNAPPY, + writer_version: None, + max_row_group_size: None, + }; + convert_csv_to_parquet(&args).unwrap(); + } +} diff --git a/parquet/src/bin/parquet-read.rs b/parquet/src/bin/parquet-read.rs index aa3b8272dadf..0530afaa786a 100644 --- a/parquet/src/bin/parquet-read.rs +++ b/parquet/src/bin/parquet-read.rs @@ -21,94 +21,64 @@ //! //! `parquet-read` can be installed using `cargo`: //! ``` -//! cargo install parquet +//! cargo install parquet --features=cli //! ``` -//! After this `parquet-read` should be globally available: +//! After this `parquet-read` should be available: //! ``` //! parquet-read XYZ.parquet //! ``` //! //! The binary can also be built from the source code and run as follows: //! ``` -//! cargo run --bin parquet-read XYZ.parquet +//! cargo run --features=cli --bin parquet-read XYZ.parquet //! ``` //! -//! # Usage -//! ``` -//! parquet-read [num-records] -//! ``` -//! -//! ## Flags -//! -h, --help Prints help information -//! -j, --json Print Parquet file in JSON lines Format -//! -V, --version Prints version information -//! -//! ## Args -//! Path to a Parquet file -//! Number of records to read. When not provided, all records are read. -//! //! Note that `parquet-read` reads full file schema, no projection or filtering is //! applied. extern crate parquet; -use std::{env, fs::File, path::Path}; - -use clap::{crate_authors, crate_version, App, Arg}; - +use clap::Parser; use parquet::file::reader::{FileReader, SerializedFileReader}; use parquet::record::Row; +use std::{fs::File, path::Path}; + +#[derive(Debug, Parser)] +#[clap(author, version, about("Binary file to read data from a Parquet file"), long_about = None)] +struct Args { + #[clap(short, long, help("Path to a parquet file"))] + file_name: String, + #[clap( + short, + long, + default_value_t = 0_usize, + help("Number of records to read. When not provided or 0, all records are read") + )] + num_records: usize, + #[clap(short, long, help("Print Parquet file in JSON lines format"))] + json: bool, +} fn main() { - let app = App::new("parquet-read") - .version(crate_version!()) - .author(crate_authors!()) - .about("Read data from a Parquet file and print output in console, in either built-in or JSON format") - .arg( - Arg::with_name("file_path") - .value_name("file-path") - .required(true) - .index(1) - .help("Path to a parquet file"), - ) - .arg( - Arg::with_name("num_records") - .value_name("num-records") - .index(2) - .help( - "Number of records to read. When not provided, all records are read.", - ), - ) - .arg( - Arg::with_name("json") - .short("j") - .long("json") - .takes_value(false) - .help("Print Parquet file in JSON lines format"), - ); + let args = Args::parse(); - let matches = app.get_matches(); - let filename = matches.value_of("file_path").unwrap(); - let num_records: Option = if matches.is_present("num_records") { - match matches.value_of("num_records").unwrap().parse() { - Ok(value) => Some(value), - Err(e) => panic!("Error when reading value for [num-records], {}", e), - } - } else { - None - }; + let filename = args.file_name; + let num_records = args.num_records; + let json = args.json; - let json = matches.is_present("json"); let path = Path::new(&filename); - let file = File::open(&path).unwrap(); - let parquet_reader = SerializedFileReader::new(file).unwrap(); + let file = File::open(&path).expect("Unable to open file"); + let parquet_reader = + SerializedFileReader::new(file).expect("Failed to create reader"); // Use full schema as projected schema - let mut iter = parquet_reader.get_row_iter(None).unwrap(); + let mut iter = parquet_reader + .get_row_iter(None) + .expect("Failed to create row iterator"); let mut start = 0; - let end = num_records.unwrap_or(0); - let all_records = num_records.is_none(); + let end = num_records; + let all_records = end == 0; while all_records || start < end { match iter.next() { diff --git a/parquet/src/bin/parquet-rowcount.rs b/parquet/src/bin/parquet-rowcount.rs index 3c61bab882ae..d2f0311cf7a0 100644 --- a/parquet/src/bin/parquet-rowcount.rs +++ b/parquet/src/bin/parquet-rowcount.rs @@ -21,60 +21,46 @@ //! //! `parquet-rowcount` can be installed using `cargo`: //! ``` -//! cargo install parquet +//! cargo install parquet --features=cli //! ``` -//! After this `parquet-rowcount` should be globally available: +//! After this `parquet-rowcount` should be available: //! ``` //! parquet-rowcount XYZ.parquet //! ``` //! //! The binary can also be built from the source code and run as follows: //! ``` -//! cargo run --bin parquet-rowcount XYZ.parquet ABC.parquet ZXC.parquet +//! cargo run --features=cli --bin parquet-rowcount XYZ.parquet ABC.parquet ZXC.parquet //! ``` //! -//! # Usage -//! ``` -//! parquet-rowcount ... -//! ``` -//! -//! ## Flags -//! -h, --help Prints help information -//! -V, --version Prints version information -//! -//! ## Args -//! ... List of Parquet files to read from -//! //! Note that `parquet-rowcount` reads full file schema, no projection or filtering is //! applied. extern crate parquet; - -use std::{env, fs::File, path::Path}; - -use clap::{crate_authors, crate_version, App, Arg}; - +use clap::Parser; use parquet::file::reader::{FileReader, SerializedFileReader}; +use std::{fs::File, path::Path}; + +#[derive(Debug, Parser)] +#[clap(author, version, about("Binary file to return the number of rows found from Parquet file(s)"), long_about = None)] +struct Args { + #[clap( + short, + long, + multiple_values(true), + help("List of Parquet files to read from separated by space") + )] + file_paths: Vec, +} fn main() { - let matches = App::new("parquet-rowcount") - .version(crate_version!()) - .author(crate_authors!()) - .about("Return number of rows in Parquet file") - .arg( - Arg::with_name("file_paths") - .value_name("file-paths") - .required(true) - .multiple(true) - .help("List of Parquet files to read from separated by space"), - ) - .get_matches(); + let args = Args::parse(); - let filenames: Vec<&str> = matches.values_of("file_paths").unwrap().collect(); - for filename in &filenames { - let path = Path::new(filename); - let file = File::open(path).unwrap(); - let parquet_reader = SerializedFileReader::new(file).unwrap(); + for filename in args.file_paths { + let path = Path::new(&filename); + let file = File::open(path).expect("Unable to open file"); + let parquet_reader = + SerializedFileReader::new(file).expect("Unable to read file"); let row_group_metadata = parquet_reader.metadata().row_groups(); let mut total_num_rows = 0; diff --git a/parquet/src/bin/parquet-schema.rs b/parquet/src/bin/parquet-schema.rs index 1b806372b107..b875b0e7102b 100644 --- a/parquet/src/bin/parquet-schema.rs +++ b/parquet/src/bin/parquet-schema.rs @@ -21,72 +21,44 @@ //! //! `parquet-schema` can be installed using `cargo`: //! ``` -//! cargo install parquet +//! cargo install parquet --features=cli //! ``` -//! After this `parquet-schema` should be globally available: +//! After this `parquet-schema` should be available: //! ``` //! parquet-schema XYZ.parquet //! ``` //! //! The binary can also be built from the source code and run as follows: //! ``` -//! cargo run --bin parquet-schema XYZ.parquet +//! cargo run --features=cli --bin parquet-schema XYZ.parquet //! ``` //! -//! # Usage -//! ``` -//! parquet-schema [FLAGS] -//! ``` -//! -//! ## Flags -//! -h, --help Prints help information -//! -V, --version Prints version information -//! -v, --verbose Enable printing full file metadata -//! -//! ## Args -//! Path to a Parquet file -//! //! Note that `verbose` is an optional boolean flag that allows to print schema only, //! when not provided or print full file metadata when provided. extern crate parquet; - -use std::{env, fs::File, path::Path}; - -use clap::{crate_authors, crate_version, App, Arg}; - +use clap::Parser; use parquet::{ file::reader::{FileReader, SerializedFileReader}, schema::printer::{print_file_metadata, print_parquet_metadata}, }; +use std::{fs::File, path::Path}; -fn main() { - let matches = App::new("parquet-schema") - .version(crate_version!()) - .author(crate_authors!()) - .arg( - Arg::with_name("file_path") - .value_name("file-path") - .required(true) - .index(1) - .help("Path to a Parquet file"), - ) - .arg( - Arg::with_name("verbose") - .short("v") - .long("verbose") - .takes_value(false) - .help("Enable printing full file metadata"), - ) - .get_matches(); +#[derive(Debug, Parser)] +#[clap(author, version, about("Binary file to print the schema and metadata of a Parquet file"), long_about = None)] +struct Args { + #[clap(short, long)] + file_path: String, + #[clap(short, long, help("Enable printing full file metadata"))] + verbose: bool, +} - let filename = matches.value_of("file_path").unwrap(); +fn main() { + let args = Args::parse(); + let filename = args.file_path; let path = Path::new(&filename); - let file = match File::open(&path) { - Err(e) => panic!("Error when opening file {}: {}", path.display(), e), - Ok(f) => f, - }; - let verbose = matches.is_present("verbose"); + let file = File::open(&path).expect("Unable to open file"); + let verbose = args.verbose; match SerializedFileReader::new(file) { Err(e) => panic!("Error when parsing Parquet file: {}", e), diff --git a/parquet/src/column/mod.rs b/parquet/src/column/mod.rs index 7ed7bfc256e6..93a4f00d2eef 100644 --- a/parquet/src/column/mod.rs +++ b/parquet/src/column/mod.rs @@ -40,10 +40,11 @@ //! //! use parquet::{ //! column::{reader::ColumnReader, writer::ColumnWriter}, +//! data_type::Int32Type, //! file::{ //! properties::WriterProperties, //! reader::{FileReader, SerializedFileReader}, -//! writer::{FileWriter, SerializedFileWriter}, +//! writer::SerializedFileWriter, //! }, //! schema::parser::parse_message_type, //! }; @@ -65,20 +66,17 @@ //! let props = Arc::new(WriterProperties::builder().build()); //! let file = fs::File::create(path).unwrap(); //! let mut writer = SerializedFileWriter::new(file, schema, props).unwrap(); +//! //! let mut row_group_writer = writer.next_row_group().unwrap(); //! while let Some(mut col_writer) = row_group_writer.next_column().unwrap() { -//! match col_writer { -//! // You can also use `get_typed_column_writer` method to extract typed writer. -//! ColumnWriter::Int32ColumnWriter(ref mut typed_writer) => { -//! typed_writer -//! .write_batch(&[1, 2, 3], Some(&[3, 3, 3, 2, 2]), Some(&[0, 1, 0, 1, 1])) -//! .unwrap(); -//! } -//! _ => {} -//! } -//! row_group_writer.close_column(col_writer).unwrap(); +//! col_writer +//! .typed::() +//! .write_batch(&[1, 2, 3], Some(&[3, 3, 3, 2, 2]), Some(&[0, 1, 0, 1, 1])) +//! .unwrap(); +//! col_writer.close().unwrap(); //! } -//! writer.close_row_group(row_group_writer).unwrap(); +//! row_group_writer.close().unwrap(); +//! //! writer.close().unwrap(); //! //! // Reading data using column reader API. diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index acbf3eba410e..9364bd30fffd 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -189,7 +189,7 @@ impl PageWriteSpec { /// API for reading pages from a column chunk. /// This offers a iterator like API to get the next page. -pub trait PageReader: Iterator> { +pub trait PageReader: Iterator> + Send { /// Gets the next page in the column chunk associated with this reader. /// Returns `None` if there are no pages left. fn get_next_page(&mut self) -> Result>; @@ -219,8 +219,8 @@ pub trait PageWriter { fn close(&mut self) -> Result<()>; } -/// An iterator over pages of some specific column in a parquet file. -pub trait PageIterator: Iterator>> { +/// An iterator over pages of one specific column in a parquet file. +pub trait PageIterator: Iterator>> + Send { /// Get schema of parquet file. fn schema(&mut self) -> Result; diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index 63be17b7dd1f..3a45ecf3ff8f 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -17,22 +17,21 @@ //! Contains column reader API. -use std::{ - cmp::{max, min}, - collections::HashMap, -}; +use std::cmp::{max, min}; use super::page::{Page, PageReader}; use crate::basic::*; -use crate::data_type::*; -use crate::encodings::{ - decoding::{get_decoder, Decoder, DictDecoder, PlainDecoder}, - levels::LevelDecoder, +use crate::column::reader::decoder::{ + ColumnLevelDecoder, ColumnValueDecoder, LevelsBufferSlice, ValuesBufferSlice, }; +use crate::data_type::*; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; +use crate::util::bit_util::ceil; use crate::util::memory::ByteBufferPtr; +pub(crate) mod decoder; + /// Column reader for a Parquet type. pub enum ColumnReader { BoolColumnReader(ColumnReaderImpl), @@ -102,36 +101,65 @@ pub fn get_typed_column_reader( } /// Typed value reader for a particular primitive column. -pub struct ColumnReaderImpl { +pub type ColumnReaderImpl = GenericColumnReader< + decoder::ColumnLevelDecoderImpl, + decoder::ColumnLevelDecoderImpl, + decoder::ColumnValueDecoderImpl, +>; + +#[doc(hidden)] +/// Reads data for a given column chunk, using the provided decoders: +/// +/// - R: [`ColumnLevelDecoder`] used to decode repetition levels +/// - D: [`ColumnLevelDecoder`] used to decode definition levels +/// - V: [`ColumnValueDecoder`] used to decode value data +pub struct GenericColumnReader { descr: ColumnDescPtr, - def_level_decoder: Option, - rep_level_decoder: Option, + page_reader: Box, - current_encoding: Option, - // The total number of values stored in the data page. + /// The total number of values stored in the data page. num_buffered_values: u32, - // The number of values from the current data page that has been decoded into memory - // so far. + /// The number of values from the current data page that has been decoded into memory + /// so far. num_decoded_values: u32, - // Cache of decoders for existing encodings - decoders: HashMap>>, + /// The decoder for the definition levels if any + def_level_decoder: Option, + + /// The decoder for the repetition levels if any + rep_level_decoder: Option, + + /// The decoder for the values + values_decoder: V, } -impl ColumnReaderImpl { +impl GenericColumnReader +where + R: ColumnLevelDecoder, + D: ColumnLevelDecoder, + V: ColumnValueDecoder, +{ /// Creates new column reader based on column descriptor and page reader. pub fn new(descr: ColumnDescPtr, page_reader: Box) -> Self { + let values_decoder = V::new(&descr); + Self::new_with_decoder(descr, page_reader, values_decoder) + } + + fn new_with_decoder( + descr: ColumnDescPtr, + page_reader: Box, + values_decoder: V, + ) -> Self { Self { descr, def_level_decoder: None, rep_level_decoder: None, page_reader, - current_encoding: None, num_buffered_values: 0, num_decoded_values: 0, - decoders: HashMap::new(), + values_decoder, } } @@ -159,20 +187,20 @@ impl ColumnReaderImpl { pub fn read_batch( &mut self, batch_size: usize, - mut def_levels: Option<&mut [i16]>, - mut rep_levels: Option<&mut [i16]>, - values: &mut [T::T], + mut def_levels: Option<&mut D::Slice>, + mut rep_levels: Option<&mut R::Slice>, + values: &mut V::Slice, ) -> Result<(usize, usize)> { let mut values_read = 0; let mut levels_read = 0; // Compute the smallest batch size we can read based on provided slices - let mut batch_size = min(batch_size, values.len()); + let mut batch_size = min(batch_size, values.capacity()); if let Some(ref levels) = def_levels { - batch_size = min(batch_size, levels.len()); + batch_size = min(batch_size, levels.capacity()); } if let Some(ref levels) = rep_levels { - batch_size = min(batch_size, levels.len()); + batch_size = min(batch_size, levels.capacity()); } // Read exhaustively all pages until we read all batch_size values/levels @@ -200,57 +228,64 @@ impl ColumnReaderImpl { adjusted_size }; - let mut values_to_read = 0; - let mut num_def_levels = 0; - let mut num_rep_levels = 0; - // If the field is required and non-repeated, there are no definition levels - if self.descr.max_def_level() > 0 && def_levels.as_ref().is_some() { - if let Some(ref mut levels) = def_levels { - num_def_levels = self.read_def_levels( - &mut levels[levels_read..levels_read + iter_batch_size], - )?; - for i in levels_read..levels_read + num_def_levels { - if levels[i] == self.descr.max_def_level() { - values_to_read += 1; - } - } + let (num_def_levels, null_count) = match def_levels.as_mut() { + Some(levels) if self.descr.max_def_level() > 0 => { + let num_def_levels = self + .def_level_decoder + .as_mut() + .expect("def_level_decoder be set") + .read(*levels, levels_read..levels_read + iter_batch_size)?; + + let null_count = levels.count_nulls( + levels_read..levels_read + num_def_levels, + self.descr.max_def_level(), + ); + (num_def_levels, null_count) } - } else { - // If max definition level == 0, then it is REQUIRED field, read all - // values. If definition levels are not provided, we still - // read all values. - values_to_read = iter_batch_size; - } + _ => (0, 0), + }; - if self.descr.max_rep_level() > 0 && rep_levels.is_some() { - if let Some(ref mut levels) = rep_levels { - num_rep_levels = self.read_rep_levels( - &mut levels[levels_read..levels_read + iter_batch_size], - )?; - - // If definition levels are defined, check that rep levels == def - // levels - if def_levels.is_some() { - assert_eq!( - num_def_levels, num_rep_levels, - "Number of decoded rep / def levels did not match" - ); - } - } - } + let num_rep_levels = match rep_levels.as_mut() { + Some(levels) if self.descr.max_rep_level() > 0 => self + .rep_level_decoder + .as_mut() + .expect("rep_level_decoder be set") + .read(levels, levels_read..levels_read + iter_batch_size)?, + _ => 0, + }; // At this point we have read values, definition and repetition levels. // If both definition and repetition levels are defined, their counts // should be equal. Values count is always less or equal to definition levels. - // + if num_def_levels != 0 + && num_rep_levels != 0 + && num_rep_levels != num_def_levels + { + return Err(general_err!( + "inconsistent number of levels read - def: {}, rep: {}", + num_def_levels, + num_rep_levels + )); + } + // Note that if field is not required, but no definition levels are provided, // we would read values of batch size and (if provided, of course) repetition // levels of batch size - [!] they will not be synced, because only definition // levels enforce number of non-null values to read. - let curr_values_read = - self.read_values(&mut values[values_read..values_read + values_to_read])?; + let values_to_read = iter_batch_size - null_count; + let curr_values_read = self + .values_decoder + .read(values, values_read..values_read + values_to_read)?; + + if num_def_levels != 0 && curr_values_read != num_def_levels - null_count { + return Err(general_err!( + "insufficient values read from column - expected: {}, got: {}", + num_def_levels - null_count, + curr_values_read + )); + } // Update all "return" counters and internal state. @@ -275,8 +310,14 @@ impl ColumnReaderImpl { Some(current_page) => { match current_page { // 1. Dictionary page: configure dictionary for this page. - p @ Page::DictionaryPage { .. } => { - self.configure_dictionary(p)?; + Page::DictionaryPage { + buf, + num_values, + encoding, + is_sorted, + } => { + self.values_decoder + .set_dict(buf, num_values, encoding, is_sorted)?; continue; } // 2. Data page v1 @@ -291,41 +332,46 @@ impl ColumnReaderImpl { self.num_buffered_values = num_values; self.num_decoded_values = 0; - let mut buffer_ptr = buf; + let max_rep_level = self.descr.max_rep_level(); + let max_def_level = self.descr.max_def_level(); - if self.descr.max_rep_level() > 0 { - let mut rep_decoder = LevelDecoder::v1( + let mut offset = 0; + + if max_rep_level > 0 { + let (bytes_read, level_data) = parse_v1_level( + max_rep_level, + num_values, rep_level_encoding, - self.descr.max_rep_level(), - ); - let total_bytes = rep_decoder.set_data( - self.num_buffered_values as usize, - buffer_ptr.all(), - ); - buffer_ptr = buffer_ptr.start_from(total_bytes); - self.rep_level_decoder = Some(rep_decoder); + buf.start_from(offset), + )?; + offset += bytes_read; + + let decoder = + R::new(max_rep_level, rep_level_encoding, level_data); + + self.rep_level_decoder = Some(decoder); } - if self.descr.max_def_level() > 0 { - let mut def_decoder = LevelDecoder::v1( + if max_def_level > 0 { + let (bytes_read, level_data) = parse_v1_level( + max_def_level, + num_values, def_level_encoding, - self.descr.max_def_level(), - ); - let total_bytes = def_decoder.set_data( - self.num_buffered_values as usize, - buffer_ptr.all(), - ); - buffer_ptr = buffer_ptr.start_from(total_bytes); - self.def_level_decoder = Some(def_decoder); + buf.start_from(offset), + )?; + offset += bytes_read; + + let decoder = + D::new(max_def_level, def_level_encoding, level_data); + + self.def_level_decoder = Some(decoder); } - // Data page v1 does not have offset, all content of buffer - // should be passed - self.set_current_page_encoding( + self.values_decoder.set_data( encoding, - &buffer_ptr, - 0, + buf.start_from(offset), num_values as usize, + None, )?; return Ok(true); } @@ -334,53 +380,52 @@ impl ColumnReaderImpl { buf, num_values, encoding, - num_nulls: _, + num_nulls, num_rows: _, def_levels_byte_len, rep_levels_byte_len, is_compressed: _, statistics: _, } => { + if num_nulls > num_values { + return Err(general_err!("more nulls than values in page, contained {} values and {} nulls", num_values, num_nulls)); + } + self.num_buffered_values = num_values; self.num_decoded_values = 0; - let mut offset = 0; - // DataPage v2 only supports RLE encoding for repetition // levels if self.descr.max_rep_level() > 0 { - let mut rep_decoder = - LevelDecoder::v2(self.descr.max_rep_level()); - let bytes_read = rep_decoder.set_data_range( - self.num_buffered_values as usize, - &buf, - offset, - rep_levels_byte_len as usize, + let decoder = R::new( + self.descr.max_rep_level(), + Encoding::RLE, + buf.range(0, rep_levels_byte_len as usize), ); - offset += bytes_read; - self.rep_level_decoder = Some(rep_decoder); + self.rep_level_decoder = Some(decoder); } // DataPage v2 only supports RLE encoding for definition // levels if self.descr.max_def_level() > 0 { - let mut def_decoder = - LevelDecoder::v2(self.descr.max_def_level()); - let bytes_read = def_decoder.set_data_range( - self.num_buffered_values as usize, - &buf, - offset, - def_levels_byte_len as usize, + let decoder = D::new( + self.descr.max_def_level(), + Encoding::RLE, + buf.range( + rep_levels_byte_len as usize, + def_levels_byte_len as usize, + ), ); - offset += bytes_read; - self.def_level_decoder = Some(def_decoder); + self.def_level_decoder = Some(decoder); } - self.set_current_page_encoding( + self.values_decoder.set_data( encoding, - &buf, - offset, + buf.start_from( + (rep_levels_byte_len + def_levels_byte_len) as usize, + ), num_values as usize, + Some((num_values - num_nulls) as usize), )?; return Ok(true); } @@ -392,38 +437,6 @@ impl ColumnReaderImpl { Ok(true) } - /// Resolves and updates encoding and set decoder for the current page - fn set_current_page_encoding( - &mut self, - mut encoding: Encoding, - buffer_ptr: &ByteBufferPtr, - offset: usize, - len: usize, - ) -> Result<()> { - if encoding == Encoding::PLAIN_DICTIONARY { - encoding = Encoding::RLE_DICTIONARY; - } - - let decoder = if encoding == Encoding::RLE_DICTIONARY { - self.decoders - .get_mut(&encoding) - .expect("Decoder for dict should have been set") - } else { - // Search cache for data page decoder - #[allow(clippy::map_entry)] - if !self.decoders.contains_key(&encoding) { - // Initialize decoder for this page - let data_decoder = get_decoder::(self.descr.clone(), encoding)?; - self.decoders.insert(encoding, data_decoder); - } - self.decoders.get_mut(&encoding).unwrap() - }; - - decoder.set_data(buffer_ptr.start_from(offset), len as usize)?; - self.current_encoding = Some(encoding); - Ok(()) - } - #[inline] fn has_next(&mut self) -> Result { if self.num_buffered_values == 0 @@ -440,63 +453,29 @@ impl ColumnReaderImpl { Ok(true) } } +} - #[inline] - fn read_rep_levels(&mut self, buffer: &mut [i16]) -> Result { - let level_decoder = self - .rep_level_decoder - .as_mut() - .expect("rep_level_decoder be set"); - level_decoder.get(buffer) - } - - #[inline] - fn read_def_levels(&mut self, buffer: &mut [i16]) -> Result { - let level_decoder = self - .def_level_decoder - .as_mut() - .expect("def_level_decoder be set"); - level_decoder.get(buffer) - } - - #[inline] - fn read_values(&mut self, buffer: &mut [T::T]) -> Result { - let encoding = self - .current_encoding - .expect("current_encoding should be set"); - let current_decoder = self - .decoders - .get_mut(&encoding) - .unwrap_or_else(|| panic!("decoder for encoding {} should be set", encoding)); - current_decoder.get(buffer) - } - - #[inline] - fn configure_dictionary(&mut self, page: Page) -> Result { - let mut encoding = page.encoding(); - if encoding == Encoding::PLAIN || encoding == Encoding::PLAIN_DICTIONARY { - encoding = Encoding::RLE_DICTIONARY - } - - if self.decoders.contains_key(&encoding) { - return Err(general_err!("Column cannot have more than one dictionary")); +fn parse_v1_level( + max_level: i16, + num_buffered_values: u32, + encoding: Encoding, + buf: ByteBufferPtr, +) -> Result<(usize, ByteBufferPtr)> { + match encoding { + Encoding::RLE => { + let i32_size = std::mem::size_of::(); + let data_size = read_num_bytes!(i32, i32_size, buf.as_ref()) as usize; + Ok((i32_size + data_size, buf.range(i32_size, data_size))) } - - if encoding == Encoding::RLE_DICTIONARY { - let mut dictionary = PlainDecoder::::new(self.descr.type_length()); - let num_values = page.num_values(); - dictionary.set_data(page.buffer().clone(), num_values as usize)?; - - let mut decoder = DictDecoder::new(); - decoder.set_dict(Box::new(dictionary))?; - self.decoders.insert(encoding, Box::new(decoder)); - Ok(true) - } else { - Err(nyi_err!( - "Invalid/Unsupported encoding type for dictionary: {}", - encoding - )) + Encoding::BIT_PACKED => { + let bit_width = crate::util::bit_util::log2(max_level as u64 + 1) as u8; + let num_bytes = ceil( + (num_buffered_values as usize * bit_width as usize) as i64, + 8, + ) as usize; + Ok((num_bytes, buf.range(0, num_bytes))) } + _ => Err(general_err!("invalid level encoding: {}", encoding)), } } diff --git a/parquet/src/column/reader/decoder.rs b/parquet/src/column/reader/decoder.rs new file mode 100644 index 000000000000..76b0d079fb85 --- /dev/null +++ b/parquet/src/column/reader/decoder.rs @@ -0,0 +1,265 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::ops::Range; + +use crate::basic::Encoding; +use crate::data_type::DataType; +use crate::encodings::{ + decoding::{get_decoder, Decoder, DictDecoder, PlainDecoder}, + rle::RleDecoder, +}; +use crate::errors::{ParquetError, Result}; +use crate::schema::types::ColumnDescPtr; +use crate::util::{bit_util::BitReader, memory::ByteBufferPtr}; + +/// A slice of levels buffer data that is written to by a [`ColumnLevelDecoder`] +pub trait LevelsBufferSlice { + /// Returns the capacity of this slice or `usize::MAX` if no limit + fn capacity(&self) -> usize; + + /// Count the number of levels in `range` not equal to `max_level` + fn count_nulls(&self, range: Range, max_level: i16) -> usize; +} + +impl LevelsBufferSlice for [i16] { + fn capacity(&self) -> usize { + self.len() + } + + fn count_nulls(&self, range: Range, max_level: i16) -> usize { + self[range].iter().filter(|i| **i != max_level).count() + } +} + +/// A slice of values buffer data that is written to by a [`ColumnValueDecoder`] +pub trait ValuesBufferSlice { + /// Returns the capacity of this slice or `usize::MAX` if no limit + fn capacity(&self) -> usize; +} + +impl ValuesBufferSlice for [T] { + fn capacity(&self) -> usize { + self.len() + } +} + +/// Decodes level data to a [`LevelsBufferSlice`] +pub trait ColumnLevelDecoder { + type Slice: LevelsBufferSlice + ?Sized; + + /// Create a new [`ColumnLevelDecoder`] + fn new(max_level: i16, encoding: Encoding, data: ByteBufferPtr) -> Self; + + /// Read level data into `out[range]` returning the number of levels read + /// + /// `range` is provided by the caller to allow for types such as default-initialized `[T]` + /// that only track capacity and not length + /// + /// # Panics + /// + /// Implementations may panic if `range` overlaps with already written data + /// + fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result; +} + +/// Decodes value data to a [`ValuesBufferSlice`] +pub trait ColumnValueDecoder { + type Slice: ValuesBufferSlice + ?Sized; + + /// Create a new [`ColumnValueDecoder`] + fn new(col: &ColumnDescPtr) -> Self; + + /// Set the current dictionary page + fn set_dict( + &mut self, + buf: ByteBufferPtr, + num_values: u32, + encoding: Encoding, + is_sorted: bool, + ) -> Result<()>; + + /// Set the current data page + /// + /// - `encoding` - the encoding of the page + /// - `data` - a point to the page's uncompressed value data + /// - `num_levels` - the number of levels contained within the page, i.e. values including nulls + /// - `num_values` - the number of non-null values contained within the page (V2 page only) + /// + /// Note: data encoded with [`Encoding::RLE`] may not know its exact length, as the final + /// run may be zero-padded. As such if `num_values` is not provided (i.e. `None`), + /// subsequent calls to `ColumnValueDecoder::read` may yield more values than + /// non-null definition levels within the page + fn set_data( + &mut self, + encoding: Encoding, + data: ByteBufferPtr, + num_levels: usize, + num_values: Option, + ) -> Result<()>; + + /// Read values data into `out[range]` returning the number of values read + /// + /// `range` is provided by the caller to allow for types such as default-initialized `[T]` + /// that only track capacity and not length + /// + /// # Panics + /// + /// Implementations may panic if `range` overlaps with already written data + /// + fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result; +} + +/// An implementation of [`ColumnValueDecoder`] for `[T::T]` +pub struct ColumnValueDecoderImpl { + descr: ColumnDescPtr, + + current_encoding: Option, + + // Cache of decoders for existing encodings + decoders: HashMap>>, +} + +impl ColumnValueDecoder for ColumnValueDecoderImpl { + type Slice = [T::T]; + + fn new(descr: &ColumnDescPtr) -> Self { + Self { + descr: descr.clone(), + current_encoding: None, + decoders: Default::default(), + } + } + + fn set_dict( + &mut self, + buf: ByteBufferPtr, + num_values: u32, + mut encoding: Encoding, + _is_sorted: bool, + ) -> Result<()> { + if encoding == Encoding::PLAIN || encoding == Encoding::PLAIN_DICTIONARY { + encoding = Encoding::RLE_DICTIONARY + } + + if self.decoders.contains_key(&encoding) { + return Err(general_err!("Column cannot have more than one dictionary")); + } + + if encoding == Encoding::RLE_DICTIONARY { + let mut dictionary = PlainDecoder::::new(self.descr.type_length()); + dictionary.set_data(buf, num_values as usize)?; + + let mut decoder = DictDecoder::new(); + decoder.set_dict(Box::new(dictionary))?; + self.decoders.insert(encoding, Box::new(decoder)); + Ok(()) + } else { + Err(nyi_err!( + "Invalid/Unsupported encoding type for dictionary: {}", + encoding + )) + } + } + + fn set_data( + &mut self, + mut encoding: Encoding, + data: ByteBufferPtr, + num_levels: usize, + num_values: Option, + ) -> Result<()> { + use std::collections::hash_map::Entry; + + if encoding == Encoding::PLAIN_DICTIONARY { + encoding = Encoding::RLE_DICTIONARY; + } + + let decoder = if encoding == Encoding::RLE_DICTIONARY { + self.decoders + .get_mut(&encoding) + .expect("Decoder for dict should have been set") + } else { + // Search cache for data page decoder + match self.decoders.entry(encoding) { + Entry::Occupied(e) => e.into_mut(), + Entry::Vacant(v) => { + let data_decoder = get_decoder::(self.descr.clone(), encoding)?; + v.insert(data_decoder) + } + } + }; + + decoder.set_data(data, num_values.unwrap_or(num_levels))?; + self.current_encoding = Some(encoding); + Ok(()) + } + + fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result { + let encoding = self + .current_encoding + .expect("current_encoding should be set"); + + let current_decoder = self + .decoders + .get_mut(&encoding) + .unwrap_or_else(|| panic!("decoder for encoding {} should be set", encoding)); + + current_decoder.get(&mut out[range]) + } +} + +/// An implementation of [`ColumnLevelDecoder`] for `[i16]` +pub struct ColumnLevelDecoderImpl { + inner: LevelDecoderInner, +} + +enum LevelDecoderInner { + Packed(BitReader, u8), + Rle(RleDecoder), +} + +impl ColumnLevelDecoder for ColumnLevelDecoderImpl { + type Slice = [i16]; + + fn new(max_level: i16, encoding: Encoding, data: ByteBufferPtr) -> Self { + let bit_width = crate::util::bit_util::log2(max_level as u64 + 1) as u8; + match encoding { + Encoding::RLE => { + let mut decoder = RleDecoder::new(bit_width); + decoder.set_data(data); + Self { + inner: LevelDecoderInner::Rle(decoder), + } + } + Encoding::BIT_PACKED => Self { + inner: LevelDecoderInner::Packed(BitReader::new(data), bit_width), + }, + _ => unreachable!("invalid level encoding: {}", encoding), + } + } + + fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result { + match &mut self.inner { + LevelDecoderInner::Packed(reader, bit_width) => { + Ok(reader.get_batch::(&mut out[range], *bit_width as usize)) + } + LevelDecoderInner::Rle(reader) => reader.get_batch(&mut out[range]), + } + } +} diff --git a/parquet/src/column/writer.rs b/parquet/src/column/writer.rs index 162941a54dea..d80cafe0e0a8 100644 --- a/parquet/src/column/writer.rs +++ b/parquet/src/column/writer.rs @@ -16,9 +16,9 @@ // under the License. //! Contains column writer API. -use std::{cmp, collections::VecDeque, convert::TryFrom, marker::PhantomData, sync::Arc}; +use std::{cmp, collections::VecDeque, convert::TryFrom, marker::PhantomData}; -use crate::basic::{Compression, Encoding, LogicalType, PageType, Type}; +use crate::basic::{Compression, ConvertedType, Encoding, LogicalType, PageType, Type}; use crate::column::page::{CompressedPage, Page, PageWriteSpec, PageWriter}; use crate::compression::{create_codec, Codec}; use crate::data_type::private::ParquetValueType; @@ -36,18 +36,18 @@ use crate::file::{ }; use crate::schema::types::ColumnDescPtr; use crate::util::bit_util::FromBytes; -use crate::util::memory::{ByteBufferPtr, MemTracker}; +use crate::util::memory::ByteBufferPtr; /// Column writer for a Parquet type. -pub enum ColumnWriter { - BoolColumnWriter(ColumnWriterImpl), - Int32ColumnWriter(ColumnWriterImpl), - Int64ColumnWriter(ColumnWriterImpl), - Int96ColumnWriter(ColumnWriterImpl), - FloatColumnWriter(ColumnWriterImpl), - DoubleColumnWriter(ColumnWriterImpl), - ByteArrayColumnWriter(ColumnWriterImpl), - FixedLenByteArrayColumnWriter(ColumnWriterImpl), +pub enum ColumnWriter<'a> { + BoolColumnWriter(ColumnWriterImpl<'a, BoolType>), + Int32ColumnWriter(ColumnWriterImpl<'a, Int32Type>), + Int64ColumnWriter(ColumnWriterImpl<'a, Int64Type>), + Int96ColumnWriter(ColumnWriterImpl<'a, Int96Type>), + FloatColumnWriter(ColumnWriterImpl<'a, FloatType>), + DoubleColumnWriter(ColumnWriterImpl<'a, DoubleType>), + ByteArrayColumnWriter(ColumnWriterImpl<'a, ByteArrayType>), + FixedLenByteArrayColumnWriter(ColumnWriterImpl<'a, FixedLenByteArrayType>), } pub enum Level { @@ -76,11 +76,11 @@ macro_rules! gen_stats_section { } /// Gets a specific column writer corresponding to column descriptor `descr`. -pub fn get_column_writer( +pub fn get_column_writer<'a>( descr: ColumnDescPtr, props: WriterPropertiesPtr, - page_writer: Box, -) -> ColumnWriter { + page_writer: Box, +) -> ColumnWriter<'a> { match descr.physical_type() { Type::BOOLEAN => ColumnWriter::BoolColumnWriter(ColumnWriterImpl::new( descr, @@ -139,9 +139,9 @@ pub fn get_typed_column_writer( } /// Similar to `get_typed_column_writer` but returns a reference. -pub fn get_typed_column_writer_ref( - col_writer: &ColumnWriter, -) -> &ColumnWriterImpl { +pub fn get_typed_column_writer_ref<'a, 'b: 'a, T: DataType>( + col_writer: &'b ColumnWriter<'a>, +) -> &'b ColumnWriterImpl<'a, T> { T::get_column_writer_ref(col_writer).unwrap_or_else(|| { panic!( "Failed to convert column writer into a typed column writer for `{}` type", @@ -151,9 +151,9 @@ pub fn get_typed_column_writer_ref( } /// Similar to `get_typed_column_writer` but returns a reference. -pub fn get_typed_column_writer_mut( - col_writer: &mut ColumnWriter, -) -> &mut ColumnWriterImpl { +pub fn get_typed_column_writer_mut<'a, 'b: 'a, T: DataType>( + col_writer: &'a mut ColumnWriter<'b>, +) -> &'a mut ColumnWriterImpl<'b, T> { T::get_column_writer_mut(col_writer).unwrap_or_else(|| { panic!( "Failed to convert column writer into a typed column writer for `{}` type", @@ -163,11 +163,11 @@ pub fn get_typed_column_writer_mut( } /// Typed column writer for a primitive column. -pub struct ColumnWriterImpl { +pub struct ColumnWriterImpl<'a, T: DataType> { // Column writer properties descr: ColumnDescPtr, props: WriterPropertiesPtr, - page_writer: Box, + page_writer: Box, has_dictionary: bool, dict_encoder: Option>, encoder: Box>, @@ -200,11 +200,11 @@ pub struct ColumnWriterImpl { _phantom: PhantomData, } -impl ColumnWriterImpl { +impl<'a, T: DataType> ColumnWriterImpl<'a, T> { pub fn new( descr: ColumnDescPtr, props: WriterPropertiesPtr, - page_writer: Box, + page_writer: Box, ) -> Self { let codec = props.compression(descr.path()); let compressor = create_codec(codec).unwrap(); @@ -213,7 +213,7 @@ impl ColumnWriterImpl { let dict_encoder = if props.dictionary_enabled(descr.path()) && has_dictionary_support(T::get_physical_type(), &props) { - Some(DictEncoder::new(descr.clone(), Arc::new(MemTracker::new()))) + Some(DictEncoder::new(descr.clone())) } else { None }; @@ -227,7 +227,6 @@ impl ColumnWriterImpl { props .encoding(descr.path()) .unwrap_or_else(|| fallback_encoding(T::get_physical_type(), &props)), - Arc::new(MemTracker::new()), ) .unwrap(); @@ -270,8 +269,8 @@ impl ColumnWriterImpl { values: &[T::T], def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, - min: &Option, - max: &Option, + min: Option<&T::T>, + max: Option<&T::T>, null_count: Option, distinct_count: Option, ) -> Result { @@ -376,9 +375,7 @@ impl ColumnWriterImpl { def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, ) -> Result { - self.write_batch_internal( - values, def_levels, rep_levels, &None, &None, None, None, - ) + self.write_batch_internal(values, def_levels, rep_levels, None, None, None, None) } /// Writer may optionally provide pre-calculated statistics for this batch, in which case we do @@ -389,8 +386,8 @@ impl ColumnWriterImpl { values: &[T::T], def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, - min: &Option, - max: &Option, + min: Option<&T::T>, + max: Option<&T::T>, nulls_count: Option, distinct_count: Option, ) -> Result { @@ -567,6 +564,14 @@ impl ColumnWriterImpl { /// Returns true if there is enough data for a data page, false otherwise. #[inline] fn should_add_data_page(&self) -> bool { + // This is necessary in the event of a much larger dictionary size than page size + // + // In such a scenario the dictionary decoder may return an estimated encoded + // size in excess of the page size limit, even when there are no buffered values + if self.num_buffered_values == 0 { + return false; + } + match self.dict_encoder { Some(ref encoder) => { encoder.estimated_data_encoded_size() >= self.props.data_pagesize_limit() @@ -994,12 +999,47 @@ impl ColumnWriterImpl { /// Evaluate `a > b` according to underlying logical type. fn compare_greater(&self, a: &T::T, b: &T::T) -> bool { - if let Some(LogicalType::INTEGER(int_type)) = self.descr.logical_type() { - if !int_type.is_signed { + if let Some(LogicalType::Integer { is_signed, .. }) = self.descr.logical_type() { + if !is_signed { // need to compare unsigned return a.as_u64().unwrap() > b.as_u64().unwrap(); } } + + match self.descr.converted_type() { + ConvertedType::UINT_8 + | ConvertedType::UINT_16 + | ConvertedType::UINT_32 + | ConvertedType::UINT_64 => { + return a.as_u64().unwrap() > b.as_u64().unwrap(); + } + _ => {} + }; + + if let Some(LogicalType::Decimal { .. }) = self.descr.logical_type() { + match self.descr.physical_type() { + Type::FIXED_LEN_BYTE_ARRAY | Type::BYTE_ARRAY => { + return compare_greater_byte_array_decimals( + a.as_bytes(), + b.as_bytes(), + ); + } + _ => {} + }; + } + + if self.descr.converted_type() == ConvertedType::DECIMAL { + match self.descr.physical_type() { + Type::FIXED_LEN_BYTE_ARRAY | Type::BYTE_ARRAY => { + return compare_greater_byte_array_decimals( + a.as_bytes(), + b.as_bytes(), + ); + } + _ => {} + }; + }; + a > b } } @@ -1043,23 +1083,70 @@ fn has_dictionary_support(kind: Type, props: &WriterProperties) -> bool { } } +/// Signed comparison of bytes arrays +fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool { + let a_length = a.len(); + let b_length = b.len(); + + if a_length == 0 || b_length == 0 { + return a_length > 0; + } + + let first_a: u8 = a[0]; + let first_b: u8 = b[0]; + + // We can short circuit for different signed numbers or + // for equal length bytes arrays that have different first bytes. + // The equality requirement is necessary for sign extension cases. + // 0xFF10 should be equal to 0x10 (due to big endian sign extension). + if (0x80 & first_a) != (0x80 & first_b) + || (a_length == b_length && first_a != first_b) + { + return (first_a as i8) > (first_b as i8); + } + + // When the lengths are unequal and the numbers are of the same + // sign we need to do comparison by sign extending the shorter + // value first, and once we get to equal sized arrays, lexicographical + // unsigned comparison of everything but the first byte is sufficient. + + let extension: u8 = if (first_a as i8) < 0 { 0xFF } else { 0 }; + + if a_length != b_length { + let not_equal = if a_length > b_length { + let lead_length = a_length - b_length; + (&a[0..lead_length]).iter().any(|&x| x != extension) + } else { + let lead_length = b_length - a_length; + (&b[0..lead_length]).iter().any(|&x| x != extension) + }; + + if not_equal { + let negative_values: bool = (first_a as i8) < 0; + let a_longer: bool = a_length > b_length; + return if negative_values { !a_longer } else { a_longer }; + } + } + + (a[1..]) > (b[1..]) +} + #[cfg(test)] mod tests { use rand::distributions::uniform::SampleUniform; + use std::sync::Arc; use crate::column::{ page::PageReader, reader::{get_column_reader, get_typed_column_reader, ColumnReaderImpl}, }; + use crate::file::writer::TrackedWrite; use crate::file::{ properties::WriterProperties, reader::SerializedPageReader, writer::SerializedPageWriter, }; use crate::schema::types::{ColumnDescriptor, ColumnPath, Type as SchemaType}; - use crate::util::{ - io::{FileSink, FileSource}, - test_common::{get_temp_file, random_numbers_range}, - }; + use crate::util::{io::FileSource, test_common::random_numbers_range}; use super::*; @@ -1469,6 +1556,84 @@ mod tests { } } + #[test] + fn test_column_writer_check_byte_array_min_max() { + let page_writer = get_test_page_writer(); + let props = Arc::new(WriterProperties::builder().build()); + let mut writer = + get_test_decimals_column_writer::(page_writer, 0, 0, props); + writer + .write_batch( + &[ + ByteArray::from(vec![ + 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 179u8, + 172u8, 19u8, 35u8, 231u8, 90u8, 0u8, 0u8, + ]), + ByteArray::from(vec![ + 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 228u8, + 62u8, 146u8, 152u8, 177u8, 56u8, 0u8, 0u8, + ]), + ByteArray::from(vec![ + 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, + 0u8, 0u8, 0u8, + ]), + ByteArray::from(vec![ + 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 41u8, 162u8, 36u8, 26u8, + 246u8, 44u8, 0u8, 0u8, + ]), + ], + None, + None, + ) + .unwrap(); + let (_bytes_written, _rows_written, metadata) = writer.close().unwrap(); + if let Some(stats) = metadata.statistics() { + assert!(stats.has_min_max_set()); + if let Statistics::ByteArray(stats) = stats { + assert_eq!( + stats.min(), + &ByteArray::from(vec![ + 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 179u8, + 172u8, 19u8, 35u8, 231u8, 90u8, 0u8, 0u8, + ]) + ); + assert_eq!( + stats.max(), + &ByteArray::from(vec![ + 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 41u8, 162u8, 36u8, 26u8, + 246u8, 44u8, 0u8, 0u8, + ]) + ); + } else { + panic!("expecting Statistics::ByteArray"); + } + } else { + panic!("metadata missing statistics"); + } + } + + #[test] + fn test_column_writer_uint32_converted_type_min_max() { + let page_writer = get_test_page_writer(); + let props = Arc::new(WriterProperties::builder().build()); + let mut writer = get_test_unsigned_int_given_as_converted_column_writer::< + Int32Type, + >(page_writer, 0, 0, props); + writer.write_batch(&[0, 1, 2, 3, 4, 5], None, None).unwrap(); + let (_bytes_written, _rows_written, metadata) = writer.close().unwrap(); + if let Some(stats) = metadata.statistics() { + assert!(stats.has_min_max_set()); + if let Statistics::Int32(stats) = stats { + assert_eq!(stats.min(), &0,); + assert_eq!(stats.max(), &5,); + } else { + panic!("expecting Statistics::Int32"); + } + } else { + panic!("metadata missing statistics"); + } + } + #[test] fn test_column_writer_precalculated_statistics() { let page_writer = get_test_page_writer(); @@ -1479,8 +1644,8 @@ mod tests { &[1, 2, 3, 4], None, None, - &Some(-17), - &Some(9000), + Some(&-17), + Some(&9000), Some(21), Some(55), ) @@ -1516,14 +1681,13 @@ mod tests { #[test] fn test_column_writer_empty_column_roundtrip() { let props = WriterProperties::builder().build(); - column_roundtrip::("test_col_writer_rnd_1", props, &[], None, None); + column_roundtrip::(props, &[], None, None); } #[test] fn test_column_writer_non_nullable_values_roundtrip() { let props = WriterProperties::builder().build(); column_roundtrip_random::( - "test_col_writer_rnd_2", props, 1024, std::i32::MIN, @@ -1537,7 +1701,6 @@ mod tests { fn test_column_writer_nullable_non_repeated_values_roundtrip() { let props = WriterProperties::builder().build(); column_roundtrip_random::( - "test_column_writer_nullable_non_repeated_values_roundtrip", props, 1024, std::i32::MIN, @@ -1551,7 +1714,6 @@ mod tests { fn test_column_writer_nullable_repeated_values_roundtrip() { let props = WriterProperties::builder().build(); column_roundtrip_random::( - "test_col_writer_rnd_3", props, 1024, std::i32::MIN, @@ -1568,7 +1730,6 @@ mod tests { .set_data_pagesize_limit(32) .build(); column_roundtrip_random::( - "test_col_writer_rnd_4", props, 1024, std::i32::MIN, @@ -1584,7 +1745,6 @@ mod tests { let props = WriterProperties::builder().set_write_batch_size(*i).build(); column_roundtrip_random::( - "test_col_writer_rnd_5", props, 1024, std::i32::MIN, @@ -1602,7 +1762,6 @@ mod tests { .set_dictionary_enabled(false) .build(); column_roundtrip_random::( - "test_col_writer_rnd_6", props, 1024, std::i32::MIN, @@ -1619,7 +1778,6 @@ mod tests { .set_dictionary_enabled(false) .build(); column_roundtrip_random::( - "test_col_writer_rnd_7", props, 1024, std::i32::MIN, @@ -1636,7 +1794,6 @@ mod tests { .set_compression(Compression::SNAPPY) .build(); column_roundtrip_random::( - "test_col_writer_rnd_8", props, 2048, std::i32::MIN, @@ -1653,7 +1810,6 @@ mod tests { .set_compression(Compression::SNAPPY) .build(); column_roundtrip_random::( - "test_col_writer_rnd_9", props, 2048, std::i32::MIN, @@ -1667,9 +1823,9 @@ mod tests { fn test_column_writer_add_data_pages_with_dict() { // ARROW-5129: Test verifies that we add data page in case of dictionary encoding // and no fallback occurred so far. - let file = get_temp_file("test_column_writer_add_data_pages_with_dict", &[]); - let sink = FileSink::new(&file); - let page_writer = Box::new(SerializedPageWriter::new(sink)); + let mut file = tempfile::tempfile().unwrap(); + let mut writer = TrackedWrite::new(&mut file); + let page_writer = Box::new(SerializedPageWriter::new(&mut writer)); let props = Arc::new( WriterProperties::builder() .set_data_pagesize_limit(15) // actually each page will have size 15-18 bytes @@ -1890,11 +2046,32 @@ mod tests { assert!(matches!(stats, Statistics::Double(_))); } + #[test] + fn test_compare_greater_byte_array_decimals() { + assert!(!compare_greater_byte_array_decimals(&[], &[],),); + assert!(compare_greater_byte_array_decimals(&[1u8,], &[],),); + assert!(!compare_greater_byte_array_decimals(&[], &[1u8,],),); + assert!(compare_greater_byte_array_decimals(&[1u8,], &[0u8,],),); + assert!(!compare_greater_byte_array_decimals(&[1u8,], &[1u8,],),); + assert!(compare_greater_byte_array_decimals(&[1u8, 0u8,], &[0u8,],),); + assert!(!compare_greater_byte_array_decimals( + &[0u8, 1u8,], + &[1u8, 0u8,], + ),); + assert!(!compare_greater_byte_array_decimals( + &[255u8, 35u8, 0u8, 0u8,], + &[0u8,], + ),); + assert!(compare_greater_byte_array_decimals( + &[0u8,], + &[255u8, 35u8, 0u8, 0u8,], + ),); + } + /// Performs write-read roundtrip with randomly generated values and levels. /// `max_size` is maximum number of values or levels (if `max_def_level` > 0) to write /// for a column. fn column_roundtrip_random( - file_name: &str, props: WriterProperties, max_size: usize, min_value: T::T, @@ -1931,20 +2108,19 @@ mod tests { let mut values: Vec = Vec::new(); random_numbers_range(num_values, min_value, max_value, &mut values); - column_roundtrip::(file_name, props, &values[..], def_levels, rep_levels); + column_roundtrip::(props, &values[..], def_levels, rep_levels); } /// Performs write-read roundtrip and asserts written values and levels. - fn column_roundtrip<'a, T: DataType>( - file_name: &'a str, + fn column_roundtrip( props: WriterProperties, values: &[T::T], def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, ) { - let file = get_temp_file(file_name, &[]); - let sink = FileSink::new(&file); - let page_writer = Box::new(SerializedPageWriter::new(sink)); + let mut file = tempfile::tempfile().unwrap(); + let mut writer = TrackedWrite::new(&mut file); + let page_writer = Box::new(SerializedPageWriter::new(&mut writer)); let max_def_level = match def_levels { Some(buf) => *buf.iter().max().unwrap_or(&0i16), @@ -2079,12 +2255,12 @@ mod tests { } /// Returns column writer. - fn get_test_column_writer( - page_writer: Box, + fn get_test_column_writer<'a, T: DataType>( + page_writer: Box, max_def_level: i16, max_rep_level: i16, props: WriterPropertiesPtr, - ) -> ColumnWriterImpl { + ) -> ColumnWriterImpl<'a, T> { let descr = Arc::new(get_test_column_descr::(max_def_level, max_rep_level)); let column_writer = get_column_writer(descr, props, page_writer); get_typed_column_writer::(column_writer) @@ -2158,4 +2334,94 @@ mod tests { panic!("metadata missing statistics"); } } + + /// Returns Decimals column writer. + fn get_test_decimals_column_writer( + page_writer: Box, + max_def_level: i16, + max_rep_level: i16, + props: WriterPropertiesPtr, + ) -> ColumnWriterImpl<'static, T> { + let descr = Arc::new(get_test_decimals_column_descr::( + max_def_level, + max_rep_level, + )); + let column_writer = get_column_writer(descr, props, page_writer); + get_typed_column_writer::(column_writer) + } + + /// Returns decimals column reader. + fn get_test_decimals_column_reader( + page_reader: Box, + max_def_level: i16, + max_rep_level: i16, + ) -> ColumnReaderImpl { + let descr = Arc::new(get_test_decimals_column_descr::( + max_def_level, + max_rep_level, + )); + let column_reader = get_column_reader(descr, page_reader); + get_typed_column_reader::(column_reader) + } + + /// Returns descriptor for Decimal type with primitive column. + fn get_test_decimals_column_descr( + max_def_level: i16, + max_rep_level: i16, + ) -> ColumnDescriptor { + let path = ColumnPath::from("col"); + let tpe = SchemaType::primitive_type_builder("col", T::get_physical_type()) + .with_length(16) + .with_logical_type(Some(LogicalType::Decimal { + scale: 2, + precision: 3, + })) + .with_scale(2) + .with_precision(3) + .build() + .unwrap(); + ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) + } + + /// Returns column writer for UINT32 Column provided as ConvertedType only + fn get_test_unsigned_int_given_as_converted_column_writer<'a, T: DataType>( + page_writer: Box, + max_def_level: i16, + max_rep_level: i16, + props: WriterPropertiesPtr, + ) -> ColumnWriterImpl<'a, T> { + let descr = Arc::new(get_test_converted_type_unsigned_integer_column_descr::( + max_def_level, + max_rep_level, + )); + let column_writer = get_column_writer(descr, props, page_writer); + get_typed_column_writer::(column_writer) + } + + /// Returns column reader for UINT32 Column provided as ConvertedType only + fn get_test_unsigned_int_given_as_converted_column_reader( + page_reader: Box, + max_def_level: i16, + max_rep_level: i16, + ) -> ColumnReaderImpl { + let descr = Arc::new(get_test_converted_type_unsigned_integer_column_descr::( + max_def_level, + max_rep_level, + )); + let column_reader = get_column_reader(descr, page_reader); + get_typed_column_reader::(column_reader) + } + + /// Returns column descriptor for UINT32 Column provided as ConvertedType only + fn get_test_converted_type_unsigned_integer_column_descr( + max_def_level: i16, + max_rep_level: i16, + ) -> ColumnDescriptor { + let path = ColumnPath::from("col"); + let tpe = SchemaType::primitive_type_builder("col", T::get_physical_type()) + .with_converted_type(ConvertedType::UINT_32) + .build() + .unwrap(); + ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) + } } diff --git a/parquet/src/compression.rs b/parquet/src/compression.rs index a1155971fbd2..a5e49360a28a 100644 --- a/parquet/src/compression.rs +++ b/parquet/src/compression.rs @@ -20,38 +20,43 @@ //! See [`Compression`](crate::basic::Compression) enum for all available compression //! algorithms. //! -//! # Example -//! -//! ```no_run -//! use parquet::{basic::Compression, compression::create_codec}; -//! -//! let mut codec = match create_codec(Compression::SNAPPY) { -//! Ok(Some(codec)) => codec, -//! _ => panic!(), -//! }; -//! -//! let data = vec![b'p', b'a', b'r', b'q', b'u', b'e', b't']; -//! let mut compressed = vec![]; -//! codec.compress(&data[..], &mut compressed).unwrap(); -//! -//! let mut output = vec![]; -//! codec.decompress(&compressed[..], &mut output).unwrap(); -//! -//! assert_eq!(output, data); -//! ``` - +#[cfg_attr( + feature = "experimental", + doc = r##" +# Example + +```no_run +use parquet::{basic::Compression, compression::create_codec}; + +let mut codec = match create_codec(Compression::SNAPPY) { + Ok(Some(codec)) => codec, + _ => panic!(), +}; + +let data = vec![b'p', b'a', b'r', b'q', b'u', b'e', b't']; +let mut compressed = vec![]; +codec.compress(&data[..], &mut compressed).unwrap(); + +let mut output = vec![]; +codec.decompress(&compressed[..], &mut output).unwrap(); + +assert_eq!(output, data); +``` +"## +)] use crate::basic::Compression as CodecType; use crate::errors::{ParquetError, Result}; /// Parquet compression codec interface. -pub trait Codec { - /// Compresses data stored in slice `input_buf` and writes the compressed result +pub trait Codec: Send { + /// Compresses data stored in slice `input_buf` and appends the compressed result /// to `output_buf`. + /// /// Note that you'll need to call `clear()` before reusing the same `output_buf` /// across different `compress` calls. fn compress(&mut self, input_buf: &[u8], output_buf: &mut Vec) -> Result<()>; - /// Decompresses data stored in slice `input_buf` and writes output to `output_buf`. + /// Decompresses data stored in slice `input_buf` and appends output to `output_buf`. /// Returns the total number of bytes written. fn decompress(&mut self, input_buf: &[u8], output_buf: &mut Vec) -> Result; @@ -107,9 +112,10 @@ mod snappy_codec { output_buf: &mut Vec, ) -> Result { let len = decompress_len(input_buf)?; - output_buf.resize(len, 0); + let offset = output_buf.len(); + output_buf.resize(offset + len, 0); self.decoder - .decompress(input_buf, output_buf) + .decompress(input_buf, &mut output_buf[offset..]) .map_err(|e| e.into()) } @@ -336,13 +342,13 @@ mod tests { .expect("Error when compressing"); // Decompress with c2 - let mut decompressed_size = c2 + let decompressed_size = c2 .decompress(compressed.as_slice(), &mut decompressed) .expect("Error when decompressing"); assert_eq!(data.len(), decompressed_size); - decompressed.truncate(decompressed_size); assert_eq!(data, decompressed.as_slice()); + decompressed.clear(); compressed.clear(); // Compress with c2 @@ -350,12 +356,32 @@ mod tests { .expect("Error when compressing"); // Decompress with c1 - decompressed_size = c1 + let decompressed_size = c1 .decompress(compressed.as_slice(), &mut decompressed) .expect("Error when decompressing"); assert_eq!(data.len(), decompressed_size); - decompressed.truncate(decompressed_size); assert_eq!(data, decompressed.as_slice()); + + decompressed.clear(); + compressed.clear(); + + // Test does not trample existing data in output buffers + let prefix = &[0xDE, 0xAD, 0xBE, 0xEF]; + decompressed.extend_from_slice(prefix); + compressed.extend_from_slice(prefix); + + c2.compress(data, &mut compressed) + .expect("Error when compressing"); + + assert_eq!(&compressed[..4], prefix); + + let decompressed_size = c2 + .decompress(&compressed[4..], &mut decompressed) + .expect("Error when decompressing"); + + assert_eq!(data.len(), decompressed_size); + assert_eq!(data, &decompressed[4..]); + assert_eq!(&decompressed[..4], prefix); } fn test_codec(c: CodecType) { diff --git a/parquet/src/data_type.rs b/parquet/src/data_type.rs index 8c64e8629463..86ccefbd85eb 100644 --- a/parquet/src/data_type.rs +++ b/parquet/src/data_type.rs @@ -30,13 +30,13 @@ use crate::column::reader::{ColumnReader, ColumnReaderImpl}; use crate::column::writer::{ColumnWriter, ColumnWriterImpl}; use crate::errors::{ParquetError, Result}; use crate::util::{ - bit_util::{from_ne_slice, FromBytes}, - memory::{ByteBuffer, ByteBufferPtr}, + bit_util::{from_le_slice, from_ne_slice, FromBytes}, + memory::ByteBufferPtr, }; /// Rust representation for logical type INT96, value is backed by an array of `u32`. /// The type only takes 12 bytes, without extra padding. -#[derive(Clone, Debug, PartialOrd)] +#[derive(Clone, Debug, PartialOrd, Default)] pub struct Int96 { value: Option<[u32; 3]>, } @@ -75,12 +75,6 @@ impl Int96 { } } -impl Default for Int96 { - fn default() -> Self { - Self { value: None } - } -} - impl PartialEq for Int96 { fn eq(&self, other: &Int96) -> bool { match (&self.value, &other.value) { @@ -109,7 +103,7 @@ impl fmt::Display for Int96 { /// Rust representation for BYTE_ARRAY and FIXED_LEN_BYTE_ARRAY Parquet physical types. /// Value is backed by a byte buffer. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct ByteArray { data: Option, } @@ -223,20 +217,6 @@ impl From for ByteArray { } } -impl From for ByteArray { - fn from(mut buf: ByteBuffer) -> ByteArray { - Self { - data: Some(buf.consume()), - } - } -} - -impl Default for ByteArray { - fn default() -> Self { - ByteArray { data: None } - } -} - impl PartialEq for ByteArray { fn eq(&self, other: &ByteArray) -> bool { match (&self.data, &other.data) { @@ -610,6 +590,8 @@ pub(crate) mod private { + super::FromBytes + super::SliceAsBytes + PartialOrd + + Send + + crate::encodings::decoding::private::GetDecoder { /// Encode the value directly from a higher level encoder fn encode( @@ -858,17 +840,12 @@ pub(crate) mod private { decoder.start += bytes_to_decode; let mut pos = 0; // position in byte array - for i in 0..num_values { + for item in buffer.iter_mut().take(num_values) { let elem0 = byteorder::LittleEndian::read_u32(&bytes[pos..pos + 4]); let elem1 = byteorder::LittleEndian::read_u32(&bytes[pos + 4..pos + 8]); let elem2 = byteorder::LittleEndian::read_u32(&bytes[pos + 8..pos + 12]); - buffer[i] - .as_mut_any() - .downcast_mut::() - .unwrap() - .set_data(elem0, elem1, elem2); - + item.set_data(elem0, elem1, elem2); pos += 12; } decoder.num_values -= num_values; @@ -1012,16 +989,15 @@ pub(crate) mod private { .as_mut() .expect("set_data should have been called"); let num_values = std::cmp::min(buffer.len(), decoder.num_values); - for i in 0..num_values { + + for item in buffer.iter_mut().take(num_values) { let len = decoder.type_length as usize; if data.len() < decoder.start + len { return Err(eof_err!("Not enough bytes to decode")); } - let val: &mut Self = buffer[i].as_mut_any().downcast_mut().unwrap(); - - val.set_data(data.range(decoder.start, len)); + item.set_data(data.range(decoder.start, len)); decoder.start += len; } decoder.num_values -= num_values; @@ -1048,7 +1024,7 @@ pub(crate) mod private { /// Contains the Parquet physical type information as well as the Rust primitive type /// presentation. -pub trait DataType: 'static { +pub trait DataType: 'static + Send { type T: private::ParquetValueType; /// Returns Parquet physical type. @@ -1061,19 +1037,21 @@ pub trait DataType: 'static { where Self: Sized; - fn get_column_writer(column_writer: ColumnWriter) -> Option> + fn get_column_writer( + column_writer: ColumnWriter<'_>, + ) -> Option> where Self: Sized; - fn get_column_writer_ref( - column_writer: &ColumnWriter, - ) -> Option<&ColumnWriterImpl> + fn get_column_writer_ref<'a, 'b: 'a>( + column_writer: &'b ColumnWriter<'a>, + ) -> Option<&'b ColumnWriterImpl<'a, Self>> where Self: Sized; - fn get_column_writer_mut( - column_writer: &mut ColumnWriter, - ) -> Option<&mut ColumnWriterImpl> + fn get_column_writer_mut<'a, 'b: 'a>( + column_writer: &'a mut ColumnWriter<'b>, + ) -> Option<&'a mut ColumnWriterImpl<'b, Self>> where Self: Sized; } @@ -1118,26 +1096,26 @@ macro_rules! make_type { } fn get_column_writer( - column_writer: ColumnWriter, - ) -> Option> { + column_writer: ColumnWriter<'_>, + ) -> Option> { match column_writer { ColumnWriter::$writer_ident(w) => Some(w), _ => None, } } - fn get_column_writer_ref( - column_writer: &ColumnWriter, - ) -> Option<&ColumnWriterImpl> { + fn get_column_writer_ref<'a, 'b: 'a>( + column_writer: &'a ColumnWriter<'b>, + ) -> Option<&'a ColumnWriterImpl<'b, Self>> { match column_writer { ColumnWriter::$writer_ident(w) => Some(w), _ => None, } } - fn get_column_writer_mut( - column_writer: &mut ColumnWriter, - ) -> Option<&mut ColumnWriterImpl> { + fn get_column_writer_mut<'a, 'b: 'a>( + column_writer: &'a mut ColumnWriter<'b>, + ) -> Option<&'a mut ColumnWriterImpl<'b, Self>> { match column_writer { ColumnWriter::$writer_ident(w) => Some(w), _ => None, @@ -1216,8 +1194,14 @@ make_type!( impl FromBytes for Int96 { type Buffer = [u8; 12]; - fn from_le_bytes(_bs: Self::Buffer) -> Self { - unimplemented!() + fn from_le_bytes(bs: Self::Buffer) -> Self { + let mut i = Int96::new(); + i.set_data( + from_le_slice(&bs[0..4]), + from_le_slice(&bs[4..8]), + from_le_slice(&bs[8..12]), + ); + i } fn from_be_bytes(_bs: Self::Buffer) -> Self { unimplemented!() @@ -1237,8 +1221,8 @@ impl FromBytes for Int96 { // appear to actual be converted directly from bytes impl FromBytes for ByteArray { type Buffer = [u8; 8]; - fn from_le_bytes(_bs: Self::Buffer) -> Self { - unreachable!() + fn from_le_bytes(bs: Self::Buffer) -> Self { + ByteArray::from(bs.to_vec()) } fn from_be_bytes(_bs: Self::Buffer) -> Self { unreachable!() @@ -1251,8 +1235,8 @@ impl FromBytes for ByteArray { impl FromBytes for FixedLenByteArray { type Buffer = [u8; 8]; - fn from_le_bytes(_bs: Self::Buffer) -> Self { - unreachable!() + fn from_le_bytes(bs: Self::Buffer) -> Self { + Self(ByteArray::from(bs.to_vec())) } fn from_be_bytes(_bs: Self::Buffer) -> Self { unreachable!() @@ -1338,8 +1322,7 @@ mod tests { ByteArray::from(ByteBufferPtr::new(vec![1u8, 2u8, 3u8, 4u8, 5u8])).data(), &[1u8, 2u8, 3u8, 4u8, 5u8] ); - let mut buf = ByteBuffer::new(); - buf.set_data(vec![6u8, 7u8, 8u8, 9u8, 10u8]); + let buf = vec![6u8, 7u8, 8u8, 9u8, 10u8]; assert_eq!(ByteArray::from(buf).data(), &[6u8, 7u8, 8u8, 9u8, 10u8]); } diff --git a/parquet/src/encodings/decoding.rs b/parquet/src/encodings/decoding.rs index 48fc108840e9..7c95d553234b 100644 --- a/parquet/src/encodings/decoding.rs +++ b/parquet/src/encodings/decoding.rs @@ -17,25 +17,132 @@ //! Contains all supported decoders for Parquet. +use num::traits::WrappingAdd; +use num::FromPrimitive; use std::{cmp, marker::PhantomData, mem}; use super::rle::RleDecoder; use crate::basic::*; -use crate::data_type::private::*; +use crate::data_type::private::ParquetValueType; use crate::data_type::*; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use crate::util::{ - bit_util::{self, BitReader, FromBytes}, - memory::{ByteBuffer, ByteBufferPtr}, + bit_util::{self, BitReader}, + memory::ByteBufferPtr, }; +pub(crate) mod private { + use super::*; + + /// A trait that allows getting a [`Decoder`] implementation for a [`DataType`] with + /// the corresponding [`ParquetValueType`]. This is necessary to support + /// [`Decoder`] implementations that may not be applicable for all [`DataType`] + /// and by extension all [`ParquetValueType`] + pub trait GetDecoder { + fn get_decoder>( + descr: ColumnDescPtr, + encoding: Encoding, + ) -> Result>> { + get_decoder_default(descr, encoding) + } + } + + fn get_decoder_default( + descr: ColumnDescPtr, + encoding: Encoding, + ) -> Result>> { + match encoding { + Encoding::PLAIN => Ok(Box::new(PlainDecoder::new(descr.type_length()))), + Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => Err(general_err!( + "Cannot initialize this encoding through this function" + )), + Encoding::RLE + | Encoding::DELTA_BINARY_PACKED + | Encoding::DELTA_BYTE_ARRAY + | Encoding::DELTA_LENGTH_BYTE_ARRAY => Err(general_err!( + "Encoding {} is not supported for type", + encoding + )), + e => Err(nyi_err!("Encoding {} is not supported", e)), + } + } + + impl GetDecoder for bool { + fn get_decoder>( + descr: ColumnDescPtr, + encoding: Encoding, + ) -> Result>> { + match encoding { + Encoding::RLE => Ok(Box::new(RleValueDecoder::new())), + _ => get_decoder_default(descr, encoding), + } + } + } + + impl GetDecoder for i32 { + fn get_decoder>( + descr: ColumnDescPtr, + encoding: Encoding, + ) -> Result>> { + match encoding { + Encoding::DELTA_BINARY_PACKED => Ok(Box::new(DeltaBitPackDecoder::new())), + _ => get_decoder_default(descr, encoding), + } + } + } + + impl GetDecoder for i64 { + fn get_decoder>( + descr: ColumnDescPtr, + encoding: Encoding, + ) -> Result>> { + match encoding { + Encoding::DELTA_BINARY_PACKED => Ok(Box::new(DeltaBitPackDecoder::new())), + _ => get_decoder_default(descr, encoding), + } + } + } + + impl GetDecoder for f32 {} + impl GetDecoder for f64 {} + + impl GetDecoder for ByteArray { + fn get_decoder>( + descr: ColumnDescPtr, + encoding: Encoding, + ) -> Result>> { + match encoding { + Encoding::DELTA_BYTE_ARRAY => Ok(Box::new(DeltaByteArrayDecoder::new())), + Encoding::DELTA_LENGTH_BYTE_ARRAY => { + Ok(Box::new(DeltaLengthByteArrayDecoder::new())) + } + _ => get_decoder_default(descr, encoding), + } + } + } + + impl GetDecoder for FixedLenByteArray { + fn get_decoder>( + descr: ColumnDescPtr, + encoding: Encoding, + ) -> Result>> { + match encoding { + Encoding::DELTA_BYTE_ARRAY => Ok(Box::new(DeltaByteArrayDecoder::new())), + _ => get_decoder_default(descr, encoding), + } + } + } + + impl GetDecoder for Int96 {} +} + // ---------------------------------------------------------------------- // Decoders /// A Parquet decoder for the data type `T`. -pub trait Decoder { +pub trait Decoder: Send { /// Sets the data to decode to be `data`, which should contain `num_values` of values /// to decode. fn set_data(&mut self, data: ByteBufferPtr, num_values: usize) -> Result<()>; @@ -109,20 +216,8 @@ pub fn get_decoder( descr: ColumnDescPtr, encoding: Encoding, ) -> Result>> { - let decoder: Box> = match encoding { - Encoding::PLAIN => Box::new(PlainDecoder::new(descr.type_length())), - Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => { - return Err(general_err!( - "Cannot initialize this encoding through this function" - )); - } - Encoding::RLE => Box::new(RleValueDecoder::new()), - Encoding::DELTA_BINARY_PACKED => Box::new(DeltaBitPackDecoder::new()), - Encoding::DELTA_LENGTH_BYTE_ARRAY => Box::new(DeltaLengthByteArrayDecoder::new()), - Encoding::DELTA_BYTE_ARRAY => Box::new(DeltaByteArrayDecoder::new()), - e => return Err(nyi_err!("Encoding {} is not supported", e)), - }; - Ok(decoder) + use self::private::GetDecoder; + T::T::get_decoder(descr, encoding) } // ---------------------------------------------------------------------- @@ -338,189 +433,258 @@ pub struct DeltaBitPackDecoder { initialized: bool, // Header info - num_values: usize, - num_mini_blocks: i64, + + /// The number of values in each block + block_size: usize, + /// The number of values that remain to be read in the current page + values_left: usize, + /// The number of mini-blocks in each block + mini_blocks_per_block: usize, + /// The number of values in each mini block values_per_mini_block: usize, - values_current_mini_block: usize, - first_value: i64, - first_value_read: bool, // Per block info - min_delta: i64, - mini_block_idx: usize, - delta_bit_width: u8, - delta_bit_widths: ByteBuffer, - deltas_in_mini_block: Vec, // eagerly loaded deltas for a mini block - use_batch: bool, - - current_value: i64, - _phantom: PhantomData, + /// The minimum delta in the block + min_delta: T::T, + /// The byte offset of the end of the current block + block_end_offset: usize, + /// The index on the current mini block + mini_block_idx: usize, + /// The bit widths of each mini block in the current block + mini_block_bit_widths: Vec, + /// The number of values remaining in the current mini block + mini_block_remaining: usize, + + /// The first value from the block header if not consumed + first_value: Option, + /// The last value to compute offsets from + last_value: T::T, } -impl DeltaBitPackDecoder { +impl DeltaBitPackDecoder +where + T::T: Default + FromPrimitive + WrappingAdd + Copy, +{ /// Creates new delta bit packed decoder. pub fn new() -> Self { Self { bit_reader: BitReader::from(vec![]), initialized: false, - num_values: 0, - num_mini_blocks: 0, + block_size: 0, + values_left: 0, + mini_blocks_per_block: 0, values_per_mini_block: 0, - values_current_mini_block: 0, - first_value: 0, - first_value_read: false, - min_delta: 0, + min_delta: Default::default(), mini_block_idx: 0, - delta_bit_width: 0, - delta_bit_widths: ByteBuffer::new(), - deltas_in_mini_block: vec![], - use_batch: mem::size_of::() == 4, - current_value: 0, - _phantom: PhantomData, + mini_block_bit_widths: vec![], + mini_block_remaining: 0, + block_end_offset: 0, + first_value: None, + last_value: Default::default(), } } - /// Returns underlying bit reader offset. + /// Returns the current offset pub fn get_offset(&self) -> usize { assert!(self.initialized, "Bit reader is not initialized"); - self.bit_reader.get_byte_offset() + match self.values_left { + // If we've exhausted this page report the end of the current block + // as we may not have consumed the trailing padding + // + // The max is necessary to handle pages which don't contain more than + // one value and therefore have no blocks, but still contain a page header + 0 => self.bit_reader.get_byte_offset().max(self.block_end_offset), + _ => self.bit_reader.get_byte_offset(), + } } - /// Initializes new mini block. + /// Initializes the next block and the first mini block within it #[inline] - fn init_block(&mut self) -> Result<()> { - self.min_delta = self + fn next_block(&mut self) -> Result<()> { + let min_delta = self .bit_reader .get_zigzag_vlq_int() .ok_or_else(|| eof_err!("Not enough data to decode 'min_delta'"))?; - self.delta_bit_widths.clear(); - for _ in 0..self.num_mini_blocks { - let w = self - .bit_reader - .get_aligned::(1) - .ok_or_else(|| eof_err!("Not enough data to decode 'width'"))?; - self.delta_bit_widths.push(w); + self.min_delta = T::T::from_i64(min_delta) + .ok_or_else(|| general_err!("'min_delta' too large"))?; + + self.mini_block_bit_widths.clear(); + self.bit_reader.get_aligned_bytes( + &mut self.mini_block_bit_widths, + self.mini_blocks_per_block as usize, + ); + + let mut offset = self.bit_reader.get_byte_offset(); + let mut remaining = self.values_left; + + // Compute the end offset of the current block + for b in &mut self.mini_block_bit_widths { + if remaining == 0 { + // Specification requires handling arbitrary bit widths + // for trailing mini blocks + *b = 0; + } + remaining = remaining.saturating_sub(self.values_per_mini_block); + offset += *b as usize * self.values_per_mini_block / 8; + } + self.block_end_offset = offset; + + if self.mini_block_bit_widths.len() != self.mini_blocks_per_block { + return Err(eof_err!("insufficient mini block bit widths")); } + self.mini_block_remaining = self.values_per_mini_block; self.mini_block_idx = 0; - self.delta_bit_width = self.delta_bit_widths.data()[0]; - self.values_current_mini_block = self.values_per_mini_block; + Ok(()) } - /// Loads delta into mini block. + /// Initializes the next mini block #[inline] - fn load_deltas_in_mini_block(&mut self) -> Result<()> - where - T::T: FromBytes, - { - if self.use_batch { - self.deltas_in_mini_block - .resize(self.values_current_mini_block, T::T::default()); - let loaded = self.bit_reader.get_batch::( - &mut self.deltas_in_mini_block[..], - self.delta_bit_width as usize, - ); - assert!(loaded == self.values_current_mini_block); + fn next_mini_block(&mut self) -> Result<()> { + if self.mini_block_idx + 1 < self.mini_block_bit_widths.len() { + self.mini_block_idx += 1; + self.mini_block_remaining = self.values_per_mini_block; + Ok(()) } else { - self.deltas_in_mini_block.clear(); - for _ in 0..self.values_current_mini_block { - // TODO: load one batch at a time similar to int32 - let delta = self - .bit_reader - .get_value::(self.delta_bit_width as usize) - .ok_or_else(|| eof_err!("Not enough data to decode 'delta'"))?; - self.deltas_in_mini_block.push(delta); - } + self.next_block() } - - Ok(()) } } -impl Decoder for DeltaBitPackDecoder { +impl Decoder for DeltaBitPackDecoder +where + T::T: Default + FromPrimitive + WrappingAdd + Copy, +{ // # of total values is derived from encoding #[inline] fn set_data(&mut self, data: ByteBufferPtr, _index: usize) -> Result<()> { self.bit_reader = BitReader::new(data); self.initialized = true; - let block_size = self + // Read header information + self.block_size = self .bit_reader .get_vlq_int() - .ok_or_else(|| eof_err!("Not enough data to decode 'block_size'"))?; - self.num_mini_blocks = self + .ok_or_else(|| eof_err!("Not enough data to decode 'block_size'"))? + .try_into() + .map_err(|_| general_err!("invalid 'block_size'"))?; + + self.mini_blocks_per_block = self .bit_reader .get_vlq_int() - .ok_or_else(|| eof_err!("Not enough data to decode 'num_mini_blocks'"))?; - self.num_values = self + .ok_or_else(|| eof_err!("Not enough data to decode 'mini_blocks_per_block'"))? + .try_into() + .map_err(|_| general_err!("invalid 'mini_blocks_per_block'"))?; + + self.values_left = self .bit_reader .get_vlq_int() - .ok_or_else(|| eof_err!("Not enough data to decode 'num_values'"))? - as usize; - self.first_value = self + .ok_or_else(|| eof_err!("Not enough data to decode 'values_left'"))? + .try_into() + .map_err(|_| general_err!("invalid 'values_left'"))?; + + let first_value = self .bit_reader .get_zigzag_vlq_int() .ok_or_else(|| eof_err!("Not enough data to decode 'first_value'"))?; + self.first_value = Some( + T::T::from_i64(first_value) + .ok_or_else(|| general_err!("first value too large"))?, + ); + + if self.block_size % 128 != 0 { + return Err(general_err!( + "'block_size' must be a multiple of 128, got {}", + self.block_size + )); + } + + if self.block_size % self.mini_blocks_per_block != 0 { + return Err(general_err!( + "'block_size' must be a multiple of 'mini_blocks_per_block' got {} and {}", + self.block_size, self.mini_blocks_per_block + )); + } + // Reset decoding state - self.first_value_read = false; self.mini_block_idx = 0; - self.delta_bit_widths.clear(); - self.values_current_mini_block = 0; + self.values_per_mini_block = self.block_size / self.mini_blocks_per_block; + self.mini_block_remaining = 0; + self.mini_block_bit_widths.clear(); - self.values_per_mini_block = (block_size / self.num_mini_blocks) as usize; - assert!(self.values_per_mini_block % 8 == 0); + if self.values_per_mini_block % 32 != 0 { + return Err(general_err!( + "'values_per_mini_block' must be a multiple of 32 got {}", + self.values_per_mini_block + )); + } Ok(()) } fn get(&mut self, buffer: &mut [T::T]) -> Result { assert!(self.initialized, "Bit reader is not initialized"); + if buffer.is_empty() { + return Ok(0); + } - let num_values = cmp::min(buffer.len(), self.num_values); - for i in 0..num_values { - if !self.first_value_read { - self.set_decoded_value(buffer, i, self.first_value); - self.current_value = self.first_value; - self.first_value_read = true; - continue; + let mut read = 0; + let to_read = buffer.len().min(self.values_left); + + if let Some(value) = self.first_value.take() { + self.last_value = value; + buffer[0] = value; + read += 1; + self.values_left -= 1; + } + + while read != to_read { + if self.mini_block_remaining == 0 { + self.next_mini_block()?; } - if self.values_current_mini_block == 0 { - self.mini_block_idx += 1; - if self.mini_block_idx < self.delta_bit_widths.size() { - self.delta_bit_width = - self.delta_bit_widths.data()[self.mini_block_idx]; - self.values_current_mini_block = self.values_per_mini_block; - } else { - self.init_block()?; - } - self.load_deltas_in_mini_block()?; + let bit_width = self.mini_block_bit_widths[self.mini_block_idx] as usize; + let batch_to_read = self.mini_block_remaining.min(to_read - read); + + let batch_read = self + .bit_reader + .get_batch(&mut buffer[read..read + batch_to_read], bit_width); + + if batch_read != batch_to_read { + return Err(general_err!( + "Expected to read {} values from miniblock got {}", + batch_to_read, + batch_read + )); + } + + // At this point we have read the deltas to `buffer` we now need to offset + // these to get back to the original values that were encoded + for v in &mut buffer[read..read + batch_read] { + // It is OK for deltas to contain "overflowed" values after encoding, + // e.g. i64::MAX - i64::MIN, so we use `wrapping_add` to "overflow" again and + // restore original value. + *v = v + .wrapping_add(&self.min_delta) + .wrapping_add(&self.last_value); + + self.last_value = *v; } - // we decrement values in current mini block, so we need to invert index for - // delta - let delta = self.get_delta( - self.deltas_in_mini_block.len() - self.values_current_mini_block, - ); - // It is OK for deltas to contain "overflowed" values after encoding, - // e.g. i64::MAX - i64::MIN, so we use `wrapping_add` to "overflow" again and - // restore original value. - self.current_value = self.current_value.wrapping_add(self.min_delta); - self.current_value = self.current_value.wrapping_add(delta as i64); - self.set_decoded_value(buffer, i, self.current_value); - self.values_current_mini_block -= 1; + read += batch_read; + self.mini_block_remaining -= batch_read; + self.values_left -= batch_read; } - self.num_values -= num_values; - Ok(num_values) + Ok(to_read) } fn values_left(&self) -> usize { - self.num_values + self.values_left } fn encoding(&self) -> Encoding { @@ -528,42 +692,6 @@ impl Decoder for DeltaBitPackDecoder { } } -/// Helper trait to define specific conversions when decoding values -trait DeltaBitPackDecoderConversion { - /// Sets decoded value based on type `T`. - fn get_delta(&self, index: usize) -> i64; - - fn set_decoded_value(&self, buffer: &mut [T::T], index: usize, value: i64); -} - -impl DeltaBitPackDecoderConversion for DeltaBitPackDecoder { - #[inline] - fn get_delta(&self, index: usize) -> i64 { - ensure_phys_ty!( - Type::INT32 | Type::INT64, - "DeltaBitPackDecoder only supports Int32Type and Int64Type" - ); - self.deltas_in_mini_block[index].as_i64().unwrap() - } - - #[inline] - fn set_decoded_value(&self, buffer: &mut [T::T], index: usize, value: i64) { - match T::get_physical_type() { - Type::INT32 => { - let val = buffer[index].as_mut_any().downcast_mut::().unwrap(); - - *val = value as i32; - } - Type::INT64 => { - let val = buffer[index].as_mut_any().downcast_mut::().unwrap(); - - *val = value; - } - _ => panic!("DeltaBitPackDecoder only supports Int32Type and Int64Type"), - }; - } -} - // ---------------------------------------------------------------------- // DELTA_LENGTH_BYTE_ARRAY Decoding @@ -636,11 +764,11 @@ impl Decoder for DeltaLengthByteArrayDecoder { let data = self.data.as_ref().unwrap(); let num_values = cmp::min(buffer.len(), self.num_values); - for i in 0..num_values { + + for item in buffer.iter_mut().take(num_values) { let len = self.lengths[self.current_idx] as usize; - buffer[i] - .as_mut_any() + item.as_mut_any() .downcast_mut::() .unwrap() .set_data(data.range(self.offset, len)); @@ -743,7 +871,7 @@ impl<'m, T: DataType> Decoder for DeltaByteArrayDecoder { ty @ Type::BYTE_ARRAY | ty @ Type::FIXED_LEN_BYTE_ARRAY => { let num_values = cmp::min(buffer.len(), self.num_values); let mut v: [ByteArray; 1] = [ByteArray::new(); 1]; - for i in 0..num_values { + for item in buffer.iter_mut().take(num_values) { // Process suffix // TODO: this is awkward - maybe we should add a non-vectorized API? let suffix_decoder = self.suffix_decoder.as_mut().expect("decoder not initialized"); @@ -761,12 +889,12 @@ impl<'m, T: DataType> Decoder for DeltaByteArrayDecoder { let data = ByteBufferPtr::new(result.clone()); match ty { - Type::BYTE_ARRAY => buffer[i] + Type::BYTE_ARRAY => item .as_mut_any() .downcast_mut::() .unwrap() .set_data(data), - Type::FIXED_LEN_BYTE_ARRAY => buffer[i] + Type::FIXED_LEN_BYTE_ARRAY => item .as_mut_any() .downcast_mut::() .unwrap() @@ -808,17 +936,18 @@ mod tests { use crate::schema::types::{ ColumnDescPtr, ColumnDescriptor, ColumnPath, Type as SchemaType, }; - use crate::util::{ - bit_util::set_array_bit, memory::MemTracker, test_common::RandGen, - }; + use crate::util::{bit_util::set_array_bit, test_common::RandGen}; #[test] fn test_get_decoders() { // supported encodings create_and_check_decoder::(Encoding::PLAIN, None); create_and_check_decoder::(Encoding::DELTA_BINARY_PACKED, None); - create_and_check_decoder::(Encoding::DELTA_LENGTH_BYTE_ARRAY, None); - create_and_check_decoder::(Encoding::DELTA_BYTE_ARRAY, None); + create_and_check_decoder::( + Encoding::DELTA_LENGTH_BYTE_ARRAY, + None, + ); + create_and_check_decoder::(Encoding::DELTA_BYTE_ARRAY, None); create_and_check_decoder::(Encoding::RLE, None); // error when initializing @@ -834,6 +963,18 @@ mod tests { "Cannot initialize this encoding through this function" )), ); + create_and_check_decoder::( + Encoding::DELTA_LENGTH_BYTE_ARRAY, + Some(general_err!( + "Encoding DELTA_LENGTH_BYTE_ARRAY is not supported for type" + )), + ); + create_and_check_decoder::( + Encoding::DELTA_BYTE_ARRAY, + Some(general_err!( + "Encoding DELTA_BYTE_ARRAY is not supported for type" + )), + ); // unsupported create_and_check_decoder::( @@ -1192,6 +1333,83 @@ mod tests { assert_eq!(result, vec![29, 43, 89]); } + #[test] + fn test_delta_bit_packed_padding() { + // Page header + let header = vec![ + // Page Header + + // Block Size - 256 + 128, + 2, + // Miniblocks in block, + 4, + // Total value count - 419 + 128 + 35, + 3, + // First value - 7 + 7, + ]; + + // Block Header + let block1_header = vec![ + 0, // Min delta + 0, 1, 0, 0, // Bit widths + ]; + + // Mini-block 1 - bit width 0 => 0 bytes + // Mini-block 2 - bit width 1 => 8 bytes + // Mini-block 3 - bit width 0 => 0 bytes + // Mini-block 4 - bit width 0 => 0 bytes + let block1 = vec![0xFF; 8]; + + // Block Header + let block2_header = vec![ + 0, // Min delta + 0, 1, 2, 0xFF, // Bit widths, including non-zero padding + ]; + + // Mini-block 1 - bit width 0 => 0 bytes + // Mini-block 2 - bit width 1 => 8 bytes + // Mini-block 3 - bit width 2 => 16 bytes + // Mini-block 4 - padding => no bytes + let block2 = vec![0xFF; 24]; + + let data: Vec = header + .into_iter() + .chain(block1_header) + .chain(block1) + .chain(block2_header) + .chain(block2) + .collect(); + + let length = data.len(); + + let ptr = ByteBufferPtr::new(data); + let mut reader = BitReader::new(ptr.clone()); + assert_eq!(reader.get_vlq_int().unwrap(), 256); + assert_eq!(reader.get_vlq_int().unwrap(), 4); + assert_eq!(reader.get_vlq_int().unwrap(), 419); + assert_eq!(reader.get_vlq_int().unwrap(), 7); + + // Test output buffer larger than needed and not exact multiple of block size + let mut output = vec![0_i32; 420]; + + let mut decoder = DeltaBitPackDecoder::::new(); + decoder.set_data(ptr.clone(), 0).unwrap(); + assert_eq!(decoder.get(&mut output).unwrap(), 419); + assert_eq!(decoder.get_offset(), length); + + // Test with truncated buffer + decoder.set_data(ptr.range(0, 12), 0).unwrap(); + let err = decoder.get(&mut output).unwrap_err().to_string(); + assert!( + err.contains("Expected to read 64 values from miniblock got 8"), + "{}", + err + ); + } + #[test] fn test_delta_byte_array_same_arrays() { let data = vec![ @@ -1250,8 +1468,7 @@ mod tests { // Encode data let mut encoder = - get_encoder::(col_descr.clone(), encoding, Arc::new(MemTracker::new())) - .expect("get encoder"); + get_encoder::(col_descr.clone(), encoding).expect("get encoder"); for v in &data[..] { encoder.put(&v[..]).expect("ok to encode"); @@ -1340,11 +1557,11 @@ mod tests { #[allow(clippy::wrong_self_convention)] fn to_byte_array(data: &[bool]) -> Vec { let mut v = vec![]; - for i in 0..data.len() { + for (i, item) in data.iter().enumerate() { if i % 8 == 0 { v.push(0); } - if data[i] { + if *item { set_array_bit(&mut v[..], i); } } diff --git a/parquet/src/encodings/encoding.rs b/parquet/src/encodings/encoding.rs index f452618d5dc2..4b9cf2e9bad5 100644 --- a/parquet/src/encodings/encoding.rs +++ b/parquet/src/encodings/encoding.rs @@ -28,7 +28,7 @@ use crate::schema::types::ColumnDescPtr; use crate::util::{ bit_util::{self, log2, num_required_bits, BitWriter}, hash_util, - memory::{Buffer, ByteBuffer, ByteBufferPtr, MemTrackerPtr}, + memory::ByteBufferPtr, }; // ---------------------------------------------------------------------- @@ -50,9 +50,9 @@ pub trait Encoder { let num_values = values.len(); let mut buffer = Vec::with_capacity(num_values); // TODO: this is pretty inefficient. Revisit in future. - for i in 0..num_values { + for (i, item) in values.iter().enumerate().take(num_values) { if bit_util::get_bit(valid_bits, i) { - buffer.push(values[i].clone()); + buffer.push(item.clone()); } } self.put(&buffer[..])?; @@ -76,10 +76,9 @@ pub trait Encoder { pub fn get_encoder( desc: ColumnDescPtr, encoding: Encoding, - mem_tracker: MemTrackerPtr, ) -> Result>> { let encoder: Box> = match encoding { - Encoding::PLAIN => Box::new(PlainEncoder::new(desc, mem_tracker, vec![])), + Encoding::PLAIN => Box::new(PlainEncoder::new(desc, vec![])), Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => { return Err(general_err!( "Cannot initialize this encoding through this function" @@ -109,7 +108,7 @@ pub fn get_encoder( /// - BYTE_ARRAY - 4 byte length stored as little endian, followed by bytes. /// - FIXED_LEN_BYTE_ARRAY - just the bytes are stored. pub struct PlainEncoder { - buffer: ByteBuffer, + buffer: Vec, bit_writer: BitWriter, desc: ColumnDescPtr, _phantom: PhantomData, @@ -117,11 +116,9 @@ pub struct PlainEncoder { impl PlainEncoder { /// Creates new plain encoder. - pub fn new(desc: ColumnDescPtr, mem_tracker: MemTrackerPtr, vec: Vec) -> Self { - let mut byte_buffer = ByteBuffer::new().with_mem_tracker(mem_tracker); - byte_buffer.set_data(vec); + pub fn new(desc: ColumnDescPtr, buffer: Vec) -> Self { Self { - buffer: byte_buffer, + buffer, bit_writer: BitWriter::new(256), desc, _phantom: PhantomData, @@ -139,16 +136,15 @@ impl Encoder for PlainEncoder { } fn estimated_data_encoded_size(&self) -> usize { - self.buffer.size() + self.bit_writer.bytes_written() + self.buffer.len() + self.bit_writer.bytes_written() } #[inline] fn flush_buffer(&mut self) -> Result { - self.buffer.write_all(self.bit_writer.flush_buffer())?; - self.buffer.flush()?; + self.buffer + .extend_from_slice(self.bit_writer.flush_buffer()); self.bit_writer.clear(); - - Ok(self.buffer.consume()) + Ok(std::mem::take(&mut self.buffer).into()) } #[inline] @@ -189,35 +185,31 @@ pub struct DictEncoder { // Stores indices which map (many-to-one) to the values in the `uniques` array. // Here we are using fix-sized array with linear probing. // A slot with `HASH_SLOT_EMPTY` indicates the slot is not currently occupied. - hash_slots: Buffer, + hash_slots: Vec, // Indices that have not yet be written out by `write_indices()`. - buffered_indices: Buffer, + buffered_indices: Vec, // The unique observed values. - uniques: Buffer, + uniques: Vec, // Size in bytes needed to encode this dictionary. uniques_size_in_bytes: usize, - - // Tracking memory usage for the various data structures in this struct. - mem_tracker: MemTrackerPtr, } impl DictEncoder { /// Creates new dictionary encoder. - pub fn new(desc: ColumnDescPtr, mem_tracker: MemTrackerPtr) -> Self { - let mut slots = Buffer::new().with_mem_tracker(mem_tracker.clone()); + pub fn new(desc: ColumnDescPtr) -> Self { + let mut slots = vec![]; slots.resize(INITIAL_HASH_TABLE_SIZE, -1); Self { desc, hash_table_size: INITIAL_HASH_TABLE_SIZE, mod_bitmask: (INITIAL_HASH_TABLE_SIZE - 1) as u32, hash_slots: slots, - buffered_indices: Buffer::new().with_mem_tracker(mem_tracker.clone()), - uniques: Buffer::new().with_mem_tracker(mem_tracker.clone()), + buffered_indices: vec![], + uniques: vec![], uniques_size_in_bytes: 0, - mem_tracker, } } @@ -230,7 +222,7 @@ impl DictEncoder { /// Returns number of unique values (keys) in the dictionary. pub fn num_entries(&self) -> usize { - self.uniques.size() + self.uniques.len() } /// Returns size of unique values (keys) in the dictionary, in bytes. @@ -242,9 +234,8 @@ impl DictEncoder { /// the result. #[inline] pub fn write_dict(&self) -> Result { - let mut plain_encoder = - PlainEncoder::::new(self.desc.clone(), self.mem_tracker.clone(), vec![]); - plain_encoder.put(self.uniques.data())?; + let mut plain_encoder = PlainEncoder::::new(self.desc.clone(), vec![]); + plain_encoder.put(&self.uniques)?; plain_encoder.flush_buffer() } @@ -255,12 +246,11 @@ impl DictEncoder { let buffer_len = self.estimated_data_encoded_size(); let mut buffer: Vec = vec![0; buffer_len as usize]; buffer[0] = self.bit_width() as u8; - self.mem_tracker.alloc(buffer.capacity() as i64); // Write bit width in the first byte buffer.write_all((self.bit_width() as u8).as_bytes())?; let mut encoder = RleEncoder::new_from_buf(self.bit_width(), buffer, 1); - for index in self.buffered_indices.data() { + for index in &self.buffered_indices { if !encoder.put(*index as u64)? { return Err(general_err!("Encoder doesn't have enough space")); } @@ -293,7 +283,7 @@ impl DictEncoder { #[inline(never)] fn insert_fresh_slot(&mut self, slot: usize, value: T::T) -> i32 { - let index = self.uniques.size() as i32; + let index = self.uniques.len() as i32; self.hash_slots[slot] = index; let (base_size, num_elements) = value.dict_encoding_size(); @@ -307,7 +297,7 @@ impl DictEncoder { self.uniques_size_in_bytes += unique_size; self.uniques.push(value); - if self.uniques.size() > (self.hash_table_size as f32 * MAX_HASH_LOAD) as usize { + if self.uniques.len() > (self.hash_table_size as f32 * MAX_HASH_LOAD) as usize { self.double_table_size(); } @@ -316,7 +306,7 @@ impl DictEncoder { #[inline] fn bit_width(&self) -> u8 { - let num_entries = self.uniques.size(); + let num_entries = self.uniques.len(); if num_entries == 0 { 0 } else if num_entries == 1 { @@ -328,7 +318,7 @@ impl DictEncoder { fn double_table_size(&mut self) { let new_size = self.hash_table_size * 2; - let mut new_hash_slots = Buffer::new().with_mem_tracker(self.mem_tracker.clone()); + let mut new_hash_slots = vec![]; new_hash_slots.resize(new_size, HASH_SLOT_EMPTY); for i in 0..self.hash_table_size { let index = self.hash_slots[i]; @@ -376,7 +366,7 @@ impl Encoder for DictEncoder { fn estimated_data_encoded_size(&self) -> usize { let bit_width = self.bit_width(); 1 + RleEncoder::min_buffer_size(bit_width) - + RleEncoder::max_buffer_size(bit_width, self.buffered_indices.size()) + + RleEncoder::max_buffer_size(bit_width, self.buffered_indices.len()) } #[inline] @@ -565,7 +555,7 @@ impl DeltaBitPackEncoder { return Ok(()); } - let mut min_delta = i64::max_value(); + let mut min_delta = i64::MAX; for i in 0..self.values_in_block { min_delta = cmp::min(min_delta, self.deltas[i]); } @@ -581,6 +571,13 @@ impl DeltaBitPackEncoder { // values left let n = cmp::min(self.mini_block_size, self.values_in_block); if n == 0 { + // Decoders should be agnostic to the padding value, we therefore use 0xFF + // when running tests. However, not all implementations may handle this correctly + // so pad with 0 when not running tests + let pad_value = cfg!(test).then(|| 0xFF).unwrap_or(0); + for j in i..self.num_mini_blocks { + self.bit_writer.write_at(offset + j, pad_value); + } break; } @@ -610,8 +607,8 @@ impl DeltaBitPackEncoder { self.values_in_block -= n; } - assert!( - self.values_in_block == 0, + assert_eq!( + self.values_in_block, 0, "Expected 0 values in block, found {}", self.values_in_block ); @@ -670,10 +667,9 @@ impl Encoder for DeltaBitPackEncoder { // Write page header with total values self.write_page_header(); - let mut buffer = ByteBuffer::new(); - buffer.write_all(self.page_header_writer.flush_buffer())?; - buffer.write_all(self.bit_writer.flush_buffer())?; - buffer.flush()?; + let mut buffer = Vec::new(); + buffer.extend_from_slice(self.page_header_writer.flush_buffer()); + buffer.extend_from_slice(self.bit_writer.flush_buffer()); // Reset state self.page_header_writer.clear(); @@ -683,7 +679,7 @@ impl Encoder for DeltaBitPackEncoder { self.current_value = 0; self.values_in_block = 0; - Ok(buffer.consume()) + Ok(buffer.into()) } } @@ -922,14 +918,11 @@ mod tests { use std::sync::Arc; - use crate::decoding::{get_decoder, Decoder, DictDecoder, PlainDecoder}; + use crate::encodings::decoding::{get_decoder, Decoder, DictDecoder, PlainDecoder}; use crate::schema::types::{ ColumnDescPtr, ColumnDescriptor, ColumnPath, Type as SchemaType, }; - use crate::util::{ - memory::MemTracker, - test_common::{random_bytes, RandGen}, - }; + use crate::util::test_common::{random_bytes, RandGen}; const TEST_SET_SIZE: usize = 1024; @@ -1128,11 +1121,13 @@ mod tests { let mut decoder = create_test_decoder::(0, Encoding::DELTA_BYTE_ARRAY); - let mut input = vec![]; - input.push(ByteArray::from("aa")); - input.push(ByteArray::from("aaa")); - input.push(ByteArray::from("aa")); - input.push(ByteArray::from("aaa")); + let input = vec![ + ByteArray::from("aa"), + ByteArray::from("aaa"), + ByteArray::from("aa"), + ByteArray::from("aaa"), + ]; + let mut output = vec![ByteArray::default(); input.len()]; let mut result = @@ -1277,8 +1272,7 @@ mod tests { err: Option, ) { let descr = create_test_col_desc_ptr(-1, T::get_physical_type()); - let mem_tracker = Arc::new(MemTracker::new()); - let encoder = get_encoder::(descr, encoding, mem_tracker); + let encoder = get_encoder::(descr, encoding); match err { Some(parquet_error) => { assert!(encoder.is_err()); @@ -1310,8 +1304,7 @@ mod tests { enc: Encoding, ) -> Box> { let desc = create_test_col_desc_ptr(type_len, T::get_physical_type()); - let mem_tracker = Arc::new(MemTracker::new()); - get_encoder(desc, enc, mem_tracker).unwrap() + get_encoder(desc, enc).unwrap() } fn create_test_decoder( @@ -1324,8 +1317,7 @@ mod tests { fn create_test_dict_encoder(type_len: i32) -> DictEncoder { let desc = create_test_col_desc_ptr(type_len, T::get_physical_type()); - let mem_tracker = Arc::new(MemTracker::new()); - DictEncoder::::new(desc, mem_tracker) + DictEncoder::::new(desc) } fn create_test_dict_decoder() -> DictDecoder { diff --git a/parquet/src/encodings/levels.rs b/parquet/src/encodings/levels.rs index b82d2959becf..c8682e06d391 100644 --- a/parquet/src/encodings/levels.rs +++ b/parquet/src/encodings/levels.rs @@ -50,11 +50,11 @@ pub fn max_buffer_size( } /// Encoder for definition/repetition levels. -/// Currently only supports RLE and BIT_PACKED (dev/null) encoding, including v2. +/// Currently only supports Rle and BitPacked (dev/null) encoding, including v2. pub enum LevelEncoder { - RLE(RleEncoder), - RLE_V2(RleEncoder), - BIT_PACKED(u8, BitWriter), + Rle(RleEncoder), + RleV2(RleEncoder), + BitPacked(u8, BitWriter), } impl LevelEncoder { @@ -68,7 +68,7 @@ impl LevelEncoder { pub fn v1(encoding: Encoding, max_level: i16, byte_buffer: Vec) -> Self { let bit_width = log2(max_level as u64 + 1) as u8; match encoding { - Encoding::RLE => LevelEncoder::RLE(RleEncoder::new_from_buf( + Encoding::RLE => LevelEncoder::Rle(RleEncoder::new_from_buf( bit_width, byte_buffer, mem::size_of::(), @@ -77,7 +77,7 @@ impl LevelEncoder { // Here we set full byte buffer without adjusting for num_buffered_values, // because byte buffer will already be allocated with size from // `max_buffer_size()` method. - LevelEncoder::BIT_PACKED( + LevelEncoder::BitPacked( bit_width, BitWriter::new_from_buf(byte_buffer, 0), ) @@ -90,7 +90,7 @@ impl LevelEncoder { /// repetition and definition levels. pub fn v2(max_level: i16, byte_buffer: Vec) -> Self { let bit_width = log2(max_level as u64 + 1) as u8; - LevelEncoder::RLE_V2(RleEncoder::new_from_buf(bit_width, byte_buffer, 0)) + LevelEncoder::RleV2(RleEncoder::new_from_buf(bit_width, byte_buffer, 0)) } /// Put/encode levels vector into this level encoder. @@ -103,8 +103,7 @@ impl LevelEncoder { pub fn put(&mut self, buffer: &[i16]) -> Result { let mut num_encoded = 0; match *self { - LevelEncoder::RLE(ref mut encoder) - | LevelEncoder::RLE_V2(ref mut encoder) => { + LevelEncoder::Rle(ref mut encoder) | LevelEncoder::RleV2(ref mut encoder) => { for value in buffer { if !encoder.put(*value as u64)? { return Err(general_err!("RLE buffer is full")); @@ -113,7 +112,7 @@ impl LevelEncoder { } encoder.flush()?; } - LevelEncoder::BIT_PACKED(bit_width, ref mut encoder) => { + LevelEncoder::BitPacked(bit_width, ref mut encoder) => { for value in buffer { if !encoder.put_value(*value as u64, bit_width as usize) { return Err(general_err!("Not enough bytes left")); @@ -131,7 +130,7 @@ impl LevelEncoder { #[inline] pub fn consume(self) -> Result> { match self { - LevelEncoder::RLE(encoder) => { + LevelEncoder::Rle(encoder) => { let mut encoded_data = encoder.consume()?; // Account for the buffer offset let encoded_len = encoded_data.len() - mem::size_of::(); @@ -140,8 +139,8 @@ impl LevelEncoder { encoded_data[0..len_bytes.len()].copy_from_slice(len_bytes); Ok(encoded_data) } - LevelEncoder::RLE_V2(encoder) => encoder.consume(), - LevelEncoder::BIT_PACKED(_, encoder) => Ok(encoder.consume()), + LevelEncoder::RleV2(encoder) => encoder.consume(), + LevelEncoder::BitPacked(_, encoder) => Ok(encoder.consume()), } } } @@ -150,9 +149,9 @@ impl LevelEncoder { /// Currently only supports RLE and BIT_PACKED encoding for Data Page v1 and /// RLE for Data Page v2. pub enum LevelDecoder { - RLE(Option, RleDecoder), - RLE_V2(Option, RleDecoder), - BIT_PACKED(Option, u8, BitReader), + Rle(Option, RleDecoder), + RleV2(Option, RleDecoder), + BitPacked(Option, u8, BitReader), } impl LevelDecoder { @@ -166,9 +165,9 @@ impl LevelDecoder { pub fn v1(encoding: Encoding, max_level: i16) -> Self { let bit_width = log2(max_level as u64 + 1) as u8; match encoding { - Encoding::RLE => LevelDecoder::RLE(None, RleDecoder::new(bit_width)), + Encoding::RLE => LevelDecoder::Rle(None, RleDecoder::new(bit_width)), Encoding::BIT_PACKED => { - LevelDecoder::BIT_PACKED(None, bit_width, BitReader::from(Vec::new())) + LevelDecoder::BitPacked(None, bit_width, BitReader::from(Vec::new())) } _ => panic!("Unsupported encoding type {}", encoding), } @@ -180,7 +179,7 @@ impl LevelDecoder { /// To set data for this decoder, use `set_data_range` method. pub fn v2(max_level: i16) -> Self { let bit_width = log2(max_level as u64 + 1) as u8; - LevelDecoder::RLE_V2(None, RleDecoder::new(bit_width)) + LevelDecoder::RleV2(None, RleDecoder::new(bit_width)) } /// Sets data for this level decoder, and returns total number of bytes set. @@ -194,21 +193,21 @@ impl LevelDecoder { #[inline] pub fn set_data(&mut self, num_buffered_values: usize, data: ByteBufferPtr) -> usize { match *self { - LevelDecoder::RLE(ref mut num_values, ref mut decoder) => { + LevelDecoder::Rle(ref mut num_values, ref mut decoder) => { *num_values = Some(num_buffered_values); let i32_size = mem::size_of::(); let data_size = read_num_bytes!(i32, i32_size, data.as_ref()) as usize; decoder.set_data(data.range(i32_size, data_size)); i32_size + data_size } - LevelDecoder::BIT_PACKED(ref mut num_values, bit_width, ref mut decoder) => { + LevelDecoder::BitPacked(ref mut num_values, bit_width, ref mut decoder) => { *num_values = Some(num_buffered_values); // Set appropriate number of bytes: if max size is larger than buffer - // set full buffer let num_bytes = ceil((num_buffered_values * bit_width as usize) as i64, 8); let data_size = cmp::min(num_bytes as usize, data.len()); - decoder.reset(data.range(data.start(), data_size)); + decoder.reset(data.range(0, data_size)); data_size } _ => panic!(), @@ -227,7 +226,7 @@ impl LevelDecoder { len: usize, ) -> usize { match *self { - LevelDecoder::RLE_V2(ref mut num_values, ref mut decoder) => { + LevelDecoder::RleV2(ref mut num_values, ref mut decoder) => { decoder.set_data(data.range(start, len)); *num_values = Some(num_buffered_values); len @@ -242,9 +241,9 @@ impl LevelDecoder { #[inline] pub fn is_data_set(&self) -> bool { match self { - LevelDecoder::RLE(ref num_values, _) => num_values.is_some(), - LevelDecoder::RLE_V2(ref num_values, _) => num_values.is_some(), - LevelDecoder::BIT_PACKED(ref num_values, ..) => num_values.is_some(), + LevelDecoder::Rle(ref num_values, _) => num_values.is_some(), + LevelDecoder::RleV2(ref num_values, _) => num_values.is_some(), + LevelDecoder::BitPacked(ref num_values, ..) => num_values.is_some(), } } @@ -255,15 +254,15 @@ impl LevelDecoder { pub fn get(&mut self, buffer: &mut [i16]) -> Result { assert!(self.is_data_set(), "No data set for decoding"); match *self { - LevelDecoder::RLE(ref mut num_values, ref mut decoder) - | LevelDecoder::RLE_V2(ref mut num_values, ref mut decoder) => { + LevelDecoder::Rle(ref mut num_values, ref mut decoder) + | LevelDecoder::RleV2(ref mut num_values, ref mut decoder) => { // Max length we can read let len = cmp::min(num_values.unwrap(), buffer.len()); let values_read = decoder.get_batch::(&mut buffer[0..len])?; *num_values = num_values.map(|len| len - values_read); Ok(values_read) } - LevelDecoder::BIT_PACKED(ref mut num_values, bit_width, ref mut decoder) => { + LevelDecoder::BitPacked(ref mut num_values, bit_width, ref mut decoder) => { // When extracting values from bit reader, it might return more values // than left because of padding to a full byte, we use // num_values to track precise number of values. diff --git a/parquet/src/encodings/mod.rs b/parquet/src/encodings/mod.rs index 6046ddaec805..9577a8e624f6 100644 --- a/parquet/src/encodings/mod.rs +++ b/parquet/src/encodings/mod.rs @@ -18,4 +18,4 @@ pub mod decoding; pub mod encoding; pub mod levels; -pub(crate) mod rle; +experimental_mod_crate!(rle); diff --git a/parquet/src/encodings/rle.rs b/parquet/src/encodings/rle.rs index b2a23da7c0b8..64a112101f3c 100644 --- a/parquet/src/encodings/rle.rs +++ b/parquet/src/encodings/rle.rs @@ -310,6 +310,9 @@ impl RleEncoder { } } +/// Size, in number of `i32s` of buffer to use for RLE batch reading +const RLE_DECODER_INDEX_BUFFER_SIZE: usize = 1024; + /// A RLE/Bit-Packing hybrid decoder. pub struct RleDecoder { // Number of bits used to encode the value. Must be between [0, 64]. @@ -319,7 +322,7 @@ pub struct RleDecoder { bit_reader: Option, // Buffer used when `bit_reader` is not `None`, for batch reading. - index_buf: [i32; 1024], + index_buf: Option>, // The remaining number of values in RLE for this run rle_left: u32, @@ -338,7 +341,7 @@ impl RleDecoder { rle_left: 0, bit_packed_left: 0, bit_reader: None, - index_buf: [0; 1024], + index_buf: None, current_value: None, } } @@ -416,6 +419,11 @@ impl RleDecoder { &mut buffer[values_read..values_read + num_values], self.bit_width as usize, ); + if num_values == 0 { + // Handle writers which truncate the final block + self.bit_packed_left = 0; + continue; + } self.bit_packed_left -= num_values as u32; values_read += num_values; } else if !self.reload() { @@ -440,6 +448,8 @@ impl RleDecoder { let mut values_read = 0; while values_read < max_values { + let index_buf = self.index_buf.get_or_insert_with(|| Box::new([0; 1024])); + if self.rle_left > 0 { let num_values = cmp::min(max_values - values_read, self.rle_left as usize); @@ -456,19 +466,23 @@ impl RleDecoder { let mut num_values = cmp::min(max_values - values_read, self.bit_packed_left as usize); - num_values = cmp::min(num_values, self.index_buf.len()); + num_values = cmp::min(num_values, index_buf.len()); loop { num_values = bit_reader.get_batch::( - &mut self.index_buf[..num_values], + &mut index_buf[..num_values], self.bit_width as usize, ); + if num_values == 0 { + // Handle writers which truncate the final block + self.bit_packed_left = 0; + break; + } for i in 0..num_values { - buffer[values_read + i] - .clone_from(&dict[self.index_buf[i] as usize]) + buffer[values_read + i].clone_from(&dict[index_buf[i] as usize]) } self.bit_packed_left -= num_values as u32; values_read += num_values; - if num_values < self.index_buf.len() { + if num_values < index_buf.len() { break; } } @@ -660,13 +674,9 @@ mod tests { #[test] fn test_rle_specific_sequences() { let mut expected_buffer = Vec::new(); - let mut values = Vec::new(); - for _ in 0..50 { - values.push(0); - } - for _ in 0..50 { - values.push(1); - } + let mut values = vec![0; 50]; + values.resize(100, 1); + expected_buffer.push(50 << 1); expected_buffer.push(0); expected_buffer.push(50 << 1); @@ -692,9 +702,8 @@ mod tests { } let num_groups = bit_util::ceil(100, 8) as u8; expected_buffer.push(((num_groups << 1) as u8) | 1); - for _ in 1..(100 / 8) + 1 { - expected_buffer.push(0b10101010); - } + expected_buffer.resize(expected_buffer.len() + 100 / 8, 0b10101010); + // For the last 4 0 and 1's, padded with 0. expected_buffer.push(0b00001010); validate_rle( @@ -744,6 +753,42 @@ mod tests { } } + #[test] + fn test_truncated_rle() { + // The final bit packed run within a page may not be a multiple of 8 values + // Unfortunately the specification stores `(bit-packed-run-len) / 8` + // This means we don't necessarily know how many values are present + // and some writers may not add padding to compensate for this ambiguity + + // Bit pack encode 20 values with a bit width of 8 + let mut data: Vec = vec![ + (3 << 1) | 1, // bit-packed run of 3 * 8 + ]; + data.extend(std::iter::repeat(0xFF).take(20)); + let data = ByteBufferPtr::new(data); + + let mut decoder = RleDecoder::new(8); + decoder.set_data(data.clone()); + + let mut output = vec![0_u16; 100]; + let read = decoder.get_batch(&mut output).unwrap(); + + assert_eq!(read, 20); + assert!(output.iter().take(20).all(|x| *x == 255)); + + // Reset decoder + decoder.set_data(data); + + let dict: Vec = (0..256).collect(); + let mut output = vec![0_u16; 100]; + let read = decoder + .get_batch_with_dict(&dict, &mut output, 100) + .unwrap(); + + assert_eq!(read, 20); + assert!(output.iter().take(20).all(|x| *x == 255)); + } + #[test] fn test_rle_specific_roundtrip() { let bit_width = 1; diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs index be1a22192954..c2fb5bd66cf9 100644 --- a/parquet/src/errors.rs +++ b/parquet/src/errors.rs @@ -17,7 +17,7 @@ //! Common Parquet errors and macros. -use std::{cell, convert, io, result, str}; +use std::{cell, io, result, str}; #[cfg(any(feature = "arrow", test))] use arrow::error::ArrowError; @@ -108,7 +108,7 @@ pub type Result = result::Result; // ---------------------------------------------------------------------- // Conversion from `ParquetError` to other types of `Error`s -impl convert::From for io::Error { +impl From for io::Error { fn from(e: ParquetError) -> Self { io::Error::new(io::ErrorKind::Other, e) } @@ -135,6 +135,15 @@ macro_rules! eof_err { ($fmt:expr, $($args:expr),*) => (ParquetError::EOF(format!($fmt, $($args),*))); } +#[cfg(any(feature = "arrow", test))] +macro_rules! arrow_err { + ($fmt:expr) => (ParquetError::ArrowError($fmt.to_owned())); + ($fmt:expr, $($args:expr),*) => (ParquetError::ArrowError(format!($fmt, $($args),*))); + ($e:expr, $fmt:expr) => (ParquetError::ArrowError($fmt.to_owned(), $e)); + ($e:ident, $fmt:expr, $($args:tt),*) => ( + ParquetError::ArrowError(&format!($fmt, $($args),*), $e)); +} + // ---------------------------------------------------------------------- // Convert parquet error into other errors diff --git a/parquet/src/file/footer.rs b/parquet/src/file/footer.rs index 2e572944868b..dc1d66d0fa44 100644 --- a/parquet/src/file/footer.rs +++ b/parquet/src/file/footer.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{ - cmp::min, - io::{Cursor, Read, Seek, SeekFrom}, - sync::Arc, -}; +use std::{io::Read, sync::Arc}; use byteorder::{ByteOrder, LittleEndian}; use parquet_format::{ColumnOrder as TColumnOrder, FileMetaData as TFileMetaData}; @@ -28,10 +24,7 @@ use thrift::protocol::TCompactInputProtocol; use crate::basic::ColumnOrder; use crate::errors::{ParquetError, Result}; -use crate::file::{ - metadata::*, reader::ChunkReader, DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, - PARQUET_MAGIC, -}; +use crate::file::{metadata::*, reader::ChunkReader, FOOTER_SIZE, PARQUET_MAGIC}; use crate::schema::types::{self, SchemaDescriptor}; @@ -52,51 +45,42 @@ pub fn parse_metadata(chunk_reader: &R) -> Result file_size as usize { return Err(general_err!( - "Invalid Parquet file. Metadata length is less than zero ({})", - metadata_len + "Invalid Parquet file. Reported metadata length of {} + {} byte footer, but file is only {} bytes", + metadata_len, + FOOTER_SIZE, + file_size )); } - let footer_metadata_len = FOOTER_SIZE + metadata_len as usize; - // build up the reader covering the entire metadata - let mut default_end_cursor = Cursor::new(default_len_end_buf); - let metadata_read: Box; - if footer_metadata_len > file_size as usize { - return Err(general_err!( - "Invalid Parquet file. Metadata start is less than zero ({})", - file_size as i64 - footer_metadata_len as i64 + let mut metadata = Vec::with_capacity(metadata_len); + + let read = chunk_reader + .get_read(file_size - footer_metadata_len as u64, metadata_len)? + .read_to_end(&mut metadata)?; + + if read != metadata_len { + return Err(eof_err!( + "Expected to read {} bytes of metadata, got {}", + metadata_len, + read )); - } else if footer_metadata_len < DEFAULT_FOOTER_READ_SIZE { - // the whole metadata is in the bytes we already read - default_end_cursor.seek(SeekFrom::End(-(footer_metadata_len as i64)))?; - metadata_read = Box::new(default_end_cursor); - } else { - // the end of file read by default is not long enough, read missing bytes - let complementary_end_read = chunk_reader.get_read( - file_size - footer_metadata_len as u64, - FOOTER_SIZE + metadata_len as usize - default_end_len, - )?; - metadata_read = Box::new(complementary_end_read.chain(default_end_cursor)); } + decode_metadata(&metadata) +} + +/// Decodes [`ParquetMetaData`] from the provided bytes +pub fn decode_metadata(metadata_read: &[u8]) -> Result { // TODO: row group filtering let mut prot = TCompactInputProtocol::new(metadata_read); let t_file_metadata: TFileMetaData = TFileMetaData::read_from_in_protocol(&mut prot) @@ -120,6 +104,23 @@ pub fn parse_metadata(chunk_reader: &R) -> Result Result { + // check this is indeed a parquet file + if slice[4..] != PARQUET_MAGIC { + return Err(general_err!("Invalid Parquet file. Corrupt footer")); + } + + // get the metadata length from the footer + let metadata_len = LittleEndian::read_i32(&slice[..4]); + metadata_len.try_into().map_err(|_| { + general_err!( + "Invalid Parquet file. Metadata length is less than zero ({})", + metadata_len + ) + }) +} + /// Parses column orders from Thrift definition. /// If no column orders are defined, returns `None`. fn parse_column_orders( @@ -156,16 +157,16 @@ fn parse_column_orders( #[cfg(test)] mod tests { use super::*; + use bytes::Bytes; use crate::basic::SortOrder; use crate::basic::Type; use crate::schema::types::Type as SchemaType; - use crate::util::test_common::get_temp_file; use parquet_format::TypeDefinedOrder; #[test] fn test_parse_metadata_size_smaller_than_footer() { - let test_file = get_temp_file("corrupt-1.parquet", &[]); + let test_file = tempfile::tempfile().unwrap(); let reader_result = parse_metadata(&test_file); assert!(reader_result.is_err()); assert_eq!( @@ -176,8 +177,8 @@ mod tests { #[test] fn test_parse_metadata_corrupt_footer() { - let test_file = get_temp_file("corrupt-2.parquet", &[1, 2, 3, 4, 5, 6, 7, 8]); - let reader_result = parse_metadata(&test_file); + let data = Bytes::from(vec![1, 2, 3, 4, 5, 6, 7, 8]); + let reader_result = parse_metadata(&data); assert!(reader_result.is_err()); assert_eq!( reader_result.err().unwrap(), @@ -187,8 +188,7 @@ mod tests { #[test] fn test_parse_metadata_invalid_length() { - let test_file = - get_temp_file("corrupt-3.parquet", &[0, 0, 0, 255, b'P', b'A', b'R', b'1']); + let test_file = Bytes::from(vec![0, 0, 0, 255, b'P', b'A', b'R', b'1']); let reader_result = parse_metadata(&test_file); assert!(reader_result.is_err()); assert_eq!( @@ -201,13 +201,14 @@ mod tests { #[test] fn test_parse_metadata_invalid_start() { - let test_file = - get_temp_file("corrupt-4.parquet", &[255, 0, 0, 0, b'P', b'A', b'R', b'1']); + let test_file = Bytes::from(vec![255, 0, 0, 0, b'P', b'A', b'R', b'1']); let reader_result = parse_metadata(&test_file); assert!(reader_result.is_err()); assert_eq!( reader_result.err().unwrap(), - general_err!("Invalid Parquet file. Metadata start is less than zero (-255)") + general_err!( + "Invalid Parquet file. Reported metadata length of 255 + 8 byte footer, but file is only 8 bytes" + ) ); } diff --git a/parquet/src/file/metadata.rs b/parquet/src/file/metadata.rs index ca5a823c9d15..a3477dd75779 100644 --- a/parquet/src/file/metadata.rs +++ b/parquet/src/file/metadata.rs @@ -35,10 +35,12 @@ use std::sync::Arc; -use parquet_format::{ColumnChunk, ColumnMetaData, RowGroup}; +use parquet_format::{ColumnChunk, ColumnMetaData, PageLocation, RowGroup}; use crate::basic::{ColumnOrder, Compression, Encoding, Type}; use crate::errors::{ParquetError, Result}; +use crate::file::page_encoding_stats::{self, PageEncodingStats}; +use crate::file::page_index::index::Index; use crate::file::statistics::{self, Statistics}; use crate::schema::types::{ ColumnDescPtr, ColumnDescriptor, ColumnPath, SchemaDescPtr, SchemaDescriptor, @@ -50,6 +52,8 @@ use crate::schema::types::{ pub struct ParquetMetaData { file_metadata: FileMetaData, row_groups: Vec, + page_indexes: Option>, + offset_indexes: Option>>, } impl ParquetMetaData { @@ -59,6 +63,22 @@ impl ParquetMetaData { ParquetMetaData { file_metadata, row_groups, + page_indexes: None, + offset_indexes: None, + } + } + + pub fn new_with_page_index( + file_metadata: FileMetaData, + row_groups: Vec, + page_indexes: Option>, + offset_indexes: Option>>, + ) -> Self { + ParquetMetaData { + file_metadata, + row_groups, + page_indexes, + offset_indexes, } } @@ -82,6 +102,16 @@ impl ParquetMetaData { pub fn row_groups(&self) -> &[RowGroupMetaData] { &self.row_groups } + + /// Returns page indexes in this file. + pub fn page_indexes(&self) -> Option<&Vec> { + self.page_indexes.as_ref() + } + + /// Returns offset indexes in this file. + pub fn offset_indexes(&self) -> Option<&Vec>> { + self.offset_indexes.as_ref() + } } pub type KeyValue = parquet_format::KeyValue; @@ -138,13 +168,13 @@ impl FileMetaData { /// ```shell /// parquet-mr version 1.8.0 (build 0fda28af84b9746396014ad6a415b90592a98b3b) /// ``` - pub fn created_by(&self) -> &Option { - &self.created_by + pub fn created_by(&self) -> Option<&str> { + self.created_by.as_deref() } /// Returns key_value_metadata of this file. - pub fn key_value_metadata(&self) -> &Option> { - &self.key_value_metadata + pub fn key_value_metadata(&self) -> Option<&Vec> { + self.key_value_metadata.as_ref() } /// Returns Parquet ['Type`] that describes schema in this file. @@ -187,12 +217,13 @@ impl FileMetaData { pub type RowGroupMetaDataPtr = Arc; /// Metadata for a row group. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct RowGroupMetaData { columns: Vec, num_rows: i64, total_byte_size: i64, schema_descr: SchemaDescPtr, + // Todo add filter result -> row range } impl RowGroupMetaData { @@ -269,6 +300,9 @@ impl RowGroupMetaData { total_byte_size: self.total_byte_size, num_rows: self.num_rows, sorting_columns: None, + file_offset: None, + total_compressed_size: None, + ordinal: None, } } } @@ -330,7 +364,7 @@ impl RowGroupMetaDataBuilder { } /// Metadata for a column chunk. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ColumnChunkMetaData { column_type: Type, column_path: ColumnPath, @@ -346,6 +380,12 @@ pub struct ColumnChunkMetaData { index_page_offset: Option, dictionary_page_offset: Option, statistics: Option, + encoding_stats: Option>, + bloom_filter_offset: Option, + offset_index_offset: Option, + offset_index_length: Option, + column_index_offset: Option, + column_index_length: Option, } /// Represents common operations for a column chunk. @@ -418,21 +458,11 @@ impl ColumnChunkMetaData { self.data_page_offset } - /// Returns `true` if this column chunk contains a index page, `false` otherwise. - pub fn has_index_page(&self) -> bool { - self.index_page_offset.is_some() - } - /// Returns the offset for the index page. pub fn index_page_offset(&self) -> Option { self.index_page_offset } - /// Returns `true` if this column chunk contains a dictionary page, `false` otherwise. - pub fn has_dictionary_page(&self) -> bool { - self.dictionary_page_offset.is_some() - } - /// Returns the offset for the dictionary page, if any. pub fn dictionary_page_offset(&self) -> Option { self.dictionary_page_offset @@ -440,10 +470,9 @@ impl ColumnChunkMetaData { /// Returns the offset and length in bytes of the column chunk within the file pub fn byte_range(&self) -> (u64, u64) { - let col_start = if self.has_dictionary_page() { - self.dictionary_page_offset().unwrap() - } else { - self.data_page_offset() + let col_start = match self.dictionary_page_offset() { + Some(dictionary_page_offset) => dictionary_page_offset, + None => self.data_page_offset(), }; let col_len = self.compressed_size(); assert!( @@ -459,6 +488,37 @@ impl ColumnChunkMetaData { self.statistics.as_ref() } + /// Returns the offset for the page encoding stats, + /// or `None` if no page encoding stats are available. + pub fn page_encoding_stats(&self) -> Option<&Vec> { + self.encoding_stats.as_ref() + } + + /// Returns the offset for the bloom filter. + pub fn bloom_filter_offset(&self) -> Option { + self.bloom_filter_offset + } + + /// Returns the offset for the column index. + pub fn column_index_offset(&self) -> Option { + self.column_index_offset + } + + /// Returns the offset for the column index length. + pub fn column_index_length(&self) -> Option { + self.column_index_length + } + + /// Returns the offset for the offset index. + pub fn offset_index_offset(&self) -> Option { + self.offset_index_offset + } + + /// Returns the offset for the offset index length. + pub fn offset_index_length(&self) -> Option { + self.offset_index_length + } + /// Method to convert from Thrift. pub fn from_thrift(column_descr: ColumnDescPtr, cc: ColumnChunk) -> Result { if cc.meta_data.is_none() { @@ -482,6 +542,16 @@ impl ColumnChunkMetaData { let index_page_offset = col_metadata.index_page_offset; let dictionary_page_offset = col_metadata.dictionary_page_offset; let statistics = statistics::from_thrift(column_type, col_metadata.statistics); + let encoding_stats = col_metadata + .encoding_stats + .as_ref() + .map(|vec| vec.iter().map(page_encoding_stats::from_thrift).collect()); + let bloom_filter_offset = col_metadata.bloom_filter_offset; + let offset_index_offset = cc.offset_index_offset; + let offset_index_length = cc.offset_index_length; + let column_index_offset = cc.column_index_offset; + let column_index_length = cc.column_index_length; + let result = ColumnChunkMetaData { column_type, column_path, @@ -497,6 +567,12 @@ impl ColumnChunkMetaData { index_page_offset, dictionary_page_offset, statistics, + encoding_stats, + bloom_filter_offset, + offset_index_offset, + offset_index_length, + column_index_offset, + column_index_length, }; Ok(result) } @@ -516,17 +592,23 @@ impl ColumnChunkMetaData { index_page_offset: self.index_page_offset, dictionary_page_offset: self.dictionary_page_offset, statistics: statistics::to_thrift(self.statistics.as_ref()), - encoding_stats: None, + encoding_stats: self + .encoding_stats + .as_ref() + .map(|vec| vec.iter().map(page_encoding_stats::to_thrift).collect()), + bloom_filter_offset: self.bloom_filter_offset, }; ColumnChunk { file_path: self.file_path().map(|s| s.to_owned()), file_offset: self.file_offset, meta_data: Some(column_metadata), - offset_index_offset: None, - offset_index_length: None, - column_index_offset: None, - column_index_length: None, + offset_index_offset: self.offset_index_offset, + offset_index_length: self.offset_index_length, + column_index_offset: self.column_index_offset, + column_index_length: self.column_index_length, + crypto_metadata: None, + encrypted_column_metadata: None, } } } @@ -545,6 +627,12 @@ pub struct ColumnChunkMetaDataBuilder { index_page_offset: Option, dictionary_page_offset: Option, statistics: Option, + encoding_stats: Option>, + bloom_filter_offset: Option, + offset_index_offset: Option, + offset_index_length: Option, + column_index_offset: Option, + column_index_length: Option, } impl ColumnChunkMetaDataBuilder { @@ -563,6 +651,12 @@ impl ColumnChunkMetaDataBuilder { index_page_offset: None, dictionary_page_offset: None, statistics: None, + encoding_stats: None, + bloom_filter_offset: None, + offset_index_offset: None, + offset_index_length: None, + column_index_offset: None, + column_index_length: None, } } @@ -632,6 +726,42 @@ impl ColumnChunkMetaDataBuilder { self } + /// Sets page encoding stats for this column chunk. + pub fn set_page_encoding_stats(mut self, value: Vec) -> Self { + self.encoding_stats = Some(value); + self + } + + /// Sets optional bloom filter offset in bytes. + pub fn set_bloom_filter_offset(mut self, value: Option) -> Self { + self.bloom_filter_offset = value; + self + } + + /// Sets optional offset index offset in bytes. + pub fn set_offset_index_offset(mut self, value: Option) -> Self { + self.offset_index_offset = value; + self + } + + /// Sets optional offset index length in bytes. + pub fn set_offset_index_length(mut self, value: Option) -> Self { + self.offset_index_length = value; + self + } + + /// Sets optional column index offset in bytes. + pub fn set_column_index_offset(mut self, value: Option) -> Self { + self.column_index_offset = value; + self + } + + /// Sets optional column index length in bytes. + pub fn set_column_index_length(mut self, value: Option) -> Self { + self.column_index_length = value; + self + } + /// Builds column chunk metadata. pub fn build(self) -> Result { Ok(ColumnChunkMetaData { @@ -649,6 +779,12 @@ impl ColumnChunkMetaDataBuilder { index_page_offset: self.index_page_offset, dictionary_page_offset: self.dictionary_page_offset, statistics: self.statistics, + encoding_stats: self.encoding_stats, + bloom_filter_offset: self.bloom_filter_offset, + offset_index_offset: self.offset_index_offset, + offset_index_length: self.offset_index_length, + column_index_offset: self.column_index_offset, + column_index_length: self.column_index_length, }) } } @@ -656,6 +792,7 @@ impl ColumnChunkMetaDataBuilder { #[cfg(test)] mod tests { use super::*; + use crate::basic::{Encoding, PageType}; #[test] fn test_row_group_metadata_thrift_conversion() { @@ -711,17 +848,31 @@ mod tests { .set_total_uncompressed_size(3000) .set_data_page_offset(4000) .set_dictionary_page_offset(Some(5000)) + .set_page_encoding_stats(vec![ + PageEncodingStats { + page_type: PageType::DATA_PAGE, + encoding: Encoding::PLAIN, + count: 3, + }, + PageEncodingStats { + page_type: PageType::DATA_PAGE, + encoding: Encoding::RLE, + count: 5, + }, + ]) + .set_bloom_filter_offset(Some(6000)) + .set_offset_index_offset(Some(7000)) + .set_offset_index_length(Some(25)) + .set_column_index_offset(Some(8000)) + .set_column_index_length(Some(25)) .build() .unwrap(); - let col_chunk_exp = col_metadata.to_thrift(); - let col_chunk_res = - ColumnChunkMetaData::from_thrift(column_descr, col_chunk_exp.clone()) - .unwrap() - .to_thrift(); + ColumnChunkMetaData::from_thrift(column_descr, col_metadata.to_thrift()) + .unwrap(); - assert_eq!(col_chunk_res, col_chunk_exp); + assert_eq!(col_chunk_res, col_metadata); } #[test] diff --git a/parquet/src/file/mod.rs b/parquet/src/file/mod.rs index f85de98ccab6..66d8ce48e0a7 100644 --- a/parquet/src/file/mod.rs +++ b/parquet/src/file/mod.rs @@ -19,7 +19,7 @@ //! //! Provides access to file and row group readers and writers, record API, metadata, etc. //! -//! See [`reader::SerializedFileReader`](reader/struct.SerializedFileReader.html) or +//! See [`serialized_reader::SerializedFileReader`](serialized_reader/struct.SerializedFileReader.html) or //! [`writer::SerializedFileWriter`](writer/struct.SerializedFileWriter.html) for a //! starting reference, [`metadata::ParquetMetaData`](metadata/index.html) for file //! metadata, and [`statistics`](statistics/index.html) for working with statistics. @@ -32,7 +32,7 @@ //! use parquet::{ //! file::{ //! properties::WriterProperties, -//! writer::{FileWriter, SerializedFileWriter}, +//! writer::SerializedFileWriter, //! }, //! schema::parser::parse_message_type, //! }; @@ -51,9 +51,9 @@ //! let mut row_group_writer = writer.next_row_group().unwrap(); //! while let Some(mut col_writer) = row_group_writer.next_column().unwrap() { //! // ... write values to a column writer -//! row_group_writer.close_column(col_writer).unwrap(); +//! col_writer.close().unwrap() //! } -//! writer.close_row_group(row_group_writer).unwrap(); +//! row_group_writer.close().unwrap(); //! writer.close().unwrap(); //! //! let bytes = fs::read(&path).unwrap(); @@ -97,14 +97,14 @@ //! ``` pub mod footer; pub mod metadata; +pub mod page_encoding_stats; +pub mod page_index; pub mod properties; pub mod reader; pub mod serialized_reader; pub mod statistics; pub mod writer; -const FOOTER_SIZE: usize = 8; +/// The length of the parquet footer in bytes +pub const FOOTER_SIZE: usize = 8; const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1']; - -/// The number of bytes read at the end of the parquet file on first read -const DEFAULT_FOOTER_READ_SIZE: usize = 64 * 1024; diff --git a/parquet/src/file/page_encoding_stats.rs b/parquet/src/file/page_encoding_stats.rs new file mode 100644 index 000000000000..3180c7820802 --- /dev/null +++ b/parquet/src/file/page_encoding_stats.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::basic::{Encoding, PageType}; +use parquet_format::{ + Encoding as TEncoding, PageEncodingStats as TPageEncodingStats, PageType as TPageType, +}; + +/// PageEncodingStats for a column chunk and data page. +#[derive(Clone, Debug, PartialEq)] +pub struct PageEncodingStats { + /// the page type (data/dic/...) + pub page_type: PageType, + /// encoding of the page + pub encoding: Encoding, + /// number of pages of this type with this encoding + pub count: i32, +} + +/// Converts Thrift definition into `PageEncodingStats`. +pub fn from_thrift(thrift_encoding_stats: &TPageEncodingStats) -> PageEncodingStats { + let page_type = PageType::from(thrift_encoding_stats.page_type); + let encoding = Encoding::from(thrift_encoding_stats.encoding); + let count = thrift_encoding_stats.count; + + PageEncodingStats { + page_type, + encoding, + count, + } +} + +/// Converts `PageEncodingStats` into Thrift definition. +pub fn to_thrift(encoding_stats: &PageEncodingStats) -> TPageEncodingStats { + let page_type = TPageType::from(encoding_stats.page_type); + let encoding = TEncoding::from(encoding_stats.encoding); + let count = encoding_stats.count; + + TPageEncodingStats { + page_type, + encoding, + count, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::basic::{Encoding, PageType}; + + #[test] + fn test_page_encoding_stats_from_thrift() { + let stats = PageEncodingStats { + page_type: PageType::DATA_PAGE, + encoding: Encoding::PLAIN, + count: 1, + }; + + assert_eq!(from_thrift(&to_thrift(&stats)), stats); + } +} diff --git a/parquet/src/file/page_index/index.rs b/parquet/src/file/page_index/index.rs new file mode 100644 index 000000000000..e97826c63b41 --- /dev/null +++ b/parquet/src/file/page_index/index.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::basic::Type; +use crate::data_type::private::ParquetValueType; +use crate::data_type::Int96; +use crate::errors::ParquetError; +use crate::util::bit_util::from_le_slice; +use parquet_format::{BoundaryOrder, ColumnIndex}; +use std::fmt::Debug; + +/// The statistics in one page +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct PageIndex { + /// The minimum value, It is None when all values are null + pub min: Option, + /// The maximum value, It is None when all values are null + pub max: Option, + /// Null values in the page + pub null_count: Option, +} + +impl PageIndex { + pub fn min(&self) -> Option<&T> { + self.min.as_ref() + } + pub fn max(&self) -> Option<&T> { + self.max.as_ref() + } + pub fn null_count(&self) -> Option { + self.null_count + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Index { + BOOLEAN(BooleanIndex), + INT32(NativeIndex), + INT64(NativeIndex), + INT96(NativeIndex), + FLOAT(NativeIndex), + DOUBLE(NativeIndex), + BYTE_ARRAY(ByteArrayIndex), + FIXED_LEN_BYTE_ARRAY(ByteArrayIndex), +} + +/// An index of a column of [`Type`] physical representation +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct NativeIndex { + /// The physical type + pub physical_type: Type, + /// The indexes, one item per page + pub indexes: Vec>, + /// the order + pub boundary_order: BoundaryOrder, +} + +impl NativeIndex { + /// Creates a new [`NativeIndex`] + pub(crate) fn try_new( + index: ColumnIndex, + physical_type: Type, + ) -> Result { + let len = index.min_values.len(); + + let null_counts = index + .null_counts + .map(|x| x.into_iter().map(Some).collect::>()) + .unwrap_or_else(|| vec![None; len]); + + let indexes = index + .min_values + .iter() + .zip(index.max_values.into_iter()) + .zip(index.null_pages.into_iter()) + .zip(null_counts.into_iter()) + .map(|(((min, max), is_null), null_count)| { + let (min, max) = if is_null { + (None, None) + } else { + let min = min.as_slice(); + let max = max.as_slice(); + (Some(from_le_slice::(min)), Some(from_le_slice::(max))) + }; + Ok(PageIndex { + min, + max, + null_count, + }) + }) + .collect::, ParquetError>>()?; + + Ok(Self { + physical_type, + indexes, + boundary_order: index.boundary_order, + }) + } +} + +/// An index of a column of bytes type +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ByteArrayIndex { + /// The physical type + pub physical_type: Type, + /// The indexes, one item per page + pub indexes: Vec>>, + pub boundary_order: BoundaryOrder, +} + +impl ByteArrayIndex { + pub(crate) fn try_new( + index: ColumnIndex, + physical_type: Type, + ) -> Result { + let len = index.min_values.len(); + + let null_counts = index + .null_counts + .map(|x| x.into_iter().map(Some).collect::>()) + .unwrap_or_else(|| vec![None; len]); + + let indexes = index + .min_values + .into_iter() + .zip(index.max_values.into_iter()) + .zip(index.null_pages.into_iter()) + .zip(null_counts.into_iter()) + .map(|(((min, max), is_null), null_count)| { + let (min, max) = if is_null { + (None, None) + } else { + (Some(min), Some(max)) + }; + Ok(PageIndex { + min, + max, + null_count, + }) + }) + .collect::, ParquetError>>()?; + + Ok(Self { + physical_type, + indexes, + boundary_order: index.boundary_order, + }) + } +} + +/// An index of a column of boolean physical type +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct BooleanIndex { + /// The indexes, one item per page + pub indexes: Vec>, + pub boundary_order: BoundaryOrder, +} + +impl BooleanIndex { + pub(crate) fn try_new(index: ColumnIndex) -> Result { + let len = index.min_values.len(); + + let null_counts = index + .null_counts + .map(|x| x.into_iter().map(Some).collect::>()) + .unwrap_or_else(|| vec![None; len]); + + let indexes = index + .min_values + .into_iter() + .zip(index.max_values.into_iter()) + .zip(index.null_pages.into_iter()) + .zip(null_counts.into_iter()) + .map(|(((min, max), is_null), null_count)| { + let (min, max) = if is_null { + (None, None) + } else { + let min = min[0] != 0; + let max = max[0] == 1; + (Some(min), Some(max)) + }; + Ok(PageIndex { + min, + max, + null_count, + }) + }) + .collect::, ParquetError>>()?; + + Ok(Self { + indexes, + boundary_order: index.boundary_order, + }) + } +} diff --git a/parquet/src/file/page_index/index_reader.rs b/parquet/src/file/page_index/index_reader.rs new file mode 100644 index 000000000000..8414480903fd --- /dev/null +++ b/parquet/src/file/page_index/index_reader.rs @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::basic::Type; +use crate::data_type::Int96; +use crate::errors::ParquetError; +use crate::file::metadata::ColumnChunkMetaData; +use crate::file::page_index::index::{BooleanIndex, ByteArrayIndex, Index, NativeIndex}; +use crate::file::reader::ChunkReader; +use parquet_format::{ColumnIndex, OffsetIndex, PageLocation}; +use std::io::{Cursor, Read}; +use thrift::protocol::TCompactInputProtocol; + +/// Read on row group's all columns indexes and change into [`Index`] +/// If not the format not available return an empty vector. +pub fn read_columns_indexes( + reader: &R, + chunks: &[ColumnChunkMetaData], +) -> Result, ParquetError> { + let (offset, lengths) = get_index_offset_and_lengths(chunks)?; + let length = lengths.iter().sum::(); + + //read all need data into buffer + let mut reader = reader.get_read(offset, reader.len() as usize)?; + let mut data = vec![0; length]; + reader.read_exact(&mut data)?; + + let mut start = 0; + let data = lengths.into_iter().map(|length| { + let r = &data[start..start + length]; + start += length; + r + }); + + chunks + .iter() + .zip(data) + .map(|(chunk, data)| { + let column_type = chunk.column_type(); + deserialize_column_index(data, column_type) + }) + .collect() +} + +/// Read on row group's all indexes and change into [`Index`] +/// If not the format not available return an empty vector. +pub fn read_pages_locations( + reader: &R, + chunks: &[ColumnChunkMetaData], +) -> Result>, ParquetError> { + let (offset, total_length) = get_location_offset_and_total_length(chunks)?; + + //read all need data into buffer + let mut reader = reader.get_read(offset, reader.len() as usize)?; + let mut data = vec![0; total_length]; + reader.read_exact(&mut data)?; + + let mut d = Cursor::new(data); + let mut result = vec![]; + + for _ in 0..chunks.len() { + let mut prot = TCompactInputProtocol::new(&mut d); + let offset = OffsetIndex::read_from_in_protocol(&mut prot)?; + result.push(offset.page_locations); + } + Ok(result) +} + +//Get File offsets of every ColumnChunk's page_index +//If there are invalid offset return a zero offset with empty lengths. +fn get_index_offset_and_lengths( + chunks: &[ColumnChunkMetaData], +) -> Result<(u64, Vec), ParquetError> { + let first_col_metadata = if let Some(chunk) = chunks.first() { + chunk + } else { + return Ok((0, vec![])); + }; + + let offset: u64 = if let Some(offset) = first_col_metadata.column_index_offset() { + offset.try_into().unwrap() + } else { + return Ok((0, vec![])); + }; + + let lengths = chunks + .iter() + .map(|x| x.column_index_length()) + .map(|maybe_length| { + let index_length = maybe_length.ok_or_else(|| { + ParquetError::General( + "The column_index_length must exist if offset_index_offset exists" + .to_string(), + ) + })?; + + Ok(index_length.try_into().unwrap()) + }) + .collect::, ParquetError>>()?; + + Ok((offset, lengths)) +} + +//Get File offset of ColumnChunk's pages_locations +//If there are invalid offset return a zero offset with zero length. +fn get_location_offset_and_total_length( + chunks: &[ColumnChunkMetaData], +) -> Result<(u64, usize), ParquetError> { + let metadata = if let Some(chunk) = chunks.first() { + chunk + } else { + return Ok((0, 0)); + }; + + let offset: u64 = if let Some(offset) = metadata.offset_index_offset() { + offset.try_into().unwrap() + } else { + return Ok((0, 0)); + }; + + let total_length = chunks + .iter() + .map(|x| x.offset_index_length().unwrap()) + .sum::() as usize; + Ok((offset, total_length)) +} + +fn deserialize_column_index( + data: &[u8], + column_type: Type, +) -> Result { + let mut d = Cursor::new(data); + let mut prot = TCompactInputProtocol::new(&mut d); + + let index = ColumnIndex::read_from_in_protocol(&mut prot)?; + + let index = match column_type { + Type::BOOLEAN => Index::BOOLEAN(BooleanIndex::try_new(index)?), + Type::INT32 => Index::INT32(NativeIndex::::try_new(index, column_type)?), + Type::INT64 => Index::INT64(NativeIndex::::try_new(index, column_type)?), + Type::INT96 => Index::INT96(NativeIndex::::try_new(index, column_type)?), + Type::FLOAT => Index::FLOAT(NativeIndex::::try_new(index, column_type)?), + Type::DOUBLE => Index::DOUBLE(NativeIndex::::try_new(index, column_type)?), + Type::BYTE_ARRAY => { + Index::BYTE_ARRAY(ByteArrayIndex::try_new(index, column_type)?) + } + Type::FIXED_LEN_BYTE_ARRAY => { + Index::FIXED_LEN_BYTE_ARRAY(ByteArrayIndex::try_new(index, column_type)?) + } + }; + + Ok(index) +} diff --git a/arrow/src/zz_memory_check.rs b/parquet/src/file/page_index/mod.rs similarity index 63% rename from arrow/src/zz_memory_check.rs rename to parquet/src/file/page_index/mod.rs index 70ec8ebdbdd2..fc87ef20448f 100644 --- a/arrow/src/zz_memory_check.rs +++ b/parquet/src/file/page_index/mod.rs @@ -15,17 +15,6 @@ // specific language governing permissions and limitations // under the License. -// This file is named like this so that it is the last one to be tested -// It contains no content, it has a single test that verifies that there is no memory leak -// on all unit-tests - -#[cfg(feature = "memory-check")] -mod tests { - use crate::memory::ALLOCATIONS; - - // verify that there is no data un-allocated - #[test] - fn test_memory_check() { - unsafe { assert_eq!(ALLOCATIONS.load(std::sync::atomic::Ordering::SeqCst), 0) } - } -} +pub mod index; +pub mod index_reader; +pub(crate) mod range; diff --git a/parquet/src/file/page_index/range.rs b/parquet/src/file/page_index/range.rs new file mode 100644 index 000000000000..06c06553ccd5 --- /dev/null +++ b/parquet/src/file/page_index/range.rs @@ -0,0 +1,474 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +use crate::errors::ParquetError; +use parquet_format::PageLocation; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::ops::RangeInclusive; + +type Range = RangeInclusive; + +pub trait RangeOps { + fn is_before(&self, other: &Self) -> bool; + + fn is_after(&self, other: &Self) -> bool; + + fn count(&self) -> usize; + + fn union(left: &Range, right: &Range) -> Option; + + fn intersection(left: &Range, right: &Range) -> Option; +} + +impl RangeOps for Range { + fn is_before(&self, other: &Range) -> bool { + self.end() < other.start() + } + + fn is_after(&self, other: &Range) -> bool { + self.start() > other.end() + } + + fn count(&self) -> usize { + self.end() + 1 - self.start() + } + + /// Return the union of the two ranges, + /// Return `None` if there are hole between them. + fn union(left: &Range, right: &Range) -> Option { + if left.start() <= right.start() { + if left.end() + 1 >= *right.start() { + return Some(Range::new( + *left.start(), + std::cmp::max(*left.end(), *right.end()), + )); + } + } else if right.end() + 1 >= *left.start() { + return Some(Range::new( + *right.start(), + std::cmp::max(*left.end(), *right.end()), + )); + } + None + } + + /// Returns the intersection of the two ranges, + /// return null if they are not overlapped. + fn intersection(left: &Range, right: &Range) -> Option { + if left.start() <= right.start() { + if left.end() >= right.start() { + return Some(Range::new( + *right.start(), + std::cmp::min(*left.end(), *right.end()), + )); + } + } else if right.end() >= left.start() { + return Some(Range::new( + *left.start(), + std::cmp::min(*left.end(), *right.end()), + )); + } + None + } +} + +/// Struct representing row ranges in a row-group. These row ranges are calculated as a result of using +/// the column index on the filtering. +#[derive(Debug, Clone)] +pub struct RowRanges { + pub ranges: VecDeque, +} + +impl RowRanges { + //create an empty RowRanges + pub fn new_empty() -> Self { + RowRanges { + ranges: VecDeque::new(), + } + } + + pub fn count(&self) -> usize { + self.ranges.len() + } + + pub fn filter_with_mask(&self, mask: &[bool]) -> Result { + if self.ranges.len() != mask.len() { + return Err(ParquetError::General(format!( + "Mask size {} is not equal to number of pages {}", + mask.len(), + self.count() + ))); + } + let vec_range = mask + .iter() + .zip(self.ranges.clone()) + .filter_map(|(&f, r)| if f { Some(r) } else { None }) + .collect(); + Ok(RowRanges { ranges: vec_range }) + } + + /// Add a range to the end of the list of ranges. It maintains the disjunctive ascending order of the ranges by + /// trying to union the specified range to the last ranges in the list. The specified range shall be larger than + /// the last one or might be overlapped with some of the last ones. + /// [a, b] < [c, d] if b < c + pub fn add(&mut self, mut range: Range) { + let count = self.count(); + if count > 0 { + for i in 1..(count + 1) { + let index = count - i; + let last = self.ranges.get(index).unwrap(); + assert!(!last.is_after(&range), "Must add range in ascending!"); + // try to merge range + match Range::union(last, &range) { + None => { + break; + } + Some(r) => { + range = r; + self.ranges.remove(index); + } + } + } + } + self.ranges.push_back(range); + } + + /// Calculates the union of the two specified RowRanges object. The union of two range is calculated if there are no + /// elements between them. Otherwise, the two disjunctive ranges are stored separately. + /// For example: + /// [113, 241] ∪ [221, 340] = [113, 330] + /// [113, 230] ∪ [231, 340] = [113, 340] + /// while + /// [113, 230] ∪ [232, 340] = [113, 230], [232, 340] + /// + /// The result RowRanges object will contain all the row indexes that were contained in one of the specified objects. + pub fn union(mut left: RowRanges, mut right: RowRanges) -> RowRanges { + let v1 = &mut left.ranges; + let v2 = &mut right.ranges; + let mut result = RowRanges::new_empty(); + if v2.is_empty() { + left.clone() + } else { + let mut range2 = v2.pop_front().unwrap(); + while !v1.is_empty() { + let range1 = v1.pop_front().unwrap(); + if range1.is_after(&range2) { + result.add(range2); + range2 = range1; + std::mem::swap(v1, v2); + } else { + result.add(range1); + } + } + + result.add(range2); + while !v2.is_empty() { + result.add(v2.pop_front().unwrap()) + } + + result + } + } + + /// Calculates the intersection of the two specified RowRanges object. Two ranges intersect if they have common + /// elements otherwise the result is empty. + /// For example: + /// [113, 241] ∩ [221, 340] = [221, 241] + /// while + /// [113, 230] ∩ [231, 340] = + /// + /// The result RowRanges object will contain all the row indexes there were contained in both of the specified objects + #[allow(clippy::mut_range_bound)] + pub fn intersection(left: RowRanges, right: RowRanges) -> RowRanges { + let mut result = RowRanges::new_empty(); + let mut right_index = 0; + for l in left.ranges.iter() { + for i in right_index..right.ranges.len() { + let r = right.ranges.get(i).unwrap(); + if l.is_before(r) { + break; + } else if l.is_after(r) { + right_index = i + 1; + continue; + } + if let Some(ra) = Range::intersection(l, r) { + result.add(ra); + } + } + } + result + } + + pub fn row_count(&self) -> usize { + self.ranges.iter().map(|x| x.count()).sum() + } + + pub fn is_overlapping(&self, x: &Range) -> bool { + self.ranges + .binary_search_by(|y| -> Ordering { + if y.is_before(x) { + Ordering::Less + } else if y.is_after(x) { + Ordering::Greater + } else { + Ordering::Equal + } + }) + .is_ok() + } +} + +/// Takes an array of [`PageLocation`], and a total number of rows, and based on the provided `page_mask` +/// returns the corresponding [`RowRanges`] to scan +pub fn compute_row_ranges( + page_mask: &[bool], + locations: &[PageLocation], + total_rows: usize, +) -> Result { + if page_mask.len() != locations.len() { + return Err(ParquetError::General(format!( + "Page_mask size {} is not equal to number of locations {}", + page_mask.len(), + locations.len(), + ))); + } + let row_ranges = page_locations_to_row_ranges(locations, total_rows)?; + row_ranges.filter_with_mask(page_mask) +} + +fn page_locations_to_row_ranges( + locations: &[PageLocation], + total_rows: usize, +) -> Result { + if locations.is_empty() || total_rows == 0 { + return Ok(RowRanges::new_empty()); + } + + // If we read directly from parquet pageIndex to construct locations, + // the location index should be continuous + let mut vec_range: VecDeque = locations + .windows(2) + .map(|x| { + let start = x[0].first_row_index as usize; + let end = (x[1].first_row_index - 1) as usize; + Range::new(start, end) + }) + .collect(); + + let last = Range::new( + locations.last().unwrap().first_row_index as usize, + total_rows - 1, + ); + vec_range.push_back(last); + + Ok(RowRanges { ranges: vec_range }) +} + +#[cfg(test)] +mod tests { + use crate::basic::Type::INT32; + use crate::file::page_index::index::{NativeIndex, PageIndex}; + use crate::file::page_index::range::{compute_row_ranges, Range, RowRanges}; + use parquet_format::{BoundaryOrder, PageLocation}; + + #[test] + fn test_binary_search_overlap() { + let mut ranges = RowRanges::new_empty(); + ranges.add(Range::new(1, 3)); + ranges.add(Range::new(6, 7)); + + assert!(ranges.is_overlapping(&Range::new(1, 2))); + // include both [start, end] + assert!(ranges.is_overlapping(&Range::new(0, 1))); + assert!(ranges.is_overlapping(&Range::new(0, 3))); + + assert!(ranges.is_overlapping(&Range::new(0, 7))); + assert!(ranges.is_overlapping(&Range::new(2, 7))); + + assert!(!ranges.is_overlapping(&Range::new(4, 5))); + } + + #[test] + fn test_add_func_ascending_disjunctive() { + let mut ranges_1 = RowRanges::new_empty(); + ranges_1.add(Range::new(1, 3)); + ranges_1.add(Range::new(5, 6)); + ranges_1.add(Range::new(8, 9)); + assert_eq!(ranges_1.count(), 3); + } + + #[test] + fn test_add_func_ascending_merge() { + let mut ranges_1 = RowRanges::new_empty(); + ranges_1.add(Range::new(1, 3)); + ranges_1.add(Range::new(4, 5)); + ranges_1.add(Range::new(6, 7)); + assert_eq!(ranges_1.count(), 1); + } + + #[test] + #[should_panic(expected = "Must add range in ascending!")] + fn test_add_func_not_ascending() { + let mut ranges_1 = RowRanges::new_empty(); + ranges_1.add(Range::new(6, 7)); + ranges_1.add(Range::new(1, 3)); + ranges_1.add(Range::new(4, 5)); + assert_eq!(ranges_1.count(), 1); + } + + #[test] + fn test_union_func() { + let mut ranges_1 = RowRanges::new_empty(); + ranges_1.add(Range::new(1, 2)); + ranges_1.add(Range::new(3, 4)); + ranges_1.add(Range::new(5, 6)); + + let mut ranges_2 = RowRanges::new_empty(); + ranges_2.add(Range::new(2, 3)); + ranges_2.add(Range::new(4, 5)); + ranges_2.add(Range::new(6, 7)); + + let ranges = RowRanges::union(ranges_1, ranges_2); + assert_eq!(ranges.count(), 1); + let range = ranges.ranges.get(0).unwrap(); + assert_eq!(*range.start(), 1); + assert_eq!(*range.end(), 7); + + let mut ranges_a = RowRanges::new_empty(); + ranges_a.add(Range::new(1, 3)); + ranges_a.add(Range::new(5, 8)); + ranges_a.add(Range::new(11, 12)); + + let mut ranges_b = RowRanges::new_empty(); + ranges_b.add(Range::new(0, 2)); + ranges_b.add(Range::new(6, 7)); + ranges_b.add(Range::new(10, 11)); + + let ranges = RowRanges::union(ranges_a, ranges_b); + assert_eq!(ranges.count(), 3); + + let range_1 = ranges.ranges.get(0).unwrap(); + assert_eq!(*range_1.start(), 0); + assert_eq!(*range_1.end(), 3); + let range_2 = ranges.ranges.get(1).unwrap(); + assert_eq!(*range_2.start(), 5); + assert_eq!(*range_2.end(), 8); + let range_3 = ranges.ranges.get(2).unwrap(); + assert_eq!(*range_3.start(), 10); + assert_eq!(*range_3.end(), 12); + } + + #[test] + fn test_intersection_func() { + let mut ranges_1 = RowRanges::new_empty(); + ranges_1.add(Range::new(1, 2)); + ranges_1.add(Range::new(3, 4)); + ranges_1.add(Range::new(5, 6)); + + let mut ranges_2 = RowRanges::new_empty(); + ranges_2.add(Range::new(2, 3)); + ranges_2.add(Range::new(4, 5)); + ranges_2.add(Range::new(6, 7)); + + let ranges = RowRanges::intersection(ranges_1, ranges_2); + assert_eq!(ranges.count(), 1); + let range = ranges.ranges.get(0).unwrap(); + assert_eq!(*range.start(), 2); + assert_eq!(*range.end(), 6); + + let mut ranges_a = RowRanges::new_empty(); + ranges_a.add(Range::new(1, 3)); + ranges_a.add(Range::new(5, 8)); + ranges_a.add(Range::new(11, 12)); + + let mut ranges_b = RowRanges::new_empty(); + ranges_b.add(Range::new(0, 2)); + ranges_b.add(Range::new(6, 7)); + ranges_b.add(Range::new(10, 11)); + + let ranges = RowRanges::intersection(ranges_a, ranges_b); + assert_eq!(ranges.count(), 3); + + let range_1 = ranges.ranges.get(0).unwrap(); + assert_eq!(*range_1.start(), 1); + assert_eq!(*range_1.end(), 2); + let range_2 = ranges.ranges.get(1).unwrap(); + assert_eq!(*range_2.start(), 6); + assert_eq!(*range_2.end(), 7); + let range_3 = ranges.ranges.get(2).unwrap(); + assert_eq!(*range_3.start(), 11); + assert_eq!(*range_3.end(), 11); + } + + #[test] + fn test_compute_one() { + let locations = &[PageLocation { + offset: 50, + compressed_page_size: 10, + first_row_index: 0, + }]; + let total_rows = 10; + + let row_ranges = compute_row_ranges(&[true], locations, total_rows).unwrap(); + assert_eq!(row_ranges.count(), 1); + assert_eq!(row_ranges.ranges.get(0).unwrap(), &Range::new(0, 9)); + } + + #[test] + fn test_compute_multi() { + let index: NativeIndex = NativeIndex { + physical_type: INT32, + indexes: vec![ + PageIndex { + min: Some(0), + max: Some(10), + null_count: Some(0), + }, + PageIndex { + min: Some(15), + max: Some(20), + null_count: Some(0), + }, + ], + boundary_order: BoundaryOrder::Ascending, + }; + let locations = &[ + PageLocation { + offset: 100, + compressed_page_size: 10, + first_row_index: 0, + }, + PageLocation { + offset: 200, + compressed_page_size: 20, + first_row_index: 11, + }, + ]; + let total_rows = 20; + + //filter `x < 11` + let filter = + |page: &PageIndex| page.max.as_ref().map(|&x| x < 11).unwrap_or(false); + + let mask = index.indexes.iter().map(filter).collect::>(); + + let row_ranges = compute_row_ranges(&mask, locations, total_rows).unwrap(); + + assert_eq!(row_ranges.count(), 1); + assert_eq!(row_ranges.ranges.get(0).unwrap(), &Range::new(0, 10)); + } +} diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index c48e4e7a07b0..2baf93933cbc 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -62,7 +62,7 @@ const DEFAULT_DICTIONARY_ENABLED: bool = true; const DEFAULT_DICTIONARY_PAGE_SIZE_LIMIT: usize = DEFAULT_PAGE_SIZE; const DEFAULT_STATISTICS_ENABLED: bool = true; const DEFAULT_MAX_STATISTICS_SIZE: usize = 4096; -const DEFAULT_MAX_ROW_GROUP_SIZE: usize = 128 * 1024 * 1024; +const DEFAULT_MAX_ROW_GROUP_SIZE: usize = 1024 * 1024; const DEFAULT_CREATED_BY: &str = env!("PARQUET_CREATED_BY"); /// Parquet writer version. @@ -129,7 +129,7 @@ impl WriterProperties { self.write_batch_size } - /// Returns max size for a row group. + /// Returns maximum number of rows in a row group. pub fn max_row_group_size(&self) -> usize { self.max_row_group_size } @@ -145,8 +145,8 @@ impl WriterProperties { } /// Returns `key_value_metadata` KeyValue pairs. - pub fn key_value_metadata(&self) -> &Option> { - &self.key_value_metadata + pub fn key_value_metadata(&self) -> Option<&Vec> { + self.key_value_metadata.as_ref() } /// Returns encoding for a data page, when dictionary encoding is enabled. @@ -288,7 +288,7 @@ impl WriterPropertiesBuilder { self } - /// Sets max size for a row group. + /// Sets maximum number of rows in a row group. pub fn set_max_row_group_size(mut self, value: usize) -> Self { assert!(value > 0, "Cannot have a 0 max row group size"); self.max_row_group_size = value; @@ -523,7 +523,7 @@ mod tests { assert_eq!(props.max_row_group_size(), DEFAULT_MAX_ROW_GROUP_SIZE); assert_eq!(props.writer_version(), DEFAULT_WRITER_VERSION); assert_eq!(props.created_by(), DEFAULT_CREATED_BY); - assert_eq!(props.key_value_metadata(), &None); + assert_eq!(props.key_value_metadata(), None); assert_eq!(props.encoding(&ColumnPath::from("col")), None); assert_eq!( props.compression(&ColumnPath::from("col")), @@ -631,7 +631,9 @@ mod tests { assert_eq!(props.created_by(), "default"); assert_eq!( props.key_value_metadata(), - &Some(vec![KeyValue::new("key".to_string(), "value".to_string(),)]) + Some(&vec![ + KeyValue::new("key".to_string(), "value".to_string(),) + ]) ); assert_eq!( diff --git a/parquet/src/file/reader.rs b/parquet/src/file/reader.rs index aa8ba83a6c0e..d752273655c5 100644 --- a/parquet/src/file/reader.rs +++ b/parquet/src/file/reader.rs @@ -43,8 +43,8 @@ pub trait Length { /// The ChunkReader trait generates readers of chunks of a source. /// For a file system reader, each chunk might contain a clone of File bounded on a given range. /// For an object store reader, each read can be mapped to a range request. -pub trait ChunkReader: Length { - type T: Read; +pub trait ChunkReader: Length + Send + Sync { + type T: Read + Send; /// get a serialy readeable slice of the current reader /// This should fail if the slice exceeds the current bounds fn get_read(&self, start: u64, length: usize) -> Result; @@ -55,7 +55,7 @@ pub trait ChunkReader: Length { /// Parquet file reader API. With this, user can get metadata information about the /// Parquet file, can get reader for each row group, and access record iterator. -pub trait FileReader { +pub trait FileReader: Send + Sync { /// Get metadata information about this file. fn metadata(&self) -> &ParquetMetaData; @@ -76,7 +76,7 @@ pub trait FileReader { /// Parquet row group reader API. With this, user can get metadata information about the /// row group, as well as readers for each individual column chunk. -pub trait RowGroupReader { +pub trait RowGroupReader: Send + Sync { /// Get metadata information about this row group. fn metadata(&self) -> &RowGroupMetaData; @@ -139,7 +139,7 @@ pub trait RowGroupReader { /// Implementation of page iterator for parquet file. pub struct FilePageIterator { column_index: usize, - row_group_indices: Box>, + row_group_indices: Box + Send>, file_reader: Arc, } @@ -156,7 +156,7 @@ impl FilePageIterator { /// Create page iterator from parquet file reader with only some row groups. pub fn with_row_groups( column_index: usize, - row_group_indices: Box>, + row_group_indices: Box + Send>, file_reader: Arc, ) -> Result { // Check that column_index is valid diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 9d6fb52491fb..6ff73e041e88 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -18,6 +18,7 @@ //! Contains implementations of the reader traits FileReader, RowGroupReader and PageReader //! Also contains implementations of the ChunkReader for files (with buffering) and byte arrays (RAM) +use bytes::{Buf, Bytes}; use std::{convert::TryFrom, fs::File, io::Read, path::Path, sync::Arc}; use parquet_format::{PageHeader, PageType}; @@ -27,6 +28,7 @@ use crate::basic::{Compression, Encoding, Type}; use crate::column::page::{Page, PageReader}; use crate::compression::{create_codec, Codec}; use crate::errors::{ParquetError, Result}; +use crate::file::page_index::index_reader; use crate::file::{footer, metadata::*, reader::*, statistics}; use crate::record::reader::RowIter; use crate::record::Row; @@ -35,6 +37,7 @@ use crate::util::{io::TryClone, memory::ByteBufferPtr}; // export `SliceableCursor` and `FileSource` publically so clients can // re-use the logic in their own ParquetFileWriter wrappers +#[allow(deprecated)] pub use crate::util::{cursor::SliceableCursor, io::FileSource}; // ---------------------------------------------------------------------- @@ -60,12 +63,35 @@ impl ChunkReader for File { } } +impl Length for Bytes { + fn len(&self) -> u64 { + self.len() as u64 + } +} + +impl TryClone for Bytes { + fn try_clone(&self) -> std::io::Result { + Ok(self.clone()) + } +} + +impl ChunkReader for Bytes { + type T = bytes::buf::Reader; + + fn get_read(&self, start: u64, length: usize) -> Result { + let start = start as usize; + Ok(self.slice(start..start + length).reader()) + } +} + +#[allow(deprecated)] impl Length for SliceableCursor { fn len(&self) -> u64 { SliceableCursor::len(self) } } +#[allow(deprecated)] impl ChunkReader for SliceableCursor { type T = SliceableCursor; @@ -127,6 +153,69 @@ pub struct SerializedFileReader { metadata: ParquetMetaData, } +/// A builder for [`ReadOptions`]. +/// For the predicates that are added to the builder, +/// they will be chained using 'AND' to filter the row groups. +pub struct ReadOptionsBuilder { + predicates: Vec bool>>, + enable_page_index: bool, +} + +impl ReadOptionsBuilder { + /// New builder + pub fn new() -> Self { + ReadOptionsBuilder { + predicates: vec![], + enable_page_index: false, + } + } + + /// Add a predicate on row group metadata to the reading option, + /// Filter only row groups that match the predicate criteria + pub fn with_predicate( + mut self, + predicate: Box bool>, + ) -> Self { + self.predicates.push(predicate); + self + } + + /// Add a range predicate on filtering row groups if their midpoints are within + /// the Closed-Open range `[start..end) {x | start <= x < end}` + pub fn with_range(mut self, start: i64, end: i64) -> Self { + assert!(start < end); + let predicate = move |rg: &RowGroupMetaData, _: usize| { + let mid = get_midpoint_offset(rg); + mid >= start && mid < end + }; + self.predicates.push(Box::new(predicate)); + self + } + + /// Enable page index in the reading option, + pub fn with_page_index(mut self) -> Self { + self.enable_page_index = true; + self + } + + /// Seal the builder and return the read options + pub fn build(self) -> ReadOptions { + ReadOptions { + predicates: self.predicates, + enable_page_index: self.enable_page_index, + } + } +} + +/// A collection of options for reading a Parquet file. +/// +/// Currently, only predicates on row group metadata are supported. +/// All predicates will be chained using 'AND' to filter the row groups. +pub struct ReadOptions { + predicates: Vec bool>>, + enable_page_index: bool, +} + impl SerializedFileReader { /// Creates file reader from a Parquet file. /// Returns error if Parquet file does not exist or is corrupt. @@ -138,25 +227,68 @@ impl SerializedFileReader { }) } - /// Filters row group metadata to only those row groups, - /// for which the predicate function returns true - pub fn filter_row_groups( - &mut self, - predicate: &dyn Fn(&RowGroupMetaData, usize) -> bool, - ) { + /// Creates file reader from a Parquet file with read options. + /// Returns error if Parquet file does not exist or is corrupt. + pub fn new_with_options(chunk_reader: R, options: ReadOptions) -> Result { + let metadata = footer::parse_metadata(&chunk_reader)?; + let mut predicates = options.predicates; + let row_groups = metadata.row_groups().to_vec(); let mut filtered_row_groups = Vec::::new(); - for (i, row_group_metadata) in self.metadata.row_groups().iter().enumerate() { - if predicate(row_group_metadata, i) { - filtered_row_groups.push(row_group_metadata.clone()); + for (i, rg_meta) in row_groups.into_iter().enumerate() { + let mut keep = true; + for predicate in &mut predicates { + if !predicate(&rg_meta, i) { + keep = false; + break; + } + } + if keep { + filtered_row_groups.push(rg_meta); } } - self.metadata = ParquetMetaData::new( - self.metadata.file_metadata().clone(), - filtered_row_groups, - ); + + if options.enable_page_index { + //Todo for now test data `data_index_bloom_encoding_stats.parquet` only have one rowgroup + //support multi after create multi-RG test data. + let cols = metadata.row_group(0); + let columns_indexes = + index_reader::read_columns_indexes(&chunk_reader, cols.columns())?; + let pages_locations = + index_reader::read_pages_locations(&chunk_reader, cols.columns())?; + + Ok(Self { + chunk_reader: Arc::new(chunk_reader), + metadata: ParquetMetaData::new_with_page_index( + metadata.file_metadata().clone(), + filtered_row_groups, + Some(columns_indexes), + Some(pages_locations), + ), + }) + } else { + Ok(Self { + chunk_reader: Arc::new(chunk_reader), + metadata: ParquetMetaData::new( + metadata.file_metadata().clone(), + filtered_row_groups, + ), + }) + } } } +/// Get midpoint offset for a row group +fn get_midpoint_offset(meta: &RowGroupMetaData) -> i64 { + let col = meta.column(0); + let mut offset = col.data_page_offset(); + if let Some(dic_offset) = col.dictionary_page_offset() { + if offset > dic_offset { + offset = dic_offset + } + }; + offset + meta.compressed_size() / 2 +} + impl FileReader for SerializedFileReader { fn metadata(&self) -> &ParquetMetaData { &self.metadata @@ -210,6 +342,7 @@ impl<'a, R: 'static + ChunkReader> RowGroupReader for SerializedRowGroupReader<' fn get_column_page_reader(&self, i: usize) -> Result> { let col = self.metadata.column(i); let (col_start, col_length) = col.byte_range(); + //Todo filter with multi row range let file_chunk = self.chunk_reader.get_read(col_start, col_length as usize)?; let page_reader = SerializedPageReader::new( file_chunk, @@ -225,6 +358,108 @@ impl<'a, R: 'static + ChunkReader> RowGroupReader for SerializedRowGroupReader<' } } +/// Reads a [`PageHeader`] from the provided [`Read`] +pub(crate) fn read_page_header(input: &mut T) -> Result { + let mut prot = TCompactInputProtocol::new(input); + let page_header = PageHeader::read_from_in_protocol(&mut prot)?; + Ok(page_header) +} + +/// Decodes a [`Page`] from the provided `buffer` +pub(crate) fn decode_page( + page_header: PageHeader, + buffer: ByteBufferPtr, + physical_type: Type, + decompressor: Option<&mut Box>, +) -> Result { + // When processing data page v2, depending on enabled compression for the + // page, we should account for uncompressed data ('offset') of + // repetition and definition levels. + // + // We always use 0 offset for other pages other than v2, `true` flag means + // that compression will be applied if decompressor is defined + let mut offset: usize = 0; + let mut can_decompress = true; + + if let Some(ref header_v2) = page_header.data_page_header_v2 { + offset = (header_v2.definition_levels_byte_length + + header_v2.repetition_levels_byte_length) as usize; + // When is_compressed flag is missing the page is considered compressed + can_decompress = header_v2.is_compressed.unwrap_or(true); + } + + // TODO: page header could be huge because of statistics. We should set a + // maximum page header size and abort if that is exceeded. + let buffer = match decompressor { + Some(decompressor) if can_decompress => { + let uncompressed_size = page_header.uncompressed_page_size as usize; + let mut decompressed = Vec::with_capacity(uncompressed_size); + let compressed = &buffer.as_ref()[offset..]; + decompressed.extend_from_slice(&buffer.as_ref()[..offset]); + decompressor.decompress(compressed, &mut decompressed)?; + + if decompressed.len() != uncompressed_size { + return Err(general_err!( + "Actual decompressed size doesn't match the expected one ({} vs {})", + decompressed.len(), + uncompressed_size + )); + } + + ByteBufferPtr::new(decompressed) + } + _ => buffer, + }; + + let result = match page_header.type_ { + PageType::DictionaryPage => { + assert!(page_header.dictionary_page_header.is_some()); + let dict_header = page_header.dictionary_page_header.as_ref().unwrap(); + let is_sorted = dict_header.is_sorted.unwrap_or(false); + Page::DictionaryPage { + buf: buffer, + num_values: dict_header.num_values as u32, + encoding: Encoding::from(dict_header.encoding), + is_sorted, + } + } + PageType::DataPage => { + assert!(page_header.data_page_header.is_some()); + let header = page_header.data_page_header.unwrap(); + Page::DataPage { + buf: buffer, + num_values: header.num_values as u32, + encoding: Encoding::from(header.encoding), + def_level_encoding: Encoding::from(header.definition_level_encoding), + rep_level_encoding: Encoding::from(header.repetition_level_encoding), + statistics: statistics::from_thrift(physical_type, header.statistics), + } + } + PageType::DataPageV2 => { + assert!(page_header.data_page_header_v2.is_some()); + let header = page_header.data_page_header_v2.unwrap(); + let is_compressed = header.is_compressed.unwrap_or(true); + Page::DataPageV2 { + buf: buffer, + num_values: header.num_values as u32, + encoding: Encoding::from(header.encoding), + num_nulls: header.num_nulls as u32, + num_rows: header.num_rows as u32, + def_levels_byte_len: header.definition_levels_byte_length as u32, + rep_levels_byte_len: header.repetition_levels_byte_length as u32, + is_compressed, + statistics: statistics::from_thrift(physical_type, header.statistics), + } + } + _ => { + // For unknown page type (e.g., INDEX_PAGE), skip and read next. + unimplemented!("Page type {:?} is not supported", page_header.type_) + } + }; + + Ok(result) +} + /// A serialized implementation for Parquet [`PageReader`]. pub struct SerializedPageReader { // The file source buffer which references exactly the bytes for the column trunk @@ -262,16 +497,9 @@ impl SerializedPageReader { }; Ok(result) } - - /// Reads Page header from Thrift. - fn read_page_header(&mut self) -> Result { - let mut prot = TCompactInputProtocol::new(&mut self.buf); - let page_header = PageHeader::read_from_in_protocol(&mut prot)?; - Ok(page_header) - } } -impl Iterator for SerializedPageReader { +impl Iterator for SerializedPageReader { type Item = Result; fn next(&mut self) -> Option { @@ -279,111 +507,43 @@ impl Iterator for SerializedPageReader { } } -impl PageReader for SerializedPageReader { +impl PageReader for SerializedPageReader { fn get_next_page(&mut self) -> Result> { while self.seen_num_values < self.total_num_values { - let page_header = self.read_page_header()?; - - // When processing data page v2, depending on enabled compression for the - // page, we should account for uncompressed data ('offset') of - // repetition and definition levels. - // - // We always use 0 offset for other pages other than v2, `true` flag means - // that compression will be applied if decompressor is defined - let mut offset: usize = 0; - let mut can_decompress = true; - - if let Some(ref header_v2) = page_header.data_page_header_v2 { - offset = (header_v2.definition_levels_byte_length - + header_v2.repetition_levels_byte_length) - as usize; - // When is_compressed flag is missing the page is considered compressed - can_decompress = header_v2.is_compressed.unwrap_or(true); - } - - let compressed_len = page_header.compressed_page_size as usize - offset; - let uncompressed_len = page_header.uncompressed_page_size as usize - offset; - // We still need to read all bytes from buffered stream - let mut buffer = vec![0; offset + compressed_len]; - self.buf.read_exact(&mut buffer)?; - - // TODO: page header could be huge because of statistics. We should set a - // maximum page header size and abort if that is exceeded. - if let Some(decompressor) = self.decompressor.as_mut() { - if can_decompress { - let mut decompressed_buffer = Vec::with_capacity(uncompressed_len); - let decompressed_size = decompressor - .decompress(&buffer[offset..], &mut decompressed_buffer)?; - if decompressed_size != uncompressed_len { - return Err(general_err!( - "Actual decompressed size doesn't match the expected one ({} vs {})", - decompressed_size, - uncompressed_len - )); - } - if offset == 0 { - buffer = decompressed_buffer; - } else { - // Prepend saved offsets to the buffer - buffer.truncate(offset); - buffer.append(&mut decompressed_buffer); - } - } + let page_header = read_page_header(&mut self.buf)?; + + let to_read = page_header.compressed_page_size as usize; + let mut buffer = Vec::with_capacity(to_read); + let read = (&mut self.buf) + .take(to_read as u64) + .read_to_end(&mut buffer)?; + + if read != to_read { + return Err(eof_err!( + "Expected to read {} bytes of page, read only {}", + to_read, + read + )); } + let buffer = ByteBufferPtr::new(buffer); let result = match page_header.type_ { - PageType::DictionaryPage => { - assert!(page_header.dictionary_page_header.is_some()); - let dict_header = - page_header.dictionary_page_header.as_ref().unwrap(); - let is_sorted = dict_header.is_sorted.unwrap_or(false); - Page::DictionaryPage { - buf: ByteBufferPtr::new(buffer), - num_values: dict_header.num_values as u32, - encoding: Encoding::from(dict_header.encoding), - is_sorted, - } - } - PageType::DataPage => { - assert!(page_header.data_page_header.is_some()); - let header = page_header.data_page_header.unwrap(); - self.seen_num_values += header.num_values as i64; - Page::DataPage { - buf: ByteBufferPtr::new(buffer), - num_values: header.num_values as u32, - encoding: Encoding::from(header.encoding), - def_level_encoding: Encoding::from( - header.definition_level_encoding, - ), - rep_level_encoding: Encoding::from( - header.repetition_level_encoding, - ), - statistics: statistics::from_thrift( - self.physical_type, - header.statistics, - ), - } - } - PageType::DataPageV2 => { - assert!(page_header.data_page_header_v2.is_some()); - let header = page_header.data_page_header_v2.unwrap(); - let is_compressed = header.is_compressed.unwrap_or(true); - self.seen_num_values += header.num_values as i64; - Page::DataPageV2 { - buf: ByteBufferPtr::new(buffer), - num_values: header.num_values as u32, - encoding: Encoding::from(header.encoding), - num_nulls: header.num_nulls as u32, - num_rows: header.num_rows as u32, - def_levels_byte_len: header.definition_levels_byte_length as u32, - rep_levels_byte_len: header.repetition_levels_byte_length as u32, - is_compressed, - statistics: statistics::from_thrift( - self.physical_type, - header.statistics, - ), - } + PageType::DataPage | PageType::DataPageV2 => { + let decoded = decode_page( + page_header, + buffer, + self.physical_type, + self.decompressor.as_mut(), + )?; + self.seen_num_values += decoded.num_values() as i64; + decoded } + PageType::DictionaryPage => decode_page( + page_header, + buffer, + self.physical_type, + self.decompressor.as_mut(), + )?, _ => { // For unknown page type (e.g., INDEX_PAGE), skip and read next. continue; @@ -400,10 +560,12 @@ impl PageReader for SerializedPageReader { #[cfg(test)] mod tests { use super::*; - use crate::basic::ColumnOrder; + use crate::basic::{self, ColumnOrder}; + use crate::file::page_index::index::Index; use crate::record::RowAccessor; use crate::schema::parser::parse_message_type; use crate::util::test_common::{get_test_file, get_test_path}; + use parquet_format::BoundaryOrder; use std::sync::Arc; #[test] @@ -412,7 +574,7 @@ mod tests { get_test_file("alltypes_plain.parquet") .read_to_end(&mut buf) .unwrap(); - let cursor = SliceableCursor::new(buf); + let cursor = Bytes::from(buf); let read_from_cursor = SerializedFileReader::new(cursor).unwrap(); let test_file = get_test_file("alltypes_plain.parquet"); @@ -531,9 +693,9 @@ mod tests { let file_metadata = metadata.file_metadata(); assert!(file_metadata.created_by().is_some()); assert_eq!( - file_metadata.created_by().as_ref().unwrap(), - "impala version 1.3.0-INTERNAL (build 8a48ddb1eff84592b3fc06bc6f51ec120e1fffc9)" - ); + file_metadata.created_by().unwrap(), + "impala version 1.3.0-INTERNAL (build 8a48ddb1eff84592b3fc06bc6f51ec120e1fffc9)" + ); assert!(file_metadata.key_value_metadata().is_none()); assert_eq!(file_metadata.num_rows(), 8); assert_eq!(file_metadata.version(), 1); @@ -621,7 +783,7 @@ mod tests { let file_metadata = metadata.file_metadata(); assert!(file_metadata.created_by().is_some()); assert_eq!( - file_metadata.created_by().as_ref().unwrap(), + file_metadata.created_by().unwrap(), "parquet-mr version 1.8.1 (build 4aba4dae7bb0d4edbcf7923ae1339f28fd3f7fcf)" ); assert!(file_metadata.key_value_metadata().is_some()); @@ -743,7 +905,6 @@ mod tests { .metadata .file_metadata() .key_value_metadata() - .as_ref() .unwrap(); assert_eq!(metadata.len(), 3); @@ -761,19 +922,187 @@ mod tests { } #[test] - fn test_file_reader_filter_row_groups() -> Result<()> { + fn test_file_reader_optional_metadata() { + // file with optional metadata: bloom filters, encoding stats, column index and offset index. + let file = get_test_file("data_index_bloom_encoding_stats.parquet"); + let file_reader = Arc::new(SerializedFileReader::new(file).unwrap()); + + let row_group_metadata = file_reader.metadata.row_group(0); + let col0_metadata = row_group_metadata.column(0); + + // test optional bloom filter offset + assert_eq!(col0_metadata.bloom_filter_offset().unwrap(), 192); + + // test page encoding stats + let page_encoding_stats = + col0_metadata.page_encoding_stats().unwrap().get(0).unwrap(); + + assert_eq!(page_encoding_stats.page_type, basic::PageType::DATA_PAGE); + assert_eq!(page_encoding_stats.encoding, Encoding::PLAIN); + assert_eq!(page_encoding_stats.count, 1); + + // test optional column index offset + assert_eq!(col0_metadata.column_index_offset().unwrap(), 156); + assert_eq!(col0_metadata.column_index_length().unwrap(), 25); + + // test optional offset index offset + assert_eq!(col0_metadata.offset_index_offset().unwrap(), 181); + assert_eq!(col0_metadata.offset_index_length().unwrap(), 11); + } + + #[test] + fn test_file_reader_with_no_filter() -> Result<()> { let test_file = get_test_file("alltypes_plain.parquet"); - let mut reader = SerializedFileReader::new(test_file)?; + let origin_reader = SerializedFileReader::new(test_file)?; + // test initial number of row groups + let metadata = origin_reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + Ok(()) + } + #[test] + fn test_file_reader_filter_row_groups_with_predicate() -> Result<()> { + let test_file = get_test_file("alltypes_plain.parquet"); + let read_options = ReadOptionsBuilder::new() + .with_predicate(Box::new(|_, _| false)) + .build(); + let reader = SerializedFileReader::new_with_options(test_file, read_options)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 0); + Ok(()) + } + + #[test] + fn test_file_reader_filter_row_groups_with_range() -> Result<()> { + let test_file = get_test_file("alltypes_plain.parquet"); + let origin_reader = SerializedFileReader::new(test_file)?; // test initial number of row groups + let metadata = origin_reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + let mid = get_midpoint_offset(metadata.row_group(0)); + + let test_file = get_test_file("alltypes_plain.parquet"); + let read_options = ReadOptionsBuilder::new().with_range(0, mid + 1).build(); + let reader = SerializedFileReader::new_with_options(test_file, read_options)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + + let test_file = get_test_file("alltypes_plain.parquet"); + let read_options = ReadOptionsBuilder::new().with_range(0, mid).build(); + let reader = SerializedFileReader::new_with_options(test_file, read_options)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 0); + Ok(()) + } + + #[test] + fn test_file_reader_filter_row_groups_and_range() -> Result<()> { + let test_file = get_test_file("alltypes_plain.parquet"); + let origin_reader = SerializedFileReader::new(test_file)?; + let metadata = origin_reader.metadata(); + let mid = get_midpoint_offset(metadata.row_group(0)); + + // true, true predicate + let test_file = get_test_file("alltypes_plain.parquet"); + let read_options = ReadOptionsBuilder::new() + .with_predicate(Box::new(|_, _| true)) + .with_range(mid, mid + 1) + .build(); + let reader = SerializedFileReader::new_with_options(test_file, read_options)?; let metadata = reader.metadata(); assert_eq!(metadata.num_row_groups(), 1); - // test filtering out all row groups - reader.filter_row_groups(&|_, _| false); + // true, false predicate + let test_file = get_test_file("alltypes_plain.parquet"); + let read_options = ReadOptionsBuilder::new() + .with_predicate(Box::new(|_, _| true)) + .with_range(0, mid) + .build(); + let reader = SerializedFileReader::new_with_options(test_file, read_options)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 0); + + // false, true predicate + let test_file = get_test_file("alltypes_plain.parquet"); + let read_options = ReadOptionsBuilder::new() + .with_predicate(Box::new(|_, _| false)) + .with_range(mid, mid + 1) + .build(); + let reader = SerializedFileReader::new_with_options(test_file, read_options)?; let metadata = reader.metadata(); assert_eq!(metadata.num_row_groups(), 0); + // false, false predicate + let test_file = get_test_file("alltypes_plain.parquet"); + let read_options = ReadOptionsBuilder::new() + .with_predicate(Box::new(|_, _| false)) + .with_range(0, mid) + .build(); + let reader = SerializedFileReader::new_with_options(test_file, read_options)?; + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 0); Ok(()) } + + #[test] + // Use java parquet-tools get below pageIndex info + // !``` + // parquet-tools column-index ./data_index_bloom_encoding_stats.parquet + // row group 0: + // column index for column String: + // Boudary order: ASCENDING + // page-0 : + // null count min max + // 0 Hello today + // + // offset index for column String: + // page-0 : + // offset compressed size first row index + // 4 152 0 + ///``` + // + fn test_page_index_reader() { + let test_file = get_test_file("data_index_bloom_encoding_stats.parquet"); + let builder = ReadOptionsBuilder::new(); + //enable read page index + let options = builder.with_page_index().build(); + let reader_result = SerializedFileReader::new_with_options(test_file, options); + let reader = reader_result.unwrap(); + + // Test contents in Parquet metadata + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + + let page_indexes = metadata.page_indexes().unwrap(); + + // only one row group + assert_eq!(page_indexes.len(), 1); + let index = if let Index::BYTE_ARRAY(index) = page_indexes.get(0).unwrap() { + index + } else { + unreachable!() + }; + + assert_eq!(index.boundary_order, BoundaryOrder::Ascending); + let index_in_pages = &index.indexes; + + //only one page group + assert_eq!(index_in_pages.len(), 1); + + let page0 = index_in_pages.get(0).unwrap(); + let min = page0.min.as_ref().unwrap(); + let max = page0.max.as_ref().unwrap(); + assert_eq!("Hello", std::str::from_utf8(min.as_slice()).unwrap()); + assert_eq!("today", std::str::from_utf8(max.as_slice()).unwrap()); + + let offset_indexes = metadata.offset_indexes().unwrap(); + // only one row group + assert_eq!(offset_indexes.len(), 1); + let offset_index = offset_indexes.get(0).unwrap(); + let page_offset = offset_index.get(0).unwrap(); + + assert_eq!(4, page_offset.offset); + assert_eq!(152, page_offset.compressed_page_size); + assert_eq!(0, page_offset.first_row_index); + } } diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index e1c2dc6b616f..0a8fc331e7e1 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -18,189 +18,196 @@ //! Contains file writer API, and provides methods to write row groups and columns by //! using row group writers and column writers respectively. -use std::{ - io::{Seek, SeekFrom, Write}, - sync::Arc, -}; +use std::{io::Write, sync::Arc}; use byteorder::{ByteOrder, LittleEndian}; use parquet_format as parquet; use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; use crate::basic::PageType; +use crate::column::writer::{get_typed_column_writer_mut, ColumnWriterImpl}; use crate::column::{ page::{CompressedPage, Page, PageWriteSpec, PageWriter}, writer::{get_column_writer, ColumnWriter}, }; +use crate::data_type::DataType; use crate::errors::{ParquetError, Result}; use crate::file::{ metadata::*, properties::WriterPropertiesPtr, statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC, }; use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr}; -use crate::util::io::{FileSink, Position}; - -// Exposed publically so client code can implement [`ParquetWriter`] -pub use crate::util::io::TryClone; +use crate::util::io::TryClone; -// Exposed publically for convenience of writing Parquet to a buffer of bytes -pub use crate::util::cursor::InMemoryWriteableCursor; +/// A wrapper around a [`Write`] that keeps track of the number +/// of bytes that have been written +pub struct TrackedWrite { + inner: W, + bytes_written: usize, +} -// ---------------------------------------------------------------------- -// APIs for file & row group writers +impl TrackedWrite { + /// Create a new [`TrackedWrite`] from a [`Write`] + pub fn new(inner: W) -> Self { + Self { + inner, + bytes_written: 0, + } + } -/// Parquet file writer API. -/// Provides methods to write row groups sequentially. -/// -/// The main workflow should be as following: -/// - Create file writer, this will open a new file and potentially write some metadata. -/// - Request a new row group writer by calling `next_row_group`. -/// - Once finished writing row group, close row group writer by passing it into -/// `close_row_group` method - this will finalise row group metadata and update metrics. -/// - Write subsequent row groups, if necessary. -/// - After all row groups have been written, close the file writer using `close` method. -pub trait FileWriter { - /// Creates new row group from this file writer. - /// In case of IO error or Thrift error, returns `Err`. - /// - /// There is no limit on a number of row groups in a file; however, row groups have - /// to be written sequentially. Every time the next row group is requested, the - /// previous row group must be finalised and closed using `close_row_group` method. - fn next_row_group(&mut self) -> Result>; + /// Returns the number of bytes written to this instance + pub fn bytes_written(&self) -> usize { + self.bytes_written + } +} - /// Finalises and closes row group that was created using `next_row_group` method. - /// After calling this method, the next row group is available for writes. - fn close_row_group( - &mut self, - row_group_writer: Box, - ) -> Result<()>; +impl Write for TrackedWrite { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let bytes = self.inner.write(buf)?; + self.bytes_written += bytes; + Ok(bytes) + } - /// Closes and finalises file writer, returning the file metadata. - /// - /// All row groups must be appended before this method is called. - /// No writes are allowed after this point. - /// - /// Can be called multiple times. It is up to implementation to either result in - /// no-op, or return an `Err` for subsequent calls. - fn close(&mut self) -> Result; + fn flush(&mut self) -> std::io::Result<()> { + self.inner.flush() + } } -/// Parquet row group writer API. -/// Provides methods to access column writers in an iterator-like fashion, order is -/// guaranteed to match the order of schema leaves (column descriptors). +/// Callback invoked on closing a column chunk, arguments are: /// -/// All columns should be written sequentially; the main workflow is: -/// - Request the next column using `next_column` method - this will return `None` if no -/// more columns are available to write. -/// - Once done writing a column, close column writer with `close_column` method - this -/// will finalise column chunk metadata and update row group metrics. -/// - Once all columns have been written, close row group writer with `close` method - -/// it will return row group metadata and is no-op on already closed row group. -pub trait RowGroupWriter { - /// Returns the next column writer, if available; otherwise returns `None`. - /// In case of any IO error or Thrift error, or if row group writer has already been - /// closed returns `Err`. - /// - /// To request the next column writer, the previous one must be finalised and closed - /// using `close_column`. - fn next_column(&mut self) -> Result>; +/// - the number of bytes written +/// - the number of rows written +/// - the column chunk metadata +/// +pub type OnCloseColumnChunk<'a> = + Box Result<()> + 'a>; - /// Closes column writer that was created using `next_column` method. - /// This should be called before requesting the next column writer. - fn close_column(&mut self, column_writer: ColumnWriter) -> Result<()>; +/// Callback invoked on closing a row group, arguments are: +/// +/// - the row group metadata +pub type OnCloseRowGroup<'a> = Box Result<()> + 'a>; - /// Closes this row group writer and returns row group metadata. - /// After calling this method row group writer must not be used. - /// - /// It is recommended to call this method before requesting another row group, but it - /// will be closed automatically before returning a new row group. - /// - /// Can be called multiple times. In subsequent calls will result in no-op and return - /// already created row group metadata. - fn close(&mut self) -> Result; -} +#[deprecated = "use std::io::Write"] +pub trait ParquetWriter: Write + std::io::Seek + TryClone {} +#[allow(deprecated)] +impl ParquetWriter for T {} // ---------------------------------------------------------------------- // Serialized impl for file & row group writers -pub trait ParquetWriter: Write + Seek + TryClone {} -impl ParquetWriter for T {} - -/// A serialized implementation for Parquet [`FileWriter`]. -/// See documentation on file writer for more information. -pub struct SerializedFileWriter { - buf: W, +/// Parquet file writer API. +/// Provides methods to write row groups sequentially. +/// +/// The main workflow should be as following: +/// - Create file writer, this will open a new file and potentially write some metadata. +/// - Request a new row group writer by calling `next_row_group`. +/// - Once finished writing row group, close row group writer by calling `close` +/// - Write subsequent row groups, if necessary. +/// - After all row groups have been written, close the file writer using `close` method. +pub struct SerializedFileWriter { + buf: TrackedWrite, schema: TypePtr, descr: SchemaDescPtr, props: WriterPropertiesPtr, - total_num_rows: i64, row_groups: Vec, - previous_writer_closed: bool, - is_closed: bool, + row_group_index: usize, } -impl SerializedFileWriter { +impl SerializedFileWriter { /// Creates new file writer. - pub fn new( - mut buf: W, - schema: TypePtr, - properties: WriterPropertiesPtr, - ) -> Result { + pub fn new(buf: W, schema: TypePtr, properties: WriterPropertiesPtr) -> Result { + let mut buf = TrackedWrite::new(buf); Self::start_file(&mut buf)?; Ok(Self { buf, schema: schema.clone(), descr: Arc::new(SchemaDescriptor::new(schema)), props: properties, - total_num_rows: 0, - row_groups: Vec::new(), - previous_writer_closed: true, - is_closed: false, + row_groups: vec![], + row_group_index: 0, }) } - /// Writes magic bytes at the beginning of the file. - fn start_file(buf: &mut W) -> Result<()> { - buf.write_all(&PARQUET_MAGIC)?; - Ok(()) + /// Creates new row group from this file writer. + /// In case of IO error or Thrift error, returns `Err`. + /// + /// There is no limit on a number of row groups in a file; however, row groups have + /// to be written sequentially. Every time the next row group is requested, the + /// previous row group must be finalised and closed using `RowGroupWriter::close` method. + pub fn next_row_group(&mut self) -> Result> { + self.assert_previous_writer_closed()?; + self.row_group_index += 1; + + let row_groups = &mut self.row_groups; + let on_close = |metadata| { + row_groups.push(metadata); + Ok(()) + }; + + let row_group_writer = SerializedRowGroupWriter::new( + self.descr.clone(), + self.props.clone(), + &mut self.buf, + Some(Box::new(on_close)), + ); + Ok(row_group_writer) + } + + /// Returns metadata for any flushed row groups + pub fn flushed_row_groups(&self) -> &[RowGroupMetaDataPtr] { + &self.row_groups + } + + /// Closes and finalises file writer, returning the file metadata. + /// + /// All row groups must be appended before this method is called. + /// No writes are allowed after this point. + /// + /// Can be called multiple times. It is up to implementation to either result in + /// no-op, or return an `Err` for subsequent calls. + pub fn close(mut self) -> Result { + self.assert_previous_writer_closed()?; + let metadata = self.write_metadata()?; + Ok(metadata) } - /// Finalises active row group writer, otherwise no-op. - fn finalise_row_group_writer( - &mut self, - mut row_group_writer: Box, - ) -> Result<()> { - let row_group_metadata = row_group_writer.close()?; - self.total_num_rows += row_group_metadata.num_rows(); - self.row_groups.push(row_group_metadata); + /// Writes magic bytes at the beginning of the file. + fn start_file(buf: &mut TrackedWrite) -> Result<()> { + buf.write_all(&PARQUET_MAGIC)?; Ok(()) } /// Assembles and writes metadata at the end of the file. fn write_metadata(&mut self) -> Result { + let num_rows = self.row_groups.iter().map(|x| x.num_rows()).sum(); + + let row_groups = self + .row_groups + .as_slice() + .iter() + .map(|v| v.to_thrift()) + .collect(); + let file_metadata = parquet::FileMetaData { + num_rows, + row_groups, version: self.props.writer_version().as_num(), schema: types::to_thrift(self.schema.as_ref())?, - num_rows: self.total_num_rows as i64, - row_groups: self - .row_groups - .as_slice() - .iter() - .map(|v| v.to_thrift()) - .collect(), - key_value_metadata: self.props.key_value_metadata().to_owned(), + key_value_metadata: self.props.key_value_metadata().cloned(), created_by: Some(self.props.created_by().to_owned()), column_orders: None, + encryption_algorithm: None, + footer_signing_key_metadata: None, }; // Write file metadata - let start_pos = self.buf.seek(SeekFrom::Current(0))?; + let start_pos = self.buf.bytes_written(); { let mut protocol = TCompactOutputProtocol::new(&mut self.buf); file_metadata.write_to_out_protocol(&mut protocol)?; protocol.flush()?; } - let end_pos = self.buf.seek(SeekFrom::Current(0))?; + let end_pos = self.buf.bytes_written(); // Write footer let mut footer_buffer: [u8; FOOTER_SIZE] = [0; FOOTER_SIZE]; @@ -211,18 +218,9 @@ impl SerializedFileWriter { Ok(file_metadata) } - #[inline] - fn assert_closed(&self) -> Result<()> { - if self.is_closed { - Err(general_err!("File writer is closed")) - } else { - Ok(()) - } - } - #[inline] fn assert_previous_writer_closed(&self) -> Result<()> { - if !self.previous_writer_closed { + if self.row_group_index != self.row_groups.len() { Err(general_err!("Previous row group writer was not closed")) } else { Ok(()) @@ -230,157 +228,107 @@ impl SerializedFileWriter { } } -impl FileWriter for SerializedFileWriter { - #[inline] - fn next_row_group(&mut self) -> Result> { - self.assert_closed()?; - self.assert_previous_writer_closed()?; - let row_group_writer = SerializedRowGroupWriter::new( - self.descr.clone(), - self.props.clone(), - &self.buf, - ); - self.previous_writer_closed = false; - Ok(Box::new(row_group_writer)) - } - - #[inline] - fn close_row_group( - &mut self, - row_group_writer: Box, - ) -> Result<()> { - self.assert_closed()?; - let res = self.finalise_row_group_writer(row_group_writer); - self.previous_writer_closed = res.is_ok(); - res - } - - #[inline] - fn close(&mut self) -> Result { - self.assert_closed()?; - self.assert_previous_writer_closed()?; - let metadata = self.write_metadata()?; - self.is_closed = true; - Ok(metadata) - } -} - -/// A serialized implementation for Parquet [`RowGroupWriter`]. -/// Coordinates writing of a row group with column writers. -/// See documentation on row group writer for more information. -pub struct SerializedRowGroupWriter { +/// Parquet row group writer API. +/// Provides methods to access column writers in an iterator-like fashion, order is +/// guaranteed to match the order of schema leaves (column descriptors). +/// +/// All columns should be written sequentially; the main workflow is: +/// - Request the next column using `next_column` method - this will return `None` if no +/// more columns are available to write. +/// - Once done writing a column, close column writer with `close` +/// - Once all columns have been written, close row group writer with `close` method - +/// it will return row group metadata and is no-op on already closed row group. +pub struct SerializedRowGroupWriter<'a, W: Write> { descr: SchemaDescPtr, props: WriterPropertiesPtr, - buf: W, + buf: &'a mut TrackedWrite, total_rows_written: Option, total_bytes_written: u64, column_index: usize, - previous_writer_closed: bool, row_group_metadata: Option, column_chunks: Vec, + on_close: Option>, } -impl SerializedRowGroupWriter { +impl<'a, W: Write> SerializedRowGroupWriter<'a, W> { + /// Creates a new `SerializedRowGroupWriter` with: + /// + /// - `schema_descr` - the schema to write + /// - `properties` - writer properties + /// - `buf` - the buffer to write data to + /// - `on_close` - an optional callback that will invoked on [`Self::close`] pub fn new( schema_descr: SchemaDescPtr, properties: WriterPropertiesPtr, - buf: &W, + buf: &'a mut TrackedWrite, + on_close: Option>, ) -> Self { let num_columns = schema_descr.num_columns(); Self { + buf, + on_close, + total_rows_written: None, descr: schema_descr, props: properties, - buf: buf.try_clone().unwrap(), - total_rows_written: None, - total_bytes_written: 0, column_index: 0, - previous_writer_closed: true, row_group_metadata: None, column_chunks: Vec::with_capacity(num_columns), + total_bytes_written: 0, } } - /// Checks and finalises current column writer. - fn finalise_column_writer(&mut self, writer: ColumnWriter) -> Result<()> { - let (bytes_written, rows_written, metadata) = match writer { - ColumnWriter::BoolColumnWriter(typed) => typed.close()?, - ColumnWriter::Int32ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int64ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int96ColumnWriter(typed) => typed.close()?, - ColumnWriter::FloatColumnWriter(typed) => typed.close()?, - ColumnWriter::DoubleColumnWriter(typed) => typed.close()?, - ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?, - ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?, - }; - - // Update row group writer metrics - self.total_bytes_written += bytes_written; - self.column_chunks.push(metadata); - if let Some(rows) = self.total_rows_written { - if rows != rows_written { - return Err(general_err!( - "Incorrect number of rows, expected {} != {} rows", - rows, - rows_written - )); - } - } else { - self.total_rows_written = Some(rows_written); - } - - Ok(()) - } - - #[inline] - fn assert_closed(&self) -> Result<()> { - if self.row_group_metadata.is_some() { - Err(general_err!("Row group writer is closed")) - } else { - Ok(()) - } - } - - #[inline] - fn assert_previous_writer_closed(&self) -> Result<()> { - if !self.previous_writer_closed { - Err(general_err!("Previous column writer was not closed")) - } else { - Ok(()) - } - } -} - -impl RowGroupWriter for SerializedRowGroupWriter { - #[inline] - fn next_column(&mut self) -> Result> { - self.assert_closed()?; + /// Returns the next column writer, if available; otherwise returns `None`. + /// In case of any IO error or Thrift error, or if row group writer has already been + /// closed returns `Err`. + pub fn next_column(&mut self) -> Result>> { self.assert_previous_writer_closed()?; if self.column_index >= self.descr.num_columns() { return Ok(None); } - let sink = FileSink::new(&self.buf); - let page_writer = Box::new(SerializedPageWriter::new(sink)); + let page_writer = Box::new(SerializedPageWriter::new(self.buf)); let column_writer = get_column_writer( self.descr.column(self.column_index), self.props.clone(), page_writer, ); self.column_index += 1; - self.previous_writer_closed = false; - Ok(Some(column_writer)) - } + let total_bytes_written = &mut self.total_bytes_written; + let total_rows_written = &mut self.total_rows_written; + let column_chunks = &mut self.column_chunks; + + let on_close = |bytes_written, rows_written, metadata| { + // Update row group writer metrics + *total_bytes_written += bytes_written; + column_chunks.push(metadata); + if let Some(rows) = *total_rows_written { + if rows != rows_written { + return Err(general_err!( + "Incorrect number of rows, expected {} != {} rows", + rows, + rows_written + )); + } + } else { + *total_rows_written = Some(rows_written); + } - #[inline] - fn close_column(&mut self, column_writer: ColumnWriter) -> Result<()> { - let res = self.finalise_column_writer(column_writer); - self.previous_writer_closed = res.is_ok(); - res + Ok(()) + }; + + Ok(Some(SerializedColumnWriter::new( + column_writer, + Some(Box::new(on_close)), + ))) } - #[inline] - fn close(&mut self) -> Result { + /// Closes this row group writer and returns row group metadata. + /// After calling this method row group writer must not be used. + /// + /// Can be called multiple times. In subsequent calls will result in no-op and return + /// already created row group metadata. + pub fn close(mut self) -> Result { if self.row_group_metadata.is_none() { self.assert_previous_writer_closed()?; @@ -391,25 +339,86 @@ impl RowGroupWriter for SerializedRowGroupWriter .set_num_rows(self.total_rows_written.unwrap_or(0) as i64) .build()?; - self.row_group_metadata = Some(Arc::new(row_group_metadata)); + let metadata = Arc::new(row_group_metadata); + self.row_group_metadata = Some(metadata.clone()); + + if let Some(on_close) = self.on_close.take() { + on_close(metadata)? + } } let metadata = self.row_group_metadata.as_ref().unwrap().clone(); Ok(metadata) } + + #[inline] + fn assert_previous_writer_closed(&self) -> Result<()> { + if self.column_index != self.column_chunks.len() { + Err(general_err!("Previous column writer was not closed")) + } else { + Ok(()) + } + } +} + +/// A wrapper around a [`ColumnWriter`] that invokes a callback on [`Self::close`] +pub struct SerializedColumnWriter<'a> { + inner: ColumnWriter<'a>, + on_close: Option>, +} + +impl<'a> SerializedColumnWriter<'a> { + /// Create a new [`SerializedColumnWriter`] from a `[`ColumnWriter`] and an + /// optional callback to be invoked on [`Self::close`] + pub fn new( + inner: ColumnWriter<'a>, + on_close: Option>, + ) -> Self { + Self { inner, on_close } + } + + /// Returns a reference to an untyped [`ColumnWriter`] + pub fn untyped(&mut self) -> &mut ColumnWriter<'a> { + &mut self.inner + } + + /// Returns a reference to a typed [`ColumnWriterImpl`] + pub fn typed(&mut self) -> &mut ColumnWriterImpl<'a, T> { + get_typed_column_writer_mut(&mut self.inner) + } + + /// Close this [`SerializedColumnWriter] + pub fn close(mut self) -> Result<()> { + let (bytes_written, rows_written, metadata) = match self.inner { + ColumnWriter::BoolColumnWriter(typed) => typed.close()?, + ColumnWriter::Int32ColumnWriter(typed) => typed.close()?, + ColumnWriter::Int64ColumnWriter(typed) => typed.close()?, + ColumnWriter::Int96ColumnWriter(typed) => typed.close()?, + ColumnWriter::FloatColumnWriter(typed) => typed.close()?, + ColumnWriter::DoubleColumnWriter(typed) => typed.close()?, + ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?, + ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?, + }; + + if let Some(on_close) = self.on_close.take() { + on_close(bytes_written, rows_written, metadata)? + } + + Ok(()) + } } /// A serialized implementation for Parquet [`PageWriter`]. /// Writes and serializes pages and metadata into output stream. /// /// `SerializedPageWriter` should not be used after calling `close()`. -pub struct SerializedPageWriter { - sink: T, +pub struct SerializedPageWriter<'a, W> { + sink: &'a mut TrackedWrite, } -impl SerializedPageWriter { +impl<'a, W: Write> SerializedPageWriter<'a, W> { /// Creates new page writer. - pub fn new(sink: T) -> Self { + pub fn new(sink: &'a mut TrackedWrite) -> Self { Self { sink } } @@ -417,13 +426,13 @@ impl SerializedPageWriter { /// Returns number of bytes that have been written into the sink. #[inline] fn serialize_page_header(&mut self, header: parquet::PageHeader) -> Result { - let start_pos = self.sink.pos(); + let start_pos = self.sink.bytes_written(); { let mut protocol = TCompactOutputProtocol::new(&mut self.sink); header.write_to_out_protocol(&mut protocol)?; protocol.flush()?; } - Ok((self.sink.pos() - start_pos) as usize) + Ok(self.sink.bytes_written() - start_pos) } /// Serializes column chunk into Thrift. @@ -437,7 +446,7 @@ impl SerializedPageWriter { } } -impl PageWriter for SerializedPageWriter { +impl<'a, W: Write> PageWriter for SerializedPageWriter<'a, W> { fn write_page(&mut self, page: CompressedPage) -> Result { let uncompressed_size = page.uncompressed_size(); let compressed_size = page.compressed_size(); @@ -504,7 +513,7 @@ impl PageWriter for SerializedPageWriter { } } - let start_pos = self.sink.pos(); + let start_pos = self.sink.bytes_written() as u64; let header_size = self.serialize_page_header(page_header)?; self.sink.write_all(page.data())?; @@ -514,7 +523,7 @@ impl PageWriter for SerializedPageWriter { spec.uncompressed_size = uncompressed_size + header_size; spec.compressed_size = compressed_size + header_size; spec.offset = start_pos; - spec.bytes_written = self.sink.pos() - start_pos; + spec.bytes_written = self.sink.bytes_written() as u64 - start_pos; // Number of values is incremented for data pages only if page_type == PageType::DATA_PAGE || page_type == PageType::DATA_PAGE_V2 { spec.num_values = num_values; @@ -537,65 +546,24 @@ impl PageWriter for SerializedPageWriter { mod tests { use super::*; + use bytes::Bytes; use std::{fs::File, io::Cursor}; - use crate::basic::{Compression, Encoding, IntType, LogicalType, Repetition, Type}; + use crate::basic::{Compression, Encoding, LogicalType, Repetition, Type}; use crate::column::page::PageReader; use crate::compression::{create_codec, Codec}; + use crate::data_type::Int32Type; use crate::file::{ properties::{WriterProperties, WriterVersion}, reader::{FileReader, SerializedFileReader, SerializedPageReader}, statistics::{from_thrift, to_thrift, Statistics}, }; use crate::record::RowAccessor; - use crate::util::{memory::ByteBufferPtr, test_common::get_temp_file}; - - #[test] - fn test_file_writer_error_after_close() { - let file = get_temp_file("test_file_writer_error_after_close", &[]); - let schema = Arc::new(types::Type::group_type_builder("schema").build().unwrap()); - let props = Arc::new(WriterProperties::builder().build()); - let mut writer = SerializedFileWriter::new(file, schema, props).unwrap(); - writer.close().unwrap(); - { - let res = writer.next_row_group(); - assert!(res.is_err()); - if let Err(err) = res { - assert_eq!(format!("{}", err), "Parquet error: File writer is closed"); - } - } - { - let res = writer.close(); - assert!(res.is_err()); - if let Err(err) = res { - assert_eq!(format!("{}", err), "Parquet error: File writer is closed"); - } - } - } - - #[test] - fn test_row_group_writer_error_after_close() { - let file = get_temp_file("test_file_writer_row_group_error_after_close", &[]); - let schema = Arc::new(types::Type::group_type_builder("schema").build().unwrap()); - let props = Arc::new(WriterProperties::builder().build()); - let mut writer = SerializedFileWriter::new(file, schema, props).unwrap(); - let mut row_group_writer = writer.next_row_group().unwrap(); - row_group_writer.close().unwrap(); - - let res = row_group_writer.next_column(); - assert!(res.is_err()); - if let Err(err) = res { - assert_eq!( - format!("{}", err), - "Parquet error: Row group writer is closed" - ); - } - } + use crate::util::memory::ByteBufferPtr; #[test] fn test_row_group_writer_error_not_all_columns_written() { - let file = - get_temp_file("test_row_group_writer_error_not_all_columns_written", &[]); + let file = tempfile::tempfile().unwrap(); let schema = Arc::new( types::Type::group_type_builder("schema") .with_fields(&mut vec![Arc::new( @@ -608,7 +576,7 @@ mod tests { ); let props = Arc::new(WriterProperties::builder().build()); let mut writer = SerializedFileWriter::new(file, schema, props).unwrap(); - let mut row_group_writer = writer.next_row_group().unwrap(); + let row_group_writer = writer.next_row_group().unwrap(); let res = row_group_writer.close(); assert!(res.is_err()); if let Err(err) = res { @@ -621,7 +589,7 @@ mod tests { #[test] fn test_row_group_writer_num_records_mismatch() { - let file = get_temp_file("test_row_group_writer_num_records_mismatch", &[]); + let file = tempfile::tempfile().unwrap(); let schema = Arc::new( types::Type::group_type_builder("schema") .with_fields(&mut vec![ @@ -646,29 +614,28 @@ mod tests { let mut row_group_writer = writer.next_row_group().unwrap(); let mut col_writer = row_group_writer.next_column().unwrap().unwrap(); - if let ColumnWriter::Int32ColumnWriter(ref mut typed) = col_writer { - typed.write_batch(&[1, 2, 3], None, None).unwrap(); - } - row_group_writer.close_column(col_writer).unwrap(); + col_writer + .typed::() + .write_batch(&[1, 2, 3], None, None) + .unwrap(); + col_writer.close().unwrap(); let mut col_writer = row_group_writer.next_column().unwrap().unwrap(); - if let ColumnWriter::Int32ColumnWriter(ref mut typed) = col_writer { - typed.write_batch(&[1, 2], None, None).unwrap(); - } + col_writer + .typed::() + .write_batch(&[1, 2], None, None) + .unwrap(); - let res = row_group_writer.close_column(col_writer); - assert!(res.is_err()); - if let Err(err) = res { - assert_eq!( - format!("{}", err), - "Parquet error: Incorrect number of rows, expected 3 != 2 rows" - ); - } + let err = col_writer.close().unwrap_err(); + assert_eq!( + err.to_string(), + "Parquet error: Incorrect number of rows, expected 3 != 2 rows" + ); } #[test] fn test_file_writer_empty_file() { - let file = get_temp_file("test_file_writer_write_empty_file", &[]); + let file = tempfile::tempfile().unwrap(); let schema = Arc::new( types::Type::group_type_builder("schema") @@ -681,7 +648,7 @@ mod tests { .unwrap(), ); let props = Arc::new(WriterProperties::builder().build()); - let mut writer = + let writer = SerializedFileWriter::new(file.try_clone().unwrap(), schema, props).unwrap(); writer.close().unwrap(); @@ -691,7 +658,7 @@ mod tests { #[test] fn test_file_writer_with_metadata() { - let file = get_temp_file("test_file_writer_write_with_metadata", &[]); + let file = tempfile::tempfile().unwrap(); let schema = Arc::new( types::Type::group_type_builder("schema") @@ -711,7 +678,7 @@ mod tests { )])) .build(), ); - let mut writer = + let writer = SerializedFileWriter::new(file.try_clone().unwrap(), schema, props).unwrap(); writer.close().unwrap(); @@ -730,11 +697,11 @@ mod tests { #[test] fn test_file_writer_v2_with_metadata() { - let file = get_temp_file("test_file_writer_v2_write_with_metadata", &[]); - let field_logical_type = Some(LogicalType::INTEGER(IntType { + let file = tempfile::tempfile().unwrap(); + let field_logical_type = Some(LogicalType::Integer { bit_width: 8, is_signed: false, - })); + }); let field = Arc::new( types::Type::primitive_type_builder("col1", Type::INT32) .with_logical_type(field_logical_type.clone()) @@ -757,7 +724,7 @@ mod tests { .set_writer_version(WriterVersion::PARQUET_2_0) .build(), ); - let mut writer = + let writer = SerializedFileWriter::new(file.try_clone().unwrap(), schema, props).unwrap(); writer.close().unwrap(); @@ -783,19 +750,19 @@ mod tests { #[test] fn test_file_writer_empty_row_groups() { - let file = get_temp_file("test_file_writer_write_empty_row_groups", &[]); + let file = tempfile::tempfile().unwrap(); test_file_roundtrip(file, vec![]); } #[test] fn test_file_writer_single_row_group() { - let file = get_temp_file("test_file_writer_write_single_row_group", &[]); + let file = tempfile::tempfile().unwrap(); test_file_roundtrip(file, vec![vec![1, 2, 3, 4, 5]]); } #[test] fn test_file_writer_multiple_row_groups() { - let file = get_temp_file("test_file_writer_write_multiple_row_groups", &[]); + let file = tempfile::tempfile().unwrap(); test_file_roundtrip( file, vec![ @@ -809,7 +776,7 @@ mod tests { #[test] fn test_file_writer_multiple_large_row_groups() { - let file = get_temp_file("test_file_writer_multiple_large_row_groups", &[]); + let file = tempfile::tempfile().unwrap(); test_file_roundtrip( file, vec![vec![123; 1024], vec![124; 1000], vec![125; 15], vec![]], @@ -970,8 +937,8 @@ mod tests { let mut buffer: Vec = vec![]; let mut result_pages: Vec = vec![]; { - let cursor = Cursor::new(&mut buffer); - let mut page_writer = SerializedPageWriter::new(cursor); + let mut writer = TrackedWrite::new(&mut buffer); + let mut page_writer = SerializedPageWriter::new(&mut writer); for page in compressed_pages { page_writer.write_page(page).unwrap(); @@ -1038,24 +1005,20 @@ mod tests { ); let mut rows: i64 = 0; - for subset in &data { + for (idx, subset) in data.iter().enumerate() { let mut row_group_writer = file_writer.next_row_group().unwrap(); - let col_writer = row_group_writer.next_column().unwrap(); - if let Some(mut writer) = col_writer { - match writer { - ColumnWriter::Int32ColumnWriter(ref mut typed) => { - rows += - typed.write_batch(&subset[..], None, None).unwrap() as i64; - } - _ => { - unimplemented!(); - } - } - row_group_writer.close_column(writer).unwrap(); + if let Some(mut writer) = row_group_writer.next_column().unwrap() { + rows += writer + .typed::() + .write_batch(&subset[..], None, None) + .unwrap() as i64; + writer.close().unwrap(); } - file_writer.close_row_group(row_group_writer).unwrap(); + let last_group = row_group_writer.close().unwrap(); + let flushed = file_writer.flushed_row_groups(); + assert_eq!(flushed.len(), idx + 1); + assert_eq!(flushed[idx].as_ref(), last_group.as_ref()); } - file_writer.close().unwrap(); let reader = assert_send(SerializedFileReader::new(file).unwrap()); @@ -1065,13 +1028,13 @@ mod tests { rows, "row count in metadata not equal to number of rows written" ); - for i in 0..reader.num_row_groups() { + for (i, item) in data.iter().enumerate().take(reader.num_row_groups()) { let row_group_reader = reader.get_row_group(i).unwrap(); let iter = row_group_reader.get_row_iter(None).unwrap(); let res = iter .map(|elem| elem.get_int(0).unwrap()) .collect::>(); - assert_eq!(res, data[i]); + assert_eq!(res, *item); } } @@ -1100,7 +1063,7 @@ mod tests { } fn test_bytes_roundtrip(data: Vec>) { - let cursor = InMemoryWriteableCursor::default(); + let mut buffer = vec![]; let schema = Arc::new( types::Type::group_type_builder("schema") @@ -1118,32 +1081,24 @@ mod tests { { let props = Arc::new(WriterProperties::builder().build()); let mut writer = - SerializedFileWriter::new(cursor.clone(), schema, props).unwrap(); + SerializedFileWriter::new(&mut buffer, schema, props).unwrap(); for subset in &data { let mut row_group_writer = writer.next_row_group().unwrap(); - let col_writer = row_group_writer.next_column().unwrap(); - if let Some(mut writer) = col_writer { - match writer { - ColumnWriter::Int32ColumnWriter(ref mut typed) => { - rows += typed.write_batch(&subset[..], None, None).unwrap() - as i64; - } - _ => { - unimplemented!(); - } - } - row_group_writer.close_column(writer).unwrap(); + if let Some(mut writer) = row_group_writer.next_column().unwrap() { + rows += writer + .typed::() + .write_batch(&subset[..], None, None) + .unwrap() as i64; + + writer.close().unwrap(); } - writer.close_row_group(row_group_writer).unwrap(); + row_group_writer.close().unwrap(); } - writer.close().unwrap(); } - let buffer = cursor.into_inner().unwrap(); - - let reading_cursor = crate::file::serialized_reader::SliceableCursor::new(buffer); + let reading_cursor = Bytes::from(buffer); let reader = SerializedFileReader::new(reading_cursor).unwrap(); assert_eq!(reader.num_row_groups(), data.len()); @@ -1152,13 +1107,13 @@ mod tests { rows, "row count in metadata not equal to number of rows written" ); - for i in 0..reader.num_row_groups() { + for (i, item) in data.iter().enumerate().take(reader.num_row_groups()) { let row_group_reader = reader.get_row_group(i).unwrap(); let iter = row_group_reader.get_row_iter(None).unwrap(); let res = iter .map(|elem| elem.get_int(0).unwrap()) .collect::>(); - assert_eq!(res, data[i]); + assert_eq!(res, *item); } } } diff --git a/parquet/src/lib.rs b/parquet/src/lib.rs index 900e2b5c2b61..e86b9e65917a 100644 --- a/parquet/src/lib.rs +++ b/parquet/src/lib.rs @@ -15,43 +15,82 @@ // specific language governing permissions and limitations // under the License. -#![allow(incomplete_features)] +//! This crate contains the official Native Rust implementation of +//! [Apache Parquet](https://parquet.apache.org/), part of +//! the [Apache Arrow](https://arrow.apache.org/) project. +//! +//! # Getting Started +//! Start with some examples: +//! +//! 1. [mod@file] for reading and writing parquet files using the +//! [ColumnReader](column::reader::ColumnReader) API. +//! +//! 2. [arrow] for reading and writing parquet files to Arrow +//! `RecordBatch`es +//! +//! 3. [arrow::async_reader] for `async` reading and writing parquet +//! files to Arrow `RecordBatch`es (requires the `async` feature). #![allow(dead_code)] #![allow(non_camel_case_types)] #![allow( - clippy::approx_constant, - clippy::cast_ptr_alignment, - clippy::float_cmp, - clippy::float_equality_without_abs, clippy::from_over_into, - clippy::many_single_char_names, - clippy::needless_range_loop, clippy::new_without_default, clippy::or_fun_call, - clippy::same_item_push, - clippy::too_many_arguments, - clippy::transmute_ptr_to_ptr, - clippy::upper_case_acronyms, - clippy::vec_init_then_push + clippy::too_many_arguments )] +/// Defines a module with an experimental public API +/// +/// The module will not be documented, and will only be public if the +/// experimental feature flag is enabled +/// +/// Experimental modules have no stability guarantees +macro_rules! experimental_mod { + ($module:ident $(, #[$meta:meta])*) => { + #[cfg(feature = "experimental")] + #[doc(hidden)] + $(#[$meta])* + pub mod $module; + #[cfg(not(feature = "experimental"))] + $(#[$meta])* + mod $module; + }; +} + +macro_rules! experimental_mod_crate { + ($module:ident $(, #[$meta:meta])*) => { + #[cfg(feature = "experimental")] + #[doc(hidden)] + $(#[$meta])* + pub mod $module; + #[cfg(not(feature = "experimental"))] + $(#[$meta])* + pub(crate) mod $module; + }; +} + #[macro_use] pub mod errors; pub mod basic; + #[macro_use] pub mod data_type; // Exported for external use, such as benchmarks +#[cfg(feature = "experimental")] +#[doc(hidden)] pub use self::encodings::{decoding, encoding}; + +#[cfg(feature = "experimental")] +#[doc(hidden)] pub use self::util::memory; -#[macro_use] -pub mod util; +experimental_mod!(util, #[macro_use]); #[cfg(any(feature = "arrow", test))] pub mod arrow; pub mod column; -pub mod compression; -mod encodings; +experimental_mod!(compression); +experimental_mod!(encodings); pub mod file; pub mod record; pub mod schema; diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs index 0293ffee7da0..5df21e4b0d00 100644 --- a/parquet/src/record/api.rs +++ b/parquet/src/record/api.rs @@ -27,7 +27,7 @@ use crate::data_type::{ByteArray, Decimal, Int96}; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; -#[cfg(feature = "cli")] +#[cfg(any(feature = "cli", test))] use serde_json::Value; /// Macro as a shortcut to generate 'not yet implemented' panic error. @@ -79,7 +79,7 @@ impl Row { } } - #[cfg(feature = "cli")] + #[cfg(any(feature = "cli", test))] pub fn to_json_value(&self) -> Value { Value::Object( self.fields @@ -650,7 +650,7 @@ impl Field { } } - #[cfg(feature = "cli")] + #[cfg(any(feature = "cli", test))] pub fn to_json_value(&self) -> Value { match &self { Field::Null => Value::Null, @@ -669,7 +669,7 @@ impl Field { Field::Double(n) => serde_json::Number::from_f64(*n) .map(Value::Number) .unwrap_or(Value::Null), - Field::Decimal(n) => Value::String(convert_decimal_to_string(&n)), + Field::Decimal(n) => Value::String(convert_decimal_to_string(n)), Field::Str(s) => Value::String(s.to_owned()), Field::Bytes(b) => Value::String(base64::encode(b.data())), Field::Date(d) => Value::String(convert_date_to_string(*d)), @@ -716,15 +716,19 @@ impl fmt::Display for Field { Field::Float(value) => { if !(1e-15..=1e19).contains(&value) { write!(f, "{:E}", value) + } else if value.trunc() == value { + write!(f, "{}.0", value) } else { - write!(f, "{:?}", value) + write!(f, "{}", value) } } Field::Double(value) => { if !(1e-15..=1e19).contains(&value) { write!(f, "{:E}", value) + } else if value.trunc() == value { + write!(f, "{}.0", value) } else { - write!(f, "{:?}", value) + write!(f, "{}", value) } } Field::Decimal(ref value) => { @@ -1372,8 +1376,8 @@ mod tests { assert_eq!(4, row.get_ushort(7).unwrap()); assert_eq!(5, row.get_uint(8).unwrap()); assert_eq!(6, row.get_ulong(9).unwrap()); - assert!(7.1 - row.get_float(10).unwrap() < f32::EPSILON); - assert!(8.1 - row.get_double(11).unwrap() < f64::EPSILON); + assert!((7.1 - row.get_float(10).unwrap()).abs() < f32::EPSILON); + assert!((8.1 - row.get_double(11).unwrap()).abs() < f64::EPSILON); assert_eq!("abc", row.get_string(12).unwrap()); assert_eq!(5, row.get_bytes(13).unwrap().len()); assert_eq!(7, row.get_decimal(14).unwrap().precision()); @@ -1520,10 +1524,10 @@ mod tests { Field::Float(9.2), Field::Float(10.3), ]); - assert!(10.3 - list.get_float(2).unwrap() < f32::EPSILON); + assert!((10.3 - list.get_float(2).unwrap()).abs() < f32::EPSILON); let list = make_list(vec![Field::Double(3.1415)]); - assert!(3.1415 - list.get_double(0).unwrap() < f64::EPSILON); + assert!((3.1415 - list.get_double(0).unwrap()).abs() < f64::EPSILON); let list = make_list(vec![Field::Str("abc".to_string())]); assert_eq!(&"abc".to_string(), list.get_string(0).unwrap()); @@ -1664,7 +1668,7 @@ mod tests { } #[test] - #[cfg(feature = "cli")] + #[cfg(any(feature = "cli", test))] fn test_to_json_value() { assert_eq!(Field::Null.to_json_value(), Value::Null); assert_eq!(Field::Bool(true).to_json_value(), Value::Bool(true)); @@ -1703,21 +1707,19 @@ mod tests { ); assert_eq!( Field::Float(5.0).to_json_value(), - Value::Number(serde_json::Number::from_f64(f64::from(5.0 as f32)).unwrap()) + Value::Number(serde_json::Number::from_f64(5.0).unwrap()) ); assert_eq!( Field::Float(5.1234).to_json_value(), - Value::Number( - serde_json::Number::from_f64(f64::from(5.1234 as f32)).unwrap() - ) + Value::Number(serde_json::Number::from_f64(5.1234_f32 as f64).unwrap()) ); assert_eq!( Field::Double(6.0).to_json_value(), - Value::Number(serde_json::Number::from_f64(6.0 as f64).unwrap()) + Value::Number(serde_json::Number::from_f64(6.0).unwrap()) ); assert_eq!( Field::Double(6.1234).to_json_value(), - Value::Number(serde_json::Number::from_f64(6.1234 as f64).unwrap()) + Value::Number(serde_json::Number::from_f64(6.1234).unwrap()) ); assert_eq!( Field::Str("abc".to_string()).to_json_value(), diff --git a/parquet/src/record/reader.rs b/parquet/src/record/reader.rs index 8f901f59c3d4..05b63661f09b 100644 --- a/parquet/src/record/reader.rs +++ b/parquet/src/record/reader.rs @@ -106,7 +106,7 @@ impl TreeBuilder { fn reader_tree( &self, field: TypePtr, - mut path: &mut Vec, + path: &mut Vec, mut curr_def_level: i16, mut curr_rep_level: i16, paths: &HashMap, @@ -136,7 +136,7 @@ impl TreeBuilder { .column_descr_ptr(); let col_reader = row_group_reader.get_column_reader(orig_index).unwrap(); let column = TripletIter::new(col_descr, col_reader, self.batch_size); - Reader::PrimitiveReader(field, column) + Reader::PrimitiveReader(field, Box::new(column)) } else { match field.get_basic_info().converted_type() { // List types @@ -160,7 +160,7 @@ impl TreeBuilder { // Support for backward compatible lists let reader = self.reader_tree( repeated_field, - &mut path, + path, curr_def_level, curr_rep_level, paths, @@ -180,7 +180,7 @@ impl TreeBuilder { let reader = self.reader_tree( child_field, - &mut path, + path, curr_def_level + 1, curr_rep_level + 1, paths, @@ -235,7 +235,7 @@ impl TreeBuilder { ); let key_reader = self.reader_tree( key_type.clone(), - &mut path, + path, curr_def_level + 1, curr_rep_level + 1, paths, @@ -245,7 +245,7 @@ impl TreeBuilder { let value_type = &key_value_type.get_fields()[1]; let value_reader = self.reader_tree( value_type.clone(), - &mut path, + path, curr_def_level + 1, curr_rep_level + 1, paths, @@ -278,7 +278,7 @@ impl TreeBuilder { let reader = self.reader_tree( Arc::new(required_field), - &mut path, + path, curr_def_level, curr_rep_level, paths, @@ -298,7 +298,7 @@ impl TreeBuilder { for child in field.get_fields() { let reader = self.reader_tree( child.clone(), - &mut path, + path, curr_def_level, curr_rep_level, paths, @@ -319,7 +319,7 @@ impl TreeBuilder { /// Reader tree for record assembly pub enum Reader { // Primitive reader with type information and triplet iterator - PrimitiveReader(TypePtr, TripletIter), + PrimitiveReader(TypePtr, Box), // Optional reader with definition level of a parent and a reader OptionReader(i16, Box), // Group (struct) reader with type information, definition level and list of child @@ -828,55 +828,25 @@ mod tests { // Convenient macros to assemble row, list, map, and group. macro_rules! row { - () => { + ($($e:tt)*) => { { - let result = Vec::new(); - make_row(result) - } - }; - ( $( $e:expr ), + ) => { - { - let mut result = Vec::new(); - $( - result.push($e); - )* - make_row(result) + make_row(vec![$($e)*]) } } } macro_rules! list { - () => { + ($($e:tt)*) => { { - let result = Vec::new(); - Field::ListInternal(make_list(result)) - } - }; - ( $( $e:expr ), + ) => { - { - let mut result = Vec::new(); - $( - result.push($e); - )* - Field::ListInternal(make_list(result)) + Field::ListInternal(make_list(vec![$($e)*])) } } } macro_rules! map { - () => { - { - let result = Vec::new(); - Field::MapInternal(make_map(result)) - } - }; - ( $( $e:expr ), + ) => { + ($($e:tt)*) => { { - let mut result = Vec::new(); - $( - result.push($e); - )* - Field::MapInternal(make_map(result)) + Field::MapInternal(make_map(vec![$($e)*])) } } } diff --git a/parquet/src/record/record_writer.rs b/parquet/src/record/record_writer.rs index 6668eec51494..fe803a7ff4ef 100644 --- a/parquet/src/record/record_writer.rs +++ b/parquet/src/record/record_writer.rs @@ -18,12 +18,12 @@ use crate::schema::types::TypePtr; use super::super::errors::ParquetError; -use super::super::file::writer::RowGroupWriter; +use super::super::file::writer::SerializedRowGroupWriter; pub trait RecordWriter { - fn write_to_row_group( + fn write_to_row_group( &self, - row_group_writer: &mut Box, + row_group_writer: &mut SerializedRowGroupWriter, ) -> Result<(), ParquetError>; /// Generated schema diff --git a/parquet/src/schema/parser.rs b/parquet/src/schema/parser.rs index ba1f566f1792..140e3e08500b 100644 --- a/parquet/src/schema/parser.rs +++ b/parquet/src/schema/parser.rs @@ -45,8 +45,7 @@ use std::sync::Arc; use crate::basic::{ - ConvertedType, DecimalType, IntType, LogicalType, Repetition, TimeType, TimeUnit, - TimestampType, Type as PhysicalType, + ConvertedType, LogicalType, Repetition, TimeUnit, Type as PhysicalType, }; use crate::errors::{ParquetError, Result}; use crate::schema::types::{Type, TypePtr}; @@ -77,7 +76,7 @@ impl<'a> Tokenizer<'a> { pub fn from_str(string: &'a str) -> Self { let vec = string .split_whitespace() - .flat_map(|t| Self::split_token(t)) + .flat_map(Self::split_token) .collect(); Tokenizer { tokens: vec, @@ -357,7 +356,7 @@ impl<'a> Parser<'a> { // Parse the concrete logical type if let Some(tpe) = &logical { match tpe { - LogicalType::DECIMAL(_) => { + LogicalType::Decimal { .. } => { if let Some("(") = self.tokenizer.next() { precision = parse_i32( self.tokenizer.next(), @@ -374,14 +373,11 @@ impl<'a> Parser<'a> { } else { scale = 0 } - logical = Some(LogicalType::DECIMAL(DecimalType { - scale, - precision, - })); + logical = Some(LogicalType::Decimal { scale, precision }); converted = ConvertedType::from(logical.clone()); } } - LogicalType::TIME(_) => { + LogicalType::Time { .. } => { if let Some("(") = self.tokenizer.next() { let unit = parse_timeunit( self.tokenizer.next(), @@ -395,10 +391,10 @@ impl<'a> Parser<'a> { "Failed to parse timezone info for TIME type", )?; assert_token(self.tokenizer.next(), ")")?; - logical = Some(LogicalType::TIME(TimeType { + logical = Some(LogicalType::Time { is_adjusted_to_u_t_c, unit, - })); + }); converted = ConvertedType::from(logical.clone()); } else { // Invalid token for unit @@ -406,7 +402,7 @@ impl<'a> Parser<'a> { } } } - LogicalType::TIMESTAMP(_) => { + LogicalType::Timestamp { .. } => { if let Some("(") = self.tokenizer.next() { let unit = parse_timeunit( self.tokenizer.next(), @@ -420,10 +416,10 @@ impl<'a> Parser<'a> { "Failed to parse timezone info for TIMESTAMP type", )?; assert_token(self.tokenizer.next(), ")")?; - logical = Some(LogicalType::TIMESTAMP(TimestampType { + logical = Some(LogicalType::Timestamp { is_adjusted_to_u_t_c, unit, - })); + }); converted = ConvertedType::from(logical.clone()); } else { // Invalid token for unit @@ -431,7 +427,7 @@ impl<'a> Parser<'a> { } } } - LogicalType::INTEGER(_) => { + LogicalType::Integer { .. } => { if let Some("(") = self.tokenizer.next() { let bit_width = parse_i32( self.tokenizer.next(), @@ -453,7 +449,7 @@ impl<'a> Parser<'a> { } } _ => { - return Err(general_err!("Logical type INTEGER cannot be used with physical type {}", physical_type)) + return Err(general_err!("Logical type Integer cannot be used with physical type {}", physical_type)) } } if let Some(",") = self.tokenizer.next() { @@ -463,10 +459,10 @@ impl<'a> Parser<'a> { "Failed to parse is_signed for INTEGER type", )?; assert_token(self.tokenizer.next(), ")")?; - logical = Some(LogicalType::INTEGER(IntType { + logical = Some(LogicalType::Integer { bit_width, is_signed, - })); + }); converted = ConvertedType::from(logical.clone()); } else { // Invalid token for unit @@ -925,10 +921,10 @@ mod tests { "f1", PhysicalType::FIXED_LEN_BYTE_ARRAY, ) - .with_logical_type(Some(LogicalType::DECIMAL(DecimalType { + .with_logical_type(Some(LogicalType::Decimal { precision: 9, scale: 3, - }))) + })) .with_converted_type(ConvertedType::DECIMAL) .with_length(5) .with_precision(9) @@ -941,10 +937,10 @@ mod tests { "f2", PhysicalType::FIXED_LEN_BYTE_ARRAY, ) - .with_logical_type(Some(LogicalType::DECIMAL(DecimalType { + .with_logical_type(Some(LogicalType::Decimal { precision: 38, scale: 18, - }))) + })) .with_converted_type(ConvertedType::DECIMAL) .with_length(16) .with_precision(38) @@ -992,9 +988,7 @@ mod tests { Arc::new( Type::group_type_builder("a1") .with_repetition(Repetition::OPTIONAL) - .with_logical_type(Some(LogicalType::LIST( - Default::default(), - ))) + .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) .with_fields(&mut vec![Arc::new( Type::primitive_type_builder( @@ -1012,9 +1006,7 @@ mod tests { Arc::new( Type::group_type_builder("b1") .with_repetition(Repetition::OPTIONAL) - .with_logical_type(Some(LogicalType::LIST( - Default::default(), - ))) + .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) .with_fields(&mut vec![Arc::new( Type::group_type_builder("b2") @@ -1101,7 +1093,7 @@ mod tests { ), Arc::new( Type::primitive_type_builder("_5", PhysicalType::INT32) - .with_logical_type(Some(LogicalType::DATE(Default::default()))) + .with_logical_type(Some(LogicalType::Date)) .with_converted_type(ConvertedType::DATE) .build() .unwrap(), @@ -1148,20 +1140,20 @@ mod tests { Arc::new( Type::primitive_type_builder("_1", PhysicalType::INT32) .with_repetition(Repetition::REQUIRED) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 8, is_signed: true, - }))) + })) .build() .unwrap(), ), Arc::new( Type::primitive_type_builder("_2", PhysicalType::INT32) .with_repetition(Repetition::REQUIRED) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 16, is_signed: false, - }))) + })) .build() .unwrap(), ), @@ -1179,49 +1171,49 @@ mod tests { ), Arc::new( Type::primitive_type_builder("_5", PhysicalType::INT32) - .with_logical_type(Some(LogicalType::DATE(Default::default()))) + .with_logical_type(Some(LogicalType::Date)) .build() .unwrap(), ), Arc::new( Type::primitive_type_builder("_6", PhysicalType::INT32) - .with_logical_type(Some(LogicalType::TIME(TimeType { + .with_logical_type(Some(LogicalType::Time { unit: TimeUnit::MILLIS(Default::default()), is_adjusted_to_u_t_c: false, - }))) + })) .build() .unwrap(), ), Arc::new( Type::primitive_type_builder("_7", PhysicalType::INT64) - .with_logical_type(Some(LogicalType::TIME(TimeType { + .with_logical_type(Some(LogicalType::Time { unit: TimeUnit::MICROS(Default::default()), is_adjusted_to_u_t_c: true, - }))) + })) .build() .unwrap(), ), Arc::new( Type::primitive_type_builder("_8", PhysicalType::INT64) - .with_logical_type(Some(LogicalType::TIMESTAMP(TimestampType { + .with_logical_type(Some(LogicalType::Timestamp { unit: TimeUnit::MILLIS(Default::default()), is_adjusted_to_u_t_c: true, - }))) + })) .build() .unwrap(), ), Arc::new( Type::primitive_type_builder("_9", PhysicalType::INT64) - .with_logical_type(Some(LogicalType::TIMESTAMP(TimestampType { + .with_logical_type(Some(LogicalType::Timestamp { unit: TimeUnit::NANOS(Default::default()), is_adjusted_to_u_t_c: false, - }))) + })) .build() .unwrap(), ), Arc::new( Type::primitive_type_builder("_10", PhysicalType::BYTE_ARRAY) - .with_logical_type(Some(LogicalType::STRING(Default::default()))) + .with_logical_type(Some(LogicalType::String)) .build() .unwrap(), ), diff --git a/parquet/src/schema/printer.rs b/parquet/src/schema/printer.rs index de1934ff5616..5cfd30dd977c 100644 --- a/parquet/src/schema/printer.rs +++ b/parquet/src/schema/printer.rs @@ -163,6 +163,31 @@ fn print_column_chunk_metadata( Some(stats) => stats.to_string(), }; writeln!(out, "statistics: {}", statistics_str); + let bloom_filter_offset_str = match cc_metadata.bloom_filter_offset() { + None => "N/A".to_owned(), + Some(bfo) => bfo.to_string(), + }; + writeln!(out, "bloom filter offset: {}", bloom_filter_offset_str); + let offset_index_offset_str = match cc_metadata.offset_index_offset() { + None => "N/A".to_owned(), + Some(oio) => oio.to_string(), + }; + writeln!(out, "offset index offset: {}", offset_index_offset_str); + let offset_index_length_str = match cc_metadata.offset_index_length() { + None => "N/A".to_owned(), + Some(oil) => oil.to_string(), + }; + writeln!(out, "offset index length: {}", offset_index_length_str); + let column_index_offset_str = match cc_metadata.column_index_offset() { + None => "N/A".to_owned(), + Some(cio) => cio.to_string(), + }; + writeln!(out, "column index offset: {}", column_index_offset_str); + let column_index_length_str = match cc_metadata.column_index_length() { + None => "N/A".to_owned(), + Some(cil) => cil.to_string(), + }; + writeln!(out, "column index length: {}", column_index_length_str); writeln!(out); } @@ -206,47 +231,52 @@ fn print_timeunit(unit: &TimeUnit) -> &str { #[inline] fn print_logical_and_converted( - logical_type: &Option, + logical_type: Option<&LogicalType>, converted_type: ConvertedType, precision: i32, scale: i32, ) -> String { match logical_type { Some(logical_type) => match logical_type { - LogicalType::INTEGER(t) => { - format!("INTEGER({},{})", t.bit_width, t.is_signed) + LogicalType::Integer { + bit_width, + is_signed, + } => { + format!("INTEGER({},{})", bit_width, is_signed) } - LogicalType::DECIMAL(t) => { - format!("DECIMAL({},{})", t.precision, t.scale) + LogicalType::Decimal { scale, precision } => { + format!("DECIMAL({},{})", precision, scale) } - LogicalType::TIMESTAMP(t) => { + LogicalType::Timestamp { + is_adjusted_to_u_t_c, + unit, + } => { format!( "TIMESTAMP({},{})", - print_timeunit(&t.unit), - t.is_adjusted_to_u_t_c + print_timeunit(unit), + is_adjusted_to_u_t_c ) } - LogicalType::TIME(t) => { - format!( - "TIME({},{})", - print_timeunit(&t.unit), - t.is_adjusted_to_u_t_c - ) + LogicalType::Time { + is_adjusted_to_u_t_c, + unit, + } => { + format!("TIME({},{})", print_timeunit(unit), is_adjusted_to_u_t_c) } - LogicalType::DATE(_) => "DATE".to_string(), - LogicalType::BSON(_) => "BSON".to_string(), - LogicalType::JSON(_) => "JSON".to_string(), - LogicalType::STRING(_) => "STRING".to_string(), - LogicalType::UUID(_) => "UUID".to_string(), - LogicalType::ENUM(_) => "ENUM".to_string(), - LogicalType::LIST(_) => "LIST".to_string(), - LogicalType::MAP(_) => "MAP".to_string(), - LogicalType::UNKNOWN(_) => "UNKNOWN".to_string(), + LogicalType::Date => "DATE".to_string(), + LogicalType::Bson => "BSON".to_string(), + LogicalType::Json => "JSON".to_string(), + LogicalType::String => "STRING".to_string(), + LogicalType::Uuid => "UUID".to_string(), + LogicalType::Enum => "ENUM".to_string(), + LogicalType::List => "LIST".to_string(), + LogicalType::Map => "MAP".to_string(), + LogicalType::Unknown => "UNKNOWN".to_string(), }, None => { // Also print converted type if it is available match converted_type { - ConvertedType::NONE => format!(""), + ConvertedType::NONE => String::new(), decimal @ ConvertedType::DECIMAL => { // For decimal type we should print precision and scale if they // are > 0, e.g. DECIMAL(9,2) - @@ -256,7 +286,7 @@ fn print_logical_and_converted( format!("({},{})", p, s) } (p, 0) if p > 0 => format!("({})", p), - _ => format!(""), + _ => String::new(), }; format!("{}{}", decimal, precision_scale) } @@ -290,7 +320,7 @@ impl<'a> Printer<'a> { // Also print logical type if it is available // If there is a logical type, do not print converted type let logical_type_str = print_logical_and_converted( - &basic_info.logical_type(), + basic_info.logical_type().as_ref(), basic_info.converted_type(), precision, scale, @@ -322,7 +352,7 @@ impl<'a> Printer<'a> { let r = basic_info.repetition(); write!(self.output, "{} group {} ", r, basic_info.name()); let logical_str = print_logical_and_converted( - &basic_info.logical_type(), + basic_info.logical_type().as_ref(), basic_info.converted_type(), 0, 0, @@ -354,10 +384,7 @@ mod tests { use std::sync::Arc; - use crate::basic::{ - DateType, DecimalType, IntType, LogicalType, Repetition, TimeType, TimestampType, - Type as PhysicalType, - }; + use crate::basic::{LogicalType, Repetition, Type as PhysicalType}; use crate::errors::Result; use crate::schema::{parser::parse_message_type, types::Type}; @@ -409,10 +436,10 @@ mod tests { build_primitive_type( "field", PhysicalType::INT32, - Some(LogicalType::INTEGER(IntType { + Some(LogicalType::Integer { bit_width: 32, is_signed: true, - })), + }), ConvertedType::NONE, Repetition::REQUIRED, ) @@ -423,10 +450,10 @@ mod tests { build_primitive_type( "field", PhysicalType::INT32, - Some(LogicalType::INTEGER(IntType { + Some(LogicalType::Integer { bit_width: 8, is_signed: false, - })), + }), ConvertedType::NONE, Repetition::OPTIONAL, ) @@ -437,10 +464,10 @@ mod tests { build_primitive_type( "field", PhysicalType::INT32, - Some(LogicalType::INTEGER(IntType { + Some(LogicalType::Integer { bit_width: 16, is_signed: true, - })), + }), ConvertedType::INT_16, Repetition::REPEATED, ) @@ -484,10 +511,10 @@ mod tests { build_primitive_type( "field", PhysicalType::INT64, - Some(LogicalType::TIMESTAMP(TimestampType { + Some(LogicalType::Timestamp { is_adjusted_to_u_t_c: true, unit: TimeUnit::MILLIS(Default::default()), - })), + }), ConvertedType::NONE, Repetition::REQUIRED, ) @@ -498,7 +525,7 @@ mod tests { build_primitive_type( "field", PhysicalType::INT32, - Some(LogicalType::DATE(DateType {})), + Some(LogicalType::Date), ConvertedType::NONE, Repetition::OPTIONAL, ) @@ -509,10 +536,10 @@ mod tests { build_primitive_type( "field", PhysicalType::INT32, - Some(LogicalType::TIME(TimeType { + Some(LogicalType::Time { unit: TimeUnit::MILLIS(Default::default()), is_adjusted_to_u_t_c: false, - })), + }), ConvertedType::TIME_MILLIS, Repetition::REQUIRED, ) @@ -545,7 +572,7 @@ mod tests { build_primitive_type( "field", PhysicalType::BYTE_ARRAY, - Some(LogicalType::JSON(Default::default())), + Some(LogicalType::Json), ConvertedType::JSON, Repetition::REQUIRED, ) @@ -556,7 +583,7 @@ mod tests { build_primitive_type( "field", PhysicalType::BYTE_ARRAY, - Some(LogicalType::BSON(Default::default())), + Some(LogicalType::Bson), ConvertedType::BSON, Repetition::REQUIRED, ) @@ -567,7 +594,7 @@ mod tests { build_primitive_type( "field", PhysicalType::BYTE_ARRAY, - Some(LogicalType::STRING(Default::default())), + Some(LogicalType::String), ConvertedType::NONE, Repetition::REQUIRED, ) @@ -609,7 +636,7 @@ mod tests { ), ( Type::primitive_type_builder("field", PhysicalType::FIXED_LEN_BYTE_ARRAY) - .with_logical_type(Some(LogicalType::UUID(Default::default()))) + .with_logical_type(Some(LogicalType::Uuid)) .with_length(16) .with_repetition(Repetition::REQUIRED) .build() @@ -621,10 +648,10 @@ mod tests { "decimal", PhysicalType::FIXED_LEN_BYTE_ARRAY, ) - .with_logical_type(Some(LogicalType::DECIMAL(DecimalType { + .with_logical_type(Some(LogicalType::Decimal { precision: 32, scale: 20, - }))) + })) .with_precision(32) .with_scale(20) .with_length(decimal_length_from_precision(32)) @@ -674,7 +701,7 @@ mod tests { .with_id(1) .build(); let f3 = Type::primitive_type_builder("f3", PhysicalType::BYTE_ARRAY) - .with_logical_type(Some(LogicalType::STRING(Default::default()))) + .with_logical_type(Some(LogicalType::String)) .with_id(1) .build(); let f4 = @@ -684,19 +711,20 @@ mod tests { .with_length(12) .with_id(2) .build(); - let mut struct_fields = Vec::new(); - struct_fields.push(Arc::new(f1.unwrap())); - struct_fields.push(Arc::new(f2.unwrap())); - struct_fields.push(Arc::new(f3.unwrap())); + + let mut struct_fields = vec![ + Arc::new(f1.unwrap()), + Arc::new(f2.unwrap()), + Arc::new(f3.unwrap()), + ]; let field = Type::group_type_builder("field") .with_repetition(Repetition::OPTIONAL) .with_fields(&mut struct_fields) .with_id(1) .build() .unwrap(); - let mut fields = Vec::new(); - fields.push(Arc::new(field)); - fields.push(Arc::new(f4.unwrap())); + + let mut fields = vec![Arc::new(field), Arc::new(f4.unwrap())]; let message = Type::group_type_builder("schema") .with_fields(&mut fields) .with_id(2) @@ -725,7 +753,7 @@ mod tests { let a1 = Type::group_type_builder("a1") .with_repetition(Repetition::OPTIONAL) - .with_logical_type(Some(LogicalType::LIST(Default::default()))) + .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) .with_fields(&mut vec![Arc::new(a2)]) .build() @@ -750,7 +778,7 @@ mod tests { let b1 = Type::group_type_builder("b1") .with_repetition(Repetition::OPTIONAL) - .with_logical_type(Some(LogicalType::LIST(Default::default()))) + .with_logical_type(Some(LogicalType::List)) .with_converted_type(ConvertedType::LIST) .with_fields(&mut vec![Arc::new(b2)]) .build() @@ -809,10 +837,10 @@ mod tests { fn test_print_and_parse_decimal() { let f1 = Type::primitive_type_builder("f1", PhysicalType::INT32) .with_repetition(Repetition::OPTIONAL) - .with_logical_type(Some(LogicalType::DECIMAL(DecimalType { + .with_logical_type(Some(LogicalType::Decimal { precision: 9, scale: 2, - }))) + })) .with_converted_type(ConvertedType::DECIMAL) .with_precision(9) .with_scale(2) @@ -821,10 +849,10 @@ mod tests { let f2 = Type::primitive_type_builder("f2", PhysicalType::INT32) .with_repetition(Repetition::OPTIONAL) - .with_logical_type(Some(LogicalType::DECIMAL(DecimalType { + .with_logical_type(Some(LogicalType::Decimal { precision: 9, scale: 0, - }))) + })) .with_converted_type(ConvertedType::DECIMAL) .with_precision(9) .with_scale(0) diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index d885ff1eb7f3..8d624fe3d185 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -22,7 +22,7 @@ use std::{collections::HashMap, convert::From, fmt, sync::Arc}; use parquet_format::SchemaElement; use crate::basic::{ - ConvertedType, LogicalType, Repetition, TimeType, TimeUnit, Type as PhysicalType, + ConvertedType, LogicalType, Repetition, TimeUnit, Type as PhysicalType, }; use crate::errors::{ParquetError, Result}; @@ -306,57 +306,57 @@ impl<'a> PrimitiveTypeBuilder<'a> { } // Check that logical type and physical type are compatible match (logical_type, self.physical_type) { - (LogicalType::MAP(_), _) | (LogicalType::LIST(_), _) => { + (LogicalType::Map, _) | (LogicalType::List, _) => { return Err(general_err!( "{:?} cannot be applied to a primitive type", logical_type )); } - (LogicalType::ENUM(_), PhysicalType::BYTE_ARRAY) => {} - (LogicalType::DECIMAL(t), _) => { + (LogicalType::Enum, PhysicalType::BYTE_ARRAY) => {} + (LogicalType::Decimal { scale, precision }, _) => { // Check that scale and precision are consistent with legacy values - if t.scale != self.scale { + if *scale != self.scale { return Err(general_err!( "DECIMAL logical type scale {} must match self.scale {}", - t.scale, + scale, self.scale )); } - if t.precision != self.precision { + if *precision != self.precision { return Err(general_err!( "DECIMAL logical type precision {} must match self.precision {}", - t.precision, + precision, self.precision )); } self.check_decimal_precision_scale()?; } - (LogicalType::DATE(_), PhysicalType::INT32) => {} + (LogicalType::Date, PhysicalType::INT32) => {} ( - LogicalType::TIME(TimeType { + LogicalType::Time { unit: TimeUnit::MILLIS(_), .. - }), + }, PhysicalType::INT32, ) => {} - (LogicalType::TIME(t), PhysicalType::INT64) => { - if t.unit == TimeUnit::MILLIS(Default::default()) { + (LogicalType::Time { unit, .. }, PhysicalType::INT64) => { + if *unit == TimeUnit::MILLIS(Default::default()) { return Err(general_err!( "Cannot use millisecond unit on INT64 type" )); } } - (LogicalType::TIMESTAMP(_), PhysicalType::INT64) => {} - (LogicalType::INTEGER(t), PhysicalType::INT32) - if t.bit_width <= 32 => {} - (LogicalType::INTEGER(t), PhysicalType::INT64) - if t.bit_width == 64 => {} + (LogicalType::Timestamp { .. }, PhysicalType::INT64) => {} + (LogicalType::Integer { bit_width, .. }, PhysicalType::INT32) + if *bit_width <= 32 => {} + (LogicalType::Integer { bit_width, .. }, PhysicalType::INT64) + if *bit_width == 64 => {} // Null type - (LogicalType::UNKNOWN(_), PhysicalType::INT32) => {} - (LogicalType::STRING(_), PhysicalType::BYTE_ARRAY) => {} - (LogicalType::JSON(_), PhysicalType::BYTE_ARRAY) => {} - (LogicalType::BSON(_), PhysicalType::BYTE_ARRAY) => {} - (LogicalType::UUID(_), PhysicalType::FIXED_LEN_BYTE_ARRAY) => {} + (LogicalType::Unknown, PhysicalType::INT32) => {} + (LogicalType::String, PhysicalType::BYTE_ARRAY) => {} + (LogicalType::Json, PhysicalType::BYTE_ARRAY) => {} + (LogicalType::Bson, PhysicalType::BYTE_ARRAY) => {} + (LogicalType::Uuid, PhysicalType::FIXED_LEN_BYTE_ARRAY) => {} (a, b) => { return Err(general_err!( "Cannot annotate {:?} from {} fields", @@ -467,13 +467,13 @@ impl<'a> PrimitiveTypeBuilder<'a> { return Err(general_err!("Invalid DECIMAL scale: {}", self.scale)); } - if self.scale >= self.precision { + if self.scale > self.precision { return Err(general_err!( - "Invalid DECIMAL: scale ({}) cannot be greater than or equal to precision \ + "Invalid DECIMAL: scale ({}) cannot be greater than precision \ ({})", - self.scale, - self.precision - )); + self.scale, + self.precision + )); } // Check precision and scale based on physical type limitations. @@ -838,6 +838,7 @@ impl ColumnDescriptor { /// A schema descriptor. This encapsulates the top-level schemas for all the columns, /// as well as all descriptors for all the primitive columns. +#[derive(PartialEq)] pub struct SchemaDescriptor { // The top-level schema (the "message" type). // This must be a `GroupType` where each field is a root column type in the schema. @@ -847,13 +848,13 @@ pub struct SchemaDescriptor { // `schema` in DFS order. leaves: Vec, - // Mapping from a leaf column's index to the root column type that it + // Mapping from a leaf column's index to the root column index that it // comes from. For instance: the leaf `a.b.c.d` would have a link back to `a`: // -- a <-----+ // -- -- b | // -- -- -- c | // -- -- -- -- d - leaf_to_base: Vec, + leaf_to_base: Vec, } impl fmt::Debug for SchemaDescriptor { @@ -871,9 +872,9 @@ impl SchemaDescriptor { assert!(tp.is_group(), "SchemaDescriptor should take a GroupType"); let mut leaves = vec![]; let mut leaf_to_base = Vec::new(); - for f in tp.get_fields() { + for (root_idx, f) in tp.get_fields().iter().enumerate() { let mut path = vec![]; - build_tree(f, f, 0, 0, &mut leaves, &mut leaf_to_base, &mut path); + build_tree(f, root_idx, 0, 0, &mut leaves, &mut leaf_to_base, &mut path); } Self { @@ -904,30 +905,35 @@ impl SchemaDescriptor { self.leaves.len() } - /// Returns column root [`Type`](crate::schema::types::Type) for a field position. + /// Returns column root [`Type`](crate::schema::types::Type) for a leaf position. pub fn get_column_root(&self, i: usize) -> &Type { let result = self.column_root_of(i); result.as_ref() } - /// Returns column root [`Type`](crate::schema::types::Type) pointer for a field + /// Returns column root [`Type`](crate::schema::types::Type) pointer for a leaf /// position. pub fn get_column_root_ptr(&self, i: usize) -> TypePtr { let result = self.column_root_of(i); result.clone() } - fn column_root_of(&self, i: usize) -> &Arc { + /// Returns the index of the root column for a field position + pub fn get_column_root_idx(&self, leaf: usize) -> usize { assert!( - i < self.leaves.len(), + leaf < self.leaves.len(), "Index out of bound: {} not in [0, {})", - i, + leaf, self.leaves.len() ); - self.leaf_to_base - .get(i) - .unwrap_or_else(|| panic!("Expected a value for index {} but found None", i)) + *self.leaf_to_base.get(leaf).unwrap_or_else(|| { + panic!("Expected a value for index {} but found None", leaf) + }) + } + + fn column_root_of(&self, i: usize) -> &TypePtr { + &self.schema.get_fields()[self.get_column_root_idx(i)] } /// Returns schema as [`Type`](crate::schema::types::Type). @@ -947,11 +953,11 @@ impl SchemaDescriptor { fn build_tree<'a>( tp: &'a TypePtr, - base_tp: &TypePtr, + root_idx: usize, mut max_rep_level: i16, mut max_def_level: i16, leaves: &mut Vec, - leaf_to_base: &mut Vec, + leaf_to_base: &mut Vec, path_so_far: &mut Vec<&'a str>, ) { assert!(tp.get_basic_info().has_repetition()); @@ -978,13 +984,13 @@ fn build_tree<'a>( max_rep_level, ColumnPath::new(path), ))); - leaf_to_base.push(base_tp.clone()); + leaf_to_base.push(root_idx); } Type::GroupType { ref fields, .. } => { for f in fields { build_tree( f, - base_tp, + root_idx, max_rep_level, max_def_level, leaves, @@ -1198,7 +1204,6 @@ fn to_thrift_helper(schema: &Type, elements: &mut Vec) { mod tests { use super::*; - use crate::basic::{DecimalType, IntType}; use crate::schema::parser::parse_message_type; // TODO: add tests for v2 types @@ -1206,10 +1211,10 @@ mod tests { #[test] fn test_primitive_type() { let mut result = Type::primitive_type_builder("foo", PhysicalType::INT32) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { bit_width: 32, is_signed: true, - }))) + })) .with_id(0) .build(); assert!(result.is_ok()); @@ -1221,10 +1226,10 @@ mod tests { assert_eq!(basic_info.repetition(), Repetition::OPTIONAL); assert_eq!( basic_info.logical_type(), - Some(LogicalType::INTEGER(IntType { + Some(LogicalType::Integer { bit_width: 32, is_signed: true - })) + }) ); assert_eq!(basic_info.converted_type(), ConvertedType::INT_32); assert_eq!(basic_info.id(), 0); @@ -1239,16 +1244,16 @@ mod tests { // Test illegal inputs with logical type result = Type::primitive_type_builder("foo", PhysicalType::INT64) .with_repetition(Repetition::REPEATED) - .with_logical_type(Some(LogicalType::INTEGER(IntType { + .with_logical_type(Some(LogicalType::Integer { is_signed: true, bit_width: 8, - }))) + })) .build(); assert!(result.is_err()); if let Err(e) = result { assert_eq!( format!("{}", e), - "Parquet error: Cannot annotate INTEGER(IntType { bit_width: 8, is_signed: true }) from INT64 fields" + "Parquet error: Cannot annotate Integer { bit_width: 8, is_signed: true } from INT64 fields" ); } @@ -1281,10 +1286,10 @@ mod tests { result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY) .with_repetition(Repetition::REQUIRED) - .with_logical_type(Some(LogicalType::DECIMAL(DecimalType { + .with_logical_type(Some(LogicalType::Decimal { scale: 32, precision: 12, - }))) + })) .with_precision(-1) .with_scale(-1) .build(); @@ -1345,10 +1350,19 @@ mod tests { if let Err(e) = result { assert_eq!( format!("{}", e), - "Parquet error: Invalid DECIMAL: scale (2) cannot be greater than or equal to precision (1)" + "Parquet error: Invalid DECIMAL: scale (2) cannot be greater than precision (1)" ); } + // It is OK if precision == scale + result = Type::primitive_type_builder("foo", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::REQUIRED) + .with_converted_type(ConvertedType::DECIMAL) + .with_precision(1) + .with_scale(1) + .build(); + assert!(result.is_ok()); + result = Type::primitive_type_builder("foo", PhysicalType::INT32) .with_repetition(Repetition::REQUIRED) .with_converted_type(ConvertedType::DECIMAL) @@ -1492,13 +1506,11 @@ mod tests { .build(); assert!(f2.is_ok()); - let mut fields = vec![]; - fields.push(Arc::new(f1.unwrap())); - fields.push(Arc::new(f2.unwrap())); + let mut fields = vec![Arc::new(f1.unwrap()), Arc::new(f2.unwrap())]; let result = Type::group_type_builder("foo") .with_repetition(Repetition::REPEATED) - .with_logical_type(Some(LogicalType::LIST(Default::default()))) + .with_logical_type(Some(LogicalType::List)) .with_fields(&mut fields) .with_id(1) .build(); @@ -1509,10 +1521,7 @@ mod tests { assert!(tp.is_group()); assert!(!tp.is_primitive()); assert_eq!(basic_info.repetition(), Repetition::REPEATED); - assert_eq!( - basic_info.logical_type(), - Some(LogicalType::LIST(Default::default())) - ); + assert_eq!(basic_info.logical_type(), Some(LogicalType::List)); assert_eq!(basic_info.converted_type(), ConvertedType::LIST); assert_eq!(basic_info.id(), 1); assert_eq!(tp.get_fields().len(), 2); diff --git a/parquet/src/schema/visitor.rs b/parquet/src/schema/visitor.rs index 8ed079fb4237..9d28fa5e8dcd 100644 --- a/parquet/src/schema/visitor.rs +++ b/parquet/src/schema/visitor.rs @@ -27,17 +27,30 @@ pub trait TypeVisitor { /// Default implementation when visiting a list. /// - /// It checks list type definition and calls `visit_list_with_item` with extracted + /// It checks list type definition and calls [`Self::visit_list_with_item`] with extracted /// item type. /// /// To fully understand this algorithm, please refer to /// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). + /// + /// For example, a standard list type looks like: + /// + /// ```text + /// required/optional group my_list (LIST) { + // repeated group list { + // required/optional binary element (UTF8); + // } + // } + /// ``` + /// + /// In such a case, [`Self::visit_list_with_item`] will be called with `my_list` as the list + /// type, and `element` as the `item_type` + /// fn visit_list(&mut self, list_type: TypePtr, context: C) -> Result { match list_type.as_ref() { - Type::PrimitiveType { .. } => panic!( - "{:?} is a list type and can't be processed as primitive.", - list_type - ), + Type::PrimitiveType { .. } => { + panic!("{:?} is a list type and must be a group type", list_type) + } Type::GroupType { basic_info: _, fields, diff --git a/parquet/src/util/bit_packing.rs b/parquet/src/util/bit_packing.rs index 6b9673f6c307..758992ab2723 100644 --- a/parquet/src/util/bit_packing.rs +++ b/parquet/src/util/bit_packing.rs @@ -79,3584 +79,3584 @@ unsafe fn nullunpacker32(in_buf: *const u32, mut out: *mut u32) -> *const u32 { } unsafe fn unpack1_32(in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) & 1; + *out = (in_buf.read_unaligned()) & 1; out = out.offset(1); - *out = ((*in_buf) >> 1) & 1; + *out = ((in_buf.read_unaligned()) >> 1) & 1; out = out.offset(1); - *out = ((*in_buf) >> 2) & 1; + *out = ((in_buf.read_unaligned()) >> 2) & 1; out = out.offset(1); - *out = ((*in_buf) >> 3) & 1; + *out = ((in_buf.read_unaligned()) >> 3) & 1; out = out.offset(1); - *out = ((*in_buf) >> 4) & 1; + *out = ((in_buf.read_unaligned()) >> 4) & 1; out = out.offset(1); - *out = ((*in_buf) >> 5) & 1; + *out = ((in_buf.read_unaligned()) >> 5) & 1; out = out.offset(1); - *out = ((*in_buf) >> 6) & 1; + *out = ((in_buf.read_unaligned()) >> 6) & 1; out = out.offset(1); - *out = ((*in_buf) >> 7) & 1; + *out = ((in_buf.read_unaligned()) >> 7) & 1; out = out.offset(1); - *out = ((*in_buf) >> 8) & 1; + *out = ((in_buf.read_unaligned()) >> 8) & 1; out = out.offset(1); - *out = ((*in_buf) >> 9) & 1; + *out = ((in_buf.read_unaligned()) >> 9) & 1; out = out.offset(1); - *out = ((*in_buf) >> 10) & 1; + *out = ((in_buf.read_unaligned()) >> 10) & 1; out = out.offset(1); - *out = ((*in_buf) >> 11) & 1; + *out = ((in_buf.read_unaligned()) >> 11) & 1; out = out.offset(1); - *out = ((*in_buf) >> 12) & 1; + *out = ((in_buf.read_unaligned()) >> 12) & 1; out = out.offset(1); - *out = ((*in_buf) >> 13) & 1; + *out = ((in_buf.read_unaligned()) >> 13) & 1; out = out.offset(1); - *out = ((*in_buf) >> 14) & 1; + *out = ((in_buf.read_unaligned()) >> 14) & 1; out = out.offset(1); - *out = ((*in_buf) >> 15) & 1; + *out = ((in_buf.read_unaligned()) >> 15) & 1; out = out.offset(1); - *out = ((*in_buf) >> 16) & 1; + *out = ((in_buf.read_unaligned()) >> 16) & 1; out = out.offset(1); - *out = ((*in_buf) >> 17) & 1; + *out = ((in_buf.read_unaligned()) >> 17) & 1; out = out.offset(1); - *out = ((*in_buf) >> 18) & 1; + *out = ((in_buf.read_unaligned()) >> 18) & 1; out = out.offset(1); - *out = ((*in_buf) >> 19) & 1; + *out = ((in_buf.read_unaligned()) >> 19) & 1; out = out.offset(1); - *out = ((*in_buf) >> 20) & 1; + *out = ((in_buf.read_unaligned()) >> 20) & 1; out = out.offset(1); - *out = ((*in_buf) >> 21) & 1; + *out = ((in_buf.read_unaligned()) >> 21) & 1; out = out.offset(1); - *out = ((*in_buf) >> 22) & 1; + *out = ((in_buf.read_unaligned()) >> 22) & 1; out = out.offset(1); - *out = ((*in_buf) >> 23) & 1; + *out = ((in_buf.read_unaligned()) >> 23) & 1; out = out.offset(1); - *out = ((*in_buf) >> 24) & 1; + *out = ((in_buf.read_unaligned()) >> 24) & 1; out = out.offset(1); - *out = ((*in_buf) >> 25) & 1; + *out = ((in_buf.read_unaligned()) >> 25) & 1; out = out.offset(1); - *out = ((*in_buf) >> 26) & 1; + *out = ((in_buf.read_unaligned()) >> 26) & 1; out = out.offset(1); - *out = ((*in_buf) >> 27) & 1; + *out = ((in_buf.read_unaligned()) >> 27) & 1; out = out.offset(1); - *out = ((*in_buf) >> 28) & 1; + *out = ((in_buf.read_unaligned()) >> 28) & 1; out = out.offset(1); - *out = ((*in_buf) >> 29) & 1; + *out = ((in_buf.read_unaligned()) >> 29) & 1; out = out.offset(1); - *out = ((*in_buf) >> 30) & 1; + *out = ((in_buf.read_unaligned()) >> 30) & 1; out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf.offset(1) } unsafe fn unpack2_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 2); + *out = (in_buf.read_unaligned()) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 26) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 2); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 2); + *out = (in_buf.read_unaligned()) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 26) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 2); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 2); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 2); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf.offset(1) } unsafe fn unpack3_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 3); + *out = (in_buf.read_unaligned()) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 15) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 21) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 27) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 27) % (1u32 << 3); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (3 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (3 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 19) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 25) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 25) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 3); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (3 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (3 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 17) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 23) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 23) % (1u32 << 3); out = out.offset(1); - *out = ((*in_buf) >> 26) % (1u32 << 3); + *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 3); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf.offset(1) } unsafe fn unpack4_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 4); + *out = (in_buf.read_unaligned()) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 4); + *out = (in_buf.read_unaligned()) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 4); + *out = (in_buf.read_unaligned()) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 4); + *out = (in_buf.read_unaligned()) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 4); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); in_buf.offset(1) } unsafe fn unpack5_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 5); + *out = (in_buf.read_unaligned()) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 15) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 25) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 25) % (1u32 << 5); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (5 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (5 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 23) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 23) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 28) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 5); in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (5 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (5 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 21) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 26) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 5); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (5 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (5 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 19) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 5); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (5 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (5 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 17) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 5); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 5); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 5); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf.offset(1) } unsafe fn unpack6_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 6); + *out = (in_buf.read_unaligned()) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 6); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (6 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (6 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 6); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (6 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (6 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 6); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 6); + *out = (in_buf.read_unaligned()) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 6); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (6 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (6 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 6); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (6 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (6 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 6); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 6); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 6); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf.offset(1) } unsafe fn unpack7_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 7); + *out = (in_buf.read_unaligned()) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 21) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 7); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (7 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (7 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 17) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 24) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 7); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (7 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (7 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 7); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (7 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (7 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 23) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 23) % (1u32 << 7); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (7 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (7 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 19) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 7); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (7 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (7 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 15) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 7); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (7 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (7 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 7); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 7); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 7); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf.offset(1) } unsafe fn unpack8_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 8); + *out = (in_buf.read_unaligned()) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 8); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf.offset(1) } unsafe fn unpack9_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 9); + *out = (in_buf.read_unaligned()) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (9 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (9 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 22) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (9 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (9 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 17) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (9 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (9 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 21) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (9 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (9 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (9 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (9 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (9 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (9 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 15) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (9 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (9 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 19) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (9 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (9 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 9); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 9); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 9); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf.offset(1) } unsafe fn unpack10_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 10); + *out = (in_buf.read_unaligned()) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (10 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (10 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (10 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (10 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (10 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (10 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (10 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (10 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 10); + *out = (in_buf.read_unaligned()) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (10 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (10 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (10 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (10 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (10 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (10 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (10 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (10 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 10); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 10); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 10); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf.offset(1) } unsafe fn unpack11_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 11); + *out = (in_buf.read_unaligned()) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (11 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (11 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (11 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (11 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (11 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (11 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (11 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (11 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 15) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (11 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (11 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (11 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (11 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 17) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (11 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (11 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (11 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (11 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 19) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (11 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (11 - 9); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 11); out = out.offset(1); - *out = ((*in_buf) >> 20) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (11 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (11 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 11); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 11); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf.offset(1) } unsafe fn unpack12_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 12); + *out = (in_buf.read_unaligned()) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (12 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (12 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 12); + *out = (in_buf.read_unaligned()) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (12 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (12 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 12); + *out = (in_buf.read_unaligned()) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (12 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (12 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 12); + *out = (in_buf.read_unaligned()) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (12 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (12 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 12); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf.offset(1) } unsafe fn unpack13_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 13); + *out = (in_buf.read_unaligned()) % (1u32 << 13); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (13 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (13 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (13 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (13 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 13); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (13 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (13 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (13 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (13 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 13); out = out.offset(1); - *out = ((*in_buf) >> 15) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (13 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (13 - 9); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (13 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (13 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 13); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (13 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (13 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (13 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (13 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 13); out = out.offset(1); - *out = ((*in_buf) >> 17) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (13 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (13 - 11); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (13 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (13 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 13); out = out.offset(1); - *out = ((*in_buf) >> 18) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (13 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (13 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (13 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (13 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 13); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 13); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf.offset(1) } unsafe fn unpack14_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 14); + *out = (in_buf.read_unaligned()) % (1u32 << 14); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (14 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (14 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (14 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (14 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (14 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (14 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 14); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (14 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (14 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (14 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (14 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (14 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (14 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 14); + *out = (in_buf.read_unaligned()) % (1u32 << 14); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (14 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (14 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (14 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (14 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (14 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (14 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 14); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (14 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (14 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (14 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (14 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (14 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (14 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 14); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 14); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf.offset(1) } unsafe fn unpack15_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 15); + *out = (in_buf.read_unaligned()) % (1u32 << 15); out = out.offset(1); - *out = ((*in_buf) >> 15) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (15 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (15 - 13); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (15 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (15 - 11); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (15 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (15 - 9); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (15 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (15 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (15 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (15 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (15 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (15 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (15 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (15 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 15); out = out.offset(1); - *out = ((*in_buf) >> 16) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (15 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (15 - 14); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (15 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (15 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (15 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (15 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (15 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (15 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (15 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (15 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (15 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (15 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (15 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (15 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 15); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 15); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf.offset(1) } unsafe fn unpack16_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; out = out.offset(1); in_buf = in_buf.offset(1); - *out = (*in_buf) % (1u32 << 16); + *out = (in_buf.read_unaligned()) % (1u32 << 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf.offset(1) } unsafe fn unpack17_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 17); + *out = (in_buf.read_unaligned()) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (17 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (17 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (17 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (17 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (17 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (17 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (17 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (17 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (17 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (17 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (17 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (17 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (17 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (17 - 14); out = out.offset(1); - *out = ((*in_buf) >> 14) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (17 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (17 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (17 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (17 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (17 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (17 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (17 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (17 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (17 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (17 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (17 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (17 - 9); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (17 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (17 - 11); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (17 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (17 - 13); out = out.offset(1); - *out = ((*in_buf) >> 13) % (1u32 << 17); + *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 17); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (17 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (17 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf.offset(1) } unsafe fn unpack18_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 18); + *out = (in_buf.read_unaligned()) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (18 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (18 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (18 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (18 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (18 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (18 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (18 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (18 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (18 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (18 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (18 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (18 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (18 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (18 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (18 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (18 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 18); + *out = (in_buf.read_unaligned()) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (18 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (18 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (18 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (18 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (18 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (18 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (18 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (18 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (18 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (18 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (18 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (18 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (18 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (18 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 18); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 18); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (18 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (18 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf.offset(1) } unsafe fn unpack19_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 19); + *out = (in_buf.read_unaligned()) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (19 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (19 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (19 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (19 - 12); out = out.offset(1); - *out = ((*in_buf) >> 12) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (19 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (19 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (19 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (19 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (19 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (19 - 11); out = out.offset(1); - *out = ((*in_buf) >> 11) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 17)) << (19 - 17); + *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (19 - 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (19 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (19 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (19 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (19 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (19 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (19 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (19 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (19 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (19 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (19 - 9); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (19 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (19 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (19 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (19 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (19 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (19 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (19 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (19 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (19 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (19 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (19 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (19 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 19); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 19); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (19 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (19 - 13); out = out.offset(1); - *out = (*in_buf) >> 13; + *out = (in_buf.read_unaligned()) >> 13; in_buf.offset(1) } unsafe fn unpack20_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 20); + *out = (in_buf.read_unaligned()) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (20 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (20 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (20 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (20 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 20); + *out = (in_buf.read_unaligned()) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (20 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (20 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (20 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (20 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 20); + *out = (in_buf.read_unaligned()) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (20 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (20 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (20 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (20 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 20); + *out = (in_buf.read_unaligned()) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (20 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (20 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (20 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 20); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (20 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf.offset(1) } unsafe fn unpack21_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 21); + *out = (in_buf.read_unaligned()) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (21 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (21 - 10); out = out.offset(1); - *out = ((*in_buf) >> 10) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (21 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (21 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (21 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (21 - 9); out = out.offset(1); - *out = ((*in_buf) >> 9) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 19)) << (21 - 19); + *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (21 - 19); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (21 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (21 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (21 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (21 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (21 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (21 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 17)) << (21 - 17); + *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (21 - 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (21 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (21 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (21 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (21 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (21 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (21 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (21 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (21 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (21 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (21 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (21 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (21 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (21 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (21 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (21 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (21 - 13); out = out.offset(1); - *out = (*in_buf) >> 13; + *out = (in_buf.read_unaligned()) >> 13; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (21 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (21 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (21 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (21 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (21 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (21 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 21); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 21); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (21 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (21 - 11); out = out.offset(1); - *out = (*in_buf) >> 11; + *out = (in_buf.read_unaligned()) >> 11; in_buf.offset(1) } unsafe fn unpack22_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 22); + *out = (in_buf.read_unaligned()) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (22 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (22 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (22 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (22 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (22 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (22 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (22 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (22 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (22 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (22 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (22 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (22 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (22 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (22 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (22 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (22 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (22 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (22 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (22 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (22 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 22); + *out = (in_buf.read_unaligned()) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (22 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (22 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (22 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (22 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (22 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (22 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (22 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (22 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (22 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (22 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (22 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (22 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (22 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (22 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (22 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (22 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 22); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 22); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (22 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (22 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (22 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (22 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf.offset(1) } unsafe fn unpack23_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 23); + *out = (in_buf.read_unaligned()) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (23 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (23 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (23 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (23 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 19)) << (23 - 19); + *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (23 - 19); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (23 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (23 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (23 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (23 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (23 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (23 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (23 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (23 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (23 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (23 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (23 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (23 - 11); out = out.offset(1); - *out = (*in_buf) >> 11; + *out = (in_buf.read_unaligned()) >> 11; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (23 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (23 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (23 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (23 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (23 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (23 - 7); out = out.offset(1); - *out = ((*in_buf) >> 7) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 21)) << (23 - 21); + *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (23 - 21); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (23 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (23 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (23 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (23 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 17)) << (23 - 17); + *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (23 - 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (23 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (23 - 8); out = out.offset(1); - *out = ((*in_buf) >> 8) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (23 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (23 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (23 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (23 - 13); out = out.offset(1); - *out = (*in_buf) >> 13; + *out = (in_buf.read_unaligned()) >> 13; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (23 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (23 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 23); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 23); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (23 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (23 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (23 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (23 - 9); out = out.offset(1); - *out = (*in_buf) >> 9; + *out = (in_buf.read_unaligned()) >> 9; in_buf.offset(1) } unsafe fn unpack24_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 24); + *out = (in_buf.read_unaligned()) % (1u32 << 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (24 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (24 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf.offset(1) } unsafe fn unpack25_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 25); + *out = (in_buf.read_unaligned()) % (1u32 << 25); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (25 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (25 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (25 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (25 - 11); out = out.offset(1); - *out = (*in_buf) >> 11; + *out = (in_buf.read_unaligned()) >> 11; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (25 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (25 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 25); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 25); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (25 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (25 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (25 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (25 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (25 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (25 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (25 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (25 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 25); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 25); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 19)) << (25 - 19); + *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (25 - 19); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (25 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (25 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (25 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (25 - 5); out = out.offset(1); - *out = ((*in_buf) >> 5) % (1u32 << 25); + *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 25); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 23)) << (25 - 23); + *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (25 - 23); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (25 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (25 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (25 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (25 - 9); out = out.offset(1); - *out = (*in_buf) >> 9; + *out = (in_buf.read_unaligned()) >> 9; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (25 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (25 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 25); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 25); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (25 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (25 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (25 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (25 - 13); out = out.offset(1); - *out = (*in_buf) >> 13; + *out = (in_buf.read_unaligned()) >> 13; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (25 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (25 - 6); out = out.offset(1); - *out = ((*in_buf) >> 6) % (1u32 << 25); + *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 25); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (25 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (25 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 17)) << (25 - 17); + *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (25 - 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (25 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (25 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (25 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (25 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 25); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 25); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 21)) << (25 - 21); + *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (25 - 21); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (25 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (25 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (25 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (25 - 7); out = out.offset(1); - *out = (*in_buf) >> 7; + *out = (in_buf.read_unaligned()) >> 7; in_buf.offset(1) } unsafe fn unpack26_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 26); + *out = (in_buf.read_unaligned()) % (1u32 << 26); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (26 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (26 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (26 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (26 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (26 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (26 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (26 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (26 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 26); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 26); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (26 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (26 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (26 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (26 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (26 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (26 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (26 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (26 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 26); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 26); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (26 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (26 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (26 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (26 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (26 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (26 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (26 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (26 - 6); out = out.offset(1); - *out = (*in_buf) >> 6; + *out = (in_buf.read_unaligned()) >> 6; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 26); + *out = (in_buf.read_unaligned()) % (1u32 << 26); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (26 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (26 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (26 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (26 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (26 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (26 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (26 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (26 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 26); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 26); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (26 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (26 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (26 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (26 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (26 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (26 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (26 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (26 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 26); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 26); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (26 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (26 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (26 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (26 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (26 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (26 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (26 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (26 - 6); out = out.offset(1); - *out = (*in_buf) >> 6; + *out = (in_buf.read_unaligned()) >> 6; in_buf.offset(1) } unsafe fn unpack27_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 27); + *out = (in_buf.read_unaligned()) % (1u32 << 27); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (27 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (27 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 17)) << (27 - 17); + *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (27 - 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (27 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (27 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (27 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (27 - 7); out = out.offset(1); - *out = (*in_buf) >> 7; + *out = (in_buf.read_unaligned()) >> 7; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (27 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (27 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 27); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 27); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (27 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (27 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 19)) << (27 - 19); + *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (27 - 19); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (27 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (27 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (27 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (27 - 9); out = out.offset(1); - *out = (*in_buf) >> 9; + *out = (in_buf.read_unaligned()) >> 9; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (27 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (27 - 4); out = out.offset(1); - *out = ((*in_buf) >> 4) % (1u32 << 27); + *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 27); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 26)) << (27 - 26); + *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (27 - 26); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 21)) << (27 - 21); + *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (27 - 21); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (27 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (27 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (27 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (27 - 11); out = out.offset(1); - *out = (*in_buf) >> 11; + *out = (in_buf.read_unaligned()) >> 11; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (27 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (27 - 6); out = out.offset(1); - *out = (*in_buf) >> 6; + *out = (in_buf.read_unaligned()) >> 6; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (27 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (27 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 27); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 27); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 23)) << (27 - 23); + *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (27 - 23); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (27 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (27 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (27 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (27 - 13); out = out.offset(1); - *out = (*in_buf) >> 13; + *out = (in_buf.read_unaligned()) >> 13; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (27 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (27 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (27 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (27 - 3); out = out.offset(1); - *out = ((*in_buf) >> 3) % (1u32 << 27); + *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 27); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 25)) << (27 - 25); + *out |= ((in_buf.read_unaligned()) % (1u32 << 25)) << (27 - 25); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (27 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (27 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (27 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (27 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (27 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (27 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (27 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (27 - 5); out = out.offset(1); - *out = (*in_buf) >> 5; + *out = (in_buf.read_unaligned()) >> 5; in_buf.offset(1) } unsafe fn unpack28_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 28); + *out = (in_buf.read_unaligned()) % (1u32 << 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (28 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (28 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (28 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (28 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (28 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (28 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 28); + *out = (in_buf.read_unaligned()) % (1u32 << 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (28 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (28 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (28 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (28 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (28 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (28 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 28); + *out = (in_buf.read_unaligned()) % (1u32 << 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (28 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (28 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (28 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (28 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (28 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (28 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 28); + *out = (in_buf.read_unaligned()) % (1u32 << 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (28 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (28 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (28 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (28 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (28 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (28 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf.offset(1) } unsafe fn unpack29_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 29); + *out = (in_buf.read_unaligned()) % (1u32 << 29); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 26)) << (29 - 26); + *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (29 - 26); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 23)) << (29 - 23); + *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (29 - 23); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (29 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (29 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 17)) << (29 - 17); + *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (29 - 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (29 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (29 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (29 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (29 - 11); out = out.offset(1); - *out = (*in_buf) >> 11; + *out = (in_buf.read_unaligned()) >> 11; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (29 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (29 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (29 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (29 - 5); out = out.offset(1); - *out = (*in_buf) >> 5; + *out = (in_buf.read_unaligned()) >> 5; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (29 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (29 - 2); out = out.offset(1); - *out = ((*in_buf) >> 2) % (1u32 << 29); + *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 29); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 28)) << (29 - 28); + *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (29 - 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 25)) << (29 - 25); + *out |= ((in_buf.read_unaligned()) % (1u32 << 25)) << (29 - 25); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (29 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (29 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 19)) << (29 - 19); + *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (29 - 19); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (29 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (29 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (29 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (29 - 13); out = out.offset(1); - *out = (*in_buf) >> 13; + *out = (in_buf.read_unaligned()) >> 13; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (29 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (29 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (29 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (29 - 7); out = out.offset(1); - *out = (*in_buf) >> 7; + *out = (in_buf.read_unaligned()) >> 7; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (29 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (29 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (29 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (29 - 1); out = out.offset(1); - *out = ((*in_buf) >> 1) % (1u32 << 29); + *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 29); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 27)) << (29 - 27); + *out |= ((in_buf.read_unaligned()) % (1u32 << 27)) << (29 - 27); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (29 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (29 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 21)) << (29 - 21); + *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (29 - 21); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (29 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (29 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (29 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (29 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (29 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (29 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (29 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (29 - 9); out = out.offset(1); - *out = (*in_buf) >> 9; + *out = (in_buf.read_unaligned()) >> 9; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (29 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (29 - 6); out = out.offset(1); - *out = (*in_buf) >> 6; + *out = (in_buf.read_unaligned()) >> 6; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (29 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (29 - 3); out = out.offset(1); - *out = (*in_buf) >> 3; + *out = (in_buf.read_unaligned()) >> 3; in_buf.offset(1) } unsafe fn unpack30_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 30); + *out = (in_buf.read_unaligned()) % (1u32 << 30); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 28)) << (30 - 28); + *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (30 - 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 26)) << (30 - 26); + *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (30 - 26); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (30 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (30 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (30 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (30 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (30 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (30 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (30 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (30 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (30 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (30 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (30 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (30 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (30 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (30 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (30 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (30 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (30 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (30 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (30 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (30 - 6); out = out.offset(1); - *out = (*in_buf) >> 6; + *out = (in_buf.read_unaligned()) >> 6; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (30 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (30 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (30 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (30 - 2); out = out.offset(1); - *out = (*in_buf) >> 2; + *out = (in_buf.read_unaligned()) >> 2; in_buf = in_buf.offset(1); out = out.offset(1); - *out = (*in_buf) % (1u32 << 30); + *out = (in_buf.read_unaligned()) % (1u32 << 30); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 28)) << (30 - 28); + *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (30 - 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 26)) << (30 - 26); + *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (30 - 26); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (30 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (30 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (30 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (30 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (30 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (30 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (30 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (30 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (30 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (30 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (30 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (30 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (30 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (30 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (30 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (30 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (30 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (30 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (30 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (30 - 6); out = out.offset(1); - *out = (*in_buf) >> 6; + *out = (in_buf.read_unaligned()) >> 6; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (30 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (30 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (30 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (30 - 2); out = out.offset(1); - *out = (*in_buf) >> 2; + *out = (in_buf.read_unaligned()) >> 2; in_buf.offset(1) } unsafe fn unpack31_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (*in_buf) % (1u32 << 31); + *out = (in_buf.read_unaligned()) % (1u32 << 31); out = out.offset(1); - *out = (*in_buf) >> 31; + *out = (in_buf.read_unaligned()) >> 31; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 30)) << (31 - 30); + *out |= ((in_buf.read_unaligned()) % (1u32 << 30)) << (31 - 30); out = out.offset(1); - *out = (*in_buf) >> 30; + *out = (in_buf.read_unaligned()) >> 30; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 29)) << (31 - 29); + *out |= ((in_buf.read_unaligned()) % (1u32 << 29)) << (31 - 29); out = out.offset(1); - *out = (*in_buf) >> 29; + *out = (in_buf.read_unaligned()) >> 29; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 28)) << (31 - 28); + *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (31 - 28); out = out.offset(1); - *out = (*in_buf) >> 28; + *out = (in_buf.read_unaligned()) >> 28; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 27)) << (31 - 27); + *out |= ((in_buf.read_unaligned()) % (1u32 << 27)) << (31 - 27); out = out.offset(1); - *out = (*in_buf) >> 27; + *out = (in_buf.read_unaligned()) >> 27; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 26)) << (31 - 26); + *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (31 - 26); out = out.offset(1); - *out = (*in_buf) >> 26; + *out = (in_buf.read_unaligned()) >> 26; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 25)) << (31 - 25); + *out |= ((in_buf.read_unaligned()) % (1u32 << 25)) << (31 - 25); out = out.offset(1); - *out = (*in_buf) >> 25; + *out = (in_buf.read_unaligned()) >> 25; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 24)) << (31 - 24); + *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (31 - 24); out = out.offset(1); - *out = (*in_buf) >> 24; + *out = (in_buf.read_unaligned()) >> 24; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 23)) << (31 - 23); + *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (31 - 23); out = out.offset(1); - *out = (*in_buf) >> 23; + *out = (in_buf.read_unaligned()) >> 23; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 22)) << (31 - 22); + *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (31 - 22); out = out.offset(1); - *out = (*in_buf) >> 22; + *out = (in_buf.read_unaligned()) >> 22; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 21)) << (31 - 21); + *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (31 - 21); out = out.offset(1); - *out = (*in_buf) >> 21; + *out = (in_buf.read_unaligned()) >> 21; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 20)) << (31 - 20); + *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (31 - 20); out = out.offset(1); - *out = (*in_buf) >> 20; + *out = (in_buf.read_unaligned()) >> 20; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 19)) << (31 - 19); + *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (31 - 19); out = out.offset(1); - *out = (*in_buf) >> 19; + *out = (in_buf.read_unaligned()) >> 19; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 18)) << (31 - 18); + *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (31 - 18); out = out.offset(1); - *out = (*in_buf) >> 18; + *out = (in_buf.read_unaligned()) >> 18; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 17)) << (31 - 17); + *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (31 - 17); out = out.offset(1); - *out = (*in_buf) >> 17; + *out = (in_buf.read_unaligned()) >> 17; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 16)) << (31 - 16); + *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (31 - 16); out = out.offset(1); - *out = (*in_buf) >> 16; + *out = (in_buf.read_unaligned()) >> 16; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 15)) << (31 - 15); + *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (31 - 15); out = out.offset(1); - *out = (*in_buf) >> 15; + *out = (in_buf.read_unaligned()) >> 15; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 14)) << (31 - 14); + *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (31 - 14); out = out.offset(1); - *out = (*in_buf) >> 14; + *out = (in_buf.read_unaligned()) >> 14; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 13)) << (31 - 13); + *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (31 - 13); out = out.offset(1); - *out = (*in_buf) >> 13; + *out = (in_buf.read_unaligned()) >> 13; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 12)) << (31 - 12); + *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (31 - 12); out = out.offset(1); - *out = (*in_buf) >> 12; + *out = (in_buf.read_unaligned()) >> 12; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 11)) << (31 - 11); + *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (31 - 11); out = out.offset(1); - *out = (*in_buf) >> 11; + *out = (in_buf.read_unaligned()) >> 11; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 10)) << (31 - 10); + *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (31 - 10); out = out.offset(1); - *out = (*in_buf) >> 10; + *out = (in_buf.read_unaligned()) >> 10; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 9)) << (31 - 9); + *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (31 - 9); out = out.offset(1); - *out = (*in_buf) >> 9; + *out = (in_buf.read_unaligned()) >> 9; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 8)) << (31 - 8); + *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (31 - 8); out = out.offset(1); - *out = (*in_buf) >> 8; + *out = (in_buf.read_unaligned()) >> 8; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 7)) << (31 - 7); + *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (31 - 7); out = out.offset(1); - *out = (*in_buf) >> 7; + *out = (in_buf.read_unaligned()) >> 7; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 6)) << (31 - 6); + *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (31 - 6); out = out.offset(1); - *out = (*in_buf) >> 6; + *out = (in_buf.read_unaligned()) >> 6; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 5)) << (31 - 5); + *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (31 - 5); out = out.offset(1); - *out = (*in_buf) >> 5; + *out = (in_buf.read_unaligned()) >> 5; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 4)) << (31 - 4); + *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (31 - 4); out = out.offset(1); - *out = (*in_buf) >> 4; + *out = (in_buf.read_unaligned()) >> 4; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 3)) << (31 - 3); + *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (31 - 3); out = out.offset(1); - *out = (*in_buf) >> 3; + *out = (in_buf.read_unaligned()) >> 3; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 2)) << (31 - 2); + *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (31 - 2); out = out.offset(1); - *out = (*in_buf) >> 2; + *out = (in_buf.read_unaligned()) >> 2; in_buf = in_buf.offset(1); - *out |= ((*in_buf) % (1u32 << 1)) << (31 - 1); + *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (31 - 1); out = out.offset(1); - *out = (*in_buf) >> 1; + *out = (in_buf.read_unaligned()) >> 1; in_buf.offset(1) } unsafe fn unpack32_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf = in_buf.offset(1); out = out.offset(1); - *out = *in_buf; + *out = in_buf.read_unaligned(); in_buf.offset(1) } diff --git a/parquet/src/util/bit_util.rs b/parquet/src/util/bit_util.rs index 010ed32c8c8b..b535ee02a0ef 100644 --- a/parquet/src/util/bit_util.rs +++ b/parquet/src/util/bit_util.rs @@ -32,6 +32,17 @@ pub fn from_ne_slice(bs: &[u8]) -> T { T::from_ne_bytes(b) } +#[inline] +pub fn from_le_slice(bs: &[u8]) -> T { + let mut b = T::Buffer::default(); + { + let b = b.as_mut(); + let bs = &bs[..b.len()]; + b.copy_from_slice(bs); + } + T::from_le_bytes(b) +} + pub trait FromBytes: Sized { type Buffer: AsMut<[u8]> + Default; fn from_le_bytes(bs: Self::Buffer) -> Self; @@ -383,8 +394,8 @@ impl BitWriter { // TODO: should we return `Result` for this func? return false; } - let mut ptr = result.unwrap(); - memcpy_value(&val, num_bytes, &mut ptr); + let ptr = result.unwrap(); + memcpy_value(&val, num_bytes, ptr); true } @@ -528,8 +539,14 @@ impl BitReader { Some(from_ne_slice(v.as_bytes())) } + /// Read multiple values from their packed representation + /// + /// # Panics + /// + /// This function panics if + /// - `bit_width` is larger than the bit-capacity of `T` + /// pub fn get_batch(&mut self, batch: &mut [T], num_bits: usize) -> usize { - assert!(num_bits <= 32); assert!(num_bits <= size_of::() * 8); let mut values_to_read = batch.len(); @@ -541,6 +558,17 @@ impl BitReader { let mut i = 0; + if num_bits > 32 { + // No fast path - read values individually + while i < values_to_read { + batch[i] = self + .get_value(num_bits) + .expect("expected to have more data"); + i += 1; + } + return values_to_read + } + // First align bit offset to byte offset if self.bit_offset != 0 { while i < values_to_read && self.bit_offset != 0 { @@ -551,41 +579,35 @@ impl BitReader { } } - unsafe { - let in_buf = &self.buffer.data()[self.byte_offset..]; - let mut in_ptr = in_buf as *const [u8] as *const u8 as *const u32; - // FIXME assert!(memory::is_ptr_aligned(in_ptr)); - if size_of::() == 4 { - while values_to_read - i >= 32 { - let out_ptr = &mut batch[i..] as *mut [T] as *mut T as *mut u32; - in_ptr = unpack32(in_ptr, out_ptr, num_bits); - self.byte_offset += 4 * num_bits; - i += 32; - } - } else { - let mut out_buf = [0u32; 32]; - let out_ptr = &mut out_buf as &mut [u32] as *mut [u32] as *mut u32; - while values_to_read - i >= 32 { - in_ptr = unpack32(in_ptr, out_ptr, num_bits); - self.byte_offset += 4 * num_bits; - for n in 0..32 { - // We need to copy from smaller size to bigger size to avoid - // overwriting other memory regions. - if size_of::() > size_of::() { - std::ptr::copy_nonoverlapping( - out_buf[n..].as_ptr() as *const u32, - &mut batch[i] as *mut T as *mut u32, - 1, - ); - } else { - std::ptr::copy_nonoverlapping( - out_buf[n..].as_ptr() as *const T, - &mut batch[i] as *mut T, - 1, - ); - } - i += 1; + let in_buf = &self.buffer.data()[self.byte_offset..]; + let mut in_ptr = in_buf as *const [u8] as *const u8 as *const u32; + if size_of::() == 4 { + while values_to_read - i >= 32 { + let out_ptr = &mut batch[i..] as *mut [T] as *mut T as *mut u32; + in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) }; + self.byte_offset += 4 * num_bits; + i += 32; + } + } else { + let mut out_buf = [0u32; 32]; + let out_ptr = &mut out_buf as &mut [u32] as *mut [u32] as *mut u32; + while values_to_read - i >= 32 { + in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) }; + self.byte_offset += 4 * num_bits; + + for out in out_buf { + // Zero-allocate buffer + let mut out_bytes = T::Buffer::default(); + let in_bytes = out.to_le_bytes(); + + { + let out_bytes = out_bytes.as_mut(); + let len = out_bytes.len().min(in_bytes.len()); + (&mut out_bytes[..len]).copy_from_slice(&in_bytes[..len]); } + + batch[i] = T::from_le_bytes(out_bytes); + i += 1; } } } @@ -603,6 +625,26 @@ impl BitReader { values_to_read } + /// Reads up to `num_bytes` to `buf` returning the number of bytes read + pub(crate) fn get_aligned_bytes( + &mut self, + buf: &mut Vec, + num_bytes: usize, + ) -> usize { + // Align to byte offset + self.byte_offset += ceil(self.bit_offset as i64, 8) as usize; + self.bit_offset = 0; + + let src = &self.buffer.data()[self.byte_offset..]; + let to_read = num_bytes.min(src.len()); + buf.extend_from_slice(&src[..to_read]); + + self.byte_offset += to_read; + self.reload_buffer_values(); + + to_read + } + /// Reads a `num_bytes`-sized value from this buffer and return it. /// `T` needs to be a little-endian native type. The value is assumed to be byte /// aligned so the bit reader will be advanced to the start of the next byte before @@ -1157,4 +1199,19 @@ mod tests { ); }); } + + #[test] + fn test_get_batch_zero_extend() { + let to_read = vec![0xFF; 4]; + let mut reader = BitReader::new(ByteBufferPtr::new(to_read)); + + // Create a non-zeroed output buffer + let mut output = [u64::MAX; 32]; + reader.get_batch(&mut output, 1); + + for v in output { + // Values should be read correctly + assert_eq!(v, 1); + } + } } diff --git a/parquet/src/util/cursor.rs b/parquet/src/util/cursor.rs index c847fc8f6823..706724dbf52a 100644 --- a/parquet/src/util/cursor.rs +++ b/parquet/src/util/cursor.rs @@ -15,18 +15,18 @@ // specific language governing permissions and limitations // under the License. +use crate::util::io::TryClone; use std::io::{self, Cursor, Error, ErrorKind, Read, Seek, SeekFrom, Write}; use std::sync::{Arc, Mutex}; use std::{cmp, fmt}; -use crate::file::writer::TryClone; - /// This is object to use if your file is already in memory. /// The sliceable cursor is similar to std::io::Cursor, except that it makes it easy to create "cursor slices". /// To achieve this, it uses Arc instead of shared references. Indeed reference fields are painful /// because the lack of Generic Associated Type implies that you would require complex lifetime propagation when /// returning such a cursor. #[allow(clippy::rc_buffer)] +#[deprecated = "use bytes::Bytes instead"] pub struct SliceableCursor { inner: Arc>, start: u64, @@ -34,6 +34,7 @@ pub struct SliceableCursor { pos: u64, } +#[allow(deprecated)] impl fmt::Debug for SliceableCursor { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SliceableCursor") @@ -45,6 +46,7 @@ impl fmt::Debug for SliceableCursor { } } +#[allow(deprecated)] impl SliceableCursor { pub fn new(content: impl Into>>) -> Self { let inner = content.into(); @@ -91,6 +93,7 @@ impl SliceableCursor { } /// Implementation inspired by std::io::Cursor +#[allow(deprecated)] impl Read for SliceableCursor { fn read(&mut self, buf: &mut [u8]) -> io::Result { let n = Read::read(&mut self.remaining_slice(), buf)?; @@ -99,6 +102,7 @@ impl Read for SliceableCursor { } } +#[allow(deprecated)] impl Seek for SliceableCursor { fn seek(&mut self, pos: SeekFrom) -> io::Result { let new_pos = match pos { @@ -134,11 +138,13 @@ impl Seek for SliceableCursor { } /// Use this type to write Parquet to memory rather than a file. +#[deprecated = "use Vec instead"] #[derive(Debug, Default, Clone)] pub struct InMemoryWriteableCursor { buffer: Arc>>>, } +#[allow(deprecated)] impl InMemoryWriteableCursor { /// Consume this instance and return the underlying buffer as long as there are no other /// references to this instance. @@ -168,6 +174,7 @@ impl InMemoryWriteableCursor { } } +#[allow(deprecated)] impl TryClone for InMemoryWriteableCursor { fn try_clone(&self) -> std::io::Result { Ok(Self { @@ -176,6 +183,7 @@ impl TryClone for InMemoryWriteableCursor { } } +#[allow(deprecated)] impl Write for InMemoryWriteableCursor { fn write(&mut self, buf: &[u8]) -> std::io::Result { let mut inner = self.buffer.lock().unwrap(); @@ -188,6 +196,7 @@ impl Write for InMemoryWriteableCursor { } } +#[allow(deprecated)] impl Seek for InMemoryWriteableCursor { fn seek(&mut self, pos: SeekFrom) -> std::io::Result { let mut inner = self.buffer.lock().unwrap(); @@ -200,17 +209,19 @@ mod tests { use super::*; /// Create a SliceableCursor of all u8 values in ascending order + #[allow(deprecated)] fn get_u8_range() -> SliceableCursor { let data: Vec = (0u8..=255).collect(); SliceableCursor::new(data) } /// Reads all the bytes in the slice and checks that it matches the u8 range from start to end_included + #[allow(deprecated)] fn check_read_all(mut cursor: SliceableCursor, start: u8, end_included: u8) { let mut target = vec![]; let cursor_res = cursor.read_to_end(&mut target); println!("{:?}", cursor_res); - assert!(!cursor_res.is_err(), "reading error"); + assert!(cursor_res.is_ok(), "reading error"); assert_eq!((end_included - start) as usize + 1, cursor_res.unwrap()); assert_eq!((start..=end_included).collect::>(), target); } diff --git a/parquet/src/util/hash_util.rs b/parquet/src/util/hash_util.rs index 34d954c3b3e2..f3705bc32f5f 100644 --- a/parquet/src/util/hash_util.rs +++ b/parquet/src/util/hash_util.rs @@ -25,9 +25,9 @@ pub fn hash(data: &T, seed: u32) -> u32 { fn hash_(data: &[u8], seed: u32) -> u32 { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - unsafe { + { if is_x86_feature_detected!("sse4.2") { - crc32_hash(data, seed) + unsafe { crc32_hash(data, seed) } } else { murmur_hash2_64a(data, seed as u64) as u32 } @@ -39,7 +39,7 @@ fn hash_(data: &[u8], seed: u32) -> u32 { target_arch = "riscv64", target_arch = "wasm32" ))] - unsafe { + { murmur_hash2_64a(data, seed as u64) as u32 } } @@ -48,19 +48,15 @@ const MURMUR_PRIME: u64 = 0xc6a4a7935bd1e995; const MURMUR_R: i32 = 47; /// Rust implementation of MurmurHash2, 64-bit version for 64-bit platforms -/// -/// SAFTETY Only safe on platforms which support unaligned loads (like x86_64) -unsafe fn murmur_hash2_64a(data_bytes: &[u8], seed: u64) -> u64 { +fn murmur_hash2_64a(data_bytes: &[u8], seed: u64) -> u64 { let len = data_bytes.len(); let len_64 = (len / 8) * 8; - let data_bytes_64 = std::slice::from_raw_parts( - &data_bytes[0..len_64] as *const [u8] as *const u64, - len / 8, - ); let mut h = seed ^ (MURMUR_PRIME.wrapping_mul(data_bytes.len() as u64)); - for v in data_bytes_64 { - let mut k = *v; + for mut k in data_bytes + .chunks_exact(8) + .map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())) + { k = k.wrapping_mul(MURMUR_PRIME); k ^= k >> MURMUR_R; k = k.wrapping_mul(MURMUR_PRIME); @@ -146,16 +142,14 @@ mod tests { #[test] fn test_murmur2_64a() { - unsafe { - let result = murmur_hash2_64a(b"hello", 123); - assert_eq!(result, 2597646618390559622); + let result = murmur_hash2_64a(b"hello", 123); + assert_eq!(result, 2597646618390559622); - let result = murmur_hash2_64a(b"helloworld", 123); - assert_eq!(result, 4934371746140206573); + let result = murmur_hash2_64a(b"helloworld", 123); + assert_eq!(result, 4934371746140206573); - let result = murmur_hash2_64a(b"helloworldparquet", 123); - assert_eq!(result, 2392198230801491746); - } + let result = murmur_hash2_64a(b"helloworldparquet", 123); + assert_eq!(result, 2392198230801491746); } #[test] diff --git a/parquet/src/util/io.rs b/parquet/src/util/io.rs index 44e99ac0a779..a7b5e73074c6 100644 --- a/parquet/src/util/io.rs +++ b/parquet/src/util/io.rs @@ -17,7 +17,9 @@ use std::{cell::RefCell, cmp, fmt, io::*}; -use crate::file::{reader::Length, writer::ParquetWriter}; +use crate::file::reader::Length; +#[allow(deprecated)] +use crate::file::writer::ParquetWriter; const DEFAULT_BUF_SIZE: usize = 8 * 1024; @@ -156,6 +158,8 @@ impl Length for FileSource { /// Struct that represents `File` output stream with position tracking. /// Used as a sink in file writer. +#[deprecated = "use TrackedWrite instead"] +#[allow(deprecated)] pub struct FileSink { buf: BufWriter, // This is not necessarily position in the underlying file, @@ -163,6 +167,7 @@ pub struct FileSink { pos: u64, } +#[allow(deprecated)] impl FileSink { /// Creates new file sink. /// Position is set to whatever position file has. @@ -176,6 +181,7 @@ impl FileSink { } } +#[allow(deprecated)] impl Write for FileSink { fn write(&mut self, buf: &[u8]) -> Result { let num_bytes = self.buf.write(buf)?; @@ -188,6 +194,7 @@ impl Write for FileSink { } } +#[allow(deprecated)] impl Position for FileSink { fn pos(&self) -> u64 { self.pos @@ -207,7 +214,7 @@ mod tests { use std::iter; - use crate::util::test_common::{get_temp_file, get_test_file}; + use crate::util::test_common::get_test_file; #[test] fn test_io_read_fully() { @@ -271,9 +278,10 @@ mod tests { } #[test] + #[allow(deprecated)] fn test_io_write_with_pos() { - let mut file = get_temp_file("file_sink_test", &[b'a', b'b', b'c']); - file.seek(SeekFrom::Current(3)).unwrap(); + let mut file = tempfile::tempfile().unwrap(); + file.write_all(&[b'a', b'b', b'c']).unwrap(); // Write into sink let mut sink = FileSink::new(&file); @@ -300,8 +308,9 @@ mod tests { .flatten() .take(3 * DEFAULT_BUF_SIZE) .collect(); - // always use different temp files as test might be run in parallel - let mut file = get_temp_file("large_file_sink_test", &patterned_data); + + let mut file = tempfile::tempfile().unwrap(); + file.write_all(&patterned_data).unwrap(); // seek the underlying file to the first 'd' file.seek(SeekFrom::Start(3)).unwrap(); diff --git a/parquet/src/util/memory.rs b/parquet/src/util/memory.rs index a9d0ba6a3d85..909878a6d538 100644 --- a/parquet/src/util/memory.rs +++ b/parquet/src/util/memory.rs @@ -17,526 +17,124 @@ //! Utility methods and structs for working with memory. +use bytes::Bytes; use std::{ fmt::{Debug, Display, Formatter, Result as FmtResult}, - io::{Result as IoResult, Write}, - mem, - ops::{Index, IndexMut}, - sync::{ - atomic::{AtomicI64, Ordering}, - Arc, Weak, - }, + ops::Index, }; -// ---------------------------------------------------------------------- -// Memory Tracker classes - -/// Reference counted pointer for [`MemTracker`]. -pub type MemTrackerPtr = Arc; -/// Non-owning reference for [`MemTracker`]. -pub type WeakMemTrackerPtr = Weak; - -/// Struct to track memory usage information. -#[derive(Debug)] -pub struct MemTracker { - // In the tuple, the first element is the current memory allocated (in bytes), - // and the second element is the maximum memory allocated so far (in bytes). - current_memory_usage: AtomicI64, - max_memory_usage: AtomicI64, -} - -impl MemTracker { - /// Creates new memory tracker. - #[inline] - pub fn new() -> MemTracker { - MemTracker { - current_memory_usage: Default::default(), - max_memory_usage: Default::default(), - } - } - - /// Returns the current memory consumption, in bytes. - pub fn memory_usage(&self) -> i64 { - self.current_memory_usage.load(Ordering::Acquire) - } - - /// Returns the maximum memory consumption so far, in bytes. - pub fn max_memory_usage(&self) -> i64 { - self.max_memory_usage.load(Ordering::Acquire) - } - - /// Adds `num_bytes` to the memory consumption tracked by this memory tracker. - #[inline] - pub fn alloc(&self, num_bytes: i64) { - let new_current = self - .current_memory_usage - .fetch_add(num_bytes, Ordering::Acquire) - + num_bytes; - self.max_memory_usage - .fetch_max(new_current, Ordering::Acquire); - } -} - -// ---------------------------------------------------------------------- -// Buffer classes - -/// Type alias for [`Buffer`]. -pub type ByteBuffer = Buffer; -/// Type alias for [`BufferPtr`]. -pub type ByteBufferPtr = BufferPtr; - -/// A resize-able buffer class with generic member, with optional memory tracker. -/// -/// Note that a buffer has two attributes: -/// `capacity` and `size`: the former is the total number of space reserved for -/// the buffer, while the latter is the actual number of elements. -/// Invariant: `capacity` >= `size`. -/// The total allocated bytes for a buffer equals to `capacity * sizeof()`. -pub struct Buffer { - data: Vec, - mem_tracker: Option, - type_length: usize, -} - -impl Buffer { - /// Creates new empty buffer. - pub fn new() -> Self { - Buffer { - data: vec![], - mem_tracker: None, - type_length: std::mem::size_of::(), - } - } - - /// Adds [`MemTracker`] for this buffer. - #[inline] - pub fn with_mem_tracker(mut self, mc: MemTrackerPtr) -> Self { - mc.alloc((self.data.capacity() * self.type_length) as i64); - self.mem_tracker = Some(mc); - self - } - - /// Returns slice of data in this buffer. - #[inline] - pub fn data(&self) -> &[T] { - self.data.as_slice() - } - - /// Sets data for this buffer. - #[inline] - pub fn set_data(&mut self, new_data: Vec) { - if let Some(ref mc) = self.mem_tracker { - let capacity_diff = new_data.capacity() as i64 - self.data.capacity() as i64; - mc.alloc(capacity_diff * self.type_length as i64); - } - self.data = new_data; - } - - /// Resizes underlying data in place to a new length `new_size`. - /// - /// If `new_size` is less than current length, data is truncated, otherwise, it is - /// extended to `new_size` with provided default value `init_value`. - /// - /// Memory tracker is also updated, if available. - #[inline] - pub fn resize(&mut self, new_size: usize, init_value: T) { - let old_capacity = self.data.capacity(); - self.data.resize(new_size, init_value); - if let Some(ref mc) = self.mem_tracker { - let capacity_diff = self.data.capacity() as i64 - old_capacity as i64; - mc.alloc(capacity_diff * self.type_length as i64); - } - } - - /// Clears underlying data. - #[inline] - pub fn clear(&mut self) { - self.data.clear() - } - - /// Reserves capacity `additional_capacity` for underlying data vector. - /// - /// Memory tracker is also updated, if available. - #[inline] - pub fn reserve(&mut self, additional_capacity: usize) { - let old_capacity = self.data.capacity(); - self.data.reserve(additional_capacity); - if self.data.capacity() > old_capacity { - if let Some(ref mc) = self.mem_tracker { - let capacity_diff = self.data.capacity() as i64 - old_capacity as i64; - mc.alloc(capacity_diff * self.type_length as i64); - } - } - } - - /// Returns [`BufferPtr`] with buffer data. - /// Buffer data is reset. - #[inline] - pub fn consume(&mut self) -> BufferPtr { - let old_data = mem::take(&mut self.data); - let mut result = BufferPtr::new(old_data); - if let Some(ref mc) = self.mem_tracker { - result = result.with_mem_tracker(mc.clone()); - } - result - } - - /// Adds `value` to the buffer. - #[inline] - pub fn push(&mut self, value: T) { - self.data.push(value) - } - - /// Returns current capacity for the buffer. - #[inline] - pub fn capacity(&self) -> usize { - self.data.capacity() - } - - /// Returns current size for the buffer. - #[inline] - pub fn size(&self) -> usize { - self.data.len() - } - - /// Returns `true` if memory tracker is added to buffer, `false` otherwise. - #[inline] - pub fn is_mem_tracked(&self) -> bool { - self.mem_tracker.is_some() - } - - /// Returns memory tracker associated with this buffer. - /// This may panic, if memory tracker is not set, use method above to check if - /// memory tracker is available. - #[inline] - pub fn mem_tracker(&self) -> &MemTrackerPtr { - self.mem_tracker.as_ref().unwrap() - } -} - -impl Index for Buffer { - type Output = T; - - fn index(&self, index: usize) -> &T { - &self.data[index] - } -} - -impl IndexMut for Buffer { - fn index_mut(&mut self, index: usize) -> &mut T { - &mut self.data[index] - } -} - -// TODO: implement this for other types -impl Write for Buffer { - #[inline] - fn write(&mut self, buf: &[u8]) -> IoResult { - let old_capacity = self.data.capacity(); - let bytes_written = self.data.write(buf)?; - if let Some(ref mc) = self.mem_tracker { - if self.data.capacity() - old_capacity > 0 { - mc.alloc((self.data.capacity() - old_capacity) as i64) - } - } - Ok(bytes_written) - } - - fn flush(&mut self) -> IoResult<()> { - // No-op - self.data.flush() - } -} - -impl AsRef<[u8]> for Buffer { - fn as_ref(&self) -> &[u8] { - self.data.as_slice() - } -} - -impl Drop for Buffer { - #[inline] - fn drop(&mut self) { - if let Some(ref mc) = self.mem_tracker { - mc.alloc(-((self.data.capacity() * self.type_length) as i64)); - } - } -} - // ---------------------------------------------------------------------- // Immutable Buffer (BufferPtr) classes /// An representation of a slice on a reference-counting and read-only byte array. /// Sub-slices can be further created from this. The byte array will be released /// when all slices are dropped. -#[allow(clippy::rc_buffer)] +/// +/// TODO: Remove and replace with [`bytes::Bytes`] #[derive(Clone, Debug)] -pub struct BufferPtr { - data: Arc>, - start: usize, - len: usize, - // TODO: will this create too many references? rethink about this. - mem_tracker: Option, +pub struct ByteBufferPtr { + data: Bytes, } -impl BufferPtr { +impl ByteBufferPtr { /// Creates new buffer from a vector. - pub fn new(v: Vec) -> Self { - let len = v.len(); - Self { - data: Arc::new(v), - start: 0, - len, - mem_tracker: None, - } + pub fn new(v: Vec) -> Self { + Self { data: v.into() } } /// Returns slice of data in this buffer. #[inline] - pub fn data(&self) -> &[T] { - &self.data[self.start..self.start + self.len] - } - - /// Updates this buffer with new `start` position and length `len`. - /// - /// Range should be within current start position and length. - #[inline] - pub fn with_range(mut self, start: usize, len: usize) -> Self { - self.set_range(start, len); - self - } - - /// Updates this buffer with new `start` position and length `len`. - /// - /// Range should be within current start position and length. - #[inline] - pub fn set_range(&mut self, start: usize, len: usize) { - assert!(self.start <= start && start + len <= self.start + self.len); - self.start = start; - self.len = len; - } - - /// Adds memory tracker to this buffer. - pub fn with_mem_tracker(mut self, mc: MemTrackerPtr) -> Self { - self.mem_tracker = Some(mc); - self - } - - /// Returns start position of this buffer. - #[inline] - pub fn start(&self) -> usize { - self.start + pub fn data(&self) -> &[u8] { + &self.data } /// Returns length of this buffer #[inline] pub fn len(&self) -> usize { - self.len + self.data.len() } /// Returns whether this buffer is empty #[inline] pub fn is_empty(&self) -> bool { - self.len == 0 - } - - /// Returns `true` if this buffer has memory tracker, `false` otherwise. - pub fn is_mem_tracked(&self) -> bool { - self.mem_tracker.is_some() + self.data.is_empty() } /// Returns a shallow copy of the buffer. /// Reference counted pointer to the data is copied. - pub fn all(&self) -> BufferPtr { - BufferPtr { - data: self.data.clone(), - start: self.start, - len: self.len, - mem_tracker: self.mem_tracker.as_ref().cloned(), - } + pub fn all(&self) -> Self { + self.clone() } /// Returns a shallow copy of the buffer that starts with `start` position. - pub fn start_from(&self, start: usize) -> BufferPtr { - assert!(start <= self.len); - BufferPtr { - data: self.data.clone(), - start: self.start + start, - len: self.len - start, - mem_tracker: self.mem_tracker.as_ref().cloned(), + pub fn start_from(&self, start: usize) -> Self { + Self { + data: self.data.slice(start..), } } /// Returns a shallow copy that is a range slice within this buffer. - pub fn range(&self, start: usize, len: usize) -> BufferPtr { - assert!(start + len <= self.len); - BufferPtr { - data: self.data.clone(), - start: self.start + start, - len, - mem_tracker: self.mem_tracker.as_ref().cloned(), + pub fn range(&self, start: usize, len: usize) -> Self { + Self { + data: self.data.slice(start..start + len), } } } -impl Index for BufferPtr { - type Output = T; +impl Index for ByteBufferPtr { + type Output = u8; - fn index(&self, index: usize) -> &T { - assert!(index < self.len); - &self.data[self.start + index] + fn index(&self, index: usize) -> &u8 { + &self.data[index] } } -impl Display for BufferPtr { +impl Display for ByteBufferPtr { fn fmt(&self, f: &mut Formatter) -> FmtResult { write!(f, "{:?}", self.data) } } -impl Drop for BufferPtr { - fn drop(&mut self) { - if let Some(ref mc) = self.mem_tracker { - if Arc::strong_count(&self.data) == 1 && Arc::weak_count(&self.data) == 0 { - mc.alloc(-(self.data.capacity() as i64)); - } - } - } -} - -impl AsRef<[u8]> for BufferPtr { +impl AsRef<[u8]> for ByteBufferPtr { #[inline] fn as_ref(&self) -> &[u8] { - &self.data[self.start..self.start + self.len] + &self.data } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_byte_buffer_mem_tracker() { - let mem_tracker = Arc::new(MemTracker::new()); - - let mut buffer = ByteBuffer::new().with_mem_tracker(mem_tracker.clone()); - buffer.set_data(vec![0; 10]); - assert_eq!(mem_tracker.memory_usage(), buffer.capacity() as i64); - buffer.set_data(vec![0; 20]); - let capacity = buffer.capacity() as i64; - assert_eq!(mem_tracker.memory_usage(), capacity); - - let max_capacity = { - let mut buffer2 = ByteBuffer::new().with_mem_tracker(mem_tracker.clone()); - buffer2.reserve(30); - assert_eq!( - mem_tracker.memory_usage(), - buffer2.capacity() as i64 + capacity - ); - buffer2.set_data(vec![0; 100]); - assert_eq!( - mem_tracker.memory_usage(), - buffer2.capacity() as i64 + capacity - ); - buffer2.capacity() as i64 + capacity - }; - - assert_eq!(mem_tracker.memory_usage(), capacity); - assert_eq!(mem_tracker.max_memory_usage(), max_capacity); - - buffer.reserve(40); - assert_eq!(mem_tracker.memory_usage(), buffer.capacity() as i64); - - buffer.consume(); - assert_eq!(mem_tracker.memory_usage(), buffer.capacity() as i64); +impl From> for ByteBufferPtr { + fn from(data: Vec) -> Self { + Self { data: data.into() } } +} - #[test] - fn test_byte_ptr_mem_tracker() { - let mem_tracker = Arc::new(MemTracker::new()); - - let mut buffer = ByteBuffer::new().with_mem_tracker(mem_tracker.clone()); - buffer.set_data(vec![0; 60]); - - { - let buffer_capacity = buffer.capacity() as i64; - let buf_ptr = buffer.consume(); - assert_eq!(mem_tracker.memory_usage(), buffer_capacity); - { - let buf_ptr1 = buf_ptr.all(); - { - let _ = buf_ptr.start_from(20); - assert_eq!(mem_tracker.memory_usage(), buffer_capacity); - } - assert_eq!(mem_tracker.memory_usage(), buffer_capacity); - let _ = buf_ptr1.range(30, 20); - assert_eq!(mem_tracker.memory_usage(), buffer_capacity); - } - assert_eq!(mem_tracker.memory_usage(), buffer_capacity); - } - assert_eq!(mem_tracker.memory_usage(), buffer.capacity() as i64); +impl From for ByteBufferPtr { + fn from(data: Bytes) -> Self { + Self { data } } +} - #[test] - fn test_byte_buffer() { - let mut buffer = ByteBuffer::new(); - assert_eq!(buffer.size(), 0); - assert_eq!(buffer.capacity(), 0); - - let mut buffer2 = ByteBuffer::new(); - buffer2.reserve(40); - assert_eq!(buffer2.size(), 0); - assert_eq!(buffer2.capacity(), 40); - - buffer.set_data((0..5).collect()); - assert_eq!(buffer.size(), 5); - assert_eq!(buffer[4], 4); - - buffer.set_data((0..20).collect()); - assert_eq!(buffer.size(), 20); - assert_eq!(buffer[10], 10); - - let expected: Vec = (0..20).collect(); - { - let data = buffer.data(); - assert_eq!(data, expected.as_slice()); - } - - buffer.reserve(40); - assert!(buffer.capacity() >= 40); - - let byte_ptr = buffer.consume(); - assert_eq!(buffer.size(), 0); - assert_eq!(byte_ptr.as_ref(), expected.as_slice()); - - let values: Vec = (0..30).collect(); - let _ = buffer.write(values.as_slice()); - let _ = buffer.flush(); - - assert_eq!(buffer.data(), values.as_slice()); - } +#[cfg(test)] +mod tests { + use super::*; #[test] fn test_byte_ptr() { let values = (0..50).collect(); let ptr = ByteBufferPtr::new(values); assert_eq!(ptr.len(), 50); - assert_eq!(ptr.start(), 0); assert_eq!(ptr[40], 40); let ptr2 = ptr.all(); assert_eq!(ptr2.len(), 50); - assert_eq!(ptr2.start(), 0); assert_eq!(ptr2[40], 40); let ptr3 = ptr.start_from(20); assert_eq!(ptr3.len(), 30); - assert_eq!(ptr3.start(), 20); assert_eq!(ptr3[0], 20); let ptr4 = ptr3.range(10, 10); assert_eq!(ptr4.len(), 10); - assert_eq!(ptr4.start(), 30); assert_eq!(ptr4[0], 30); let expected: Vec = (30..40).collect(); diff --git a/parquet/src/util/test_common/file_util.rs b/parquet/src/util/test_common/file_util.rs index 7393b55f1ed2..c2dcd677360d 100644 --- a/parquet/src/util/test_common/file_util.rs +++ b/parquet/src/util/test_common/file_util.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{env, fs, io::Write, path::PathBuf, str::FromStr}; +use std::{fs, path::PathBuf, str::FromStr}; /// Returns path to the test parquet file in 'data' directory pub fn get_test_path(file_name: &str) -> PathBuf { @@ -36,38 +36,3 @@ pub fn get_test_file(file_name: &str) -> fs::File { ) }) } - -/// Returns file handle for a temp file in 'target' directory with a provided content -pub fn get_temp_file(file_name: &str, content: &[u8]) -> fs::File { - // build tmp path to a file in "target/debug/testdata" - let mut path_buf = env::current_dir().unwrap(); - path_buf.push("target"); - path_buf.push("debug"); - path_buf.push("testdata"); - fs::create_dir_all(&path_buf).unwrap(); - path_buf.push(file_name); - - // write file content - let mut tmp_file = fs::File::create(path_buf.as_path()).unwrap(); - tmp_file.write_all(content).unwrap(); - tmp_file.sync_all().unwrap(); - - // return file handle for both read and write - let file = fs::OpenOptions::new() - .read(true) - .write(true) - .open(path_buf.as_path()); - assert!(file.is_ok()); - file.unwrap() -} - -pub fn get_temp_filename() -> PathBuf { - let mut path_buf = env::current_dir().unwrap(); - path_buf.push("target"); - path_buf.push("debug"); - path_buf.push("testdata"); - fs::create_dir_all(&path_buf).unwrap(); - path_buf.push(rand::random::().to_string()); - - path_buf -} diff --git a/parquet/src/util/test_common/mod.rs b/parquet/src/util/test_common/mod.rs index ed65bbe8a820..f0beb16ca954 100644 --- a/parquet/src/util/test_common/mod.rs +++ b/parquet/src/util/test_common/mod.rs @@ -25,8 +25,6 @@ pub use self::rand_gen::random_numbers; pub use self::rand_gen::random_numbers_range; pub use self::rand_gen::RandGen; -pub use self::file_util::get_temp_file; -pub use self::file_util::get_temp_filename; pub use self::file_util::get_test_file; pub use self::file_util::get_test_path; diff --git a/parquet/src/util/test_common/page_util.rs b/parquet/src/util/test_common/page_util.rs index 581845a3c1cf..ffa559f3fecb 100644 --- a/parquet/src/util/test_common/page_util.rs +++ b/parquet/src/util/test_common/page_util.rs @@ -25,13 +25,10 @@ use crate::encodings::levels::LevelEncoder; use crate::errors::Result; use crate::schema::types::{ColumnDescPtr, SchemaDescPtr}; use crate::util::memory::ByteBufferPtr; -use crate::util::memory::MemTracker; -use crate::util::memory::MemTrackerPtr; use crate::util::test_common::random_numbers_range; use rand::distributions::uniform::SampleUniform; use std::collections::VecDeque; use std::mem; -use std::sync::Arc; pub trait DataPageBuilder { fn add_rep_levels(&mut self, max_level: i16, rep_levels: &[i16]); @@ -50,7 +47,6 @@ pub trait DataPageBuilder { pub struct DataPageBuilderImpl { desc: ColumnDescPtr, encoding: Option, - mem_tracker: MemTrackerPtr, num_values: u32, buffer: Vec, rep_levels_byte_len: u32, @@ -66,7 +62,6 @@ impl DataPageBuilderImpl { DataPageBuilderImpl { desc, encoding: None, - mem_tracker: Arc::new(MemTracker::new()), num_values, buffer: vec![], rep_levels_byte_len: 0, @@ -122,7 +117,7 @@ impl DataPageBuilder for DataPageBuilderImpl { ); self.encoding = Some(encoding); let mut encoder: Box> = - get_encoder::(self.desc.clone(), encoding, self.mem_tracker.clone()) + get_encoder::(self.desc.clone(), encoding) .expect("get_encoder() should be OK"); encoder.put(values).expect("put() should be OK"); let encoded_values = encoder @@ -177,13 +172,13 @@ impl> InMemoryPageReader

{ } } -impl> PageReader for InMemoryPageReader

{ +impl + Send> PageReader for InMemoryPageReader

{ fn get_next_page(&mut self) -> Result> { Ok(self.page_iter.next()) } } -impl> Iterator for InMemoryPageReader

{ +impl + Send> Iterator for InMemoryPageReader

{ type Item = Result; fn next(&mut self) -> Option { @@ -223,7 +218,7 @@ impl>> Iterator for InMemoryPageIterator { } } -impl>> PageIterator for InMemoryPageIterator { +impl> + Send> PageIterator for InMemoryPageIterator { fn schema(&mut self) -> Result { Ok(self.schema.clone()) } @@ -252,8 +247,7 @@ pub fn make_pages( let max_def_level = desc.max_def_level(); let max_rep_level = desc.max_rep_level(); - let mem_tracker = Arc::new(MemTracker::new()); - let mut dict_encoder = DictEncoder::::new(desc.clone(), mem_tracker); + let mut dict_encoder = DictEncoder::::new(desc.clone()); for i in 0..num_pages { let mut num_values_cur_page = 0; diff --git a/parquet/src/util/test_common/rand_gen.rs b/parquet/src/util/test_common/rand_gen.rs index ea91b28d4963..d9c256577684 100644 --- a/parquet/src/util/test_common/rand_gen.rs +++ b/parquet/src/util/test_common/rand_gen.rs @@ -91,13 +91,8 @@ impl RandGen for ByteArrayType { impl RandGen for FixedLenByteArrayType { fn gen(len: i32) -> FixedLenByteArray { - let mut rng = thread_rng(); - let value_len = if len < 0 { - rng.gen_range(0..128) - } else { - len as usize - }; - let value = random_bytes(value_len); + assert!(len >= 0); + let value = random_bytes(len as usize); ByteArray::from(value).into() } } diff --git a/parquet/tests/boolean_writer.rs b/parquet/tests/boolean_writer.rs index b9d757e71a8e..dc2eccfbf3c3 100644 --- a/parquet/tests/boolean_writer.rs +++ b/parquet/tests/boolean_writer.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use parquet::column::writer::ColumnWriter; +use parquet::data_type::BoolType; use parquet::file::properties::WriterProperties; use parquet::file::reader::FileReader; use parquet::file::serialized_reader::SerializedFileReader; -use parquet::file::writer::FileWriter; use parquet::file::writer::SerializedFileWriter; use parquet::schema::parser::parse_message_type; use std::fs; @@ -53,25 +52,15 @@ fn it_writes_data_without_hanging() { while let Some(mut col_writer) = row_group_writer.next_column().expect("next column") { - match col_writer { - ColumnWriter::BoolColumnWriter(ref mut typed_writer) => { - typed_writer - .write_batch(&my_bool_values, None, None) - .expect("writing bool column"); - } - _ => { - panic!("only test boolean values"); - } - } - row_group_writer - .close_column(col_writer) - .expect("close column"); + col_writer + .typed::() + .write_batch(&my_bool_values, None, None) + .expect("writing bool column"); + + col_writer.close().expect("close column"); } let rg_md = row_group_writer.close().expect("close row group"); println!("total rows written: {}", rg_md.num_rows()); - writer - .close_row_group(row_group_writer) - .expect("close row groups"); } writer.close().expect("close writer"); diff --git a/parquet/tests/custom_writer.rs b/parquet/tests/custom_writer.rs deleted file mode 100644 index 0a57e79d9558..000000000000 --- a/parquet/tests/custom_writer.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::fs::File; -use std::{ - fs, - io::{prelude::*, SeekFrom}, - sync::Arc, -}; - -use parquet::file::writer::TryClone; -use parquet::{ - basic::Repetition, basic::Type, file::properties::WriterProperties, - file::writer::SerializedFileWriter, schema::types, -}; -use std::env; - -// Test creating some sort of custom writer to ensure the -// appropriate traits are exposed -struct CustomWriter { - file: File, -} - -impl Write for CustomWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.file.write(buf) - } - fn flush(&mut self) -> std::io::Result<()> { - self.file.flush() - } -} - -impl Seek for CustomWriter { - fn seek(&mut self, pos: SeekFrom) -> std::io::Result { - self.file.seek(pos) - } -} - -impl TryClone for CustomWriter { - fn try_clone(&self) -> std::io::Result { - use std::io::{Error, ErrorKind}; - Err(Error::new(ErrorKind::Other, "Clone not supported")) - } -} - -#[test] -fn test_custom_writer() { - let schema = Arc::new( - types::Type::group_type_builder("schema") - .with_fields(&mut vec![Arc::new( - types::Type::primitive_type_builder("col1", Type::INT32) - .with_repetition(Repetition::REQUIRED) - .build() - .unwrap(), - )]) - .build() - .unwrap(), - ); - let props = Arc::new(WriterProperties::builder().build()); - - let file = get_temp_file("test_custom_file_writer"); - let test_file = file.try_clone().unwrap(); - - let writer = CustomWriter { file }; - - // test is that this file can be created - let file_writer = SerializedFileWriter::new(writer, schema, props).unwrap(); - std::mem::drop(file_writer); - - // ensure the file now exists and has non zero size - let metadata = test_file.metadata().unwrap(); - assert!(metadata.len() > 0); -} - -/// Returns file handle for a temp file in 'target' directory with a provided content -fn get_temp_file(file_name: &str) -> fs::File { - // build tmp path to a file in "target/debug/testdata" - let mut path_buf = env::current_dir().unwrap(); - path_buf.push("target"); - path_buf.push("debug"); - path_buf.push("testdata"); - fs::create_dir_all(&path_buf).unwrap(); - path_buf.push(file_name); - - File::create(path_buf).unwrap() -} diff --git a/parquet_derive/Cargo.toml b/parquet_derive/Cargo.toml index 31ccba44a263..680074d08705 100644 --- a/parquet_derive/Cargo.toml +++ b/parquet_derive/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "parquet_derive" -version = "7.0.0-SNAPSHOT" +version = "16.0.0" license = "Apache-2.0" description = "Derive macros for the Rust implementation of Apache Parquet" homepage = "https://github.com/apache/arrow-rs" @@ -25,7 +25,8 @@ repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] keywords = [ "parquet" ] readme = "README.md" -edition = "2018" +edition = "2021" +rust-version = "1.57" [lib] proc-macro = true @@ -34,4 +35,4 @@ proc-macro = true proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["full", "extra-traits"] } -parquet = { path = "../parquet", version = "7.0.0-SNAPSHOT" } +parquet = { path = "../parquet", version = "16.0.0" } diff --git a/parquet_derive/README.md b/parquet_derive/README.md index d0804db2a1c6..4f390b0cd911 100644 --- a/parquet_derive/README.md +++ b/parquet_derive/README.md @@ -27,14 +27,17 @@ supported. Derive also has some support for the chrono time library. You must must enable the `chrono` feature to get this support. ## Usage + Add this to your Cargo.toml: + ```toml [dependencies] -parquet = "7.0.0-SNAPSHOT" -parquet_derive = "7.0.0-SNAPSHOT" +parquet = "16.0.0" +parquet_derive = "16.0.0" ``` and this to your crate root: + ```rust extern crate parquet; #[macro_use] extern crate parquet_derive; @@ -75,24 +78,29 @@ writer.close().unwrap(); ``` ## Features -- [X] Support writing `String`, `&str`, `bool`, `i32`, `f32`, `f64`, `Vec` + +- [x] Support writing `String`, `&str`, `bool`, `i32`, `f32`, `f64`, `Vec` - [ ] Support writing dictionaries -- [X] Support writing logical types like timestamp -- [X] Derive definition_levels for `Option` +- [x] Support writing logical types like timestamp +- [x] Derive definition_levels for `Option` - [ ] Derive definition levels for nested structures - [ ] Derive writing tuple struct - [ ] Derive writing `tuple` container types ## Requirements + - Same as `parquet-rs` ## Test + Testing a `*_derive` crate requires an intermediate crate. Go to `parquet_derive_test` and run `cargo test` for unit tests. ## Docs + To build documentation, run `cargo doc --no-deps`. To compile and view in the browser, run `cargo doc --no-deps --open`. ## License + Licensed under the Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0. diff --git a/parquet_derive/src/lib.rs b/parquet_derive/src/lib.rs index 6d751500e949..fc7af20ca3f1 100644 --- a/parquet_derive/src/lib.rs +++ b/parquet_derive/src/lib.rs @@ -25,7 +25,7 @@ extern crate quote; extern crate parquet; -use syn::{parse_macro_input, Data, DataStruct, DeriveInput}; +use ::syn::{parse_macro_input, Data, DataStruct, DeriveInput}; mod parquet_field; @@ -85,10 +85,7 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke Data::Union(_) => unimplemented!("Union currently is not supported"), }; - let field_infos: Vec<_> = fields - .iter() - .map(|f: &syn::Field| parquet_field::Field::from(f)) - .collect(); + let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect(); let writer_snippets: Vec = field_infos.iter().map(|x| x.writer_snippet()).collect(); @@ -101,9 +98,9 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke (quote! { impl #generics RecordWriter<#derived_for #generics> for &[#derived_for #generics] { - fn write_to_row_group( + fn write_to_row_group( &self, - row_group_writer: &mut Box + row_group_writer: &mut parquet::file::writer::SerializedRowGroupWriter<'_, W> ) -> Result<(), parquet::errors::ParquetError> { let mut row_group_writer = row_group_writer; let records = &self; // Used by all the writer snippets to be more clear @@ -113,7 +110,7 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke let mut some_column_writer = row_group_writer.next_column().unwrap(); if let Some(mut column_writer) = some_column_writer { #writer_snippets - row_group_writer.close_column(column_writer)?; + column_writer.close()?; } else { return Err(parquet::errors::ParquetError::General("Failed to get next column".into())) } diff --git a/parquet_derive/src/parquet_field.rs b/parquet_derive/src/parquet_field.rs index 36730c7713c5..835ac793e409 100644 --- a/parquet_derive/src/parquet_field.rs +++ b/parquet_derive/src/parquet_field.rs @@ -147,7 +147,7 @@ impl Field { // this expression just switches between non-nullable and nullable write statements let write_batch_expr = if definition_levels.is_some() { quote! { - if let #column_writer(ref mut typed) = column_writer { + if let #column_writer(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], Some(&definition_levels[..]), None)?; } else { panic!("Schema and struct disagree on type for {}", stringify!{#ident}) @@ -155,7 +155,7 @@ impl Field { } } else { quote! { - if let #column_writer(ref mut typed) = column_writer { + if let #column_writer(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], None, None)?; } else { panic!("Schema and struct disagree on type for {}", stringify!{#ident}) @@ -507,48 +507,48 @@ impl Type { match last_part.trim() { "bool" => quote! { None }, - "u8" => quote! { Some(LogicalType::INTEGER(IntType { + "u8" => quote! { Some(LogicalType::Integer { bit_width: 8, is_signed: false, - })) }, - "u16" => quote! { Some(LogicalType::INTEGER(IntType { + }) }, + "u16" => quote! { Some(LogicalType::Integer { bit_width: 16, is_signed: false, - })) }, - "u32" => quote! { Some(LogicalType::INTEGER(IntType { + }) }, + "u32" => quote! { Some(LogicalType::Integer { bit_width: 32, is_signed: false, - })) }, - "u64" => quote! { Some(LogicalType::INTEGER(IntType { + }) }, + "u64" => quote! { Some(LogicalType::Integer { bit_width: 64, is_signed: false, - })) }, - "i8" => quote! { Some(LogicalType::INTEGER(IntType { + }) }, + "i8" => quote! { Some(LogicalType::Integer { bit_width: 8, is_signed: true, - })) }, - "i16" => quote! { Some(LogicalType::INTEGER(IntType { + }) }, + "i16" => quote! { Some(LogicalType::Integer { bit_width: 16, is_signed: true, - })) }, + }) }, "i32" | "i64" => quote! { None }, "usize" => { - quote! { Some(LogicalType::INTEGER(IntType { + quote! { Some(LogicalType::Integer { bit_width: usize::BITS as i8, is_signed: false - })) } + }) } } "isize" => { - quote! { Some(LogicalType::INTEGER(IntType { + quote! { Some(LogicalType::Integer { bit_width: usize::BITS as i8, is_signed: true - })) } + }) } } - "NaiveDate" => quote! { Some(LogicalType::DATE(Default::default())) }, + "NaiveDate" => quote! { Some(LogicalType::Date) }, "NaiveDateTime" => quote! { None }, "f32" | "f64" => quote! { None }, - "String" | "str" => quote! { Some(LogicalType::STRING(Default::default())) }, - "Uuid" => quote! { Some(LogicalType::UUID(Default::default())) }, + "String" | "str" => quote! { Some(LogicalType::String) }, + "Uuid" => quote! { Some(LogicalType::Uuid) }, f => unimplemented!("{} currently is not supported", f), } } @@ -666,7 +666,7 @@ mod test { { let vals : Vec < _ > = records . iter ( ) . map ( | rec | rec . counter as i64 ) . collect ( ); - if let parquet::column::writer::ColumnWriter::Int64ColumnWriter ( ref mut typed ) = column_writer { + if let parquet::column::writer::ColumnWriter::Int64ColumnWriter ( ref mut typed ) = column_writer.untyped() { typed . write_batch ( & vals [ .. ] , None , None ) ?; } else { panic!("Schema and struct disagree on type for {}" , stringify!{ counter } ) @@ -703,7 +703,7 @@ mod test { } }).collect(); - if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter ( ref mut typed ) = column_writer { + if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter ( ref mut typed ) = column_writer.untyped() { typed . write_batch ( & vals [ .. ] , Some(&definition_levels[..]) , None ) ? ; } else { panic!("Schema and struct disagree on type for {}" , stringify ! { optional_str } ) @@ -727,7 +727,7 @@ mod test { } }).collect(); - if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter ( ref mut typed ) = column_writer { + if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter ( ref mut typed ) = column_writer.untyped() { typed . write_batch ( & vals [ .. ] , Some(&definition_levels[..]) , None ) ? ; } else { panic!("Schema and struct disagree on type for {}" , stringify ! { optional_string } ) @@ -750,7 +750,7 @@ mod test { } }).collect(); - if let parquet::column::writer::ColumnWriter::Int32ColumnWriter ( ref mut typed ) = column_writer { + if let parquet::column::writer::ColumnWriter::Int32ColumnWriter ( ref mut typed ) = column_writer.untyped() { typed . write_batch ( & vals [ .. ] , Some(&definition_levels[..]) , None ) ? ; } else { panic!("Schema and struct disagree on type for {}" , stringify ! { optional_dumb_int } ) @@ -769,7 +769,7 @@ mod test { }; let fields = extract_fields(snippet); - let processed: Vec<_> = fields.iter().map(|field| Field::from(field)).collect(); + let processed: Vec<_> = fields.iter().map(Field::from).collect(); let column_writers: Vec<_> = processed .iter() @@ -800,7 +800,7 @@ mod test { }; let fields = extract_fields(snippet); - let processed: Vec<_> = fields.iter().map(|field| Field::from(field)).collect(); + let processed: Vec<_> = fields.iter().map(Field::from).collect(); assert_eq!(processed.len(), 3); assert_eq!( @@ -840,8 +840,7 @@ mod test { }; let fields = extract_fields(snippet); - let converted_fields: Vec<_> = - fields.iter().map(|field| Type::from(field)).collect(); + let converted_fields: Vec<_> = fields.iter().map(Type::from).collect(); let inner_types: Vec<_> = converted_fields .iter() .map(|field| field.inner_type()) @@ -878,8 +877,7 @@ mod test { }; let fields = extract_fields(snippet); - let converted_fields: Vec<_> = - fields.iter().map(|field| Type::from(field)).collect(); + let converted_fields: Vec<_> = fields.iter().map(Type::from).collect(); let physical_types: Vec<_> = converted_fields .iter() .map(|ty| ty.physical_type()) @@ -911,8 +909,7 @@ mod test { }; let fields = extract_fields(snippet); - let converted_fields: Vec<_> = - fields.iter().map(|field| Type::from(field)).collect(); + let converted_fields: Vec<_> = fields.iter().map(Type::from).collect(); assert_eq!( converted_fields, @@ -938,7 +935,7 @@ mod test { }; let fields = extract_fields(snippet); - let types: Vec<_> = fields.iter().map(|field| Type::from(field)).collect(); + let types: Vec<_> = fields.iter().map(Type::from).collect(); assert_eq!( types, @@ -978,7 +975,7 @@ mod test { assert_eq!(when.writer_snippet().to_string(),(quote!{ { let vals : Vec<_> = records.iter().map(|rec| rec.henceforth.timestamp_millis() ).collect(); - if let parquet::column::writer::ColumnWriter::Int64ColumnWriter(ref mut typed) = column_writer { + if let parquet::column::writer::ColumnWriter::Int64ColumnWriter(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], None, None) ?; } else { panic!("Schema and struct disagree on type for {}" , stringify!{ henceforth }) @@ -998,7 +995,7 @@ mod test { } }).collect(); - if let parquet::column::writer::ColumnWriter::Int64ColumnWriter(ref mut typed) = column_writer { + if let parquet::column::writer::ColumnWriter::Int64ColumnWriter(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], Some(&definition_levels[..]), None) ?; } else { panic!("Schema and struct disagree on type for {}" , stringify!{ maybe_happened }) @@ -1021,7 +1018,7 @@ mod test { assert_eq!(when.writer_snippet().to_string(),(quote!{ { let vals : Vec<_> = records.iter().map(|rec| rec.henceforth.signed_duration_since(chrono::NaiveDate::from_ymd(1970, 1, 1)).num_days() as i32).collect(); - if let parquet::column::writer::ColumnWriter::Int32ColumnWriter(ref mut typed) = column_writer { + if let parquet::column::writer::ColumnWriter::Int32ColumnWriter(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], None, None) ?; } else { panic!("Schema and struct disagree on type for {}" , stringify!{ henceforth }) @@ -1041,7 +1038,7 @@ mod test { } }).collect(); - if let parquet::column::writer::ColumnWriter::Int32ColumnWriter(ref mut typed) = column_writer { + if let parquet::column::writer::ColumnWriter::Int32ColumnWriter(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], Some(&definition_levels[..]), None) ?; } else { panic!("Schema and struct disagree on type for {}" , stringify!{ maybe_happened }) @@ -1064,7 +1061,7 @@ mod test { assert_eq!(when.writer_snippet().to_string(),(quote!{ { let vals : Vec<_> = records.iter().map(|rec| (&rec.unique_id.to_string()[..]).into() ).collect(); - if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter(ref mut typed) = column_writer { + if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], None, None) ?; } else { panic!("Schema and struct disagree on type for {}" , stringify!{ unique_id }) @@ -1084,7 +1081,7 @@ mod test { } }).collect(); - if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter(ref mut typed) = column_writer { + if let parquet::column::writer::ColumnWriter::ByteArrayColumnWriter(ref mut typed) = column_writer.untyped() { typed.write_batch(&vals[..], Some(&definition_levels[..]), None) ?; } else { panic!("Schema and struct disagree on type for {}" , stringify!{ maybe_unique_id }) diff --git a/parquet_derive/test/dependency/default-features/Cargo.toml b/parquet_derive/test/dependency/default-features/Cargo.toml deleted file mode 100644 index 2a98dd49657e..000000000000 --- a/parquet_derive/test/dependency/default-features/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "defeault-features" -description = "Models a user application of parquet_derive that uses no additional features of arrow" -version = "0.1.0" -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -parquet_derive = { path = "../../../../parquet_derive", version = "7.0.0-SNAPSHOT" } - -# Keep this out of the default workspace -[workspace] diff --git a/parquet_derive/test/dependency/default-features/src/main.rs b/parquet_derive/test/dependency/default-features/src/main.rs deleted file mode 100644 index e7a11a969c03..000000000000 --- a/parquet_derive/test/dependency/default-features/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello, world!"); -} diff --git a/parquet_derive_test/Cargo.toml b/parquet_derive_test/Cargo.toml index f48530b81da5..7bf6db6730e6 100644 --- a/parquet_derive_test/Cargo.toml +++ b/parquet_derive_test/Cargo.toml @@ -17,17 +17,18 @@ [package] name = "parquet_derive_test" -version = "7.0.0-SNAPSHOT" +version = "16.0.0" license = "Apache-2.0" description = "Integration test package for parquet-derive" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] keywords = [ "parquet" ] -edition = "2018" +edition = "2021" publish = false +rust-version = "1.57" [dependencies] -parquet = { path = "../parquet", version = "7.0.0-SNAPSHOT" } -parquet_derive = { path = "../parquet_derive", version = "7.0.0-SNAPSHOT" } -chrono = "0.4.19" \ No newline at end of file +parquet = { path = "../parquet", version = "16.0.0" } +parquet_derive = { path = "../parquet_derive", version = "16.0.0" } +chrono = "0.4.19" diff --git a/parquet_derive_test/src/lib.rs b/parquet_derive_test/src/lib.rs index 2b7c060bbd5f..189802b9a527 100644 --- a/parquet_derive_test/src/lib.rs +++ b/parquet_derive_test/src/lib.rs @@ -54,10 +54,7 @@ mod tests { use super::*; use parquet::{ - file::{ - properties::WriterProperties, - writer::{FileWriter, SerializedFileWriter}, - }, + file::{properties::WriterProperties, writer::SerializedFileWriter}, schema::parser::parse_message_type, }; use std::{env, fs, io::Write, sync::Arc}; @@ -133,7 +130,7 @@ mod tests { let mut row_group = writer.next_row_group().unwrap(); drs.as_slice().write_to_row_group(&mut row_group).unwrap(); - writer.close_row_group(row_group).unwrap(); + row_group.close().unwrap(); writer.close().unwrap(); } diff --git a/rustfmt.toml b/rustfmt.toml index c49cccdd9f5d..4522e520a469 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -edition = "2018" +edition = "2021" max_width = 90 # ignore generated files diff --git a/testing b/testing index b658b087767b..d315f7985207 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit b658b087767b041b2081766814655b4dd5a9a439 +Subproject commit d315f7985207d2d67fc2c8e41053e9d97d573f4b