From 8788c0995fb5013c5fa950949cc587016ddd3d94 Mon Sep 17 00:00:00 2001 From: eliphatfs Date: Tue, 27 Jun 2023 23:08:12 +0800 Subject: [PATCH 1/3] Fast inference script (experimental). --- inference_realesrgan_video_fast.py | 467 +++++++++++++++++++++++++++++ 1 file changed, 467 insertions(+) create mode 100644 inference_realesrgan_video_fast.py diff --git a/inference_realesrgan_video_fast.py b/inference_realesrgan_video_fast.py new file mode 100644 index 000000000..e472a9e97 --- /dev/null +++ b/inference_realesrgan_video_fast.py @@ -0,0 +1,467 @@ +import argparse +import cv2 +import glob +import mimetypes +import numpy as np +import os +import shutil +import subprocess +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.utils.download_util import load_file_from_url +from os import path as osp +from tqdm import tqdm +import torch.nn.functional as F + +from realesrgan import RealESRGANer +from realesrgan.archs.srvgg_arch import SRVGGNetCompact + +try: + import ffmpeg +except ImportError: + import pip + pip.main(['install', '--user', 'ffmpeg-python']) + import ffmpeg + + +def get_video_meta_info(video_path): + ret = {} + probe = ffmpeg.probe(video_path) + video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] + has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) + ret['width'] = video_streams[0]['width'] + ret['height'] = video_streams[0]['height'] + ret['fps'] = eval(video_streams[0]['avg_frame_rate']) + ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None + ret['nb_frames'] = int(video_streams[0]['nb_frames']) + return ret + + +def get_sub_video(args, num_process, process_idx): + if num_process == 1: + return args.input + meta = get_video_meta_info(args.input) + duration = int(meta['nb_frames'] / meta['fps']) + part_time = duration // num_process + print(f'duration: {duration}, part_time: {part_time}') + os.makedirs(osp.join(args.output, f'{args.video_name}_inp_tmp_videos'), exist_ok=True) + out_path = osp.join(args.output, f'{args.video_name}_inp_tmp_videos', f'{process_idx:03d}.mp4') + cmd = [ + args.ffmpeg_bin, f'-i {args.input}', '-ss', f'{part_time * process_idx}', + f'-to {part_time * (process_idx + 1)}' if process_idx != num_process - 1 else '', '-async 1', out_path, '-y' + ] + print(' '.join(cmd)) + subprocess.call(' '.join(cmd), shell=True) + return out_path + + +class Reader: + + def __init__(self, args, total_workers=1, worker_idx=0): + self.args = args + input_type = mimetypes.guess_type(args.input)[0] + self.input_type = 'folder' if input_type is None else input_type + self.paths = [] # for image&folder type + self.audio = None + self.input_fps = None + if self.input_type.startswith('video'): + video_path = get_sub_video(args, total_workers, worker_idx) + self.stream_reader = ( + ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='rgb24', + loglevel='error').run_async( + pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)) + meta = get_video_meta_info(video_path) + self.width = meta['width'] + self.height = meta['height'] + self.input_fps = meta['fps'] + self.audio = meta['audio'] + self.nb_frames = meta['nb_frames'] + + else: + if self.input_type.startswith('image'): + self.paths = [args.input] + else: + paths = sorted(glob.glob(os.path.join(args.input, '*'))) + tot_frames = len(paths) + num_frame_per_worker = tot_frames // total_workers + (1 if tot_frames % total_workers else 0) + self.paths = paths[num_frame_per_worker * worker_idx:num_frame_per_worker * (worker_idx + 1)] + + self.nb_frames = len(self.paths) + assert self.nb_frames > 0, 'empty folder' + from PIL import Image + tmp_img = Image.open(self.paths[0]) + self.width, self.height = tmp_img.size + self.idx = 0 + + def get_resolution(self): + return self.height, self.width + + def get_fps(self): + if self.args.fps is not None: + return self.args.fps + elif self.input_fps is not None: + return self.input_fps + return 24 + + def get_audio(self): + return self.audio + + def __len__(self): + return self.nb_frames + + def get_frame_from_stream(self): + img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel + if not img_bytes: + return None + img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) + return img + + def get_frame_from_list(self): + if self.idx >= self.nb_frames: + return None + img = cv2.imread(self.paths[self.idx]) + self.idx += 1 + return img + + def get_frame(self): + if self.input_type.startswith('video'): + return self.get_frame_from_stream() + else: + return self.get_frame_from_list() + + def close(self): + if self.input_type.startswith('video'): + self.stream_reader.stdin.close() + self.stream_reader.wait() + + +class Writer: + + def __init__(self, args, audio, height, width, video_save_path, fps): + out_width, out_height = int(width * args.outscale), int(height * args.outscale) + if out_height > 2160: + print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', + 'We highly recommend to decrease the outscale(aka, -s).') + + if audio is not None: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', + framerate=fps).output( + audio, + video_save_path, + pix_fmt='yuv420p', + vcodec='libx264', + loglevel='error', + acodec='copy').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)) + else: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{out_width}x{out_height}', + framerate=fps).output( + video_save_path, pix_fmt='yuv420p', vcodec='libx264', + loglevel='error').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)) + + @profile + def write_frame(self, frame: np.ndarray): + assert frame.dtype == np.uint8 + frame = frame.data + self.stream_writer.stdin.write(frame) + + def close(self): + self.stream_writer.stdin.close() + self.stream_writer.wait() + + + +def convert(self: RealESRGANer, img): + img = torch.from_numpy(np.array(np.transpose(img, (2, 0, 1)))) + img = img.unsqueeze(0).to(self.device) + if self.half: + img = img.half() + else: + img = img.float() + return img / 255.0 + +def pre_process_batched(self: RealESRGANer): + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + +@torch.no_grad() +@profile +def batch_enhance_rgb(self: RealESRGANer, imgs, outscale=None, alpha_upsampler='realesrgan'): + tensors = [] + for img in imgs: + # img: numpy + if np.max(img) > 256: # 16-bit image + assert False + if len(img.shape) == 2: # gray image + assert False + elif img.shape[2] == 4: # RGBA image with alpha channel + assert False + tensors.append(convert(self, img)) + + self.img = torch.cat(tensors) + # ------------------- process image (without the alpha channel) ------------------- # + pre_process_batched(self) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + if outscale is not None and outscale != float(self.scale): + output_img = F.interpolate(output_img, scale_factor=outscale / float(self.scale), mode='area') + output_img = (output_img * 255).clamp_(0, 255).byte().permute(0, 2, 3, 1).contiguous().cpu().numpy() + for output in output_img: + yield output + + +@profile +def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0): + # ---------------------- determine models according to model names ---------------------- # + args.model_name = args.model_name.split('.pth')[0] + if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] + elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth'] + elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + netscale = 4 + file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth'] + elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) + netscale = 2 + file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth'] + elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size) + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') + netscale = 4 + file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'] + elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size) + model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') + netscale = 4 + file_url = [ + 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth', + 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' + ] + + # ---------------------- determine model paths ---------------------- # + model_path = os.path.join('weights', args.model_name + '.pth') + if not os.path.isfile(model_path): + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + for url in file_url: + # model_path will be updated + model_path = load_file_from_url( + url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) + + # use dni to control the denoise strength + dni_weight = None + if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1: + wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') + model_path = [model_path, wdn_model_path] + dni_weight = [args.denoise_strength, 1 - args.denoise_strength] + + # restorer + upsampler = RealESRGANer( + scale=netscale, + model_path=model_path, + dni_weight=dni_weight, + model=model, + tile=args.tile, + tile_pad=args.tile_pad, + pre_pad=args.pre_pad, + half=not args.fp32, + device=device, + ) + + if 'anime' in args.model_name and args.face_enhance: + print('face_enhance is not supported in anime models, we turned this option off for you. ' + 'if you insist on turning it on, please manually comment the relevant lines of code.') + args.face_enhance = False + + if args.face_enhance: # Use GFPGAN for face enhancement + from gfpgan import GFPGANer + face_enhancer = GFPGANer( + model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', + upscale=args.outscale, + arch='clean', + channel_multiplier=2, + bg_upsampler=upsampler) # TODO support custom device + else: + face_enhancer = None + + reader = Reader(args, total_workers, worker_idx) + audio = reader.get_audio() + height, width = reader.get_resolution() + fps = reader.get_fps() + writer = Writer(args, audio, height, width, video_save_path, fps) + + pbar = tqdm(total=len(reader), unit='frame', desc='inference') + queue = [] + assert not args.face_enhance + while True: + img = reader.get_frame() + if img is None: + break + queue.append(img) + if len(queue) == args.batch: + try: + output = list(batch_enhance_rgb(upsampler, queue, outscale=args.outscale)) + queue.clear() + except RuntimeError as error: + print('Error', error) + print('If you encounter CUDA out of memory, try to set --tile with a smaller number.') + else: + for frame in output: + writer.write_frame(frame) + pbar.update(args.batch) + torch.cuda.synchronize(device) + if len(queue): + for frame in batch_enhance_rgb(upsampler, queue, outscale=args.outscale): + writer.write_frame(frame) + pbar.update(1) + queue.clear() + torch.cuda.synchronize(device) + reader.close() + writer.close() + + +def run(args): + args.video_name = osp.splitext(os.path.basename(args.input))[0] + video_save_path = osp.join(args.output, f'{args.video_name}_{args.suffix}.mp4') + + if args.extract_frame_first: + tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames') + os.makedirs(tmp_frames_folder, exist_ok=True) + os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {tmp_frames_folder}/frame%08d.png') + args.input = tmp_frames_folder + + num_gpus = torch.cuda.device_count() + num_process = num_gpus * args.num_process_per_gpu + if num_process == 1: + inference_video(args, video_save_path) + return + + ctx = torch.multiprocessing.get_context('spawn') + pool = ctx.Pool(num_process) + os.makedirs(osp.join(args.output, f'{args.video_name}_out_tmp_videos'), exist_ok=True) + pbar = tqdm(total=num_process, unit='sub_video', desc='inference') + for i in range(num_process): + sub_video_save_path = osp.join(args.output, f'{args.video_name}_out_tmp_videos', f'{i:03d}.mp4') + pool.apply_async( + inference_video, + args=(args, sub_video_save_path, torch.device(i % num_gpus), num_process, i), + callback=lambda arg: pbar.update(1)) + pool.close() + pool.join() + + # combine sub videos + # prepare vidlist.txt + with open(f'{args.output}/{args.video_name}_vidlist.txt', 'w') as f: + for i in range(num_process): + f.write(f'file \'{args.video_name}_out_tmp_videos/{i:03d}.mp4\'\n') + + cmd = [ + args.ffmpeg_bin, '-f', 'concat', '-safe', '0', '-i', f'{args.output}/{args.video_name}_vidlist.txt', '-c', + 'copy', f'{video_save_path}' + ] + print(' '.join(cmd)) + subprocess.call(cmd) + shutil.rmtree(osp.join(args.output, f'{args.video_name}_out_tmp_videos')) + if osp.exists(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')): + shutil.rmtree(osp.join(args.output, f'{args.video_name}_inp_tmp_videos')) + os.remove(f'{args.output}/{args.video_name}_vidlist.txt') + + +def main(): + """Inference demo for Real-ESRGAN. + It mainly for restoring anime videos. + + """ + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', type=str, default='inputs', help='Input video, image or folder') + parser.add_argument( + '-n', + '--model_name', + type=str, + default='realesr-animevideov3', + help=('Model names: realesr-animevideov3 | RealESRGAN_x4plus_anime_6B | RealESRGAN_x4plus | RealESRNet_x4plus |' + ' RealESRGAN_x2plus | realesr-general-x4v3' + 'Default:realesr-animevideov3')) + parser.add_argument('-o', '--output', type=str, default='results', help='Output folder') + parser.add_argument( + '-dn', + '--denoise_strength', + type=float, + default=0.5, + help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. ' + 'Only used for the realesr-general-x4v3 model')) + parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image') + parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video') + parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') + parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') + parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') + parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face') + parser.add_argument( + '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).') + parser.add_argument('--fps', type=float, default=None, help='FPS of the output video') + parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg') + parser.add_argument('--extract_frame_first', action='store_true') + parser.add_argument('--num_process_per_gpu', type=int, default=1) + parser.add_argument('--batch', type=int, default=4) + + parser.add_argument( + '--alpha_upsampler', + type=str, + default='realesrgan', + help='The upsampler for the alpha channels. Options: realesrgan | bicubic') + parser.add_argument( + '--ext', + type=str, + default='auto', + help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') + args = parser.parse_args() + + args.input = args.input.rstrip('/').rstrip('\\') + os.makedirs(args.output, exist_ok=True) + + if mimetypes.guess_type(args.input)[0] is not None and mimetypes.guess_type(args.input)[0].startswith('video'): + is_video = True + else: + is_video = False + + if is_video and args.input.endswith('.flv'): + mp4_path = args.input.replace('.flv', '.mp4') + os.system(f'ffmpeg -i {args.input} -codec copy {mp4_path}') + args.input = mp4_path + + if args.extract_frame_first and not is_video: + args.extract_frame_first = False + + run(args) + + if args.extract_frame_first: + tmp_frames_folder = osp.join(args.output, f'{args.video_name}_inp_tmp_frames') + shutil.rmtree(tmp_frames_folder) + + +if __name__ == '__main__': + main() From aa723c5b0833beb673762b266748cafb54566db6 Mon Sep 17 00:00:00 2001 From: eliphatfs Date: Wed, 28 Jun 2023 08:29:19 +0800 Subject: [PATCH 2/3] Remove profiling instrumentation. --- inference_realesrgan_video_fast.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/inference_realesrgan_video_fast.py b/inference_realesrgan_video_fast.py index e472a9e97..4088c076e 100644 --- a/inference_realesrgan_video_fast.py +++ b/inference_realesrgan_video_fast.py @@ -162,7 +162,6 @@ def __init__(self, args, audio, height, width, video_save_path, fps): loglevel='error').overwrite_output().run_async( pipe_stdin=True, pipe_stdout=True, cmd=args.ffmpeg_bin)) - @profile def write_frame(self, frame: np.ndarray): assert frame.dtype == np.uint8 frame = frame.data @@ -203,7 +202,6 @@ def pre_process_batched(self: RealESRGANer): @torch.no_grad() -@profile def batch_enhance_rgb(self: RealESRGANer, imgs, outscale=None, alpha_upsampler='realesrgan'): tensors = [] for img in imgs: @@ -231,7 +229,6 @@ def batch_enhance_rgb(self: RealESRGANer, imgs, outscale=None, alpha_upsampler=' yield output -@profile def inference_video(args, video_save_path, device=None, total_workers=1, worker_idx=0): # ---------------------- determine models according to model names ---------------------- # args.model_name = args.model_name.split('.pth')[0] From a706f392e552d1a7e5108747d48c9ceeb64ef64d Mon Sep 17 00:00:00 2001 From: eliphatfs Date: Wed, 28 Jun 2023 09:05:35 +0800 Subject: [PATCH 3/3] Removed unsupported parameters. --- inference_realesrgan_video_fast.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/inference_realesrgan_video_fast.py b/inference_realesrgan_video_fast.py index 4088c076e..21d1ec8f4 100644 --- a/inference_realesrgan_video_fast.py +++ b/inference_realesrgan_video_fast.py @@ -416,26 +416,29 @@ def main(): parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing') parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding') parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border') - parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face') + # parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face') parser.add_argument( '--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).') parser.add_argument('--fps', type=float, default=None, help='FPS of the output video') parser.add_argument('--ffmpeg_bin', type=str, default='ffmpeg', help='The path to ffmpeg') - parser.add_argument('--extract_frame_first', action='store_true') + # parser.add_argument('--extract_frame_first', action='store_true') parser.add_argument('--num_process_per_gpu', type=int, default=1) parser.add_argument('--batch', type=int, default=4) - parser.add_argument( - '--alpha_upsampler', - type=str, - default='realesrgan', - help='The upsampler for the alpha channels. Options: realesrgan | bicubic') - parser.add_argument( - '--ext', - type=str, - default='auto', - help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') + # parser.add_argument( + # '--alpha_upsampler', + # type=str, + # default='realesrgan', + # help='The upsampler for the alpha channels. Options: realesrgan | bicubic') + # parser.add_argument( + # '--ext', + # type=str, + # default='auto', + # help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') args = parser.parse_args() + args.extract_frame_first = False + args.face_enhance = False + # args.alpha_upsampler = 'bicubic' args.input = args.input.rstrip('/').rstrip('\\') os.makedirs(args.output, exist_ok=True)