-
Notifications
You must be signed in to change notification settings - Fork 587
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prevent MutableVectorData from being used for std::string.
PiperOrigin-RevId: 705972216
- Loading branch information
1 parent
8fa237b
commit e1b0be3
Showing
4 changed files
with
148 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_ | ||
#define TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_ | ||
|
||
#include <cstddef> | ||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "absl/strings/string_view.h" | ||
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h" | ||
|
||
namespace tensorflow_federated { | ||
namespace aggregation { | ||
|
||
// MutableUnownedStringData implements TensorData by wrapping std::vector and | ||
// using it as backing storage for string_view objects. MutableUnownedStringData | ||
// can be mutated using std::vector methods. The MutableUnownedStringData object | ||
// does not own the string values. Use MutableStringData instead if you want | ||
// a TensorData object that owns the strings. | ||
class MutableUnownedStringData : public std::vector<absl::string_view>, | ||
public TensorData { | ||
public: | ||
// Derive constructors from the base vector class. | ||
using std::vector<absl::string_view>::vector; | ||
|
||
~MutableUnownedStringData() override = default; | ||
|
||
// Implementation of the base class methods. | ||
size_t byte_size() const override { | ||
return this->size() * sizeof(absl::string_view); | ||
} | ||
const void* data() const override { | ||
return this->std::vector<absl::string_view>::data(); | ||
} | ||
|
||
// Copy the MutableUnownedStringData into a string. | ||
std::string EncodeContent() { | ||
return std::string(reinterpret_cast<const char*>(this->data()), | ||
this->byte_size()); | ||
} | ||
|
||
// Create and return a new MutableUnownedStringData populated with the data | ||
// from content. | ||
static std::unique_ptr<MutableUnownedStringData> CreateFromEncodedContent( | ||
const std::string& content) { | ||
const absl::string_view* data = | ||
reinterpret_cast<const absl::string_view*>(content.data()); | ||
return std::make_unique<MutableUnownedStringData>( | ||
data, data + content.size() / sizeof(absl::string_view)); | ||
} | ||
}; | ||
|
||
} // namespace aggregation | ||
} // namespace tensorflow_federated | ||
|
||
#endif // TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_ |
62 changes: 62 additions & 0 deletions
62
tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h" | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "googlemock/include/gmock/gmock.h" | ||
#include "googletest/include/gtest/gtest.h" | ||
#include "absl/strings/string_view.h" | ||
#include "tensorflow_federated/cc/testing/status_matchers.h" | ||
|
||
namespace tensorflow_federated { | ||
namespace aggregation { | ||
namespace { | ||
|
||
TEST(MutableUnownedStringDataTest, MutableUnownedStringDataValid) { | ||
std::string string_1 = "foo"; | ||
std::string string_2 = "bar"; | ||
std::string string_3 = "baz"; | ||
MutableUnownedStringData vector_data; | ||
vector_data.push_back(absl::string_view(string_1)); | ||
vector_data.push_back(absl::string_view(string_2)); | ||
vector_data.push_back(absl::string_view(string_3)); | ||
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk()); | ||
} | ||
|
||
TEST(MutableUnownedStringDataTest, EncodeDecodeSucceeds) { | ||
std::string string_1 = "foo"; | ||
std::string string_2 = "bar"; | ||
std::string string_3 = "baz"; | ||
MutableUnownedStringData vector_data; | ||
vector_data.push_back(absl::string_view(string_1)); | ||
vector_data.push_back(absl::string_view(string_2)); | ||
vector_data.push_back(absl::string_view(string_3)); | ||
std::string encoded_vector_data = vector_data.EncodeContent(); | ||
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk()); | ||
auto decoded_vector_data = | ||
MutableUnownedStringData::CreateFromEncodedContent(encoded_vector_data); | ||
EXPECT_THAT(decoded_vector_data->CheckValid<absl::string_view>(), IsOk()); | ||
EXPECT_EQ((*decoded_vector_data)[0], string_1); | ||
EXPECT_EQ((*decoded_vector_data)[1], string_2); | ||
EXPECT_EQ((*decoded_vector_data)[2], string_3); | ||
} | ||
|
||
} // namespace | ||
} // namespace aggregation | ||
} // namespace tensorflow_federated |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters