diff --git a/.config/nextest.toml b/.config/nextest.toml new file mode 100644 index 0000000..bac2ec1 --- /dev/null +++ b/.config/nextest.toml @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[profile.default] +slow-timeout = { period = "60s", terminate-after = 3, grace-period = "30s" } diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d62b1e9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + paths-ignore: + - 'docs/**' + - 'config/**' + - '**.md' + - '.dockerignore' + - 'docker/**' + - '.gitignore' + push: + branches: + - develop + - main + paths-ignore: + - 'docs/**' + - 'config/**' + - '**.md' + - '.dockerignore' + - 'docker/**' + - '.gitignore' + workflow_dispatch: + +name: CI + +env: + RUST_TOOLCHAIN: stable + +jobs: + typos: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: crate-ci/typos@v1.13.10 + + check: + name: Check + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + timeout-minutes: 60 + strategy: + matrix: + features: + - '' + - '--no-default-features' + - '--all-features' + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.RUST_TOOLCHAIN }} + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Run cargo check + run: cargo check --workspace --all-targets ${{ matrix.features }} + + toml: + name: Toml Check + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.RUST_TOOLCHAIN }} + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Install taplo + run: cargo install taplo-cli --version ^0.8 --locked + - name: Run taplo + run: taplo format --check + + fmt: + name: Rustfmt + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.RUST_TOOLCHAIN }} + components: rustfmt + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Run cargo fmt + run: cargo fmt --all -- --check + + clippy: + name: Clippy + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + timeout-minutes: 60 + strategy: + matrix: + features: + - '' + - '--no-default-features' + - '--all-features' + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.RUST_TOOLCHAIN }} + components: clippy + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Run cargo clippy + run: cargo clippy --workspace --all-targets ${{ matrix.features }} -- -D warnings + + license-header: + name: Check license header + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Check license headers + uses: korandoru/hawkeye@v5 + + cargo-deny: + name: Cargo Deny License Check + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: EmbarkStudios/cargo-deny-action@v1 + with: + command: check license + + coverage: + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + timeout-minutes: 60 + needs: [clippy] + steps: + - uses: actions/checkout@v3 + - uses: KyleMayes/install-llvm-action@v1 + with: + version: "14.0" + - name: Install toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.RUST_TOOLCHAIN }} + components: llvm-tools-preview + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Install latest nextest release + uses: taiki-e/install-action@nextest + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + - name: Collect coverage data + run: cargo llvm-cov nextest --workspace --lcov --output-path lcov.info --all-features + env: + CARGO_BUILD_RUSTFLAGS: "-C link-arg=-fuse-ld=lld" + RUST_BACKTRACE: 1 + CARGO_INCREMENTAL: 0 + UNITTEST_LOG_DIR: "__unittest_logs" + - name: Codecov upload + uses: codecov/codecov-action@v2 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./lcov.info + flags: rust + fail_ci_if_error: false + verbose: true diff --git a/.gitignore b/.gitignore index d01bd1a..3f1214c 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,14 @@ Cargo.lock # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +.idea/ + +venv +/benchmark_data + +private/ +*.txt + +/perf.* +/flamegraph.svg + diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b7cf9cf --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "orc-rust" +version = "0.4.1" +edition = "2021" +homepage = "https://github.com/datafusion-contrib/datafusion-orc" +repository = "https://github.com/datafusion-contrib/datafusion-orc" +authors = ["Weny ", "Jeffrey Vo "] +license = "Apache-2.0" +description = "Implementation of Apache ORC file format using Apache Arrow in-memory format" +keywords = ["arrow", "orc", "arrow-rs", "datafusion"] +include = ["src/**/*.rs", "Cargo.toml"] +rust-version = "1.73" + +[package.metadata.docs.rs] +all-features = true + +[dependencies] +arrow = { version = "52", features = ["prettyprint", "chrono-tz"] } +bytemuck = { version = "1.18.0", features = ["must_cast"] } +bytes = "1.4" +chrono = { version = "0.4.37", default-features = false, features = ["std"] } +chrono-tz = "0.9" +fallible-streaming-iterator = { version = "0.1" } +flate2 = "1" +lz4_flex = "0.11" +lzokay-native = "0.1" +num = "0.4.1" +prost = { version = "0.12" } +snafu = "0.8" +snap = "1.1" +zstd = "0.12" + +# async support +async-trait = { version = "0.1.77", optional = true } +futures = { version = "0.3", optional = true, default-features = false, features = ["std"] } +futures-util = { version = "0.3", optional = true } +tokio = { version = "1.28", optional = true, features = [ + "io-util", + "sync", + "fs", + "macros", + "rt", + "rt-multi-thread", +] } + +# cli +anyhow = { version = "1.0", optional = true } +clap = { version = "4.5.4", features = ["derive"], optional = true } + +# opendal +opendal = { version = "0.48", optional = true, default-features = false } + +[dev-dependencies] +arrow-ipc = { version = "52.0.0", features = ["lz4"] } +arrow-json = "52.0.0" +criterion = { version = "0.5", default-features = false, features = ["async_tokio"] } +opendal = { version = "0.48", default-features = false, features = ["services-memory"] } +pretty_assertions = "1.3.0" +proptest = "1.0.0" +serde_json = { version = "1.0", default-features = false, features = ["std"] } + +[features] +default = ["async"] + +async = ["async-trait", "futures", "futures-util", "tokio"] +cli = ["anyhow", "clap"] +# Enable opendal support. +opendal = ["dep:opendal"] + +[[bench]] +name = "arrow_reader" +harness = false +required-features = ["async"] +# Some issue when publishing and path isn't specified, so adding here +path = "./benches/arrow_reader.rs" + +[profile.bench] +debug = true + +[[bin]] +name = "orc-metadata" +required-features = ["cli"] + +[[bin]] +name = "orc-export" +required-features = ["cli"] + +[[bin]] +name = "orc-stats" +required-features = ["cli"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..829157b --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +.PHONY: fmt +fmt: ## Format all the Rust code. + cargo fmt --all + + +.PHONY: clippy +clippy: ## Check clippy rules. + cargo clippy --workspace --all-targets -- -D warnings + + +.PHONY: fmt-toml +fmt-toml: ## Format all TOML files. + taplo format --option "indent_string= " + +.PHONY: check-toml +check-toml: ## Check all TOML files. + taplo format --check --option "indent_string= " \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..6532c32 --- /dev/null +++ b/README.md @@ -0,0 +1,124 @@ +[![test](https://github.com/datafusion-contrib/datafusion-orc/actions/workflows/ci.yml/badge.svg)](https://github.com/datafusion-contrib/datafusion-orc/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/WenyXu/orc-rs/branch/main/graph/badge.svg?token=2CSHZX02XM)](https://codecov.io/gh/WenyXu/orc-rs) +[![Crates.io](https://img.shields.io/crates/v/orc-rust)](https://crates.io/crates/orc-rust) +[![Crates.io](https://img.shields.io/crates/d/orc-rust)](https://crates.io/crates/orc-rust) + +# orc-rust + +A native Rust implementation of the [Apache ORC](https://orc.apache.org) file format, +providing API's to read data into [Apache Arrow](https://arrow.apache.org) in-memory arrays. + +See the [documentation](https://docs.rs/orc-rust/latest/orc_rust/) for examples on how to use this crate. + +## Supported features + +This crate currently only supports reading ORC files into Arrow arrays. Write support is planned +(see [Roadmap](#roadmap)). The below features listed relate only to reading ORC files. +At this time, we aim to support the [ORCv1](https://orc.apache.org/specification/ORCv1/) specification only. + +- Read synchronously & asynchronously (using Tokio) +- All compression types (Zlib, Snappy, Lzo, Lz4, Zstd) +- All ORC data types +- All encodings +- Rudimentary support for retrieving statistics +- Retrieving user metadata into Arrow schema metadata + +## Roadmap + +The long term vision for this crate is to be feature complete enough to be donated to the +[arrow-rs](https://github.com/apache/arrow-rs) project. + +The following lists the rough roadmap for features to be implemented, from highest to lowest priority. + +- Performance enhancements +- DataFusion integration +- Predicate pushdown +- Row indices +- Bloom filters +- Write from Arrow arrays +- Encryption + +A non-Arrow API interface is not planned at the moment. Feel free to raise an issue if there is such +a use case. + +## Version compatibility + +No guarantees are provided about stability across versions. We will endeavour to keep the top level API's +(`ArrowReader` and `ArrowStreamReader`) as stable as we can, but other API's provided may change as we +explore the interface we want the library to expose. + +Versions will be released on an ad-hoc basis (with no fixed schedule). + +## Mapping ORC types to Arrow types + +The following table lists how ORC data types are read into Arrow data types: + +| ORC Data Type | Arrow Data Type | Notes | +| ----------------- | -------------------------- | ----- | +| Boolean | Boolean | | +| TinyInt | Int8 | | +| SmallInt | Int16 | | +| Int | Int32 | | +| BigInt | Int64 | | +| Float | Float32 | | +| Double | Float64 | | +| String | Utf8 | | +| Char | Utf8 | | +| VarChar | Utf8 | | +| Binary | Binary | | +| Decimal | Decimal128 | | +| Date | Date32 | | +| Timestamp | Timestamp(Nanosecond, None) | ¹ | +| Timestamp instant | Timestamp(Nanosecond, UTC) | ¹ | +| Struct | Struct | | +| List | List | | +| Map | Map | | +| Union | Union(_, Sparse) | ² | + +¹: `ArrowReaderBuilder::with_schema` allows configuring different time units or decoding to +`Decimal128(38, 9)` (i128 of non-leap nanoseconds since UNIX epoch). +Overflows may happen while decoding to a non-Seconds time unit, and results in `OrcError`. +Loss of precision may happen while decoding to a non-Nanosecond time unit, and results in `OrcError`. +`Decimal128(38, 9)` avoids both overflows and loss of precision. + +²: Currently only supports a maximum of 127 variants + +## Contributing + +All contributions are welcome! Feel free to raise an issue if you have a feature request, bug report, +or a question. Feel free to raise a Pull Request without raising an issue first, as long as the Pull +Request is descriptive enough. + +Some tools we use in addition to the standard `cargo` that require installation are: + +- [taplo](https://taplo.tamasfe.dev/) +- [typos](https://crates.io/crates/typos) + +```shell +cargo install typos-cli +cargo install taplo-cli +``` + +```shell +# Building the crate +cargo build + +# Running the test suite +cargo test + +# Simple benchmarks +cargo bench + +# Formatting TOML files +taplo format + +# Detect any typos in the codebase +typos +``` + +To regenerate/update the [proto.rs](src/proto.rs) file, execute the [regen.sh](regen.sh) script. + +```shell +./regen.sh +``` + diff --git a/benches/arrow_reader.rs b/benches/arrow_reader.rs new file mode 100644 index 0000000..10f2899 --- /dev/null +++ b/benches/arrow_reader.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs::File; + +use criterion::{criterion_group, criterion_main, Criterion}; +use futures_util::TryStreamExt; +use orc_rust::arrow_reader::ArrowReaderBuilder; + +fn basic_path(path: &str) -> String { + let dir = env!("CARGO_MANIFEST_DIR"); + format!("{}/tests/basic/data/{}", dir, path) +} + +// demo-12-zlib.orc +// 1,920,800 total rows +// Columns: +// - Int32 +// - Dictionary(UInt64, Utf8) +// - Dictionary(UInt64, Utf8) +// - Dictionary(UInt64, Utf8) +// - Int32 +// - Dictionary(UInt64, Utf8) +// - Int32 +// - Int32 +// - Int32 + +async fn async_read_all() { + let file = "demo-12-zlib.orc"; + let file_path = basic_path(file); + let f = tokio::fs::File::open(file_path).await.unwrap(); + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .build_async(); + let _ = reader.try_collect::>().await.unwrap(); +} + +fn sync_read_all() { + let file = "demo-12-zlib.orc"; + let file_path = basic_path(file); + let f = File::open(file_path).unwrap(); + let reader = ArrowReaderBuilder::try_new(f).unwrap().build(); + let _ = reader.collect::, _>>().unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("sync reader", |b| b.iter(sync_read_all)); + c.bench_function("async reader", |b| { + b.to_async(tokio::runtime::Runtime::new().unwrap()) + .iter(async_read_all); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..f6ab7b3 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# codecov config +coverage: + status: + project: + default: + threshold: 1% + patch: off +ignore: + - "**/error*.rs" # ignore all error.rs files +comment: # this is a top-level key + layout: "diff" diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..da1913e --- /dev/null +++ b/deny.toml @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[licenses] +allow = [ + "Apache-2.0", + "Apache-2.0 WITH LLVM-exception", + "MIT", + "BSD-2-Clause", + "BSD-3-Clause", + "CC0-1.0", +] +exceptions = [ + { allow = ["Unicode-DFS-2016"], name = "unicode-ident" }, +] +version = 2 diff --git a/format/orc_proto.proto b/format/orc_proto.proto new file mode 100644 index 0000000..ff71659 --- /dev/null +++ b/format/orc_proto.proto @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +syntax = "proto2"; + +package orc.proto; + +option java_package = "org.apache.orc"; + +message IntegerStatistics { + optional sint64 minimum = 1; + optional sint64 maximum = 2; + optional sint64 sum = 3; +} + +message DoubleStatistics { + optional double minimum = 1; + optional double maximum = 2; + optional double sum = 3; +} + +message StringStatistics { + optional string minimum = 1; + optional string maximum = 2; + // sum will store the total length of all strings in a stripe + optional sint64 sum = 3; + // If the minimum or maximum value was longer than 1024 bytes, store a lower or upper + // bound instead of the minimum or maximum values above. + optional string lowerBound = 4; + optional string upperBound = 5; +} + +message BucketStatistics { + repeated uint64 count = 1 [packed=true]; +} + +message DecimalStatistics { + optional string minimum = 1; + optional string maximum = 2; + optional string sum = 3; +} + +message DateStatistics { + // min,max values saved as days since epoch + optional sint32 minimum = 1; + optional sint32 maximum = 2; +} + +message TimestampStatistics { + // min,max values saved as milliseconds since epoch + optional sint64 minimum = 1; + optional sint64 maximum = 2; + optional sint64 minimumUtc = 3; + optional sint64 maximumUtc = 4; + // store the lower 6 TS digits for min/max to achieve nanosecond precision + optional int32 minimumNanos = 5; + optional int32 maximumNanos = 6; +} + +message BinaryStatistics { + // sum will store the total binary blob length in a stripe + optional sint64 sum = 1; +} + +// Statistics for list and map +message CollectionStatistics { + optional uint64 minChildren = 1; + optional uint64 maxChildren = 2; + optional uint64 totalChildren = 3; +} + +message ColumnStatistics { + optional uint64 numberOfValues = 1; + optional IntegerStatistics intStatistics = 2; + optional DoubleStatistics doubleStatistics = 3; + optional StringStatistics stringStatistics = 4; + optional BucketStatistics bucketStatistics = 5; + optional DecimalStatistics decimalStatistics = 6; + optional DateStatistics dateStatistics = 7; + optional BinaryStatistics binaryStatistics = 8; + optional TimestampStatistics timestampStatistics = 9; + optional bool hasNull = 10; + optional uint64 bytesOnDisk = 11; + optional CollectionStatistics collectionStatistics = 12; +} + +message RowIndexEntry { + repeated uint64 positions = 1 [packed=true]; + optional ColumnStatistics statistics = 2; +} + +message RowIndex { + repeated RowIndexEntry entry = 1; +} + +message BloomFilter { + optional uint32 numHashFunctions = 1; + repeated fixed64 bitset = 2; + optional bytes utf8bitset = 3; +} + +message BloomFilterIndex { + repeated BloomFilter bloomFilter = 1; +} + +message Stream { + // if you add new index stream kinds, you need to make sure to update + // StreamName to ensure it is added to the stripe in the right area + enum Kind { + PRESENT = 0; + DATA = 1; + LENGTH = 2; + DICTIONARY_DATA = 3; + DICTIONARY_COUNT = 4; + SECONDARY = 5; + ROW_INDEX = 6; + BLOOM_FILTER = 7; + BLOOM_FILTER_UTF8 = 8; + // Virtual stream kinds to allocate space for encrypted index and data. + ENCRYPTED_INDEX = 9; + ENCRYPTED_DATA = 10; + + // stripe statistics streams + STRIPE_STATISTICS = 100; + // A virtual stream kind that is used for setting the encryption IV. + FILE_STATISTICS = 101; + } + optional Kind kind = 1; + optional uint32 column = 2; + optional uint64 length = 3; +} + +message ColumnEncoding { + enum Kind { + DIRECT = 0; + DICTIONARY = 1; + DIRECT_V2 = 2; + DICTIONARY_V2 = 3; + } + optional Kind kind = 1; + optional uint32 dictionarySize = 2; + + // The encoding of the bloom filters for this column: + // 0 or missing = none or original + // 1 = ORC-135 (utc for timestamps) + optional uint32 bloomEncoding = 3; +} + +message StripeEncryptionVariant { + repeated Stream streams = 1; + repeated ColumnEncoding encoding = 2; +} + +// each stripe looks like: +// index streams +// unencrypted +// variant 1..N +// data streams +// unencrypted +// variant 1..N +// footer + +message StripeFooter { + repeated Stream streams = 1; + repeated ColumnEncoding columns = 2; + optional string writerTimezone = 3; + // one for each column encryption variant + repeated StripeEncryptionVariant encryption = 4; +} + +// the file tail looks like: +// encrypted stripe statistics: ColumnarStripeStatistics (order by variant) +// stripe statistics: Metadata +// footer: Footer +// postscript: PostScript +// psLen: byte + +message StringPair { + optional string key = 1; + optional string value = 2; +} + +message Type { + enum Kind { + BOOLEAN = 0; + BYTE = 1; + SHORT = 2; + INT = 3; + LONG = 4; + FLOAT = 5; + DOUBLE = 6; + STRING = 7; + BINARY = 8; + TIMESTAMP = 9; + LIST = 10; + MAP = 11; + STRUCT = 12; + UNION = 13; + DECIMAL = 14; + DATE = 15; + VARCHAR = 16; + CHAR = 17; + TIMESTAMP_INSTANT = 18; + } + optional Kind kind = 1; + repeated uint32 subtypes = 2 [packed=true]; + repeated string fieldNames = 3; + optional uint32 maximumLength = 4; + optional uint32 precision = 5; + optional uint32 scale = 6; + repeated StringPair attributes = 7; +} + +message StripeInformation { + // the global file offset of the start of the stripe + optional uint64 offset = 1; + // the number of bytes of index + optional uint64 indexLength = 2; + // the number of bytes of data + optional uint64 dataLength = 3; + // the number of bytes in the stripe footer + optional uint64 footerLength = 4; + // the number of rows in this stripe + optional uint64 numberOfRows = 5; + // If this is present, the reader should use this value for the encryption + // stripe id for setting the encryption IV. Otherwise, the reader should + // use one larger than the previous stripe's encryptStripeId. + // For unmerged ORC files, the first stripe will use 1 and the rest of the + // stripes won't have it set. For merged files, the stripe information + // will be copied from their original files and thus the first stripe of + // each of the input files will reset it to 1. + // Note that 1 was choosen, because protobuf v3 doesn't serialize + // primitive types that are the default (eg. 0). + optional uint64 encryptStripeId = 6; + // For each encryption variant, the new encrypted local key to use + // until we find a replacement. + repeated bytes encryptedLocalKeys = 7; +} + +message UserMetadataItem { + optional string name = 1; + optional bytes value = 2; +} + +// StripeStatistics (1 per a stripe), which each contain the +// ColumnStatistics for each column. +// This message type is only used in ORC v0 and v1. +message StripeStatistics { + repeated ColumnStatistics colStats = 1; +} + +// This message type is only used in ORC v0 and v1. +message Metadata { + repeated StripeStatistics stripeStats = 1; +} + +// In ORC v2 (and for encrypted columns in v1), each column has +// their column statistics written separately. +message ColumnarStripeStatistics { + // one value for each stripe in the file + repeated ColumnStatistics colStats = 1; +} + +enum EncryptionAlgorithm { + UNKNOWN_ENCRYPTION = 0; // used for detecting future algorithms + AES_CTR_128 = 1; + AES_CTR_256 = 2; +} + +message FileStatistics { + repeated ColumnStatistics column = 1; +} + +// How was the data masked? This isn't necessary for reading the file, but +// is documentation about how the file was written. +message DataMask { + // the kind of masking, which may include third party masks + optional string name = 1; + // parameters for the mask + repeated string maskParameters = 2; + // the unencrypted column roots this mask was applied to + repeated uint32 columns = 3 [packed = true]; +} + +// Information about the encryption keys. +message EncryptionKey { + optional string keyName = 1; + optional uint32 keyVersion = 2; + optional EncryptionAlgorithm algorithm = 3; +} + +// The description of an encryption variant. +// Each variant is a single subtype that is encrypted with a single key. +message EncryptionVariant { + // the column id of the root + optional uint32 root = 1; + // The master key that was used to encrypt the local key, referenced as + // an index into the Encryption.key list. + optional uint32 key = 2; + // the encrypted key for the file footer + optional bytes encryptedKey = 3; + // the stripe statistics for this variant + repeated Stream stripeStatistics = 4; + // encrypted file statistics as a FileStatistics + optional bytes fileStatistics = 5; +} + +// Which KeyProvider encrypted the local keys. +enum KeyProviderKind { + UNKNOWN = 0; + HADOOP = 1; + AWS = 2; + GCP = 3; + AZURE = 4; +} + +message Encryption { + // all of the masks used in this file + repeated DataMask mask = 1; + // all of the keys used in this file + repeated EncryptionKey key = 2; + // The encrypted variants. + // Readers should prefer the first variant that the user has access to + // the corresponding key. If they don't have access to any of the keys, + // they should get the unencrypted masked data. + repeated EncryptionVariant variants = 3; + // How are the local keys encrypted? + optional KeyProviderKind keyProvider = 4; +} + +enum CalendarKind { + UNKNOWN_CALENDAR = 0; + // A hybrid Julian/Gregorian calendar with a cutover point in October 1582. + JULIAN_GREGORIAN = 1; + // A calendar that extends the Gregorian calendar back forever. + PROLEPTIC_GREGORIAN = 2; +} + +message Footer { + optional uint64 headerLength = 1; + optional uint64 contentLength = 2; + repeated StripeInformation stripes = 3; + repeated Type types = 4; + repeated UserMetadataItem metadata = 5; + optional uint64 numberOfRows = 6; + repeated ColumnStatistics statistics = 7; + optional uint32 rowIndexStride = 8; + + // Each implementation that writes ORC files should register for a code + // 0 = ORC Java + // 1 = ORC C++ + // 2 = Presto + // 3 = Scritchley Go from https://github.com/scritchley/orc + // 4 = Trino + optional uint32 writer = 9; + + // information about the encryption in this file + optional Encryption encryption = 10; + optional CalendarKind calendar = 11; + + // informative description about the version of the software that wrote + // the file. It is assumed to be within a given writer, so for example + // ORC 1.7.2 = "1.7.2". It may include suffixes, such as "-SNAPSHOT". + optional string softwareVersion = 12; +} + +enum CompressionKind { + NONE = 0; + ZLIB = 1; + SNAPPY = 2; + LZO = 3; + LZ4 = 4; + ZSTD = 5; +} + +// Serialized length must be less that 255 bytes +message PostScript { + optional uint64 footerLength = 1; + optional CompressionKind compression = 2; + optional uint64 compressionBlockSize = 3; + // the version of the file format + // [0, 11] = Hive 0.11 + // [0, 12] = Hive 0.12 + repeated uint32 version = 4 [packed = true]; + optional uint64 metadataLength = 5; + + // The version of the writer that wrote the file. This number is + // updated when we make fixes or large changes to the writer so that + // readers can detect whether a given bug is present in the data. + // + // Only the Java ORC writer may use values under 6 (or missing) so that + // readers that predate ORC-202 treat the new writers correctly. Each + // writer should assign their own sequence of versions starting from 6. + // + // Version of the ORC Java writer: + // 0 = original + // 1 = HIVE-8732 fixed (fixed stripe/file maximum statistics & + // string statistics use utf8 for min/max) + // 2 = HIVE-4243 fixed (use real column names from Hive tables) + // 3 = HIVE-12055 added (vectorized writer implementation) + // 4 = HIVE-13083 fixed (decimals write present stream correctly) + // 5 = ORC-101 fixed (bloom filters use utf8 consistently) + // 6 = ORC-135 fixed (timestamp statistics use utc) + // 7 = ORC-517 fixed (decimal64 min/max incorrect) + // 8 = ORC-203 added (trim very long string statistics) + // 9 = ORC-14 added (column encryption) + // + // Version of the ORC C++ writer: + // 6 = original + // + // Version of the Presto writer: + // 6 = original + // + // Version of the Scritchley Go writer: + // 6 = original + // + // Version of the Trino writer: + // 6 = original + // + optional uint32 writerVersion = 6; + + // the number of bytes in the encrypted stripe statistics + optional uint64 stripeStatisticsLength = 7; + + // Leave this last in the record + optional string magic = 8000; +} + +// The contents of the file tail that must be serialized. +// This gets serialized as part of OrcSplit, also used by footer cache. +message FileTail { + optional PostScript postscript = 1; + optional Footer footer = 2; + optional uint64 fileLength = 3; + optional uint64 postscriptLength = 4; +} diff --git a/gen/Cargo.toml b/gen/Cargo.toml new file mode 100644 index 0000000..b9f2767 --- /dev/null +++ b/gen/Cargo.toml @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "gen" +description = "Code generation for datafusion-orc" +version = "0.1.0" +edition = "2021" +rust-version = "1.70" +license = "Apache-2.0" +publish = false + +[dependencies] +prost-build = { version = "=0.12.1", default-features = false } diff --git a/gen/src/main.rs b/gen/src/main.rs new file mode 100644 index 0000000..c1cba08 --- /dev/null +++ b/gen/src/main.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fs::{remove_file, OpenOptions}; +use std::io::{Read, Write}; + +fn main() -> Result<(), Box> { + prost_build::Config::new() + .out_dir("src/") + .compile_well_known_types() + .extern_path(".google.protobuf", "::pbjson_types") + .compile_protos(&["format/orc_proto.proto"], &["format"])?; + + // read file contents to string + let mut file = OpenOptions::new().read(true).open("src/orc.proto.rs")?; + let mut buffer = String::new(); + file.read_to_string(&mut buffer)?; + // append warning that file was auto-generate + let mut file = OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open("src/proto.rs")?; + file.write_all("// This file was automatically generated through the regen.sh script, and should not be edited.\n\n".as_bytes())?; + file.write_all(buffer.as_bytes())?; + + // since we renamed file to proto.rs to avoid period in the name + remove_file("src/orc.proto.rs")?; + + // As the proto file is checked in, the build should not fail if the file is not found + Ok(()) +} diff --git a/licenserc.toml b/licenserc.toml new file mode 100644 index 0000000..c9a40f4 --- /dev/null +++ b/licenserc.toml @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +headerPath = "Apache-2.0-ASF.txt" + +excludes = [ + "**/*.md" +] + diff --git a/regen.sh b/regen.sh new file mode 100755 index 0000000..87086c8 --- /dev/null +++ b/regen.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR && cargo run --manifest-path gen/Cargo.toml +rustfmt src/proto.rs diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..9fd9403 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,17 @@ +## Generate data + +Setup the virtual environment with dependencies on PyArrow, PySpark and PyOrc +to generate the reference data: + +```bash +# Run once +./scripts/setup-venv.sh +./scripts/prepare-test-data.sh +``` + +Then execute the tests: + +```bash +cargo test +``` + diff --git a/scripts/convert_tpch.py b/scripts/convert_tpch.py new file mode 100644 index 0000000..524201e --- /dev/null +++ b/scripts/convert_tpch.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pyarrow as pa +from pyarrow import orc +from pyarrow import csv + +tables = [ + "customer", + "lineitem", + "nation", + "orders", + "part", + "partsupp", + "region", + "supplier" +] + +# Datatypes based on: +# https://github.com/apache/datafusion/blob/3b93cc952b889cec2364ad2490ae18ecddb3ca49/benchmarks/src/tpch/mod.rs#L50-L134 +schemas = { + "customer": pa.schema([ + pa.field("c_custkey", pa.int64()), + pa.field("c_name", pa.string()), + pa.field("c_address", pa.string()), + pa.field("c_nationkey", pa.int64()), + pa.field("c_phone", pa.string()), + pa.field("c_acctbal", pa.decimal128(15, 2)), + pa.field("c_mktsegment", pa.string()), + pa.field("c_comment", pa.string()), + ]), + "lineitem": pa.schema([ + pa.field("l_orderkey", pa.int64()), + pa.field("l_partkey", pa.int64()), + pa.field("l_suppkey", pa.int64()), + pa.field("l_linenumber", pa.int32()), + pa.field("l_quantity", pa.decimal128(15, 2)), + pa.field("l_extendedprice", pa.decimal128(15, 2)), + pa.field("l_discount", pa.decimal128(15, 2)), + pa.field("l_tax", pa.decimal128(15, 2)), + pa.field("l_returnflag", pa.string()), + pa.field("l_linestatus", pa.string()), + pa.field("l_shipdate", pa.date32()), + pa.field("l_commitdate", pa.date32()), + pa.field("l_receiptdate", pa.date32()), + pa.field("l_shipinstruct", pa.string()), + pa.field("l_shipmode", pa.string()), + pa.field("l_comment", pa.string()), + ]), + "nation": pa.schema([ + pa.field("n_nationkey", pa.int64()), + pa.field("n_name", pa.string()), + pa.field("n_regionkey", pa.int64()), + pa.field("n_comment", pa.string()), + ]), + "orders": pa.schema([ + pa.field("o_orderkey", pa.int64()), + pa.field("o_custkey", pa.int64()), + pa.field("o_orderstatus", pa.string()), + pa.field("o_totalprice", pa.decimal128(15, 2)), + pa.field("o_orderdate", pa.date32()), + pa.field("o_orderpriority", pa.string()), + pa.field("o_clerk", pa.string()), + pa.field("o_shippriority", pa.int32()), + pa.field("o_comment", pa.string()), + ]), + "part": pa.schema([ + pa.field("p_partkey", pa.int64()), + pa.field("p_name", pa.string()), + pa.field("p_mfgr", pa.string()), + pa.field("p_brand", pa.string()), + pa.field("p_type", pa.string()), + pa.field("p_size", pa.int32()), + pa.field("p_container", pa.string()), + pa.field("p_retailprice", pa.decimal128(15, 2)), + pa.field("p_comment", pa.string()), + ]), + "partsupp": pa.schema([ + pa.field("ps_partkey", pa.int64()), + pa.field("ps_suppkey", pa.int64()), + pa.field("ps_availqty", pa.int32()), + pa.field("ps_supplycost", pa.decimal128(15, 2)), + pa.field("ps_comment", pa.string()), + ]), + "region": pa.schema([ + pa.field("r_regionkey", pa.int64()), + pa.field("r_name", pa.string()), + pa.field("r_comment", pa.string()), + ]), + "supplier": pa.schema([ + pa.field("s_suppkey", pa.int64()), + pa.field("s_name", pa.string()), + pa.field("s_address", pa.string()), + pa.field("s_nationkey", pa.int64()), + pa.field("s_phone", pa.string()), + pa.field("s_acctbal", pa.decimal128(15, 2)), + pa.field("s_comment", pa.string()), + ]), +} + +for table in tables: + schema = schemas[table] + tbl = csv.read_csv( + f"benchmark_data/{table}.tbl", + read_options=csv.ReadOptions(column_names=schema.names), + parse_options=csv.ParseOptions(delimiter="|"), + convert_options=csv.ConvertOptions(column_types=schema), + ) + orc.write_table(tbl, f"benchmark_data/{table}.orc") diff --git a/scripts/generate-tpch.sh b/scripts/generate-tpch.sh new file mode 100755 index 0000000..78931e0 --- /dev/null +++ b/scripts/generate-tpch.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BASE_DIR=$SCRIPT_DIR/.. +DATA_DIR=$BASE_DIR/benchmark_data +VENV_BIN=$BASE_DIR/venv/bin + +SCALE_FACTOR=${1:-1} + +# Generate TBL data +mkdir -p $DATA_DIR +docker run --rm \ + -v $DATA_DIR:/data \ + ghcr.io/scalytics/tpch-docker:main -vf -s $SCALE_FACTOR +# Removing trailing | +sed -i 's/.$//' benchmark_data/*.tbl +$VENV_BIN/python $SCRIPT_DIR/convert_tpch.py +echo "Done" diff --git a/scripts/generate_arrow.py b/scripts/generate_arrow.py new file mode 100644 index 0000000..6515f21 --- /dev/null +++ b/scripts/generate_arrow.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Requires pyarrow to be installed +import glob +from pyarrow import orc, feather + +dir = "tests/integration/data" + +files = glob.glob(f"{dir}/expected/*") +files = [file.removeprefix(f"{dir}/expected/").removesuffix(".jsn.gz") for file in files] + +ignore_files = [ + "TestOrcFile.testTimestamp" # Root data type isn't struct +] + +files = [file for file in files if file not in ignore_files] + +for file in files: + print(f"Converting {file} from ORC to feather") + table = orc.read_table(f"{dir}/{file}.orc") + feather.write_feather(table, f"{dir}/expected_arrow/{file}.feather") diff --git a/scripts/generate_orc.py b/scripts/generate_orc.py new file mode 100644 index 0000000..a962bf8 --- /dev/null +++ b/scripts/generate_orc.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import shutil +import glob +from datetime import date as dt +from decimal import Decimal as Dec +from pyspark.sql import SparkSession +from pyspark.sql.types import * + +dir = "tests/basic/data" + +# We're using Spark because it supports lzo compression writing +# (PyArrow supports all except lzo writing) + +spark = SparkSession.builder.getOrCreate() + +# TODO: how to do char and varchar? +# TODO: struct, list, map, union +df = spark.createDataFrame( + [ # bool, int8, int16, int32, int64, float32, float64, decimal, binary, utf8, date32 + ( None, None, None, None, None, None, None, None, None, None, None), + ( True, 0, 0, 0, 0, 0.0, 0.0, Dec(0), "".encode(), "", dt(1970, 1, 1)), + (False, 1, 1, 1, 1, 1.0, 1.0, Dec(1), "a".encode(), "a", dt(1970, 1, 2)), + (False, -1, -1, -1, -1, -1.0, -1.0, Dec(-1), " ".encode(), " ", dt(1969, 12, 31)), + ( True, 127, (1 << 15) - 1, (1 << 31) - 1, (1 << 63) - 1, float("inf"), float("inf"), Dec(123456789.12345), "encode".encode(), "encode", dt(9999, 12, 31)), + ( True, -128, -(1 << 15), -(1 << 31), -(1 << 63), float("-inf"), float("-inf"), Dec(-999999999.99999), "decode".encode(), "decode", dt(1582, 10, 15)), + ( True, 50, 50, 50, 50, 3.1415927, 3.14159265359, Dec(-31256.123), "大熊和奏".encode(), "大熊和奏", dt(1582, 10, 16)), + ( True, 51, 51, 51, 51, -3.1415927, -3.14159265359, Dec(1241000), "斉藤朱夏".encode(), "斉藤朱夏", dt(2000, 1, 1)), + ( True, 52, 52, 52, 52, 1.1, 1.1, Dec(1.1), "鈴原希実".encode(), "鈴原希実", dt(3000, 12, 31)), + (False, 53, 53, 53, 53, -1.1, -1.1, Dec(0.99999), "🤔".encode(), "🤔", dt(1900, 1, 1)), + ( None, None, None, None, None, None, None, None, None, None, None), + ], + StructType( + [ + StructField("boolean", BooleanType()), + StructField( "int8", ByteType()), + StructField( "int16", ShortType()), + StructField( "int32", IntegerType()), + StructField( "int64", LongType()), + StructField("float32", FloatType()), + StructField("float64", DoubleType()), + StructField("decimal", DecimalType(15, 5)), + StructField( "binary", BinaryType()), + StructField( "utf8", StringType()), + StructField( "date32", DateType()), + ] + ), +).coalesce(1) + +compression = ["none", "snappy", "zlib", "lzo", "zstd", "lz4"] +for c in compression: + df.write.format("orc")\ + .option("compression", c)\ + .mode("overwrite")\ + .save(f"{dir}/alltypes.{c}") + # Since Spark saves into a directory + # Move out and rename the expected single ORC file (because of coalesce above) + orc_file = glob.glob(f"{dir}/alltypes.{c}/*.orc")[0] + shutil.move(orc_file, f"{dir}/alltypes.{c}.orc") + shutil.rmtree(f"{dir}/alltypes.{c}") diff --git a/scripts/generate_orc_timestamps.py b/scripts/generate_orc_timestamps.py new file mode 100644 index 0000000..aac13f0 --- /dev/null +++ b/scripts/generate_orc_timestamps.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime as dttm +import pyarrow as pa +from pyarrow import orc +from pyarrow import parquet +import pyorc + +dir = "tests/basic/data" + +schema = pa.schema([ + pa.field('timestamp_notz', pa.timestamp("ns")), + pa.field('timestamp_utc', pa.timestamp("ns", tz="UTC")), +]) + +# TODO test with other non-UTC timezones +arr = pa.array([ + None, + dttm(1970, 1, 1, 0, 0, 0), + dttm(1970, 1, 2, 23, 59, 59), + dttm(1969, 12, 31, 23, 59, 59), + dttm(2262, 4, 11, 11, 47, 16), + dttm(2001, 4, 13, 2, 14, 0), + dttm(2000, 1, 1, 23, 10, 10), + dttm(1900, 1, 1, 14, 25, 14), +]) +table = pa.Table.from_arrays([arr, arr], schema=schema) +orc.write_table(table, f"{dir}/pyarrow_timestamps.orc") + + +# pyarrow overflows when trying to write this, so we have to use pyorc instead +class TimestampConverter: + @staticmethod + def from_orc(obj, tz): + return obj + @staticmethod + def to_orc(obj, tz): + return obj +schema = pyorc.Struct( + id=pyorc.Int(), + timestamp=pyorc.Timestamp() +) +with open(f"{dir}/overflowing_timestamps.orc", "wb") as f: + with pyorc.Writer( + f, + schema, + converters={pyorc.TypeKind.TIMESTAMP: TimestampConverter}, + ) as writer: + writer.write((1, (12345678, 0))) + writer.write((2, (-62135596800, 0))) + writer.write((3, (12345678, 0))) diff --git a/scripts/prepare-test-data.sh b/scripts/prepare-test-data.sh new file mode 100755 index 0000000..4583313 --- /dev/null +++ b/scripts/prepare-test-data.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BASE_DIR=$SCRIPT_DIR/.. +VENV_BIN=$BASE_DIR/venv/bin + +cd $BASE_DIR +$VENV_BIN/python $SCRIPT_DIR/write.py +$VENV_BIN/python $SCRIPT_DIR/generate_orc.py +$VENV_BIN/python $SCRIPT_DIR/generate_orc_timestamps.py +$VENV_BIN/python $SCRIPT_DIR/generate_arrow.py + +echo "Done" + diff --git a/scripts/setup-venv.sh b/scripts/setup-venv.sh new file mode 100755 index 0000000..3e8bff7 --- /dev/null +++ b/scripts/setup-venv.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BASE_DIR=$SCRIPT_DIR/.. +VENV_BIN=$BASE_DIR/venv/bin + +python3 -m venv $BASE_DIR/venv + +$VENV_BIN/pip install -U pyorc pyspark pyarrow + +echo "Done" + diff --git a/scripts/write.py b/scripts/write.py new file mode 100644 index 0000000..e4d08ad --- /dev/null +++ b/scripts/write.py @@ -0,0 +1,213 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Copied from https://github.com/DataEngineeringLabs/orc-format/blob/416490db0214fc51d53289253c0ee91f7fc9bc17/write.py +import random +import datetime +import pyorc + +dir = "tests/basic/data" + +data = { + "a": [1.0, 2.0, None, 4.0, 5.0], + "b": [True, False, None, True, False], + "str_direct": ["a", "cccccc", None, "ddd", "ee"], + "d": ["a", "bb", None, "ccc", "ddd"], + "e": ["ddd", "cc", None, "bb", "a"], + "f": ["aaaaa", "bbbbb", None, "ccccc", "ddddd"], + "int_short_repeated": [5, 5, None, 5, 5], + "int_neg_short_repeated": [-5, -5, None, -5, -5], + "int_delta": [1, 2, None, 4, 5], + "int_neg_delta": [5, 4, None, 2, 1], + "int_direct": [1, 6, None, 3, 2], + "int_neg_direct": [-1, -6, None, -3, -2], + "bigint_direct": [1, 6, None, 3, 2], + "bigint_neg_direct": [-1, -6, None, -3, -2], + "bigint_other": [5, -5, 1, 5, 5], + "utf8_increase": ["a", "bb", "ccc", "dddd", "eeeee"], + "utf8_decrease": ["eeeee", "dddd", "ccc", "bb", "a"], + "timestamp_simple": [datetime.datetime(2023, 4, 1, 20, 15, 30, 2000), datetime.datetime.fromtimestamp(int('1629617204525777000')/1000000000), datetime.datetime(2023, 1, 1), datetime.datetime(2023, 2, 1), datetime.datetime(2023, 3, 1)], + "date_simple": [datetime.date(2023, 4, 1), datetime.date(2023, 3, 1), datetime.date(2023, 1, 1), datetime.date(2023, 2, 1), datetime.date(2023, 3, 1)], + "tinyint_simple": [-1, None, 1, 127, -127] +} + +def infer_schema(data): + schema = "struct<" + for key, value in data.items(): + dt = type(value[0]) + if dt == float: + dt = "float" + elif dt == int: + dt = "int" + elif dt == bool: + dt = "boolean" + elif dt == str: + dt = "string" + elif dt == dict: + dt = infer_schema(value[0]) + elif key.startswith("timestamp"): + dt = "timestamp" + elif key.startswith("date"): + dt = "date" + else: + print(key,value,dt) + raise NotImplementedError + if key.startswith("double"): + dt = "double" + if key.startswith("bigint"): + dt = "bigint" + if key.startswith("tinyint"): + dt = "tinyint" + schema += key + ":" + dt + "," + + schema = schema[:-1] + ">" + return schema + + + +def _write( + schema: str, + data, + file_name: str, + compression=pyorc.CompressionKind.NONE, + dict_key_size_threshold=0.0, +): + output = open(file_name, "wb") + writer = pyorc.Writer( + output, + schema, + dict_key_size_threshold=dict_key_size_threshold, + # use a small number to ensure that compression crosses value boundaries + compression_block_size=32, + compression=compression, + ) + num_rows = len(list(data.values())[0]) + for x in range(num_rows): + row = tuple(values[x] for values in data.values()) + writer.write(row) + writer.close() + + with open(file_name, "rb") as f: + reader = pyorc.Reader(f) + list(reader) + +nested_struct = { + "nest": [ + (1.0,True), + (3.0,None), + (None,None), + None, + (-3.0,None) + ], +} + +_write("struct>", nested_struct, f"{dir}/nested_struct.orc") + + +nested_array = { + "value": [ + [1, None, 3, 43, 5], + [5, None, 32, 4, 15], + [16, None, 3, 4, 5, 6], + None, + [3, None], + ], +} + +_write("struct>", nested_array, f"{dir}/nested_array.orc") + + +nested_array_float = { + "value": [ + [1.0, 3.0], + [None, 2.0], + ], +} + +_write("struct>", nested_array_float, f"{dir}/nested_array_float.orc") + +nested_array_struct = { + "value": [ + [(1.0, 1, "01"), (2.0, 2, "02")], + [None, (3.0, 3, "03")], + ], +} + +_write("struct>>", nested_array_struct, f"{dir}/nested_array_struct.orc") + +nested_map = { + "map": [ + {"zero": 0, "one": 1}, + None, + {"two": 2, "tree": 3}, + {"one": 1, "two": 2, "nill": None}, + ], +} + +_write("struct>", nested_map, f"{dir}/nested_map.orc") + +nested_map_struct = { + "map": [ + {"01": (1.0, 1, "01"), "02": (2.0, 1, "02")}, + None, + {"03": (3.0, 3, "03"), "04": (4.0, 4, "04")}, + ], +} + +_write("struct>>", nested_map_struct, f"{dir}/nested_map_struct.orc") + + +_write( + infer_schema(data), + data, + f"{dir}/test.orc", +) + +data_boolean = { + "long": [True] * 32, +} + +_write("struct", data_boolean, f"{dir}/long_bool.orc") + +_write("struct", data_boolean, f"{dir}/long_bool_gzip.orc", pyorc.CompressionKind.ZLIB) + +data_dict = { + "dict": ["abcd", "efgh"] * 32, +} + +_write("struct", data_dict, f"{dir}/string_long.orc") + +data_dict = { + "dict": ["abc", "efgh"] * 32, +} + +_write("struct", data_dict, f"{dir}/string_dict.orc", dict_key_size_threshold=0.1) + +_write("struct", data_dict, f"{dir}/string_dict_gzip.orc", pyorc.CompressionKind.ZLIB) + +data_dict = { + "dict": ["abcd", "efgh"] * (10**4 // 2), +} + +_write("struct", data_dict, f"{dir}/string_long_long.orc") +_write("struct", data_dict, f"{dir}/string_long_long_gzip.orc", pyorc.CompressionKind.ZLIB) + +long_f32 = { + "dict": [random.uniform(0, 1) for _ in range(10**6)], +} + +_write("struct", long_f32, f"{dir}/f32_long_long_gzip.orc", pyorc.CompressionKind.ZLIB) diff --git a/src/array_decoder/decimal.rs b/src/array_decoder/decimal.rs new file mode 100644 index 0000000..40e766e --- /dev/null +++ b/src/array_decoder/decimal.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::buffer::NullBuffer; +use arrow::datatypes::Decimal128Type; +use snafu::ResultExt; + +use crate::encoding::decimal::UnboundedVarintStreamDecoder; +use crate::encoding::integer::get_rle_reader; +use crate::encoding::PrimitiveValueDecoder; +use crate::error::ArrowSnafu; +use crate::proto::stream::Kind; +use crate::stripe::Stripe; +use crate::{column::Column, error::Result}; + +use super::{ArrayBatchDecoder, PresentDecoder, PrimitiveArrayDecoder}; + +pub fn new_decimal_decoder( + column: &Column, + stripe: &Stripe, + precision: u32, + fixed_scale: u32, +) -> Result> { + let varint_iter = stripe.stream_map().get(column, Kind::Data); + let varint_iter = Box::new(UnboundedVarintStreamDecoder::new(varint_iter)); + + // Scale is specified on a per varint basis (in addition to being encoded in the type) + let scale_iter = stripe.stream_map().get(column, Kind::Secondary); + let scale_iter = get_rle_reader::(column, scale_iter)?; + + let present = PresentDecoder::from_stripe(stripe, column); + + let iter = DecimalScaleRepairDecoder { + varint_iter, + scale_iter, + fixed_scale, + }; + let iter = Box::new(iter); + + Ok(Box::new(DecimalArrayDecoder::new( + precision as u8, + fixed_scale as i8, + iter, + present, + ))) +} + +/// Wrapper around PrimitiveArrayDecoder to allow specifying the precision and scale +/// of the output decimal array. +pub struct DecimalArrayDecoder { + precision: u8, + scale: i8, + inner: PrimitiveArrayDecoder, +} + +impl DecimalArrayDecoder { + pub fn new( + precision: u8, + scale: i8, + iter: Box + Send>, + present: Option, + ) -> Self { + let inner = PrimitiveArrayDecoder::::new(iter, present); + Self { + precision, + scale, + inner, + } + } +} + +impl ArrayBatchDecoder for DecimalArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let array = self + .inner + .next_primitive_batch(batch_size, parent_present)? + .with_precision_and_scale(self.precision, self.scale) + .context(ArrowSnafu)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } +} + +/// This iter fixes the scales of the varints decoded as scale is specified on a per +/// varint basis, and needs to align with type specified scale +struct DecimalScaleRepairDecoder { + varint_iter: Box + Send>, + scale_iter: Box + Send>, + fixed_scale: u32, +} + +impl PrimitiveValueDecoder for DecimalScaleRepairDecoder { + fn decode(&mut self, out: &mut [i128]) -> Result<()> { + // TODO: can probably optimize, reuse buffers? + let mut varint = vec![0; out.len()]; + let mut scale = vec![0; out.len()]; + self.varint_iter.decode(&mut varint)?; + self.scale_iter.decode(&mut scale)?; + for (index, (&varint, &scale)) in varint.iter().zip(scale.iter()).enumerate() { + out[index] = fix_i128_scale(varint, self.fixed_scale, scale); + } + Ok(()) + } +} + +fn fix_i128_scale(i: i128, fixed_scale: u32, varying_scale: i32) -> i128 { + // TODO: Verify with C++ impl in ORC repo, which does this cast + // Not sure why scale stream can be signed if it gets casted to unsigned anyway + // https://github.com/apache/orc/blob/0014bec1e4cdd1206f5bae4f5c2000b9300c6eb1/c%2B%2B/src/ColumnReader.cc#L1459-L1476 + let varying_scale = varying_scale as u32; + match fixed_scale.cmp(&varying_scale) { + Ordering::Less => { + // fixed_scale < varying_scale + // Current scale of number is greater than scale of the array type + // So need to divide to align the scale + // TODO: this differs from C++ implementation, need to verify + let scale_factor = varying_scale - fixed_scale; + // TODO: replace with lookup table? + let scale_factor = 10_i128.pow(scale_factor); + i / scale_factor + } + Ordering::Equal => i, + Ordering::Greater => { + // fixed_scale > varying_scale + // Current scale of number is smaller than scale of the array type + // So need to multiply to align the scale + // TODO: this differs from C++ implementation, need to verify + let scale_factor = fixed_scale - varying_scale; + // TODO: replace with lookup table? + let scale_factor = 10_i128.pow(scale_factor); + i * scale_factor + } + } +} diff --git a/src/array_decoder/list.rs b/src/array_decoder/list.rs new file mode 100644 index 0000000..7a5ccfb --- /dev/null +++ b/src/array_decoder/list.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, ListArray}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{Field, FieldRef}; +use snafu::ResultExt; + +use crate::array_decoder::derive_present_vec; +use crate::column::Column; +use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::PrimitiveValueDecoder; +use crate::proto::stream::Kind; + +use crate::error::{ArrowSnafu, Result}; +use crate::stripe::Stripe; + +use super::{array_decoder_factory, ArrayBatchDecoder, PresentDecoder}; + +pub struct ListArrayDecoder { + inner: Box, + present: Option, + lengths: Box + Send>, + field: FieldRef, +} + +impl ListArrayDecoder { + pub fn new(column: &Column, field: Arc, stripe: &Stripe) -> Result { + let present = PresentDecoder::from_stripe(stripe, column); + + let child = &column.children()[0]; + let inner = array_decoder_factory(child, field.clone(), stripe)?; + + let reader = stripe.stream_map().get(column, Kind::Length); + let lengths = get_unsigned_rle_reader(column, reader); + + Ok(Self { + inner, + present, + lengths, + field, + }) + } +} + +impl ArrayBatchDecoder for ListArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let present = + derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?; + + let mut lengths = vec![0; batch_size]; + if let Some(present) = &present { + self.lengths.decode_spaced(&mut lengths, present)?; + } else { + self.lengths.decode(&mut lengths)?; + } + let total_length: i64 = lengths.iter().sum(); + // Fetch child array as one Array with total_length elements + let child_array = self.inner.next_batch(total_length as usize, None)?; + let offsets = OffsetBuffer::from_lengths(lengths.into_iter().map(|l| l as usize)); + let null_buffer = present.map(NullBuffer::from); + + let array = ListArray::try_new(self.field.clone(), offsets, child_array, null_buffer) + .context(ArrowSnafu)?; + let array = Arc::new(array); + Ok(array) + } +} diff --git a/src/array_decoder/map.rs b/src/array_decoder/map.rs new file mode 100644 index 0000000..7c01988 --- /dev/null +++ b/src/array_decoder/map.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, MapArray, StructArray}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{Field, Fields}; +use snafu::ResultExt; + +use crate::array_decoder::derive_present_vec; +use crate::column::Column; +use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::PrimitiveValueDecoder; +use crate::error::{ArrowSnafu, Result}; +use crate::proto::stream::Kind; +use crate::stripe::Stripe; + +use super::{array_decoder_factory, ArrayBatchDecoder, PresentDecoder}; + +pub struct MapArrayDecoder { + keys: Box, + values: Box, + present: Option, + lengths: Box + Send>, + fields: Fields, +} + +impl MapArrayDecoder { + pub fn new( + column: &Column, + keys_field: Arc, + values_field: Arc, + stripe: &Stripe, + ) -> Result { + let present = PresentDecoder::from_stripe(stripe, column); + + let keys_column = &column.children()[0]; + let keys = array_decoder_factory(keys_column, keys_field.clone(), stripe)?; + + let values_column = &column.children()[1]; + let values = array_decoder_factory(values_column, values_field.clone(), stripe)?; + + let reader = stripe.stream_map().get(column, Kind::Length); + let lengths = get_unsigned_rle_reader(column, reader); + + let fields = Fields::from(vec![keys_field, values_field]); + + Ok(Self { + keys, + values, + present, + lengths, + fields, + }) + } +} + +impl ArrayBatchDecoder for MapArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let present = + derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?; + + let mut lengths = vec![0; batch_size]; + if let Some(present) = &present { + self.lengths.decode_spaced(&mut lengths, present)?; + } else { + self.lengths.decode(&mut lengths)?; + } + let total_length: i64 = lengths.iter().sum(); + // Fetch key and value arrays, each with total_length elements + // Fetch child array as one Array with total_length elements + let keys_array = self.keys.next_batch(total_length as usize, None)?; + let values_array = self.values.next_batch(total_length as usize, None)?; + // Compose the keys + values array into a StructArray with two entries + let entries = + StructArray::try_new(self.fields.clone(), vec![keys_array, values_array], None) + .context(ArrowSnafu)?; + let offsets = OffsetBuffer::from_lengths(lengths.into_iter().map(|l| l as usize)); + + let field = Arc::new(Field::new_struct("entries", self.fields.clone(), false)); + let array = + MapArray::try_new(field, offsets, entries, present, false).context(ArrowSnafu)?; + let array = Arc::new(array); + Ok(array) + } +} diff --git a/src/array_decoder/mod.rs b/src/array_decoder/mod.rs new file mode 100644 index 0000000..df4eeda --- /dev/null +++ b/src/array_decoder/mod.rs @@ -0,0 +1,451 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::ArrowNativeTypeOp; +use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::{DataType as ArrowDataType, Field}; +use arrow::datatypes::{ + Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, +}; +use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use snafu::{ensure, ResultExt}; + +use crate::column::Column; +use crate::encoding::boolean::BooleanDecoder; +use crate::encoding::byte::ByteRleDecoder; +use crate::encoding::float::FloatDecoder; +use crate::encoding::integer::get_rle_reader; +use crate::encoding::PrimitiveValueDecoder; +use crate::error::{ + self, MismatchedSchemaSnafu, Result, UnexpectedSnafu, UnsupportedTypeVariantSnafu, +}; +use crate::proto::stream::Kind; +use crate::schema::DataType; +use crate::stripe::Stripe; + +use self::decimal::new_decimal_decoder; +use self::list::ListArrayDecoder; +use self::map::MapArrayDecoder; +use self::string::{new_binary_decoder, new_string_decoder}; +use self::struct_decoder::StructArrayDecoder; +use self::timestamp::{new_timestamp_decoder, new_timestamp_instant_decoder}; +use self::union::UnionArrayDecoder; + +mod decimal; +mod list; +mod map; +mod string; +mod struct_decoder; +mod timestamp; +mod union; + +pub trait ArrayBatchDecoder: Send { + /// Used as base for decoding ORC columns into Arrow arrays. Provide an input `batch_size` + /// which specifies the upper limit of the number of values returned in the output array. + /// + /// If parent nested type (e.g. Struct) indicates a null in it's PRESENT stream, + /// then the child doesn't have a value (similar to other nullability). So we need + /// to take care to insert these null values as Arrow requires the child to hold + /// data in the null slot of the child. + // TODO: encode nullability in generic -> for a given column in a stripe, we will always know + // upfront if we need to bother with nulls or not, so we don't need to keep checking this + // for every invocation of next_batch + // NOTE: null parent may have non-null child, so would still have to account for this + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result; +} + +struct PrimitiveArrayDecoder { + iter: Box + Send>, + present: Option, +} + +impl PrimitiveArrayDecoder { + pub fn new( + iter: Box + Send>, + present: Option, + ) -> Self { + Self { iter, present } + } + + fn next_primitive_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result> { + let present = + derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?; + let mut data = vec![T::Native::ZERO; batch_size]; + match present { + Some(present) => { + self.iter.decode_spaced(data.as_mut_slice(), &present)?; + let array = PrimitiveArray::::new(data.into(), Some(present)); + Ok(array) + } + None => { + self.iter.decode(data.as_mut_slice())?; + let array = PrimitiveArray::::from_iter_values(data); + Ok(array) + } + } + } +} + +impl ArrayBatchDecoder for PrimitiveArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let array = self.next_primitive_batch(batch_size, parent_present)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } +} + +type Int64ArrayDecoder = PrimitiveArrayDecoder; +type Int32ArrayDecoder = PrimitiveArrayDecoder; +type Int16ArrayDecoder = PrimitiveArrayDecoder; +type Int8ArrayDecoder = PrimitiveArrayDecoder; +type Float32ArrayDecoder = PrimitiveArrayDecoder; +type Float64ArrayDecoder = PrimitiveArrayDecoder; +type DateArrayDecoder = PrimitiveArrayDecoder; // TODO: does ORC encode as i64 or i32? + +struct BooleanArrayDecoder { + iter: Box + Send>, + present: Option, +} + +impl BooleanArrayDecoder { + pub fn new( + iter: Box + Send>, + present: Option, + ) -> Self { + Self { iter, present } + } +} + +impl ArrayBatchDecoder for BooleanArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let present = + derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?; + let mut data = vec![false; batch_size]; + let array = match present { + Some(present) => { + self.iter.decode_spaced(data.as_mut_slice(), &present)?; + BooleanArray::new(data.into(), Some(present)) + } + None => { + self.iter.decode(data.as_mut_slice())?; + BooleanArray::from(data) + } + }; + Ok(Arc::new(array)) + } +} + +struct PresentDecoder { + // TODO: ideally directly reference BooleanDecoder, doing this way to avoid + // the generic propagation that would be required (BooleanDecoder) + inner: Box + Send>, +} + +impl PresentDecoder { + fn from_stripe(stripe: &Stripe, column: &Column) -> Option { + stripe + .stream_map() + .get_opt(column, Kind::Present) + .map(|stream| { + let inner = Box::new(BooleanDecoder::new(stream)); + PresentDecoder { inner } + }) + } + + fn next_buffer(&mut self, size: usize) -> Result { + let mut data = vec![false; size]; + self.inner.decode(&mut data)?; + Ok(NullBuffer::from(data)) + } +} + +fn merge_parent_present( + parent_present: &NullBuffer, + present: Result, +) -> Result { + let present = present?; + let non_null_count = parent_present.len() - parent_present.null_count(); + debug_assert!(present.len() == non_null_count); + let mut builder = BooleanBufferBuilder::new(parent_present.len()); + builder.append_n(parent_present.len(), false); + for (idx, p) in parent_present.valid_indices().zip(present.iter()) { + builder.set_bit(idx, p); + } + Ok(builder.finish().into()) +} + +fn derive_present_vec( + present: &mut Option, + parent_present: Option<&NullBuffer>, + batch_size: usize, +) -> Option> { + match (present, parent_present) { + (Some(present), Some(parent_present)) => { + let element_count = parent_present.len() - parent_present.null_count(); + let present = present.next_buffer(element_count); + Some(merge_parent_present(parent_present, present)) + } + (Some(present), None) => Some(present.next_buffer(batch_size)), + (None, Some(parent_present)) => Some(Ok(parent_present.clone())), + (None, None) => None, + } +} + +pub struct NaiveStripeDecoder { + stripe: Stripe, + schema_ref: SchemaRef, + decoders: Vec>, + index: usize, + batch_size: usize, + number_of_rows: usize, +} + +impl Iterator for NaiveStripeDecoder { + type Item = Result; + + fn next(&mut self) -> Option { + if self.index < self.number_of_rows { + let record = self + .decode_next_batch(self.number_of_rows - self.index) + .transpose()?; + self.index += self.batch_size; + Some(record) + } else { + None + } + } +} + +pub fn array_decoder_factory( + column: &Column, + field: Arc, + stripe: &Stripe, +) -> Result> { + let decoder: Box = match (column.data_type(), field.data_type()) { + // TODO: try make branches more generic, reduce duplication + (DataType::Boolean { .. }, ArrowDataType::Boolean) => { + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = Box::new(BooleanDecoder::new(iter)); + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(BooleanArrayDecoder::new(iter, present)) + } + (DataType::Byte { .. }, ArrowDataType::Int8) => { + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = Box::new(ByteRleDecoder::new(iter)); + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(Int8ArrayDecoder::new(iter, present)) + } + (DataType::Short { .. }, ArrowDataType::Int16) => { + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = get_rle_reader(column, iter)?; + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(Int16ArrayDecoder::new(iter, present)) + } + (DataType::Int { .. }, ArrowDataType::Int32) => { + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = get_rle_reader(column, iter)?; + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(Int32ArrayDecoder::new(iter, present)) + } + (DataType::Long { .. }, ArrowDataType::Int64) => { + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = get_rle_reader(column, iter)?; + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(Int64ArrayDecoder::new(iter, present)) + } + (DataType::Float { .. }, ArrowDataType::Float32) => { + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = Box::new(FloatDecoder::new(iter)); + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(Float32ArrayDecoder::new(iter, present)) + } + (DataType::Double { .. }, ArrowDataType::Float64) => { + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = Box::new(FloatDecoder::new(iter)); + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(Float64ArrayDecoder::new(iter, present)) + } + (DataType::String { .. }, ArrowDataType::Utf8) + | (DataType::Varchar { .. }, ArrowDataType::Utf8) + | (DataType::Char { .. }, ArrowDataType::Utf8) => new_string_decoder(column, stripe)?, + (DataType::Binary { .. }, ArrowDataType::Binary) => new_binary_decoder(column, stripe)?, + ( + DataType::Decimal { + precision, scale, .. + }, + ArrowDataType::Decimal128(a_precision, a_scale), + ) if *precision as u8 == *a_precision && *scale as i8 == *a_scale => { + new_decimal_decoder(column, stripe, *precision, *scale)? + } + (DataType::Timestamp { .. }, field_type) => { + new_timestamp_decoder(column, field_type.clone(), stripe)? + } + (DataType::TimestampWithLocalTimezone { .. }, field_type) => { + new_timestamp_instant_decoder(column, field_type.clone(), stripe)? + } + (DataType::Date { .. }, ArrowDataType::Date32) => { + // TODO: allow Date64 + let iter = stripe.stream_map().get(column, Kind::Data); + let iter = get_rle_reader(column, iter)?; + let present = PresentDecoder::from_stripe(stripe, column); + Box::new(DateArrayDecoder::new(iter, present)) + } + (DataType::Struct { .. }, ArrowDataType::Struct(fields)) => { + Box::new(StructArrayDecoder::new(column, fields.clone(), stripe)?) + } + (DataType::List { .. }, ArrowDataType::List(field)) => { + // TODO: add support for ArrowDataType::LargeList + Box::new(ListArrayDecoder::new(column, field.clone(), stripe)?) + } + (DataType::Map { .. }, ArrowDataType::Map(entries, sorted)) => { + ensure!(!sorted, UnsupportedTypeVariantSnafu { msg: "Sorted map" }); + let ArrowDataType::Struct(entries) = entries.data_type() else { + UnexpectedSnafu { + msg: "arrow Map with non-Struct entry type".to_owned(), + } + .fail()? + }; + ensure!( + entries.len() == 2, + UnexpectedSnafu { + msg: format!( + "arrow Map with {} columns per entry (expected 2)", + entries.len() + ) + } + ); + let keys_field = entries[0].clone(); + let values_field = entries[1].clone(); + + Box::new(MapArrayDecoder::new( + column, + keys_field, + values_field, + stripe, + )?) + } + (DataType::Union { .. }, ArrowDataType::Union(fields, _)) => { + Box::new(UnionArrayDecoder::new(column, fields.clone(), stripe)?) + } + (data_type, field_type) => { + return MismatchedSchemaSnafu { + orc_type: data_type.clone(), + arrow_type: field_type.clone(), + } + .fail() + } + }; + + Ok(decoder) +} + +impl NaiveStripeDecoder { + fn inner_decode_next_batch(&mut self, remaining: usize) -> Result> { + let chunk = self.batch_size.min(remaining); + + let mut fields = Vec::with_capacity(self.stripe.columns().len()); + + for decoder in &mut self.decoders { + let array = decoder.next_batch(chunk, None)?; + if array.is_empty() { + break; + } else { + fields.push(array); + } + } + + Ok(fields) + } + + fn decode_next_batch(&mut self, remaining: usize) -> Result> { + let fields = self.inner_decode_next_batch(remaining)?; + + if fields.is_empty() { + if remaining == 0 { + Ok(None) + } else { + // In case of empty projection, we need to create a RecordBatch with `row_count` only + // to reflect the row number + Ok(Some( + RecordBatch::try_new_with_options( + Arc::clone(&self.schema_ref), + fields, + &RecordBatchOptions::new() + .with_row_count(Some(self.batch_size.min(remaining))), + ) + .context(error::ConvertRecordBatchSnafu)?, + )) + } + } else { + //TODO(weny): any better way? + let fields = self + .schema_ref + .fields + .into_iter() + .map(|field| field.name()) + .zip(fields) + .collect::>(); + + Ok(Some( + RecordBatch::try_from_iter(fields).context(error::ConvertRecordBatchSnafu)?, + )) + } + } + + pub fn new(stripe: Stripe, schema_ref: SchemaRef, batch_size: usize) -> Result { + let mut decoders = Vec::with_capacity(stripe.columns().len()); + let number_of_rows = stripe.number_of_rows(); + + for (col, field) in stripe + .columns() + .iter() + .zip(schema_ref.fields.iter().cloned()) + { + let decoder = array_decoder_factory(col, field, &stripe)?; + decoders.push(decoder); + } + + Ok(Self { + stripe, + schema_ref, + decoders, + index: 0, + batch_size, + number_of_rows, + }) + } +} diff --git a/src/array_decoder/string.rs b/src/array_decoder/string.rs new file mode 100644 index 0000000..dda72f0 --- /dev/null +++ b/src/array_decoder/string.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow::array::{ArrayRef, DictionaryArray, GenericByteArray, StringArray}; +use arrow::buffer::{Buffer, NullBuffer, OffsetBuffer}; +use arrow::compute::kernels::cast; +use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType, GenericStringType}; +use snafu::ResultExt; + +use crate::array_decoder::derive_present_vec; +use crate::column::Column; +use crate::compression::Decompressor; +use crate::encoding::integer::get_unsigned_rle_reader; +use crate::encoding::PrimitiveValueDecoder; +use crate::error::{ArrowSnafu, IoSnafu, Result}; +use crate::proto::column_encoding::Kind as ColumnEncodingKind; +use crate::proto::stream::Kind; +use crate::stripe::Stripe; + +use super::{ArrayBatchDecoder, Int64ArrayDecoder, PresentDecoder}; + +// TODO: reduce duplication with string below +pub fn new_binary_decoder(column: &Column, stripe: &Stripe) -> Result> { + let present = PresentDecoder::from_stripe(stripe, column); + + let lengths = stripe.stream_map().get(column, Kind::Length); + let lengths = get_unsigned_rle_reader(column, lengths); + + let bytes = Box::new(stripe.stream_map().get(column, Kind::Data)); + Ok(Box::new(BinaryArrayDecoder::new(bytes, lengths, present))) +} + +pub fn new_string_decoder(column: &Column, stripe: &Stripe) -> Result> { + let kind = column.encoding().kind(); + let present = PresentDecoder::from_stripe(stripe, column); + + let lengths = stripe.stream_map().get(column, Kind::Length); + let lengths = get_unsigned_rle_reader(column, lengths); + + match kind { + ColumnEncodingKind::Direct | ColumnEncodingKind::DirectV2 => { + let bytes = Box::new(stripe.stream_map().get(column, Kind::Data)); + Ok(Box::new(DirectStringArrayDecoder::new( + bytes, lengths, present, + ))) + } + ColumnEncodingKind::Dictionary | ColumnEncodingKind::DictionaryV2 => { + let bytes = Box::new(stripe.stream_map().get(column, Kind::DictionaryData)); + // TODO: is this always guaranteed to be set for all dictionaries? + let dictionary_size = column.dictionary_size(); + // We assume here we have fetched all the dictionary strings (according to size above) + let dictionary_strings = DirectStringArrayDecoder::new(bytes, lengths, None) + .next_byte_batch(dictionary_size, None)?; + let dictionary_strings = Arc::new(dictionary_strings); + + let indexes = stripe.stream_map().get(column, Kind::Data); + let indexes = get_unsigned_rle_reader(column, indexes); + let indexes = Int64ArrayDecoder::new(indexes, present); + + Ok(Box::new(DictionaryStringArrayDecoder::new( + indexes, + dictionary_strings, + )?)) + } + } +} + +// TODO: check this offset size type +pub type DirectStringArrayDecoder = GenericByteArrayDecoder>; +pub type BinaryArrayDecoder = GenericByteArrayDecoder>; + +pub struct GenericByteArrayDecoder { + bytes: Box, + lengths: Box + Send>, + present: Option, + phantom: PhantomData, +} + +impl GenericByteArrayDecoder { + fn new( + bytes: Box, + lengths: Box + Send>, + present: Option, + ) -> Self { + Self { + bytes, + lengths, + present, + phantom: Default::default(), + } + } + + fn next_byte_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result> { + let present = + derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?; + + let mut lengths = vec![0; batch_size]; + if let Some(present) = &present { + self.lengths.decode_spaced(&mut lengths, present)?; + } else { + self.lengths.decode(&mut lengths)?; + } + let total_length: i64 = lengths.iter().sum(); + // Fetch all data bytes at once + let mut bytes = Vec::with_capacity(total_length as usize); + self.bytes + .by_ref() + .take(total_length as u64) + .read_to_end(&mut bytes) + .context(IoSnafu)?; + let bytes = Buffer::from(bytes); + let offsets = + OffsetBuffer::::from_lengths(lengths.into_iter().map(|l| l as usize)); + + let null_buffer = match present { + // Edge case where keys of map cannot have a null buffer + Some(present) if present.null_count() == 0 => None, + _ => present, + }; + let array = + GenericByteArray::::try_new(offsets, bytes, null_buffer).context(ArrowSnafu)?; + Ok(array) + } +} + +impl ArrayBatchDecoder for GenericByteArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let array = self.next_byte_batch(batch_size, parent_present)?; + let array = Arc::new(array) as ArrayRef; + Ok(array) + } +} + +pub struct DictionaryStringArrayDecoder { + indexes: Int64ArrayDecoder, + dictionary: Arc, +} + +impl DictionaryStringArrayDecoder { + fn new(indexes: Int64ArrayDecoder, dictionary: Arc) -> Result { + Ok(Self { + indexes, + dictionary, + }) + } +} + +impl ArrayBatchDecoder for DictionaryStringArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let keys = self + .indexes + .next_primitive_batch(batch_size, parent_present)?; + // TODO: ORC spec states: For dictionary encodings the dictionary is sorted + // (in lexicographical order of bytes in the UTF-8 encodings). + // So we can set the is_ordered property here? + let array = DictionaryArray::try_new(keys, self.dictionary.clone()).context(ArrowSnafu)?; + // Cast back to StringArray to ensure all stripes have consistent datatype + // TODO: Is there anyway to preserve the dictionary encoding? + // This costs performance. + let array = cast(&array, &DataType::Utf8).context(ArrowSnafu)?; + + let array = Arc::new(array); + Ok(array) + } +} diff --git a/src/array_decoder/struct_decoder.rs b/src/array_decoder/struct_decoder.rs new file mode 100644 index 0000000..b09bb3b --- /dev/null +++ b/src/array_decoder/struct_decoder.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{ArrayRef, StructArray}, + buffer::NullBuffer, + datatypes::Fields, +}; +use snafu::ResultExt; + +use crate::error::Result; +use crate::stripe::Stripe; +use crate::{column::Column, error::ArrowSnafu}; + +use super::{array_decoder_factory, derive_present_vec, ArrayBatchDecoder, PresentDecoder}; + +pub struct StructArrayDecoder { + fields: Fields, + decoders: Vec>, + present: Option, +} + +impl StructArrayDecoder { + pub fn new(column: &Column, fields: Fields, stripe: &Stripe) -> Result { + let present = PresentDecoder::from_stripe(stripe, column); + + let decoders = column + .children() + .iter() + .zip(fields.iter().cloned()) + .map(|(child, field)| array_decoder_factory(child, field, stripe)) + .collect::>>()?; + + Ok(Self { + decoders, + present, + fields, + }) + } +} + +impl ArrayBatchDecoder for StructArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let present = + derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?; + + let child_arrays = self + .decoders + .iter_mut() + .map(|child| child.next_batch(batch_size, present.as_ref())) + .collect::>>()?; + + let null_buffer = present.map(NullBuffer::from); + let array = StructArray::try_new(self.fields.clone(), child_arrays, null_buffer) + .context(ArrowSnafu)?; + let array = Arc::new(array); + Ok(array) + } +} diff --git a/src/array_decoder/timestamp.rs b/src/array_decoder/timestamp.rs new file mode 100644 index 0000000..b6c45b5 --- /dev/null +++ b/src/array_decoder/timestamp.rs @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::{ + array_decoder::ArrowDataType, + column::Column, + encoding::{ + integer::{get_rle_reader, get_unsigned_rle_reader}, + timestamp::{TimestampDecoder, TimestampNanosecondAsDecimalDecoder}, + PrimitiveValueDecoder, + }, + error::{MismatchedSchemaSnafu, Result}, + proto::stream::Kind, + stripe::Stripe, +}; +use arrow::datatypes::{ + ArrowTimestampType, Decimal128Type, DecimalType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use arrow::{array::ArrayRef, buffer::NullBuffer}; +use chrono::offset::TimeZone; +use chrono::TimeDelta; +use chrono_tz::{Tz, UTC}; + +use super::{ + decimal::DecimalArrayDecoder, ArrayBatchDecoder, PresentDecoder, PrimitiveArrayDecoder, +}; +use crate::error::UnsupportedTypeVariantSnafu; + +const NANOSECONDS_IN_SECOND: i128 = 1_000_000_000; +const NANOSECOND_DIGITS: i8 = 9; + +/// Seconds from ORC epoch of 1 January 2015, which serves as the 0 +/// point for all timestamp values, to the UNIX epoch of 1 January 1970. +const ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH: i64 = 1_420_070_400; + +fn get_inner_timestamp_decoder( + column: &Column, + stripe: &Stripe, + seconds_since_unix_epoch: i64, +) -> Result> { + let data = stripe.stream_map().get(column, Kind::Data); + let data = get_rle_reader(column, data)?; + + let secondary = stripe.stream_map().get(column, Kind::Secondary); + let secondary = get_unsigned_rle_reader(column, secondary); + + let present = PresentDecoder::from_stripe(stripe, column); + + let iter = Box::new(TimestampDecoder::::new( + seconds_since_unix_epoch, + data, + secondary, + )); + Ok(PrimitiveArrayDecoder::::new(iter, present)) +} + +fn get_timestamp_decoder( + column: &Column, + stripe: &Stripe, + seconds_since_unix_epoch: i64, +) -> Result> { + let inner = get_inner_timestamp_decoder::(column, stripe, seconds_since_unix_epoch)?; + match stripe.writer_tz() { + Some(writer_tz) => Ok(Box::new(TimestampOffsetArrayDecoder { inner, writer_tz })), + None => Ok(Box::new(inner)), + } +} + +fn get_timestamp_instant_decoder( + column: &Column, + stripe: &Stripe, +) -> Result> { + // TIMESTAMP_INSTANT is encoded as UTC so we don't check writer timezone in stripe + let inner = + get_inner_timestamp_decoder::(column, stripe, ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH)?; + Ok(Box::new(TimestampInstantArrayDecoder(inner))) +} + +fn decimal128_decoder( + column: &Column, + stripe: &Stripe, + seconds_since_unix_epoch: i64, + writer_tz: Option, +) -> Result { + let data = stripe.stream_map().get(column, Kind::Data); + let data = get_rle_reader(column, data)?; + + let secondary = stripe.stream_map().get(column, Kind::Secondary); + let secondary = get_rle_reader(column, secondary)?; + + let present = PresentDecoder::from_stripe(stripe, column); + + let iter = TimestampNanosecondAsDecimalDecoder::new(seconds_since_unix_epoch, data, secondary); + + let iter: Box + Send> = match writer_tz { + Some(UTC) | None => Box::new(iter), + Some(writer_tz) => Box::new(TimestampNanosecondAsDecimalWithTzDecoder(iter, writer_tz)), + }; + + Ok(DecimalArrayDecoder::new( + Decimal128Type::MAX_PRECISION, + NANOSECOND_DIGITS, + iter, + present, + )) +} + +/// Decodes a TIMESTAMP column stripe into batches of Timestamp{Nano,Micro,Milli,}secondArrays +/// with no timezone. Will convert timestamps from writer timezone to UTC if a writer timezone +/// is specified for the stripe. +pub fn new_timestamp_decoder( + column: &Column, + field_type: ArrowDataType, + stripe: &Stripe, +) -> Result> { + let seconds_since_unix_epoch = match stripe.writer_tz() { + Some(writer_tz) => { + // If writer timezone exists then we must take the ORC epoch according + // to that timezone, and find seconds since UTC UNIX epoch for the base. + writer_tz + .with_ymd_and_hms(2015, 1, 1, 0, 0, 0) + .unwrap() + .timestamp() + } + None => { + // No writer timezone, we can assume UTC, so we can use known fixed value + // for the base offset. + ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH + } + }; + + match field_type { + ArrowDataType::Timestamp(TimeUnit::Second, None) => { + get_timestamp_decoder::(column, stripe, seconds_since_unix_epoch) + } + ArrowDataType::Timestamp(TimeUnit::Millisecond, None) => { + get_timestamp_decoder::( + column, + stripe, + seconds_since_unix_epoch, + ) + } + ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => { + get_timestamp_decoder::( + column, + stripe, + seconds_since_unix_epoch, + ) + } + ArrowDataType::Timestamp(TimeUnit::Nanosecond, None) => { + get_timestamp_decoder::( + column, + stripe, + seconds_since_unix_epoch, + ) + } + ArrowDataType::Decimal128(Decimal128Type::MAX_PRECISION, NANOSECOND_DIGITS) => { + Ok(Box::new(decimal128_decoder( + column, + stripe, + seconds_since_unix_epoch, + stripe.writer_tz(), + )?)) + } + _ => MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type, + } + .fail(), + } +} + +/// Decodes a TIMESTAMP_INSTANT column stripe into batches of +/// Timestamp{Nano,Micro,Milli,}secondArrays with UTC timezone. +pub fn new_timestamp_instant_decoder( + column: &Column, + field_type: ArrowDataType, + stripe: &Stripe, +) -> Result> { + match field_type { + ArrowDataType::Timestamp(TimeUnit::Second, Some(tz)) if tz.as_ref() == "UTC" => { + get_timestamp_instant_decoder::(column, stripe) + } + ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { + get_timestamp_instant_decoder::(column, stripe) + } + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { + get_timestamp_instant_decoder::(column, stripe) + } + ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) if tz.as_ref() == "UTC" => { + get_timestamp_instant_decoder::(column, stripe) + } + ArrowDataType::Timestamp(_, Some(_)) => UnsupportedTypeVariantSnafu { + msg: "Non-UTC Arrow timestamps", + } + .fail(), + ArrowDataType::Decimal128(Decimal128Type::MAX_PRECISION, NANOSECOND_DIGITS) => { + Ok(Box::new(decimal128_decoder( + column, + stripe, + ORC_EPOCH_UTC_SECONDS_SINCE_UNIX_EPOCH, + None, + )?)) + } + _ => MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type, + } + .fail()?, + } +} + +/// Wrapper around PrimitiveArrayDecoder to decode timestamps which are encoded in +/// timezone of the writer to their UTC value. +struct TimestampOffsetArrayDecoder { + inner: PrimitiveArrayDecoder, + writer_tz: chrono_tz::Tz, +} + +impl ArrayBatchDecoder for TimestampOffsetArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let array = self + .inner + .next_primitive_batch(batch_size, parent_present)?; + + let convert_timezone = |ts| { + // Convert from writer timezone to reader timezone (which we default to UTC) + // TODO: more efficient way of doing this? + self.writer_tz + .timestamp_nanos(ts) + .naive_local() + .and_utc() + .timestamp_nanos_opt() + }; + let array = array + // first try to convert all non-nullable batches to non-nullable batches + .try_unary::<_, T, _>(|ts| convert_timezone(ts).ok_or(())) + // in the rare case one of the values was out of the timeunit's range (eg. see + // ), + // for nanoseconds), try again by allowing a nullable batch as output + .unwrap_or_else(|()| array.unary_opt::<_, T>(convert_timezone)); + let array = Arc::new(array) as ArrayRef; + Ok(array) + } +} + +/// Wrapper around PrimitiveArrayDecoder to allow specifying the timezone of the output +/// timestamp array as UTC. +struct TimestampInstantArrayDecoder(PrimitiveArrayDecoder); + +impl ArrayBatchDecoder for TimestampInstantArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let array = self + .0 + .next_primitive_batch(batch_size, parent_present)? + .with_timezone("UTC"); + let array = Arc::new(array) as ArrayRef; + Ok(array) + } +} + +struct TimestampNanosecondAsDecimalWithTzDecoder(TimestampNanosecondAsDecimalDecoder, Tz); + +impl TimestampNanosecondAsDecimalWithTzDecoder { + fn next_inner(&self, ts: i128) -> i128 { + let seconds = ts.div_euclid(NANOSECONDS_IN_SECOND); + let nanoseconds = ts.rem_euclid(NANOSECONDS_IN_SECOND); + + // The addition may panic, because chrono stores dates in an i32, + // which can be overflowed with an i64 of seconds. + let dt = (self.1.timestamp_nanos(0) + + TimeDelta::new(seconds as i64, nanoseconds as u32) + .expect("TimeDelta duration out of bound")) + .naive_local() + .and_utc(); + + (dt.timestamp() as i128) * NANOSECONDS_IN_SECOND + (dt.timestamp_subsec_nanos() as i128) + } +} + +impl PrimitiveValueDecoder for TimestampNanosecondAsDecimalWithTzDecoder { + fn decode(&mut self, out: &mut [i128]) -> Result<()> { + self.0.decode(out)?; + for x in out.iter_mut() { + *x = self.next_inner(*x); + } + Ok(()) + } +} diff --git a/src/array_decoder/union.rs b/src/array_decoder/union.rs new file mode 100644 index 0000000..39af4ea --- /dev/null +++ b/src/array_decoder/union.rs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanBufferBuilder, UnionArray}; +use arrow::buffer::{Buffer, NullBuffer}; +use arrow::datatypes::UnionFields; +use snafu::ResultExt; + +use crate::column::Column; +use crate::encoding::byte::ByteRleDecoder; +use crate::encoding::PrimitiveValueDecoder; +use crate::error::ArrowSnafu; +use crate::error::Result; +use crate::proto::stream::Kind; +use crate::stripe::Stripe; + +use super::{array_decoder_factory, derive_present_vec, ArrayBatchDecoder, PresentDecoder}; + +/// Decode ORC Union column into batches of Arrow Sparse UnionArrays. +pub struct UnionArrayDecoder { + // fields and variants should have same length + // TODO: encode this assumption into types + fields: UnionFields, + variants: Vec>, + tags: Box + Send>, + present: Option, +} + +impl UnionArrayDecoder { + pub fn new(column: &Column, fields: UnionFields, stripe: &Stripe) -> Result { + let present = PresentDecoder::from_stripe(stripe, column); + + let tags = stripe.stream_map().get(column, Kind::Data); + let tags = Box::new(ByteRleDecoder::new(tags)); + + let variants = column + .children() + .iter() + .zip(fields.iter()) + .map(|(child, (_id, field))| array_decoder_factory(child, field.clone(), stripe)) + .collect::>>()?; + + Ok(Self { + fields, + variants, + tags, + present, + }) + } +} + +impl ArrayBatchDecoder for UnionArrayDecoder { + fn next_batch( + &mut self, + batch_size: usize, + parent_present: Option<&NullBuffer>, + ) -> Result { + let present = + derive_present_vec(&mut self.present, parent_present, batch_size).transpose()?; + let mut tags = vec![0; batch_size]; + match &present { + Some(present) => { + // Since UnionArrays don't have nullability, we rely on child arrays. + // So we default to first child (tag 0) for any nulls from this parent Union. + self.tags.decode_spaced(&mut tags, present)?; + } + None => { + self.tags.decode(&mut tags)?; + } + } + + // Calculate nullability for children + let mut children_nullability = (0..self.variants.len()) + .map(|index| { + let mut child_present = BooleanBufferBuilder::new(batch_size); + child_present.append_n(batch_size, false); + for idx in tags + .iter() + .enumerate() + // Where the parent expects the value of the child, we set to non-null. + // Otherwise for the sparse spots, we leave as null in children. + .filter_map(|(idx, &tag)| (tag as usize == index).then_some(idx)) + { + child_present.set_bit(idx, true); + } + child_present + }) + .collect::>(); + // If parent says a slot is null, we need to ensure the first child (0-index) also + // encodes this information, since as mentioned before, Arrow UnionArrays don't store + // nullability and rely on their children. We default to first child to encode this + // information so need to enforce that here. + if let Some(present) = &present { + let first_child = &mut children_nullability[0]; + for idx in present + .iter() + .enumerate() + .filter_map(|(idx, parent_present)| (!parent_present).then_some(idx)) + { + first_child.set_bit(idx, false); + } + } + + let child_arrays = self + .variants + .iter_mut() + .zip(children_nullability) + .map(|(decoder, mut present)| { + let present = NullBuffer::from(present.finish()); + decoder.next_batch(batch_size, Some(&present)) + }) + .collect::>>()?; + + // Currently default to decoding as Sparse UnionArray so no value offsets + let type_ids = Buffer::from_vec(tags.clone()).into(); + let array = UnionArray::try_new(self.fields.clone(), type_ids, None, child_arrays) + .context(ArrowSnafu)?; + let array = Arc::new(array); + Ok(array) + } +} diff --git a/src/arrow_reader.rs b/src/arrow_reader.rs new file mode 100644 index 0000000..fd0eb53 --- /dev/null +++ b/src/arrow_reader.rs @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::ops::Range; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; + +use crate::array_decoder::NaiveStripeDecoder; +use crate::error::Result; +use crate::projection::ProjectionMask; +use crate::reader::metadata::{read_metadata, FileMetadata}; +use crate::reader::ChunkReader; +use crate::schema::RootDataType; +use crate::stripe::{Stripe, StripeMetadata}; + +const DEFAULT_BATCH_SIZE: usize = 8192; + +pub struct ArrowReaderBuilder { + pub(crate) reader: R, + pub(crate) file_metadata: Arc, + pub(crate) batch_size: usize, + pub(crate) projection: ProjectionMask, + pub(crate) schema_ref: Option, + pub(crate) file_byte_range: Option>, +} + +impl ArrowReaderBuilder { + pub(crate) fn new(reader: R, file_metadata: Arc) -> Self { + Self { + reader, + file_metadata, + batch_size: DEFAULT_BATCH_SIZE, + projection: ProjectionMask::all(), + schema_ref: None, + file_byte_range: None, + } + } + + pub fn file_metadata(&self) -> &FileMetadata { + &self.file_metadata + } + + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + pub fn with_projection(mut self, projection: ProjectionMask) -> Self { + self.projection = projection; + self + } + + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema_ref = Some(schema); + self + } + + /// Specifies a range of file bytes that will read the strips offset within this range + pub fn with_file_byte_range(mut self, range: Range) -> Self { + self.file_byte_range = Some(range); + self + } + + /// Returns the currently computed schema + /// + /// Unless [`with_schema`](Self::with_schema) was called, this is computed dynamically + /// based on the current projection and the underlying file format. + pub fn schema(&self) -> SchemaRef { + let projected_data_type = self + .file_metadata + .root_data_type() + .project(&self.projection); + let metadata = self + .file_metadata + .user_custom_metadata() + .iter() + .map(|(key, value)| (key.clone(), String::from_utf8_lossy(value).to_string())) + .collect::>(); + self.schema_ref + .clone() + .unwrap_or_else(|| Arc::new(projected_data_type.create_arrow_schema(&metadata))) + } +} + +impl ArrowReaderBuilder { + pub fn try_new(mut reader: R) -> Result { + let file_metadata = Arc::new(read_metadata(&mut reader)?); + Ok(Self::new(reader, file_metadata)) + } + + pub fn build(self) -> ArrowReader { + let schema_ref = self.schema(); + let projected_data_type = self + .file_metadata + .root_data_type() + .project(&self.projection); + let cursor = Cursor { + reader: self.reader, + file_metadata: self.file_metadata, + projected_data_type, + stripe_index: 0, + file_byte_range: self.file_byte_range, + }; + ArrowReader { + cursor, + schema_ref, + current_stripe: None, + batch_size: self.batch_size, + } + } +} + +pub struct ArrowReader { + cursor: Cursor, + schema_ref: SchemaRef, + current_stripe: Option> + Send>>, + batch_size: usize, +} + +impl ArrowReader { + pub fn total_row_count(&self) -> u64 { + self.cursor.file_metadata.number_of_rows() + } +} + +impl ArrowReader { + fn try_advance_stripe(&mut self) -> Result, ArrowError> { + let stripe = self.cursor.next().transpose()?; + match stripe { + Some(stripe) => { + let decoder = + NaiveStripeDecoder::new(stripe, self.schema_ref.clone(), self.batch_size)?; + self.current_stripe = Some(Box::new(decoder)); + self.next().transpose() + } + None => Ok(None), + } + } +} + +impl RecordBatchReader for ArrowReader { + fn schema(&self) -> SchemaRef { + self.schema_ref.clone() + } +} + +impl Iterator for ArrowReader { + type Item = Result; + + fn next(&mut self) -> Option { + match self.current_stripe.as_mut() { + Some(stripe) => { + match stripe + .next() + .map(|batch| batch.map_err(|err| ArrowError::ExternalError(Box::new(err)))) + { + Some(rb) => Some(rb), + None => self.try_advance_stripe().transpose(), + } + } + None => self.try_advance_stripe().transpose(), + } + } +} + +pub(crate) struct Cursor { + pub reader: R, + pub file_metadata: Arc, + pub projected_data_type: RootDataType, + pub stripe_index: usize, + pub file_byte_range: Option>, +} + +impl Cursor { + fn get_stripe_metadatas(&self) -> Vec { + if let Some(range) = self.file_byte_range.clone() { + self.file_metadata + .stripe_metadatas() + .iter() + .filter(|info| { + let offset = info.offset() as usize; + range.contains(&offset) + }) + .map(|info| info.to_owned()) + .collect::>() + } else { + self.file_metadata.stripe_metadatas().to_vec() + } + } +} + +impl Iterator for Cursor { + type Item = Result; + + fn next(&mut self) -> Option { + self.get_stripe_metadatas() + .get(self.stripe_index) + .map(|info| { + let stripe = Stripe::new( + &mut self.reader, + &self.file_metadata, + &self.projected_data_type.clone(), + info, + ); + self.stripe_index += 1; + stripe + }) + } +} diff --git a/src/arrow_writer.rs b/src/arrow_writer.rs new file mode 100644 index 0000000..a8b1dd9 --- /dev/null +++ b/src/arrow_writer.rs @@ -0,0 +1,477 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Write; + +use arrow::{ + array::RecordBatch, + datatypes::{DataType as ArrowDataType, SchemaRef}, +}; +use prost::Message; +use snafu::{ensure, ResultExt}; + +use crate::{ + error::{IoSnafu, Result, UnexpectedSnafu}, + memory::EstimateMemory, + proto, + writer::stripe::{StripeInformation, StripeWriter}, +}; + +/// Construct an [`ArrowWriter`] to encode [`RecordBatch`]es into a single +/// ORC file. +pub struct ArrowWriterBuilder { + writer: W, + schema: SchemaRef, + batch_size: usize, + stripe_byte_size: usize, +} + +impl ArrowWriterBuilder { + /// Create a new [`ArrowWriterBuilder`], which will write an ORC file to + /// the provided writer, with the expected Arrow schema. + pub fn new(writer: W, schema: SchemaRef) -> Self { + Self { + writer, + schema, + batch_size: 1024, + // 64 MiB + stripe_byte_size: 64 * 1024 * 1024, + } + } + + /// Batch size controls the encoding behaviour, where `batch_size` values + /// are encoded at a time. Default is `1024`. + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// The approximate size of stripes. Default is `64MiB`. + pub fn with_stripe_byte_size(mut self, stripe_byte_size: usize) -> Self { + self.stripe_byte_size = stripe_byte_size; + self + } + + /// Construct an [`ArrowWriter`] ready to encode [`RecordBatch`]es into + /// an ORC file. + pub fn try_build(mut self) -> Result> { + // Required magic "ORC" bytes at start of file + self.writer.write_all(b"ORC").context(IoSnafu)?; + let writer = StripeWriter::new(self.writer, &self.schema); + Ok(ArrowWriter { + writer, + schema: self.schema, + batch_size: self.batch_size, + stripe_byte_size: self.stripe_byte_size, + written_stripes: vec![], + // Accounting for the 3 magic bytes above + total_bytes_written: 3, + }) + } +} + +/// Encodes [`RecordBatch`]es into an ORC file. Will encode `batch_size` rows +/// at a time into a stripe, flushing the stripe to the underlying writer when +/// it's estimated memory footprint exceeds the configures `stripe_byte_size`. +pub struct ArrowWriter { + writer: StripeWriter, + schema: SchemaRef, + batch_size: usize, + stripe_byte_size: usize, + written_stripes: Vec, + /// Used to keep track of progress in file so far (instead of needing Seek on the writer) + total_bytes_written: u64, +} + +impl ArrowWriter { + /// Encode the provided batch at `batch_size` rows at a time, flushing any + /// stripes that exceed the configured stripe size. + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + ensure!( + batch.schema() == self.schema, + UnexpectedSnafu { + msg: "RecordBatch doesn't match expected schema" + } + ); + + for offset in (0..batch.num_rows()).step_by(self.batch_size) { + let length = self.batch_size.min(batch.num_rows() - offset); + let batch = batch.slice(offset, length); + self.writer.encode_batch(&batch)?; + + // TODO: be able to flush whilst writing a batch (instead of between batches) + // Flush stripe when it exceeds estimated configured size + if self.writer.estimate_memory_size() > self.stripe_byte_size { + self.flush_stripe()?; + } + } + Ok(()) + } + + /// Flush any buffered data that hasn't been written, and write the stripe + /// footer metadata. + pub fn flush_stripe(&mut self) -> Result<()> { + let info = self.writer.finish_stripe(self.total_bytes_written)?; + self.total_bytes_written += info.total_byte_size(); + self.written_stripes.push(info); + Ok(()) + } + + /// Flush the current stripe if it is still in progress, and write the tail + /// metadata and close the writer. + pub fn close(mut self) -> Result<()> { + // Flush in-progress stripe + if self.writer.row_count > 0 { + self.flush_stripe()?; + } + let footer = serialize_footer(&self.written_stripes, &self.schema); + let footer = footer.encode_to_vec(); + let postscript = serialize_postscript(footer.len() as u64); + let postscript = postscript.encode_to_vec(); + let postscript_len = postscript.len() as u8; + + let mut writer = self.writer.finish(); + writer.write_all(&footer).context(IoSnafu)?; + writer.write_all(&postscript).context(IoSnafu)?; + // Postscript length as last byte + writer.write_all(&[postscript_len]).context(IoSnafu)?; + + // TODO: return file metadata + Ok(()) + } +} + +fn serialize_schema(schema: &SchemaRef) -> Vec { + let mut types = vec![]; + + let field_names = schema + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect(); + // TODO: consider nested types + let subtypes = (1..(schema.fields().len() as u32 + 1)).collect(); + let root_type = proto::Type { + kind: Some(proto::r#type::Kind::Struct.into()), + subtypes, + field_names, + maximum_length: None, + precision: None, + scale: None, + attributes: vec![], + }; + types.push(root_type); + for field in schema.fields() { + let t = match field.data_type() { + ArrowDataType::Float32 => proto::Type { + kind: Some(proto::r#type::Kind::Float.into()), + ..Default::default() + }, + ArrowDataType::Float64 => proto::Type { + kind: Some(proto::r#type::Kind::Double.into()), + ..Default::default() + }, + ArrowDataType::Int8 => proto::Type { + kind: Some(proto::r#type::Kind::Byte.into()), + ..Default::default() + }, + ArrowDataType::Int16 => proto::Type { + kind: Some(proto::r#type::Kind::Short.into()), + ..Default::default() + }, + ArrowDataType::Int32 => proto::Type { + kind: Some(proto::r#type::Kind::Int.into()), + ..Default::default() + }, + ArrowDataType::Int64 => proto::Type { + kind: Some(proto::r#type::Kind::Long.into()), + ..Default::default() + }, + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => proto::Type { + kind: Some(proto::r#type::Kind::String.into()), + ..Default::default() + }, + ArrowDataType::Binary | ArrowDataType::LargeBinary => proto::Type { + kind: Some(proto::r#type::Kind::Binary.into()), + ..Default::default() + }, + ArrowDataType::Boolean => proto::Type { + kind: Some(proto::r#type::Kind::Boolean.into()), + ..Default::default() + }, + // TODO: support more types + _ => unimplemented!("unsupported datatype"), + }; + types.push(t); + } + types +} + +fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto::Footer { + let body_length = stripes + .iter() + .map(|s| s.index_length + s.data_length + s.footer_length) + .sum::(); + let number_of_rows = stripes.iter().map(|s| s.row_count as u64).sum::(); + let stripes = stripes.iter().map(From::from).collect(); + let types = serialize_schema(schema); + proto::Footer { + header_length: Some(3), + content_length: Some(body_length + 3), + stripes, + types, + metadata: vec![], + number_of_rows: Some(number_of_rows), + statistics: vec![], + row_index_stride: None, + writer: Some(u32::MAX), + encryption: None, + calendar: None, + software_version: None, + } +} + +fn serialize_postscript(footer_length: u64) -> proto::PostScript { + proto::PostScript { + footer_length: Some(footer_length), + compression: Some(proto::CompressionKind::None.into()), // TODO: support compression + compression_block_size: None, + version: vec![0, 12], + metadata_length: Some(0), // TODO: statistics + writer_version: Some(u32::MAX), // TODO: check which version to use + stripe_statistics_length: None, + magic: Some("ORC".to_string()), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::{ + Array, BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, RecordBatchReader, + StringArray, + }, + compute::concat_batches, + datatypes::{DataType as ArrowDataType, Field, Schema}, + }; + use bytes::Bytes; + + use crate::ArrowReaderBuilder; + + use super::*; + + fn roundtrip(batches: &[RecordBatch]) -> Vec { + let mut f = vec![]; + let mut writer = ArrowWriterBuilder::new(&mut f, batches[0].schema()) + .try_build() + .unwrap(); + for batch in batches { + writer.write(batch).unwrap(); + } + writer.close().unwrap(); + + let f = Bytes::from(f); + let reader = ArrowReaderBuilder::try_new(f).unwrap().build(); + reader.collect::, _>>().unwrap() + } + + #[test] + fn test_roundtrip_write() { + let f32_array = Arc::new(Float32Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + let f64_array = Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + let int8_array = Arc::new(Int8Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let int16_array = Arc::new(Int16Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let int32_array = Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let int64_array = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, 6])); + let utf8_array = Arc::new(StringArray::from(vec![ + "Hello", + "there", + "楡井希実", + "💯", + "ORC", + "", + "123", + ])); + let binary_array = Arc::new(BinaryArray::from(vec![ + "Hello".as_bytes(), + "there".as_bytes(), + "楡井希実".as_bytes(), + "💯".as_bytes(), + "ORC".as_bytes(), + "".as_bytes(), + "123".as_bytes(), + ])); + let boolean_array = Arc::new(BooleanArray::from(vec![ + true, false, true, false, true, true, false, + ])); + let schema = Schema::new(vec![ + Field::new("f32", ArrowDataType::Float32, false), + Field::new("f64", ArrowDataType::Float64, false), + Field::new("int8", ArrowDataType::Int8, false), + Field::new("int16", ArrowDataType::Int16, false), + Field::new("int32", ArrowDataType::Int32, false), + Field::new("int64", ArrowDataType::Int64, false), + Field::new("utf8", ArrowDataType::Utf8, false), + Field::new("binary", ArrowDataType::Binary, false), + Field::new("boolean", ArrowDataType::Boolean, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + f32_array, + f64_array, + int8_array, + int16_array, + int32_array, + int64_array, + utf8_array, + binary_array, + boolean_array, + ], + ) + .unwrap(); + + let rows = roundtrip(&[batch.clone()]); + assert_eq!(batch, rows[0]); + } + + #[test] + fn test_roundtrip_write_large_type() { + let large_utf8_array = Arc::new(LargeStringArray::from(vec![ + "Hello", + "there", + "楡井希実", + "💯", + "ORC", + "", + "123", + ])); + let large_binary_array = Arc::new(LargeBinaryArray::from(vec![ + "Hello".as_bytes(), + "there".as_bytes(), + "楡井希実".as_bytes(), + "💯".as_bytes(), + "ORC".as_bytes(), + "".as_bytes(), + "123".as_bytes(), + ])); + let schema = Schema::new(vec![ + Field::new("large_utf8", ArrowDataType::LargeUtf8, false), + Field::new("large_binary", ArrowDataType::LargeBinary, false), + ]); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![large_utf8_array, large_binary_array]) + .unwrap(); + + let rows = roundtrip(&[batch]); + + // Currently we read all String/Binary columns from ORC as plain StringArray/BinaryArray + let utf8_array = Arc::new(StringArray::from(vec![ + "Hello", + "there", + "楡井希実", + "💯", + "ORC", + "", + "123", + ])); + let binary_array = Arc::new(BinaryArray::from(vec![ + "Hello".as_bytes(), + "there".as_bytes(), + "楡井希実".as_bytes(), + "💯".as_bytes(), + "ORC".as_bytes(), + "".as_bytes(), + "123".as_bytes(), + ])); + let schema = Schema::new(vec![ + Field::new("large_utf8", ArrowDataType::Utf8, false), + Field::new("large_binary", ArrowDataType::Binary, false), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![utf8_array, binary_array]).unwrap(); + assert_eq!(batch, rows[0]); + } + + #[test] + fn test_write_small_stripes() { + // Set small stripe size to ensure writing across multiple stripes works + let data: Vec = (0..1_000_000).collect(); + let int64_array = Arc::new(Int64Array::from(data)); + let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![int64_array]).unwrap(); + + let mut f = vec![]; + let mut writer = ArrowWriterBuilder::new(&mut f, batch.schema()) + .with_stripe_byte_size(256) + .try_build() + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + + let f = Bytes::from(f); + let reader = ArrowReaderBuilder::try_new(f).unwrap().build(); + let schema = reader.schema(); + // Current reader doesn't read a batch across stripe boundaries, so we expect + // more than one batch to prove multiple stripes are being written here + let rows = reader.collect::, _>>().unwrap(); + assert!( + rows.len() > 1, + "must have written more than 1 stripe (each stripe read as separate recordbatch)" + ); + let actual = concat_batches(&schema, rows.iter()).unwrap(); + assert_eq!(batch, actual); + } + + #[test] + fn test_write_inconsistent_null_buffers() { + // When writing arrays where null buffer can appear/disappear between writes + let schema = Arc::new(Schema::new(vec![Field::new( + "int64", + ArrowDataType::Int64, + true, + )])); + + // Ensure first batch has array with no null buffer + let array_no_nulls = Arc::new(Int64Array::from(vec![1, 2, 3])); + assert!(array_no_nulls.nulls().is_none()); + // But subsequent batch has array with null buffer + let array_with_nulls = Arc::new(Int64Array::from(vec![None, Some(4), None])); + assert!(array_with_nulls.nulls().is_some()); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![array_no_nulls]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![array_with_nulls]).unwrap(); + + // ORC writer should be able to handle this gracefully + let expected_array = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + Some(4), + None, + ])); + let expected_batch = RecordBatch::try_new(schema, vec![expected_array]).unwrap(); + + let rows = roundtrip(&[batch1, batch2]); + assert_eq!(expected_batch, rows[0]); + } +} diff --git a/src/async_arrow_reader.rs b/src/async_arrow_reader.rs new file mode 100644 index 0000000..94a0565 --- /dev/null +++ b/src/async_arrow_reader.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Formatter; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use futures::future::BoxFuture; +use futures::{ready, Stream}; +use futures_util::FutureExt; + +use crate::array_decoder::NaiveStripeDecoder; +use crate::arrow_reader::Cursor; +use crate::error::Result; +use crate::reader::metadata::read_metadata_async; +use crate::reader::AsyncChunkReader; +use crate::stripe::{Stripe, StripeMetadata}; +use crate::ArrowReaderBuilder; + +type BoxedDecoder = Box> + Send>; + +enum StreamState { + /// At the start of a new row group, or the end of the file stream + Init, + /// Decoding a batch + Decoding(BoxedDecoder), + /// Reading data from input + Reading(BoxFuture<'static, Result<(StripeFactory, Option)>>), + /// Error + Error, +} + +impl std::fmt::Debug for StreamState { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + StreamState::Init => write!(f, "StreamState::Init"), + StreamState::Decoding(_) => write!(f, "StreamState::Decoding"), + StreamState::Reading(_) => write!(f, "StreamState::Reading"), + StreamState::Error => write!(f, "StreamState::Error"), + } + } +} + +impl From> for StripeFactory { + fn from(c: Cursor) -> Self { + Self { + inner: c, + is_end: false, + } + } +} + +struct StripeFactory { + inner: Cursor, + is_end: bool, +} + +pub struct ArrowStreamReader { + factory: Option>>, + batch_size: usize, + schema_ref: SchemaRef, + state: StreamState, +} + +impl StripeFactory { + async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result { + let inner = &mut self.inner; + + inner.stripe_index += 1; + + Stripe::new_async( + &mut inner.reader, + &inner.file_metadata, + &inner.projected_data_type, + info, + ) + .await + } + + async fn read_next_stripe(mut self) -> Result<(Self, Option)> { + let info = self + .inner + .file_metadata + .stripe_metadatas() + .get(self.inner.stripe_index) + .cloned(); + + if let Some(info) = info { + if let Some(range) = self.inner.file_byte_range.clone() { + let offset = info.offset() as usize; + if !range.contains(&offset) { + self.inner.stripe_index += 1; + return Ok((self, None)); + } + } + match self.read_next_stripe_inner(&info).await { + Ok(stripe) => Ok((self, Some(stripe))), + Err(err) => Err(err), + } + } else { + self.is_end = true; + Ok((self, None)) + } + } +} + +impl ArrowStreamReader { + pub(crate) fn new(cursor: Cursor, batch_size: usize, schema_ref: SchemaRef) -> Self { + Self { + factory: Some(Box::new(cursor.into())), + batch_size, + schema_ref, + state: StreamState::Init, + } + } + + pub fn schema(&self) -> SchemaRef { + self.schema_ref.clone() + } + + fn poll_next_inner( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match &mut self.state { + StreamState::Decoding(decoder) => match decoder.next() { + Some(Ok(batch)) => { + return Poll::Ready(Some(Ok(batch))); + } + Some(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(e))); + } + None => self.state = StreamState::Init, + }, + StreamState::Init => { + let factory = self.factory.take().expect("lost factory"); + if factory.is_end { + return Poll::Ready(None); + } + + let fut = factory.read_next_stripe().boxed(); + + self.state = StreamState::Reading(fut) + } + StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) { + Ok((factory, Some(stripe))) => { + self.factory = Some(Box::new(factory)); + match NaiveStripeDecoder::new( + stripe, + self.schema_ref.clone(), + self.batch_size, + ) { + Ok(decoder) => { + self.state = StreamState::Decoding(Box::new(decoder)); + } + Err(e) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(e))); + } + } + } + Ok((factory, None)) => { + self.factory = Some(Box::new(factory)); + // All rows skipped, read next row group + self.state = StreamState::Init; + } + Err(e) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(e))); + } + }, + StreamState::Error => return Poll::Ready(None), // Ends the stream as error happens. + } + } + } +} + +impl Stream for ArrowStreamReader { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_inner(cx) + .map_err(|e| ArrowError::ExternalError(Box::new(e))) + } +} + +impl ArrowReaderBuilder { + pub async fn try_new_async(mut reader: R) -> Result { + let file_metadata = Arc::new(read_metadata_async(&mut reader).await?); + Ok(Self::new(reader, file_metadata)) + } + + pub fn build_async(self) -> ArrowStreamReader { + let projected_data_type = self + .file_metadata() + .root_data_type() + .project(&self.projection); + let schema_ref = self.schema(); + let cursor = Cursor { + reader: self.reader, + file_metadata: self.file_metadata, + projected_data_type, + stripe_index: 0, + file_byte_range: self.file_byte_range, + }; + ArrowStreamReader::new(cursor, self.batch_size, schema_ref) + } +} diff --git a/src/bin/orc-export.rs b/src/bin/orc-export.rs new file mode 100644 index 0000000..257c03f --- /dev/null +++ b/src/bin/orc-export.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{fs::File, io, path::PathBuf}; + +use anyhow::Result; +use arrow::{array::RecordBatch, csv, datatypes::DataType, error::ArrowError, json}; +use clap::{Parser, ValueEnum}; +use json::writer::{JsonFormat, LineDelimited}; +use orc_rust::{projection::ProjectionMask, reader::metadata::read_metadata, ArrowReaderBuilder}; + +#[derive(Parser)] +#[command(name = "orc-export")] +#[command(version, about = "Export data from orc file to csv", long_about = None)] +struct Cli { + /// Path to the orc file + file: PathBuf, + /// Output file. If not provided output will be printed on console + #[arg(short, long)] + output_file: Option, + /// Output format. If not provided then the output is csv + #[arg(value_enum, short, long, default_value_t = FileFormat::Csv)] + format: FileFormat, + /// export only first N records + #[arg(short, long, value_name = "N")] + num_rows: Option, + /// export only provided columns. Comma separated list + #[arg(short, long, value_delimiter = ',')] + columns: Option>, +} + +#[derive(Clone, Debug, PartialEq, ValueEnum)] +enum FileFormat { + /// Output data in csv format + Csv, + /// Output data in json format + Json, +} + +#[allow(clippy::large_enum_variant)] +enum OutputWriter { + Csv(csv::Writer), + Json(json::Writer), +} + +impl OutputWriter +where + W: io::Write, + F: JsonFormat, +{ + fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + match self { + OutputWriter::Csv(w) => w.write(batch), + OutputWriter::Json(w) => w.write(batch), + } + } + + fn finish(&mut self) -> Result<(), ArrowError> { + match self { + OutputWriter::Csv(_) => Ok(()), + OutputWriter::Json(w) => w.finish(), + } + } +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + + // Prepare reader + let mut f = File::open(&cli.file)?; + let metadata = read_metadata(&mut f)?; + + // Select columns which should be exported (Binary and Decimal are not supported) + let cols: Vec = metadata + .root_data_type() + .children() + .iter() + .enumerate() + // TODO: handle nested types + .filter(|(_, nc)| match nc.data_type().to_arrow_data_type() { + DataType::Binary => false, + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { + matches!(cli.format, FileFormat::Csv) + } + _ => { + if let Some(cols) = &cli.columns { + cols.iter().any(|c| nc.name().eq(c)) + } else { + true + } + } + }) + .map(|(i, _)| i) + .collect(); + + let projection = ProjectionMask::roots(metadata.root_data_type(), cols); + let reader = ArrowReaderBuilder::try_new(f)? + .with_projection(projection) + .build(); + + // Prepare writer + let writer: Box = if let Some(output) = cli.output_file { + Box::new(File::create(output)?) + } else { + Box::new(io::stdout()) + }; + + let mut output_writer = match cli.format { + FileFormat::Json => { + OutputWriter::Json(json::WriterBuilder::new().build::<_, LineDelimited>(writer)) + } + _ => OutputWriter::Csv(csv::WriterBuilder::new().with_header(true).build(writer)), + }; + + // Convert data + let mut num_rows = cli.num_rows.unwrap_or(u64::MAX); + for mut batch in reader.flatten() { + // Restrict rows + if num_rows < batch.num_rows() as u64 { + batch = batch.slice(0, num_rows as usize); + } + + // Save + output_writer.write(&batch)?; + + // Have we reached limit on the number of rows? + if num_rows > batch.num_rows() as u64 { + num_rows -= batch.num_rows() as u64; + } else { + break; + } + } + + output_writer.finish()?; + + Ok(()) +} diff --git a/src/bin/orc-metadata.rs b/src/bin/orc-metadata.rs new file mode 100644 index 0000000..9c15ca8 --- /dev/null +++ b/src/bin/orc-metadata.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{error::Error, fs::File, path::PathBuf, sync::Arc}; + +use clap::Parser; +use orc_rust::{reader::metadata::read_metadata, stripe::Stripe}; + +#[derive(Parser)] +#[command(version, about, long_about = None)] +struct Cli { + /// ORC file path + file: PathBuf, + + /// Display data for all stripes + #[arg(short, long)] + stripes: bool, +} + +fn main() -> Result<(), Box> { + let cli = Cli::parse(); + + let mut f = File::open(cli.file)?; + let metadata = Arc::new(read_metadata(&mut f)?); + + // TODO: better way to handle this printing? + println!( + "compression: {}", + metadata + .compression() + .map(|c| c.to_string()) + .unwrap_or("None".to_string()) + ); + println!("file format version: {}", metadata.file_format_version()); + println!("number of rows: {}", metadata.number_of_rows()); + println!("number of stripes: {}", metadata.stripe_metadatas().len()); + + // TODO: nesting types indentation is messed up + println!("schema:\n{}", metadata.root_data_type()); + if cli.stripes { + println!("\n=== Stripes ==="); + for (i, stripe_metadata) in metadata.stripe_metadatas().iter().enumerate() { + let stripe = Stripe::new( + &mut f, + &metadata, + metadata.root_data_type(), + stripe_metadata, + )?; + println!("stripe index: {i}"); + println!("number of rows: {}", stripe.number_of_rows()); + println!( + "writer timezone: {}", + stripe + .writer_tz() + .map(|tz| tz.to_string()) + .unwrap_or("None".to_string()) + ); + println!(); + } + } + + Ok(()) +} diff --git a/src/bin/orc-stats.rs b/src/bin/orc-stats.rs new file mode 100644 index 0000000..1113a01 --- /dev/null +++ b/src/bin/orc-stats.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{fs::File, path::PathBuf, sync::Arc}; + +use anyhow::Result; +use arrow::temporal_conversions::{date32_to_datetime, timestamp_ms_to_datetime}; +use clap::Parser; +use orc_rust::{reader::metadata::read_metadata, statistics::ColumnStatistics}; + +#[derive(Parser)] +#[command(name = "orc-stats")] +#[command(version, about = "Print column and stripe stats from the orc file", long_about = None)] +struct Cli { + /// Path to the orc file + file: PathBuf, +} + +fn print_column_stats(col_stats: &ColumnStatistics) { + if let Some(tstats) = col_stats.type_statistics() { + match tstats { + orc_rust::statistics::TypeStatistics::Integer { min, max, sum } => { + println!("* Data type Integer"); + println!("* Minimum: {}", min); + println!("* Maximum: {}", max); + if let Some(sum) = sum { + println!("* Sum: {}", sum); + } + } + orc_rust::statistics::TypeStatistics::Double { min, max, sum } => { + println!("* Data type Double"); + println!("* Minimum: {}", min); + println!("* Maximum: {}", max); + if let Some(sum) = sum { + println!("* Sum: {}", sum); + } + } + orc_rust::statistics::TypeStatistics::String { min, max, sum } => { + println!("* Data type String"); + println!("* Minimum: {}", min); + println!("* Maximum: {}", max); + println!("* Sum: {}", sum); + } + orc_rust::statistics::TypeStatistics::Bucket { true_count } => { + println!("* Data type Bucket"); + println!("* True count: {}", true_count); + } + orc_rust::statistics::TypeStatistics::Decimal { min, max, sum } => { + println!("* Data type Decimal"); + println!("* Minimum: {}", min); + println!("* Maximum: {}", max); + println!("* Sum: {}", sum); + } + orc_rust::statistics::TypeStatistics::Date { min, max } => { + println!("* Data type Date"); + if let Some(dt) = date32_to_datetime(*min) { + println!("* Minimum: {}", dt); + } + if let Some(dt) = date32_to_datetime(*max) { + println!("* Maximum: {}", dt); + } + } + orc_rust::statistics::TypeStatistics::Binary { sum } => { + println!("* Data type Binary"); + println!("* Sum: {}", sum); + } + orc_rust::statistics::TypeStatistics::Timestamp { + min, + max, + min_utc, + max_utc, + } => { + println!("* Data type Timestamp"); + println!("* Minimum: {}", min); + println!("* Maximum: {}", max); + if let Some(ts) = timestamp_ms_to_datetime(*min_utc) { + println!("* Minimum UTC: {}", ts); + } + if let Some(ts) = timestamp_ms_to_datetime(*max_utc) { + println!("* Maximum UTC: {}", ts); + } + } + orc_rust::statistics::TypeStatistics::Collection { + min_children, + max_children, + total_children, + } => { + println!("* Data type Collection"); + println!("* Minimum children: {}", min_children); + println!("* Maximum children: {}", max_children); + println!("* Total children: {}", total_children); + } + } + } + + println!("* Num values: {}", col_stats.number_of_values()); + println!("* Has nulls: {}", col_stats.has_null()); + println!(); +} + +fn main() -> Result<()> { + let cli = Cli::parse(); + + let mut f = File::open(&cli.file)?; + let metadata = Arc::new(read_metadata(&mut f)?); + + println!("# Column stats"); + println!( + "File {:?} has {} columns", + cli.file, + metadata.column_file_statistics().len() + ); + println!(); + for (idx, col_stats) in metadata.column_file_statistics().iter().enumerate() { + println!("## Column {idx}"); + print_column_stats(col_stats); + } + + println!("# Stripe stats"); + println!( + "File {:?} has {} stripes", + cli.file, + metadata.stripe_metadatas().len() + ); + println!(); + for (idm, sm) in metadata.stripe_metadatas().iter().enumerate() { + println!("----- Stripe {idm} -----\n"); + for (idc, col_stats) in sm.column_statistics().iter().enumerate() { + println!("## Column {idc}"); + print_column_stats(col_stats); + } + } + + Ok(()) +} diff --git a/src/column.rs b/src/column.rs new file mode 100644 index 0000000..aaacb31 --- /dev/null +++ b/src/column.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use bytes::Bytes; +use snafu::ResultExt; + +use crate::error::{IoSnafu, Result}; +use crate::proto::{ColumnEncoding, StripeFooter}; +use crate::reader::ChunkReader; +use crate::schema::DataType; + +#[derive(Clone, Debug)] +pub struct Column { + footer: Arc, + name: String, + data_type: DataType, +} + +impl Column { + pub fn new(name: &str, data_type: &DataType, footer: &Arc) -> Self { + Self { + footer: footer.clone(), + data_type: data_type.clone(), + name: name.to_string(), + } + } + + pub fn dictionary_size(&self) -> usize { + let column = self.data_type.column_index(); + self.footer.columns[column] + .dictionary_size + .unwrap_or_default() as usize + } + + pub fn encoding(&self) -> ColumnEncoding { + let column = self.data_type.column_index(); + self.footer.columns[column].clone() + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn column_id(&self) -> u32 { + self.data_type.column_index() as u32 + } + + pub fn children(&self) -> Vec { + match &self.data_type { + DataType::Boolean { .. } + | DataType::Byte { .. } + | DataType::Short { .. } + | DataType::Int { .. } + | DataType::Long { .. } + | DataType::Float { .. } + | DataType::Double { .. } + | DataType::String { .. } + | DataType::Varchar { .. } + | DataType::Char { .. } + | DataType::Binary { .. } + | DataType::Decimal { .. } + | DataType::Timestamp { .. } + | DataType::TimestampWithLocalTimezone { .. } + | DataType::Date { .. } => vec![], + DataType::Struct { children, .. } => children + .iter() + .map(|col| Column { + footer: self.footer.clone(), + name: col.name().to_string(), + data_type: col.data_type().clone(), + }) + .collect(), + DataType::List { child, .. } => { + vec![Column { + footer: self.footer.clone(), + name: "item".to_string(), + data_type: *child.clone(), + }] + } + DataType::Map { key, value, .. } => { + vec![ + Column { + footer: self.footer.clone(), + name: "key".to_string(), + data_type: *key.clone(), + }, + Column { + footer: self.footer.clone(), + name: "value".to_string(), + data_type: *value.clone(), + }, + ] + } + DataType::Union { variants, .. } => { + // TODO: might need corrections + variants + .iter() + .enumerate() + .map(|(index, data_type)| Column { + footer: self.footer.clone(), + name: format!("{index}"), + data_type: data_type.clone(), + }) + .collect() + } + } + } + + pub fn read_stream(reader: &mut R, start: u64, length: u64) -> Result { + reader.get_bytes(start, length).context(IoSnafu) + } + + #[cfg(feature = "async")] + pub async fn read_stream_async( + reader: &mut R, + start: u64, + length: u64, + ) -> Result { + reader.get_bytes(start, length).await.context(IoSnafu) + } +} diff --git a/src/compression.rs b/src/compression.rs new file mode 100644 index 0000000..ffd6038 --- /dev/null +++ b/src/compression.rs @@ -0,0 +1,371 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Modified from https://github.com/DataEngineeringLabs/orc-format/blob/416490db0214fc51d53289253c0ee91f7fc9bc17/src/read/decompress/mod.rs +//! Related code for handling decompression of ORC files. + +use std::io::Read; + +use bytes::{Bytes, BytesMut}; +use fallible_streaming_iterator::FallibleStreamingIterator; +use snafu::ResultExt; + +use crate::error::{self, OrcError, Result}; +use crate::proto::{self, CompressionKind}; + +// Spec states default is 256K +const DEFAULT_COMPRESSION_BLOCK_SIZE: u64 = 256 * 1024; + +#[derive(Clone, Copy, Debug)] +pub struct Compression { + compression_type: CompressionType, + /// No compression chunk will decompress to larger than this size. + /// Use to size the scratch buffer appropriately. + max_decompressed_block_size: usize, +} + +impl std::fmt::Display for Compression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} ({} byte max block size)", + self.compression_type, self.max_decompressed_block_size + ) + } +} + +impl Compression { + pub fn compression_type(&self) -> CompressionType { + self.compression_type + } + + pub(crate) fn from_proto( + kind: proto::CompressionKind, + compression_block_size: Option, + ) -> Option { + let max_decompressed_block_size = + compression_block_size.unwrap_or(DEFAULT_COMPRESSION_BLOCK_SIZE) as usize; + match kind { + CompressionKind::None => None, + CompressionKind::Zlib => Some(Self { + compression_type: CompressionType::Zlib, + max_decompressed_block_size, + }), + CompressionKind::Snappy => Some(Self { + compression_type: CompressionType::Snappy, + max_decompressed_block_size, + }), + CompressionKind::Lzo => Some(Self { + compression_type: CompressionType::Lzo, + max_decompressed_block_size, + }), + CompressionKind::Lz4 => Some(Self { + compression_type: CompressionType::Lz4, + max_decompressed_block_size, + }), + CompressionKind::Zstd => Some(Self { + compression_type: CompressionType::Zstd, + max_decompressed_block_size, + }), + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum CompressionType { + Zlib, + Snappy, + Lzo, + Lz4, + Zstd, +} + +impl std::fmt::Display for CompressionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +/// Indicates length of block and whether it's compressed or not. +#[derive(Debug, PartialEq, Eq)] +enum CompressionHeader { + Original(u32), + Compressed(u32), +} + +/// ORC files are compressed in blocks, with a 3 byte header at the start +/// of these blocks indicating the length of the block and whether it's +/// compressed or not. +fn decode_header(bytes: [u8; 3]) -> CompressionHeader { + let bytes = [bytes[0], bytes[1], bytes[2], 0]; + let length_and_flag = u32::from_le_bytes(bytes); + let is_original = length_and_flag & 1 == 1; + let length = length_and_flag >> 1; + if is_original { + CompressionHeader::Original(length) + } else { + CompressionHeader::Compressed(length) + } +} + +pub(crate) trait DecompressorVariant: Send { + fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec) -> Result<()>; +} + +#[derive(Debug, Clone, Copy)] +struct Zlib; +#[derive(Debug, Clone, Copy)] +struct Zstd; +#[derive(Debug, Clone, Copy)] +struct Snappy; +#[derive(Debug, Clone, Copy)] +struct Lzo; +#[derive(Debug, Clone, Copy)] +struct Lz4 { + max_decompressed_block_size: usize, +} + +impl DecompressorVariant for Zlib { + fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec) -> Result<()> { + let mut gz = flate2::read::DeflateDecoder::new(compressed_bytes); + scratch.clear(); + gz.read_to_end(scratch).context(error::IoSnafu)?; + Ok(()) + } +} + +impl DecompressorVariant for Zstd { + fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec) -> Result<()> { + let mut reader = + zstd::Decoder::new(compressed_bytes).context(error::BuildZstdDecoderSnafu)?; + scratch.clear(); + reader.read_to_end(scratch).context(error::IoSnafu)?; + Ok(()) + } +} + +impl DecompressorVariant for Snappy { + fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec) -> Result<()> { + let len = + snap::raw::decompress_len(compressed_bytes).context(error::BuildSnappyDecoderSnafu)?; + scratch.resize(len, 0); + let mut decoder = snap::raw::Decoder::new(); + decoder + .decompress(compressed_bytes, scratch) + .context(error::BuildSnappyDecoderSnafu)?; + Ok(()) + } +} + +impl DecompressorVariant for Lzo { + fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec) -> Result<()> { + let decompressed = lzokay_native::decompress_all(compressed_bytes, None) + .context(error::BuildLzoDecoderSnafu)?; + // TODO: better way to utilize scratch here + scratch.clear(); + scratch.extend(decompressed); + Ok(()) + } +} + +impl DecompressorVariant for Lz4 { + fn decompress_block(&self, compressed_bytes: &[u8], scratch: &mut Vec) -> Result<()> { + let decompressed = + lz4_flex::block::decompress(compressed_bytes, self.max_decompressed_block_size) + .context(error::BuildLz4DecoderSnafu)?; + // TODO: better way to utilize scratch here + scratch.clear(); + scratch.extend(decompressed); + Ok(()) + } +} + +// TODO: push this earlier so we don't check this variant each time +fn get_decompressor_variant( + Compression { + compression_type, + max_decompressed_block_size, + }: Compression, +) -> Box { + match compression_type { + CompressionType::Zlib => Box::new(Zlib), + CompressionType::Snappy => Box::new(Snappy), + CompressionType::Lzo => Box::new(Lzo), + CompressionType::Lz4 => Box::new(Lz4 { + max_decompressed_block_size, + }), + CompressionType::Zstd => Box::new(Zstd), + } +} + +enum State { + Original(Bytes), + Compressed(Vec), +} + +struct DecompressorIter { + stream: BytesMut, + current: Option, // when we have compression but the value is original + compression: Option>, + scratch: Vec, +} + +impl DecompressorIter { + fn new(stream: Bytes, compression: Option, scratch: Vec) -> Self { + Self { + stream: BytesMut::from(stream.as_ref()), + current: None, + compression: compression.map(get_decompressor_variant), + scratch, + } + } +} + +impl FallibleStreamingIterator for DecompressorIter { + type Item = [u8]; + + type Error = OrcError; + + #[inline] + fn advance(&mut self) -> Result<(), Self::Error> { + if self.stream.is_empty() { + self.current = None; + return Ok(()); + } + + match &self.compression { + Some(compression) => { + // TODO: take stratch from current State::Compressed for re-use + let header = self.stream.split_to(3); + let header = [header[0], header[1], header[2]]; + match decode_header(header) { + CompressionHeader::Original(length) => { + let original = self.stream.split_to(length as usize); + self.current = Some(State::Original(original.into())); + } + CompressionHeader::Compressed(length) => { + let compressed = self.stream.split_to(length as usize); + compression.decompress_block(&compressed, &mut self.scratch)?; + self.current = Some(State::Compressed(std::mem::take(&mut self.scratch))); + } + }; + Ok(()) + } + None => { + // TODO: take stratch from current State::Compressed for re-use + self.current = Some(State::Original(self.stream.clone().into())); + self.stream.clear(); + Ok(()) + } + } + } + + #[inline] + fn get(&self) -> Option<&Self::Item> { + self.current.as_ref().map(|x| match x { + State::Original(x) => x.as_ref(), + State::Compressed(x) => x.as_ref(), + }) + } +} + +/// A [`Read`]er fulfilling the ORC specification of reading compressed data. +pub(crate) struct Decompressor { + decompressor: DecompressorIter, + offset: usize, + is_first: bool, +} + +impl Decompressor { + /// Creates a new [`Decompressor`] that will use `scratch` as a temporary region. + pub fn new(stream: Bytes, compression: Option, scratch: Vec) -> Self { + Self { + decompressor: DecompressorIter::new(stream, compression, scratch), + offset: 0, + is_first: true, + } + } + + // TODO: remove need for this upstream + pub fn empty() -> Self { + Self { + decompressor: DecompressorIter::new(Bytes::new(), None, vec![]), + offset: 0, + is_first: true, + } + } +} + +impl std::io::Read for Decompressor { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + if self.is_first { + self.is_first = false; + self.decompressor.advance().unwrap(); + } + let current = self.decompressor.get(); + let current = if let Some(current) = current { + if current.len() == self.offset { + self.decompressor.advance().unwrap(); + self.offset = 0; + let current = self.decompressor.get(); + if let Some(current) = current { + current + } else { + return Ok(0); + } + } else { + ¤t[self.offset..] + } + } else { + return Ok(0); + }; + + if current.len() >= buf.len() { + buf.copy_from_slice(¤t[..buf.len()]); + self.offset += buf.len(); + Ok(buf.len()) + } else { + buf[..current.len()].copy_from_slice(current); + self.offset += current.len(); + Ok(current.len()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode_uncompressed() { + // 5 uncompressed = [0x0b, 0x00, 0x00] = [0b1011, 0, 0] + let bytes = [0b1011, 0, 0]; + + let expected = CompressionHeader::Original(5); + let actual = decode_header(bytes); + assert_eq!(expected, actual); + } + + #[test] + fn decode_compressed() { + // 100_000 compressed = [0x40, 0x0d, 0x03] = [0b01000000, 0b00001101, 0b00000011] + let bytes = [0b0100_0000, 0b0000_1101, 0b0000_0011]; + let expected = CompressionHeader::Compressed(100_000); + let actual = decode_header(bytes); + assert_eq!(expected, actual); + } +} diff --git a/src/encoding/boolean.rs b/src/encoding/boolean.rs new file mode 100644 index 0000000..f86e585 --- /dev/null +++ b/src/encoding/boolean.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use arrow::{ + array::BooleanBufferBuilder, + buffer::{BooleanBuffer, NullBuffer}, +}; +use bytes::Bytes; + +use crate::{error::Result, memory::EstimateMemory}; + +use super::{ + byte::{ByteRleDecoder, ByteRleEncoder}, + PrimitiveValueDecoder, PrimitiveValueEncoder, +}; + +pub struct BooleanDecoder { + decoder: ByteRleDecoder, + data: u8, + bits_in_data: usize, +} + +impl BooleanDecoder { + pub fn new(reader: R) -> Self { + Self { + decoder: ByteRleDecoder::new(reader), + bits_in_data: 0, + data: 0, + } + } + + pub fn value(&mut self) -> bool { + let value = (self.data & 0x80) != 0; + self.data <<= 1; + self.bits_in_data -= 1; + + value + } +} + +impl PrimitiveValueDecoder for BooleanDecoder { + // TODO: can probably implement this better + fn decode(&mut self, out: &mut [bool]) -> Result<()> { + for x in out.iter_mut() { + // read more data if necessary + if self.bits_in_data == 0 { + let mut data = [0]; + self.decoder.decode(&mut data)?; + self.data = data[0] as u8; + self.bits_in_data = 8; + } + *x = self.value(); + } + Ok(()) + } +} + +/// ORC encodes validity starting from MSB, whilst Arrow encodes it +/// from LSB. After bytes are filled with the present bits, they are +/// further encoded via Byte RLE. +pub struct BooleanEncoder { + // TODO: can we refactor to not need two separate buffers? + byte_encoder: ByteRleEncoder, + builder: BooleanBufferBuilder, +} + +impl EstimateMemory for BooleanEncoder { + fn estimate_memory_size(&self) -> usize { + self.builder.len() / 8 + } +} + +impl BooleanEncoder { + pub fn new() -> Self { + Self { + byte_encoder: ByteRleEncoder::new(), + builder: BooleanBufferBuilder::new(8), + } + } + + pub fn extend(&mut self, null_buffer: &NullBuffer) { + let bb = null_buffer.inner(); + self.extend_bb(bb); + } + + pub fn extend_bb(&mut self, bb: &BooleanBuffer) { + self.builder.append_buffer(bb); + } + + /// Extend with n true bits. + pub fn extend_present(&mut self, n: usize) { + self.builder.append_n(n, true); + } + + pub fn extend_boolean(&mut self, b: bool) { + self.builder.append(b); + } + + /// Produce ORC present stream bytes and reset internal builder. + pub fn finish(&mut self) -> Bytes { + // TODO: don't throw away allocation? + let bb = self.builder.finish(); + // We use BooleanBufferBuilder so offset is 0 + let bytes = bb.values(); + // Reverse bits as ORC stores from MSB + let bytes = bytes.iter().map(|b| b.reverse_bits()).collect::>(); + for &b in bytes.as_slice() { + self.byte_encoder.write_one(b as i8); + } + self.byte_encoder.take_inner() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + let expected = vec![false; 800]; + let data = [0x61u8, 0x00]; + let data = &mut data.as_ref(); + let mut decoder = BooleanDecoder::new(data); + let mut actual = vec![true; expected.len()]; + decoder.decode(&mut actual).unwrap(); + assert_eq!(actual, expected) + } + + #[test] + fn literals() { + let expected = vec![ + false, true, false, false, false, true, false, false, // 0b01000100 + false, true, false, false, false, true, false, true, // 0b01000101 + ]; + let data = [0xfeu8, 0b01000100, 0b01000101]; + let data = &mut data.as_ref(); + let mut decoder = BooleanDecoder::new(data); + let mut actual = vec![true; expected.len()]; + decoder.decode(&mut actual).unwrap(); + assert_eq!(actual, expected) + } + + #[test] + fn another() { + // "For example, the byte sequence [0xff, 0x80] would be one true followed by seven false values." + let expected = vec![true, false, false, false, false, false, false, false]; + let data = [0xff, 0x80]; + let data = &mut data.as_ref(); + let mut decoder = BooleanDecoder::new(data); + let mut actual = vec![true; expected.len()]; + decoder.decode(&mut actual).unwrap(); + assert_eq!(actual, expected) + } +} diff --git a/src/encoding/byte.rs b/src/encoding/byte.rs new file mode 100644 index 0000000..d2ad199 --- /dev/null +++ b/src/encoding/byte.rs @@ -0,0 +1,340 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use bytemuck::must_cast_slice; +use bytes::{BufMut, BytesMut}; +use snafu::ResultExt; + +use crate::{ + error::{IoSnafu, Result}, + memory::EstimateMemory, +}; +use std::io::Read; + +use super::{rle::GenericRle, util::read_u8, PrimitiveValueEncoder}; + +const MAX_LITERAL_LENGTH: usize = 128; +const MIN_REPEAT_LENGTH: usize = 3; +const MAX_REPEAT_LENGTH: usize = 130; + +pub struct ByteRleEncoder { + writer: BytesMut, + /// Literal values to encode. + literals: [u8; MAX_LITERAL_LENGTH], + /// Represents the number of elements currently in `literals` if Literals, + /// otherwise represents the length of the Run. + num_literals: usize, + /// Tracks if current Literal sequence will turn into a Run sequence due to + /// repeated values at the end of the value sequence. + tail_run_length: usize, + /// If in Run sequence or not, and keeps the corresponding value. + run_value: Option, +} + +impl ByteRleEncoder { + /// Incrementally encode bytes using Run Length Encoding, where the subencodings are: + /// - Run: at least 3 repeated values in sequence (up to `MAX_REPEAT_LENGTH`) + /// - Literals: disparate values (up to `MAX_LITERAL_LENGTH` length) + /// + /// How the relevant encodings are chosen: + /// - Keep of track of values as they come, starting off assuming Literal sequence + /// - Keep track of latest value, to see if we are encountering a sequence of repeated + /// values (Run sequence) + /// - If this tail end exceeds the required minimum length, flush the current Literal + /// sequence (or switch to Run if entire current sequence is the repeated value) + /// - Whether in Literal or Run mode, keep buffering values and flushing when max length + /// reached or encoding is broken (e.g. non-repeated value found in Run mode) + fn process_value(&mut self, value: u8) { + // Adapted from https://github.com/apache/orc/blob/main/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java + if self.num_literals == 0 { + // Start off in Literal mode + self.run_value = None; + self.literals[0] = value; + self.num_literals = 1; + self.tail_run_length = 1; + } else if let Some(run_value) = self.run_value { + // Run mode + + if value == run_value { + // Continue buffering for Run sequence, flushing if reaching max length + self.num_literals += 1; + if self.num_literals == MAX_REPEAT_LENGTH { + write_run(&mut self.writer, run_value, MAX_REPEAT_LENGTH); + self.clear_state(); + } + } else { + // Run is broken, flush then start again in Literal mode + write_run(&mut self.writer, run_value, self.num_literals); + self.run_value = None; + self.literals[0] = value; + self.num_literals = 1; + self.tail_run_length = 1; + } + } else { + // Literal mode + + // tail_run_length tracks length of repetition of last value + if value == self.literals[self.num_literals - 1] { + self.tail_run_length += 1; + } else { + self.tail_run_length = 1; + } + + if self.tail_run_length == MIN_REPEAT_LENGTH { + // When the tail end of the current sequence is enough for a Run sequence + + if self.num_literals + 1 == MIN_REPEAT_LENGTH { + // If current values are enough for a Run sequence, switch to Run encoding + self.run_value = Some(value); + self.num_literals += 1; + } else { + // Flush the current Literal sequence, then switch to Run encoding + // We don't flush the tail end which is a Run sequence + let len = self.num_literals - (MIN_REPEAT_LENGTH - 1); + let literals = &self.literals[..len]; + write_literals(&mut self.writer, literals); + self.run_value = Some(value); + self.num_literals = MIN_REPEAT_LENGTH; + } + } else { + // Continue buffering for Literal sequence, flushing if reaching max length + self.literals[self.num_literals] = value; + self.num_literals += 1; + if self.num_literals == MAX_LITERAL_LENGTH { + // Entire literals is filled, pass in as is + write_literals(&mut self.writer, &self.literals); + self.clear_state(); + } + } + } + } + + fn clear_state(&mut self) { + self.run_value = None; + self.tail_run_length = 0; + self.num_literals = 0; + } + + /// Flush any buffered values to writer in appropriate sequence. + fn flush(&mut self) { + if self.num_literals != 0 { + if let Some(value) = self.run_value { + write_run(&mut self.writer, value, self.num_literals); + } else { + let literals = &self.literals[..self.num_literals]; + write_literals(&mut self.writer, literals); + } + self.clear_state(); + } + } +} + +impl EstimateMemory for ByteRleEncoder { + fn estimate_memory_size(&self) -> usize { + self.writer.len() + self.num_literals + } +} + +/// i8 to match with Arrow Int8 type. +impl PrimitiveValueEncoder for ByteRleEncoder { + fn new() -> Self { + Self { + writer: BytesMut::new(), + literals: [0; MAX_LITERAL_LENGTH], + num_literals: 0, + tail_run_length: 0, + run_value: None, + } + } + + fn write_one(&mut self, value: i8) { + self.process_value(value as u8); + } + + fn take_inner(&mut self) -> bytes::Bytes { + self.flush(); + std::mem::take(&mut self.writer).into() + } +} + +fn write_run(writer: &mut BytesMut, value: u8, run_length: usize) { + debug_assert!( + (MIN_REPEAT_LENGTH..=MAX_REPEAT_LENGTH).contains(&run_length), + "Byte RLE Run sequence must be in range 3..=130" + ); + // [3, 130] to [0, 127] + let header = run_length - MIN_REPEAT_LENGTH; + writer.put_u8(header as u8); + writer.put_u8(value); +} + +fn write_literals(writer: &mut BytesMut, literals: &[u8]) { + debug_assert!( + (1..=MAX_LITERAL_LENGTH).contains(&literals.len()), + "Byte RLE Literal sequence must be in range 1..=128" + ); + // [1, 128] to [-1, -128], then writing as a byte + let header = -(literals.len() as i32); + writer.put_u8(header as u8); + writer.put_slice(literals); +} + +pub struct ByteRleDecoder { + reader: R, + /// Values that have been decoded but not yet emitted. + leftovers: Vec, + /// Index into leftovers to make it act like a queue; indicates the + /// next element available to read + index: usize, +} + +impl ByteRleDecoder { + pub fn new(reader: R) -> Self { + Self { + reader, + leftovers: Vec::with_capacity(MAX_REPEAT_LENGTH), + index: 0, + } + } +} + +impl GenericRle for ByteRleDecoder { + fn advance(&mut self, n: usize) { + self.index += n + } + + fn available(&self) -> &[i8] { + let bytes = &self.leftovers[self.index..]; + must_cast_slice(bytes) + } + + fn decode_batch(&mut self) -> Result<()> { + self.index = 0; + self.leftovers.clear(); + + let header = read_u8(&mut self.reader)?; + if header < 0x80 { + // Run of repeated value + let length = header as usize + MIN_REPEAT_LENGTH; + let value = read_u8(&mut self.reader)?; + self.leftovers.extend(std::iter::repeat(value).take(length)); + } else { + // List of values + let length = 0x100 - header as usize; + self.leftovers.resize(length, 0); + self.reader + .read_exact(&mut self.leftovers) + .context(IoSnafu)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use crate::encoding::PrimitiveValueDecoder; + + use super::*; + + use proptest::prelude::*; + + // TODO: have tests varying the out buffer, to ensure decode() is called + // multiple times + + fn test_helper(data: &[u8], expected: &[i8]) { + let mut reader = ByteRleDecoder::new(Cursor::new(data)); + let mut actual = vec![0; expected.len()]; + reader.decode(&mut actual).unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn reader_test() { + let data = [0x61u8, 0x00]; + let expected = [0; 100]; + test_helper(&data, &expected); + + let data = [0x01, 0x01]; + let expected = [1; 4]; + test_helper(&data, &expected); + + let data = [0xfe, 0x44, 0x45]; + let expected = [0x44, 0x45]; + test_helper(&data, &expected); + } + + fn roundtrip_byte_rle_helper(values: &[i8]) -> Result> { + let mut writer = ByteRleEncoder::new(); + writer.write_slice(values); + writer.flush(); + + let buf = writer.take_inner(); + let mut cursor = Cursor::new(&buf); + let mut reader = ByteRleDecoder::new(&mut cursor); + let mut actual = vec![0; values.len()]; + reader.decode(&mut actual)?; + Ok(actual) + } + + #[derive(Debug, Clone)] + enum ByteSequence { + Run(i8, usize), + Literals(Vec), + } + + fn byte_sequence_strategy() -> impl Strategy { + // We limit the max length of the sequences to 140 to try get more interleaving + prop_oneof![ + (any::(), 1..140_usize).prop_map(|(a, b)| ByteSequence::Run(a, b)), + prop::collection::vec(any::(), 1..140).prop_map(ByteSequence::Literals) + ] + } + + fn generate_bytes_from_sequences(sequences: Vec) -> Vec { + let mut bytes = vec![]; + for sequence in sequences { + match sequence { + ByteSequence::Run(value, length) => { + bytes.extend(std::iter::repeat(value).take(length)) + } + ByteSequence::Literals(literals) => bytes.extend(literals), + } + } + bytes + } + + proptest! { + #[test] + fn roundtrip_byte_rle_pure_random(values: Vec) { + // Biased towards literal sequences due to purely random values + let out = roundtrip_byte_rle_helper(&values).unwrap(); + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_byte_rle_biased( + sequences in prop::collection::vec(byte_sequence_strategy(), 1..200) + ) { + // Intentionally introduce run sequences to not be entirely random literals + let values = generate_bytes_from_sequences(sequences); + let out = roundtrip_byte_rle_helper(&values).unwrap(); + prop_assert_eq!(out, values); + } + } +} diff --git a/src/encoding/decimal.rs b/src/encoding/decimal.rs new file mode 100644 index 0000000..f722cf0 --- /dev/null +++ b/src/encoding/decimal.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use crate::error::Result; + +use super::{ + integer::{read_varint_zigzagged, SignedEncoding}, + PrimitiveValueDecoder, +}; + +/// Read stream of zigzag encoded varints as i128 (unbound). +pub struct UnboundedVarintStreamDecoder { + reader: R, +} + +impl UnboundedVarintStreamDecoder { + pub fn new(reader: R) -> Self { + Self { reader } + } +} + +impl PrimitiveValueDecoder for UnboundedVarintStreamDecoder { + fn decode(&mut self, out: &mut [i128]) -> Result<()> { + for x in out.iter_mut() { + *x = read_varint_zigzagged::(&mut self.reader)?; + } + Ok(()) + } +} diff --git a/src/encoding/float.rs b/src/encoding/float.rs new file mode 100644 index 0000000..5b9fa7e --- /dev/null +++ b/src/encoding/float.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::marker::PhantomData; + +use bytemuck::{must_cast_slice, must_cast_slice_mut}; +use bytes::{Bytes, BytesMut}; +use snafu::ResultExt; + +use crate::{ + error::{IoSnafu, Result}, + memory::EstimateMemory, +}; + +use super::{PrimitiveValueDecoder, PrimitiveValueEncoder}; + +/// Collect all the required traits we need on floats. +pub trait Float: + num::Float + std::fmt::Debug + bytemuck::NoUninit + bytemuck::AnyBitPattern +{ +} +impl Float for f32 {} +impl Float for f64 {} + +pub struct FloatDecoder { + reader: R, + phantom: std::marker::PhantomData, +} + +impl FloatDecoder { + pub fn new(reader: R) -> Self { + Self { + reader, + phantom: Default::default(), + } + } +} + +impl PrimitiveValueDecoder for FloatDecoder { + fn decode(&mut self, out: &mut [F]) -> Result<()> { + let bytes = must_cast_slice_mut::(out); + self.reader.read_exact(bytes).context(IoSnafu)?; + Ok(()) + } +} + +/// No special run encoding for floats/doubles, they are stored as their IEEE 754 floating +/// point bit layout. This encoder simply copies incoming floats/doubles to its internal +/// byte buffer. +pub struct FloatEncoder { + data: BytesMut, + _phantom: PhantomData, +} + +impl EstimateMemory for FloatEncoder { + fn estimate_memory_size(&self) -> usize { + self.data.len() + } +} + +impl PrimitiveValueEncoder for FloatEncoder { + fn new() -> Self { + Self { + data: BytesMut::new(), + _phantom: Default::default(), + } + } + + fn write_one(&mut self, value: F) { + self.write_slice(&[value]); + } + + fn write_slice(&mut self, values: &[F]) { + let bytes = must_cast_slice::(values); + self.data.extend_from_slice(bytes); + } + + fn take_inner(&mut self) -> Bytes { + std::mem::take(&mut self.data).into() + } +} + +#[cfg(test)] +mod tests { + use std::f32::consts as f32c; + use std::f64::consts as f64c; + use std::io::Cursor; + + use proptest::prelude::*; + + use super::*; + + fn roundtrip_helper(input: &[F]) -> Result> { + let mut encoder = FloatEncoder::::new(); + encoder.write_slice(input); + let bytes = encoder.take_inner(); + let bytes = Cursor::new(bytes); + + let mut iter = FloatDecoder::::new(bytes); + let mut actual = vec![F::zero(); input.len()]; + iter.decode(&mut actual)?; + + Ok(actual) + } + + fn assert_roundtrip(input: Vec) { + let actual = roundtrip_helper(&input).unwrap(); + assert_eq!(input, actual); + } + + proptest! { + #[test] + fn roundtrip_f32(values: Vec) { + let out = roundtrip_helper(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_f64(values: Vec) { + let out = roundtrip_helper(&values)?; + prop_assert_eq!(out, values); + } + } + + #[test] + fn test_float_edge_cases() { + assert_roundtrip::(vec![]); + assert_roundtrip::(vec![]); + + assert_roundtrip(vec![f32c::PI]); + assert_roundtrip(vec![f64c::PI]); + + let actual = roundtrip_helper(&[f32::NAN]).unwrap(); + assert!(actual[0].is_nan()); + let actual = roundtrip_helper(&[f64::NAN]).unwrap(); + assert!(actual[0].is_nan()); + } + + #[test] + fn test_float_many() { + assert_roundtrip(vec![ + f32::NEG_INFINITY, + f32::MIN, + -1.0, + -0.0, + 0.0, + 1.0, + f32c::SQRT_2, + f32::MAX, + f32::INFINITY, + ]); + + assert_roundtrip(vec![ + f64::NEG_INFINITY, + f64::MIN, + -1.0, + -0.0, + 0.0, + 1.0, + f64c::SQRT_2, + f64::MAX, + f64::INFINITY, + ]); + } +} diff --git a/src/encoding/integer/mod.rs b/src/encoding/integer/mod.rs new file mode 100644 index 0000000..f652d4e --- /dev/null +++ b/src/encoding/integer/mod.rs @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Run length encoding & decoding for integers + +use std::{ + fmt, + io::Read, + ops::{BitOrAssign, ShlAssign}, +}; + +use num::{traits::CheckedShl, PrimInt, Signed}; +use rle_v1::RleV1Decoder; +use rle_v2::RleV2Decoder; +use snafu::ResultExt; +use util::{ + get_closest_aligned_bit_width, signed_msb_decode, signed_zigzag_decode, signed_zigzag_encode, +}; + +use crate::{ + column::Column, + error::{InvalidColumnEncodingSnafu, IoSnafu, Result}, + proto::column_encoding::Kind as ProtoColumnKind, +}; + +use super::PrimitiveValueDecoder; + +pub mod rle_v1; +pub mod rle_v2; +mod util; + +// TODO: consider having a separate varint.rs +pub use util::read_varint_zigzagged; + +pub fn get_unsigned_rle_reader( + column: &Column, + reader: R, +) -> Box + Send> { + match column.encoding().kind() { + ProtoColumnKind::Direct | ProtoColumnKind::Dictionary => { + Box::new(RleV1Decoder::::new(reader)) + } + ProtoColumnKind::DirectV2 | ProtoColumnKind::DictionaryV2 => { + Box::new(RleV2Decoder::::new(reader)) + } + } +} + +pub fn get_rle_reader( + column: &Column, + reader: R, +) -> Result + Send>> { + match column.encoding().kind() { + ProtoColumnKind::Direct => Ok(Box::new(RleV1Decoder::::new(reader))), + ProtoColumnKind::DirectV2 => { + Ok(Box::new(RleV2Decoder::::new(reader))) + } + k => InvalidColumnEncodingSnafu { + name: column.name(), + encoding: k, + } + .fail(), + } +} + +pub trait EncodingSign: Send + 'static { + // TODO: have separate type/trait to represent Zigzag encoded NInt? + fn zigzag_decode(v: N) -> N; + fn zigzag_encode(v: N) -> N; + + fn decode_signed_msb(v: N, encoded_byte_size: usize) -> N; +} + +pub struct SignedEncoding; + +impl EncodingSign for SignedEncoding { + #[inline] + fn zigzag_decode(v: N) -> N { + signed_zigzag_decode(v) + } + + #[inline] + fn zigzag_encode(v: N) -> N { + signed_zigzag_encode(v) + } + + #[inline] + fn decode_signed_msb(v: N, encoded_byte_size: usize) -> N { + signed_msb_decode(v, encoded_byte_size) + } +} + +pub struct UnsignedEncoding; + +impl EncodingSign for UnsignedEncoding { + #[inline] + fn zigzag_decode(v: N) -> N { + v + } + + #[inline] + fn zigzag_encode(v: N) -> N { + v + } + + #[inline] + fn decode_signed_msb(v: N, _encoded_byte_size: usize) -> N { + v + } +} + +pub trait VarintSerde: PrimInt + CheckedShl + BitOrAssign + Signed { + const BYTE_SIZE: usize; + + /// Calculate the minimum bit size required to represent this value, by truncating + /// the leading zeros. + #[inline] + fn bits_used(self) -> usize { + Self::BYTE_SIZE * 8 - self.leading_zeros() as usize + } + + /// Feeds [`Self::bits_used`] into a mapping to get an aligned bit width. + fn closest_aligned_bit_width(self) -> usize { + get_closest_aligned_bit_width(self.bits_used()) + } + + fn from_u8(b: u8) -> Self; +} + +/// Helps generalise the decoder efforts to be specific to supported integers. +/// (Instead of decoding to u64/i64 for all then downcasting). +pub trait NInt: + VarintSerde + ShlAssign + fmt::Debug + fmt::Display + fmt::Binary + Send + Sync + 'static +{ + type Bytes: AsRef<[u8]> + AsMut<[u8]> + Default + Clone + Copy + fmt::Debug; + + #[inline] + fn empty_byte_array() -> Self::Bytes { + Self::Bytes::default() + } + + /// Should truncate any extra bits. + fn from_i64(u: i64) -> Self; + + fn from_be_bytes(b: Self::Bytes) -> Self; + + // TODO: use num_traits::ToBytes instead + fn to_be_bytes(self) -> Self::Bytes; + + fn add_i64(self, i: i64) -> Option; + + fn sub_i64(self, i: i64) -> Option; + + // TODO: use Into instead? + fn as_i64(self) -> i64; + + fn read_big_endian(reader: &mut impl Read, byte_size: usize) -> Result { + debug_assert!( + byte_size <= Self::BYTE_SIZE, + "byte_size cannot exceed max byte size of self" + ); + let mut buffer = Self::empty_byte_array(); + // Read into back part of buffer since is big endian. + // So if smaller than N::BYTE_SIZE bytes, most significant bytes will be 0. + reader + .read_exact(&mut buffer.as_mut()[Self::BYTE_SIZE - byte_size..]) + .context(IoSnafu)?; + Ok(Self::from_be_bytes(buffer)) + } +} + +impl VarintSerde for i16 { + const BYTE_SIZE: usize = 2; + + #[inline] + fn from_u8(b: u8) -> Self { + b as Self + } +} + +impl VarintSerde for i32 { + const BYTE_SIZE: usize = 4; + + #[inline] + fn from_u8(b: u8) -> Self { + b as Self + } +} + +impl VarintSerde for i64 { + const BYTE_SIZE: usize = 8; + + #[inline] + fn from_u8(b: u8) -> Self { + b as Self + } +} + +impl VarintSerde for i128 { + const BYTE_SIZE: usize = 16; + + #[inline] + fn from_u8(b: u8) -> Self { + b as Self + } +} + +// We only implement for i16, i32, i64 and u64. +// ORC supports only signed Short, Integer and Long types for its integer types, +// and i8 is encoded as bytes. u64 is used for other encodings such as Strings +// (to encode length, etc.). + +impl NInt for i16 { + type Bytes = [u8; 2]; + + #[inline] + fn from_i64(i: i64) -> Self { + i as Self + } + + #[inline] + fn from_be_bytes(b: Self::Bytes) -> Self { + Self::from_be_bytes(b) + } + + #[inline] + fn to_be_bytes(self) -> Self::Bytes { + self.to_be_bytes() + } + + #[inline] + fn add_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_add(i)) + } + + #[inline] + fn sub_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_sub(i)) + } + + #[inline] + fn as_i64(self) -> i64 { + self as i64 + } +} + +impl NInt for i32 { + type Bytes = [u8; 4]; + + #[inline] + fn from_i64(i: i64) -> Self { + i as Self + } + + #[inline] + fn from_be_bytes(b: Self::Bytes) -> Self { + Self::from_be_bytes(b) + } + + #[inline] + fn to_be_bytes(self) -> Self::Bytes { + self.to_be_bytes() + } + + #[inline] + fn add_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_add(i)) + } + + #[inline] + fn sub_i64(self, i: i64) -> Option { + i.try_into().ok().and_then(|i| self.checked_sub(i)) + } + + #[inline] + fn as_i64(self) -> i64 { + self as i64 + } +} + +impl NInt for i64 { + type Bytes = [u8; 8]; + + #[inline] + fn from_i64(i: i64) -> Self { + i as Self + } + + #[inline] + fn from_be_bytes(b: Self::Bytes) -> Self { + Self::from_be_bytes(b) + } + + #[inline] + fn to_be_bytes(self) -> Self::Bytes { + self.to_be_bytes() + } + + #[inline] + fn add_i64(self, i: i64) -> Option { + self.checked_add(i) + } + + #[inline] + fn sub_i64(self, i: i64) -> Option { + self.checked_sub(i) + } + + #[inline] + fn as_i64(self) -> i64 { + self + } +} diff --git a/src/encoding/integer/rle_v1.rs b/src/encoding/integer/rle_v1.rs new file mode 100644 index 0000000..02c0495 --- /dev/null +++ b/src/encoding/integer/rle_v1.rs @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Handling decoding of Integer Run Length Encoded V1 data in ORC files + +use std::{io::Read, marker::PhantomData}; + +use snafu::OptionExt; + +use crate::{ + encoding::{ + rle::GenericRle, + util::{read_u8, try_read_u8}, + }, + error::{OutOfSpecSnafu, Result}, +}; + +use super::{util::read_varint_zigzagged, EncodingSign, NInt}; + +const MAX_RUN_LENGTH: usize = 130; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum EncodingType { + Run { length: usize, delta: i8 }, + Literals { length: usize }, +} + +impl EncodingType { + /// Decode header byte to determine sub-encoding. + /// Runs start with a positive byte, and literals with a negative byte. + fn from_header(reader: &mut R) -> Result> { + let opt_encoding = match try_read_u8(reader)?.map(|b| b as i8) { + Some(header) if header < 0 => { + let length = header.unsigned_abs() as usize; + Some(Self::Literals { length }) + } + Some(header) => { + let length = header as u8 as usize + 3; + let delta = read_u8(reader)? as i8; + Some(Self::Run { length, delta }) + } + None => None, + }; + Ok(opt_encoding) + } +} + +/// Decodes a stream of Integer Run Length Encoded version 1 bytes. +pub struct RleV1Decoder { + reader: R, + decoded_ints: Vec, + current_head: usize, + sign: PhantomData, +} + +impl RleV1Decoder { + pub fn new(reader: R) -> Self { + Self { + reader, + decoded_ints: Vec::with_capacity(MAX_RUN_LENGTH), + current_head: 0, + sign: Default::default(), + } + } +} + +fn read_literals( + reader: &mut R, + out_ints: &mut Vec, + length: usize, +) -> Result<()> { + for _ in 0..length { + let lit = read_varint_zigzagged::<_, _, S>(reader)?; + out_ints.push(lit); + } + Ok(()) +} + +fn read_run( + reader: &mut R, + out_ints: &mut Vec, + length: usize, + delta: i8, +) -> Result<()> { + let mut base = read_varint_zigzagged::<_, _, S>(reader)?; + // Account for base value + let length = length - 1; + out_ints.push(base); + if delta < 0 { + let delta = delta.unsigned_abs(); + let delta = N::from_u8(delta); + for _ in 0..length { + base = base.checked_sub(&delta).context(OutOfSpecSnafu { + msg: "over/underflow when decoding patched base integer", + })?; + out_ints.push(base); + } + } else { + let delta = delta as u8; + let delta = N::from_u8(delta); + for _ in 0..length { + base = base.checked_add(&delta).context(OutOfSpecSnafu { + msg: "over/underflow when decoding patched base integer", + })?; + out_ints.push(base); + } + } + Ok(()) +} + +impl GenericRle for RleV1Decoder { + fn advance(&mut self, n: usize) { + self.current_head += n; + } + + fn available(&self) -> &[N] { + &self.decoded_ints[self.current_head..] + } + + fn decode_batch(&mut self) -> Result<()> { + self.current_head = 0; + self.decoded_ints.clear(); + + match EncodingType::from_header(&mut self.reader)? { + Some(EncodingType::Literals { length }) => { + read_literals::<_, _, S>(&mut self.reader, &mut self.decoded_ints, length) + } + Some(EncodingType::Run { length, delta }) => { + read_run::<_, _, S>(&mut self.reader, &mut self.decoded_ints, length, delta) + } + None => Ok(()), + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use crate::encoding::{integer::UnsignedEncoding, PrimitiveValueDecoder}; + + use super::*; + + fn test_helper(data: &[u8], expected: &[i64]) { + let mut reader = RleV1Decoder::::new(Cursor::new(data)); + let mut actual = vec![0; expected.len()]; + reader.decode(&mut actual).unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn test_run() -> Result<()> { + let data = [0x61, 0x00, 0x07]; + let expected = [7; 100]; + test_helper(&data, &expected); + + let data = [0x61, 0xff, 0x64]; + let expected = (1..=100).rev().collect::>(); + test_helper(&data, &expected); + + Ok(()) + } + + #[test] + fn test_literal() -> Result<()> { + let data = [0xfb, 0x02, 0x03, 0x06, 0x07, 0xb]; + let expected = vec![2, 3, 6, 7, 11]; + test_helper(&data, &expected); + + Ok(()) + } +} diff --git a/src/encoding/integer/rle_v2/delta.rs b/src/encoding/integer/rle_v2/delta.rs new file mode 100644 index 0000000..63e81b7 --- /dev/null +++ b/src/encoding/integer/rle_v2/delta.rs @@ -0,0 +1,289 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use bytes::{BufMut, BytesMut}; +use snafu::OptionExt; + +use crate::{ + encoding::{ + integer::{ + rle_v2::{EncodingType, MAX_RUN_LENGTH}, + util::{ + extract_run_length_from_header, read_ints, read_varint_zigzagged, + rle_v2_decode_bit_width, rle_v2_encode_bit_width, write_aligned_packed_ints, + write_varint_zigzagged, + }, + EncodingSign, SignedEncoding, VarintSerde, + }, + util::read_u8, + }, + error::{OrcError, OutOfSpecSnafu, Result}, +}; + +use super::NInt; + +/// We use i64 and u64 for delta to make things easier and to avoid edge cases, +/// as for example for i16, the delta may be too large to represent in an i16. +// TODO: expand on the above +pub fn read_delta_values( + reader: &mut R, + out_ints: &mut Vec, + deltas: &mut Vec, + header: u8, +) -> Result<()> { + // Encoding format: + // 2 bytes header + // - 2 bits for encoding type (constant 3) + // - 5 bits for encoded delta bitwidth (0 to 64) + // - 9 bits for run length (1 to 512) + // Base value (signed or unsigned) varint + // Delta value signed varint + // Sequence of delta values + + let encoded_delta_bit_width = (header >> 1) & 0x1f; + // Uses same encoding table as for direct & patched base, + // but special case where 0 indicates 0 width (for fixed delta) + let delta_bit_width = if encoded_delta_bit_width == 0 { + encoded_delta_bit_width as usize + } else { + rle_v2_decode_bit_width(encoded_delta_bit_width) + }; + + let second_byte = read_u8(reader)?; + let length = extract_run_length_from_header(header, second_byte); + + let base_value = read_varint_zigzagged::(reader)?; + out_ints.push(base_value); + + // Always signed since can be decreasing sequence + let delta_base = read_varint_zigzagged::(reader)?; + // TODO: does this get inlined? + let op: fn(N, i64) -> Option = if delta_base.is_positive() { + |acc, delta| acc.add_i64(delta) + } else { + |acc, delta| acc.sub_i64(delta) + }; + let delta_base = delta_base.abs(); // TODO: i64::MIN? + + if delta_bit_width == 0 { + // If width is 0 then all values have fixed delta of delta_base + // Skip first value since that's base_value + (1..length).try_fold(base_value, |acc, _| { + let acc = op(acc, delta_base).context(OutOfSpecSnafu { + msg: "over/underflow when decoding delta integer", + })?; + out_ints.push(acc); + Ok::<_, OrcError>(acc) + })?; + } else { + deltas.clear(); + // Add delta base and first value + let second_value = op(base_value, delta_base).context(OutOfSpecSnafu { + msg: "over/underflow when decoding delta integer", + })?; + out_ints.push(second_value); + // Run length includes base value and first delta, so skip them + let length = length - 2; + + // Unpack the delta values + read_ints(deltas, length, delta_bit_width, reader)?; + let mut acc = second_value; + // Each element is the delta, so find actual value using running accumulator + for delta in deltas { + acc = op(acc, *delta).context(OutOfSpecSnafu { + msg: "over/underflow when decoding delta integer", + })?; + out_ints.push(acc); + } + } + Ok(()) +} + +pub fn write_varying_delta( + writer: &mut BytesMut, + base_value: N, + first_delta: i64, + max_delta: i64, + subsequent_deltas: &[i64], +) { + debug_assert!( + max_delta > 0, + "varying deltas must have at least one non-zero delta" + ); + let bit_width = max_delta.closest_aligned_bit_width(); + // We can't have bit width of 1 for delta as that would get decoded as + // 0 bit width on reader, which indicates fixed delta, so bump 1 to 2 + // in this case. + let bit_width = if bit_width == 1 { 2 } else { bit_width }; + // Add 2 to len for the base_value and first_delta + let header = derive_delta_header(bit_width, subsequent_deltas.len() + 2); + writer.put_slice(&header); + + write_varint_zigzagged::<_, S>(writer, base_value); + // First delta always signed to indicate increasing/decreasing sequence + write_varint_zigzagged::<_, SignedEncoding>(writer, first_delta); + + // Bitpacked deltas + write_aligned_packed_ints(writer, bit_width, subsequent_deltas); +} + +pub fn write_fixed_delta( + writer: &mut BytesMut, + base_value: N, + fixed_delta: i64, + subsequent_deltas_len: usize, +) { + // Assuming len excludes base_value and first delta + let header = derive_delta_header(0, subsequent_deltas_len + 2); + writer.put_slice(&header); + + write_varint_zigzagged::<_, S>(writer, base_value); + // First delta always signed to indicate increasing/decreasing sequence + write_varint_zigzagged::<_, SignedEncoding>(writer, fixed_delta); +} + +fn derive_delta_header(delta_width: usize, run_length: usize) -> [u8; 2] { + debug_assert!( + (1..=MAX_RUN_LENGTH).contains(&run_length), + "delta run length cannot exceed 512 values" + ); + // [1, 512] to [0, 511] + let run_length = run_length as u16 - 1; + // 0 is special value to indicate fixed delta + let delta_width = if delta_width == 0 { + 0 + } else { + rle_v2_encode_bit_width(delta_width) + }; + // No need to mask as we guarantee max length is 512 + let encoded_length_high_bit = (run_length >> 8) as u8; + let encoded_length_low_bits = (run_length & 0xFF) as u8; + + let header1 = EncodingType::Delta.to_header() | delta_width << 1 | encoded_length_high_bit; + let header2 = encoded_length_low_bits; + + [header1, header2] +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use crate::encoding::integer::UnsignedEncoding; + + use super::*; + + // TODO: figure out how to write proptests for these + + #[test] + fn test_fixed_delta_positive() { + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_fixed_delta::(&mut buf, 0, 10, 100 - 2); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let expected = (0..100).map(|i| i * 10).collect::>(); + assert_eq!(expected, out); + } + + #[test] + fn test_fixed_delta_negative() { + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_fixed_delta::(&mut buf, 10_000, -63, 150 - 2); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let expected = (0..150).map(|i| 10_000 - i * 63).collect::>(); + assert_eq!(expected, out); + } + + #[test] + fn test_varying_delta_positive() { + let deltas = [ + 1, 6, 98, 12, 65, 9, 0, 0, 1, 128, 643, 129, 469, 123, 4572, 124, + ]; + let max = *deltas.iter().max().unwrap(); + + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_varying_delta::(&mut buf, 0, 10, max, &deltas); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let mut expected = vec![0, 10]; + let mut i = 1; + for d in deltas { + expected.push(d + expected[i]); + i += 1; + } + assert_eq!(expected, out); + } + + #[test] + fn test_varying_delta_negative() { + let deltas = [ + 1, 6, 98, 12, 65, 9, 0, 0, 1, 128, 643, 129, 469, 123, 4572, 124, + ]; + let max = *deltas.iter().max().unwrap(); + + let mut buf = BytesMut::new(); + let mut out = vec![]; + let mut deltas = vec![]; + write_varying_delta::(&mut buf, 10_000, -1, max, &deltas); + let header = buf[0]; + read_delta_values::( + &mut Cursor::new(&buf[1..]), + &mut out, + &mut deltas, + header, + ) + .unwrap(); + + let mut expected = vec![10_000, 9_999]; + let mut i = 1; + for d in deltas { + expected.push(expected[i] - d); + i += 1; + } + assert_eq!(expected, out); + } +} diff --git a/src/encoding/integer/rle_v2/direct.rs b/src/encoding/integer/rle_v2/direct.rs new file mode 100644 index 0000000..f05ec85 --- /dev/null +++ b/src/encoding/integer/rle_v2/direct.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use bytes::{BufMut, BytesMut}; + +use crate::{ + encoding::{ + integer::{ + rle_v2::{EncodingType, MAX_RUN_LENGTH}, + util::{ + extract_run_length_from_header, read_ints, rle_v2_decode_bit_width, + rle_v2_encode_bit_width, write_aligned_packed_ints, + }, + EncodingSign, + }, + util::read_u8, + }, + error::{OutOfSpecSnafu, Result}, +}; + +use super::NInt; + +pub fn read_direct_values( + reader: &mut R, + out_ints: &mut Vec, + header: u8, +) -> Result<()> { + let encoded_bit_width = (header >> 1) & 0x1F; + let bit_width = rle_v2_decode_bit_width(encoded_bit_width); + + if (N::BYTE_SIZE * 8) < bit_width { + return OutOfSpecSnafu { + msg: "byte width of direct encoding exceeds byte size of integer being decoded to", + } + .fail(); + } + + let second_byte = read_u8(reader)?; + let length = extract_run_length_from_header(header, second_byte); + + // Write the unpacked values and zigzag decode to result buffer + read_ints(out_ints, length, bit_width, reader)?; + + for lit in out_ints.iter_mut() { + *lit = S::zigzag_decode(*lit); + } + + Ok(()) +} + +/// `values` and `max` must be zigzag encoded. If `max` is not provided, it is derived +/// by iterating over `values`. +pub fn write_direct(writer: &mut BytesMut, values: &[N], max: Option) { + debug_assert!( + (1..=MAX_RUN_LENGTH).contains(&values.len()), + "direct run length cannot exceed 512 values" + ); + + let max = max.unwrap_or_else(|| { + // Assert guards that values is non-empty + *values.iter().max_by_key(|x| x.bits_used()).unwrap() + }); + + let bit_width = max.closest_aligned_bit_width(); + let encoded_bit_width = rle_v2_encode_bit_width(bit_width); + // From [1, 512] to [0, 511] + let encoded_length = values.len() as u16 - 1; + // No need to mask as we guarantee max length is 512 + let encoded_length_high_bit = (encoded_length >> 8) as u8; + let encoded_length_low_bits = (encoded_length & 0xFF) as u8; + + let header1 = + EncodingType::Direct.to_header() | (encoded_bit_width << 1) | encoded_length_high_bit; + let header2 = encoded_length_low_bits; + + writer.put_u8(header1); + writer.put_u8(header2); + write_aligned_packed_ints(writer, bit_width, values); +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use proptest::prelude::*; + + use crate::encoding::integer::{SignedEncoding, UnsignedEncoding}; + + use super::*; + + fn roundtrip_direct_helper(values: &[N]) -> Result> { + let mut buf = BytesMut::new(); + let mut out = vec![]; + + write_direct(&mut buf, values, None); + let header = buf[0]; + read_direct_values::<_, _, S>(&mut Cursor::new(&buf[1..]), &mut out, header)?; + + Ok(out) + } + + #[test] + fn test_direct_edge_case() { + let values: Vec = vec![109, -17809, -29946, -17285]; + let encoded = values + .iter() + .map(|&v| SignedEncoding::zigzag_encode(v)) + .collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded).unwrap(); + assert_eq!(out, values); + } + + proptest! { + #[test] + fn roundtrip_direct_i16(values in prop::collection::vec(any::(), 1..=512)) { + let encoded = values.iter().map(|v| SignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_direct_i32(values in prop::collection::vec(any::(), 1..=512)) { + let encoded = values.iter().map(|v| SignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_direct_i64(values in prop::collection::vec(any::(), 1..=512)) { + let encoded = values.iter().map(|v| SignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, SignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_direct_i64_unsigned(values in prop::collection::vec(0..=i64::MAX, 1..=512)) { + let encoded = values.iter().map(|v| UnsignedEncoding::zigzag_encode(*v)).collect::>(); + let out = roundtrip_direct_helper::<_, UnsignedEncoding>(&encoded)?; + prop_assert_eq!(out, values); + } + } +} diff --git a/src/encoding/integer/rle_v2/mod.rs b/src/encoding/integer/rle_v2/mod.rs new file mode 100644 index 0000000..ed871cf --- /dev/null +++ b/src/encoding/integer/rle_v2/mod.rs @@ -0,0 +1,677 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{io::Read, marker::PhantomData}; + +use bytes::BytesMut; + +use crate::{ + encoding::{rle::GenericRle, util::try_read_u8, PrimitiveValueEncoder}, + error::Result, + memory::EstimateMemory, +}; + +use self::{ + delta::{read_delta_values, write_fixed_delta, write_varying_delta}, + direct::{read_direct_values, write_direct}, + patched_base::{read_patched_base, write_patched_base}, + short_repeat::{read_short_repeat_values, write_short_repeat}, +}; + +use super::{util::calculate_percentile_bits, EncodingSign, NInt, VarintSerde}; + +mod delta; +mod direct; +mod patched_base; +mod short_repeat; + +const MAX_RUN_LENGTH: usize = 512; +/// Minimum number of repeated values required to use Short Repeat sub-encoding +const SHORT_REPEAT_MIN_LENGTH: usize = 3; +const SHORT_REPEAT_MAX_LENGTH: usize = 10; +const BASE_VALUE_LIMIT: i64 = 1 << 56; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +// TODO: put header data in here, e.g. base value, len, etc. +enum EncodingType { + ShortRepeat, + Direct, + PatchedBase, + Delta, +} + +impl EncodingType { + /// Checking highest two bits for encoding type. + #[inline] + fn from_header(header: u8) -> Self { + match header & 0b_1100_0000 { + 0b_1100_0000 => Self::Delta, + 0b_1000_0000 => Self::PatchedBase, + 0b_0100_0000 => Self::Direct, + 0b_0000_0000 => Self::ShortRepeat, + _ => unreachable!(), + } + } + + /// Return byte with highest two bits set according to variant. + #[inline] + fn to_header(self) -> u8 { + match self { + EncodingType::Delta => 0b_1100_0000, + EncodingType::PatchedBase => 0b_1000_0000, + EncodingType::Direct => 0b_0100_0000, + EncodingType::ShortRepeat => 0b_0000_0000, + } + } +} + +pub struct RleV2Decoder { + reader: R, + decoded_ints: Vec, + /// Indexes into decoded_ints to make it act like a queue + current_head: usize, + deltas: Vec, + sign: PhantomData, +} + +impl RleV2Decoder { + pub fn new(reader: R) -> Self { + Self { + reader, + decoded_ints: Vec::with_capacity(MAX_RUN_LENGTH), + current_head: 0, + deltas: Vec::with_capacity(MAX_RUN_LENGTH), + sign: Default::default(), + } + } +} + +impl GenericRle for RleV2Decoder { + fn advance(&mut self, n: usize) { + self.current_head += n; + } + + fn available(&self) -> &[N] { + &self.decoded_ints[self.current_head..] + } + + fn decode_batch(&mut self) -> Result<()> { + self.current_head = 0; + self.decoded_ints.clear(); + let header = match try_read_u8(&mut self.reader)? { + Some(byte) => byte, + None => return Ok(()), + }; + + match EncodingType::from_header(header) { + EncodingType::ShortRepeat => read_short_repeat_values::<_, _, S>( + &mut self.reader, + &mut self.decoded_ints, + header, + )?, + EncodingType::Direct => { + read_direct_values::<_, _, S>(&mut self.reader, &mut self.decoded_ints, header)? + } + EncodingType::PatchedBase => { + read_patched_base::<_, _, S>(&mut self.reader, &mut self.decoded_ints, header)? + } + EncodingType::Delta => read_delta_values::<_, _, S>( + &mut self.reader, + &mut self.decoded_ints, + &mut self.deltas, + header, + )?, + } + + Ok(()) + } +} + +struct DeltaEncodingCheckResult { + base_value: N, + min: N, + max: N, + first_delta: i64, + max_delta: i64, + is_monotonic: bool, + is_fixed_delta: bool, + adjacent_deltas: Vec, +} + +/// Calculate the necessary values to determine if sequence can be delta encoded. +fn delta_encoding_check(literals: &[N]) -> DeltaEncodingCheckResult { + let base_value = literals[0]; + let mut min = base_value.min(literals[1]); + let mut max = base_value.max(literals[1]); + // Saturating should be fine here (and below) as we later check the + // difference between min & max and defer to direct encoding if it + // is too large (so the corrupt delta here won't actually be used). + // TODO: is there a more explicit way of ensuring this behaviour? + let first_delta = literals[1].as_i64().saturating_sub(base_value.as_i64()); + let mut current_delta; + let mut max_delta = 0; + + let mut is_increasing = first_delta.is_positive(); + let mut is_decreasing = first_delta.is_negative(); + let mut is_fixed_delta = true; + + let mut adjacent_deltas = vec![]; + + // We've already preprocessed the first step above + for i in 2..literals.len() { + let l1 = literals[i]; + let l0 = literals[i - 1]; + + min = min.min(l1); + max = max.max(l1); + + current_delta = l1.as_i64().saturating_sub(l0.as_i64()); + + is_increasing &= current_delta >= 0; + is_decreasing &= current_delta <= 0; + + is_fixed_delta &= current_delta == first_delta; + let current_delta = current_delta.saturating_abs(); + adjacent_deltas.push(current_delta); + max_delta = max_delta.max(current_delta); + } + let is_monotonic = is_increasing || is_decreasing; + + DeltaEncodingCheckResult { + base_value, + min, + max, + first_delta, + max_delta, + is_monotonic, + is_fixed_delta, + adjacent_deltas, + } +} + +/// Runs are guaranteed to have length > 1. +#[derive(Debug, Clone, Eq, PartialEq)] +enum RleV2EncodingState { + /// When buffer is empty and no values to encode. + Empty, + /// Special state for first value as we determine after the first + /// value whether to go fixed or variable run. + One(N), + /// Run of identical value of specified count. + FixedRun { value: N, count: usize }, + /// Run of variable values. + VariableRun { literals: Vec }, +} + +impl Default for RleV2EncodingState { + fn default() -> Self { + Self::Empty + } +} + +pub struct RleV2Encoder { + /// Stores the run length encoded sequences. + data: BytesMut, + /// Used in state machine for determining which sub-encoding + /// for a sequence to use. + state: RleV2EncodingState, + phantom: PhantomData, +} + +impl RleV2Encoder { + // Algorithm adapted from: + // https://github.com/apache/orc/blob/main/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java + + /// Process each value to build up knowledge to determine which encoding to use. We attempt + /// to identify runs of identical values (fixed runs), otherwise falling back to variable + /// runs (varying values). + /// + /// When in a fixed run state, as long as identical values are found, we keep incrementing + /// the run length up to a maximum of 512, flushing to fixed delta run if so. If we encounter + /// a differing value, we flush to short repeat or fixed delta depending on the length and + /// reset the state (if the current run is small enough, we switch direct to variable run). + /// + /// When in a variable run state, if we find 3 identical values in a row as the latest values, + /// we flush the variable run to a sub-encoding then switch to fixed run, otherwise continue + /// incrementing the run length up to a max length of 512, before flushing and resetting the + /// state. For a variable run, extra logic must take place to determine which sub-encoding to + /// use when flushing, see [`Self::determine_variable_run_encoding`] for more details. + fn process_value(&mut self, value: N) { + match &mut self.state { + // When we start, or when a run was flushed to a sub-encoding + RleV2EncodingState::Empty => { + self.state = RleV2EncodingState::One(value); + } + // Here we determine if we look like we're in a fixed run or variable run + RleV2EncodingState::One(one_value) => { + if value == *one_value { + self.state = RleV2EncodingState::FixedRun { value, count: 2 }; + } else { + // TODO: alloc here + let mut literals = Vec::with_capacity(MAX_RUN_LENGTH); + literals.push(*one_value); + literals.push(value); + self.state = RleV2EncodingState::VariableRun { literals }; + } + } + // When we're in a run of identical values + RleV2EncodingState::FixedRun { + value: fixed_value, + count, + } => { + if value == *fixed_value { + // Continue fixed run, flushing to delta when max length reached + *count += 1; + if *count == MAX_RUN_LENGTH { + write_fixed_delta::<_, S>(&mut self.data, value, 0, *count - 2); + self.state = RleV2EncodingState::Empty; + } + } else { + // If fixed run is broken by a different value. + match count { + // Note that count cannot be 0 or 1 here as that is encoded + // by Empty and One states in self.state + 2 => { + // If fixed run is smaller than short repeat then just include + // it at the start of the variable run we're switching to. + // TODO: alloc here + let mut literals = Vec::with_capacity(MAX_RUN_LENGTH); + literals.push(*fixed_value); + literals.push(*fixed_value); + literals.push(value); + self.state = RleV2EncodingState::VariableRun { literals }; + } + SHORT_REPEAT_MIN_LENGTH..=SHORT_REPEAT_MAX_LENGTH => { + // If we have enough values for a Short Repeat, then encode as + // such. + write_short_repeat::<_, S>(&mut self.data, *fixed_value, *count); + self.state = RleV2EncodingState::One(value); + } + _ => { + // Otherwise if too large, use Delta encoding. + write_fixed_delta::<_, S>(&mut self.data, *fixed_value, 0, *count - 2); + self.state = RleV2EncodingState::One(value); + } + } + } + } + // When we're in a run of varying values + RleV2EncodingState::VariableRun { literals } => { + let length = literals.len(); + let last_value = literals[length - 1]; + let second_last_value = literals[length - 2]; + if value == last_value && value == second_last_value { + // Last 3 values (including current new one) are identical. Break the current + // variable run, flushing it to a sub-encoding, then switch to a fixed run + // state. + + // Pop off the last two values (which are identical to value) and flush + // the variable run to writer + literals.truncate(literals.len() - 2); + determine_variable_run_encoding::<_, S>(&mut self.data, literals); + + self.state = RleV2EncodingState::FixedRun { value, count: 3 }; + } else { + // Continue variable run, flushing sub-encoding if max length reached + literals.push(value); + if literals.len() == MAX_RUN_LENGTH { + determine_variable_run_encoding::<_, S>(&mut self.data, literals); + self.state = RleV2EncodingState::Empty; + } + } + } + } + } + + /// Flush any buffered values to the writer. + fn flush(&mut self) { + let state = std::mem::take(&mut self.state); + match state { + RleV2EncodingState::Empty => {} + RleV2EncodingState::One(value) => { + let value = S::zigzag_encode(value); + write_direct(&mut self.data, &[value], Some(value)); + } + RleV2EncodingState::FixedRun { value, count: 2 } => { + // Direct has smallest overhead + let value = S::zigzag_encode(value); + write_direct(&mut self.data, &[value, value], Some(value)); + } + RleV2EncodingState::FixedRun { value, count } if count <= SHORT_REPEAT_MAX_LENGTH => { + // Short repeat must have length [3, 10] + write_short_repeat::<_, S>(&mut self.data, value, count); + } + RleV2EncodingState::FixedRun { value, count } => { + write_fixed_delta::<_, S>(&mut self.data, value, 0, count - 2); + } + RleV2EncodingState::VariableRun { mut literals } => { + determine_variable_run_encoding::<_, S>(&mut self.data, &mut literals); + } + } + } +} + +impl EstimateMemory for RleV2Encoder { + fn estimate_memory_size(&self) -> usize { + self.data.len() + } +} + +impl PrimitiveValueEncoder for RleV2Encoder { + fn new() -> Self { + Self { + data: BytesMut::new(), + state: RleV2EncodingState::Empty, + phantom: Default::default(), + } + } + + fn write_one(&mut self, value: N) { + self.process_value(value); + } + + fn take_inner(&mut self) -> bytes::Bytes { + self.flush(); + std::mem::take(&mut self.data).into() + } +} + +fn determine_variable_run_encoding( + writer: &mut BytesMut, + literals: &mut [N], +) { + // Direct will have smallest overhead for tiny runs + if literals.len() <= SHORT_REPEAT_MIN_LENGTH { + for v in literals.iter_mut() { + *v = S::zigzag_encode(*v); + } + write_direct(writer, literals, None); + return; + } + + // Invariant: literals.len() > 3 + let DeltaEncodingCheckResult { + base_value, + min, + max, + first_delta, + max_delta, + is_monotonic, + is_fixed_delta, + adjacent_deltas, + } = delta_encoding_check(literals); + + // Quick check for delta overflow, if so just move to Direct as it has less + // overhead than Patched Base. + // TODO: should min/max be N or i64 here? + if max.checked_sub(&min).is_none() { + for v in literals.iter_mut() { + *v = S::zigzag_encode(*v); + } + write_direct(writer, literals, None); + return; + } + + // Any subtractions here on are safe due to above check + + if is_fixed_delta { + write_fixed_delta::<_, S>(writer, literals[0], first_delta, literals.len() - 2); + return; + } + + // First delta used to indicate if increasing or decreasing, so must be non-zero + if first_delta != 0 && is_monotonic { + write_varying_delta::<_, S>(writer, base_value, first_delta, max_delta, &adjacent_deltas); + return; + } + + // In Java implementation, Patched Base encoding base value cannot exceed 56 + // bits in value otherwise it can overflow the maximum 8 bytes used to encode + // the value when signed MSB encoding is used (adds an extra bit). + let min = min.as_i64(); + if min.abs() >= BASE_VALUE_LIMIT && min != i64::MIN { + for v in literals.iter_mut() { + *v = S::zigzag_encode(*v); + } + write_direct(writer, literals, None); + return; + } + + // TODO: another allocation here + let zigzag_literals = literals + .iter() + .map(|&v| S::zigzag_encode(v)) + .collect::>(); + let zigzagged_90_percentile_bit_width = calculate_percentile_bits(&zigzag_literals, 0.90); + // TODO: can derive from min/max? + let zigzagged_100_percentile_bit_width = calculate_percentile_bits(&zigzag_literals, 1.00); + // If variation of bit width between largest value and lower 90% of values isn't + // significant enough, just use direct encoding as patched base wouldn't be as + // efficient. + if (zigzagged_100_percentile_bit_width.saturating_sub(zigzagged_90_percentile_bit_width)) <= 1 { + // TODO: pass through the 100p here + write_direct(writer, &zigzag_literals, None); + return; + } + + // Base value for patched base is the minimum value + // Patch data values are the literals with the base value subtracted + // We use base_reduced_literals to store these base reduced literals + let mut max_data_value = 0; + let mut base_reduced_literals = vec![]; + for l in literals.iter() { + // All base reduced literals become positive here + let base_reduced_literal = l.as_i64() - min; + base_reduced_literals.push(base_reduced_literal); + max_data_value = max_data_value.max(base_reduced_literal); + } + + // Aka 100th percentile + let base_reduced_literals_max_bit_width = max_data_value.closest_aligned_bit_width(); + // 95th percentile width is used to find the 5% of values to encode with patches + let base_reduced_literals_95th_percentile_bit_width = + calculate_percentile_bits(&base_reduced_literals, 0.95); + + // Patch only if we have outliers, based on bit width + if base_reduced_literals_max_bit_width != base_reduced_literals_95th_percentile_bit_width { + write_patched_base( + writer, + &mut base_reduced_literals, + min, + base_reduced_literals_max_bit_width, + base_reduced_literals_95th_percentile_bit_width, + ); + } else { + // TODO: pass through the 100p here + write_direct(writer, &zigzag_literals, None); + } +} + +#[cfg(test)] +mod tests { + + use std::io::Cursor; + + use proptest::prelude::*; + + use crate::encoding::{ + integer::{SignedEncoding, UnsignedEncoding}, + PrimitiveValueDecoder, + }; + + use super::*; + + // TODO: have tests varying the out buffer, to ensure decode() is called + // multiple times + + fn test_helper(data: &[u8], expected: &[i64]) { + let mut reader = RleV2Decoder::::new(Cursor::new(data)); + let mut actual = vec![0; expected.len()]; + reader.decode(&mut actual).unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn reader_test() { + let data = [2, 1, 64, 5, 80, 1, 1]; + let expected = [1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1]; + test_helper::(&data, &expected); + + // direct + let data = [0x5e, 0x03, 0x5c, 0xa1, 0xab, 0x1e, 0xde, 0xad, 0xbe, 0xef]; + let expected = [23713, 43806, 57005, 48879]; + test_helper::(&data, &expected); + + // patched base + let data = [ + 102, 9, 0, 126, 224, 7, 208, 0, 126, 79, 66, 64, 0, 127, 128, 8, 2, 0, 128, 192, 8, 22, + 0, 130, 0, 8, 42, + ]; + let expected = [ + 2030, 2000, 2020, 1000000, 2040, 2050, 2060, 2070, 2080, 2090, + ]; + test_helper::(&data, &expected); + + let data = [196, 9, 2, 2, 74, 40, 166]; + let expected = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]; + test_helper::(&data, &expected); + + let data = [0xc6, 0x09, 0x02, 0x02, 0x22, 0x42, 0x42, 0x46]; + let expected = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]; + test_helper::(&data, &expected); + + let data = [7, 1]; + let expected = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + test_helper::(&data, &expected); + } + + #[test] + fn short_repeat() { + let data = [0x0a, 0x27, 0x10]; + let expected = [10000, 10000, 10000, 10000, 10000]; + test_helper::(&data, &expected); + } + + #[test] + fn direct() { + let data = [0x5e, 0x03, 0x5c, 0xa1, 0xab, 0x1e, 0xde, 0xad, 0xbe, 0xef]; + let expected = [23713, 43806, 57005, 48879]; + test_helper::(&data, &expected); + } + + #[test] + fn direct_signed() { + let data = [110, 3, 0, 185, 66, 1, 86, 60, 1, 189, 90, 1, 125, 222]; + let expected = [23713, 43806, 57005, 48879]; + test_helper::(&data, &expected); + } + + #[test] + fn delta() { + let data = [0xc6, 0x09, 0x02, 0x02, 0x22, 0x42, 0x42, 0x46]; + let expected = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]; + test_helper::(&data, &expected); + } + + #[test] + fn patched_base() { + let data = [ + 0x8e, 0x09, 0x2b, 0x21, 0x07, 0xd0, 0x1e, 0x00, 0x14, 0x70, 0x28, 0x32, 0x3c, 0x46, + 0x50, 0x5a, 0xfc, 0xe8, + ]; + let expected = [ + 2030, 2000, 2020, 1000000, 2040, 2050, 2060, 2070, 2080, 2090, + ]; + test_helper::(&data, &expected); + } + + #[test] + fn patched_base_1() { + let data = vec![ + 144, 109, 4, 164, 141, 16, 131, 194, 0, 240, 112, 64, 60, 84, 24, 3, 193, 201, 128, + 120, 60, 33, 4, 244, 3, 193, 192, 224, 128, 56, 32, 15, 22, 131, 129, 225, 0, 112, 84, + 86, 14, 8, 106, 193, 192, 228, 160, 64, 32, 14, 213, 131, 193, 192, 240, 121, 124, 30, + 18, 9, 132, 67, 0, 224, 120, 60, 28, 14, 32, 132, 65, 192, 240, 160, 56, 61, 91, 7, 3, + 193, 192, 240, 120, 76, 29, 23, 7, 3, 220, 192, 240, 152, 60, 52, 15, 7, 131, 129, 225, + 0, 144, 56, 30, 14, 44, 140, 129, 194, 224, 120, 0, 28, 15, 8, 6, 129, 198, 144, 128, + 104, 36, 27, 11, 38, 131, 33, 48, 224, 152, 60, 111, 6, 183, 3, 112, 0, 1, 78, 5, 46, + 2, 1, 1, 141, 3, 1, 1, 138, 22, 0, 65, 1, 4, 0, 225, 16, 209, 192, 4, 16, 8, 36, 16, 3, + 48, 1, 3, 13, 33, 0, 176, 0, 1, 94, 18, 0, 68, 0, 33, 1, 143, 0, 1, 7, 93, 0, 25, 0, 5, + 0, 2, 0, 4, 0, 1, 0, 1, 0, 2, 0, 16, 0, 1, 11, 150, 0, 3, 0, 1, 0, 1, 99, 157, 0, 1, + 140, 54, 0, 162, 1, 130, 0, 16, 112, 67, 66, 0, 2, 4, 0, 0, 224, 0, 1, 0, 16, 64, 16, + 91, 198, 1, 2, 0, 32, 144, 64, 0, 12, 2, 8, 24, 0, 64, 0, 1, 0, 0, 8, 48, 51, 128, 0, + 2, 12, 16, 32, 32, 71, 128, 19, 76, + ]; + // expected data generated from Orc Java implementation + let expected = vec![ + 20, 2, 3, 2, 1, 3, 17, 71, 35, 2, 1, 139, 2, 2, 3, 1783, 475, 2, 1, 1, 3, 1, 3, 2, 32, + 1, 2, 3, 1, 8, 30, 1, 3, 414, 1, 1, 135, 3, 3, 1, 414, 2, 1, 2, 2, 594, 2, 5, 6, 4, 11, + 1, 2, 2, 1, 1, 52, 4, 1, 2, 7, 1, 17, 334, 1, 2, 1, 2, 2, 6, 1, 266, 1, 2, 217, 2, 6, + 2, 13, 2, 2, 1, 2, 3, 5, 1, 2, 1, 7244, 11813, 1, 33, 2, -13, 1, 2, 3, 13, 1, 92, 3, + 13, 5, 14, 9, 141, 12, 6, 15, 25, -1, -1, -1, 23, 1, -1, -1, -71, -2, -1, -1, -1, -1, + 2, 1, 4, 34, 5, 78, 8, 1, 2, 2, 1, 9, 10, 2, 1, 4, 13, 1, 5, 4, 4, 19, 5, -1, -1, -1, + 34, -17, -200, -1, -943, -13, -3, 1, 2, -1, -1, 1, 8, -1, 1483, -2, -1, -1, -12751, -1, + -1, -1, 66, 1, 3, 8, 131, 14, 5, 1, 2, 2, 1, 1, 8, 1, 1, 2, 1, 5, 9, 2, 3, 112, 13, 2, + 2, 1, 5, 10, 3, 1, 1, 13, 2, 3, 4, 1, 3, 1, 1, 2, 1, 1, 2, 4, 2, 207, 1, 1, 2, 4, 3, 3, + 2, 2, 16, + ]; + test_helper::(&data, &expected); + } + + // TODO: be smarter about prop test here, generate different patterns of ints + // - e.g. increasing/decreasing sequences, outliers, repeated + // - to ensure all different subencodings are being used (and might make shrinking better) + // currently 99% of the time here the subencoding will be Direct due to random generation + + fn roundtrip_helper(values: &[N]) -> Result> { + let mut writer = RleV2Encoder::::new(); + writer.write_slice(values); + let data = writer.take_inner(); + + let mut reader = RleV2Decoder::::new(Cursor::new(data)); + let mut actual = vec![N::zero(); values.len()]; + reader.decode(&mut actual).unwrap(); + + Ok(actual) + } + + proptest! { + #[test] + fn roundtrip_i16(values in prop::collection::vec(any::(), 1..1_000)) { + let out = roundtrip_helper::<_, SignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_i32(values in prop::collection::vec(any::(), 1..1_000)) { + let out = roundtrip_helper::<_, SignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_i64(values in prop::collection::vec(any::(), 1..1_000)) { + let out = roundtrip_helper::<_, SignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_i64_unsigned(values in prop::collection::vec(0..=i64::MAX, 1..1_000)) { + let out = roundtrip_helper::<_, UnsignedEncoding>(&values)?; + prop_assert_eq!(out, values); + } + } +} diff --git a/src/encoding/integer/rle_v2/patched_base.rs b/src/encoding/integer/rle_v2/patched_base.rs new file mode 100644 index 0000000..c08800c --- /dev/null +++ b/src/encoding/integer/rle_v2/patched_base.rs @@ -0,0 +1,414 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use bytes::{BufMut, BytesMut}; +use snafu::OptionExt; + +use super::{EncodingType, NInt}; +use crate::{ + encoding::{ + integer::{ + util::{ + encode_bit_width, extract_run_length_from_header, get_closest_fixed_bits, + read_ints, rle_v2_decode_bit_width, signed_msb_encode, write_packed_ints, + }, + EncodingSign, VarintSerde, + }, + util::read_u8, + }, + error::{OutOfSpecSnafu, Result}, +}; + +pub fn read_patched_base( + reader: &mut R, + out_ints: &mut Vec, + header: u8, +) -> Result<()> { + let encoded_bit_width = (header >> 1) & 0x1F; + let value_bit_width = rle_v2_decode_bit_width(encoded_bit_width); + // Bit width derived from u8 above, so impossible to overflow u32 + let value_bit_width_u32 = u32::try_from(value_bit_width).unwrap(); + + let second_byte = read_u8(reader)?; + let length = extract_run_length_from_header(header, second_byte); + + let third_byte = read_u8(reader)?; + let fourth_byte = read_u8(reader)?; + + // Base width is one off + let base_byte_width = ((third_byte >> 5) & 0x07) as usize + 1; + + let patch_bit_width = rle_v2_decode_bit_width(third_byte & 0x1f); + + // Patch gap width is one off + let patch_gap_bit_width = ((fourth_byte >> 5) & 0x07) as usize + 1; + + let patch_total_bit_width = patch_bit_width + patch_gap_bit_width; + if patch_total_bit_width > 64 { + return OutOfSpecSnafu { + msg: "combined patch width and patch gap width cannot be greater than 64 bits", + } + .fail(); + } + + let patch_list_length = (fourth_byte & 0x1f) as usize; + + let base = N::read_big_endian(reader, base_byte_width)?; + let base = S::decode_signed_msb(base, base_byte_width); + + // Get data values + // TODO: this should read into Vec + // as base reduced values can exceed N::max() + // (e.g. if base is N::min() and this is signed type) + read_ints(out_ints, length, value_bit_width, reader)?; + + // Get patches that will be applied to base values. + // At most they might be u64 in width (because of check above). + let ceil_patch_total_bit_width = get_closest_fixed_bits(patch_total_bit_width); + let mut patches: Vec = Vec::with_capacity(patch_list_length); + read_ints( + &mut patches, + patch_list_length, + ceil_patch_total_bit_width, + reader, + )?; + + // TODO: document and explain below logic + let mut patch_index = 0; + let patch_mask = (1 << patch_bit_width) - 1; + let mut current_gap = patches[patch_index] >> patch_bit_width; + let mut current_patch = patches[patch_index] & patch_mask; + let mut actual_gap = 0; + + while current_gap == 255 && current_patch == 0 { + actual_gap += 255; + patch_index += 1; + current_gap = patches[patch_index] >> patch_bit_width; + current_patch = patches[patch_index] & patch_mask; + } + actual_gap += current_gap; + + for (idx, value) in out_ints.iter_mut().enumerate() { + if idx == actual_gap as usize { + let patch_bits = + current_patch + .checked_shl(value_bit_width_u32) + .context(OutOfSpecSnafu { + msg: "Overflow while shifting patch bits by value_bit_width", + })?; + // Safe conversion without loss as we check the bit width prior + let patch_bits = N::from_i64(patch_bits); + let patched_value = *value | patch_bits; + + *value = patched_value.checked_add(&base).context(OutOfSpecSnafu { + msg: "over/underflow when decoding patched base integer", + })?; + + patch_index += 1; + + if patch_index < patches.len() { + current_gap = patches[patch_index] >> patch_bit_width; + current_patch = patches[patch_index] & patch_mask; + actual_gap = 0; + + while current_gap == 255 && current_patch == 0 { + actual_gap += 255; + patch_index += 1; + current_gap = patches[patch_index] >> patch_bit_width; + current_patch = patches[patch_index] & patch_mask; + } + + actual_gap += current_gap; + actual_gap += idx as i64; + } + } else { + *value = value.checked_add(&base).context(OutOfSpecSnafu { + msg: "over/underflow when decoding patched base integer", + })?; + } + } + + Ok(()) +} + +fn derive_patches( + base_reduced_literals: &mut [i64], + patch_bits_width: usize, + max_base_value_bit_width: usize, +) -> (Vec, usize) { + // Values with bits exceeding this mask will be patched. + let max_base_value_mask = (1 << max_base_value_bit_width) - 1; + // Used to encode gaps greater than 255 (no patch bits, just used for gap). + let jump_patch = 255 << patch_bits_width; + + // At most 5% of values that must be patched. + // (Since max buffer length is 512, at most this can be 26) + let mut patches: Vec = Vec::with_capacity(26); + let mut last_patch_index = 0; + // Needed to determine bit width of patch gaps to encode in header. + let mut max_gap = 0; + for (idx, lit) in base_reduced_literals + .iter_mut() + .enumerate() + // Find all values which need to be patched (the 5% of values larger than the others) + .filter(|(_, &mut lit)| lit > max_base_value_mask) + { + // Convert to unsigned to ensure leftmost bits are 0 + let patch_bits = (*lit as u64) >> max_base_value_bit_width; + + // Gaps can at most be 255 (since gap bit width cannot exceed 8; in spec it states + // the header has only 3 bits to encode the size of the patch gap, so 8 is the largest + // value). + // + // Therefore if gap is found greater than 255 then we insert an empty patch with gap of 255 + // (and the empty patch will have no effect when reading as patching using empty bits will + // be a no-op). + // + // Extra special case if gap is 511, we unroll into inserting two empty patches (instead of + // relying on a loop). Max buffer size cannot exceed 512 so this is the largest possible gap. + let gap = idx - last_patch_index; + let gap = if gap == 511 { + max_gap = 255; + patches.push(jump_patch); + patches.push(jump_patch); + 1 + } else if gap > 255 { + max_gap = 255; + patches.push(jump_patch); + gap - 255 + } else { + max_gap = max_gap.max(gap); + gap + }; + let patch = patch_bits | (gap << patch_bits_width) as u64; + patches.push(patch as i64); + + last_patch_index = idx; + + // Stripping patch bits + *lit &= max_base_value_mask; + } + + // If only one element to be patched, and is the very first one. + // Patch gap width minimum is 1. + let patch_gap_width = if max_gap == 0 { + 1 + } else { + (max_gap as i16).bits_used() + }; + + (patches, patch_gap_width) +} + +pub fn write_patched_base( + writer: &mut BytesMut, + base_reduced_literals: &mut [i64], + base: i64, + brl_100p_bit_width: usize, + brl_95p_bit_width: usize, +) { + let patch_bits_width = brl_100p_bit_width - brl_95p_bit_width; + let patch_bits_width = get_closest_fixed_bits(patch_bits_width); + // According to spec, each patch (patch bits + gap) must be <= 64 bits. + // So we adjust accordingly here if we hit this edge case where patch_width + // is 64 bits (which would have no space for gap). + let (patch_bits_width, brl_95p_bit_width) = if patch_bits_width == 64 { + (56, 8) + } else { + (patch_bits_width, brl_95p_bit_width) + }; + + let (patches, patch_gap_width) = + derive_patches(base_reduced_literals, patch_bits_width, brl_95p_bit_width); + + let encoded_bit_width = encode_bit_width(brl_95p_bit_width) as u8; + + // [1, 512] to [0, 511] + let run_length = base_reduced_literals.len() as u16 - 1; + + // No need to mask as we guarantee max length is 512 + let encoded_length_high_bit = (run_length >> 8) as u8; + let encoded_length_low_bits = (run_length & 0xFF) as u8; + + // +1 to account for sign bit + let base_bit_width = get_closest_fixed_bits(base.abs().bits_used() + 1); + let base_byte_width = base_bit_width.div_ceil(8).max(1); + let msb_encoded_min = signed_msb_encode(base, base_byte_width); + // [1, 8] to [0, 7] + let encoded_base_width = base_byte_width - 1; + let encoded_patch_bits_width = encode_bit_width(patch_bits_width); + let encoded_patch_gap_width = patch_gap_width - 1; + + let header1 = + EncodingType::PatchedBase.to_header() | encoded_bit_width << 1 | encoded_length_high_bit; + let header2 = encoded_length_low_bits; + let header3 = (encoded_base_width as u8) << 5 | encoded_patch_bits_width as u8; + let header4 = (encoded_patch_gap_width as u8) << 5 | patches.len() as u8; + writer.put_slice(&[header1, header2, header3, header4]); + + // Write out base value as big endian bytes + let base_bytes = msb_encoded_min.to_be_bytes(); + // 8 since i64 + let base_bytes = &base_bytes.as_ref()[8 - base_byte_width..]; + writer.put_slice(base_bytes); + + // Writing base reduced literals followed by patch list + let bit_width = get_closest_fixed_bits(brl_95p_bit_width); + write_packed_ints(writer, bit_width, base_reduced_literals); + let bit_width = get_closest_fixed_bits(patch_gap_width + patch_bits_width); + write_packed_ints(writer, bit_width, &patches); +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use proptest::prelude::*; + + use crate::encoding::integer::{util::calculate_percentile_bits, SignedEncoding}; + + use super::*; + + #[derive(Debug)] + struct PatchesStrategy { + base: i64, + base_reduced_values: Vec, + patches: Vec, + patch_indices: Vec, + base_index: usize, + } + + fn patches_strategy() -> impl Strategy { + // TODO: clean this up a bit + prop::collection::vec(0..1_000_000_i64, 20..=512) + .prop_flat_map(|base_reduced_values| { + let base_strategy = -1_000_000_000..1_000_000_000_i64; + let max_patches_length = (base_reduced_values.len() as f32 * 0.05).ceil() as usize; + let base_reduced_values_strategy = Just(base_reduced_values); + let patches_strategy = prop::collection::vec( + 1_000_000_000_000_000..1_000_000_000_000_000_000_i64, + 1..=max_patches_length, + ); + ( + base_strategy, + base_reduced_values_strategy, + patches_strategy, + ) + }) + .prop_flat_map(|(base, base_reduced_values, patches)| { + let base_strategy = Just(base); + // +1 for the base index, so we don't have to deduplicate separately + let patch_indices_strategy = + prop::collection::hash_set(0..base_reduced_values.len(), patches.len() + 1); + let base_reduced_values_strategy = Just(base_reduced_values); + let patches_strategy = Just(patches); + ( + base_strategy, + base_reduced_values_strategy, + patches_strategy, + patch_indices_strategy, + ) + }) + .prop_map(|(base, base_reduced_values, patches, patch_indices)| { + let mut patch_indices = patch_indices.into_iter().collect::>(); + let base_index = patch_indices.pop().unwrap(); + PatchesStrategy { + base, + base_reduced_values, + patches, + patch_indices, + base_index, + } + }) + } + + fn roundtrip_patched_base_helper( + base_reduced_literals: &[i64], + base: i64, + brl_95p_bit_width: usize, + brl_100p_bit_width: usize, + ) -> Result> { + let mut base_reduced_literals = base_reduced_literals.to_vec(); + + let mut buf = BytesMut::new(); + let mut out = vec![]; + + write_patched_base( + &mut buf, + &mut base_reduced_literals, + base, + brl_100p_bit_width, + brl_95p_bit_width, + ); + let header = buf[0]; + read_patched_base::(&mut Cursor::new(&buf[1..]), &mut out, header)?; + + Ok(out) + } + + fn form_patched_base_values( + base_reduced_values: &[i64], + patches: &[i64], + patch_indices: &[usize], + base_index: usize, + ) -> Vec { + let mut base_reduced_values = base_reduced_values.to_vec(); + for (&patch, &index) in patches.iter().zip(patch_indices) { + base_reduced_values[index] = patch; + } + // Need at least one zero to represent the base + base_reduced_values[base_index] = 0; + base_reduced_values + } + + fn form_expected_values(base: i64, base_reduced_values: &[i64]) -> Vec { + base_reduced_values.iter().map(|&v| base + v).collect() + } + + proptest! { + #[test] + fn roundtrip_patched_base_i64(patches_strategy in patches_strategy()) { + let PatchesStrategy { + base, + base_reduced_values, + patches, + patch_indices, + base_index + } = patches_strategy; + let base_reduced_values = form_patched_base_values( + &base_reduced_values, + &patches, + &patch_indices, + base_index + ); + let expected = form_expected_values(base, &base_reduced_values); + let brl_95p_bit_width = calculate_percentile_bits(&base_reduced_values, 0.95); + let brl_100p_bit_width = calculate_percentile_bits(&base_reduced_values, 1.0); + // Need enough outliers to require patching + prop_assume!(brl_95p_bit_width != brl_100p_bit_width); + let actual = roundtrip_patched_base_helper( + &base_reduced_values, + base, + brl_95p_bit_width, + brl_100p_bit_width + )?; + prop_assert_eq!(actual, expected); + } + } +} diff --git a/src/encoding/integer/rle_v2/short_repeat.rs b/src/encoding/integer/rle_v2/short_repeat.rs new file mode 100644 index 0000000..304d2cb --- /dev/null +++ b/src/encoding/integer/rle_v2/short_repeat.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use bytes::{BufMut, BytesMut}; + +use crate::{ + encoding::integer::{rle_v2::EncodingType, EncodingSign}, + error::{OutOfSpecSnafu, Result}, +}; + +use super::{NInt, SHORT_REPEAT_MIN_LENGTH}; + +pub fn read_short_repeat_values( + reader: &mut R, + out_ints: &mut Vec, + header: u8, +) -> Result<()> { + // Header byte: + // + // eeww_wccc + // 7 0 LSB + // + // ee = Sub-encoding bits, always 00 + // www = Value width bits + // ccc = Repeat count bits + + let byte_width = (header >> 3) & 0x07; // Encoded as 0 to 7 + let byte_width = byte_width as usize + 1; // Decode to 1 to 8 bytes + + if N::BYTE_SIZE < byte_width { + return OutOfSpecSnafu { + msg: + "byte width of short repeat encoding exceeds byte size of integer being decoded to", + } + .fail(); + } + + let run_length = (header & 0x07) as usize + SHORT_REPEAT_MIN_LENGTH; + + // Value that is being repeated is encoded as value_byte_width bytes in big endian format + let val = N::read_big_endian(reader, byte_width)?; + let val = S::zigzag_decode(val); + + out_ints.extend(std::iter::repeat(val).take(run_length)); + + Ok(()) +} + +pub fn write_short_repeat(writer: &mut BytesMut, value: N, count: usize) { + debug_assert!((SHORT_REPEAT_MIN_LENGTH..=10).contains(&count)); + + let value = S::zigzag_encode(value); + + // Take max in case value = 0 + let byte_size = value.bits_used().div_ceil(8).max(1) as u8; + let encoded_byte_size = byte_size - 1; + let encoded_count = (count - SHORT_REPEAT_MIN_LENGTH) as u8; + + let header = EncodingType::ShortRepeat.to_header() | (encoded_byte_size << 3) | encoded_count; + let bytes = value.to_be_bytes(); + let bytes = &bytes.as_ref()[N::BYTE_SIZE - byte_size as usize..]; + + writer.put_u8(header); + writer.put_slice(bytes); +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use proptest::prelude::*; + + use crate::encoding::integer::{SignedEncoding, UnsignedEncoding}; + + use super::*; + + fn roundtrip_short_repeat_helper( + value: N, + count: usize, + ) -> Result> { + let mut buf = BytesMut::new(); + let mut out = vec![]; + + write_short_repeat::<_, S>(&mut buf, value, count); + let header = buf[0]; + read_short_repeat_values::<_, _, S>(&mut Cursor::new(&buf[1..]), &mut out, header)?; + + Ok(out) + } + + proptest! { + #[test] + fn roundtrip_short_repeat_i16(value: i16, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, SignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } + + #[test] + fn roundtrip_short_repeat_i32(value: i32, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, SignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } + + #[test] + fn roundtrip_short_repeat_i64(value: i64, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, SignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } + + #[test] + fn roundtrip_short_repeat_i64_unsigned(value in 0..=i64::MAX, count in 3_usize..=10) { + let out = roundtrip_short_repeat_helper::<_, UnsignedEncoding>(value, count)?; + prop_assert_eq!(out, vec![value; count]); + } + } +} diff --git a/src/encoding/integer/util.rs b/src/encoding/integer/util.rs new file mode 100644 index 0000000..f96e3ec --- /dev/null +++ b/src/encoding/integer/util.rs @@ -0,0 +1,928 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use bytes::{BufMut, BytesMut}; +use num::Signed; +use snafu::OptionExt; + +use crate::{ + encoding::util::read_u8, + error::{Result, VarintTooLargeSnafu}, +}; + +use super::{EncodingSign, NInt, VarintSerde}; + +/// Extracting run length from first two header bytes. +/// +/// Run length encoded as range [0, 511], so adjust to actual +/// value in range [1, 512]. +/// +/// Used for patched base, delta, and direct sub-encodings. +pub fn extract_run_length_from_header(first_byte: u8, second_byte: u8) -> usize { + let length = ((first_byte as u16 & 0x01) << 8) | (second_byte as u16); + (length + 1) as usize +} + +/// Read bitpacked integers into provided buffer. `bit_size` can be any value from 1 to 64, +/// inclusive. +pub fn read_ints( + buffer: &mut Vec, + expected_no_of_ints: usize, + bit_size: usize, + r: &mut impl Read, +) -> Result<()> { + debug_assert!( + (1..=64).contains(&bit_size), + "bit_size must be in range [1, 64]" + ); + match bit_size { + 1 => unrolled_unpack_1(buffer, expected_no_of_ints, r), + 2 => unrolled_unpack_2(buffer, expected_no_of_ints, r), + 4 => unrolled_unpack_4(buffer, expected_no_of_ints, r), + n if n % 8 == 0 => unrolled_unpack_byte_aligned(buffer, expected_no_of_ints, r, n / 8), + n => unrolled_unpack_unaligned(buffer, expected_no_of_ints, r, n), + } +} + +/// Decode numbers with bit width of 1 from read stream +fn unrolled_unpack_1( + buffer: &mut Vec, + expected_num_of_ints: usize, + reader: &mut impl Read, +) -> Result<()> { + for _ in 0..(expected_num_of_ints / 8) { + let byte = read_u8(reader)?; + let nums = [ + (byte >> 7) & 1, + (byte >> 6) & 1, + (byte >> 5) & 1, + (byte >> 4) & 1, + (byte >> 3) & 1, + (byte >> 2) & 1, + (byte >> 1) & 1, + byte & 1, + ]; + buffer.extend(nums.map(N::from_u8)); + } + + // Less than full byte at end, extract these trailing numbers + let remainder = expected_num_of_ints % 8; + if remainder > 0 { + let byte = read_u8(reader)?; + for i in 0..remainder { + let shift = 7 - i; + let n = N::from_u8((byte >> shift) & 1); + buffer.push(n); + } + } + + Ok(()) +} + +/// Decode numbers with bit width of 2 from read stream +fn unrolled_unpack_2( + buffer: &mut Vec, + expected_num_of_ints: usize, + reader: &mut impl Read, +) -> Result<()> { + for _ in 0..(expected_num_of_ints / 4) { + let byte = read_u8(reader)?; + let nums = [(byte >> 6) & 3, (byte >> 4) & 3, (byte >> 2) & 3, byte & 3]; + buffer.extend(nums.map(N::from_u8)); + } + + // Less than full byte at end, extract these trailing numbers + let remainder = expected_num_of_ints % 4; + if remainder > 0 { + let byte = read_u8(reader)?; + for i in 0..remainder { + let shift = 6 - (i * 2); + let n = N::from_u8((byte >> shift) & 3); + buffer.push(n); + } + } + + Ok(()) +} + +/// Decode numbers with bit width of 4 from read stream +fn unrolled_unpack_4( + buffer: &mut Vec, + expected_num_of_ints: usize, + reader: &mut impl Read, +) -> Result<()> { + for _ in 0..(expected_num_of_ints / 2) { + let byte = read_u8(reader)?; + let nums = [(byte >> 4) & 15, byte & 15]; + buffer.extend(nums.map(N::from_u8)); + } + + // At worst have 1 trailing 4-bit number + let remainder = expected_num_of_ints % 2; + if remainder > 0 { + let byte = read_u8(reader)?; + let n = N::from_u8((byte >> 4) & 15); + buffer.push(n); + } + + Ok(()) +} + +/// When the bitpacked integers have a width that isn't byte aligned +fn unrolled_unpack_unaligned( + buffer: &mut Vec, + expected_num_of_ints: usize, + reader: &mut impl Read, + bit_size: usize, +) -> Result<()> { + debug_assert!( + bit_size <= (N::BYTE_SIZE * 8), + "bit_size cannot exceed size of N" + ); + + let mut bits_left = 0; + let mut current_bits = N::zero(); + for _ in 0..expected_num_of_ints { + let mut result = N::zero(); + let mut bits_left_to_read = bit_size; + + // No bounds check as we rely on caller doing this check + // (since we know bit_size in advance) + + // bits_left_to_read and bits_left can never exceed 8 + // So safe to convert either to N + while bits_left_to_read > bits_left { + // TODO: explain this logic a bit + result <<= bits_left; + let mask = ((1_u16 << bits_left) - 1) as u8; + let mask = N::from_u8(mask); + result |= current_bits & mask; + bits_left_to_read -= bits_left; + + let byte = read_u8(reader)?; + current_bits = N::from_u8(byte); + + bits_left = 8; + } + + if bits_left_to_read > 0 { + result <<= bits_left_to_read; + bits_left -= bits_left_to_read; + let bits = current_bits >> bits_left; + let mask = ((1_u16 << bits_left_to_read) - 1) as u8; + let mask = N::from_u8(mask); + result |= bits & mask; + } + + buffer.push(result); + } + + Ok(()) +} + +/// Decode bitpacked integers which are byte aligned +#[inline] +fn unrolled_unpack_byte_aligned( + buffer: &mut Vec, + expected_num_of_ints: usize, + r: &mut impl Read, + num_bytes: usize, +) -> Result<()> { + debug_assert!( + num_bytes <= N::BYTE_SIZE, + "num_bytes cannot exceed size of integer being decoded into" + ); + // TODO: can probably read direct into buffer? read_big_endian() decodes + // into an intermediary buffer. + for _ in 0..expected_num_of_ints { + let num = N::read_big_endian(r, num_bytes)?; + buffer.push(num); + } + Ok(()) +} + +/// Write bit packed integers, where we expect the `bit_width` to be aligned +/// by [`get_closest_aligned_bit_width`], and we write the bytes as big endian. +pub fn write_aligned_packed_ints(writer: &mut BytesMut, bit_width: usize, values: &[N]) { + debug_assert!( + bit_width == 1 || bit_width == 2 || bit_width == 4 || bit_width % 8 == 0, + "bit_width must be 1, 2, 4 or a multiple of 8" + ); + match bit_width { + 1 => unrolled_pack_1(writer, values), + 2 => unrolled_pack_2(writer, values), + 4 => unrolled_pack_4(writer, values), + n => unrolled_pack_bytes(writer, n / 8, values), + } +} + +/// Similar to [`write_aligned_packed_ints`] but the `bit_width` allows any value +/// in the range `[1, 64]`. +pub fn write_packed_ints(writer: &mut BytesMut, bit_width: usize, values: &[N]) { + debug_assert!( + (1..=64).contains(&bit_width), + "bit_width must be in the range [1, 64]" + ); + if bit_width == 1 || bit_width == 2 || bit_width == 4 || bit_width % 8 == 0 { + write_aligned_packed_ints(writer, bit_width, values); + } else { + write_unaligned_packed_ints(writer, bit_width, values) + } +} + +fn write_unaligned_packed_ints(writer: &mut BytesMut, bit_width: usize, values: &[N]) { + debug_assert!( + (1..=64).contains(&bit_width), + "bit_width must be in the range [1, 64]" + ); + let mut bits_left = 8; + let mut current_byte = 0; + for &value in values { + let mut bits_to_write = bit_width; + // This loop will write 8 bits at a time into current_byte, except for the + // first iteration after a previous value has been written. The previous + // value may have bits left over, still in current_byte, which is represented + // by 8 - bits_left (aka bits_left is the amount of space left in current_byte). + while bits_to_write > bits_left { + // Writing from most significant bits first. + let shift = bits_to_write - bits_left; + // Shift so bits to write are in least significant 8 bits. + // Masking out higher bits so conversion to u8 is safe. + let bits = value.unsigned_shr(shift as u32) & N::from_u8(0xFF); + current_byte |= bits.to_u8().unwrap(); + bits_to_write -= bits_left; + + writer.put_u8(current_byte); + current_byte = 0; + bits_left = 8; + } + + // If there are trailing bits then include these into current_byte. + bits_left -= bits_to_write; + let bits = (value << bits_left) & N::from_u8(0xFF); + current_byte |= bits.to_u8().unwrap(); + + if bits_left == 0 { + writer.put_u8(current_byte); + current_byte = 0; + bits_left = 8; + } + } + // Flush any remaining bits + if bits_left != 8 { + writer.put_u8(current_byte); + } +} + +fn unrolled_pack_1(writer: &mut BytesMut, values: &[N]) { + let mut iter = values.chunks_exact(8); + for chunk in &mut iter { + let n1 = chunk[0].to_u8().unwrap() & 0x01; + let n2 = chunk[1].to_u8().unwrap() & 0x01; + let n3 = chunk[2].to_u8().unwrap() & 0x01; + let n4 = chunk[3].to_u8().unwrap() & 0x01; + let n5 = chunk[4].to_u8().unwrap() & 0x01; + let n6 = chunk[5].to_u8().unwrap() & 0x01; + let n7 = chunk[6].to_u8().unwrap() & 0x01; + let n8 = chunk[7].to_u8().unwrap() & 0x01; + let byte = + (n1 << 7) | (n2 << 6) | (n3 << 5) | (n4 << 4) | (n5 << 3) | (n6 << 2) | (n7 << 1) | n8; + writer.put_u8(byte); + } + let remainder = iter.remainder(); + if !remainder.is_empty() { + let mut byte = 0; + for (i, n) in remainder.iter().enumerate() { + let n = n.to_u8().unwrap(); + byte |= (n & 0x03) << (7 - i); + } + writer.put_u8(byte); + } +} + +fn unrolled_pack_2(writer: &mut BytesMut, values: &[N]) { + let mut iter = values.chunks_exact(4); + for chunk in &mut iter { + let n1 = chunk[0].to_u8().unwrap() & 0x03; + let n2 = chunk[1].to_u8().unwrap() & 0x03; + let n3 = chunk[2].to_u8().unwrap() & 0x03; + let n4 = chunk[3].to_u8().unwrap() & 0x03; + let byte = (n1 << 6) | (n2 << 4) | (n3 << 2) | n4; + writer.put_u8(byte); + } + let remainder = iter.remainder(); + if !remainder.is_empty() { + let mut byte = 0; + for (i, n) in remainder.iter().enumerate() { + let n = n.to_u8().unwrap(); + byte |= (n & 0x03) << (6 - i * 2); + } + writer.put_u8(byte); + } +} + +fn unrolled_pack_4(writer: &mut BytesMut, values: &[N]) { + let mut iter = values.chunks_exact(2); + for chunk in &mut iter { + let n1 = chunk[0].to_u8().unwrap() & 0x0F; + let n2 = chunk[1].to_u8().unwrap() & 0x0F; + let byte = (n1 << 4) | n2; + writer.put_u8(byte); + } + let remainder = iter.remainder(); + if !remainder.is_empty() { + let byte = remainder[0].to_u8().unwrap() & 0x0F; + let byte = byte << 4; + writer.put_u8(byte); + } +} + +fn unrolled_pack_bytes(writer: &mut BytesMut, byte_size: usize, values: &[N]) { + for num in values { + let bytes = num.to_be_bytes(); + let bytes = &bytes.as_ref()[N::BYTE_SIZE - byte_size..]; + writer.put_slice(bytes); + } +} + +/// Decoding table for RLEv2 sub-encodings bit width. +/// +/// Used by Direct, Patched Base and Delta. By default this assumes non-delta +/// (0 maps to 1), so Delta handles this discrepancy at the caller side. +/// +/// Input must be a 5-bit integer (max value is 31). +pub fn rle_v2_decode_bit_width(encoded: u8) -> usize { + debug_assert!(encoded < 32, "encoded bit width cannot exceed 5 bits"); + match encoded { + 0..=23 => encoded as usize + 1, + 24 => 26, + 25 => 28, + 26 => 30, + 27 => 32, + 28 => 40, + 29 => 48, + 30 => 56, + 31 => 64, + _ => unreachable!(), + } +} + +/// Inverse of [`rle_v2_decode_bit_width`]. +/// +/// Assumes supported bit width is passed in. Will panic on invalid +/// inputs that aren't defined in the ORC bit width encoding table +/// (such as 50). +pub fn rle_v2_encode_bit_width(width: usize) -> u8 { + debug_assert!(width <= 64, "bit width cannot exceed 64"); + match width { + 64 => 31, + 56 => 30, + 48 => 29, + 40 => 28, + 32 => 27, + 30 => 26, + 28 => 25, + 26 => 24, + 1..=24 => width as u8 - 1, + _ => unreachable!(), + } +} + +pub fn get_closest_fixed_bits(n: usize) -> usize { + match n { + 0 => 1, + 1..=24 => n, + 25..=26 => 26, + 27..=28 => 28, + 29..=30 => 30, + 31..=32 => 32, + 33..=40 => 40, + 41..=48 => 48, + 49..=56 => 56, + 57..=64 => 64, + _ => unreachable!(), + } +} + +pub fn encode_bit_width(n: usize) -> usize { + let n = get_closest_fixed_bits(n); + match n { + 1..=24 => n - 1, + 25..=26 => 24, + 27..=28 => 25, + 29..=30 => 26, + 31..=32 => 27, + 33..=40 => 28, + 41..=48 => 29, + 49..=56 => 30, + 57..=64 => 31, + _ => unreachable!(), + } +} + +fn decode_bit_width(n: usize) -> usize { + match n { + 0..=23 => n + 1, + 24 => 26, + 25 => 28, + 26 => 30, + 27 => 32, + 28 => 40, + 29 => 48, + 30 => 56, + 31 => 64, + _ => unreachable!(), + } +} + +/// Converts width of 64 bits or less to an aligned width, either rounding +/// up to the nearest multiple of 8, or rounding up to 1, 2 or 4. +pub fn get_closest_aligned_bit_width(width: usize) -> usize { + debug_assert!(width <= 64, "bit width cannot exceed 64"); + match width { + 0..=1 => 1, + 2 => 2, + 3..=4 => 4, + 5..=8 => 8, + 9..=16 => 16, + 17..=24 => 24, + 25..=32 => 32, + 33..=40 => 40, + 41..=48 => 48, + 49..=54 => 56, + 55..=64 => 64, + _ => unreachable!(), + } +} + +/// Decode Base 128 Unsigned Varint +fn read_varint(reader: &mut R) -> Result { + // Varints are encoded as sequence of bytes. + // Where the high bit of a byte is set to 1 if the varint + // continues into the next byte. Eventually it should terminate + // with a byte with high bit of 0. + let mut num = N::zero(); + let mut offset = 0; + loop { + let byte = read_u8(reader)?; + let is_last_byte = byte & 0x80 == 0; + let without_continuation_bit = byte & 0x7F; + num |= N::from_u8(without_continuation_bit) + // Ensure we don't overflow + .checked_shl(offset) + .context(VarintTooLargeSnafu)?; + // Since high bit doesn't contribute to final number, + // we need to shift in multiples of 7 to account for this. + offset += 7; + if is_last_byte { + break; + } + } + Ok(num) +} + +/// Encode Base 128 Unsigned Varint +fn write_varint(writer: &mut BytesMut, value: N) { + // Take max in case value = 0. + // Divide by 7 as high bit is always used as continuation flag. + let byte_size = value.bits_used().div_ceil(7).max(1); + // By default we'll have continuation bit set + // TODO: can probably do without Vec allocation? + let mut bytes = vec![0x80; byte_size]; + // Then just clear for the last one + let i = bytes.len() - 1; + bytes[i] = 0; + + // Encoding 7 bits at a time into bytes + let mask = N::from_u8(0x7F); + for (i, b) in bytes.iter_mut().enumerate() { + let shift = i * 7; + *b |= ((value >> shift) & mask).to_u8().unwrap(); + } + + writer.put_slice(&bytes); +} + +pub fn read_varint_zigzagged( + reader: &mut R, +) -> Result { + let unsigned = read_varint::(reader)?; + Ok(S::zigzag_decode(unsigned)) +} + +pub fn write_varint_zigzagged(writer: &mut BytesMut, value: N) { + let value = S::zigzag_encode(value); + write_varint(writer, value) +} + +/// Zigzag encoding stores the sign bit in the least significant bit. +#[inline] +pub fn signed_zigzag_decode(encoded: N) -> N { + let without_sign_bit = encoded.unsigned_shr(1); + let sign_bit = encoded & N::one(); + // If positive, sign_bit is 0 + // Negating 0 and doing bitwise XOR will just return without_sign_bit + // Since A ^ 0 = A + // If negative, sign_bit is 1 + // Negating turns to 11111...11 + // Then A ^ 1 = ~A (negating every bit in A) + without_sign_bit ^ -sign_bit +} + +/// Opposite of [`signed_zigzag_decode`]. +#[inline] +pub fn signed_zigzag_encode(value: N) -> N { + let l = N::BYTE_SIZE * 8 - 1; + (value << 1_usize) ^ (value >> l) +} + +/// MSB indicates if value is negated (1 if negative, else positive). Note we +/// take the MSB of the encoded number which might be smaller than N, hence +/// we need the encoded number byte size to find this MSB. +#[inline] +pub fn signed_msb_decode(encoded: N, encoded_byte_size: usize) -> N { + let msb_mask = N::one() << (encoded_byte_size * 8 - 1); + let is_positive = (encoded & msb_mask) == N::zero(); + let clean_sign_bit_mask = !msb_mask; + let encoded = encoded & clean_sign_bit_mask; + if is_positive { + encoded + } else { + -encoded + } +} + +/// Inverse of [`signed_msb_decode`]. +#[inline] +// TODO: bound this to only allow i64 input? might mess up for i32::MIN? +pub fn signed_msb_encode(value: N, encoded_byte_size: usize) -> N { + let is_signed = value.is_negative(); + // 0 if unsigned, 1 if signed + let sign_bit = N::from_u8(is_signed as u8); + let value = value.abs(); + let encoded_msb = sign_bit << (encoded_byte_size * 8 - 1); + encoded_msb | value +} + +/// Get the nth percentile, where input percentile must be in range (0.0, 1.0]. +pub fn calculate_percentile_bits(values: &[N], percentile: f32) -> usize { + debug_assert!( + percentile > 0.0 && percentile <= 1.0, + "percentile must be in range (0.0, 1.0]" + ); + + let mut histogram = [0; 32]; + for n in values { + // Map into range [0, 31] + let encoded_bit_width = encode_bit_width(n.bits_used()); + histogram[encoded_bit_width] += 1; + } + + // Then calculate the percentile here + let count = values.len() as f32; + let mut per_len = ((1.0 - percentile) * count) as usize; + for i in (0..32).rev() { + if let Some(a) = per_len.checked_sub(histogram[i]) { + per_len = a; + } else { + return decode_bit_width(i); + } + } + + // If percentile is in correct input range then we should always return above + unreachable!() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + encoding::integer::{SignedEncoding, UnsignedEncoding}, + error::Result, + }; + use proptest::prelude::*; + use std::io::Cursor; + + #[test] + fn test_zigzag_decode() { + assert_eq!(0, signed_zigzag_decode(0)); + assert_eq!(-1, signed_zigzag_decode(1)); + assert_eq!(1, signed_zigzag_decode(2)); + assert_eq!(-2, signed_zigzag_decode(3)); + assert_eq!(2, signed_zigzag_decode(4)); + assert_eq!(-3, signed_zigzag_decode(5)); + assert_eq!(3, signed_zigzag_decode(6)); + assert_eq!(-4, signed_zigzag_decode(7)); + assert_eq!(4, signed_zigzag_decode(8)); + assert_eq!(-5, signed_zigzag_decode(9)); + + assert_eq!(9_223_372_036_854_775_807, signed_zigzag_decode(-2_i64)); + assert_eq!(-9_223_372_036_854_775_808, signed_zigzag_decode(-1_i64)); + } + + #[test] + fn test_zigzag_encode() { + assert_eq!(0, signed_zigzag_encode(0)); + assert_eq!(1, signed_zigzag_encode(-1)); + assert_eq!(2, signed_zigzag_encode(1)); + assert_eq!(3, signed_zigzag_encode(-2)); + assert_eq!(4, signed_zigzag_encode(2)); + assert_eq!(5, signed_zigzag_encode(-3)); + assert_eq!(6, signed_zigzag_encode(3)); + assert_eq!(7, signed_zigzag_encode(-4)); + assert_eq!(8, signed_zigzag_encode(4)); + assert_eq!(9, signed_zigzag_encode(-5)); + + assert_eq!(-2_i64, signed_zigzag_encode(9_223_372_036_854_775_807)); + assert_eq!(-1_i64, signed_zigzag_encode(-9_223_372_036_854_775_808)); + } + + #[test] + fn roundtrip_zigzag_edge_cases() { + let value = 0_i16; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i16::MAX; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + + let value = 0_i32; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i32::MAX; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i32::MIN; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + + let value = 0_i64; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i64::MAX; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + let value = i64::MIN; + assert_eq!(signed_zigzag_decode(signed_zigzag_encode(value)), value); + } + + proptest! { + #[test] + fn roundtrip_zigzag_i16(value: i16) { + let out = signed_zigzag_decode(signed_zigzag_encode(value)); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_zigzag_i32(value: i32) { + let out = signed_zigzag_decode(signed_zigzag_encode(value)); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_zigzag_i64(value: i64) { + let out = signed_zigzag_decode(signed_zigzag_encode(value)); + prop_assert_eq!(value, out); + } + } + + fn generate_msb_test_value( + seed_value: N, + byte_size: usize, + signed: bool, + ) -> N { + // We mask out to values that can fit within the specified byte_size. + let shift = (N::BYTE_SIZE - byte_size) * 8; + let mask = N::max_value().unsigned_shr(shift as u32); + // And remove the msb since we manually set a value to signed based on the signed parameter. + let mask = mask >> 1; + let value = seed_value & mask; + // This guarantees values that can fit within byte_size when they are msb encoded, both + // signed and unsigned. + if signed { + -value + } else { + value + } + } + + #[test] + fn roundtrip_msb_edge_cases() { + // Testing all cases of max values for byte_size + signed combinations + for byte_size in 1..=2 { + for signed in [true, false] { + let value = generate_msb_test_value(i16::MAX, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + assert_eq!(value, out); + } + } + + for byte_size in 1..=4 { + for signed in [true, false] { + let value = generate_msb_test_value(i32::MAX, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + assert_eq!(value, out); + } + } + + for byte_size in 1..=8 { + for signed in [true, false] { + let value = generate_msb_test_value(i64::MAX, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + assert_eq!(value, out); + } + } + } + + proptest! { + #[test] + fn roundtrip_msb_i16(value: i16, byte_size in 1..=2_usize, signed: bool) { + let value = generate_msb_test_value(value, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_msb_i32(value: i32, byte_size in 1..=4_usize, signed: bool) { + let value = generate_msb_test_value(value, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + prop_assert_eq!(value, out); + } + + #[test] + fn roundtrip_msb_i64(value: i64, byte_size in 1..=8_usize, signed: bool) { + let value = generate_msb_test_value(value, byte_size, signed); + let out = signed_msb_decode(signed_msb_encode(value, byte_size), byte_size); + prop_assert_eq!(value, out); + } + } + + #[test] + fn test_read_varint() -> Result<()> { + fn test_assert(serialized: &[u8], expected: i64) -> Result<()> { + let mut reader = Cursor::new(serialized); + assert_eq!( + expected, + read_varint_zigzagged::(&mut reader)? + ); + Ok(()) + } + + test_assert(&[0x00], 0)?; + test_assert(&[0x01], 1)?; + test_assert(&[0x7f], 127)?; + test_assert(&[0x80, 0x01], 128)?; + test_assert(&[0x81, 0x01], 129)?; + test_assert(&[0xff, 0x7f], 16_383)?; + test_assert(&[0x80, 0x80, 0x01], 16_384)?; + test_assert(&[0x81, 0x80, 0x01], 16_385)?; + + // when too large + let err = read_varint_zigzagged::(&mut Cursor::new(&[ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01, + ])); + assert!(err.is_err()); + assert_eq!( + "Varint being decoded is too large", + err.unwrap_err().to_string() + ); + + // when unexpected end to stream + let err = + read_varint_zigzagged::(&mut Cursor::new(&[0x80, 0x80])); + assert!(err.is_err()); + assert_eq!( + "Failed to read, source: failed to fill whole buffer", + err.unwrap_err().to_string() + ); + + Ok(()) + } + + fn roundtrip_varint(value: N) -> N { + let mut buf = BytesMut::new(); + write_varint_zigzagged::(&mut buf, value); + read_varint_zigzagged::(&mut Cursor::new(&buf)).unwrap() + } + + proptest! { + #[test] + fn roundtrip_varint_i16(value: i16) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_i32(value: i32) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_i64(value: i64) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_i128(value: i128) { + let out = roundtrip_varint::<_, SignedEncoding>(value); + prop_assert_eq!(out, value); + } + + #[test] + fn roundtrip_varint_u64(value in 0..=i64::MAX) { + let out = roundtrip_varint::<_, UnsignedEncoding>(value); + prop_assert_eq!(out, value); + } + } + + #[test] + fn roundtrip_varint_edge_cases() { + let value = 0_i16; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i16::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i16::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + + let value = 0_i32; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i32::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i32::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + + let value = 0_i64; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i64::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i64::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + + let value = 0_i128; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i128::MIN; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + let value = i128::MAX; + assert_eq!(roundtrip_varint::<_, SignedEncoding>(value), value); + } + + /// Easier to generate values then bound them, instead of generating correctly bounded + /// values. In this case, bounds are that no value will exceed the `bit_width` in terms + /// of bit size. + fn mask_to_bit_width(values: &[N], bit_width: usize) -> Vec { + let shift = N::BYTE_SIZE * 8 - bit_width; + let mask = N::max_value().unsigned_shr(shift as u32); + values.iter().map(|&v| v & mask).collect() + } + + fn roundtrip_packed_ints_serde(values: &[N], bit_width: usize) -> Result> { + let mut buf = BytesMut::new(); + let mut out = vec![]; + write_packed_ints(&mut buf, bit_width, values); + read_ints(&mut out, values.len(), bit_width, &mut Cursor::new(buf))?; + Ok(out) + } + + proptest! { + #[test] + fn roundtrip_packed_ints_serde_i64( + values in prop::collection::vec(any::(), 1..=512), + bit_width in 1..=64_usize + ) { + let values = mask_to_bit_width(&values, bit_width); + let out = roundtrip_packed_ints_serde(&values, bit_width)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_packed_ints_serde_i32( + values in prop::collection::vec(any::(), 1..=512), + bit_width in 1..=32_usize + ) { + let values = mask_to_bit_width(&values, bit_width); + let out = roundtrip_packed_ints_serde(&values, bit_width)?; + prop_assert_eq!(out, values); + } + + #[test] + fn roundtrip_packed_ints_serde_i16( + values in prop::collection::vec(any::(), 1..=512), + bit_width in 1..=16_usize + ) { + let values = mask_to_bit_width(&values, bit_width); + let out = roundtrip_packed_ints_serde(&values, bit_width)?; + prop_assert_eq!(out, values); + } + } +} diff --git a/src/encoding/mod.rs b/src/encoding/mod.rs new file mode 100644 index 0000000..871ae0b --- /dev/null +++ b/src/encoding/mod.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Encoding/decoding logic for writing/reading primitive values from ORC types. + +use arrow::buffer::NullBuffer; +use bytes::Bytes; + +use crate::{error::Result, memory::EstimateMemory}; + +pub mod boolean; +pub mod byte; +pub mod decimal; +pub mod float; +pub mod integer; +mod rle; +pub mod timestamp; +mod util; + +/// Encodes primitive values into an internal buffer, usually with a specialized run length +/// encoding for better compression. +pub trait PrimitiveValueEncoder: EstimateMemory +where + V: Copy, +{ + fn new() -> Self; + + fn write_one(&mut self, value: V); + + fn write_slice(&mut self, values: &[V]) { + for &value in values { + self.write_one(value); + } + } + + /// Take the encoded bytes, replacing it with an empty buffer. + // TODO: Figure out how to retain the allocation instead of handing + // it off each time. + fn take_inner(&mut self) -> Bytes; +} + +pub trait PrimitiveValueDecoder { + /// Decode out.len() values into out at a time, failing if it cannot fill + /// the buffer. + fn decode(&mut self, out: &mut [V]) -> Result<()>; + + /// Decode into `out` according to the `true` elements in `present`. + /// + /// `present` must be the same length as `out`. + fn decode_spaced(&mut self, out: &mut [V], present: &NullBuffer) -> Result<()> { + debug_assert_eq!(out.len(), present.len()); + + // First get all the non-null values into a contiguous range. + let non_null_count = present.len() - present.null_count(); + if non_null_count == 0 { + // All nulls, don't bother decoding anything + return Ok(()); + } + // We read into the back because valid_indices() below is not reversible, + // so we just reverse our algorithm. + let range_start = out.len() - non_null_count; + self.decode(&mut out[range_start..])?; + if non_null_count == present.len() { + // No nulls, don't need to space out + return Ok(()); + } + + // From the head of the contiguous range (at the end of the buffer) we swap + // with the null elements to ensure it matches with the present buffer. + let head_indices = range_start..out.len(); + for (correct_index, head_index) in present.valid_indices().zip(head_indices) { + // head_index points to the value we need to move to correct_index + out.swap(correct_index, head_index); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + + /// Emits numbers increasing from 0. + struct DummyDecoder; + + impl PrimitiveValueDecoder for DummyDecoder { + fn decode(&mut self, out: &mut [i32]) -> Result<()> { + let values = (0..out.len()).map(|x| x as i32).collect::>(); + out.copy_from_slice(&values); + Ok(()) + } + } + + fn gen_spaced_dummy_decoder_expected(present: &[bool]) -> Vec { + let mut value = 0; + let mut expected = vec![]; + for &is_present in present { + if is_present { + expected.push(value); + value += 1; + } else { + expected.push(-1); + } + } + expected + } + + proptest! { + #[test] + fn decode_spaced_proptest(present: Vec) { + let mut decoder = DummyDecoder; + let mut out = vec![-1; present.len()]; + decoder.decode_spaced(&mut out, &NullBuffer::from(present.clone())).unwrap(); + let expected = gen_spaced_dummy_decoder_expected(&present); + prop_assert_eq!(out, expected); + } + } + + #[test] + fn decode_spaced_edge_cases() { + let mut decoder = DummyDecoder; + let len = 10; + + // all present + let mut out = vec![-1; len]; + let present = vec![true; len]; + let present = NullBuffer::from(present); + decoder.decode_spaced(&mut out, &present).unwrap(); + let expected: Vec<_> = (0..len).map(|i| i as i32).collect(); + assert_eq!(out, expected); + + // all null + let mut out = vec![-1; len]; + let present = vec![false; len]; + let present = NullBuffer::from(present); + decoder.decode_spaced(&mut out, &present).unwrap(); + let expected = vec![-1; len]; + assert_eq!(out, expected); + } +} diff --git a/src/encoding/rle.rs b/src/encoding/rle.rs new file mode 100644 index 0000000..a330efa --- /dev/null +++ b/src/encoding/rle.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{OutOfSpecSnafu, Result}; + +use super::PrimitiveValueDecoder; + +mod sealed { + use std::io::Read; + + use crate::encoding::{ + byte::ByteRleDecoder, + integer::{rle_v1::RleV1Decoder, rle_v2::RleV2Decoder, EncodingSign, NInt}, + }; + + pub trait Rle {} + + impl Rle for ByteRleDecoder {} + impl Rle for RleV1Decoder {} + impl Rle for RleV2Decoder {} +} + +/// Generic decoding behaviour for run length encoded values, such as integers (v1 and v2) +/// and bytes. +/// +/// Assumes an internal buffer which acts like a (single headed) queue where values are first +/// decoded into, before being copied out into the output buffer (usually an Arrow array). +pub trait GenericRle { + /// Consume N elements from internal buffer to signify the values having been copied out. + fn advance(&mut self, n: usize); + + /// All values available in internal buffer, respecting the current advancement level. + fn available(&self) -> &[V]; + + /// This should clear the internal buffer and populate it with the next round of decoded + /// values. + // TODO: Have a version that copies directly into the output buffer (e.g. Arrow array). + // Currently we always decode to the internal buffer first, even if we can copy + // directly to the output and skip the middle man. Ideally the internal buffer + // should only be used for leftovers between calls to PrimitiveValueDecoder::decode. + fn decode_batch(&mut self) -> Result<()>; +} + +impl + sealed::Rle> PrimitiveValueDecoder for G { + fn decode(&mut self, out: &mut [V]) -> Result<()> { + let available = self.available(); + // If we have enough leftover to copy, can skip decoding more. + if available.len() >= out.len() { + out.copy_from_slice(&available[..out.len()]); + self.advance(out.len()); + return Ok(()); + } + + // Otherwise progressively decode and copy over chunks. + let len_to_copy = out.len(); + let mut copied = 0; + while copied < len_to_copy { + if self.available().is_empty() { + self.decode_batch()?; + } + + let copying = self.available().len(); + // At most, we fill to exact length of output buffer (don't overflow). + let copying = copying.min(len_to_copy - copied); + + let out = &mut out[copied..]; + out[..copying].copy_from_slice(&self.available()[..copying]); + + copied += copying; + self.advance(copying); + } + + // We always expect to be able to fill the output buffer; it is up to the + // caller to control that size. + if copied != out.len() { + // TODO: more descriptive error + OutOfSpecSnafu { + msg: "Array length less than expected", + } + .fail() + } else { + Ok(()) + } + } +} diff --git a/src/encoding/timestamp.rs b/src/encoding/timestamp.rs new file mode 100644 index 0000000..d011311 --- /dev/null +++ b/src/encoding/timestamp.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::marker::PhantomData; + +use arrow::datatypes::{ArrowTimestampType, TimeUnit}; +use snafu::ensure; + +use crate::{ + encoding::PrimitiveValueDecoder, + error::{DecodeTimestampSnafu, Result}, +}; + +const NANOSECONDS_IN_SECOND: i64 = 1_000_000_000; + +pub struct TimestampDecoder { + base_from_epoch: i64, + data: Box + Send>, + secondary: Box + Send>, + _marker: PhantomData, +} + +impl TimestampDecoder { + pub fn new( + base_from_epoch: i64, + data: Box + Send>, + secondary: Box + Send>, + ) -> Self { + Self { + base_from_epoch, + data, + secondary, + _marker: PhantomData, + } + } +} + +impl PrimitiveValueDecoder for TimestampDecoder { + fn decode(&mut self, out: &mut [T::Native]) -> Result<()> { + // TODO: can probably optimize, reuse buffers? + let mut data = vec![0; out.len()]; + let mut secondary = vec![0; out.len()]; + self.data.decode(&mut data)?; + self.secondary.decode(&mut secondary)?; + for (index, (&seconds_since_orc_base, &nanoseconds)) in + data.iter().zip(secondary.iter()).enumerate() + { + out[index] = + decode_timestamp::(self.base_from_epoch, seconds_since_orc_base, nanoseconds)?; + } + Ok(()) + } +} + +/// Arrow TimestampNanosecond type cannot represent the full datetime range of +/// the ORC Timestamp type, so this iterator provides the ability to decode the +/// raw nanoseconds without restricting it to the Arrow TimestampNanosecond range. +pub struct TimestampNanosecondAsDecimalDecoder { + base_from_epoch: i64, + data: Box + Send>, + secondary: Box + Send>, +} + +impl TimestampNanosecondAsDecimalDecoder { + pub fn new( + base_from_epoch: i64, + data: Box + Send>, + secondary: Box + Send>, + ) -> Self { + Self { + base_from_epoch, + data, + secondary, + } + } +} + +impl PrimitiveValueDecoder for TimestampNanosecondAsDecimalDecoder { + fn decode(&mut self, out: &mut [i128]) -> Result<()> { + // TODO: can probably optimize, reuse buffers? + let mut data = vec![0; out.len()]; + let mut secondary = vec![0; out.len()]; + self.data.decode(&mut data)?; + self.secondary.decode(&mut secondary)?; + for (index, (&seconds_since_orc_base, &nanoseconds)) in + data.iter().zip(secondary.iter()).enumerate() + { + out[index] = + decode_timestamp_as_i128(self.base_from_epoch, seconds_since_orc_base, nanoseconds); + } + Ok(()) + } +} + +fn decode(base: i64, seconds_since_orc_base: i64, nanoseconds: i64) -> (i128, i64, u64) { + let data = seconds_since_orc_base; + // TODO: is this a safe cast? + let mut nanoseconds = nanoseconds as u64; + // Last 3 bits indicate how many trailing zeros were truncated + let zeros = nanoseconds & 0x7; + nanoseconds >>= 3; + // Multiply by powers of 10 to get back the trailing zeros + // TODO: would it be more efficient to unroll this? (if LLVM doesn't already do so) + if zeros != 0 { + nanoseconds *= 10_u64.pow(zeros as u32 + 1); + } + let seconds_since_epoch = data + base; + // Timestamps below the UNIX epoch with nanoseconds > 999_999 need to be + // adjusted to have 1 second subtracted due to ORC-763: + // https://issues.apache.org/jira/browse/ORC-763 + let seconds = if seconds_since_epoch < 0 && nanoseconds > 999_999 { + seconds_since_epoch - 1 + } else { + seconds_since_epoch + }; + // Convert into nanoseconds since epoch, which Arrow uses as native representation + // of timestamps + // The timestamp may overflow i64 as ORC encodes them as a pair of (seconds, nanoseconds) + // while we encode them as a single i64 of nanoseconds in Arrow. + let nanoseconds_since_epoch = + (seconds as i128 * NANOSECONDS_IN_SECOND as i128) + (nanoseconds as i128); + (nanoseconds_since_epoch, seconds, nanoseconds) +} + +fn decode_timestamp( + base: i64, + seconds_since_orc_base: i64, + nanoseconds: i64, +) -> Result { + let (nanoseconds_since_epoch, seconds, nanoseconds) = + decode(base, seconds_since_orc_base, nanoseconds); + + let nanoseconds_in_timeunit = match T::UNIT { + TimeUnit::Second => 1_000_000_000, + TimeUnit::Millisecond => 1_000_000, + TimeUnit::Microsecond => 1_000, + TimeUnit::Nanosecond => 1, + }; + + // Error if loss of precision + // TODO: make this configurable (e.g. can succeed but truncate) + ensure!( + nanoseconds_since_epoch % nanoseconds_in_timeunit == 0, + DecodeTimestampSnafu { + seconds, + nanoseconds, + to_time_unit: T::UNIT, + } + ); + + // Convert to i64 and error if overflow + let num_since_epoch = (nanoseconds_since_epoch / nanoseconds_in_timeunit) + .try_into() + .or_else(|_| { + DecodeTimestampSnafu { + seconds, + nanoseconds, + to_time_unit: T::UNIT, + } + .fail() + })?; + + Ok(num_since_epoch) +} + +fn decode_timestamp_as_i128(base: i64, seconds_since_orc_base: i64, nanoseconds: i64) -> i128 { + let (nanoseconds_since_epoch, _, _) = decode(base, seconds_since_orc_base, nanoseconds); + nanoseconds_since_epoch +} diff --git a/src/encoding/util.rs b/src/encoding/util.rs new file mode 100644 index 0000000..862d12f --- /dev/null +++ b/src/encoding/util.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::Read; + +use snafu::ResultExt; + +use crate::error::{self, Result}; + +/// Read single byte. +#[inline] +pub fn read_u8(reader: &mut impl Read) -> Result { + let mut byte = [0]; + reader.read_exact(&mut byte).context(error::IoSnafu)?; + Ok(byte[0]) +} + +/// Like [`read_u8()`] but returns `Ok(None)` if reader has reached EOF. +#[inline] +pub fn try_read_u8(reader: &mut impl Read) -> Result> { + let mut byte = [0]; + let length = reader.read(&mut byte).context(error::IoSnafu)?; + Ok((length > 0).then_some(byte[0])) +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..02713e2 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io; + +use arrow::datatypes::DataType as ArrowDataType; +use arrow::datatypes::TimeUnit; +use arrow::error::ArrowError; +use snafu::prelude::*; +use snafu::Location; + +use crate::proto; +use crate::schema::DataType; + +// TODO: consolidate error types? better to have a smaller set? +#[derive(Debug, Snafu)] +#[snafu(visibility(pub))] +pub enum OrcError { + #[snafu(display("Failed to read, source: {}", source))] + IoError { + source: std::io::Error, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Empty file"))] + EmptyFile { + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Out of spec, message: {}", msg))] + OutOfSpec { + msg: String, + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("Failed to decode float, source: {}", source))] + DecodeFloat { + #[snafu(implicit)] + location: Location, + source: std::io::Error, + }, + + #[snafu(display( + "Overflow while decoding timestamp (seconds={}, nanoseconds={}) to {:?}", + seconds, + nanoseconds, + to_time_unit, + ))] + DecodeTimestamp { + #[snafu(implicit)] + location: Location, + seconds: i64, + nanoseconds: u64, + to_time_unit: TimeUnit, + }, + + #[snafu(display("Failed to decode proto, source: {}", source))] + DecodeProto { + #[snafu(implicit)] + location: Location, + source: prost::DecodeError, + }, + + #[snafu(display("No types found"))] + NoTypes { + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("unsupported type variant: {}", msg))] + UnsupportedTypeVariant { + #[snafu(implicit)] + location: Location, + msg: &'static str, + }, + + #[snafu(display( + "Cannot decode ORC type {:?} into Arrow type {:?}", + orc_type, + arrow_type, + ))] + MismatchedSchema { + #[snafu(implicit)] + location: Location, + orc_type: DataType, + arrow_type: ArrowDataType, + }, + + #[snafu(display("Invalid encoding for column '{}': {:?}", name, encoding))] + InvalidColumnEncoding { + #[snafu(implicit)] + location: Location, + name: String, + encoding: proto::column_encoding::Kind, + }, + + #[snafu(display("Failed to convert to record batch: {}", source))] + ConvertRecordBatch { + #[snafu(implicit)] + location: Location, + source: ArrowError, + }, + + #[snafu(display("Varint being decoded is too large"))] + VarintTooLarge { + #[snafu(implicit)] + location: Location, + }, + + #[snafu(display("unexpected: {}", msg))] + Unexpected { + #[snafu(implicit)] + location: Location, + msg: String, + }, + + #[snafu(display("Failed to build zstd decoder: {}", source))] + BuildZstdDecoder { + #[snafu(implicit)] + location: Location, + source: io::Error, + }, + + #[snafu(display("Failed to build snappy decoder: {}", source))] + BuildSnappyDecoder { + #[snafu(implicit)] + location: Location, + source: snap::Error, + }, + + #[snafu(display("Failed to build lzo decoder: {}", source))] + BuildLzoDecoder { + #[snafu(implicit)] + location: Location, + source: lzokay_native::Error, + }, + + #[snafu(display("Failed to build lz4 decoder: {}", source))] + BuildLz4Decoder { + #[snafu(implicit)] + location: Location, + source: lz4_flex::block::DecompressError, + }, + + #[snafu(display("Arrow error: {}", source))] + Arrow { + source: arrow::error::ArrowError, + #[snafu(implicit)] + location: Location, + }, +} + +pub type Result = std::result::Result; + +impl From for ArrowError { + fn from(value: OrcError) -> Self { + ArrowError::ExternalError(Box::new(value)) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..f5274eb --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A native Rust implementation of the [Apache ORC](https://orc.apache.org) file format, +//! providing API's to read data into [Apache Arrow](https://arrow.apache.org) in-memory arrays. +//! +//! # Example read usage +//! +//! ```no_run +//! # use std::fs::File; +//! # use orc_rust::arrow_reader::ArrowReaderBuilder; +//! let file = File::open("/path/to/file.orc").unwrap(); +//! let reader = ArrowReaderBuilder::try_new(file).unwrap().build(); +//! let record_batches = reader.collect::, _>>().unwrap(); +//! ``` +//! +//! # Example write usage +//! +//! ```no_run +//! # use std::fs::File; +//! # use arrow::array::RecordBatch; +//! # use orc_rust::arrow_writer::ArrowWriterBuilder; +//! # fn get_record_batch() -> RecordBatch { +//! # unimplemented!() +//! # } +//! let file = File::create("/path/to/file.orc").unwrap(); +//! let batch = get_record_batch(); +//! let mut writer = ArrowWriterBuilder::new(file, batch.schema()) +//! .try_build() +//! .unwrap(); +//! writer.write(&batch).unwrap(); +//! writer.close().unwrap(); +//! ``` + +mod array_decoder; +pub mod arrow_reader; +pub mod arrow_writer; +#[cfg(feature = "async")] +pub mod async_arrow_reader; +mod column; +pub mod compression; +mod encoding; +pub mod error; +mod memory; +pub mod projection; +mod proto; +pub mod reader; +pub mod schema; +pub mod statistics; +pub mod stripe; +mod writer; + +pub use arrow_reader::{ArrowReader, ArrowReaderBuilder}; +pub use arrow_writer::{ArrowWriter, ArrowWriterBuilder}; +#[cfg(feature = "async")] +pub use async_arrow_reader::ArrowStreamReader; diff --git a/src/memory.rs b/src/memory.rs new file mode 100644 index 0000000..f5e5c92 --- /dev/null +++ b/src/memory.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Estimating memory usage is important when writing files, as we finish +/// writing a stripe according to a set size threshold. +pub trait EstimateMemory { + /// Approximate current memory usage in bytes. + fn estimate_memory_size(&self) -> usize; +} diff --git a/src/projection.rs b/src/projection.rs new file mode 100644 index 0000000..70815db --- /dev/null +++ b/src/projection.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::schema::RootDataType; + +// TODO: be able to nest project (project columns within struct type) + +/// Specifies which column indices to project from an ORC type. +#[derive(Debug, Clone)] +pub struct ProjectionMask { + /// Indices of column in ORC type, can refer to nested types + /// (not only root level columns) + indices: Option>, +} + +impl ProjectionMask { + /// Project all columns. + pub fn all() -> Self { + Self { indices: None } + } + + /// Project only specific columns from the root type by column index. + pub fn roots(root_data_type: &RootDataType, indices: impl IntoIterator) -> Self { + // TODO: return error if column index not found? + let input_indices = indices.into_iter().collect::>(); + // By default always project root + let mut indices = vec![0]; + root_data_type + .children() + .iter() + .filter(|col| input_indices.contains(&col.data_type().column_index())) + .for_each(|col| indices.extend(col.data_type().all_indices())); + Self { + indices: Some(indices), + } + } + + /// Project only specific columns from the root type by column name. + pub fn named_roots(root_data_type: &RootDataType, names: &[T]) -> Self + where + T: AsRef, + { + // TODO: return error if column name not found? + // By default always project root + let mut indices = vec![0]; + let names = names.iter().map(AsRef::as_ref).collect::>(); + root_data_type + .children() + .iter() + .filter(|col| names.contains(&col.name())) + .for_each(|col| indices.extend(col.data_type().all_indices())); + Self { + indices: Some(indices), + } + } + + /// Check if ORC column should is projected or not, by index. + pub fn is_index_projected(&self, index: usize) -> bool { + match &self.indices { + Some(indices) => indices.contains(&index), + None => true, + } + } +} diff --git a/src/proto.rs b/src/proto.rs new file mode 100644 index 0000000..ae71cdb --- /dev/null +++ b/src/proto.rs @@ -0,0 +1,829 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This file was automatically generated through the regen.sh script, and should not be edited. + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IntegerStatistics { + #[prost(sint64, optional, tag = "1")] + pub minimum: ::core::option::Option, + #[prost(sint64, optional, tag = "2")] + pub maximum: ::core::option::Option, + #[prost(sint64, optional, tag = "3")] + pub sum: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DoubleStatistics { + #[prost(double, optional, tag = "1")] + pub minimum: ::core::option::Option, + #[prost(double, optional, tag = "2")] + pub maximum: ::core::option::Option, + #[prost(double, optional, tag = "3")] + pub sum: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StringStatistics { + #[prost(string, optional, tag = "1")] + pub minimum: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "2")] + pub maximum: ::core::option::Option<::prost::alloc::string::String>, + /// sum will store the total length of all strings in a stripe + #[prost(sint64, optional, tag = "3")] + pub sum: ::core::option::Option, + /// If the minimum or maximum value was longer than 1024 bytes, store a lower or upper + /// bound instead of the minimum or maximum values above. + #[prost(string, optional, tag = "4")] + pub lower_bound: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "5")] + pub upper_bound: ::core::option::Option<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BucketStatistics { + #[prost(uint64, repeated, tag = "1")] + pub count: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DecimalStatistics { + #[prost(string, optional, tag = "1")] + pub minimum: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "2")] + pub maximum: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "3")] + pub sum: ::core::option::Option<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DateStatistics { + /// min,max values saved as days since epoch + #[prost(sint32, optional, tag = "1")] + pub minimum: ::core::option::Option, + #[prost(sint32, optional, tag = "2")] + pub maximum: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TimestampStatistics { + /// min,max values saved as milliseconds since epoch + #[prost(sint64, optional, tag = "1")] + pub minimum: ::core::option::Option, + #[prost(sint64, optional, tag = "2")] + pub maximum: ::core::option::Option, + #[prost(sint64, optional, tag = "3")] + pub minimum_utc: ::core::option::Option, + #[prost(sint64, optional, tag = "4")] + pub maximum_utc: ::core::option::Option, + /// store the lower 6 TS digits for min/max to achieve nanosecond precision + #[prost(int32, optional, tag = "5")] + pub minimum_nanos: ::core::option::Option, + #[prost(int32, optional, tag = "6")] + pub maximum_nanos: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BinaryStatistics { + /// sum will store the total binary blob length in a stripe + #[prost(sint64, optional, tag = "1")] + pub sum: ::core::option::Option, +} +/// Statistics for list and map +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CollectionStatistics { + #[prost(uint64, optional, tag = "1")] + pub min_children: ::core::option::Option, + #[prost(uint64, optional, tag = "2")] + pub max_children: ::core::option::Option, + #[prost(uint64, optional, tag = "3")] + pub total_children: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnStatistics { + #[prost(uint64, optional, tag = "1")] + pub number_of_values: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub int_statistics: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub double_statistics: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub string_statistics: ::core::option::Option, + #[prost(message, optional, tag = "5")] + pub bucket_statistics: ::core::option::Option, + #[prost(message, optional, tag = "6")] + pub decimal_statistics: ::core::option::Option, + #[prost(message, optional, tag = "7")] + pub date_statistics: ::core::option::Option, + #[prost(message, optional, tag = "8")] + pub binary_statistics: ::core::option::Option, + #[prost(message, optional, tag = "9")] + pub timestamp_statistics: ::core::option::Option, + #[prost(bool, optional, tag = "10")] + pub has_null: ::core::option::Option, + #[prost(uint64, optional, tag = "11")] + pub bytes_on_disk: ::core::option::Option, + #[prost(message, optional, tag = "12")] + pub collection_statistics: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RowIndexEntry { + #[prost(uint64, repeated, tag = "1")] + pub positions: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "2")] + pub statistics: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct RowIndex { + #[prost(message, repeated, tag = "1")] + pub entry: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BloomFilter { + #[prost(uint32, optional, tag = "1")] + pub num_hash_functions: ::core::option::Option, + #[prost(fixed64, repeated, packed = "false", tag = "2")] + pub bitset: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub utf8bitset: ::core::option::Option<::prost::alloc::vec::Vec>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BloomFilterIndex { + #[prost(message, repeated, tag = "1")] + pub bloom_filter: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Stream { + #[prost(enumeration = "stream::Kind", optional, tag = "1")] + pub kind: ::core::option::Option, + #[prost(uint32, optional, tag = "2")] + pub column: ::core::option::Option, + #[prost(uint64, optional, tag = "3")] + pub length: ::core::option::Option, +} +/// Nested message and enum types in `Stream`. +pub mod stream { + /// if you add new index stream kinds, you need to make sure to update + /// StreamName to ensure it is added to the stripe in the right area + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum Kind { + Present = 0, + Data = 1, + Length = 2, + DictionaryData = 3, + DictionaryCount = 4, + Secondary = 5, + RowIndex = 6, + BloomFilter = 7, + BloomFilterUtf8 = 8, + /// Virtual stream kinds to allocate space for encrypted index and data. + EncryptedIndex = 9, + EncryptedData = 10, + /// stripe statistics streams + StripeStatistics = 100, + /// A virtual stream kind that is used for setting the encryption IV. + FileStatistics = 101, + } + impl Kind { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Kind::Present => "PRESENT", + Kind::Data => "DATA", + Kind::Length => "LENGTH", + Kind::DictionaryData => "DICTIONARY_DATA", + Kind::DictionaryCount => "DICTIONARY_COUNT", + Kind::Secondary => "SECONDARY", + Kind::RowIndex => "ROW_INDEX", + Kind::BloomFilter => "BLOOM_FILTER", + Kind::BloomFilterUtf8 => "BLOOM_FILTER_UTF8", + Kind::EncryptedIndex => "ENCRYPTED_INDEX", + Kind::EncryptedData => "ENCRYPTED_DATA", + Kind::StripeStatistics => "STRIPE_STATISTICS", + Kind::FileStatistics => "FILE_STATISTICS", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "PRESENT" => Some(Self::Present), + "DATA" => Some(Self::Data), + "LENGTH" => Some(Self::Length), + "DICTIONARY_DATA" => Some(Self::DictionaryData), + "DICTIONARY_COUNT" => Some(Self::DictionaryCount), + "SECONDARY" => Some(Self::Secondary), + "ROW_INDEX" => Some(Self::RowIndex), + "BLOOM_FILTER" => Some(Self::BloomFilter), + "BLOOM_FILTER_UTF8" => Some(Self::BloomFilterUtf8), + "ENCRYPTED_INDEX" => Some(Self::EncryptedIndex), + "ENCRYPTED_DATA" => Some(Self::EncryptedData), + "STRIPE_STATISTICS" => Some(Self::StripeStatistics), + "FILE_STATISTICS" => Some(Self::FileStatistics), + _ => None, + } + } + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnEncoding { + #[prost(enumeration = "column_encoding::Kind", optional, tag = "1")] + pub kind: ::core::option::Option, + #[prost(uint32, optional, tag = "2")] + pub dictionary_size: ::core::option::Option, + /// The encoding of the bloom filters for this column: + /// 0 or missing = none or original + /// 1 = ORC-135 (utc for timestamps) + #[prost(uint32, optional, tag = "3")] + pub bloom_encoding: ::core::option::Option, +} +/// Nested message and enum types in `ColumnEncoding`. +pub mod column_encoding { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum Kind { + Direct = 0, + Dictionary = 1, + DirectV2 = 2, + DictionaryV2 = 3, + } + impl Kind { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Kind::Direct => "DIRECT", + Kind::Dictionary => "DICTIONARY", + Kind::DirectV2 => "DIRECT_V2", + Kind::DictionaryV2 => "DICTIONARY_V2", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "DIRECT" => Some(Self::Direct), + "DICTIONARY" => Some(Self::Dictionary), + "DIRECT_V2" => Some(Self::DirectV2), + "DICTIONARY_V2" => Some(Self::DictionaryV2), + _ => None, + } + } + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StripeEncryptionVariant { + #[prost(message, repeated, tag = "1")] + pub streams: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub encoding: ::prost::alloc::vec::Vec, +} +// each stripe looks like: +// index streams +// unencrypted +// variant 1..N +// data streams +// unencrypted +// variant 1..N +// footer + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StripeFooter { + #[prost(message, repeated, tag = "1")] + pub streams: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub columns: ::prost::alloc::vec::Vec, + #[prost(string, optional, tag = "3")] + pub writer_timezone: ::core::option::Option<::prost::alloc::string::String>, + /// one for each column encryption variant + #[prost(message, repeated, tag = "4")] + pub encryption: ::prost::alloc::vec::Vec, +} +// the file tail looks like: +// encrypted stripe statistics: ColumnarStripeStatistics (order by variant) +// stripe statistics: Metadata +// footer: Footer +// postscript: PostScript +// psLen: byte + +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StringPair { + #[prost(string, optional, tag = "1")] + pub key: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Type { + #[prost(enumeration = "r#type::Kind", optional, tag = "1")] + pub kind: ::core::option::Option, + #[prost(uint32, repeated, tag = "2")] + pub subtypes: ::prost::alloc::vec::Vec, + #[prost(string, repeated, tag = "3")] + pub field_names: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(uint32, optional, tag = "4")] + pub maximum_length: ::core::option::Option, + #[prost(uint32, optional, tag = "5")] + pub precision: ::core::option::Option, + #[prost(uint32, optional, tag = "6")] + pub scale: ::core::option::Option, + #[prost(message, repeated, tag = "7")] + pub attributes: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `Type`. +pub mod r#type { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum Kind { + Boolean = 0, + Byte = 1, + Short = 2, + Int = 3, + Long = 4, + Float = 5, + Double = 6, + String = 7, + Binary = 8, + Timestamp = 9, + List = 10, + Map = 11, + Struct = 12, + Union = 13, + Decimal = 14, + Date = 15, + Varchar = 16, + Char = 17, + TimestampInstant = 18, + } + impl Kind { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Kind::Boolean => "BOOLEAN", + Kind::Byte => "BYTE", + Kind::Short => "SHORT", + Kind::Int => "INT", + Kind::Long => "LONG", + Kind::Float => "FLOAT", + Kind::Double => "DOUBLE", + Kind::String => "STRING", + Kind::Binary => "BINARY", + Kind::Timestamp => "TIMESTAMP", + Kind::List => "LIST", + Kind::Map => "MAP", + Kind::Struct => "STRUCT", + Kind::Union => "UNION", + Kind::Decimal => "DECIMAL", + Kind::Date => "DATE", + Kind::Varchar => "VARCHAR", + Kind::Char => "CHAR", + Kind::TimestampInstant => "TIMESTAMP_INSTANT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "BOOLEAN" => Some(Self::Boolean), + "BYTE" => Some(Self::Byte), + "SHORT" => Some(Self::Short), + "INT" => Some(Self::Int), + "LONG" => Some(Self::Long), + "FLOAT" => Some(Self::Float), + "DOUBLE" => Some(Self::Double), + "STRING" => Some(Self::String), + "BINARY" => Some(Self::Binary), + "TIMESTAMP" => Some(Self::Timestamp), + "LIST" => Some(Self::List), + "MAP" => Some(Self::Map), + "STRUCT" => Some(Self::Struct), + "UNION" => Some(Self::Union), + "DECIMAL" => Some(Self::Decimal), + "DATE" => Some(Self::Date), + "VARCHAR" => Some(Self::Varchar), + "CHAR" => Some(Self::Char), + "TIMESTAMP_INSTANT" => Some(Self::TimestampInstant), + _ => None, + } + } + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StripeInformation { + /// the global file offset of the start of the stripe + #[prost(uint64, optional, tag = "1")] + pub offset: ::core::option::Option, + /// the number of bytes of index + #[prost(uint64, optional, tag = "2")] + pub index_length: ::core::option::Option, + /// the number of bytes of data + #[prost(uint64, optional, tag = "3")] + pub data_length: ::core::option::Option, + /// the number of bytes in the stripe footer + #[prost(uint64, optional, tag = "4")] + pub footer_length: ::core::option::Option, + /// the number of rows in this stripe + #[prost(uint64, optional, tag = "5")] + pub number_of_rows: ::core::option::Option, + /// If this is present, the reader should use this value for the encryption + /// stripe id for setting the encryption IV. Otherwise, the reader should + /// use one larger than the previous stripe's encryptStripeId. + /// For unmerged ORC files, the first stripe will use 1 and the rest of the + /// stripes won't have it set. For merged files, the stripe information + /// will be copied from their original files and thus the first stripe of + /// each of the input files will reset it to 1. + /// Note that 1 was choosen, because protobuf v3 doesn't serialize + /// primitive types that are the default (eg. 0). + #[prost(uint64, optional, tag = "6")] + pub encrypt_stripe_id: ::core::option::Option, + /// For each encryption variant, the new encrypted local key to use + /// until we find a replacement. + #[prost(bytes = "vec", repeated, tag = "7")] + pub encrypted_local_keys: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UserMetadataItem { + #[prost(string, optional, tag = "1")] + pub name: ::core::option::Option<::prost::alloc::string::String>, + #[prost(bytes = "vec", optional, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::vec::Vec>, +} +/// StripeStatistics (1 per a stripe), which each contain the +/// ColumnStatistics for each column. +/// This message type is only used in ORC v0 and v1. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StripeStatistics { + #[prost(message, repeated, tag = "1")] + pub col_stats: ::prost::alloc::vec::Vec, +} +/// This message type is only used in ORC v0 and v1. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Metadata { + #[prost(message, repeated, tag = "1")] + pub stripe_stats: ::prost::alloc::vec::Vec, +} +/// In ORC v2 (and for encrypted columns in v1), each column has +/// their column statistics written separately. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnarStripeStatistics { + /// one value for each stripe in the file + #[prost(message, repeated, tag = "1")] + pub col_stats: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileStatistics { + #[prost(message, repeated, tag = "1")] + pub column: ::prost::alloc::vec::Vec, +} +/// How was the data masked? This isn't necessary for reading the file, but +/// is documentation about how the file was written. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DataMask { + /// the kind of masking, which may include third party masks + #[prost(string, optional, tag = "1")] + pub name: ::core::option::Option<::prost::alloc::string::String>, + /// parameters for the mask + #[prost(string, repeated, tag = "2")] + pub mask_parameters: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// the unencrypted column roots this mask was applied to + #[prost(uint32, repeated, tag = "3")] + pub columns: ::prost::alloc::vec::Vec, +} +/// Information about the encryption keys. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct EncryptionKey { + #[prost(string, optional, tag = "1")] + pub key_name: ::core::option::Option<::prost::alloc::string::String>, + #[prost(uint32, optional, tag = "2")] + pub key_version: ::core::option::Option, + #[prost(enumeration = "EncryptionAlgorithm", optional, tag = "3")] + pub algorithm: ::core::option::Option, +} +/// The description of an encryption variant. +/// Each variant is a single subtype that is encrypted with a single key. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct EncryptionVariant { + /// the column id of the root + #[prost(uint32, optional, tag = "1")] + pub root: ::core::option::Option, + /// The master key that was used to encrypt the local key, referenced as + /// an index into the Encryption.key list. + #[prost(uint32, optional, tag = "2")] + pub key: ::core::option::Option, + /// the encrypted key for the file footer + #[prost(bytes = "vec", optional, tag = "3")] + pub encrypted_key: ::core::option::Option<::prost::alloc::vec::Vec>, + /// the stripe statistics for this variant + #[prost(message, repeated, tag = "4")] + pub stripe_statistics: ::prost::alloc::vec::Vec, + /// encrypted file statistics as a FileStatistics + #[prost(bytes = "vec", optional, tag = "5")] + pub file_statistics: ::core::option::Option<::prost::alloc::vec::Vec>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Encryption { + /// all of the masks used in this file + #[prost(message, repeated, tag = "1")] + pub mask: ::prost::alloc::vec::Vec, + /// all of the keys used in this file + #[prost(message, repeated, tag = "2")] + pub key: ::prost::alloc::vec::Vec, + /// The encrypted variants. + /// Readers should prefer the first variant that the user has access to + /// the corresponding key. If they don't have access to any of the keys, + /// they should get the unencrypted masked data. + #[prost(message, repeated, tag = "3")] + pub variants: ::prost::alloc::vec::Vec, + /// How are the local keys encrypted? + #[prost(enumeration = "KeyProviderKind", optional, tag = "4")] + pub key_provider: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Footer { + #[prost(uint64, optional, tag = "1")] + pub header_length: ::core::option::Option, + #[prost(uint64, optional, tag = "2")] + pub content_length: ::core::option::Option, + #[prost(message, repeated, tag = "3")] + pub stripes: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "4")] + pub types: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "5")] + pub metadata: ::prost::alloc::vec::Vec, + #[prost(uint64, optional, tag = "6")] + pub number_of_rows: ::core::option::Option, + #[prost(message, repeated, tag = "7")] + pub statistics: ::prost::alloc::vec::Vec, + #[prost(uint32, optional, tag = "8")] + pub row_index_stride: ::core::option::Option, + /// Each implementation that writes ORC files should register for a code + /// 0 = ORC Java + /// 1 = ORC C++ + /// 2 = Presto + /// 3 = Scritchley Go from + /// 4 = Trino + #[prost(uint32, optional, tag = "9")] + pub writer: ::core::option::Option, + /// information about the encryption in this file + #[prost(message, optional, tag = "10")] + pub encryption: ::core::option::Option, + #[prost(enumeration = "CalendarKind", optional, tag = "11")] + pub calendar: ::core::option::Option, + /// informative description about the version of the software that wrote + /// the file. It is assumed to be within a given writer, so for example + /// ORC 1.7.2 = "1.7.2". It may include suffixes, such as "-SNAPSHOT". + #[prost(string, optional, tag = "12")] + pub software_version: ::core::option::Option<::prost::alloc::string::String>, +} +/// Serialized length must be less that 255 bytes +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PostScript { + #[prost(uint64, optional, tag = "1")] + pub footer_length: ::core::option::Option, + #[prost(enumeration = "CompressionKind", optional, tag = "2")] + pub compression: ::core::option::Option, + #[prost(uint64, optional, tag = "3")] + pub compression_block_size: ::core::option::Option, + /// the version of the file format + /// \[0, 11\] = Hive 0.11 + /// \[0, 12\] = Hive 0.12 + #[prost(uint32, repeated, tag = "4")] + pub version: ::prost::alloc::vec::Vec, + #[prost(uint64, optional, tag = "5")] + pub metadata_length: ::core::option::Option, + /// The version of the writer that wrote the file. This number is + /// updated when we make fixes or large changes to the writer so that + /// readers can detect whether a given bug is present in the data. + /// + /// Only the Java ORC writer may use values under 6 (or missing) so that + /// readers that predate ORC-202 treat the new writers correctly. Each + /// writer should assign their own sequence of versions starting from 6. + /// + /// Version of the ORC Java writer: + /// 0 = original + /// 1 = HIVE-8732 fixed (fixed stripe/file maximum statistics & + /// string statistics use utf8 for min/max) + /// 2 = HIVE-4243 fixed (use real column names from Hive tables) + /// 3 = HIVE-12055 added (vectorized writer implementation) + /// 4 = HIVE-13083 fixed (decimals write present stream correctly) + /// 5 = ORC-101 fixed (bloom filters use utf8 consistently) + /// 6 = ORC-135 fixed (timestamp statistics use utc) + /// 7 = ORC-517 fixed (decimal64 min/max incorrect) + /// 8 = ORC-203 added (trim very long string statistics) + /// 9 = ORC-14 added (column encryption) + /// + /// Version of the ORC C++ writer: + /// 6 = original + /// + /// Version of the Presto writer: + /// 6 = original + /// + /// Version of the Scritchley Go writer: + /// 6 = original + /// + /// Version of the Trino writer: + /// 6 = original + /// + #[prost(uint32, optional, tag = "6")] + pub writer_version: ::core::option::Option, + /// the number of bytes in the encrypted stripe statistics + #[prost(uint64, optional, tag = "7")] + pub stripe_statistics_length: ::core::option::Option, + /// Leave this last in the record + #[prost(string, optional, tag = "8000")] + pub magic: ::core::option::Option<::prost::alloc::string::String>, +} +/// The contents of the file tail that must be serialized. +/// This gets serialized as part of OrcSplit, also used by footer cache. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FileTail { + #[prost(message, optional, tag = "1")] + pub postscript: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub footer: ::core::option::Option