diff --git a/pomdp_py/representations/distribution/particles.pyx b/pomdp_py/representations/distribution/particles.pyx index 7997dd9..323208f 100644 --- a/pomdp_py/representations/distribution/particles.pyx +++ b/pomdp_py/representations/distribution/particles.pyx @@ -53,6 +53,14 @@ cdef class WeightedParticles(GenerativeDistribution): def frozen(self): return self._frozen + @property + def hist(self): + return self._hist + + @property + def hist_valid(self): + return self._hist_valid + def add(self, particle): """add(self, particle) particle: (value, weight) tuple""" @@ -77,7 +85,7 @@ cdef class WeightedParticles(GenerativeDistribution): def __eq__(self, other): if isinstance(other, WeightedParticles): - return self._hist == other._hist + return self._hist == other.hist return False def __getitem__(self, value): @@ -161,7 +169,7 @@ cdef class WeightedParticles(GenerativeDistribution): Returns a new set of weighted particles with unique values and weights aggregated (taken average). """ - return WeightedParticles.from_histogram(self.get_histogram()) + return WeightedParticles.from_histogram(self.get_histogram(), frozen=self._frozen) cdef class Particles(WeightedParticles):