Skip to content

Commit

Permalink
Clean up Python bindings and improve code formatting in wtsne.hpp
Browse files Browse the repository at this point in the history
  • Loading branch information
absternator committed Feb 25, 2025
1 parent 0243945 commit fce872b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
3 changes: 1 addition & 2 deletions src/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
#include "sound.hpp"
#include "wtsne.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;

PYBIND11_MODULE(SCE, m) {
m.doc() = "Stochastic cluster embedding";
m.attr("version") = VERSION_INFO;
Expand Down
11 changes: 5 additions & 6 deletions src/wtsne.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ std::vector<double> conditional_probabilities(const std::vector<uint64_t> &I,
real_t beta_min = -std::numeric_limits<real_t>::max();
real_t beta_max = std::numeric_limits<real_t>::max();
real_t beta = 1.0;

real_t sum_Pi, sum_disti_Pi, entropy, entropy_diff;
for (int l = 0; l < n_steps; ++l) {
sum_Pi = 0.0;
for (uint64_t j = row_start_idx[sample_idx]; j < row_start_idx[sample_idx + 1]; ++j)
{
for (uint64_t j = row_start_idx[sample_idx];
j < row_start_idx[sample_idx + 1]; ++j) {
P[j] = std::exp(-dists[j] * beta);
sum_Pi += P[j];
}
Expand All @@ -91,8 +90,8 @@ std::vector<double> conditional_probabilities(const std::vector<uint64_t> &I,
}
sum_disti_Pi = 0.0;

for (uint64_t j = row_start_idx[sample_idx]; j < row_start_idx[sample_idx + 1]; ++j)
{
for (uint64_t j = row_start_idx[sample_idx];
j < row_start_idx[sample_idx + 1]; ++j) {
P[j] /= sum_Pi;
sum_disti_Pi += dists[j] * P[j];
}
Expand Down Expand Up @@ -148,7 +147,7 @@ wtsne_init(const std::vector<uint64_t> &I, const std::vector<uint64_t> &J,
// Preprocess distances
std::vector<double> P =
conditional_probabilities<real_t>(I, J, dists, nn, perplexity, n_threads);

// Normalise distances and weights
normalise_vector(P, true, n_threads);
normalise_vector(weights, true, n_threads);
Expand Down

0 comments on commit fce872b

Please sign in to comment.