Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Script Issues #15

Open
djbyrne opened this issue Mar 18, 2024 · 0 comments
Open

Test Script Issues #15

djbyrne opened this issue Mar 18, 2024 · 0 comments

Comments

@djbyrne
Copy link

djbyrne commented Mar 18, 2024

Hi Hao,

First off, big thank you for the huge amount of work that has gone into open sourcing the implementation of your research, it is highly appreciated!

While going through the repo and trying to deeply understand the method I discovered that there are some issues with the test script.

  1. the test script does not appear to be running different attention methods and is only ever comparing against the default setting. My initial impression from the code was that by setting the 'attention_label' it would update the config and run the attention mechanism associate with that label (i.e standard, ring blockwise etc.) however after further inspection it seems like this no longer does anything and the method will always run based on what has been defined in the base config using the scan_attention, scan_mlp, scan_layers and mesh_dim arguments. In order to actually compare methods you have to update the config at each iteration.
for attention_type in attention_types:
        llama_config_copy = copy.deepcopy(llama_config)
        llama_config_copy.update(dict(attention_type=attention_type))
        if attention_type == ['standard']:
            llama_config_copy.update(dict(scan_attention=False, scan_mlp=False, scan_layers=False, remat_attention='', remat_mlp='',  mesh_dim='1,-1,2,1'))
            llama_config_copy.update(dict(attention_type=attention_type))
        elif attention_type == 'ring_blockwise':
            llama_config_copy.update(dict(scan_attention=True, scan_mlp=True, scan_layers=True, mesh_dim='1,1,2,-1'))
            llama_config_copy.update(dict(attention_type=attention_type))
            llama_config_copy.update(dict(scan_query_chunk_size=1024, scan_key_chunk_size=1024, scan_mlp_chunk_size=1024))
        model = FlaxLLaMAForCausalLMModule(
            llama_config_copy, dtype=get_float_dtype_by_name(FLAGS.dtype)
        )
        models.append(model)
    model = models[0]
  1. it appears that it isn't possible to change the mesh_dims as this is defined once at the start of the testing and is used as a context manager for the whole test. So I think we can't change between ring and blockwise during the test.

  2. It doesn't look like the grads being returned are a 'FrozenDict' , so the unfreeze at line 163 is not needed (I think its fine that its not frozen in this case).

  3. After applying my naive updates to compare Standard with Ring I am now seeing a larger diff in the logits and grads then expected.

standard
logits: 0.0 1.6717689 1.6717689
grads: 0.0 0.11031877 0.11031877

ring_blockwise
logits: 0.0044222176 1.6717689 1.6717689
grads: 6.278977e-05 0.11030923 0.11031877

Is this similar to your own results or should the results be more aligned to Standard Attention as my understanding is that the Blockwise Ring Attention is numerically equivalent. Please could you confirm if my configs are correct for comparing these methods, there is a good chance I have made a mistake somewhere. For reference, I am running on a TPU v4-8, so I only have 4 devices.

Would like to confirm if you agree with these observations, or have I just done something silly when applying my changes? If these are in-fact issues that have crept in I am happy to submit a fix 😃

Cheers,

Donal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant