diff --git a/tensorflow_data_validation/anomalies/BUILD b/tensorflow_data_validation/anomalies/BUILD index 0ec1277..2f2dbba 100644 --- a/tensorflow_data_validation/anomalies/BUILD +++ b/tensorflow_data_validation/anomalies/BUILD @@ -116,6 +116,7 @@ cc_library( "@com_github_tensorflow_metadata//tensorflow_metadata/proto/v0:metadata_v0_proto_cc_pb2", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow_data_validation/anomalies/metrics.cc b/tensorflow_data_validation/anomalies/metrics.cc index e93768d..b98d4bd 100644 --- a/tensorflow_data_validation/anomalies/metrics.cc +++ b/tensorflow_data_validation/anomalies/metrics.cc @@ -17,15 +17,18 @@ limitations under the License. #include #include -#include +#include #include -#include -#include +#include +#include +#include #include #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/types/optional.h" #include "tensorflow_data_validation/anomalies/map_util.h" +#include "tensorflow_data_validation/anomalies/statistics_view.h" #include "tensorflow_data_validation/anomalies/status_util.h" #include "tensorflow_metadata/proto/v0/schema.pb.h" #include "tensorflow_metadata/proto/v0/statistics.pb.h" @@ -356,6 +359,8 @@ absl::Status JensenShannonDivergence(Histogram& histogram_1, KullbackLeiblerDivergence(histogram_2, average_distribution_histogram)) / 2); + // Due to precision limitations, the result will be capped at 1.0. + result = std::min(result, 1.0); return absl::OkStatus(); } @@ -405,7 +410,7 @@ absl::Status JensenShannonDivergence(const std::map& map_1, kl_sum += b_ele_prob * std::log2(b_ele_prob / m); } } - result = kl_sum/2; + result = std::min(kl_sum / 2, 1.0); return absl::OkStatus(); }