Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Canary test] GridMask: Vectorized function call on the batch dimension #146

Closed
wants to merge 3 commits into from

Conversation

bhack
Copy link
Contributor

@bhack bhack commented Feb 22, 2022

As we have discussed in #143 (comment) this is just a canary (failing) test (check the CI):

ValueError: Input "maxval" of op 'RandomUniformInt' expected to be loop invariant.

As I've mentioned in the thread we really need to understand if we want to have randomness inside the batch or between the batches and what kind of impact we have between the computing overhead, contributing speed/code readability and network convergence.

Also I don't know if @joker-eph or @qlzh727 could expose us a little bit the pro and cons of jit_compile a function vs using the vectorized_map or if they are orthogonal.

With many CV transformations we cannot compile the function as the underline tf.raw_ops.ImageProjectiveTransformV3 op isn't supported by XLA.

/cc @chjort

@qlzh727
Copy link
Member

qlzh727 commented Feb 22, 2022

Thanks for bring up this issue.

  1. For Image KPL, we should have randomness within the batch, not just between the batches.
  2. There will be computation overhead if we use tf.map_fn to achieve the randomness within batch. We hope this can be reduced by tf.function with jit_compile=True (XLA support). This is at the cost of code readability, since user just need to implement the per image augmentation, and leave the batch/vectorization to the framework.
  3. When the image kpl is used by data input pipeline, the individual KPL might not be a bottleneck, as long as it is not the lowest performant component. The dataset.prefetch() could let the next batch preprocessing ongoing, while the network is doing the forward/backward pass.
  4. I do see a common roadblock for using the tf.vectorized_map here, since each of the image preprocess involve some of the RNG here, which usually reads from a stateful random op, and it very likely to cause the tf.vectorized_map to fail.

@fchollet and @LukeWood for more inputs as well.

@joker-eph, do u have any suggestion about tf.vectorized_map that reads a RNG within the fn?

@bhack
Copy link
Contributor Author

bhack commented Feb 22, 2022

  1. When the image kpl is used by data input pipeline, the individual KPL might not be a bottleneck, as long as it is not the lowest performant component. The dataset.prefetch() could let the next batch preprocessing ongoing, while the network is doing the forward/backward pass.

Probably this could be true when this is done on CPU with enough computational margin and parallel overlapping with the GPU forward/backward kernels on the GPU stream.
This is less true when you are not going to handle very large models and at the same time you are accumulating many augmentations in your pipeline.

But we need to be performant enough cause if the GPU will start to be hungry of data sometime is faster to move some preprocessing directly on the GPU device. E.g. see the rationale of Nvidia DALI project.

So we need to be reasonable with performance and at least we need to rely with jit_compile and/or tf.vectorized_map and have clear interaction path and support with the TF compiler team if something is not working, consuming too much memory, it is too slow, etc..

Other then having a green light for jit_compile/tf.vectirized_map (see tf.bincount, tf.raw_ops.ImageProjectiveTransformV3 or random handling ) I think that the main problem is how to settle the expected performance if we don't have a reference implementation and some reference devices.

More in general see our endless thread on the forum /cc @jpienaar

  1. I do see a common roadblock for using the tf.vectorized_map here, since each of the image preprocess involve some of the RNG here, which usually reads from a stateful random op, and it very likely to cause the tf.vectorized_map to fail.

In the documentation we currently have:

Stateful kernels may mostly not be supported since these often imply a data dependency. We do support a limited set of such stateful kernels though (like RandomFoo, Variable operations like reads, etc).

@bhack
Copy link
Contributor Author

bhack commented Feb 22, 2022

Just an extra note about 3) see my 2018 thread at NVIDIA/DALI#247 (comment)

@qlzh727
Copy link
Member

qlzh727 commented Feb 22, 2022

Re 3: I share the similar concern of performance and code readability for augment single image and a batch of images when I worked on keras-team/keras@9628af8.

We could add augment_images() which update batch of image as the public API, allow user to do vectorized augment if there are native ops support.

@bhack
Copy link
Contributor Author

bhack commented Feb 22, 2022

We could add augment_images() which update batch of image as the public API, allow user to do vectorized augment if there are native ops support.

What do you mean here? Vectorized with vectorized_map?
Cause earlier in this repo we have forked vectorized native ops mainly for point n. 1) (randomness within the batch, not just between the batches).

@qlzh727
Copy link
Member

qlzh727 commented Feb 22, 2022

We could add augment_images() which update batch of image as the public API, allow user to do vectorized augment if there are native ops support.

What do you mean here? Vectorized with vectorized_map? Cause earlier in this repo we have forked vectorized native ops mainly for point n. 1) (randomness within the batch, not just between the batches).

I mean if the native vectorized op support randomness within the batch, then we should use them. Otherwise, vectorized_map with single image should be preferred. If vectorized_map is blocked due to certain limit (eg stateful op in the fn), we should fallback to tf.map_fn

@bhack
Copy link
Contributor Author

bhack commented Feb 22, 2022

I mean if the native vectorized op support randomness within the batch, then we should use them. Otherwise, vectorized_map with single image should be preferred. If vectorized_map is blocked due to certain limit (eg stateful op in the fn), we should fallback to tf.map_fn

But this point is exactly the beginning of this thread. Can you just approve the CI run so you can see the failing output?

If you need to vectorize but we need to have a mandatory "within the batch" randomization we could currently have a systematic limit with tf.vectorized_map.
That's why I was asking if we have some consistent training convergence advantage about the within the batch policy in some reference paper ablations study.

At the same time we have also some generic compile limits (jit_compile) for many transformations (like rotation) for the jit_compile approach for tf.raw_ops.ImageProjectiveTransformV3.

@qlzh727
Copy link
Member

qlzh727 commented Feb 22, 2022

I don't the numerical result that the randomness within batch will have advantage of convergence, but from math point of view, the image random preprocess shouldn't have any dependency of the batch size. In the extreme case, if you have a very large batch size that is closer to your size of training data, then it means your model is overfitting to your specific augmentation within the batch. With distribution strategy and most of the data parallel solution, we are having larger batch size overall, and it could potential be a problem.

@bhack
Copy link
Contributor Author

bhack commented Feb 22, 2022

How many real cases we have batch_size=whole dataset?
Zero-shot/few-shot?
Is it plausibile that we need to use augmentation with these two regimes?

About the distributed training it is a different topic as we can find a way with TF to be random x device/node batch.

@bhack
Copy link
Contributor Author

bhack commented Feb 22, 2022

An alternative approach could be:

https://arxiv.org/abs/1901.09335

Edit: 4.4

Specifically, we achieve this effect by
synchronizing the random seeds of the dataset samplers in
every M nodes (but not the data augmentation seeds).

@qlzh727
Copy link
Member

qlzh727 commented Feb 22, 2022

The batch_size == len(dataset) is an extreme case, but even without that assumption, with large enough batch size, the regularization and normalization we expected to get from the random augmentation is reduced if the same batch get the same random behavior. From the math point of view, you would like your model to learn from different look of inputs, so that it can generalize, and not overfitting to specific input it has already seen. Having the same randomness within the batch just add some extra to this problem, which we don't know how large the impact would be. The question here to me here is that whether we want to trade off between computation performance against accuracy.

From the paper you provided, the need of creating different augmentation within the batch is clear.

@bhack
Copy link
Contributor Author

bhack commented Feb 22, 2022

If you want to have a trade off with large batches you could probably augment not per single sample as map_fn + random but few data samples/images with the same random sample and then repeat the same (or eventually) another sub-batch with a new random value augmentation.

So you could vectorize x sub-batches to compose a single batch.

See also:
https://openaccess.thecvf.com/content_CVPR_2020/html/Hoffer_Augment_Your_Batch_Improving_Generalization_Through_Instance_Repetition_CVPR_2020_paper.html

@bhack
Copy link
Contributor Author

bhack commented Feb 23, 2022

From the paper you provided, the need of creating different augmentation within the batch is clear.

I mentioned that for the distributed setting in 4.4 with a different not syncronized random augmentation seed for every node.

@bhack
Copy link
Contributor Author

bhack commented Feb 23, 2022

Also we have just merged this commit in the Doc/website

tensorflow/docs@f292e1c#diff-8e0fa0ce3afd5d39976905d770a157069fc2716bb47b449ec191027193656ac6R688

@bhack
Copy link
Contributor Author

bhack commented Feb 23, 2022

We are almost discuraging to use Keras in-model preprocessing with distributed strategies.

@bhack
Copy link
Contributor Author

bhack commented Feb 23, 2022

I don't know if @w-xinyi has some specific feedback for this distributed and preprocessing claim that he committed on the doc/website.

@qlzh727
Copy link
Member

qlzh727 commented Feb 23, 2022

Btw, in this change, you are using fallback_to_while_loop = False, which will raise error when loop related ops are encountered. Could u change it to true and see if the error still occurs?

Note that from the API doc:

fallback_to_while_loop | If true, on failing to vectorize an operation, the unsupported op is wrapped in 
a tf.while_loop to execute the map iterations. Note that this fallback only happens for unsupported ops
and other parts of fn are still  vectorized. If false, on encountering an unsupported op, a ValueError is 
thrown. Note that the fallbacks can result in slowdowns since vectorization often yields speedup of 
one to two orders of magnitude.

Since only the RNG part will be in a tf.while_loop, I guess it will still achieve some performance gain since the main workload is in the image augmentation ops.

@bhack
Copy link
Contributor Author

bhack commented Feb 23, 2022

Btw, in this change, you are using fallback_to_while_loop = False, which will raise error when loop related ops are encountered. Could u change it to true and see if the error still occurs?

Yes I know but it was already failing also with the fallback, I switch it to true so you can see the CI output.

Other then this, the doc claim is a little bit ambiguous, so who is going to analyze the HLO graph in details with or without the fallback if we don't have a non-fallback reference performance? This will not help up to select our batch assembly policy.

Also if we are not going to consider the instance repetition (duplication) with augmentation we are going to increase also the IO pressure on the filesystem.

@qlzh727
Copy link
Member

qlzh727 commented Feb 23, 2022

Humm, it is unclear to me why the intermediate result in the while loop produces different shape of results and cause the pfor to fail. @wangpengmit and @mdanatg from tf core team.

@bhack
Copy link
Contributor Author

bhack commented Feb 24, 2022

Humm, it is unclear to me why the intermediate result in the while loop produces different shape of results and cause the pfor to fail. @wangpengmit and @mdanatg from tf core team.

@qlzh727 Also if the fallback was working correctly I doubt that we could have too much speed improvement with this approach. If you check the CI output it is not only the Random op itself but also many ops are going to fallback in the while_loop.

WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting RandomUniformInt
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting RandomUniformInt
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting RngReadAndSkip
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Bitcast
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Bitcast
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting StatelessRandomUniformV2
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting ImageProjectiveTransformV3
```

@bhack
Copy link
Contributor Author

bhack commented Mar 3, 2022

@qlzh727 Do you have any feedback from the XLA team members that you have mentioned here?

@qlzh727
Copy link
Member

qlzh727 commented Mar 7, 2022

Sorry for the late reply, We haven't heard any updates from core team yet for the pfor issue. I think we need to take a closer look for the implementation details about the grid_mask. It might miss some shape information somewhere, so that each step of the iteration might get different shape result based on the input value.

@bhack
Copy link
Contributor Author

bhack commented Mar 7, 2022

It might miss some shape information somewhere, so that each step of the iteration might get different shape result based on the input value.

Or these could be cause by our randomization policy. ImageProjectiveTransformV3 in the list it is also not covered by XLA.

@wangpengmit
Copy link

maxval must be a loop invariant (e.g. a constant), so things like

shape=[], minval=1, maxval=gridblock + 1, dtype=tf.int32, seed=self.seed
don't fly in vectorized_map. Try using a constant maxval (e.g. 1.0) and then transforming the result of tf.random.uniform using gridblock.

@bhack
Copy link
Contributor Author

bhack commented Mar 8, 2022

@wangpengmit As the grid_block currently is really different for each image in the batch I think it is really related to our randomization policy in the batch as we discussed above in this thread.
As a side node could you clarify if jit_compile and vectorized_map are orthogonal (see also #165)?
Other then this what it would be the final effect on performances when we don't have a supported ops like ImageProjectiveTransformV3 for jit_compile and vectorized_map cases? Is the effect on performance similar?

We have tried to open a ticket in the TF core repository when we found something similar like in tensorflow/tensorflow#54479

@wangpengmit
Copy link

jit_compile and vectorized_map are orthogonal.

Unsupported ops will cause vectorized_map to fall back to tf.while_loop, and jit_compile to error out completely.

For grid_block, how about

u = tf.random.uniform(shape=shape, minval=0.0, maxval=1.0, dtype=tf.float32)
length = tf.cast(u * gridblock + 1, dtype=tf.int32)

?

@bhack
Copy link
Contributor Author

bhack commented Mar 9, 2022

We could refactor something to workaround the tf.ranom.uniform limits but the fallback list is quite long:

WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting RandomUniformInt
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting RandomUniformInt
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Fill
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Range
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Reshape
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting RngReadAndSkip
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Bitcast
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting Bitcast
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting StatelessRandomUniformV2
WARNING  tensorflow:pfor.py:1081 Using a while_loop for converting ImageProjectiveTransformV3

@wangpengmit
Copy link

wangpengmit commented Mar 9, 2022

Yeah, this is a vectorized_map quality issue which unfortunately doesn't have a quick fix. For stateless ops (e.g. ImageProjectiveTransformV3) we can gradually add vectorized_map rules for them. For stateful ops (e.g. RandomUniformInt, RngReadAndSkip) there really isn't anything vectorized_map can do other than fallback.

@bhack
Copy link
Contributor Author

bhack commented Mar 9, 2022

Yeah, this is a vectorized_map quality issue which unfortunately doesn't have a quick fix. For stateless ops (e.g. ImageProjectiveTransformV3) we can gradually add vectorized_map rules for them. For stateful ops (e.g. RandomUniformInt, RngReadAndSkip) there really isn't anything vectorized_map can do other than fallback.

So I suppose that we will not have a sensible performance gain with all these fallback.

I have few extra points:

  • Are Fill Range Rescape Bitcast etc.. fallback caused by the vectorized_map limited coverage like in ImageProjectiveTransformV3 or are they caused by the random policy in these ops args?
  • In the case they are part of a coverage limit how we need to handle these limits? Do you want that we open a ticket in the TF Github repository for each individual op?
  • Is there a way to retrieve the list of covered ops? E.g. The XLA list is not updated by years but I suppose that we could assume that the coverage is not the same as they are orthogonal.
  • ImageProjectiveTransformV3 is going to fallback with vectorized_map and fail with jit_compile. Do we need to open two separated tickets?

@wangpengmit
Copy link

  • Are Fill Range Rescape Bitcast etc.. fallback caused by the vectorized_map limited coverage like in ImageProjectiveTransformV3 or are they caused by the random policy in these ops args?

I don't know too much about vectorized_map. These ops seem easy to support. I can't think of how the failure to handle them is related to random policy.

  • In the case they are part of a coverage limit how we need to handle these limits? Do you want that we open a ticket in the TF Github repository for each individual op?

Yes please. Either a bug per op or a bug for all is fine. Please provide a minimal reproduce with just the op of interest in it to show that vectorize_map can't handle that op.

  • Is there a way to retrieve the list of covered ops? E.g. The XLA list is not updated by years but I suppose that we could assume that the coverage is not the same as they are orthogonal.

I don't know of such a list. It's a nice thing to have, ideally a matrix with ops and the supports they enjoy. Feel free to file a bug for this too.

  • ImageProjectiveTransformV3 is going to fallback with vectorized_map and fail with jit_compile. Do we need to open two separated tickets?

Yes please. jit_compile and vectorized_map are orthogonal.

@bhack
Copy link
Contributor Author

bhack commented Mar 10, 2022

I don't know too much about vectorized_map. These ops seem easy to support. I can't think of how the failure to handle them is related to random policy.

Do you know who is the API/Codeowner there? As It is really hard to interact with the right owner in TF

What I think it is not helping in the current pfor implementation is that we don't add in the log the main causes for the fallback.
I've added few log entries in this section and now we have a better analysis of the fallback causes:

Not vectorized variant inputs: Fill
Variant outputs: Range
Not vectorized variant inputs: Reshape
Not vectorized variant inputs: Fill
Variant outputs: Range
Not vectorized variant inputs: Reshape
Not vectorized variant inputs: Fill
Variant outputs: Range
Not vectorized variant inputs: Reshape
Not vectorized variant inputs: Fill
Variant outputs: Range
Not vectorized variant inputs: Reshape
Variant outputs: RngReadAndSkip
Variant outputs: Bitcast
Variant outputs: Bitcast
Variant outputs: StatelessRandomUniformV2
Variant outputs: ImageProjectiveTransformV3

@bhack
Copy link
Contributor Author

bhack commented Mar 10, 2022

The XLA ticket is at tensorflow/tensorflow#55194

Fallback loop "interplay/improved debugging/hints" at tensorflow/tensorflow#55192

@wangpengmit
Copy link

Do you know who is the API/Codeowner there? As It is really hard to interact with the right owner in TF

pfor unfortunately falls into the category of areas that don't have an owner right now. We'll fix the ownership soon. (CCing @rohan100jain)

@bhack
Copy link
Contributor Author

bhack commented Mar 15, 2022

@wangpengmit I've added this case in tensorflow/community#412.

Do we need to wait for a new Codeowner to review tensorflow/tensorflow#55192 ?

Instead, do we have a Codeowner for tensorflow/tensorflow#55194?

@wangpengmit
Copy link

Do we need to wait for a new Codeowner to review tensorflow/tensorflow#55192 ?

I've reviewed it.

Instead, do we have a Codeowner for tensorflow/tensorflow#55194?

Added an assignee.

@LukeWood
Copy link
Contributor

@bhack closing this as we are using BaseImageAugmentationLayer now. If there is a reason to, please feel free to re-open this PR to discuss further.

Thanks for the contribution and thanks for following up with the tensorflow team on the operations supported in vmap.

@wangpengmit
Copy link

CCing @ishark .

freedomtan pushed a commit to freedomtan/keras-cv that referenced this pull request Jul 20, 2023
* Add separable conv layer

* Fix comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants