From 3328b37f7b54351914ccc78196e2808c62da82a2 Mon Sep 17 00:00:00 2001 From: 11136543 Date: Sun, 4 Aug 2024 21:33:38 +0800 Subject: [PATCH 1/4] read encrypted ORC file --- c++/include/orc/Common.hh | 20 ++ c++/include/orc/Reader.hh | 6 +- c++/src/CMakeLists.txt | 8 +- c++/src/Compression.cc | 8 +- c++/src/Compression.hh | 4 + c++/src/Options.hh | 8 + c++/src/Reader.cc | 325 ++++++++++++++++++++++----- c++/src/Reader.hh | 13 +- c++/src/StripeStream.cc | 110 +++++---- c++/src/StripeStream.hh | 218 ++++++++++++++---- c++/src/TypeImpl.cc | 1 - c++/src/io/InputStream.cc | 96 +++++++- c++/src/io/InputStream.hh | 32 ++- c++/src/security/InMemoryKeystore.cc | 247 ++++++++++++++++++++ c++/src/security/InMemoryKeystore.hh | 182 +++++++++++++++ c++/src/security/ReaderEncryption.cc | 278 +++++++++++++++++++++++ c++/src/security/ReaderEncryption.hh | 215 ++++++++++++++++++ 17 files changed, 1616 insertions(+), 155 deletions(-) create mode 100644 c++/src/security/InMemoryKeystore.cc create mode 100644 c++/src/security/InMemoryKeystore.hh create mode 100644 c++/src/security/ReaderEncryption.cc create mode 100644 c++/src/security/ReaderEncryption.hh diff --git a/c++/include/orc/Common.hh b/c++/include/orc/Common.hh index f3cfc64339..725b95112f 100644 --- a/c++/include/orc/Common.hh +++ b/c++/include/orc/Common.hh @@ -224,6 +224,26 @@ namespace orc { * Get the writer timezone. */ virtual const std::string& getWriterTimezone() const = 0; + /** + *Get the list of Localkeys for all columns in the Stripe, with each encrypted column + *corresponding to a Localkey. + * @return + */ + virtual std::shared_ptr>> + getEncryptedLocalKeys() const = 0; + /** + *Get the Localkey for a specific column in this Stripe. + * @param col + * @return + */ + virtual std::vector& getEncryptedLocalKeyByVariantId(int col) const = 0; + /** + * In general, only the first stripe in an ORC file will store the LocalKey.In this case, the + * stripeId and originalStripeId are equal. If an ORC file has multiple stripes storing the + * LocalKey, the values of stripeId and originalStripeId may not be equal. + * @return + */ + virtual long getOriginalStripeId() const = 0; }; // Return true if val1 < val2; otherwise return false diff --git a/c++/include/orc/Reader.hh b/c++/include/orc/Reader.hh index 4b254593ee..1994285695 100644 --- a/c++/include/orc/Reader.hh +++ b/c++/include/orc/Reader.hh @@ -39,7 +39,7 @@ namespace orc { // classes that hold data members so we can maintain binary compatibility struct ReaderOptionsPrivate; struct RowReaderOptionsPrivate; - + class KeyProvider; /** * Expose the reader metrics including the latency and * number of calls of the decompression/decoding/IO modules. @@ -138,6 +138,10 @@ namespace orc { * Get the reader metrics. */ ReaderMetrics* getReaderMetrics() const; + + ReaderOptions& setKeyProvider(std::shared_ptr keyProvider); + + std::shared_ptr getKeyProvider() const; }; /** diff --git a/c++/src/CMakeLists.txt b/c++/src/CMakeLists.txt index 5b7840301c..5f6cbf0711 100644 --- a/c++/src/CMakeLists.txt +++ b/c++/src/CMakeLists.txt @@ -144,7 +144,6 @@ add_custom_command(OUTPUT orc_proto.pb.h orc_proto.pb.cc --cpp_out="${CMAKE_CURRENT_BINARY_DIR}" ../../orc-format_ep-prefix/src/orc-format_ep/src/main/proto/orc/proto/orc_proto.proto ) - set(SOURCE_FILES "${CMAKE_CURRENT_BINARY_DIR}/Adaptor.hh" orc_proto.pb.h @@ -187,7 +186,10 @@ set(SOURCE_FILES Timezone.cc TypeImpl.cc Vector.cc - Writer.cc) + Writer.cc + security/InMemoryKeystore.cc + security/ReaderEncryption.cc +) if(BUILD_LIBHDFSPP) set(SOURCE_FILES ${SOURCE_FILES} OrcHdfsFile.cc) @@ -211,6 +213,8 @@ target_link_libraries (orc $ $ $ + ${OPENSSL_CRYPTO_LIBRARY} + ${OPENSSL_SSL_LIBRARY} ) target_include_directories (orc diff --git a/c++/src/Compression.cc b/c++/src/Compression.cc index 535018dcb6..e02afc81d3 100644 --- a/c++/src/Compression.cc +++ b/c++/src/Compression.cc @@ -1272,7 +1272,13 @@ namespace orc { throw NotImplementedYet("compression codec"); } } - + std::unique_ptr createDecompressorAndDecryption( + CompressionKind kind, std::unique_ptr input, uint64_t blockSize, + MemoryPool& pool, ReaderMetrics* metrics,std::vector key, + std::vector iv,const EVP_CIPHER* cipher){ + auto dec = std::make_unique(std::move(input),key,iv,cipher,pool); + return createDecompressor(kind,std::move(dec),blockSize,pool,metrics); + } std::unique_ptr createDecompressor( CompressionKind kind, std::unique_ptr input, uint64_t blockSize, MemoryPool& pool, ReaderMetrics* metrics) { diff --git a/c++/src/Compression.hh b/c++/src/Compression.hh index 24170c56b4..361beece63 100644 --- a/c++/src/Compression.hh +++ b/c++/src/Compression.hh @@ -36,6 +36,10 @@ namespace orc { CompressionKind kind, std::unique_ptr input, uint64_t bufferSize, MemoryPool& pool, ReaderMetrics* metrics); + std::unique_ptr createDecompressorAndDecryption(CompressionKind kind, std::unique_ptr input, uint64_t blockSize, + MemoryPool& pool, ReaderMetrics* metrics,std::vector key, + std::vector iv,const EVP_CIPHER* cipher); + /** * Create a compressor for the given compression kind. * @param kind the compression type to implement diff --git a/c++/src/Options.hh b/c++/src/Options.hh index daf9d52e1c..ca161b6c51 100644 --- a/c++/src/Options.hh +++ b/c++/src/Options.hh @@ -43,6 +43,7 @@ namespace orc { MemoryPool* memoryPool; std::string serializedTail; ReaderMetrics* metrics; + std::shared_ptr keyProvider; ReaderOptionsPrivate() { tailLocation = std::numeric_limits::max(); @@ -121,7 +122,14 @@ namespace orc { std::ostream* ReaderOptions::getErrorStream() const { return privateBits_->errorStream; } + ReaderOptions& ReaderOptions::setKeyProvider(std::shared_ptr keyProvider) { + privateBits_->keyProvider = keyProvider; + return *this; + } + std::shared_ptr ReaderOptions::getKeyProvider() const { + return privateBits_->keyProvider; + } /** * RowReaderOptions Implementation */ diff --git a/c++/src/Reader.cc b/c++/src/Reader.cc index 2966c2c2ee..3c185a4cbe 100644 --- a/c++/src/Reader.cc +++ b/c++/src/Reader.cc @@ -347,6 +347,10 @@ namespace orc { return contents_->blockSize; } + ReaderEncryption* RowReaderImpl::getReaderEncryption() const { + return contents_->encryption.get(); + } + const std::vector RowReaderImpl::getSelectedColumns() const { return selectedColumns_; } @@ -442,40 +446,62 @@ namespace orc { // obtain row indexes for selected columns uint64_t offset = currentStripeInfo_.offset(); - for (int i = 0; i < currentStripeFooter_.streams_size(); ++i) { - const proto::Stream& pbStream = currentStripeFooter_.streams(i); - uint64_t colId = pbStream.column(); - if (selectedColumns_[colId] && pbStream.has_kind() && - (pbStream.kind() == proto::Stream_Kind_ROW_INDEX || - pbStream.kind() == proto::Stream_Kind_BLOOM_FILTER_UTF8)) { - std::unique_ptr inStream = createDecompressor( - getCompression(), - std::unique_ptr(new SeekableFileInputStream( - contents_->stream.get(), offset, pbStream.length(), *contents_->pool)), - getCompressionSize(), *contents_->pool, contents_->readerMetrics); - - if (pbStream.kind() == proto::Stream_Kind_ROW_INDEX) { + std::vector> streams; + // Obtain a new list of index stream objects. + StripeStreamsImpl::findStreamsByArea(const_cast(currentStripeFooter_), + offset, Area::INDEX, contents_->encryption.get(), streams); + int num_streams = streams.size(); + for (int i = 0; i < num_streams; ++i) { + // const proto::Stream& pbStream = currentStripeFooter.streams(i); + const StreamInformation* stream = streams.at(i).get(); + proto::Stream_Kind streamKind = static_cast(stream->getKind()); + uint64_t columnId = stream->getColumnId(); + if (selectedColumns_[columnId] && (streamKind == proto::Stream_Kind_ROW_INDEX || + streamKind == proto::Stream_Kind_BLOOM_FILTER_UTF8)) { + auto inputStream = std::unique_ptr(new SeekableFileInputStream( + contents_->stream.get(), stream->getOffset(), stream->getLength(), *contents_->pool)); + ReaderEncryptionVariant* variant = contents_->encryption->getVariant(columnId); + std::unique_ptr pbStream = nullptr; + if (variant != nullptr) { + ReaderEncryptionKey* encryptionKey = variant->getKeyDescription(); + const int ivLength = encryptionKey->getAlgorithm()->getIvLength(); + std::vector iv(ivLength); + const EVP_CIPHER* cipher = encryptionKey->getAlgorithm()->createCipher(); + long originalStripeId = contents_->stripeList.at(currentStripe_)->getOriginalStripeId(); + orc::CryptoUtil::modifyIvForStream(columnId, proto::Stream_Kind_ROW_INDEX, + originalStripeId, iv.data(), ivLength); + //Get the FooterKey of this column + std::vector key = variant->getStripeKey(currentStripe_)->getEncoded(); + pbStream = createDecompressorAndDecryption( + getCompression(), std::move(inputStream), getCompressionSize(), *contents_->pool, + contents_->readerMetrics, key, iv, const_cast(cipher)); + } else { + pbStream = + createDecompressor(contents_->compression, std::move(inputStream), contents_->blockSize, + *(contents_->pool), contents_->readerMetrics); + } + if (streamKind == proto::Stream_Kind_ROW_INDEX) { proto::RowIndex rowIndex; - if (!rowIndex.ParseFromZeroCopyStream(inStream.get())) { + if (!rowIndex.ParseFromZeroCopyStream(pbStream.get())) { throw ParseError("Failed to parse the row index"); } - rowIndexes_[colId] = rowIndex; + rowIndexes_[columnId] = rowIndex; } else if (!skipBloomFilters_) { // Stream_Kind_BLOOM_FILTER_UTF8 proto::BloomFilterIndex pbBFIndex; - if (!pbBFIndex.ParseFromZeroCopyStream(inStream.get())) { + if (!pbBFIndex.ParseFromZeroCopyStream(pbStream.get())) { throw ParseError("Failed to parse bloom filter index"); } BloomFilterIndex bfIndex; for (int j = 0; j < pbBFIndex.bloom_filter_size(); j++) { bfIndex.entries.push_back(BloomFilterUTF8Utils::deserialize( - pbStream.kind(), currentStripeFooter_.columns(static_cast(pbStream.column())), + streamKind, currentStripeFooter_.columns(static_cast(columnId)), pbBFIndex.bloom_filter(j))); } // add bloom filters to result for one column - bloomFilterIndex_[pbStream.column()] = bfIndex; + bloomFilterIndex_[columnId] = bfIndex; } } - offset += pbStream.length(); + // offset += pbStream.length(); } } @@ -554,6 +580,16 @@ namespace orc { contents_->schema = convertType(footer_->types(0), *footer_); contents_->blockSize = getCompressionBlockSize(*contents_->postscript); contents_->compression = convertCompressionKind(*contents_->postscript); + //Encapsulate the stream information, it will be needed for decrypting the data later. + parseStripeList_(); + if (options_.getKeyProvider() != nullptr && footer_->has_encryption()) { + contents_->encryption = std::unique_ptr(new ReaderEncryption( + contents_, getEncryptStripeStatisticsOffset(), options_.getKeyProvider())); + //If there is an encrypted column, the FileStat of the Footer needs to be updated. + updateCryptedFileStatistics_(contents_); + } else { + contents_->encryption = std::unique_ptr(new ReaderEncryption()); + } } std::string ReaderImpl::getSerializedFileTail() const { @@ -607,7 +643,19 @@ namespace orc { stripeInfo.footer_length(), stripeInfo.number_of_rows(), contents_->stream.get(), *contents_->pool, contents_->compression, contents_->blockSize, contents_->readerMetrics)); } - + void ReaderImpl::parseStripeList_() { + long previousStripeId = 0; + std::shared_ptr>> previousKeys; + for (int i = 0; i < contents_->footer->stripes_size(); i++) { + proto::StripeInformation stripeProto = contents_->footer->stripes(i); + StripeInformation* stripe = new StripeInformationImpl( + &stripeProto,contents_->encryption.get(), previousStripeId, previousKeys, contents_->stream.get(), + *contents_->pool, contents_->compression, contents_->blockSize, contents_->readerMetrics); + contents_->stripeList.push_back(std::unique_ptr(stripe)); + previousStripeId = stripe->getOriginalStripeId(); + previousKeys = stripe->getEncryptedLocalKeys(); + } + } FileVersion ReaderImpl::getFormatVersion() const { if (contents_->postscript->version_size() != 2) { return FileVersion::v_0_11(); @@ -672,6 +720,11 @@ namespace orc { return fileLength_; } + uint64_t ReaderImpl::getEncryptStripeStatisticsOffset() const { + return getFileLength() - 1 - getFilePostscriptLength() - getFileFooterLength() - + getStripeStatisticsLength() - getPostscript()->stripe_statistics_length(); + } + uint64_t ReaderImpl::getRowIndexStride() const { return footer_->row_index_stride(); } @@ -701,41 +754,63 @@ namespace orc { const proto::StripeInformation& stripeInfo, uint64_t stripeIndex, const proto::StripeFooter& currentStripeFooter, std::vector>* indexStats) const { - int num_streams = currentStripeFooter.streams_size(); - uint64_t offset = stripeInfo.offset(); - uint64_t indexEnd = stripeInfo.offset() + stripeInfo.index_length(); - for (int i = 0; i < num_streams; i++) { - const proto::Stream& stream = currentStripeFooter.streams(i); - StreamKind streamKind = static_cast(stream.kind()); - uint64_t length = static_cast(stream.length()); - if (streamKind == StreamKind::StreamKind_ROW_INDEX) { - if (offset + length > indexEnd) { - std::stringstream msg; - msg << "Malformed RowIndex stream meta in stripe " << stripeIndex - << ": streamOffset=" << offset << ", streamLength=" << length - << ", stripeOffset=" << stripeInfo.offset() - << ", stripeIndexLength=" << stripeInfo.index_length(); - throw ParseError(msg.str()); - } - std::unique_ptr pbStream = - createDecompressor(contents_->compression, - std::unique_ptr(new SeekableFileInputStream( - contents_->stream.get(), offset, length, *contents_->pool)), - contents_->blockSize, *(contents_->pool), contents_->readerMetrics); - - proto::RowIndex rowIndex; - if (!rowIndex.ParseFromZeroCopyStream(pbStream.get())) { - throw ParseError("Failed to parse RowIndex from stripe footer"); - } - int num_entries = rowIndex.entry_size(); - size_t column = static_cast(stream.column()); - for (int j = 0; j < num_entries; j++) { - const proto::RowIndexEntry& entry = rowIndex.entry(j); - (*indexStats)[column].push_back(entry.statistics()); - } - } - offset += length; - } + uint64_t offset = stripeInfo.offset(); + uint64_t indexEnd = stripeInfo.offset() + stripeInfo.index_length(); + std::vector> streams; + // Search for the index stream, first encrypted and then unencrypted. + StripeStreamsImpl::findStreamsByArea(const_cast(currentStripeFooter), + offset, Area::INDEX, contents_->encryption.get(), streams); + int num_streams = streams.size(); + for (int i = 0; i < num_streams; i++) { + const StreamInformation* stream = streams.at(i).get(); + StreamKind streamKind = static_cast(stream->getKind()); + uint64_t length = static_cast(stream->getLength()); + if (streamKind == StreamKind::StreamKind_ROW_INDEX) { + if (offset + length > indexEnd) { + std::stringstream msg; + msg << "Malformed RowIndex stream meta in stripe " << stripeIndex + << ": streamOffset=" << offset << ", streamLength=" << length + << ", stripeOffset=" << stripeInfo.offset() + << ", stripeIndexLength=" << stripeInfo.index_length(); + throw ParseError(msg.str()); + } + auto inputStream = std::unique_ptr(new SeekableFileInputStream( + contents_->stream.get(), stream->getOffset(), stream->getLength(), *contents_->pool)); + int columnId = stream->getColumnId(); + // Check if this column is encrypted. + ReaderEncryptionVariant* variant = contents_->encryption->getVariant(columnId); + std::unique_ptr pbStream = nullptr; + if (variant != nullptr) { + ReaderEncryptionKey* encryptionKey = variant->getKeyDescription(); + const int ivLength = encryptionKey->getAlgorithm()->getIvLength(); + std::vector iv(ivLength); + const EVP_CIPHER* cipher = encryptionKey->getAlgorithm()->createCipher(); + long originalStripeId = contents_->stripeList.at(stripeIndex)->getOriginalStripeId(); + orc::CryptoUtil::modifyIvForStream(columnId, proto::Stream_Kind_ROW_INDEX, + originalStripeId, iv.data(), ivLength); + // Obtain the FooterKey for this column + std::vector key = variant->getStripeKey(stripeIndex)->getEncoded(); + pbStream = createDecompressorAndDecryption( + getCompression(), std::move(inputStream), getCompressionSize(), *contents_->pool, + contents_->readerMetrics, key, iv, const_cast(cipher)); + } else { + pbStream = + createDecompressor(contents_->compression, std::move(inputStream), contents_->blockSize, + *(contents_->pool), contents_->readerMetrics); + } + proto::RowIndex rowIndex; + if (!rowIndex.ParseFromZeroCopyStream(pbStream.get())) { + throw ParseError("Failed to parse RowIndex from stripe footer"); + } + int num_entries = rowIndex.entry_size(); + size_t column = static_cast(stream->getColumnId()); + for (int j = 0; j < num_entries; j++) { + const proto::RowIndexEntry& entry = rowIndex.entry(j); + (*indexStats)[column].push_back(entry.statistics()); + } + } + offset += length; + } } bool ReaderImpl::hasMetadataValue(const std::string& key) const { @@ -750,7 +825,80 @@ namespace orc { const Type& ReaderImpl::getType() const { return *(contents_->schema.get()); } - + proto::ColumnStatistics* getStripeStatisticsFromVariant_(std::shared_ptr contents, + ReaderEncryptionVariant* variant, uint64_t stripeIndex) { + orc::proto::ColumnarStripeStatistics* stripeStatCache = variant->getAllEncryptStripeStatistics(); + // If the strip stat for this variant has been previously read + if (stripeStatCache != nullptr) { + return stripeStatCache->mutable_col_stats(stripeIndex); + } + // create the objects + uint64_t offset = variant->getStripeStatsOffset(); + // int root = variant->getColumn()->getColumnId(); + const proto::EncryptionVariant& vproto = variant->getEncryptionVariantProto(); + int size = vproto.stripe_statistics_size(); + for (int i = 0; i < size; i++) { + proto::Stream stream = vproto.stripe_statistics(i); + uint64_t length = stream.length(); + // int col = stream.column(); + proto::Stream_Kind kind = stream.kind(); + if (kind == proto::Stream_Kind::Stream_Kind_STRIPE_STATISTICS) { + const EVP_CIPHER* cipher = variant->getKeyDescription()->getAlgorithm()->createCipher(); + const int ivLength = variant->getKeyDescription()->getAlgorithm()->getIvLength(); + // Retrieve the FooterKey of this column." + std::vector key = variant->getFileFooterKey()->getEncoded(); + // Initialize IV + std::vector iv(ivLength); + orc::CryptoUtil::modifyIvForStream(variant->getRoot()->getColumnId(), + proto::Stream_Kind_STRIPE_STATISTICS, + contents->stripeList.size() + 1, iv.data(), ivLength); + auto inputStream = std::make_unique( + contents->stream.get(), offset, length, *contents->pool, length); + // Put the data into the decryption stream. + std::unique_ptr decompressStream = createDecompressorAndDecryption( + contents->compression, std::move(inputStream), contents->blockSize, *contents->pool, + contents->readerMetrics, key, iv, const_cast(cipher)); + // Read the decrypted and decompressed raw data, and perform deserialization. + orc::proto::ColumnarStripeStatistics* columnarStripeStatistics = + new orc::proto::ColumnarStripeStatistics(); + bool success = columnarStripeStatistics->ParseFromZeroCopyStream(decompressStream.get()); + if (success) { + // Cache the result. + variant->setAllEncryptStripeStatistics(columnarStripeStatistics); + return columnarStripeStatistics->mutable_col_stats(stripeIndex); + } else { + throw ParseError("decryptFileStats error " + std::to_string(variant->getVariantId())); + } + } + offset += length; + } + return nullptr; + } + void updateCryptedStripeStatistics_(std::shared_ptr contents,uint64_t stripeIndex) { + ::orc::proto::StripeStatistics* statistics = + contents->metadata->mutable_stripe_stats(static_cast(stripeIndex)); + if (contents->encryption->getVariants().size() > 0) { + for (uint64_t c = contents->schema->getColumnId(); + c <= contents->schema->getMaximumColumnId(); ++c) { + ReaderEncryptionVariant* variant = contents->encryption->getVariant(c); + if (variant != nullptr) { + Type* variantType = variant->getRoot(); + // Read the element at the stripeIndex of the c-th column. + ::orc::proto::ColumnStatistics* colStats = + getStripeStatisticsFromVariant_(contents,variant,stripeIndex); + if (colStats != nullptr) { + for (uint64_t sub = c; sub <= variantType->getMaximumColumnId(); ++sub) { + //Update the statistic of the c-th column of the stripeIndex-th stripe. + ::orc::proto::ColumnStatistics* resultElem = statistics->mutable_col_stats(c); + resultElem->Clear(); + resultElem->CopyFrom(*colStats); + } + c = variantType->getMaximumColumnId(); + } + } + } + } + } std::unique_ptr ReaderImpl::getStripeStatistics(uint64_t stripeIndex) const { if (!isMetadataLoaded_) { readMetadata(); @@ -811,6 +959,10 @@ namespace orc { if (!contents_->metadata->ParseFromZeroCopyStream(pbStream.get())) { throw ParseError("Failed to parse the metadata"); } + int size = contents_->metadata->stripe_stats().size(); + for(int i= 0; istripeList.at(currentStripe_).get(); + StripeStreamsImpl stripeStreams(*this, currentStripe_,stripeInformation->getOriginalStripeId(), + currentStripeInfo_,currentStripeFooter_, currentStripeInfo_.offset(), *contents_->stream, writerTimezone, readerTimezone_); reader_ = buildReader(*contents_->schema, stripeStreams, useTightNumericVector_, throwOnSchemaEvolutionOverflow_, /*convertToReadType=*/true); @@ -1218,7 +1371,6 @@ namespace orc { getColumnIds(type->getSubtype(i), columnIds); } } - std::unique_ptr RowReaderImpl::createRowBatch(uint64_t capacity) const { // If the read type is specified, then check that the selected schema matches the read type // on the first call to createRowBatch. @@ -1485,5 +1637,58 @@ namespace orc { InputStream::~InputStream(){ // PASS }; - + orc::proto::FileStatistics* decryptFileStats(std::shared_ptr contents, + ReaderEncryptionVariant* variant){ + //find from cache + orc::proto::FileStatistics* fileStatCache = variant->getFileStatistics(); + if(fileStatCache != nullptr){ + return variant->getFileStatistics(); + } + const EVP_CIPHER* cipher = variant->getKeyDescription()->getAlgorithm()->createCipher(); + std::string stats = contents->footer->encryption().variants(variant->getVariantId()).file_statistics(); + // Place the raw data into a SeekableArrayInputStream. + auto inputStream = std::make_unique(stats.c_str(), stats.size(), + contents->blockSize); + const int ivLength = variant->getKeyDescription()->getAlgorithm()->getIvLength(); + std::vector iv(ivLength); + orc::CryptoUtil::modifyIvForStream(variant->getRoot()->getColumnId(), + proto::Stream_Kind_FILE_STATISTICS, + contents->footer->stripes_size() + 1, iv.data(), ivLength); + + std::vector key = variant->getFileFooterKey()->getEncoded(); + std::unique_ptr decompressStream = createDecompressorAndDecryption( + contents->compression, std::move(inputStream), contents->blockSize, *contents->pool, + contents->readerMetrics, key, iv, const_cast(cipher)); + + //Read, decrypt, decompress the raw data, and perform deserialization. + orc::proto::FileStatistics* fileStatistics = new orc::proto::FileStatistics(); + bool success = fileStatistics->ParseFromZeroCopyStream(decompressStream.get()); + if (success) { + variant->setFileStatistics(fileStatistics); + return fileStatistics; + } else { + throw ParseError("decryptFileStats error " + std::to_string(variant->getVariantId())); + } + } + void updateCryptedFileStatistics_(std::shared_ptr contents){ + for(int i = 0; i < contents->footer->statistics_size(); i++) { + ReaderEncryptionVariant* variant = contents->encryption->getVariant(i); + if (variant != nullptr) { + //The lifecycle of fileStatistics is now managed by the variant + orc::proto::FileStatistics* fileStatistics = decryptFileStats(contents,variant); + ::orc::proto::ColumnStatistics* old = contents->footer->mutable_statistics(static_cast(i)); + old->Clear(); + old->CopyFrom(fileStatistics->column(0)); + //return fileStatistics->mutable_column(0); + } + } + } + std::vector getAllFileStat(std::shared_ptr contents){ + std::vector result ; + for(int i = 0; i < contents->footer->statistics_size(); i++) { + orc::proto::ColumnStatistics* stat = contents->footer->mutable_statistics(i); + result.push_back(stat); + } + return result; + } } // namespace orc diff --git a/c++/src/Reader.hh b/c++/src/Reader.hh index 630d812c38..4fe5866dd5 100644 --- a/c++/src/Reader.hh +++ b/c++/src/Reader.hh @@ -29,6 +29,8 @@ #include "SchemaEvolution.hh" #include "TypeImpl.hh" #include "sargs/SargsApplier.hh" +#include "io/InputStream.hh" +#include "security/ReaderEncryption.hh" namespace orc { @@ -70,6 +72,10 @@ namespace orc { bool isDecimalAsLong; std::unique_ptr metadata; ReaderMetrics* readerMetrics; + std::vector> stripeList; + // encryption + std::unique_ptr encryption; + std::unique_ptr tail; }; proto::StripeFooter getStripeFooter(const proto::StripeInformation& info, @@ -245,6 +251,7 @@ namespace orc { const SchemaEvolution* getSchemaEvolution() const { return &schemaEvolution_; } + ReaderEncryption* getReaderEncryption() const; }; class ReaderImpl : public Reader { @@ -269,7 +276,7 @@ namespace orc { const proto::StripeInformation& stripeInfo, uint64_t stripeIndex, const proto::StripeFooter& currentStripeFooter, std::vector >* indexStats) const; - + void parseStripeList_(); // metadata mutable bool isMetadataLoaded_; @@ -329,6 +336,7 @@ namespace orc { uint64_t getFileFooterLength() const override; uint64_t getFilePostscriptLength() const override; uint64_t getFileLength() const override; + uint64_t getEncryptStripeStatisticsOffset() const; std::unique_ptr getStatistics() const override; @@ -375,6 +383,9 @@ namespace orc { std::map getBloomFilters( uint32_t stripeIndex, const std::set& included) const override; }; + orc::proto::FileStatistics* decryptFileStats(std::shared_ptr contents,ReaderEncryptionVariant* variant); + void updateCryptedFileStatistics_(std::shared_ptr contents); + std::vector getAllFileStat(std::shared_ptr contents); } // namespace orc #endif diff --git a/c++/src/StripeStream.cc b/c++/src/StripeStream.cc index f4345c0871..a7ab0f4e91 100644 --- a/c++/src/StripeStream.cc +++ b/c++/src/StripeStream.cc @@ -25,20 +25,39 @@ namespace orc { - StripeStreamsImpl::StripeStreamsImpl(const RowReaderImpl& reader, uint64_t index, - const proto::StripeInformation& stripeInfo, - const proto::StripeFooter& footer, uint64_t stripeStart, - InputStream& input, const Timezone& writerTimezone, - const Timezone& readerTimezone) - : reader_(reader), - stripeInfo_(stripeInfo), - footer_(footer), - stripeIndex_(index), - stripeStart_(stripeStart), - input_(input), - writerTimezone_(writerTimezone), - readerTimezone_(readerTimezone) { - // PASS + StripeStreamsImpl::StripeStreamsImpl(const RowReaderImpl& _reader, uint64_t _index,long originalStripeId, + const proto::StripeInformation& _stripeInfo, + const proto::StripeFooter& _footer, uint64_t _stripeStart, + InputStream& _input, const Timezone& _writerTimezone, + const Timezone& _readerTimezone) + : reader_(_reader), + stripeInfo_(_stripeInfo), + footer_(_footer), + stripeIndex_(_index), + originalStripeId(originalStripeId), + stripeStart_(_stripeStart), + input_(_input), + writerTimezone_(_writerTimezone), + readerTimezone_(_readerTimezone) { + // +-----------------+---------------+-----------------+---------------+ + // | | | | | + // | unencrypted | encrypted | unencrypted | encrypted | + // | index | index | data | data | + // | | | | | + // +-----------------+---------------+-----------------+---------------+ + // The above refers to the storage layout of indexes and data, hence we need to follow this order to seek the stream. + // Look for the index stream, first encrypted stream and then unencrypted stream. + long currentOffset = _stripeStart; + currentOffset = StripeStreamsImpl::findStreamsByArea(const_cast(footer_),currentOffset, Area::INDEX,reader_.getReaderEncryption(),streams); + //Look for the data stream, first the encrypted stream and then the unencrypted stream. + findStreamsByArea(const_cast(footer_),currentOffset, Area::DATA,reader_.getReaderEncryption(),streams); + + for (size_t i = 0; i < streams.size(); i++) { + std::shared_ptr stream = streams.at(i); + std::string key = + std::to_string(stream->getColumnId()) + ":" + std::to_string(stream->getKind()); + streamMap.emplace(key, std::shared_ptr(stream)); + } } StripeStreamsImpl::~StripeStreamsImpl() { @@ -62,7 +81,14 @@ namespace orc { } proto::ColumnEncoding StripeStreamsImpl::getEncoding(uint64_t columnId) const { - return footer_.columns(static_cast(columnId)); + // The encoding of encrypted columns needs to be obtained in a special way. + ReaderEncryptionVariant* variant = this->reader_.getFileContents().encryption->getVariant(columnId); + if(variant != nullptr){ + int subColumn = columnId - variant->getRoot()->getColumnId(); + return footer_.encryption().Get(variant->getVariantId()).encoding(subColumn); + }else{ + return footer_.columns(static_cast(columnId)); + } } const Timezone& StripeStreamsImpl::getWriterTimezone() const { @@ -79,35 +105,33 @@ namespace orc { std::unique_ptr StripeStreamsImpl::getStream(uint64_t columnId, proto::Stream_Kind kind, - bool shouldStream) const { - uint64_t offset = stripeStart_; - uint64_t dataEnd = - stripeInfo_.offset() + stripeInfo_.index_length() + stripeInfo_.data_length(); - MemoryPool* pool = reader_.getFileContents().pool; - for (int i = 0; i < footer_.streams_size(); ++i) { - const proto::Stream& stream = footer_.streams(i); - if (stream.has_kind() && stream.kind() == kind && - stream.column() == static_cast(columnId)) { - uint64_t streamLength = stream.length(); - uint64_t myBlock = shouldStream ? input_.getNaturalReadSize() : streamLength; - if (offset + streamLength > dataEnd) { - std::stringstream msg; - msg << "Malformed stream meta at stream index " << i << " in stripe " << stripeIndex_ - << ": streamOffset=" << offset << ", streamLength=" << streamLength - << ", stripeOffset=" << stripeInfo_.offset() - << ", stripeIndexLength=" << stripeInfo_.index_length() - << ", stripeDataLength=" << stripeInfo_.data_length(); - throw ParseError(msg.str()); - } - return createDecompressor(reader_.getCompression(), - std::make_unique( - &input_, offset, stream.length(), *pool, myBlock), - reader_.getCompressionSize(), *pool, - reader_.getFileContents().readerMetrics); - } - offset += stream.length(); + bool shouldStream) const { MemoryPool* pool = reader_.getFileContents().pool; + const std::string skey = std::to_string(columnId) + ":" + std::to_string(kind); + StreamInformation* streamInformation = streamMap[skey].get(); + if(streamInformation == nullptr){ + return nullptr; + } + uint64_t myBlock = shouldStream ? input_.getNaturalReadSize() : streamInformation->getLength(); + auto inputStream = std::make_unique( + &input_, streamInformation->getOffset(), streamInformation->getLength(), *pool, myBlock); + ReaderEncryptionVariant* variant = reader_.getReaderEncryption()->getVariant(columnId); + if (variant != nullptr) { + ReaderEncryptionKey* encryptionKey = variant->getKeyDescription(); + const int ivLength = encryptionKey->getAlgorithm()->getIvLength(); + std::vector iv(ivLength); + orc::CryptoUtil::modifyIvForStream(columnId, kind, originalStripeId, iv.data(), ivLength); + const EVP_CIPHER* cipher = encryptionKey->getAlgorithm()->createCipher(); + std::vector key = variant->getStripeKey(stripeIndex_)->getEncoded(); + std::unique_ptr decompressStream = createDecompressorAndDecryption( + reader_.getCompression(), std::move(inputStream), reader_.getCompressionSize(), *pool, + reader_.getFileContents().readerMetrics, key, + iv, const_cast(cipher)); + return decompressStream; + } else { + return createDecompressor(reader_.getCompression(), + std::move(inputStream), + reader_.getCompressionSize(), *pool,reader_.getFileContents().readerMetrics); } - return nullptr; } MemoryPool& StripeStreamsImpl::getMemoryPool() const { diff --git a/c++/src/StripeStream.hh b/c++/src/StripeStream.hh index ad82d472c2..b23cb08100 100644 --- a/c++/src/StripeStream.hh +++ b/c++/src/StripeStream.hh @@ -26,10 +26,59 @@ #include "ColumnReader.hh" #include "Timezone.hh" #include "TypeImpl.hh" +#include "security/ReaderEncryption.hh" namespace orc { class RowReaderImpl; + /** + * StreamInformation Implementation + */ + class StreamInformationImpl : public StreamInformation { + private: + StreamKind kind; + uint64_t column; + uint64_t offset; + uint64_t length; + proto::Stream_Kind originalKind; + + public: + StreamInformationImpl(uint64_t _offset, const proto::Stream& stream) + : kind(static_cast(stream.kind())), + column(stream.column()), + offset(_offset), + length(stream.length()) { + // PASS + } + StreamInformationImpl(proto::Stream_Kind originalKind, uint64_t column, uint64_t _offset, + uint64_t length) + : kind(static_cast(originalKind)), + column(column), + offset(_offset), + length(length), + originalKind(originalKind) { + // PASS + } + ~StreamInformationImpl() override; + + StreamKind getKind() const override { + return kind; + } + + uint64_t getColumnId() const override { + return column; + } + + uint64_t getOffset() const override { + return offset; + } + + uint64_t getLength() const override { + return length; + } + }; + + enum class Area { DATA, INDEX, FOOTER }; /** * StripeStream Implementation @@ -41,13 +90,39 @@ namespace orc { const proto::StripeInformation& stripeInfo_; const proto::StripeFooter& footer_; const uint64_t stripeIndex_; + long originalStripeId = 0; const uint64_t stripeStart_; InputStream& input_; const Timezone& writerTimezone_; const Timezone& readerTimezone_; - + mutable std::map> streamMap; + std::vector> streams; + static long handleStream(long offset, const proto::Stream& stream, Area area, + ReaderEncryptionVariant* variant, ReaderEncryption* encryption, + std::vector>& streams) { + int column = stream.column(); + if (stream.has_kind()) { + proto::Stream_Kind kind = stream.kind(); + // If there are no encrypted columns + /* if (encryption->getKeys().empty()) { + StreamInformationImpl* info = new StreamInformationImpl(kind, column, offset, stream.length()); + streams.push_back(std::shared_ptr(info)); + return stream.length(); + }*/ + if (getArea(kind) != area || kind == proto::Stream_Kind::Stream_Kind_ENCRYPTED_INDEX || + kind == proto::Stream_Kind::Stream_Kind_ENCRYPTED_DATA) { + //Ignore the placeholder that should not be included in the offset calculation. + return 0; + } + if (encryption->getVariant(column) == variant) { + StreamInformationImpl* info = new StreamInformationImpl(kind, column, offset, stream.length()); + streams.push_back(std::shared_ptr(info)); + } + } + return stream.length(); + } public: - StripeStreamsImpl(const RowReaderImpl& reader, uint64_t index, + StripeStreamsImpl(const RowReaderImpl& reader, uint64_t index, long originalStripeId, const proto::StripeInformation& stripeInfo, const proto::StripeFooter& footer, uint64_t stripeStart, InputStream& input, const Timezone& writerTimezone, const Timezone& readerTimezone); @@ -79,44 +154,42 @@ namespace orc { int32_t getForcedScaleOnHive11Decimal() const override; const SchemaEvolution* getSchemaEvolution() const override; - }; - - /** - * StreamInformation Implementation - */ - - class StreamInformationImpl : public StreamInformation { - private: - StreamKind kind_; - uint64_t column_; - uint64_t offset_; - uint64_t length_; - - public: - StreamInformationImpl(uint64_t offset, const proto::Stream& stream) - : kind_(static_cast(stream.kind())), - column_(stream.column()), - offset_(offset), - length_(stream.length()) { - // PASS - } - - ~StreamInformationImpl() override; - - StreamKind getKind() const override { - return kind_; - } - - uint64_t getColumnId() const override { - return column_; - } - - uint64_t getOffset() const override { - return offset_; + // + static Area getArea(proto::Stream_Kind kind) { + switch (kind) { + case proto::Stream_Kind::Stream_Kind_FILE_STATISTICS: + case proto::Stream_Kind::Stream_Kind_STRIPE_STATISTICS: + return Area::FOOTER; + case proto::Stream_Kind::Stream_Kind_ROW_INDEX: + case proto::Stream_Kind::Stream_Kind_DICTIONARY_COUNT: + case proto::Stream_Kind::Stream_Kind_BLOOM_FILTER: + case proto::Stream_Kind::Stream_Kind_BLOOM_FILTER_UTF8: + case proto::Stream_Kind::Stream_Kind_ENCRYPTED_INDEX: + return Area::INDEX; + default: + return Area::DATA; + } } - - uint64_t getLength() const override { - return length_; + static long findStreamsByArea(proto::StripeFooter& footer, long currentOffset, Area area, + ReaderEncryption* encryption, + std::vector>& streams) { + // Look for the unencrypted stream. + for (const proto::Stream& stream : footer.streams()) { + currentOffset += handleStream(currentOffset, stream, area, nullptr, encryption, streams); + } + //If there are encrypted columns + if (!encryption->getKeys().empty()) { + std::vector>& vList = encryption->getVariants(); + for (std::vector::size_type i = 0; i < vList.size(); i++) { + ReaderEncryptionVariant* variant = vList.at(i).get(); + int variantId = variant->getVariantId(); + const proto::StripeEncryptionVariant& stripeVariant = footer.encryption(variantId); + for (const proto::Stream& stream : stripeVariant.streams()) { + currentOffset += handleStream(currentOffset, stream, area, variant, encryption, streams); + } + } + } + return currentOffset; } }; @@ -137,6 +210,9 @@ namespace orc { mutable std::unique_ptr stripeFooter_; ReaderMetrics* metrics_; void ensureStripeFooterLoaded() const; + std::shared_ptr>> encryptedKeys; + long originalStripeId = 0; + ReaderEncryption* encryption; public: StripeInformationImpl(uint64_t offset, uint64_t indexLength, uint64_t dataLength, @@ -155,7 +231,42 @@ namespace orc { metrics_(metrics) { // PASS } - + StripeInformationImpl(proto::StripeInformation* stripeInfo, ReaderEncryption* encryption, + long previousOriginalStripeId, + std::shared_ptr>> previousKeys, InputStream* _stream, + MemoryPool& _memory, CompressionKind _compression, uint64_t _blockSize, + ReaderMetrics* _metrics) + : offset_(stripeInfo->offset()), + indexLength_(stripeInfo->index_length()), + dataLength_(stripeInfo->data_length()), + footerLength_(stripeInfo->footer_length()), + numRows_(stripeInfo->number_of_rows()), + stream_(_stream), + memory_(_memory), + compression_(_compression), + blockSize_(_blockSize), + metrics_(_metrics), + encryption(encryption) { + // It is usually the first strip that has this value. + if (stripeInfo->has_encrypt_stripe_id()) { + originalStripeId = stripeInfo->encrypt_stripe_id(); + } else { + originalStripeId = previousOriginalStripeId + 1; + } + // The value is generally present in the first strip. + // Each encrypted column corresponds to a key. + if (stripeInfo->encrypted_local_keys_size() != 0) { + encryptedKeys = std::shared_ptr>>( + new std::vector>()); + for (int i = 0; i < static_cast(stripeInfo->encrypted_local_keys_size()); i++) { + std::string str = stripeInfo->encrypted_local_keys(i); + std::vector chars(str.begin(), str.end()); + encryptedKeys->push_back(chars); + } + } else { + encryptedKeys = std::shared_ptr>>(previousKeys); + } + } virtual ~StripeInformationImpl() override { // PASS } @@ -192,20 +303,39 @@ namespace orc { ColumnEncodingKind getColumnEncoding(uint64_t colId) const override { ensureStripeFooterLoaded(); - return static_cast( - stripeFooter_->columns(static_cast(colId)).kind()); + ReaderEncryptionVariant* variant = this->encryption->getVariant(colId); + if (variant != nullptr) { + int subColumn = colId - variant->getRoot()->getColumnId(); + return static_cast( + stripeFooter_->encryption().Get(variant->getVariantId()).encoding(subColumn).kind()); + } else { + return static_cast(stripeFooter_->columns(static_cast(colId)).kind()); + } } uint64_t getDictionarySize(uint64_t colId) const override { ensureStripeFooterLoaded(); - return static_cast( - stripeFooter_->columns(static_cast(colId)).dictionary_size()); + ReaderEncryptionVariant* variant = this->encryption->getVariant(colId); + if (variant != nullptr) { + int subColumn = colId - variant->getRoot()->getColumnId(); + return static_cast( + stripeFooter_->encryption().Get(variant->getVariantId()).encoding(subColumn).dictionary_size()); + } else { + return static_cast(stripeFooter_->columns(static_cast(colId)).dictionary_size()); + } } const std::string& getWriterTimezone() const override { ensureStripeFooterLoaded(); return stripeFooter_->writer_timezone(); } + std::shared_ptr>> getEncryptedLocalKeys() const override { + return this->encryptedKeys; + } + std::vector& getEncryptedLocalKeyByVariantId(int col) const override { + return getEncryptedLocalKeys()->at(col); + } + long getOriginalStripeId() const override { return originalStripeId; } }; } // namespace orc diff --git a/c++/src/TypeImpl.cc b/c++/src/TypeImpl.cc index cbc7b82796..9542c33d1e 100644 --- a/c++/src/TypeImpl.cc +++ b/c++/src/TypeImpl.cc @@ -837,5 +837,4 @@ namespace orc { } return nullptr; } - } // namespace orc diff --git a/c++/src/io/InputStream.cc b/c++/src/io/InputStream.cc index 06ef40bd4c..84ee7287d9 100644 --- a/c++/src/io/InputStream.cc +++ b/c++/src/io/InputStream.cc @@ -209,5 +209,99 @@ namespace orc { result << input_->getName() << " from " << start_ << " for " << length_; return result.str(); } + DecryptionInputStream::DecryptionInputStream(std::unique_ptr input, + std::vector key, + std::vector iv, + const EVP_CIPHER* cipher,MemoryPool& pool) + : input_(std::move(input)), + key_(key), + iv_(iv), + cipher(cipher), + pool(pool){ + EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new(); + if (ctx == nullptr) { + throw std::runtime_error("Failed to create EVP cipher context"); + } + int ret = EVP_DecryptInit_ex(ctx, cipher, NULL, key_.data(), iv_.data()); + if (ret != 1) { + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + throw std::runtime_error("Failed to initialize EVP cipher context"); + } + ctx_ = ctx; + outputBuffer_.reset(new DataBuffer(pool)); + inputBuffer_.reset(new DataBuffer(pool)); + } + + DecryptionInputStream::~DecryptionInputStream() { + EVP_CIPHER_CTX_free(ctx_); + EVP_CIPHER_free(const_cast(cipher)); + } + + bool DecryptionInputStream::Next(const void** data, int* size) { + int bytesRead = 0; + //const void* ptr; + const void* inptr = static_cast(inputBuffer_->data()); + input_->Next(&inptr, &bytesRead); + if (bytesRead == 0) { + return false; + } + // decrypt data + const unsigned char* result = static_cast(inptr); + int outlen = 0; + //int blockSize = EVP_CIPHER_block_size(this->cipher); + outputBuffer_->resize(bytesRead); + int ret = EVP_DecryptUpdate(ctx_, outputBuffer_->data(), &outlen, result, bytesRead); + if (ret != 1) { + throw std::runtime_error("Failed to decrypt data"); + } + outputBuffer_->resize(outlen); + *data = outputBuffer_->data(); + *size = outputBuffer_->size(); + return true; + } + void DecryptionInputStream::BackUp(int count) { + this->input_->BackUp(count); + } -} // namespace orc + bool DecryptionInputStream::Skip(int count) { + return this->input_->Skip(count); + } + + google::protobuf::int64 DecryptionInputStream::ByteCount() const { + return input_->ByteCount(); + } + + void DecryptionInputStream::seek(PositionProvider& position) { + //std::cout<<"PPP:DecryptionInputStream::seek:"<seek(position); + } + void DecryptionInputStream::changeIv(long offset) { + int blockSize = EVP_CIPHER_key_length(cipher); + long encryptionBlocks = offset / blockSize; + long extra = offset % blockSize; + std::fill(iv_.end() - 8, iv_.end(), 0); + if (encryptionBlocks != 0) { + // Add the encryption blocks into the initial iv, to compensate for + // skipping over decrypting those bytes. + int posn = iv_.size() - 1; + while (encryptionBlocks > 0) { + long sum = (iv_[posn] & 0xff) + encryptionBlocks; + iv_[posn--] = (unsigned char) sum; + encryptionBlocks = sum / 0x100; + } + } + EVP_DecryptInit_ex(ctx_, cipher, NULL, key_.data(), iv_.data()); + // If the range starts at an offset that doesn't match the encryption + // block, we need to advance some bytes within an encryption block. + if (extra > 0) { + std::vector decrypted(extra); + int decrypted_len; + EVP_DecryptUpdate(ctx_, decrypted.data(), &decrypted_len, decrypted.data(), extra); + } + } + std::string DecryptionInputStream::getName() const { + return "DecryptionInputStream("+input_->getName()+")"; + } +}// namespace orc diff --git a/c++/src/io/InputStream.hh b/c++/src/io/InputStream.hh index 07aa623b5f..9bb8dc5b22 100644 --- a/c++/src/io/InputStream.hh +++ b/c++/src/io/InputStream.hh @@ -22,7 +22,14 @@ #include "Adaptor.hh" #include "orc/OrcFile.hh" #include "wrap/zero-copy-stream-wrapper.h" - +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include @@ -103,7 +110,30 @@ namespace orc { virtual void seek(PositionProvider& position) override; virtual std::string getName() const override; }; + class DecryptionInputStream : public SeekableInputStream { + public: + DecryptionInputStream(std::unique_ptr input,std::vector key, + std::vector iv,const EVP_CIPHER* cipher,MemoryPool& pool); + virtual ~DecryptionInputStream(); + virtual bool Next(const void** data, int* size) override; + virtual void BackUp(int count) override; + virtual bool Skip(int count) override; + virtual google::protobuf::int64 ByteCount() const override; + virtual void seek(PositionProvider& position) override; + virtual std::string getName() const override; + void changeIv(long offset); + + private: + std::unique_ptr input_; + std::vector key_; + std::vector iv_; + EVP_CIPHER_CTX* ctx_; + const EVP_CIPHER* cipher; + MemoryPool& pool; + std::unique_ptr> inputBuffer_; + std::unique_ptr> outputBuffer_; + }; } // namespace orc #endif // ORC_INPUTSTREAM_HH diff --git a/c++/src/security/InMemoryKeystore.cc b/c++/src/security/InMemoryKeystore.cc new file mode 100644 index 0000000000..37edd10ecd --- /dev/null +++ b/c++/src/security/InMemoryKeystore.cc @@ -0,0 +1,247 @@ +// Copyright 2010-present vivo, Inc. All rights reserved. +// +//Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "InMemoryKeystore.hh" +#include +#include +#include +#include +#include +#include +#include +namespace orc { + InMemoryKeystore::InMemoryKeystore() : random(rd()) {} + + InMemoryKeystore::~InMemoryKeystore() { + keys.clear(); + currentVersion.clear(); + } + + InMemoryKeystore* InMemoryKeystore::addKey(std::string name, EncryptionAlgorithm* algorithm, + std::string material) { + return addKey(name, 0, algorithm, material); + } + InMemoryKeystore* InMemoryKeystore::addKey(std::string keyName, int version, + EncryptionAlgorithm* algorithm, + std::string password) { + std::vector material = std::vector(password.begin(), password.end()); + return addKey(keyName,version,algorithm,material); + } + InMemoryKeystore* InMemoryKeystore::addKey(std::string keyName, int version, + EncryptionAlgorithm* algorithm, + std::vector material) { + // Test whether platform supports the algorithm + if (!testSupportsAes256() && (algorithm != EncryptionAlgorithm::AES_CTR_128)) { + algorithm = EncryptionAlgorithm::AES_CTR_128; + } + int masterKeyLength = material.size(); + int algorithmLength = algorithm->getKeyLength(); + std::vector buffer(algorithmLength, 0); + // std::vector* buffer = new std::vector(algorithmLength); + int dataLen = algorithmLength > masterKeyLength ? masterKeyLength : algorithmLength; + std::copy(material.begin(), material.begin() + dataLen, buffer.begin()); + + std::unique_ptr key = std::make_unique(keyName, version, algorithm, buffer); + + // Check whether the key is already present and has a smaller version + if (currentVersion.find(keyName) != currentVersion.end()) { + int oldV = currentVersion[keyName]; + if (oldV >= version) { + throw std::runtime_error("Key " + key->toString() + " with equal or higher version " + + std::to_string(version) + " already exists"); + } + } + //std::pair> keyPair(buildVersionName(keyName, version), std::move(key)); + //keys[](keyPair); + keys.emplace(buildVersionName(keyName, version),std::move(key)); + currentVersion[keyName] = version; + + return this; + } + bool InMemoryKeystore::testSupportsAes256() { + const EVP_CIPHER * cipher = EVP_aes_256_cbc(); + int maxKeyLength = EVP_CIPHER_key_length(cipher); + EVP_CIPHER_free(const_cast(cipher)); + return maxKeyLength >= 256 / 8; + } + + std::string InMemoryKeystore::buildVersionName(const std::string& name, int version) { + return name + "@" + std::to_string(version); + } + + std::vector InMemoryKeystore::getKeyNames() { + std::vector names; + for (const auto& kv : currentVersion) { + names.push_back(kv.first); + } + return names; + } + + std::unique_ptr& InMemoryKeystore::getCurrentKeyVersion(std::string keyName) { + int version = currentVersion[keyName]; + const std::string keyVersionName = buildVersionName(keyName, version); + return keys[keyVersionName]; + } + + LocalKey* InMemoryKeystore::createLocalKey(KeyMetadata* key) { + const std::string keyVersion = buildVersionName(key->getKeyName(), key->getVersion()); + if (keys.count(keyVersion) == 0) { + throw std::invalid_argument("Unknown key " + keyVersion); + } + SecurityKey* secret = keys[keyVersion].get(); + EncryptionAlgorithm* algorithm = secret->getAlgorithm(); + std::vector encryptedKey(algorithm->getKeyLength()); + //new std::vector(algorithm->getKeyLength()); + RAND_bytes(encryptedKey.data(), encryptedKey.size()); + + std::vector iv(algorithm->getIvLength()); + std::copy(encryptedKey.begin(), encryptedKey.begin() + algorithm->getIvLength(), iv.begin()); + + EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new(); + if (ctx == nullptr) { + throw std::runtime_error("EVP_CIPHER_CTX_new() failed"); + } + const EVP_CIPHER* cipher = algorithm->createCipher(); + if (EVP_DecryptInit_ex(ctx, cipher, nullptr, secret->getMaterial().data(), + iv.data()) != 1) { + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + throw std::runtime_error("EVP_DecryptInit_ex() failed"); + } + int blockSize = EVP_CIPHER_block_size(cipher); + std::vector decryptedKey(encryptedKey.size() + blockSize); + // new std::vector(encryptedKey->size() + blockSize); + int outlen = 0; + if (EVP_DecryptUpdate(ctx, decryptedKey.data(), &outlen, encryptedKey.data(), + encryptedKey.size()) != 1) { + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + throw std::runtime_error("EVP_DecryptUpdate() failed"); + } + int plaintext_len = outlen; + if (EVP_DecryptFinal_ex(ctx, decryptedKey.data() + outlen, &outlen) != 1) { + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + throw std::runtime_error("EVP_DecryptFinal_ex() failed"); + } + plaintext_len += outlen; + decryptedKey.resize(plaintext_len); + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + LocalKey* localKey = new LocalKey(algorithm->getAlgorithm(), decryptedKey, encryptedKey); + return localKey; + } + + std::shared_ptr InMemoryKeystore::decryptLocalKey( + KeyMetadata* key, std::vector encryptedKey) { + const std::string keyVersion = buildVersionName(key->getKeyName(), key->getVersion()); + if (keys.count(keyVersion) == 0) { + throw std::invalid_argument("Unknown key " + keyVersion); + } + SecurityKey* secret = keys[keyVersion].get(); + EncryptionAlgorithm* algorithm = secret->getAlgorithm(); + + std::vector iv(algorithm->getIvLength()); + std::copy(encryptedKey.begin(), encryptedKey.begin() + algorithm->getIvLength(), iv.begin()); + + EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new(); + if (ctx == nullptr) { + throw std::runtime_error("EVP_CIPHER_CTX_new() failed"); + } + const evp_cipher_st* cipher = algorithm->createCipher(); + if (EVP_DecryptInit_ex(ctx, cipher, nullptr, secret->getMaterial().data(), + iv.data()) != 1) { + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + throw std::runtime_error("EVP_DecryptInit_ex() failed"); + } + int blockSize = EVP_CIPHER_block_size(cipher); + std::vector decryptedKey((encryptedKey.size() + blockSize)); + //new std::vector(encryptedKey->size() + blockSize); + int outlen = 0; + if (EVP_DecryptUpdate(ctx, decryptedKey.data(), &outlen, encryptedKey.data(), + encryptedKey.size()) != 1) { + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + throw std::runtime_error("EVP_DecryptUpdate() failed"); + } + int plaintext_len = outlen; + if (EVP_DecryptFinal_ex(ctx, decryptedKey.data() + outlen, &outlen) != 1) { + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + throw std::runtime_error("EVP_DecryptFinal_ex() failed"); + } + plaintext_len += outlen; + decryptedKey.resize(plaintext_len); + EVP_CIPHER_CTX_free(ctx); + EVP_CIPHER_free(const_cast(cipher)); + return std::make_unique(decryptedKey, algorithm->getAlgorithm()); + } + LocalKey::LocalKey(std::string algorithm, std::vector decryptedKey, + std::vector encryptedKey):encryptedKey(encryptedKey) { + if (!decryptedKey.empty()) { + setDecryptedKey(algorithm, decryptedKey); + } + } + + void LocalKey::setDecryptedKey(std::string algorithm, std::vector decrypt) { + this->decryptedKey = std::make_unique(decrypt, algorithm); + } + void LocalKey::setDecryptedKey(std::shared_ptr secretKeySpec) { + this->decryptedKey = secretKeySpec; + } + + std::shared_ptr LocalKey::getDecryptedKey() { + return decryptedKey; + } + + std::vector LocalKey::getEncryptedKey() { + return encryptedKey; + } + EncryptionAlgorithm::EncryptionAlgorithm(const std::string& algorithm, const std::string& mode, + int keyLength) + : algorithm(algorithm), mode(mode), keyLength(keyLength) {} + EncryptionAlgorithm* EncryptionAlgorithm::AES_CTR_128 = + new EncryptionAlgorithm("AES", "CTR/NoPadding", 16); + EncryptionAlgorithm* EncryptionAlgorithm::AES_CTR_256 = + new EncryptionAlgorithm("AES", "CTR/NoPadding", 32); + int EncryptionAlgorithm::getIvLength() { + return 16; + } + + int EncryptionAlgorithm::getKeyLength() { + return this->keyLength; + } + const std::string EncryptionAlgorithm::getAlgorithm() { + return this->algorithm; + } + + const std::string EncryptionAlgorithm::toString() { + return algorithm + std::to_string(keyLength * 8); + } + const EVP_CIPHER* EncryptionAlgorithm::createCipher() { + if (algorithm == "AES" && mode == "CTR/NoPadding" && keyLength == 16){ + return EVP_aes_128_ctr(); + }else if (algorithm == "AES" && mode == "CTR/NoPadding" && keyLength == 32) + return EVP_aes_256_ctr(); + else { + throw std::invalid_argument("Bad algorithm or padding"); + } + return NULL; + } +} // namespace orc diff --git a/c++/src/security/InMemoryKeystore.hh b/c++/src/security/InMemoryKeystore.hh new file mode 100644 index 0000000000..0125d403fa --- /dev/null +++ b/c++/src/security/InMemoryKeystore.hh @@ -0,0 +1,182 @@ +// Copyright 2010-present vivo, Inc. All rights reserved. +// +//Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CPPDEMO_INMEMORYKEYSTORE_H +#define CPPDEMO_INMEMORYKEYSTORE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace orc { + class EncryptionAlgorithm { + public: + EncryptionAlgorithm(const std::string& algorithm, const std::string& mode, int keyLength); + //Get the EVP_CIPHER and use it for every decryption + const EVP_CIPHER* createCipher(); + int getIvLength(); + //algorithm name like AES + const std::string getAlgorithm(); + const std::string toString(); + static EncryptionAlgorithm* AES_CTR_128; + static EncryptionAlgorithm* AES_CTR_256; + //The length of the algorithm key + int getKeyLength(); + + private: + std::string algorithm; + std::string mode; + int keyLength; + }; + //describing the KeyProvider key + class KeyMetadata { + private: + std::string keyName; + int version; + EncryptionAlgorithm* algorithm; + + public: + KeyMetadata(std::string key, int version, EncryptionAlgorithm* algorithm) + : keyName(key), version(version), algorithm(algorithm) {} + + std::string getKeyName() const { + return keyName; + } + + EncryptionAlgorithm* getAlgorithm() { + return algorithm; + } + + int getVersion() const { + return version; + } + + std::string toString() { + return keyName + "@" + std::to_string(version) + " " + algorithm->toString(); + } + }; + //The key used for decrypting data + class SecurityKey : public KeyMetadata { + private: + std::vector material; + public: + SecurityKey(std::string keyName, int version, EncryptionAlgorithm* algorithm, + std::vector& material) + : KeyMetadata(keyName, version, algorithm), material(material) {} + std::vector getMaterial() { + return material; + } + }; + + //Use the SecurityKey to encrypt the encryptedKey and obtain the reuslt + class SecretKeySpec { + private: + std::vector key; + std::string algorithm; + + public: + SecretKeySpec(std::vector var1, const std::string& var2) { + if (!var1.empty() && !var2.empty()) { + key = std::vector(var1); + algorithm = var2; + } else { + throw std::invalid_argument("Missing argument"); + } + } + std::string getAlgorithm() const { + return algorithm; + } + + std::vector getEncoded() const { + return key; + } + }; + + //Use the SecurityKey to encrypt the encryptedKey with the algorithm and obtain the decryptedKey + class LocalKey { + public: + LocalKey(std::string algorithm, std::vector decryptedKey , + std::vector encryptedKey); + + void setDecryptedKey(std::string algorithm, std::vector key); + void setDecryptedKey(std::shared_ptr secretKeySpec); + + + std::shared_ptr getDecryptedKey(); + + std::vector getEncryptedKey(); + + private: + std::vector encryptedKey; + std::shared_ptr decryptedKey = nullptr; + }; + + class KeyProvider { + public: + //virtual ~KeyProvider() = 0; + virtual std::vector getKeyNames() = 0; + //obtain the master key based on the keyName for decrypting the localKey + virtual std::unique_ptr& getCurrentKeyVersion(std::string keyName) = 0; + //decrypt the random number using the master key to generate the local key + virtual LocalKey* createLocalKey(KeyMetadata* key) = 0; + //Use the key to find the master key in the KeyProvider, then decrypt the encryptedKey and return the plaintext + virtual std::shared_ptr decryptLocalKey(KeyMetadata* key, std::vector encryptedKey) = 0; + }; + //Used for storing the master key and generating multiple local keys when writing data + class InMemoryKeystore : public KeyProvider{ + private: + //Used for generating local keys + std::random_device rd; + std::mt19937 random; + //the master key stored in this keys + std::map> keys; + // The latest version of the master key used for storage + std::map currentVersion; + + public: + InMemoryKeystore(); + ~InMemoryKeystore(); + //get the names of all master keys + std::vector getKeyNames() override; + //add master key + InMemoryKeystore* addKey(std::string name, EncryptionAlgorithm* algorithm, + std::string password); + // add master key with version + InMemoryKeystore* addKey(std::string keyName, int version, EncryptionAlgorithm* algorithm, + std::string password); + InMemoryKeystore* addKey(std::string keyName, int version, + EncryptionAlgorithm* algorithm, + std::vector material); + + std::unique_ptr& getCurrentKeyVersion(std::string keyName) override; + + LocalKey* createLocalKey(KeyMetadata* key) override; + + std::shared_ptr decryptLocalKey(KeyMetadata* key, std::vector encryptedKey) override; + //Check if Aes256 encryption and decryption is supported + bool testSupportsAes256(); + std::string buildVersionName(const std::string& name, int version); + }; + +} +#endif //CPPDEMO_INMEMORYKEYSTORE_H diff --git a/c++/src/security/ReaderEncryption.cc b/c++/src/security/ReaderEncryption.cc new file mode 100644 index 0000000000..43c2f23522 --- /dev/null +++ b/c++/src/security/ReaderEncryption.cc @@ -0,0 +1,278 @@ +// Copyright 2010-present vivo, Inc. All rights reserved. +// +//Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ReaderEncryption.hh" +#include "../Reader.hh" +namespace orc { + EncryptionAlgorithm* fromSerialization(orc::proto::EncryptionAlgorithm& serialization) { + switch (serialization) { + case orc::proto::EncryptionAlgorithm::AES_CTR_128: + return EncryptionAlgorithm::AES_CTR_128; + case orc::proto::EncryptionAlgorithm::AES_CTR_256: + return EncryptionAlgorithm::AES_CTR_256; + default: + throw std::invalid_argument("Unknown code in encryption algorithm"); + } + } + ReaderEncryption::ReaderEncryption(){}; + + ReaderEncryption::ReaderEncryption(std::shared_ptr contents, + long stripeStatisticsOffset, + std::shared_ptr provider) { + if (nullptr != contents->footer || contents->footer->has_encryption()) { + keyProvider = provider; + const proto::Encryption& encrypt = contents->footer->encryption(); + //Initialize keys. + int keysSize = encrypt.key_size(); + for (int k = 0; k < keysSize; ++k) { + keys.push_back(std::make_unique(encrypt.key(k))); + } + //Initialize variants + int variantsSize = encrypt.variants_size(); + //The offset of the stripe statistics information for the column that is encrypted. + uint64_t offset = stripeStatisticsOffset; + for (int v = 0; v < variantsSize; ++v) { + const proto::EncryptionVariant& variant = encrypt.variants(v); + std::unique_ptr& k = keys[static_cast(variant.key())]; + ReaderEncryptionVariant* readerEncryptionVariant = + new ReaderEncryptionVariant(variant,k.get(), v, contents, offset, provider); + variants.push_back(std::unique_ptr(readerEncryptionVariant)); + offset += variants[v]->getStripeStatisticsLength(); + } + //ColumnVariants's size is columnSize +1, and for unencrypted columns, its value is null. + int columnVariantsSize = contents->schema->getMaximumColumnId() + 1; + columnVariants.resize(columnVariantsSize); + for (auto it = variants.begin(); it != variants.end(); ++it) { + std::shared_ptr variant = *it; + Type* root = variant->getRoot(); + //If nested columns are encrypted, they share the same variant object with their parent and sibling columns. + for (int c = root->getColumnId(); c <= static_cast(root->getMaximumColumnId()); ++c) { + // set the variant if it is the first one that we've found + if (columnVariants[c] == nullptr) { + columnVariants[c] = std::shared_ptr(variant); + } + } + } + } + } + ReaderEncryptionVariant* ReaderEncryption::getVariant(int column) { + if (columnVariants.size() == 0) { + return nullptr; + } else { + //If a given variant cannot obtain the decrypted footer key, then look for a variant object that can be decrypted for that column. + //Look for a variant object that can be decrypted within its subcolumns. + while (columnVariants[column] != nullptr && + !columnVariants[column]->getKeyDescription()->isAvailable()) { + if (keyProvider != nullptr) { + columnVariants[column].reset(findNextVariant(column, columnVariants[column]->getVariantId())); + } + } + return columnVariants[column].get(); + } + } + + ReaderEncryptionVariant* ReaderEncryption::findNextVariant(int column, int lastVariant) { + for (int v = lastVariant + 1; v < static_cast(this->variants.size()); ++v) { + Type* root = variants[v]->getRoot(); + if (static_cast(root->getColumnId()) <= column && + column <= static_cast(root->getMaximumColumnId())) { + return variants[v].get(); + } + } + return nullptr; + } + std::vector>& ReaderEncryption::getKeys() { + return keys; + } + std::vector>& ReaderEncryption::getVariants() { + return variants; + } + + ReaderEncryptionKey::ReaderEncryptionKey(const orc::proto::EncryptionKey& key) + : name(key.key_name()), version(key.key_version()), state(State::UNTRIED) { + orc::proto::EncryptionAlgorithm al = key.algorithm(); + algorithm = orc::fromSerialization(al); + } + std::string ReaderEncryptionKey::getKeyName() const { + return name; + } + + int ReaderEncryptionKey::getKeyVersion() { + return version; + } + + orc::EncryptionAlgorithm* ReaderEncryptionKey::getAlgorithm() { + return algorithm; + } + + std::vector& ReaderEncryptionKey::getEncryptionRoots() { + return roots; + } + + std::unique_ptr ReaderEncryptionKey::getMetadata() { + return std::make_unique(name, version, algorithm); + } + + ReaderEncryptionKey::State& ReaderEncryptionKey::getState() { + return this->state; + } + + void ReaderEncryptionKey::setFailure() { + this->state = State::SUCCESS; + } + + void ReaderEncryptionKey::setSuccess() { + this->state = State::SUCCESS; + } + + void ReaderEncryptionKey::addVariant(ReaderEncryptionVariant* newVariant) { + roots.push_back(newVariant); + } + bool ReaderEncryptionKey::isAvailable() { + if (getState() == ReaderEncryptionKey::State::SUCCESS) { + return true; + } else if (getState() == ReaderEncryptionKey::State::UNTRIED && !getEncryptionRoots().empty()) { + // Check to see if we can decrypt the footer key of the first variant. + try { + return getEncryptionRoots()[0]->getFileFooterKey() != nullptr; + } catch (const std::exception& e) { + setFailure(); + } + } + return false; + } + ReaderEncryptionVariant::ReaderEncryptionVariant( + const proto::EncryptionVariant& proto, + ReaderEncryptionKey* key, + int variantId, + std::shared_ptr contents, + uint64_t stripeStatsOffset, + std::shared_ptr provider) + : proto(proto), + key(key), + variantId(variantId), + stripeStatsOffset(stripeStatsOffset), + provider(provider){ + if (proto.has_root()) { + column = contents->schema->getTypeByColumnId(proto.root()); + } else { + column = contents->schema.get(); + } + // Each strip in this column has its own key. + size_t stripeCount = contents->stripeList.size(); + if (proto.has_encrypted_key()) { + std::string algorithm = key->getAlgorithm()->getAlgorithm(); + for (size_t s = 0; s < stripeCount; ++s) { + StripeInformation* stripe = contents->stripeList.at(s).get(); + std::vector colKey = stripe->getEncryptedLocalKeyByVariantId(variantId); + std::vector deKey; + LocalKey* localKey = new LocalKey(algorithm,deKey, colKey); + localKeys.push_back(std::shared_ptr(localKey)); + } + // The FileStat and StripStat information for decrypting this column. + std::string footerKeyStr = proto.encrypted_key(); + std::vector enKey(footerKeyStr.begin(), footerKeyStr.end()); + std::vector deKey; + footerKey =std::shared_ptr(new LocalKey(algorithm, deKey,enKey)); + key->addVariant(this); + } else { + footerKey = nullptr; + } + } + ReaderEncryptionKey* ReaderEncryptionVariant::getKeyDescription() { + return this->key; + } + Type* ReaderEncryptionVariant::getRoot() { + return const_cast(this->column); + } + + int ReaderEncryptionVariant::getVariantId() { + return this->variantId; + } + + std::shared_ptr ReaderEncryptionVariant::getFileFooterKey() { + if (this->key != nullptr && this->provider != nullptr) { + return getDecryptedKey(footerKey.get());; + } else { + return nullptr; + } + } + long ReaderEncryptionVariant::getStripeStatisticsLength() { + long result = 0; + // proto.stripestatistics() + for (int i = 0; i < proto.stripe_statistics_size(); i++) { + result += proto.stripe_statistics(i).length(); + } + return result; + } + uint64_t ReaderEncryptionVariant::getStripeStatsOffset(){ + return stripeStatsOffset; + } + orc::proto::ColumnarStripeStatistics* ReaderEncryptionVariant::getAllEncryptStripeStatistics(){ + return allEncryptionStripeStatistics.get(); + } + void ReaderEncryptionVariant::setAllEncryptStripeStatistics(orc::proto::ColumnarStripeStatistics* tmp){ + allEncryptionStripeStatistics = std::unique_ptr(tmp); + } + void ReaderEncryptionVariant::setFileStatistics(orc::proto::FileStatistics* tmp){ + fileStatistics = std::unique_ptr(tmp); + } + orc::proto::FileStatistics* ReaderEncryptionVariant::getFileStatistics(){ + return fileStatistics.get(); + } + const proto::EncryptionVariant& ReaderEncryptionVariant::getEncryptionVariantProto(){ + return this->proto; + } + std::shared_ptr ReaderEncryptionVariant::getStripeKey(long stripe) { + if (key == nullptr || provider == nullptr) return nullptr; + return getDecryptedKey(localKeys[stripe].get()); + } + + std::shared_ptr ReaderEncryptionVariant::getDecryptedKey(LocalKey* localKey) const { + std::shared_ptr result = localKey->getDecryptedKey(); + //Check if localkey has been decrypted. + if (result == nullptr) { + switch (key->getState()) { + case ReaderEncryptionKey::State::UNTRIED: + try { + result = provider->decryptLocalKey(key->getMetadata().get(), localKey->getEncryptedKey()); + } catch (const std::exception& e) { + std::cout << "Can't decrypt using key " << key->getKeyName() << std::endl; + } + if (result != nullptr) { + localKey->setDecryptedKey(result); + key->setSuccess(); + } else { + key->setFailure(); + } + break; + case ReaderEncryptionKey::State::SUCCESS: + result = provider->decryptLocalKey(key->getMetadata().get(), localKey->getEncryptedKey()); + if (result == nullptr) { + throw std::runtime_error("Can't decrypt local key " + key->getKeyName()); + } + localKey->setDecryptedKey(result); + break; + case ReaderEncryptionKey::State::FAILURE: + return nullptr; + } + } + return result; + }; + +} // namespace orc \ No newline at end of file diff --git a/c++/src/security/ReaderEncryption.hh b/c++/src/security/ReaderEncryption.hh new file mode 100644 index 0000000000..840f1c1246 --- /dev/null +++ b/c++/src/security/ReaderEncryption.hh @@ -0,0 +1,215 @@ +// Copyright 2010-present vivo, Inc. All rights reserved. +// +//Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ORC_READERENCRYPTION_H +#define ORC_READERENCRYPTION_H + +#include +#include "InMemoryKeystore.hh" +#include "orc/Common.hh" +#include "orc/Statistics.hh" +#include "orc/Type.hh" +#include "orc_proto.pb.h" +namespace orc { + class ReaderEncryptionKey; + struct FileContents; + class EncryptionVariant { + public: + std::shared_ptr getKeyDescription(); + + Type* getRoot(); + + int getVariantId(); + + std::shared_ptr getFileFooterKey(); + + std::shared_ptr getStripeKey(long stripe); + }; + + // 加密变体,每个列一个 + // 每个strip 都有个StripeKey,用于解密数据的 + // 每个列都有一个FileFooterKey,用于解密FileStat,StripStat + class ReaderEncryptionVariant : public EncryptionVariant { + public: + ReaderEncryptionVariant(const proto::EncryptionVariant& proto, + ReaderEncryptionKey* key, + int variantId, + std::shared_ptr contents, + uint64_t stripeStatsOffset, + std::shared_ptr provider); + //Type* getColumn(); + long getStripeStatisticsLength(); + ReaderEncryptionKey* getKeyDescription(); + Type* getRoot(); + int getVariantId(); + /** + * Get the footer key of this column and decrypt it. + * @return + */ + std::shared_ptr getFileFooterKey(); + /** + * Get the local key of a specific stripe in this column and decrypt it. + * @param stripe + * @return + */ + std::shared_ptr getStripeKey(long stripe); + /** + * Get the length of the statistical information of the column + * @return + */ + uint64_t getStripeStatsOffset(); + /** + * All the statistical information of this column after stripping is stored here for caching, + * avoiding multiple readings and affecting performance. + * @return + */ + orc::proto::ColumnarStripeStatistics* getAllEncryptStripeStatistics(); + /** + * All the statistical information of this column after stripping is stored here for caching, + * avoiding multiple readings and affecting performance. + * @return + */ + void setAllEncryptStripeStatistics(orc::proto::ColumnarStripeStatistics* tmp); + void setFileStatistics(orc::proto::FileStatistics* tmp); + orc::proto::FileStatistics* getFileStatistics(); + const proto::EncryptionVariant& getEncryptionVariantProto(); + private: + const proto::EncryptionVariant& proto; + ReaderEncryptionKey* key; + const Type* column; + //The indexes of all encrypted columns. + int variantId; + //Each stripe in this column has a local key. + std::vector> localKeys; + //All stripes in this column have a footer key. + std::shared_ptr footerKey; + // + uint64_t stripeStatsOffset; + std::shared_ptr provider; + // All the statistical information of this column after stripping is stored here for caching, + // avoiding multiple readings and affecting performance. + std::unique_ptr allEncryptionStripeStatistics = nullptr; + std::unique_ptr fileStatistics = nullptr; + /** + * Decrypt local key, + * @param pKey + * @return + */ + std::shared_ptr getDecryptedKey(LocalKey* pKey) const; + }; + // 定义 加密key的信息,不存储密钥 + class ReaderEncryptionKey { + public: + enum class State { UNTRIED, FAILURE, SUCCESS }; + ReaderEncryptionKey(); + ReaderEncryptionKey(const proto::EncryptionKey& key); + std::string getKeyName() const; + int getKeyVersion(); + //Check if the footerkey can be decrypted. + bool isAvailable(); + EncryptionAlgorithm* getAlgorithm(); + std::vector& getEncryptionRoots(); + std::unique_ptr getMetadata(); + ReaderEncryptionKey::State& getState(); + void setFailure(); + void setSuccess(); + void addVariant(ReaderEncryptionVariant* newVariant); + bool operator==(const ReaderEncryptionKey& other) const { + return (name == other.name) && (version == other.version) && (algorithm == other.algorithm); + } + + bool operator!=(const ReaderEncryptionKey& other) const { + return !(*this == other); + } + + private: + std::string name; + int version; + EncryptionAlgorithm* algorithm; + std::vector roots; + State state; + }; + class ReaderEncryption { + public: + ReaderEncryption(std::shared_ptr contents, + long stripeStatisticsOffset, + std::shared_ptr provider); + ReaderEncryption(); + std::vector>& getKeys(); + ReaderEncryptionVariant* getVariant(int column); + std::vector>& getVariants(); + + private: + std::shared_ptr keyProvider; + std::vector> keys; + std::vector> variants; + //ColumnVariants's size is columnSize +1, and for unencrypted columns, its value is null. + std::vector> columnVariants; + ReaderEncryptionVariant* findNextVariant(int column, int lastVariant); + }; + class CryptoUtil { + private: + static const int COLUMN_ID_LENGTH = 3; + static const int KIND_LENGTH = 2; + static const int STRIPE_ID_LENGTH = 3; + static const int MIN_COUNT_BYTES = 8; + + static const int MAX_COLUMN = 0xffffff; + static const int MAX_KIND = 0xffff; + static const int MAX_STRIPE = 0xffffff; + + public: + static unsigned char* modifyIvForStream(int columnId, proto::Stream_Kind kind, long stripeId, + unsigned char* iv, int ivLength) { + if (columnId < 0 || columnId > MAX_COLUMN) { + throw std::invalid_argument("ORC encryption is limited to " + std::to_string(MAX_COLUMN) + + " columns. Value = " + std::to_string(columnId)); + } + int k = kind; + if (k < 0 || k > MAX_KIND) { + throw std::invalid_argument("ORC encryption is limited to " + std::to_string(MAX_KIND) + + " stream kinds. Value = " + std::to_string(k)); + } + if (ivLength - (COLUMN_ID_LENGTH + KIND_LENGTH + STRIPE_ID_LENGTH) < MIN_COUNT_BYTES) { + throw std::invalid_argument("Not enough space in the iv for the count"); + } + iv[0] = static_cast(columnId >> 16); + iv[1] = static_cast(columnId >> 8); + iv[2] = static_cast(columnId); + iv[COLUMN_ID_LENGTH] = static_cast(k >> 8); + iv[COLUMN_ID_LENGTH + 1] = static_cast(k); + modifyIvForStripe(stripeId, iv, ivLength); + return iv; + } + + static unsigned char* modifyIvForStripe(long stripeId, unsigned char* iv, int ivLength) { + if (stripeId < 1 || stripeId > MAX_STRIPE) { + throw std::invalid_argument("ORC encryption is limited to " + std::to_string(MAX_STRIPE) + + " stripes. Value = " + std::to_string(stripeId)); + } + iv[COLUMN_ID_LENGTH + KIND_LENGTH] = static_cast(stripeId >> 16); + iv[COLUMN_ID_LENGTH + KIND_LENGTH + 1] = static_cast(stripeId >> 8); + iv[COLUMN_ID_LENGTH + KIND_LENGTH + 2] = static_cast(stripeId); + for (int i = COLUMN_ID_LENGTH + KIND_LENGTH + STRIPE_ID_LENGTH; i < ivLength; ++i) { + iv[i] = 0; + } + return iv; + } + }; +} // namespace orc +#endif // ORC_READERENCRYPTION_H From 5ecbd4118e909fd62faa7b67970d5dd9d81470e3 Mon Sep 17 00:00:00 2001 From: zxf216 <524627843@qq.com> Date: Wed, 14 Aug 2024 07:04:32 +0800 Subject: [PATCH 2/4] Update c++/include/orc/Common.hh Co-authored-by: Gang Wu --- c++/include/orc/Common.hh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c++/include/orc/Common.hh b/c++/include/orc/Common.hh index 725b95112f..1e8a03b24e 100644 --- a/c++/include/orc/Common.hh +++ b/c++/include/orc/Common.hh @@ -243,7 +243,7 @@ namespace orc { * LocalKey, the values of stripeId and originalStripeId may not be equal. * @return */ - virtual long getOriginalStripeId() const = 0; + virtual int64_t getOriginalStripeId() const = 0; }; // Return true if val1 < val2; otherwise return false From 5dc3331e041a227d03d3db87e3f3c5623ef5bbeb Mon Sep 17 00:00:00 2001 From: zxf216 <524627843@qq.com> Date: Wed, 14 Aug 2024 07:15:16 +0800 Subject: [PATCH 3/4] Update c++/include/orc/Common.hh Co-authored-by: Gang Wu --- c++/include/orc/Common.hh | 2 -- 1 file changed, 2 deletions(-) diff --git a/c++/include/orc/Common.hh b/c++/include/orc/Common.hh index 1e8a03b24e..7e3a99bcd4 100644 --- a/c++/include/orc/Common.hh +++ b/c++/include/orc/Common.hh @@ -233,8 +233,6 @@ namespace orc { getEncryptedLocalKeys() const = 0; /** *Get the Localkey for a specific column in this Stripe. - * @param col - * @return */ virtual std::vector& getEncryptedLocalKeyByVariantId(int col) const = 0; /** From f8201b7e2856b00f0f40be5475a3b14cd03da229 Mon Sep 17 00:00:00 2001 From: zxf216 <524627843@qq.com> Date: Wed, 14 Aug 2024 07:15:29 +0800 Subject: [PATCH 4/4] Update c++/include/orc/Common.hh Co-authored-by: Gang Wu --- c++/include/orc/Common.hh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c++/include/orc/Common.hh b/c++/include/orc/Common.hh index 7e3a99bcd4..5a9f63e97d 100644 --- a/c++/include/orc/Common.hh +++ b/c++/include/orc/Common.hh @@ -234,7 +234,7 @@ namespace orc { /** *Get the Localkey for a specific column in this Stripe. */ - virtual std::vector& getEncryptedLocalKeyByVariantId(int col) const = 0; + virtual std::vector& getEncryptedLocalKeyByVariantId(int32_t col) const = 0; /** * In general, only the first stripe in an ORC file will store the LocalKey.In this case, the * stripeId and originalStripeId are equal. If an ORC file has multiple stripes storing the