From 889ca47de0aa7db2faa6b0801d10135fc4d29789 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 13 Jun 2023 14:52:21 +0200 Subject: [PATCH] Introduce dynamic frame procesors --- roop/core.py | 48 ++--- roop/{analyser.py => face_analyser.py} | 0 roop/frame_processors/__init__.py | 0 roop/frame_processors/core.py | 23 ++ .../face_enhancer.py} | 4 +- .../face_swapper.py} | 200 +++++++++--------- roop/ui.py | 15 +- 7 files changed, 158 insertions(+), 132 deletions(-) rename roop/{analyser.py => face_analyser.py} (100%) create mode 100644 roop/frame_processors/__init__.py create mode 100644 roop/frame_processors/core.py rename roop/{enhancer.py => frame_processors/face_enhancer.py} (97%) rename roop/{swapper.py => frame_processors/face_swapper.py} (92%) diff --git a/roop/core.py b/roop/core.py index 2aedece..0c72d65 100755 --- a/roop/core.py +++ b/roop/core.py @@ -2,6 +2,9 @@ import os import sys + +from roop.frame_processors.core import get_frame_processor_modules + # single thread doubles cuda performance - needs to be set before torch import if any(arg.startswith('--execution-provider') for arg in sys.argv): os.environ['OMP_NUM_THREADS'] = '1' @@ -23,10 +26,10 @@ import cv2 import roop.globals import roop.ui as ui -import roop.swapper -import roop.enhancer +import roop.frame_processors.face_swapper +import roop.frame_processors.face_enhancer from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp -from roop.analyser import get_one_face +from roop.face_analyser import get_one_face if 'ROCMExecutionProvider' in roop.globals.execution_providers: del torch @@ -40,7 +43,7 @@ def parse_args() -> None: parser.add_argument('-f', '--face', help='use a face image', dest='source_path') parser.add_argument('-t', '--target', help='replace image or video with face', dest='target_path') parser.add_argument('-o', '--output', help='save output to this file', dest='output_path') - parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face-swapper'], choices=['face-swapper', 'face-enhancer'], nargs='+') + parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face_swapper'], choices=['face_swapper', 'face_enhancer'], nargs='+') parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps', action='store_true', default=False) parser.add_argument('--keep-audio', help='maintain original audio', dest='keep_audio', action='store_true', default=True) parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False) @@ -70,6 +73,10 @@ def parse_args() -> None: roop.globals.execution_providers = decode_execution_providers(args.execution_provider) roop.globals.execution_threads = args.execution_threads + # limit face enhancer to cuda + if 'CUDAExecutionProvider' not in roop.globals.execution_providers and 'face_enhancer' in roop.globals.frame_processors: + roop.globals.frame_processors.remove('face_enhancer') + def encode_execution_providers(execution_providers: List[str]) -> List[str]: return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers] @@ -172,12 +179,11 @@ def start() -> None: if has_image_extension(roop.globals.target_path): if predict_image(roop.globals.target_path) > 0.85: destroy() - if 'face-swapper' in roop.globals.frame_processors: - update_status('Swapping in progress...') - roop.swapper.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path) - if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors: - update_status('Enhancing in progress...') - roop.enhancer.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path) + for frame_processor in roop.globals.frame_processors: + update_status(f'{frame_processor} in progress...') + module = get_frame_processor_modules(frame_processor) + module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path) + release_resources() if is_image(roop.globals.target_path): update_status('Processing to image succeed!') else: @@ -192,16 +198,11 @@ def start() -> None: update_status('Extracting frames...') extract_frames(roop.globals.target_path) temp_frame_paths = get_temp_frame_paths(roop.globals.target_path) - if 'face-swapper' in roop.globals.frame_processors: - update_status('Swapping in progress...') - conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.swapper.process_video) - release_resources() - # limit to one execution thread - roop.globals.execution_threads = 1 - if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors: - update_status('Enhancing in progress...') - conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.enhancer.process_video) - release_resources() + for frame_processor in roop.globals.frame_processors: + update_status(f'{frame_processor} in progress...') + module = get_frame_processor_modules(frame_processor) + conditional_process_video(roop.globals.source_path, temp_frame_paths, module.process_video) + release_resources() if roop.globals.keep_fps: update_status('Detecting fps...') fps = detect_fps(roop.globals.target_path) @@ -234,10 +235,9 @@ def destroy() -> None: def run() -> None: parse_args() pre_check() - if 'face-swapper' in roop.globals.frame_processors: - roop.swapper.pre_check() - if 'face-enhancer' in roop.globals.frame_processors: - roop.enhancer.pre_check() + for frame_processor in roop.globals.frame_processors: + module = get_frame_processor_modules(frame_processor) + module.pre_check() limit_resources() if roop.globals.headless: start() diff --git a/roop/analyser.py b/roop/face_analyser.py similarity index 100% rename from roop/analyser.py rename to roop/face_analyser.py diff --git a/roop/frame_processors/__init__.py b/roop/frame_processors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/roop/frame_processors/core.py b/roop/frame_processors/core.py new file mode 100644 index 0000000..473c4e9 --- /dev/null +++ b/roop/frame_processors/core.py @@ -0,0 +1,23 @@ +import sys +import importlib +from typing import Any + +import torch + +import roop.globals + +if 'ROCMExecutionProvider' in roop.globals.execution_providers: + del torch + +FRAME_PROCESSOR_MODULES = None + + +def get_frame_processor_modules(frame_processor: str) -> Any: + global FRAME_PROCESSOR_MODULES + + if not FRAME_PROCESSOR_MODULES: + try: + FRAME_PROCESSOR_MODULES = importlib.import_module(f'roop.frame_processors.{frame_processor}') + except ImportError: + sys.exit() + return FRAME_PROCESSOR_MODULES diff --git a/roop/enhancer.py b/roop/frame_processors/face_enhancer.py similarity index 97% rename from roop/enhancer.py rename to roop/frame_processors/face_enhancer.py index a0a8b86..55c3a7e 100644 --- a/roop/enhancer.py +++ b/roop/frame_processors/face_enhancer.py @@ -18,14 +18,14 @@ THREAD_LOCK = threading.Lock() def pre_check() -> None: - download_directory_path = resolve_relative_path('../models') + download_directory_path = resolve_relative_path('../../models') conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth']) def get_code_former(): global CODE_FORMER with THREAD_LOCK: - model_path = resolve_relative_path('../models/codeformer.pth') + model_path = resolve_relative_path('../../models/codeformer.pth') if CODE_FORMER is None: model = torch.load(model_path)['params_ema'] CODE_FORMER = ARCH_REGISTRY.get('CodeFormer')( diff --git a/roop/swapper.py b/roop/frame_processors/face_swapper.py similarity index 92% rename from roop/swapper.py rename to roop/frame_processors/face_swapper.py index b896e1c..fee5c79 100644 --- a/roop/swapper.py +++ b/roop/frame_processors/face_swapper.py @@ -1,100 +1,100 @@ -from typing import Any, List -from tqdm import tqdm -import cv2 -import insightface -import threading - -import roop.globals -from roop.analyser import get_one_face, get_many_faces -from roop.utilities import conditional_download, resolve_relative_path - -FACE_SWAPPER = None -THREAD_LOCK = threading.Lock() - - -def pre_check() -> None: - download_directory_path = resolve_relative_path('../models') - conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx']) - - -def get_face_swapper() -> None: - global FACE_SWAPPER - - with THREAD_LOCK: - if FACE_SWAPPER is None: - model_path = resolve_relative_path('../models/inswapper_128.onnx') - FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers) - return FACE_SWAPPER - - -def swap_face(source_face: Any, target_face: Any, temp_frame: Any) -> Any: - if target_face: - return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True) - return temp_frame - - -def process_faces(source_face: Any, temp_frame: Any) -> Any: - if roop.globals.many_faces: - many_faces = get_many_faces(temp_frame) - if many_faces: - for target_face in many_faces: - temp_frame = swap_face(source_face, target_face, temp_frame) - else: - target_face = get_one_face(temp_frame) - if target_face: - temp_frame = swap_face(source_face, target_face, temp_frame) - return temp_frame - - -def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None: - source_face = get_one_face(cv2.imread(source_path)) - for temp_frame_path in temp_frame_paths: - temp_frame = cv2.imread(temp_frame_path) - try: - result = process_faces(source_face, temp_frame) - cv2.imwrite(temp_frame_path, result) - except Exception as exception: - print(exception) - pass - if progress: - progress.update(1) - - -def multi_process_frame(source_path: str, temp_frame_paths: List[str], progress) -> None: - threads = [] - frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads - remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads - start_index = 0 - # create threads by frames - for _ in range(roop.globals.execution_threads): - end_index = start_index + frames_per_thread - if remaining_frames > 0: - end_index += 1 - remaining_frames -= 1 - thread_paths = temp_frame_paths[start_index:end_index] - thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress)) - threads.append(thread) - thread.start() - start_index = end_index - # join threads - for thread in threads: - thread.join() - - -def process_image(source_path: str, target_path: str, output_path: str) -> None: - source_face = get_one_face(cv2.imread(source_path)) - target_frame = cv2.imread(target_path) - result = process_faces(source_face, target_frame) - cv2.imwrite(output_path, result) - - -def process_video(source_path: str, temp_frame_paths: List[str], mode: str) -> None: - progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' - total = len(temp_frame_paths) - with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: - if mode == 'multi-processing': - progress.set_postfix({'mode': mode, 'cores': roop.globals.cpu_cores, 'memory': roop.globals.max_memory}) - process_frames(source_path, temp_frame_paths, progress) - elif mode == 'multi-threading': - progress.set_postfix({'mode': mode, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory}) - multi_process_frame(source_path, temp_frame_paths, progress) +from typing import Any, List +from tqdm import tqdm +import cv2 +import insightface +import threading + +import roop.globals +from roop.face_analyser import get_one_face, get_many_faces +from roop.utilities import conditional_download, resolve_relative_path + +FACE_SWAPPER = None +THREAD_LOCK = threading.Lock() + + +def pre_check() -> None: + download_directory_path = resolve_relative_path('../../models') + conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx']) + + +def get_face_swapper() -> None: + global FACE_SWAPPER + + with THREAD_LOCK: + if FACE_SWAPPER is None: + model_path = resolve_relative_path('../../models/inswapper_128.onnx') + FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers) + return FACE_SWAPPER + + +def swap_face(source_face: Any, target_face: Any, temp_frame: Any) -> Any: + if target_face: + return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True) + return temp_frame + + +def process_faces(source_face: Any, temp_frame: Any) -> Any: + if roop.globals.many_faces: + many_faces = get_many_faces(temp_frame) + if many_faces: + for target_face in many_faces: + temp_frame = swap_face(source_face, target_face, temp_frame) + else: + target_face = get_one_face(temp_frame) + if target_face: + temp_frame = swap_face(source_face, target_face, temp_frame) + return temp_frame + + +def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None: + source_face = get_one_face(cv2.imread(source_path)) + for temp_frame_path in temp_frame_paths: + temp_frame = cv2.imread(temp_frame_path) + try: + result = process_faces(source_face, temp_frame) + cv2.imwrite(temp_frame_path, result) + except Exception as exception: + print(exception) + pass + if progress: + progress.update(1) + + +def multi_process_frame(source_path: str, temp_frame_paths: List[str], progress) -> None: + threads = [] + frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads + remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads + start_index = 0 + # create threads by frames + for _ in range(roop.globals.execution_threads): + end_index = start_index + frames_per_thread + if remaining_frames > 0: + end_index += 1 + remaining_frames -= 1 + thread_paths = temp_frame_paths[start_index:end_index] + thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress)) + threads.append(thread) + thread.start() + start_index = end_index + # join threads + for thread in threads: + thread.join() + + +def process_image(source_path: str, target_path: str, output_path: str) -> None: + source_face = get_one_face(cv2.imread(source_path)) + target_frame = cv2.imread(target_path) + result = process_faces(source_face, target_frame) + cv2.imwrite(output_path, result) + + +def process_video(source_path: str, temp_frame_paths: List[str], mode: str) -> None: + progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' + total = len(temp_frame_paths) + with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: + if mode == 'multi-processing': + progress.set_postfix({'mode': mode, 'cores': roop.globals.cpu_cores, 'memory': roop.globals.max_memory}) + process_frames(source_path, temp_frame_paths, progress) + elif mode == 'multi-threading': + progress.set_postfix({'mode': mode, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory}) + multi_process_frame(source_path, temp_frame_paths, progress) diff --git a/roop/ui.py b/roop/ui.py index 385442c..ca4dfa8 100644 --- a/roop/ui.py +++ b/roop/ui.py @@ -6,9 +6,9 @@ import cv2 from PIL import Image, ImageTk, ImageOps import roop.globals -from roop.analyser import get_one_face +from roop.face_analyser import get_one_face from roop.capturer import get_video_frame, get_video_frame_total -from roop.swapper import process_faces +from roop.frame_processors.core import get_frame_processor_modules from roop.utilities import is_image, is_video, resolve_relative_path WINDOW_HEIGHT = 700 @@ -204,10 +204,13 @@ def init_preview() -> None: def update_preview(frame_number: int = 0) -> None: if roop.globals.source_path and roop.globals.target_path: - video_frame = process_faces( - get_one_face(cv2.imread(roop.globals.source_path)), - get_video_frame(roop.globals.target_path, frame_number) - ) + for frame_processor in roop.globals.frame_processors: + module = get_frame_processor_modules(frame_processor) + module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path) + video_frame = module.process_faces( + get_one_face(cv2.imread(roop.globals.source_path)), + get_video_frame(roop.globals.target_path, frame_number) + ) image = Image.fromarray(cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB)) image = ImageOps.contain(image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS) image = ImageTk.PhotoImage(image)