Skip to content

Commit

Permalink
Update beta sampling code in augment.py
Browse files Browse the repository at this point in the history
The function `_sample_from_beta(alpha, beta, shape)` in `MixupAndCutmix` class, is not having the same functionality as `numpy.random.beta`. So `tfm.vision.augment.MixupAndCutmix._sample_from_beta(0.2, 0.2, tf.shape( tf.range(10000))).numpy()` is also deviating as well. So suggesting the fix keeping `alpha=alpha, beta=1.0` in  `_sample_from_beta`. The reproduced [gist](https://colab.sandbox.google.com/gist/LakshmiKalaKadali/06533824610d6e85ea4aa3c6399819e6/tf_model_13490.ipynb#scrollTo=zSlE-3YDjL91) also attached. 

This PR closes [#13490](#13490)

Thank You
  • Loading branch information
LakshmiKalaKadali authored Jan 30, 2025
1 parent 3d5e05f commit bdbcbaa
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions official/vision/ops/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,8 +2697,8 @@ def distort(self, images: tf.Tensor,

@staticmethod
def _sample_from_beta(alpha, beta, shape):
sample_alpha = tf.random.gamma(shape, 1., beta=alpha)
sample_beta = tf.random.gamma(shape, 1., beta=beta)
sample_alpha = tf.random.gamma(shape, alpha, beta=1.0)
sample_beta = tf.random.gamma(shape, alpha, beta=1.0)
return sample_alpha / (sample_alpha + sample_beta)

def _cutmix(self, images: tf.Tensor,
Expand Down

0 comments on commit bdbcbaa

Please sign in to comment.