Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent MutableVectorData from being used for std::string. #4994

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tensorflow_federated/cc/core/impl/aggregation/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ cc_library(
"agg_vector_iterator.h",
"datatype.h",
"input_tensor_list.h",
"mutable_unowned_string_data.h",
"mutable_vector_data.h",
"tensor.h",
"tensor_data.h",
Expand Down Expand Up @@ -566,6 +567,17 @@ cc_test(
],
)

cc_test(
name = "mutable_unowned_string_data_test",
srcs = ["mutable_unowned_string_data_test.cc"],
deps = [
":tensor",
"//tensorflow_federated/cc/testing:oss_test_main",
"//tensorflow_federated/cc/testing:status_matchers",
"@com_google_absl//absl/strings:string_view",
],
)

cc_test(
name = "vector_string_data_test",
srcs = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_
#define TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_

#include <cstddef>
#include <memory>
#include <string>
#include <vector>

#include "absl/strings/string_view.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h"

namespace tensorflow_federated {
namespace aggregation {

// MutableUnownedStringData implements TensorData by wrapping std::vector and
// using it as backing storage for string_view objects. MutableUnownedStringData
// can be mutated using std::vector methods. The MutableUnownedStringData object
// does not own the string values. Use MutableStringData instead if you want
// a TensorData object that owns the strings.
class MutableUnownedStringData : public std::vector<absl::string_view>,
public TensorData {
public:
// Derive constructors from the base vector class.
using std::vector<absl::string_view>::vector;

~MutableUnownedStringData() override = default;

// Implementation of the base class methods.
size_t byte_size() const override {
return this->size() * sizeof(absl::string_view);
}
const void* data() const override {
return this->std::vector<absl::string_view>::data();
}

// Copy the MutableUnownedStringData into a string.
std::string EncodeContent() {
return std::string(reinterpret_cast<const char*>(this->data()),
this->byte_size());
}

// Create and return a new MutableUnownedStringData populated with the data
// from content.
static std::unique_ptr<MutableUnownedStringData> CreateFromEncodedContent(
const std::string& content) {
const absl::string_view* data =
reinterpret_cast<const absl::string_view*>(content.data());
return std::make_unique<MutableUnownedStringData>(
data, data + content.size() / sizeof(absl::string_view));
}
};

} // namespace aggregation
} // namespace tensorflow_federated

#endif // TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h"

#include <string>
#include <vector>

#include "googlemock/include/gmock/gmock.h"
#include "googletest/include/gtest/gtest.h"
#include "absl/strings/string_view.h"
#include "tensorflow_federated/cc/testing/status_matchers.h"

namespace tensorflow_federated {
namespace aggregation {
namespace {

TEST(MutableUnownedStringDataTest, MutableUnownedStringDataValid) {
std::string string_1 = "foo";
std::string string_2 = "bar";
std::string string_3 = "baz";
MutableUnownedStringData vector_data;
vector_data.push_back(absl::string_view(string_1));
vector_data.push_back(absl::string_view(string_2));
vector_data.push_back(absl::string_view(string_3));
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk());
}

TEST(MutableUnownedStringDataTest, EncodeDecodeSucceeds) {
std::string string_1 = "foo";
std::string string_2 = "bar";
std::string string_3 = "baz";
MutableUnownedStringData vector_data;
vector_data.push_back(absl::string_view(string_1));
vector_data.push_back(absl::string_view(string_2));
vector_data.push_back(absl::string_view(string_3));
std::string encoded_vector_data = vector_data.EncodeContent();
EXPECT_THAT(vector_data.CheckValid<absl::string_view>(), IsOk());
auto decoded_vector_data =
MutableUnownedStringData::CreateFromEncodedContent(encoded_vector_data);
EXPECT_THAT(decoded_vector_data->CheckValid<absl::string_view>(), IsOk());
EXPECT_EQ((*decoded_vector_data)[0], string_1);
EXPECT_EQ((*decoded_vector_data)[1], string_2);
EXPECT_EQ((*decoded_vector_data)[2], string_3);
}

} // namespace
} // namespace aggregation
} // namespace tensorflow_federated
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cstddef>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>

#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h"
Expand All @@ -30,7 +31,7 @@ namespace aggregation {
// MutableVectorData implements TensorData by wrapping std::vector and using it
// as a backing storage. MutableVectorData can be mutated using std::vector
// methods.
template <typename T>
template <typename T, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0>
class MutableVectorData : public std::vector<T>, public TensorData {
public:
// Derive constructors from the base vector class.
Expand Down
6 changes: 5 additions & 1 deletion tensorflow_federated/cc/core/impl/aggregation/testing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,9 @@ cc_library(
testonly = True,
hdrs = ["test_data.h"],
visibility = ["//visibility:public"],
deps = ["//tensorflow_federated/cc/core/impl/aggregation/core:tensor"],
deps = [
"//tensorflow_federated/cc/core/impl/aggregation/core:mutable_string_data",
"//tensorflow_federated/cc/core/impl/aggregation/core:tensor",
"@com_google_absl//absl/strings:string_view",
],
)
29 changes: 26 additions & 3 deletions tensorflow_federated/cc/core/impl/aggregation/testing/test_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,41 @@

#include <initializer_list>
#include <memory>
#include <string>
#include <utility>

#include "absl/strings/string_view.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_string_data.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_vector_data.h"
#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h"

namespace tensorflow_federated::aggregation {

// Creates test tensor data based on a vector<T>.
// Creates test tensor data based on a vector<T> for all arithmetic types.
template <typename T>
std::unique_ptr<MutableVectorData<T>> CreateTestData(
std::initializer_list<T> values) {
std::unique_ptr<TensorData> CreateTestData(std::initializer_list<T> values) {
return std::make_unique<MutableVectorData<T>>(values);
}

// Creates test tensor data based on a vector<absl::string_view>.
template <>
std::unique_ptr<TensorData> CreateTestData(
std::initializer_list<absl::string_view> values) {
return std::make_unique<MutableUnownedStringData>(values);
}

// Creates test tensor data based on a vector<std::string>.
template <>
std::unique_ptr<TensorData> CreateTestData(
std::initializer_list<std::string> values) {
auto data = std::make_unique<MutableStringData>(values.size());
for (auto value : values) {
data->Add(std::move(value));
}
return data;
}

} // namespace tensorflow_federated::aggregation

#endif // THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_TESTING_TEST_DATA_H_