-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathprimitives.py
859 lines (704 loc) · 25.1 KB
/
primitives.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import math
import os
from typing import Callable, List, Optional, Tuple
import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = (
deepspeed_is_installed
and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
)
if deepspeed_is_installed:
import deepspeed
if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
fastln_is_installed = os.getenv("LAYERNORM_TYPE", None) == "fast_layernorm"
if fastln_is_installed:
# LayerNorm is a time bottomneck, so we use a custom implementation.
from protenix.model.layer_norm.layer_norm import FusedLayerNorm
import torch
import torch.nn as nn
from scipy.stats import truncnorm
from protenix.openfold_local.utils.checkpointing import get_checkpoint_fn
from protenix.openfold_local.utils.precision_utils import is_fp16_enabled
from protenix.openfold_local.utils.tensor_utils import (
flatten_final_dims,
permute_final_dims,
)
DEFAULT_LMA_Q_CHUNK_SIZE = 1024
DEFAULT_LMA_KV_CHUNK_SIZE = 4096
def _prod(nums):
out = 1
for n in nums:
out = out * n
return out
def _calculate_fan(linear_weight_shape, fan="fan_in"):
fan_out, fan_in = linear_weight_shape
if fan == "fan_in":
f = fan_in
elif fan == "fan_out":
f = fan_out
elif fan == "fan_avg":
f = (fan_in + fan_out) / 2
else:
raise ValueError("Invalid fan option")
return f
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
f = _calculate_fan(shape, fan)
scale = scale / max(1, f)
a = -2
b = 2
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
size = _prod(shape)
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
samples = np.reshape(samples, shape)
with torch.no_grad():
weights.copy_(torch.tensor(samples, device=weights.device))
def lecun_normal_init_(weights):
trunc_normal_init_(weights, scale=1.0)
def he_normal_init_(weights):
trunc_normal_init_(weights, scale=2.0)
def glorot_uniform_init_(weights):
nn.init.xavier_uniform_(weights, gain=1)
def final_init_(weights):
with torch.no_grad():
weights.fill_(0.0)
def gating_init_(weights):
with torch.no_grad():
weights.fill_(0.0)
def normal_init_(weights):
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
weights.fill_(softplus_inverse_1)
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
precision=None,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
"""
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
if bias:
with torch.no_grad():
self.bias.fill_(0)
with torch.no_grad():
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
self.precision = precision
def forward(self, input: torch.Tensor) -> torch.Tensor:
d = input.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed.comm.comm.is_initialized()
)
if self.precision is not None:
with torch.cuda.amp.autocast(enabled=False):
bias = (
self.bias.to(dtype=self.precision)
if self.bias is not None
else None
)
return nn.functional.linear(
input.to(dtype=self.precision),
self.weight.to(dtype=self.precision),
bias,
).to(dtype=d)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
bias = self.bias.to(dtype=d) if self.bias is not None else None
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
return nn.functional.linear(input, self.weight, self.bias)
class OpenFoldLayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super(OpenFoldLayerNorm, self).__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
d = x.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed.comm.comm.is_initialized()
)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight.to(dtype=d),
self.bias.to(dtype=d),
self.eps,
)
else:
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight,
self.bias,
self.eps,
)
return out
# Keep the function name for code simplicity
def LayerNorm(c_in, eps: float = 1e-5):
# if specify "fast_layernorm" and fastln_is_installed, use the FusedLayerNorm,
# Otherwise, OpenFoldLayerNorm is used!
if fastln_is_installed:
# print("use fast layernorm")
return FusedLayerNorm(c_in, eps)
# print("use openfold layernorm")
return OpenFoldLayerNorm(c_in, eps)
@torch.jit.ignore
def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d = t.dtype
deepspeed_is_initialized = (
deepspeed_is_installed and deepspeed.comm.comm.is_initialized()
)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
s = torch.nn.functional.softmax(t, dim=dim)
return s
# @torch.jit.script
def _attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
biases: List[torch.Tensor],
) -> torch.Tensor:
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 0))
# [*, H, Q, K]
a = torch.matmul(query, key)
for b in biases:
a += b
a = softmax_no_cast(a, -1)
# [*, H, Q, C_hidden]
a = torch.matmul(a, value)
return a
@torch.jit.ignore
def _attention_chunked_trainable(
query,
key,
value,
biases,
chunk_size,
chunk_dim,
checkpoint,
):
if checkpoint and len(biases) > 2:
raise ValueError("Checkpointed version permits only permits two bias terms")
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b for b in [b1, b2] if b is not None]
a = _attention(q, k, v, bs)
return a
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
count = query.shape[chunk_dim]
for start in range(0, count, chunk_size):
end = start + chunk_size
idx = [slice(None)] * len(query.shape)
idx[chunk_dim] = slice(start, end)
idx_tup = tuple(idx)
q_chunk = query[idx_tup]
k_chunk = key[idx_tup]
v_chunk = value[idx_tup]
def _slice_bias(b):
idx[chunk_dim] = (
slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
)
return b[tuple(idx)]
if checkpoint:
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None
for b in (biases + [None, None])[:2]
]
o_chunk = checkpoint_fn(
_checkpointable_attention,
q_chunk,
k_chunk,
v_chunk,
bias_1_chunk,
bias_2_chunk,
)
else:
bias_chunks = [_slice_bias(b) for b in biases]
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
o_chunk = o_chunk.transpose(-2, -3)
o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim)
return o
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super(Attention, self).__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension.
self.linear_q = Linear(
self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_v = Linear(
self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final")
self.linear_g = None
if self.gating:
self.linear_g = Linear(
self.c_q, self.c_hidden * self.no_heads, init="gating"
)
self.sigmoid = nn.Sigmoid()
def _prep_qkv(
self, q_x: torch.Tensor, kv_x: torch.Tensor, apply_scale: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(kv_x)
v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
if apply_scale:
q /= math.sqrt(self.c_hidden)
return q, k, v
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
if self.linear_g is not None:
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
def forward(
self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
use_flash: bool = False,
flash_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel.
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_deepspeed_evo_attention:
Whether to use DeepSpeed memory-efficient attention kernel.
If none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
use_lma:
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
lma_q_chunk_size:
Query chunk size (for LMA)
lma_kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
raise ValueError(
"If use_lma is specified, lma_q_chunk_size and "
"lma_kv_chunk_size must be provided"
)
if use_flash and biases is not None:
raise ValueError(
"use_flash is incompatible with the bias option. For masking, "
"use flash_mask instead"
)
attn_options = [
use_memory_efficient_kernel,
use_deepspeed_evo_attention,
use_lma,
use_flash,
]
if sum(attn_options) > 1:
raise ValueError("Choose at most one alternative attention algorithm")
if biases is None:
biases = []
# DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=not use_deepspeed_evo_attention)
if is_fp16_enabled():
use_memory_efficient_kernel = False
if use_memory_efficient_kernel:
raise Exception(f"use_memory_efficient_kernel=True not supported!!!")
if len(biases) > 2:
raise ValueError(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
o = o.transpose(-2, -3)
elif use_deepspeed_evo_attention:
if len(biases) > 2:
raise ValueError(
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_lma:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
for b in biases
]
o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
o = o.transpose(-2, -3)
elif use_flash:
o = _flash_attn(q, k, v, flash_mask)
else:
o = _attention(q, k, v, biases)
o = o.transpose(-2, -3)
o = self._wrap_up(o, q_x)
return o
class GlobalAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
super(GlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot")
self.linear_k = Linear(
c_in,
c_hidden,
bias=False,
init="glorot",
)
self.linear_v = Linear(
c_in,
c_hidden,
bias=False,
init="glorot",
)
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
def forward(
self,
m: torch.Tensor,
mask: torch.Tensor,
use_lma: bool = False,
) -> torch.Tensor:
# [*, N_res, C_in]
q = torch.sum(m * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1)[..., None] + self.eps
)
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q *= self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
bias = (self.inf * (mask - 1))[..., :, None, :]
if not use_lma:
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
a += bias
a = softmax_no_cast(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
else:
o = _lma(
q, k, v, [bias], DEFAULT_LMA_Q_CHUNK_SIZE, DEFAULT_LMA_KV_CHUNK_SIZE
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(o.shape[:-2] + (-1,))
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
@torch.jit.ignore
def _deepspeed_evo_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
):
""" ""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, H, Q, C_hidden] query data
k:
[*, H, K, C_hidden] key data
v:
[*, H, V, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
if not ds4s_is_installed:
raise ValueError(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
def reshape_dims(x):
no_batch_dims = len(x.shape[:-3])
if no_batch_dims < 2:
return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
if no_batch_dims > 2:
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x
# [*, Q/K, H, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
if len(orig_shape[:-3]) != 2:
q = reshape_dims(q)
k = reshape_dims(k)
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(
q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases],
)
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.reshape(orig_shape)
return o
def _lma(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
q_chunk_size: int,
kv_chunk_size: int,
):
no_q, no_kv = q.shape[-2], k.shape[-2]
# [*, H, Q, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s : q_s + q_chunk_size, :]
large_bias_chunks = [b[..., q_s : q_s + q_chunk_size, :] for b in biases]
maxes = []
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s : kv_s + kv_chunk_size, :]
v_chunk = v[..., kv_s : kv_s + kv_chunk_size, :]
small_bias_chunks = [
b[..., kv_s : kv_s + kv_chunk_size] for b in large_bias_chunks
]
a = torch.einsum(
"...hqd,...hkd->...hqk",
q_chunk,
k_chunk,
)
for b in small_bias_chunks:
a += b
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v)
chunk_max = torch.stack(maxes, dim=-3)
chunk_weights = torch.stack(weights, dim=-3)
chunk_values = torch.stack(values, dim=-4)
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values = chunk_values * max_diffs.unsqueeze(-1)
chunk_weights = chunk_weights * max_diffs
all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights
o[..., q_s : q_s + q_chunk_size, :] = q_chunk_out
return o
@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if not fa_is_installed:
raise ValueError("_flash_attn requires that FlashAttention be installed")
batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
q = q.half()
k = k.half()
v = v.half()
kv_mask = kv_mask.half()
# [*, B, N, H, C]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# [B_flat, N, H, C]
q = q.reshape(-1, *q.shape[-3:])
k = k.reshape(-1, *k.shape[-3:])
v = v.reshape(-1, *v.shape[-3:])
# Flattened batch size
batch_size = q.shape[0]
# [B_flat * N, H, C]
q = q.reshape(-1, *q.shape[-2:])
q_max_s = n
q_cu_seqlens = torch.arange(
0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device
)
# [B_flat, N, 2, H, C]
kv = torch.stack([k, v], dim=-3)
kv_shape = kv.shape
# [B_flat, N, 2 * H * C]
kv = kv.reshape(*kv.shape[:-3], -1)
kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
out = flash_attn_unpadded_kvpacked_func(
q,
kv_unpad,
q_cu_seqlens,
kv_cu_seqlens,
q_max_s,
kv_max_s,
dropout_p=0.0,
softmax_scale=1.0, # q has been scaled already
)
# [*, B, N, H, C]
out = out.reshape(*batch_dims, n, no_heads, c)
out = out.to(dtype=dtype)
return out