Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored absolute imports to relative imports #18

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added __init__.py
Empty file.
Empty file added picotron/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions picotron/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from safetensors import safe_open
import contextlib

from picotron.utils import assert_no_meta_tensors, print
import picotron.process_group_manager as pgm
from .utils import assert_no_meta_tensors, print
from . import process_group_manager as pgm

from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel
from .pipeline_parallel.pipeline_parallel import PipelineParallel

@contextlib.contextmanager
def init_model_with_dematerialized_weights(include_buffers: bool = False):
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions picotron/context_parallel/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn.functional as F
from typing import Any, Optional, Tuple

import picotron.process_group_manager as pgm
from picotron.context_parallel.cp_communications import ContextCommunicate
from .. import process_group_manager as pgm
from .cp_communications import ContextCommunicate

def apply_context_parallel(model):
os.environ["CONTEXT_PARALLEL"] = "1" if pgm.process_group_manager.cp_world_size > 1 else "0"
Expand Down
2 changes: 1 addition & 1 deletion picotron/context_parallel/cp_communications.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch import distributed as dist
from typing import List

import picotron.process_group_manager as pgm
from .. import process_group_manager as pgm

STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"

Expand Down
4 changes: 2 additions & 2 deletions picotron/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from functools import partial
from datasets import Features, Sequence, Value, load_dataset
from transformers import AutoTokenizer
from picotron.utils import print
from .utils import print

import picotron.process_group_manager as pgm
from . import process_group_manager as pgm

class MicroBatchDataLoader(DataLoader):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, subset_name=None, split="train", num_samples=None, pin_memory=True):
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions picotron/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch import nn
from torch.autograd import Variable

from picotron.data_parallel.bucket import BucketManager
import picotron.process_group_manager as pgm
from .bucket import BucketManager
from .. import process_group_manager as pgm

class DataParallelNaive(nn.Module):
"""
Expand Down
4 changes: 2 additions & 2 deletions picotron/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from picotron.context_parallel import context_parallel
from .context_parallel import context_parallel
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.ops.triton.layer_norm import layer_norm_fn
import picotron.process_group_manager as pgm
from . import process_group_manager as pgm

def apply_rotary_pos_emb(x, cos, sin):
#TODO: Maybe do class RotaryEmbedding(nn.Module) later
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions picotron/pipeline_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

import picotron.process_group_manager as pgm
from picotron.pipeline_parallel.pp_communications import pipeline_communicate, bidirectional_pipeline_communicate
from .. import process_group_manager as pgm
from .pp_communications import pipeline_communicate, bidirectional_pipeline_communicate

class PipelineParallel(nn.Module):
def __init__(self, model, config):
Expand Down
2 changes: 1 addition & 1 deletion picotron/pipeline_parallel/pp_communications.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch
import torch.distributed as dist
import picotron.process_group_manager as pgm
from .. import process_group_manager as pgm

STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"

Expand Down
Empty file.
4 changes: 2 additions & 2 deletions picotron/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import picotron.process_group_manager as pgm
from picotron.tensor_parallel.tp_communications import ReduceFromModelParallelRegion, GatherFromModelParallelRegion, linear_with_all_reduce, linear_with_async_all_reduce
from .. import process_group_manager as pgm
from .tp_communications import ReduceFromModelParallelRegion, GatherFromModelParallelRegion, linear_with_all_reduce, linear_with_async_all_reduce

def apply_tensor_parallel(model):

Expand Down
2 changes: 1 addition & 1 deletion picotron/tensor_parallel/tp_communications.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.distributed as dist
import torch
import picotron.process_group_manager as pgm
from .. import process_group_manager as pgm
import torch.nn.functional as F

from typing import Tuple
Expand Down
4 changes: 1 addition & 3 deletions picotron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import numpy as np
import builtins
import fcntl

from . import process_group_manager as pgm
import huggingface_hub

import picotron.process_group_manager as pgm
import torch, torch.distributed as dist

def print(*args, is_print_rank=True, **kwargs):
Expand Down
Empty file added tests/__init__.py
Empty file.
23 changes: 11 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig
from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
import picotron.process_group_manager as pgm
from picotron.utils import average_loss_across_dp_cp_ranks, set_all_seed, print, to_readable_format, get_mfu, get_num_params
from picotron.checkpoint import CheckpointManager
from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights
from picotron.data import MicroBatchDataLoader
from picotron.process_group_manager import setup_process_group_manager
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.data_parallel.data_parallel import DataParallelBucket
from picotron.model import Llama
from picotron.utils import download_model
from .picotron.context_parallel.context_parallel import apply_context_parallel
from .picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
from .picotron import process_group_manager as pgm
from .picotron.utils import average_loss_across_dp_cp_ranks, set_all_seed, print, to_readable_format, get_mfu, get_num_params
from .picotron.checkpoint import CheckpointManager
from .picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights
from .picotron.data import MicroBatchDataLoader
from .picotron.process_group_manager import setup_process_group_manager
from .picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from .picotron.data_parallel.data_parallel import DataParallelBucket
from .picotron.model import Llama
import wandb

def train_step(model, data_loader, device):
Expand Down