Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jepa #52

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
.*.swp
*.pyc
*.tar

bin/
dist/
.vscode/
logs/

jepa_src/jepa.egg-info/
434 changes: 56 additions & 378 deletions README.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import yaml

from app.scaffold import main as app_main
from src.utils.distributed import init_distributed
from jepa_src.utils.distributed import init_distributed

parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -30,7 +30,7 @@ def process_main(rank, fname, world_size, devices):
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1])

import logging
from src.utils.logging import get_logger
from jepa_src.utils.logging import get_logger
logger = get_logger(force=True)
if rank == 0:
logger.setLevel(logging.INFO)
Expand Down
2 changes: 1 addition & 1 deletion app/main_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import submitit

from app.scaffold import main as app_main
from src.utils.logging import get_logger
from jepa_src.utils.logging import get_logger

logger = get_logger(force=True)

Expand Down
16 changes: 8 additions & 8 deletions app/vjepa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel

from src.datasets.data_manager import init_data
from src.masks.random_tube import MaskCollator as TubeMaskCollator
from src.masks.multiblock3d import MaskCollator as MB3DMaskCollator
from src.masks.utils import apply_masks
from src.utils.distributed import init_distributed, AllReduce
from src.utils.logging import (
from jepa_src.datasets.data_manager import init_data
from jepa_src.masks.random_tube import MaskCollator as TubeMaskCollator
from jepa_src.masks.multiblock3d import MaskCollator as MB3DMaskCollator
from jepa_src.masks.utils import apply_masks
from jepa_src.utils.distributed import init_distributed, AllReduce
from jepa_src.utils.logging import (
CSVLogger,
gpu_timer,
get_logger,
grad_logger,
adamw_logger,
AverageMeter)
from src.utils.tensors import repeat_interleave_batch
from jepa_src.utils.tensors import repeat_interleave_batch

from app.vjepa.utils import (
load_checkpoint,
Expand Down Expand Up @@ -77,7 +77,7 @@ def main(args, resume_preempt=False):
skip_batches = cfgs_meta.get('skip_batches', -1)
use_sdpa = cfgs_meta.get('use_sdpa', False)
which_dtype = cfgs_meta.get('dtype')
logger.info(f'{which_dtype=}')
logger.info(f'{which_dtype}')
if which_dtype.lower() == 'bfloat16':
dtype = torch.bfloat16
mixed_precision = True
Expand Down
4 changes: 2 additions & 2 deletions app/vjepa/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
import torchvision.transforms as transforms

import src.datasets.utils.video.transforms as video_transforms
from src.datasets.utils.video.randerase import RandomErasing
import jepa_src.datasets.utils.video.transforms as video_transforms
from jepa_src.datasets.utils.video.randerase import RandomErasing


def make_transforms(
Expand Down
10 changes: 5 additions & 5 deletions app/vjepa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@

import torch

import src.models.vision_transformer as video_vit
import src.models.predictor as vit_pred
from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper
from src.utils.schedulers import (
import jepa_src.models.vision_transformer as video_vit
import jepa_src.models.predictor as vit_pred
from jepa_src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper
from jepa_src.utils.schedulers import (
WarmupCosineSchedule,
CosineWDSchedule)
from src.utils.tensors import trunc_normal_
from jepa_src.utils.tensors import trunc_normal_

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()
Expand Down
Empty file added build/lib/jepa_src/__init__.py
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def init_data(
if (data.lower() == 'imagenet') \
or (data.lower() == 'inat21') \
or (data.lower() == 'places205'):
from src.datasets.image_dataset import make_imagedataset
from jepa_src.datasets.image_dataset import make_imagedataset
dataset, data_loader, dist_sampler = make_imagedataset(
transform=transform,
batch_size=batch_size,
Expand All @@ -66,7 +66,7 @@ def init_data(
subset_file=subset_file)

elif data.lower() == 'videodataset':
from src.datasets.video_dataset import make_videodataset
from jepa_src.datasets.video_dataset import make_videodataset
dataset, data_loader, dist_sampler = make_videodataset(
data_paths=root_path,
batch_size=batch_size,
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import torchvision.transforms.functional as F
from torchvision import transforms

import src.datasets.utils.video.functional as FF
from src.datasets.utils.video.randaugment import rand_augment_transform
import jepa_src.datasets.utils.video.functional as FF
from jepa_src.datasets.utils.video.randaugment import rand_augment_transform


_pil_interpolation_to_str = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch

from src.datasets.utils.weighted_sampler import DistributedWeightedSampler
from jepa_src.datasets.utils.weighted_sampler import DistributedWeightedSampler

_GLOBAL_SEED = 0
logger = getLogger()
Expand Down Expand Up @@ -188,15 +188,15 @@ def loadvideo_decord(self, sample):

fname = sample
if not os.path.exists(fname):
warnings.warn(f'video path not found {fname=}')
warnings.warn(f'video path not found {fname}')
return [], None

_fsize = os.path.getsize(fname)
if _fsize < 1 * 1024: # avoid hanging issue
warnings.warn(f'video too short {fname=}')
warnings.warn(f'video too short {fname}')
return [], None
if _fsize > self.filter_long_videos:
warnings.warn(f'skipping long video of size {_fsize=} (bytes)')
warnings.warn(f'skipping long video of size {_fsize} (bytes)')
return [], None

try:
Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import torch
import torch.nn as nn

from src.models.utils.modules import (
from jepa_src.models.utils.modules import (
Block,
CrossAttention,
CrossAttentionBlock
)
from src.utils.tensors import trunc_normal_
from jepa_src.utils.tensors import trunc_normal_


class AttentivePooler(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import torch
import torch.nn as nn

from src.models.utils.modules import Block
from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
from src.utils.tensors import (
from jepa_src.models.utils.modules import Block
from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
from jepa_src.utils.tensors import (
trunc_normal_,
repeat_interleave_batch
)
from src.masks.utils import apply_masks
from jepa_src.masks.utils import apply_masks


class VisionTransformerPredictor(nn.Module):
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

import jepa_src.utils.functional as JF

class MLP(nn.Module):
def __init__(
Expand Down Expand Up @@ -65,7 +66,7 @@ def forward(self, x, mask=None):

if self.use_sdpa:
with torch.backends.cuda.sdp_kernel():
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob)
x = JF.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob)
attn = None
else:
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D]
Expand Down Expand Up @@ -147,7 +148,7 @@ def forward(self, q, x):

if self.use_sdpa:
with torch.backends.cuda.sdp_kernel():
q = F.scaled_dot_product_attention(q, k, v)
q = JF.scaled_dot_product_attention(q, k, v)
else:
xattn = (q @ k.transpose(-2, -1)) * self.scale
xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import torch
import torch.nn as nn

from src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D
from src.models.utils.modules import Block
from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
from src.utils.tensors import trunc_normal_
from src.masks.utils import apply_masks
from jepa_src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D
from jepa_src.models.utils.modules import Block
from jepa_src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed
from jepa_src.utils.tensors import trunc_normal_
from jepa_src.masks.utils import apply_masks


class VisionTransformer(nn.Module):
Expand Down
Empty file.
File renamed without changes.
30 changes: 30 additions & 0 deletions build/lib/jepa_src/utils/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(q, k, v, dropout_p=0.0):
"""
Computes scaled dot product attention.

Args:
q (torch.Tensor): Query tensor of shape (batch_size, num_heads, seq_len_q, head_dim).
k (torch.Tensor): Key tensor of shape (batch_size, num_heads, seq_len_k, head_dim).
v (torch.Tensor): Value tensor of shape (batch_size, num_heads, seq_len_v, head_dim).
dropout_p (float, optional): Dropout probability. Default is 0.0.

Returns:
torch.Tensor: Output tensor of shape (batch_size, num_heads, seq_len_q, head_dim).
"""
# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1))
attn_scores = attn_scores / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32))

# Apply softmax to attention scores
attn_probs = F.softmax(attn_scores, dim=-1)

# Apply dropout to attention probabilities
attn_probs = F.dropout(attn_probs, p=dropout_p)

# Compute attention output
attn_output = torch.matmul(attn_probs, v)

return attn_output
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 6 additions & 0 deletions build/lib/vjepa_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from vjepa_encoder.vision_encoder import JepaEncoder

__all__ = [
"JepaEncoder",

]
Loading