Skip to content

Commit

Permalink
Prevent MutableVectorData from being used for std::string.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705972216
  • Loading branch information
TensorFlow Federated Team authored and copybara-github committed Dec 17, 2024
1 parent 8fa237b commit e1b0be3
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 1 deletion.
12 changes: 12 additions & 0 deletions tensorflow_federated/cc/core/impl/aggregation/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ cc_library(
"agg_vector_iterator.h",
"datatype.h",
"input_tensor_list.h",
"mutable_unowned_string_data.h",
"mutable_vector_data.h",
"tensor.h",
"tensor_data.h",
Expand Down Expand Up @@ -566,6 +567,17 @@ cc_test(
],
)

cc_test(
name = "mutable_unowned_string_data_test",
srcs = ["mutable_unowned_string_data_test.cc"],
deps = [
":tensor",
"//tensorflow_federated/cc/testing:oss_test_main",
"//tensorflow_federated/cc/testing:status_matchers",
"@com_google_absl//absl/strings:string_view",
],
)

cc_test(
name = "vector_string_data_test",
srcs = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_
#define TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_

#include <cstddef>
#include <memory>
#include <string>
#include <vector>

#include "absl/strings/string_view.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h"

namespace tensorflow_federated {
namespace aggregation {

// MutableUnownedStringData implements TensorData by wrapping std::vector and
// using it as backing storage for string_view objects. MutableUnownedStringData
// can be mutated using std::vector methods. The MutableUnownedStringData object
// does not own the string values. Use MutableStringData instead if you want
// a TensorData object that owns the strings.
class MutableUnownedStringData : public std::vector<absl::string_view>,
public TensorData {
public:
// Derive constructors from the base vector class.
using std::vector<absl::string_view>::vector;

~MutableUnownedStringData() override = default;

// Implementation of the base class methods.
size_t byte_size() const override {
return this->size() * sizeof(absl::string_view);
}
const void* data() const override {
return this->std::vector<absl::string_view>::data();
}

// Copy the MutableUnownedStringData into a string.
std::string EncodeContent() {
return std::string(reinterpret_cast<const char*>(this->data()),
this->byte_size());
}

// Create and return a new MutableUnownedStringData populated with the data
// from content.
static std::unique_ptr<MutableUnownedStringData> CreateFromEncodedContent(
const std::string& content) {
const absl::string_view* data =
reinterpret_cast<const absl::string_view*>(content.data());
return std::make_unique<MutableUnownedStringData>(
data, data + content.size() / sizeof(absl::string_view));
}
};

} // namespace aggregation
} // namespace tensorflow_federated

#endif // TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h"

#include <string>
#include <vector>

#include "googlemock/include/gmock/gmock.h"
#include "googletest/include/gtest/gtest.h"
#include "absl/strings/string_view.h"
#include "tensorflow_federated/cc/testing/status_matchers.h"

namespace tensorflow_federated {
namespace aggregation {
namespace {

TEST(MutableUnownedStringDataTest, MutableUnownedStringDataValid) {
std::string string_1 = "foo";
std::string string_2 = "bar";
std::string string_3 = "baz";
MutableUnownedStringData vector_data;
vector_data.push_back(absl::string_view(string_1));
vector_data.push_back(absl::string_view(string_2));
vector_data.push_back(absl::string_view(string_3));
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk());
}

TEST(MutableUnownedStringDataTest, EncodeDecodeSucceeds) {
std::string string_1 = "foo";
std::string string_2 = "bar";
std::string string_3 = "baz";
MutableUnownedStringData vector_data;
vector_data.push_back(absl::string_view(string_1));
vector_data.push_back(absl::string_view(string_2));
vector_data.push_back(absl::string_view(string_3));
std::string encoded_vector_data = vector_data.EncodeContent();
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk());
auto decoded_vector_data =
MutableUnownedStringData::CreateFromEncodedContent(encoded_vector_data);
EXPECT_THAT(decoded_vector_data->CheckValid<absl::string_view>(), IsOk());
EXPECT_EQ((*decoded_vector_data)[0], string_1);
EXPECT_EQ((*decoded_vector_data)[1], string_2);
EXPECT_EQ((*decoded_vector_data)[2], string_3);
}

} // namespace
} // namespace aggregation
} // namespace tensorflow_federated
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cstddef>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>

#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h"
Expand All @@ -30,7 +31,7 @@ namespace aggregation {
// MutableVectorData implements TensorData by wrapping std::vector and using it
// as a backing storage. MutableVectorData can be mutated using std::vector
// methods.
template <typename T>
template <typename T, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0>
class MutableVectorData : public std::vector<T>, public TensorData {
public:
// Derive constructors from the base vector class.
Expand Down

0 comments on commit e1b0be3

Please sign in to comment.