diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b8f4ce..420a32d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ jobs: - run: pip install flake8 - run: pip install mypy - run: flake8 run.py roop - - run: mypy --config-file mypi.ini run.py roop + - run: mypy run.py roop test: runs-on: ubuntu-latest steps: diff --git a/README.md b/README.md index 58e3e02..ca3e92a 100644 --- a/README.md +++ b/README.md @@ -31,30 +31,21 @@ Additional command line arguments are given below. To learn out what they do, ch ``` options: - -h, --help show this help message and exit - -s SOURCE_PATH, --source SOURCE_PATH - select an source image - -t TARGET_PATH, --target TARGET_PATH - select an target image or video - -o OUTPUT_PATH, --output OUTPUT_PATH - select output file or directory - --frame-processor {face_swapper,face_enhancer} [{face_swapper,face_enhancer} ...] - pipeline of frame processors - --keep-fps keep original fps - --keep-audio keep original audio - --keep-frames keep temporary frames - --many-faces process every face - --video-encoder {libx264,libx265,libvpx-vp9} - adjust output video encoder - --video-quality VIDEO_QUALITY - adjust output video quality - --max-memory MAX_MEMORY - maximum amount of RAM in GB - --execution-provider {cpu,...} [{cpu,...} ...] - execution provider - --execution-threads EXECUTION_THREADS - number of execution threads - -v, --version show program's version number and exit + -h, --help show this help message and exit + -s SOURCE_PATH, --source SOURCE_PATH select an source image + -t TARGET_PATH, --target TARGET_PATH select an target image or video + -o OUTPUT_PATH, --output OUTPUT_PATH select output file or directory + --frame-processor FRAME_PROCESSOR [FRAME_PROCESSOR ...] frame processors (choices: face_swapper, face_enhancer, ...) + --keep-fps keep original fps + --keep-audio keep original audio + --keep-frames keep temporary frames + --many-faces process every face + --video-encoder {libx264,libx265,libvpx-vp9} adjust output video encoder + --video-quality [0-51] adjust output video quality + --max-memory MAX_MEMORY maximum amount of RAM in GB + --execution-provider {cpu} [{cpu} ...] available execution provider (choices: cpu, ...) + --execution-threads EXECUTION_THREADS number of execution threads + -v, --version show program's version number and exit ``` Looking for a CLI mode? Using the -s/--source argument will make the run program in cli mode. diff --git a/gui-demo.png b/gui-demo.png index f26b691..b76a54d 100644 Binary files a/gui-demo.png and b/gui-demo.png differ diff --git a/mypi.ini b/mypy.ini similarity index 100% rename from mypi.ini rename to mypy.ini diff --git a/roop/core.py b/roop/core.py index 09603c3..b70d854 100755 --- a/roop/core.py +++ b/roop/core.py @@ -33,11 +33,11 @@ warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') def parse_args() -> None: signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) - program = argparse.ArgumentParser() + program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100)) program.add_argument('-s', '--source', help='select an source image', dest='source_path') program.add_argument('-t', '--target', help='select an target image or video', dest='target_path') program.add_argument('-o', '--output', help='select output file or directory', dest='output_path') - program.add_argument('--frame-processor', help='pipeline of frame processors', dest='frame_processor', default=['face_swapper'], choices=['face_swapper', 'face_enhancer'], nargs='+') + program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+') program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False) program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True) program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False) @@ -45,16 +45,10 @@ def parse_args() -> None: program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9']) program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]') program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory()) - program.add_argument('--execution-provider', help='execution provider', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+') + program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+') program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads()) program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}') - # register deprecated args - program.add_argument('-f', '--face', help=argparse.SUPPRESS, dest='source_path_deprecated') - program.add_argument('--cpu-cores', help=argparse.SUPPRESS, dest='cpu_cores_deprecated', type=int) - program.add_argument('--gpu-vendor', help=argparse.SUPPRESS, dest='gpu_vendor_deprecated') - program.add_argument('--gpu-threads', help=argparse.SUPPRESS, dest='gpu_threads_deprecated', type=int) - args = program.parse_args() roop.globals.source_path = args.source_path @@ -72,27 +66,6 @@ def parse_args() -> None: roop.globals.execution_providers = decode_execution_providers(args.execution_provider) roop.globals.execution_threads = args.execution_threads - # translate deprecated args - if args.source_path_deprecated: - print('\033[33mArgument -f and --face are deprecated. Use -s and --source instead.\033[0m') - roop.globals.source_path = args.source_path_deprecated - roop.globals.output_path = normalize_output_path(args.source_path_deprecated, roop.globals.target_path, args.output_path) - if args.cpu_cores_deprecated: - print('\033[33mArgument --cpu-cores is deprecated. Use --execution-threads instead.\033[0m') - roop.globals.execution_threads = args.cpu_cores_deprecated - if args.gpu_vendor_deprecated == 'apple': - print('\033[33mArgument --gpu-vendor apple is deprecated. Use --execution-provider coreml instead.\033[0m') - roop.globals.execution_providers = decode_execution_providers(['coreml']) - if args.gpu_vendor_deprecated == 'nvidia': - print('\033[33mArgument --gpu-vendor nvidia is deprecated. Use --execution-provider cuda instead.\033[0m') - roop.globals.execution_providers = decode_execution_providers(['cuda']) - if args.gpu_vendor_deprecated == 'amd': - print('\033[33mArgument --gpu-vendor amd is deprecated. Use --execution-provider cuda instead.\033[0m') - roop.globals.execution_providers = decode_execution_providers(['rocm']) - if args.gpu_threads_deprecated: - print('\033[33mArgument --gpu-threads is deprecated. Use --execution-threads instead.\033[0m') - roop.globals.execution_threads = args.gpu_threads_deprecated - def encode_execution_providers(execution_providers: List[str]) -> List[str]: return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers] @@ -125,7 +98,9 @@ def limit_resources() -> None: # prevent tensorflow memory leak gpus = tensorflow.config.experimental.list_physical_devices('GPU') for gpu in gpus: - tensorflow.config.experimental.set_memory_growth(gpu, True) + tensorflow.config.experimental.set_virtual_device_configuration(gpu, [ + tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024) + ]) # limit memory usage if roop.globals.max_memory: memory = roop.globals.max_memory * 1024 ** 3 @@ -173,6 +148,7 @@ def start() -> None: for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): update_status('Progressing...', frame_processor.NAME) frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path) + frame_processor.post_process() release_resources() if is_image(roop.globals.target_path): update_status('Processing to image succeed!') @@ -190,6 +166,7 @@ def start() -> None: for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): update_status('Progressing...', frame_processor.NAME) frame_processor.process_video(roop.globals.source_path, temp_frame_paths) + frame_processor.post_process() release_resources() # handles fps if roop.globals.keep_fps: diff --git a/roop/face_analyser.py b/roop/face_analyser.py index ba7803e..9c0afe4 100644 --- a/roop/face_analyser.py +++ b/roop/face_analyser.py @@ -1,3 +1,4 @@ +import threading from typing import Any import insightface @@ -5,14 +6,16 @@ import roop.globals from roop.typing import Frame FACE_ANALYSER = None +THREAD_LOCK = threading.Lock() def get_face_analyser() -> Any: global FACE_ANALYSER - if FACE_ANALYSER is None: - FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers) - FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640)) + with THREAD_LOCK: + if FACE_ANALYSER is None: + FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers) + FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640)) return FACE_ANALYSER diff --git a/roop/metadata.py b/roop/metadata.py index 69c387e..35b0f02 100644 --- a/roop/metadata.py +++ b/roop/metadata.py @@ -1,2 +1,2 @@ name = 'roop' -version = '1.0.1' +version = '1.1.0' diff --git a/roop/processors/frame/core.py b/roop/processors/frame/core.py index a07a9a6..c225f9d 100644 --- a/roop/processors/frame/core.py +++ b/roop/processors/frame/core.py @@ -1,6 +1,8 @@ -import sys +import os import importlib -from concurrent.futures import ThreadPoolExecutor +import psutil +from concurrent.futures import ThreadPoolExecutor, as_completed +from queue import Queue from types import ModuleType from typing import Any, List, Callable from tqdm import tqdm @@ -12,8 +14,10 @@ FRAME_PROCESSORS_INTERFACE = [ 'pre_check', 'pre_start', 'process_frame', + 'process_frames', 'process_image', - 'process_video' + 'process_video', + 'post_process' ] @@ -22,9 +26,9 @@ def load_frame_processor_module(frame_processor: str) -> Any: frame_processor_module = importlib.import_module(f'roop.processors.frame.{frame_processor}') for method_name in FRAME_PROCESSORS_INTERFACE: if not hasattr(frame_processor_module, method_name): - sys.exit() - except ImportError: - sys.exit() + raise NotImplementedError + except (ImportError, NotImplementedError): + quit(f'Frame processor {frame_processor} crashed.') return frame_processor_module @@ -38,19 +42,47 @@ def get_frame_processors_modules(frame_processors: List[str]) -> List[ModuleType return FRAME_PROCESSORS_MODULES -def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], progress: Any = None) -> None: +def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], update: Callable[[], None]) -> None: with ThreadPoolExecutor(max_workers=roop.globals.execution_threads) as executor: futures = [] - for path in temp_frame_paths: - future = executor.submit(process_frames, source_path, [path], progress) + queue = create_queue(temp_frame_paths) + queue_per_future = len(temp_frame_paths) // roop.globals.execution_threads + while not queue.empty(): + future = executor.submit(process_frames, source_path, pick_queue(queue, queue_per_future), update) futures.append(future) - for future in futures: + for future in as_completed(futures): future.result() +def create_queue(temp_frame_paths: List[str]) -> Queue[str]: + queue: Queue[str] = Queue() + for frame_path in temp_frame_paths: + queue.put(frame_path) + return queue + + +def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]: + queues = [] + for _ in range(queue_per_future): + if not queue.empty(): + queues.append(queue.get()) + return queues + + def process_video(source_path: str, frame_paths: list[str], process_frames: Callable[[str, List[str], Any], None]) -> None: progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' total = len(frame_paths) with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: - progress.set_postfix({'execution_providers': roop.globals.execution_providers, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory}) - multi_process_frame(source_path, frame_paths, process_frames, progress) + multi_process_frame(source_path, frame_paths, process_frames, lambda: update_progress(progress)) + + +def update_progress(progress: Any = None) -> None: + process = psutil.Process(os.getpid()) + memory_usage = process.memory_info().rss / 1024 / 1024 / 1024 + progress.set_postfix({ + 'memory_usage': '{:.2f}'.format(memory_usage).zfill(5) + 'GB', + 'execution_providers': roop.globals.execution_providers, + 'execution_threads': roop.globals.execution_threads + }) + progress.refresh() + progress.update(1) diff --git a/roop/processors/frame/face_enhancer.py b/roop/processors/frame/face_enhancer.py index 50c7c54..3ff92ce 100644 --- a/roop/processors/frame/face_enhancer.py +++ b/roop/processors/frame/face_enhancer.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Callable import cv2 import threading import gfpgan @@ -16,6 +16,17 @@ THREAD_LOCK = threading.Lock() NAME = 'ROOP.FACE-ENHANCER' +def get_face_enhancer() -> Any: + global FACE_ENHANCER + + with THREAD_LOCK: + if FACE_ENHANCER is None: + model_path = resolve_relative_path('../models/GFPGANv1.4.pth') + # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399 + FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined] + return FACE_ENHANCER + + def pre_check() -> bool: download_directory_path = resolve_relative_path('../models') conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth']) @@ -29,15 +40,10 @@ def pre_start() -> bool: return True -def get_face_enhancer() -> Any: +def post_process() -> None: global FACE_ENHANCER - with THREAD_LOCK: - if FACE_ENHANCER is None: - model_path = resolve_relative_path('../models/GFPGANv1.4.pth') - # todo: set models path https://github.com/TencentARC/GFPGAN/issues/399 - FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined] - return FACE_ENHANCER + FACE_ENHANCER = None def enhance_face(temp_frame: Frame) -> Frame: @@ -56,13 +62,13 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame: return temp_frame -def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: +def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None: for temp_frame_path in temp_frame_paths: temp_frame = cv2.imread(temp_frame_path) result = process_frame(None, temp_frame) cv2.imwrite(temp_frame_path, result) - if progress: - progress.update(1) + if update: + update() def process_image(source_path: str, target_path: str, output_path: str) -> None: diff --git a/roop/processors/frame/face_swapper.py b/roop/processors/frame/face_swapper.py index 35063d2..c53b5b8 100644 --- a/roop/processors/frame/face_swapper.py +++ b/roop/processors/frame/face_swapper.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Callable import cv2 import insightface import threading @@ -15,6 +15,16 @@ THREAD_LOCK = threading.Lock() NAME = 'ROOP.FACE-SWAPPER' +def get_face_swapper() -> Any: + global FACE_SWAPPER + + with THREAD_LOCK: + if FACE_SWAPPER is None: + model_path = resolve_relative_path('../models/inswapper_128.onnx') + FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers) + return FACE_SWAPPER + + def pre_check() -> bool: download_directory_path = resolve_relative_path('../models') conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/inswapper_128.onnx']) @@ -34,14 +44,10 @@ def pre_start() -> bool: return True -def get_face_swapper() -> Any: +def post_process() -> None: global FACE_SWAPPER - with THREAD_LOCK: - if FACE_SWAPPER is None: - model_path = resolve_relative_path('../models/inswapper_128.onnx') - FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers) - return FACE_SWAPPER + FACE_SWAPPER = None def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: @@ -61,18 +67,14 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame: return temp_frame -def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: +def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None: source_face = get_one_face(cv2.imread(source_path)) for temp_frame_path in temp_frame_paths: temp_frame = cv2.imread(temp_frame_path) - try: - result = process_frame(source_face, temp_frame) - cv2.imwrite(temp_frame_path, result) - except Exception as exception: - print(exception) - pass - if progress: - progress.update(1) + result = process_frame(source_face, temp_frame) + cv2.imwrite(temp_frame_path, result) + if update: + update() def process_image(source_path: str, target_path: str, output_path: str) -> None: diff --git a/roop/ui.json b/roop/ui.json index 752210c..4930991 100644 --- a/roop/ui.json +++ b/roop/ui.json @@ -153,6 +153,6 @@ } }, "RoopDonate": { - "text_color": ["gray74", "gray60"] + "text_color": ["#3a7ebf", "gray60"] } } diff --git a/roop/ui.py b/roop/ui.py index f413f20..ba693da 100644 --- a/roop/ui.py +++ b/roop/ui.py @@ -94,7 +94,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C status_label = ctk.CTkLabel(root, text=None, justify='center') status_label.place(relx=0.1, rely=0.9, relwidth=0.8) - donate_label = ctk.CTkLabel(root, text='Become a GitHub Sponsor', justify='center', cursor='hand2') + donate_label = ctk.CTkLabel(root, text='^_^ Donate to project ^_^', justify='center', cursor='hand2') donate_label.place(relx=0.1, rely=0.95, relwidth=0.8) donate_label.configure(text_color=ctk.ThemeManager.theme.get('RoopDonate').get('text_color')) donate_label.bind('