Skip to content

Commit

Permalink
The result of JSD calculation can be slightly greater than 1.0 due to…
Browse files Browse the repository at this point in the history
… floating point error. This CL fixes the bug by capping the result at 1.0.

PiperOrigin-RevId: 692261708
  • Loading branch information
tfx-copybara committed Nov 1, 2024
1 parent 58af9e7 commit 036a880
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions tensorflow_data_validation/anomalies/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ cc_library(
"@com_github_tensorflow_metadata//tensorflow_metadata/proto/v0:metadata_v0_proto_cc_pb2",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
],
)

Expand Down
13 changes: 9 additions & 4 deletions tensorflow_data_validation/anomalies/metrics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ limitations under the License.

#include <algorithm>
#include <cmath>
#include <limits>
#include <iterator>
#include <map>
#include <numeric>
#include <string>
#include <set>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/types/optional.h"
#include "tensorflow_data_validation/anomalies/map_util.h"
#include "tensorflow_data_validation/anomalies/statistics_view.h"
#include "tensorflow_data_validation/anomalies/status_util.h"
#include "tensorflow_metadata/proto/v0/schema.pb.h"
#include "tensorflow_metadata/proto/v0/statistics.pb.h"
Expand Down Expand Up @@ -356,6 +359,8 @@ absl::Status JensenShannonDivergence(Histogram& histogram_1,
KullbackLeiblerDivergence(histogram_2,
average_distribution_histogram)) /
2);
// Due to precision limitations, the result will be capped at 1.0.
result = std::min(result, 1.0);
return absl::OkStatus();
}

Expand Down Expand Up @@ -405,7 +410,7 @@ absl::Status JensenShannonDivergence(const std::map<string, double>& map_1,
kl_sum += b_ele_prob * std::log2(b_ele_prob / m);
}
}
result = kl_sum/2;
result = std::min(kl_sum / 2, 1.0);

return absl::OkStatus();
}
Expand Down

0 comments on commit 036a880

Please sign in to comment.