diff --git a/README.md b/README.md index cd752bd0..84ccdffd 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,13 @@ You can then try to offloading all weights to disk by python3 -m flexgen.flex_opt --model facebook/opt-175b --percent 0 0 100 0 100 0 --offload-dir YOUR_SSD_FOLDER ``` +### CPU and M1/M2 GPU platform +To run models on CPU platforms, all you need to do is to add an `--platform` entry: +``` +python3 -m flexgen.flex_opt --model facebook/opt-1.3b --platform cpu +``` +To run on M1/M2 platforms, [PyTorch nightly](https://pytorch.org/) is required for kernel coverage and better performance. Once you have PyTorch nightly installed, you can simply replace `cpu` with `mps:0`. + ### How to set the offloading strategy and `--percent`? We will release an automatic policy optimizer later, but now you have to manually try a few strategies. The idea of high-throughput generation is to offload parameters and attention cache as much as possible to the CPU and disk if necessary. @@ -191,7 +198,7 @@ See [flexgen/apps](flexgen/apps) for more example applications. ## Roadmap We plan to work on the following features. Community contributions are welcome. -- [ ] Support Apple silicon M1/M2 deployment +- [x] Support Apple silicon M1/M2 deployment - [ ] Support Colab deployment - [ ] Support more models (BLOOM, CodeGen, GLM) - [ ] Release the cost model and policy optimizer diff --git a/flexgen/flex_opt.py b/flexgen/flex_opt.py index 66fc8ced..16216c8a 100644 --- a/flexgen/flex_opt.py +++ b/flexgen/flex_opt.py @@ -43,6 +43,8 @@ class Policy: act_gpu_percent: float act_cpu_percent: float + only_cpu: bool + # Whether to overlap the I/O and compute overlap: bool @@ -101,12 +103,8 @@ def init_weight_list(weight_specs, policy, env): home = get_choice(mid_percent * 100, dev_percents, dev_choices) shape, dtype, filename = weight_specs[i] - if len(shape) < 2: - pin_memory = True - compress = False - else: - pin_memory = policy.pin_weight - compress = policy.compress_weight + pin_memory = policy.pin_weight + compress = policy.compress_weight if not compress: weight = home.allocate(shape, dtype, pin_memory=pin_memory) @@ -614,10 +612,11 @@ def __init__(self, else: raise NotImplementedError() - # CUDA streams - self.load_weight_stream = torch.cuda.Stream() - self.load_cache_stream = torch.cuda.Stream() - self.store_cache_stream = torch.cuda.Stream() + if self.env.gpu.device_type == DeviceType.CUDA: + # CUDA streams + self.load_weight_stream = torch.cuda.Stream() + self.load_cache_stream = torch.cuda.Stream() + self.store_cache_stream = torch.cuda.Stream() # Intermediate tensors # The following buffers store values used @@ -791,7 +790,8 @@ def compute_layer(self, i, j, k): def sync(self): self.env.disk.synchronize() - torch.cuda.synchronize() + if self.env.gpu.device_type == DeviceType.CUDA: + torch.cuda.synchronize() def init_all_weights(self): self.weight_home = array_1d(self.num_layers, ValueHolder) @@ -1184,15 +1184,17 @@ def run_flexgen(args): warmup_inputs = get_test_inputs(32, num_prompts, tokenizer) inputs = get_test_inputs(prompt_len, num_prompts, tokenizer) - gpu = TorchDevice("cuda:0") + if args.platform == "cpu": + gpu = TorchDevice("cpu") + else: + gpu = TorchDevice(args.platform) cpu = TorchDevice("cpu") - disk = TorchDisk(args.offload_dir) - env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk])) - + disk = TorchDisk(args.offload_dir,platform=args.platform) + env = ExecutionEnv(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]), platform=args.platform) policy = Policy(args.gpu_batch_size, args.num_gpu_batches, args.percent[0], args.percent[1], args.percent[2], args.percent[3], - args.percent[4], args.percent[5], + args.percent[4], args.percent[5], args.platform == "cpu", args.overlap, args.sep_layer, args.pin_weight, args.cpu_cache_compute, args.attn_sparsity, args.compress_weight, @@ -1203,7 +1205,8 @@ def run_flexgen(args): group_dim=2, symmetric=False)) assert not (args.compress_cache and args.attn_sparsity < 1.0), "Not implemented" - opt_config = get_opt_config(args.model) + # use float32 for CPU platform + opt_config = get_opt_config(args.model, dtype=np.float32 if args.platform == "cpu" else np.float16) cache_size = opt_config.cache_bytes(num_prompts, prompt_len + gen_len) hidden_size = opt_config.hidden_bytes(num_prompts, prompt_len + gen_len) print(f"model size: {opt_config.model_bytes()/GB:.3f} GB, " @@ -1311,6 +1314,7 @@ def add_parser_arguments(parser): parser.add_argument("--overlap", type=str2bool, nargs='?', const=True, default=True) + parser.add_argument("--platform", type=str, default="cuda:0", help="use the number to specify device, the platform can also be cpu or mps") if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -1319,4 +1323,25 @@ def add_parser_arguments(parser): assert len(args.percent) == 6 + if "cuda" in args.platform: + if not torch.cuda.is_available(): + if torch.backends.mps.is_available(): + args.platform = "mps:0" + else: + args.platform = "cpu" + print("CUDA devices not available, {} is used instead".format(args.platform)) + + if "mps" in args.platform: + if not torch.backends.mps.is_available(): + args.platform = "cpu" + print("MPS devices not available, CPU is used instead") + + if "cuda" not in args.platform: + # not clear how to enable overlap on MPS platform yet + args.overlap = False + args.pin_weight = False + + if args.platform == "cpu": + args.percent = [0, 100, 0, 100, 0, 100] + run_flexgen(args) diff --git a/flexgen/opt_config.py b/flexgen/opt_config.py index 5a8e9496..716229d9 100644 --- a/flexgen/opt_config.py +++ b/flexgen/opt_config.py @@ -54,6 +54,8 @@ def get_opt_config(name, **kwargs): name = name.split("/")[1] name = name.lower() + dtype = kwargs["dtype"] if "dtype" in kwargs else OptConfig.dtype + # Handle opt-iml-30b and opt-iml-max-30b if "-iml-max" in name: arch_name = name.replace("-iml-max", "") @@ -65,54 +67,54 @@ def get_opt_config(name, **kwargs): if arch_name == "opt-125m": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=12, n_head=12, - hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, + hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4, dtype=dtype ) elif arch_name == "opt-350m": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=24, n_head=16, - hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, + hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4, dtype=dtype ) raise NotImplementedError("Not implemented because this model " "has a different architecture") elif arch_name == "opt-1.3b": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=24, n_head=32, - hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, + hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4, dtype=dtype ) elif arch_name == "opt-2.7b": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=32, n_head=32, - hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, + hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4, dtype=dtype ) elif arch_name == "opt-6.7b": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=32, n_head=32, - hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, + hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4, dtype=dtype ) elif arch_name == "opt-13b": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=40, n_head=40, - hidden_size=5120, input_dim=5120, ffn_embed_dim=5120 * 4, + hidden_size=5120, input_dim=5120, ffn_embed_dim=5120 * 4, dtype=dtype ) elif arch_name == "opt-30b": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=48, n_head=56, - hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, + hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4, dtype=dtype ) elif arch_name == "opt-66b": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=64, n_head=72, - hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, + hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4, dtype=dtype ) elif arch_name == "opt-175b": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=96, n_head=96, - hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, + hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, dtype=dtype ) elif arch_name == "opt-175b-stage": config = OptConfig(name=name, max_seq_len=2048, num_hidden_layers=24, n_head=96, - hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, + hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4, dtype=dtype ) else: raise ValueError(f"Invalid model name: {name}") diff --git a/flexgen/pytorch_backend.py b/flexgen/pytorch_backend.py index 6700972d..482c8ead 100644 --- a/flexgen/pytorch_backend.py +++ b/flexgen/pytorch_backend.py @@ -32,6 +32,10 @@ def fix_recursive_import(): class DeviceType(Enum): CPU = auto() CUDA = auto() + # Metal Performance Shaders (MPS) for Mac platforms + MPS = auto() + # Heterogeneous Interface for Portability (HIP) for AMD platforms + HIP = auto() DISK = auto() MIXED = auto() COMPRESSED = auto() @@ -42,6 +46,10 @@ def convert(name): return DeviceType.CPU elif name == "cuda": return DeviceType.CUDA + elif name == "mps": + return DeviceType.MPS + elif name == "hip": + return DeviceType.HIP elif name == "disk": return DeviceType.DISK elif name == "mixed": @@ -182,8 +190,9 @@ def add_link(self, link): self.links[dst] = link def allocate(self, shape, dtype, pin_memory=None, name=None): + # set default pin_memory to be False if self.device_type == DeviceType.CPU: - pin_memory = True if pin_memory is None else pin_memory + pin_memory = False if pin_memory is None else pin_memory else: pin_memory = False dtype = np_dtype_to_torch_dtype[dtype] @@ -270,7 +279,13 @@ def opt_output_embed(self, inputs, w_ln, b_ln, w_token, donate, b, s, h = inputs.shape - hidden = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) + # workaround for PyTorch MPS bug of varianceEps implementation + if self.device_type == DeviceType.MPS: + hidden = F.layer_norm(inputs.data.type(torch.float32), (h,), weight=w_ln.data.type(torch.float32), bias=b_ln.data.type(torch.float32)) + hidden = hidden.type(torch.float16) + else: + hidden = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) + if donate[0]: inputs.delete() # output embedding @@ -291,8 +306,9 @@ def init_cache_one_gpu_batch(self, config, task, policy): shape = (prompt_len + gen_len - 1, gpu_batch_size * num_head, hidden_size // num_head) # NOTE: disable pin_memory due to high memory overhead pin_memory = False - k_cache = self.allocate(shape, np.float16, pin_memory=pin_memory) - v_cache = self.allocate(shape, np.float16, pin_memory=pin_memory) + + k_cache = self.allocate(shape, np.float32 if policy.only_cpu else np.float16, pin_memory=pin_memory) + v_cache = self.allocate(shape, np.float32 if policy.only_cpu else np.float16, pin_memory=pin_memory) return k_cache, v_cache def mha(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, @@ -309,7 +325,13 @@ def mha(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, head_dim = h // n_head scaling = head_dim ** -0.5 - hidden = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) + # workaround for PyTorch MPS bug of varianceEps implementation + if self.device_type == DeviceType.MPS: + # hidden = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data, eps=torch.tensor([1e-05]).type(torch.float16)) + hidden = F.layer_norm(inputs.data.type(torch.float32), (h,), weight=w_ln.data.type(torch.float32), bias=b_ln.data.type(torch.float32)) + hidden = hidden.type(torch.float16) + else: + hidden = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) # shape: (b, s, h) q = F.linear(hidden, w_q.data, bias=b_q.data) * scaling @@ -380,7 +402,12 @@ def mha_gen(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, head_dim = h // n_head scaling = head_dim ** -0.5 - hidden = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) + # workaround for PyTorch MPS bug of varianceEps implementation + if self.device_type == DeviceType.MPS: + hidden = F.layer_norm(inputs.data.type(torch.float32), (h,), weight=w_ln.data.type(torch.float32), bias=b_ln.data.type(torch.float32)) + hidden = hidden.type(torch.float16) + else: + hidden = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) # shape: (b, 1, h) q = F.linear(hidden, w_q.data, bias=b_q.data) * scaling @@ -416,14 +443,15 @@ def mha_gen(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, # shape: (b * n_head, s, head_dim) v = v.permute(1, 0, 2).reshape(b * n_head, src_s, head_dim) - if k.is_cuda: - value = self._attention_value(q, k, v, attention_mask.data, - b, src_s, tgt_s, n_head, head_dim) - else: + if self.device_type == DeviceType.CUDA and not k.is_cuda: q = q.float().cpu() k, v = k.float(), v.float() value = self._attention_value(q, k, v, attention_mask.data, b, src_s, tgt_s, n_head, head_dim).cuda().half() + else: + value = self._attention_value(q, k, v, attention_mask.data, + b, src_s, tgt_s, n_head, head_dim) + else: # Sparse attention # shape: (s, b * n_head, head_dim) k = k_cache.data[:src_s] @@ -431,15 +459,15 @@ def mha_gen(self, inputs, attention_mask, w_q, b_q, w_k, b_k, w_v, b_v, # shape: (b * n_head, head_dim, s) k = k.permute(1, 2, 0).reshape(b * n_head, head_dim, src_s) - if k.is_cuda: + if self.device_type == DeviceType.CUDA and not k.is_cuda: + q = q.float().cpu() value = self._sparse_attention_value(q, k, v_new, v_cache, attention_mask.data, b, src_s, tgt_s, n_head, head_dim, - attn_sparsity) + attn_sparsity).cuda().half() else: - q = q.float().cpu() value = self._sparse_attention_value(q, k, v_new, v_cache, attention_mask.data, b, src_s, tgt_s, n_head, head_dim, - attn_sparsity).cuda().half() + attn_sparsity) else: # Mixed device attention assert attn_sparsity >= 1.0 value = self._mixed_device_attention(q, k_cache, v_cache, @@ -498,12 +526,12 @@ def _sparse_attention_value(self, q, k, v_new, v_cache, mask, b, attn_weights = torch.cat([topk_weights, attn_weights[:, :, -1].unsqueeze(-1)], dim=-1) - if k.is_cuda: + if self.device_type == DeviceType.CUDA and not k.is_cuda: + (v_home, v_buf) = v_cache + else: v_home = v_cache v_buf = self.allocate((topk+1, b*n_head, head_dim), np.float16) topk_indices = topk_indices.cpu() - else: - (v_home, v_buf) = v_cache # shape: (s, b * n_head, head_dim) indices_src = topk_indices @@ -574,7 +602,13 @@ def mlp(self, inputs, wi, bi, wo, bo, w_ln, b_ln, donate): b, s, h = inputs.shape - out = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) + # workaround for PyTorch MPS bug of varianceEps implementation + if self.device_type == DeviceType.MPS: + out = F.layer_norm(inputs.data.type(torch.float32), (h,), weight=w_ln.data.type(torch.float32), bias=b_ln.data.type(torch.float32)) + out = out.type(torch.float16) + else: + out = F.layer_norm(inputs.data, (h,), weight=w_ln.data, bias=b_ln.data) + out = F.linear(out, wi.data, bias=bi.data) F.relu(out, inplace=True) out = F.linear(out, wo.data, bias=bo.data) @@ -590,7 +624,7 @@ def mem_stats(self): if self.device_type == DeviceType.CUDA: cur_mem = torch.cuda.memory_allocated(self.dev) peak_mem = torch.cuda.max_memory_allocated(self.dev) - elif self.device_type == DeviceType.CPU: + elif self.device_type in [DeviceType.MPS, DeviceType.CPU]: cur_mem = cpu_mem_stats() peak_mem = 0 else: @@ -599,7 +633,8 @@ def mem_stats(self): return cur_mem, peak_mem def print_stats(self, output_file=None): - torch.cuda.synchronize() + if self.device_type == DeviceType.CUDA: + torch.cuda.synchronize() cur_mem, peak_mem = self.mem_stats() if output_file is not None: @@ -621,7 +656,7 @@ def __str__(self): class TorchDisk: """Manage tensors stored on a disk.""" - def __init__(self, path, mem_capacity=None, cuda_id=0, num_copy_threads=4): + def __init__(self, path, mem_capacity=None, platform="cuda", device_id=0, num_copy_threads=4): self.name = path self.path = os.path.abspath(os.path.expanduser(path)) self.mem_capacity = mem_capacity @@ -640,7 +675,7 @@ def __init__(self, path, mem_capacity=None, cuda_id=0, num_copy_threads=4): self.copy_queue = queue.Queue() self.copy_threads = [ threading.Thread( - target=copy_worker_func, args=(self.copy_queue, cuda_id) + target=copy_worker_func, args=(self.copy_queue, platform, device_id) ) for _ in range(num_copy_threads) ] for t in self.copy_threads: @@ -875,32 +910,33 @@ def map_to_torch_tensor(tensor, indices): return data[indices] if indices else data -def copy_worker_func(queue, cuda_id): +def copy_worker_func(queue, platform, device_id): """The copy worker thread.""" - torch.cuda.set_device(cuda_id) - cpu_buf = torch.empty((1 * GB,), dtype=torch.float16, pin_memory=True) - copy_stream = torch.cuda.Stream() + if "cuda" in platform: + torch.cuda.set_device(device_id) + copy_stream = torch.cuda.Stream() + cpu_buf = torch.empty((1 * GB,), dtype=torch.float16, pin_memory=True) - with torch.cuda.stream(copy_stream): - while True: - item = queue.get() - if item is None: - queue.task_done() - return + while True: + item = queue.get() + if item is None: + queue.task_done() + return - dst, dst_indices, src, src_indices = item - src_data = map_to_torch_tensor(src, src_indices) - dst_data = map_to_torch_tensor(dst, dst_indices) + dst, dst_indices, src, src_indices = item + src_data = map_to_torch_tensor(src, src_indices) + dst_data = map_to_torch_tensor(dst, dst_indices) - if (src.device.device_type == DeviceType.CUDA or - dst.device.device_type == DeviceType.CUDA): + if (src.device.device_type == DeviceType.CUDA or + dst.device.device_type == DeviceType.CUDA): + with torch.cuda.stream(copy_stream): # Use a pinned cpu buffer as a relay size = np.prod(src_data.shape) tmp_cpu_buf = cpu_buf[:size].view(src_data.shape) tmp_cpu_buf.copy_(src_data) dst_data.copy_(tmp_cpu_buf) - else: - dst_data.copy_(src_data) + else: + dst_data.copy_(src_data) - queue.task_done() + queue.task_done() \ No newline at end of file diff --git a/flexgen/utils.py b/flexgen/utils.py index 7e62c9bb..3ddfa35e 100644 --- a/flexgen/utils.py +++ b/flexgen/utils.py @@ -38,6 +38,7 @@ class ExecutionEnv: cpu: Any = None disk: Any = None mixed: Any = None + platform: Any = None @classmethod def create(cls, offload_dir):