* More accurate progress description

* Add thread lock to face analyser

* Use as_completed() for thread pool

* Show memory usage in progress bar

* Using Queue for dynamic thread processing

* Fix typing

* Introduce pick_quere() to allocate frames per future

* Bump version and add missing hook function

* Fix pick_queue()

* Introduce post process (#587)

* Introduce post_process to flush VRAM for example

* Delete frame processor instances

* Limit tensorflow usage to 1GB VRAM

* Set None instead of del

* Remove deprecated args

* Update gui preview

* Remove choices restriction from frame-processor and improve help output

* faithful donation label

* original donate button colors

* Introduce Frame processor xxx crashed

* ^_^ ^_^ ^_^ ^_^ ^_^

* Update GUI demo

---------

Co-authored-by: Somdev Sangwan <s0md3v@gmail.com>
This commit is contained in:
Henry Ruhs 2023-06-26 07:49:43 +02:00 committed by GitHub
parent b41149e4a2
commit 3d02b26766
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 113 additions and 102 deletions

View File

@ -15,7 +15,7 @@ jobs:
- run: pip install flake8 - run: pip install flake8
- run: pip install mypy - run: pip install mypy
- run: flake8 run.py roop - run: flake8 run.py roop
- run: mypy --config-file mypi.ini run.py roop - run: mypy run.py roop
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

View File

@ -31,30 +31,21 @@ Additional command line arguments are given below. To learn out what they do, ch
``` ```
options: options:
-h, --help show this help message and exit -h, --help show this help message and exit
-s SOURCE_PATH, --source SOURCE_PATH -s SOURCE_PATH, --source SOURCE_PATH select an source image
select an source image -t TARGET_PATH, --target TARGET_PATH select an target image or video
-t TARGET_PATH, --target TARGET_PATH -o OUTPUT_PATH, --output OUTPUT_PATH select output file or directory
select an target image or video --frame-processor FRAME_PROCESSOR [FRAME_PROCESSOR ...] frame processors (choices: face_swapper, face_enhancer, ...)
-o OUTPUT_PATH, --output OUTPUT_PATH --keep-fps keep original fps
select output file or directory --keep-audio keep original audio
--frame-processor {face_swapper,face_enhancer} [{face_swapper,face_enhancer} ...] --keep-frames keep temporary frames
pipeline of frame processors --many-faces process every face
--keep-fps keep original fps --video-encoder {libx264,libx265,libvpx-vp9} adjust output video encoder
--keep-audio keep original audio --video-quality [0-51] adjust output video quality
--keep-frames keep temporary frames --max-memory MAX_MEMORY maximum amount of RAM in GB
--many-faces process every face --execution-provider {cpu} [{cpu} ...] available execution provider (choices: cpu, ...)
--video-encoder {libx264,libx265,libvpx-vp9} --execution-threads EXECUTION_THREADS number of execution threads
adjust output video encoder -v, --version show program's version number and exit
--video-quality VIDEO_QUALITY
adjust output video quality
--max-memory MAX_MEMORY
maximum amount of RAM in GB
--execution-provider {cpu,...} [{cpu,...} ...]
execution provider
--execution-threads EXECUTION_THREADS
number of execution threads
-v, --version show program's version number and exit
``` ```
Looking for a CLI mode? Using the -s/--source argument will make the run program in cli mode. Looking for a CLI mode? Using the -s/--source argument will make the run program in cli mode.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

After

Width:  |  Height:  |  Size: 23 KiB

View File

View File

@ -33,11 +33,11 @@ warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
def parse_args() -> None: def parse_args() -> None:
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
program = argparse.ArgumentParser() program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100))
program.add_argument('-s', '--source', help='select an source image', dest='source_path') program.add_argument('-s', '--source', help='select an source image', dest='source_path')
program.add_argument('-t', '--target', help='select an target image or video', dest='target_path') program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
program.add_argument('-o', '--output', help='select output file or directory', dest='output_path') program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
program.add_argument('--frame-processor', help='pipeline of frame processors', dest='frame_processor', default=['face_swapper'], choices=['face_swapper', 'face_enhancer'], nargs='+') program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False) program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True) program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False) program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
@ -45,16 +45,10 @@ def parse_args() -> None:
program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9']) program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]') program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory()) program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
program.add_argument('--execution-provider', help='execution provider', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+') program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads()) program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}') program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
# register deprecated args
program.add_argument('-f', '--face', help=argparse.SUPPRESS, dest='source_path_deprecated')
program.add_argument('--cpu-cores', help=argparse.SUPPRESS, dest='cpu_cores_deprecated', type=int)
program.add_argument('--gpu-vendor', help=argparse.SUPPRESS, dest='gpu_vendor_deprecated')
program.add_argument('--gpu-threads', help=argparse.SUPPRESS, dest='gpu_threads_deprecated', type=int)
args = program.parse_args() args = program.parse_args()
roop.globals.source_path = args.source_path roop.globals.source_path = args.source_path
@ -72,27 +66,6 @@ def parse_args() -> None:
roop.globals.execution_providers = decode_execution_providers(args.execution_provider) roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
roop.globals.execution_threads = args.execution_threads roop.globals.execution_threads = args.execution_threads
# translate deprecated args
if args.source_path_deprecated:
print('\033[33mArgument -f and --face are deprecated. Use -s and --source instead.\033[0m')
roop.globals.source_path = args.source_path_deprecated
roop.globals.output_path = normalize_output_path(args.source_path_deprecated, roop.globals.target_path, args.output_path)
if args.cpu_cores_deprecated:
print('\033[33mArgument --cpu-cores is deprecated. Use --execution-threads instead.\033[0m')
roop.globals.execution_threads = args.cpu_cores_deprecated
if args.gpu_vendor_deprecated == 'apple':
print('\033[33mArgument --gpu-vendor apple is deprecated. Use --execution-provider coreml instead.\033[0m')
roop.globals.execution_providers = decode_execution_providers(['coreml'])
if args.gpu_vendor_deprecated == 'nvidia':
print('\033[33mArgument --gpu-vendor nvidia is deprecated. Use --execution-provider cuda instead.\033[0m')
roop.globals.execution_providers = decode_execution_providers(['cuda'])
if args.gpu_vendor_deprecated == 'amd':
print('\033[33mArgument --gpu-vendor amd is deprecated. Use --execution-provider cuda instead.\033[0m')
roop.globals.execution_providers = decode_execution_providers(['rocm'])
if args.gpu_threads_deprecated:
print('\033[33mArgument --gpu-threads is deprecated. Use --execution-threads instead.\033[0m')
roop.globals.execution_threads = args.gpu_threads_deprecated
def encode_execution_providers(execution_providers: List[str]) -> List[str]: def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers] return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
@ -125,7 +98,9 @@ def limit_resources() -> None:
# prevent tensorflow memory leak # prevent tensorflow memory leak
gpus = tensorflow.config.experimental.list_physical_devices('GPU') gpus = tensorflow.config.experimental.list_physical_devices('GPU')
for gpu in gpus: for gpu in gpus:
tensorflow.config.experimental.set_memory_growth(gpu, True) tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
])
# limit memory usage # limit memory usage
if roop.globals.max_memory: if roop.globals.max_memory:
memory = roop.globals.max_memory * 1024 ** 3 memory = roop.globals.max_memory * 1024 ** 3
@ -173,6 +148,7 @@ def start() -> None:
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME) update_status('Progressing...', frame_processor.NAME)
frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path) frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
frame_processor.post_process()
release_resources() release_resources()
if is_image(roop.globals.target_path): if is_image(roop.globals.target_path):
update_status('Processing to image succeed!') update_status('Processing to image succeed!')
@ -190,6 +166,7 @@ def start() -> None:
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors): for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
update_status('Progressing...', frame_processor.NAME) update_status('Progressing...', frame_processor.NAME)
frame_processor.process_video(roop.globals.source_path, temp_frame_paths) frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
frame_processor.post_process()
release_resources() release_resources()
# handles fps # handles fps
if roop.globals.keep_fps: if roop.globals.keep_fps:

View File

@ -1,3 +1,4 @@
import threading
from typing import Any from typing import Any
import insightface import insightface
@ -5,14 +6,16 @@ import roop.globals
from roop.typing import Frame from roop.typing import Frame
FACE_ANALYSER = None FACE_ANALYSER = None
THREAD_LOCK = threading.Lock()
def get_face_analyser() -> Any: def get_face_analyser() -> Any:
global FACE_ANALYSER global FACE_ANALYSER
if FACE_ANALYSER is None: with THREAD_LOCK:
FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers) if FACE_ANALYSER is None:
FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640)) FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.execution_providers)
FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640))
return FACE_ANALYSER return FACE_ANALYSER

View File

@ -1,2 +1,2 @@
name = 'roop' name = 'roop'
version = '1.0.1' version = '1.1.0'

View File

@ -1,6 +1,8 @@
import sys import os
import importlib import importlib
from concurrent.futures import ThreadPoolExecutor import psutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from queue import Queue
from types import ModuleType from types import ModuleType
from typing import Any, List, Callable from typing import Any, List, Callable
from tqdm import tqdm from tqdm import tqdm
@ -12,8 +14,10 @@ FRAME_PROCESSORS_INTERFACE = [
'pre_check', 'pre_check',
'pre_start', 'pre_start',
'process_frame', 'process_frame',
'process_frames',
'process_image', 'process_image',
'process_video' 'process_video',
'post_process'
] ]
@ -22,9 +26,9 @@ def load_frame_processor_module(frame_processor: str) -> Any:
frame_processor_module = importlib.import_module(f'roop.processors.frame.{frame_processor}') frame_processor_module = importlib.import_module(f'roop.processors.frame.{frame_processor}')
for method_name in FRAME_PROCESSORS_INTERFACE: for method_name in FRAME_PROCESSORS_INTERFACE:
if not hasattr(frame_processor_module, method_name): if not hasattr(frame_processor_module, method_name):
sys.exit() raise NotImplementedError
except ImportError: except (ImportError, NotImplementedError):
sys.exit() quit(f'Frame processor {frame_processor} crashed.')
return frame_processor_module return frame_processor_module
@ -38,19 +42,47 @@ def get_frame_processors_modules(frame_processors: List[str]) -> List[ModuleType
return FRAME_PROCESSORS_MODULES return FRAME_PROCESSORS_MODULES
def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], progress: Any = None) -> None: def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], update: Callable[[], None]) -> None:
with ThreadPoolExecutor(max_workers=roop.globals.execution_threads) as executor: with ThreadPoolExecutor(max_workers=roop.globals.execution_threads) as executor:
futures = [] futures = []
for path in temp_frame_paths: queue = create_queue(temp_frame_paths)
future = executor.submit(process_frames, source_path, [path], progress) queue_per_future = len(temp_frame_paths) // roop.globals.execution_threads
while not queue.empty():
future = executor.submit(process_frames, source_path, pick_queue(queue, queue_per_future), update)
futures.append(future) futures.append(future)
for future in futures: for future in as_completed(futures):
future.result() future.result()
def create_queue(temp_frame_paths: List[str]) -> Queue[str]:
queue: Queue[str] = Queue()
for frame_path in temp_frame_paths:
queue.put(frame_path)
return queue
def pick_queue(queue: Queue[str], queue_per_future: int) -> List[str]:
queues = []
for _ in range(queue_per_future):
if not queue.empty():
queues.append(queue.get())
return queues
def process_video(source_path: str, frame_paths: list[str], process_frames: Callable[[str, List[str], Any], None]) -> None: def process_video(source_path: str, frame_paths: list[str], process_frames: Callable[[str, List[str], Any], None]) -> None:
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
total = len(frame_paths) total = len(frame_paths)
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
progress.set_postfix({'execution_providers': roop.globals.execution_providers, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory}) multi_process_frame(source_path, frame_paths, process_frames, lambda: update_progress(progress))
multi_process_frame(source_path, frame_paths, process_frames, progress)
def update_progress(progress: Any = None) -> None:
process = psutil.Process(os.getpid())
memory_usage = process.memory_info().rss / 1024 / 1024 / 1024
progress.set_postfix({
'memory_usage': '{:.2f}'.format(memory_usage).zfill(5) + 'GB',
'execution_providers': roop.globals.execution_providers,
'execution_threads': roop.globals.execution_threads
})
progress.refresh()
progress.update(1)

View File

@ -1,4 +1,4 @@
from typing import Any, List from typing import Any, List, Callable
import cv2 import cv2
import threading import threading
import gfpgan import gfpgan
@ -16,6 +16,17 @@ THREAD_LOCK = threading.Lock()
NAME = 'ROOP.FACE-ENHANCER' NAME = 'ROOP.FACE-ENHANCER'
def get_face_enhancer() -> Any:
global FACE_ENHANCER
with THREAD_LOCK:
if FACE_ENHANCER is None:
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
# todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
return FACE_ENHANCER
def pre_check() -> bool: def pre_check() -> bool:
download_directory_path = resolve_relative_path('../models') download_directory_path = resolve_relative_path('../models')
conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth']) conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth'])
@ -29,15 +40,10 @@ def pre_start() -> bool:
return True return True
def get_face_enhancer() -> Any: def post_process() -> None:
global FACE_ENHANCER global FACE_ENHANCER
with THREAD_LOCK: FACE_ENHANCER = None
if FACE_ENHANCER is None:
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
# todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
return FACE_ENHANCER
def enhance_face(temp_frame: Frame) -> Frame: def enhance_face(temp_frame: Frame) -> Frame:
@ -56,13 +62,13 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
return temp_frame return temp_frame
def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
for temp_frame_path in temp_frame_paths: for temp_frame_path in temp_frame_paths:
temp_frame = cv2.imread(temp_frame_path) temp_frame = cv2.imread(temp_frame_path)
result = process_frame(None, temp_frame) result = process_frame(None, temp_frame)
cv2.imwrite(temp_frame_path, result) cv2.imwrite(temp_frame_path, result)
if progress: if update:
progress.update(1) update()
def process_image(source_path: str, target_path: str, output_path: str) -> None: def process_image(source_path: str, target_path: str, output_path: str) -> None:

View File

@ -1,4 +1,4 @@
from typing import Any, List from typing import Any, List, Callable
import cv2 import cv2
import insightface import insightface
import threading import threading
@ -15,6 +15,16 @@ THREAD_LOCK = threading.Lock()
NAME = 'ROOP.FACE-SWAPPER' NAME = 'ROOP.FACE-SWAPPER'
def get_face_swapper() -> Any:
global FACE_SWAPPER
with THREAD_LOCK:
if FACE_SWAPPER is None:
model_path = resolve_relative_path('../models/inswapper_128.onnx')
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FACE_SWAPPER
def pre_check() -> bool: def pre_check() -> bool:
download_directory_path = resolve_relative_path('../models') download_directory_path = resolve_relative_path('../models')
conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/inswapper_128.onnx']) conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/inswapper_128.onnx'])
@ -34,14 +44,10 @@ def pre_start() -> bool:
return True return True
def get_face_swapper() -> Any: def post_process() -> None:
global FACE_SWAPPER global FACE_SWAPPER
with THREAD_LOCK: FACE_SWAPPER = None
if FACE_SWAPPER is None:
model_path = resolve_relative_path('../models/inswapper_128.onnx')
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FACE_SWAPPER
def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
@ -61,18 +67,14 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
return temp_frame return temp_frame
def process_frames(source_path: str, temp_frame_paths: List[str], progress: Any = None) -> None: def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
source_face = get_one_face(cv2.imread(source_path)) source_face = get_one_face(cv2.imread(source_path))
for temp_frame_path in temp_frame_paths: for temp_frame_path in temp_frame_paths:
temp_frame = cv2.imread(temp_frame_path) temp_frame = cv2.imread(temp_frame_path)
try: result = process_frame(source_face, temp_frame)
result = process_frame(source_face, temp_frame) cv2.imwrite(temp_frame_path, result)
cv2.imwrite(temp_frame_path, result) if update:
except Exception as exception: update()
print(exception)
pass
if progress:
progress.update(1)
def process_image(source_path: str, target_path: str, output_path: str) -> None: def process_image(source_path: str, target_path: str, output_path: str) -> None:

View File

@ -153,6 +153,6 @@
} }
}, },
"RoopDonate": { "RoopDonate": {
"text_color": ["gray74", "gray60"] "text_color": ["#3a7ebf", "gray60"]
} }
} }

View File

@ -94,7 +94,7 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
status_label = ctk.CTkLabel(root, text=None, justify='center') status_label = ctk.CTkLabel(root, text=None, justify='center')
status_label.place(relx=0.1, rely=0.9, relwidth=0.8) status_label.place(relx=0.1, rely=0.9, relwidth=0.8)
donate_label = ctk.CTkLabel(root, text='Become a GitHub Sponsor', justify='center', cursor='hand2') donate_label = ctk.CTkLabel(root, text='^_^ Donate to project ^_^', justify='center', cursor='hand2')
donate_label.place(relx=0.1, rely=0.95, relwidth=0.8) donate_label.place(relx=0.1, rely=0.95, relwidth=0.8)
donate_label.configure(text_color=ctk.ThemeManager.theme.get('RoopDonate').get('text_color')) donate_label.configure(text_color=ctk.ThemeManager.theme.get('RoopDonate').get('text_color'))
donate_label.bind('<Button>', lambda event: webbrowser.open('https://github.com/sponsors/s0md3v')) donate_label.bind('<Button>', lambda event: webbrowser.open('https://github.com/sponsors/s0md3v'))

View File

@ -108,7 +108,7 @@ def clean_temp(target_path: str) -> None:
def has_image_extension(image_path: str) -> bool: def has_image_extension(image_path: str) -> bool:
return image_path.lower().endswith(('png', 'jpg', 'jpeg')) return image_path.lower().endswith(('png', 'jpg', 'jpeg', 'webp'))
def is_image(image_path: str) -> bool: def is_image(image_path: str) -> bool: