diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/BUILD b/tensorflow_federated/cc/core/impl/aggregation/core/BUILD index f16873195..167694e04 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/BUILD +++ b/tensorflow_federated/cc/core/impl/aggregation/core/BUILD @@ -70,6 +70,7 @@ cc_library( "agg_vector_iterator.h", "datatype.h", "input_tensor_list.h", + "mutable_unowned_string_data.h", "mutable_vector_data.h", "tensor.h", "tensor_data.h", @@ -566,6 +567,17 @@ cc_test( ], ) +cc_test( + name = "mutable_unowned_string_data_test", + srcs = ["mutable_unowned_string_data_test.cc"], + deps = [ + ":tensor", + "//tensorflow_federated/cc/testing:oss_test_main", + "//tensorflow_federated/cc/testing:status_matchers", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_test( name = "vector_string_data_test", srcs = [ diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h b/tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h new file mode 100644 index 000000000..266cd113e --- /dev/null +++ b/tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h @@ -0,0 +1,72 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_ +#define TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h" + +namespace tensorflow_federated { +namespace aggregation { + +// MutableUnownedStringData implements TensorData by wrapping std::vector and +// using it as backing storage for string_view objects. MutableUnownedStringData +// can be mutated using std::vector methods. The MutableUnownedStringData object +// does not own the string values. Use MutableStringData instead if you want +// a TensorData object that owns the strings. +class MutableUnownedStringData : public std::vector, + public TensorData { + public: + // Derive constructors from the base vector class. + using std::vector::vector; + + ~MutableUnownedStringData() override = default; + + // Implementation of the base class methods. + size_t byte_size() const override { + return this->size() * sizeof(absl::string_view); + } + const void* data() const override { + return this->std::vector::data(); + } + + // Copy the MutableUnownedStringData into a string. + std::string EncodeContent() { + return std::string(reinterpret_cast(this->data()), + this->byte_size()); + } + + // Create and return a new MutableUnownedStringData populated with the data + // from content. + static std::unique_ptr CreateFromEncodedContent( + const std::string& content) { + const absl::string_view* data = + reinterpret_cast(content.data()); + return std::make_unique( + data, data + content.size() / sizeof(absl::string_view)); + } +}; + +} // namespace aggregation +} // namespace tensorflow_federated + +#endif // TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_CORE_MUTABLE_UNOWNED_STRING_DATA_H_ diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data_test.cc b/tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data_test.cc new file mode 100644 index 000000000..b97afae84 --- /dev/null +++ b/tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data_test.cc @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h" + +#include +#include + +#include "googlemock/include/gmock/gmock.h" +#include "googletest/include/gtest/gtest.h" +#include "absl/strings/string_view.h" +#include "tensorflow_federated/cc/testing/status_matchers.h" + +namespace tensorflow_federated { +namespace aggregation { +namespace { + +TEST(MutableUnownedStringDataTest, MutableUnownedStringDataValid) { + std::string string_1 = "foo"; + std::string string_2 = "bar"; + std::string string_3 = "baz"; + MutableUnownedStringData vector_data; + vector_data.push_back(absl::string_view(string_1)); + vector_data.push_back(absl::string_view(string_2)); + vector_data.push_back(absl::string_view(string_3)); + EXPECT_THAT(vector_data.CheckValid(), IsOk()); +} + +TEST(MutableUnownedStringDataTest, EncodeDecodeSucceeds) { + std::string string_1 = "foo"; + std::string string_2 = "bar"; + std::string string_3 = "baz"; + MutableUnownedStringData vector_data; + vector_data.push_back(absl::string_view(string_1)); + vector_data.push_back(absl::string_view(string_2)); + vector_data.push_back(absl::string_view(string_3)); + std::string encoded_vector_data = vector_data.EncodeContent(); + EXPECT_THAT(vector_data.CheckValid(), IsOk()); + auto decoded_vector_data = + MutableUnownedStringData::CreateFromEncodedContent(encoded_vector_data); + EXPECT_THAT(decoded_vector_data->CheckValid(), IsOk()); + EXPECT_EQ((*decoded_vector_data)[0], string_1); + EXPECT_EQ((*decoded_vector_data)[1], string_2); + EXPECT_EQ((*decoded_vector_data)[2], string_3); +} + +} // namespace +} // namespace aggregation +} // namespace tensorflow_federated diff --git a/tensorflow_federated/cc/core/impl/aggregation/core/mutable_vector_data.h b/tensorflow_federated/cc/core/impl/aggregation/core/mutable_vector_data.h index 9fec2ece4..ef68a76ae 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/core/mutable_vector_data.h +++ b/tensorflow_federated/cc/core/impl/aggregation/core/mutable_vector_data.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h" @@ -30,7 +31,7 @@ namespace aggregation { // MutableVectorData implements TensorData by wrapping std::vector and using it // as a backing storage. MutableVectorData can be mutated using std::vector // methods. -template +template , int> = 0> class MutableVectorData : public std::vector, public TensorData { public: // Derive constructors from the base vector class. diff --git a/tensorflow_federated/cc/core/impl/aggregation/testing/BUILD b/tensorflow_federated/cc/core/impl/aggregation/testing/BUILD index 324986b3e..980755718 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/testing/BUILD +++ b/tensorflow_federated/cc/core/impl/aggregation/testing/BUILD @@ -81,5 +81,9 @@ cc_library( testonly = True, hdrs = ["test_data.h"], visibility = ["//visibility:public"], - deps = ["//tensorflow_federated/cc/core/impl/aggregation/core:tensor"], + deps = [ + "//tensorflow_federated/cc/core/impl/aggregation/core:mutable_string_data", + "//tensorflow_federated/cc/core/impl/aggregation/core:tensor", + "@com_google_absl//absl/strings:string_view", + ], ) diff --git a/tensorflow_federated/cc/core/impl/aggregation/testing/test_data.h b/tensorflow_federated/cc/core/impl/aggregation/testing/test_data.h index 098b0f8f2..e6e3e2f57 100644 --- a/tensorflow_federated/cc/core/impl/aggregation/testing/test_data.h +++ b/tensorflow_federated/cc/core/impl/aggregation/testing/test_data.h @@ -19,18 +19,41 @@ #include #include +#include +#include +#include "absl/strings/string_view.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_string_data.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_unowned_string_data.h" #include "tensorflow_federated/cc/core/impl/aggregation/core/mutable_vector_data.h" +#include "tensorflow_federated/cc/core/impl/aggregation/core/tensor_data.h" namespace tensorflow_federated::aggregation { -// Creates test tensor data based on a vector. +// Creates test tensor data based on a vector for all arithmetic types. template -std::unique_ptr> CreateTestData( - std::initializer_list values) { +std::unique_ptr CreateTestData(std::initializer_list values) { return std::make_unique>(values); } +// Creates test tensor data based on a vector. +template <> +std::unique_ptr CreateTestData( + std::initializer_list values) { + return std::make_unique(values); +} + +// Creates test tensor data based on a vector. +template <> +std::unique_ptr CreateTestData( + std::initializer_list values) { + auto data = std::make_unique(values.size()); + for (auto value : values) { + data->Add(std::move(value)); + } + return data; +} + } // namespace tensorflow_federated::aggregation #endif // THIRD_PARTY_TENSORFLOW_FEDERATED_CC_CORE_IMPL_AGGREGATION_TESTING_TEST_DATA_H_