Introduce dynamic frame procesors

This commit is contained in:
henryruhs 2023-06-13 14:52:21 +02:00
parent 96a403c98d
commit 889ca47de0
7 changed files with 158 additions and 132 deletions

View File

@ -2,6 +2,9 @@
import os
import sys
from roop.frame_processors.core import get_frame_processor_modules
# single thread doubles cuda performance - needs to be set before torch import
if any(arg.startswith('--execution-provider') for arg in sys.argv):
os.environ['OMP_NUM_THREADS'] = '1'
@ -23,10 +26,10 @@ import cv2
import roop.globals
import roop.ui as ui
import roop.swapper
import roop.enhancer
import roop.frame_processors.face_swapper
import roop.frame_processors.face_enhancer
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp
from roop.analyser import get_one_face
from roop.face_analyser import get_one_face
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
del torch
@ -40,7 +43,7 @@ def parse_args() -> None:
parser.add_argument('-f', '--face', help='use a face image', dest='source_path')
parser.add_argument('-t', '--target', help='replace image or video with face', dest='target_path')
parser.add_argument('-o', '--output', help='save output to this file', dest='output_path')
parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face-swapper'], choices=['face-swapper', 'face-enhancer'], nargs='+')
parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face_swapper'], choices=['face_swapper', 'face_enhancer'], nargs='+')
parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps', action='store_true', default=False)
parser.add_argument('--keep-audio', help='maintain original audio', dest='keep_audio', action='store_true', default=True)
parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False)
@ -70,6 +73,10 @@ def parse_args() -> None:
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
roop.globals.execution_threads = args.execution_threads
# limit face enhancer to cuda
if 'CUDAExecutionProvider' not in roop.globals.execution_providers and 'face_enhancer' in roop.globals.frame_processors:
roop.globals.frame_processors.remove('face_enhancer')
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
@ -172,12 +179,11 @@ def start() -> None:
if has_image_extension(roop.globals.target_path):
if predict_image(roop.globals.target_path) > 0.85:
destroy()
if 'face-swapper' in roop.globals.frame_processors:
update_status('Swapping in progress...')
roop.swapper.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors:
update_status('Enhancing in progress...')
roop.enhancer.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
for frame_processor in roop.globals.frame_processors:
update_status(f'{frame_processor} in progress...')
module = get_frame_processor_modules(frame_processor)
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
release_resources()
if is_image(roop.globals.target_path):
update_status('Processing to image succeed!')
else:
@ -192,15 +198,10 @@ def start() -> None:
update_status('Extracting frames...')
extract_frames(roop.globals.target_path)
temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
if 'face-swapper' in roop.globals.frame_processors:
update_status('Swapping in progress...')
conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.swapper.process_video)
release_resources()
# limit to one execution thread
roop.globals.execution_threads = 1
if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors:
update_status('Enhancing in progress...')
conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.enhancer.process_video)
for frame_processor in roop.globals.frame_processors:
update_status(f'{frame_processor} in progress...')
module = get_frame_processor_modules(frame_processor)
conditional_process_video(roop.globals.source_path, temp_frame_paths, module.process_video)
release_resources()
if roop.globals.keep_fps:
update_status('Detecting fps...')
@ -234,10 +235,9 @@ def destroy() -> None:
def run() -> None:
parse_args()
pre_check()
if 'face-swapper' in roop.globals.frame_processors:
roop.swapper.pre_check()
if 'face-enhancer' in roop.globals.frame_processors:
roop.enhancer.pre_check()
for frame_processor in roop.globals.frame_processors:
module = get_frame_processor_modules(frame_processor)
module.pre_check()
limit_resources()
if roop.globals.headless:
start()

View File

View File

@ -0,0 +1,23 @@
import sys
import importlib
from typing import Any
import torch
import roop.globals
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
del torch
FRAME_PROCESSOR_MODULES = None
def get_frame_processor_modules(frame_processor: str) -> Any:
global FRAME_PROCESSOR_MODULES
if not FRAME_PROCESSOR_MODULES:
try:
FRAME_PROCESSOR_MODULES = importlib.import_module(f'roop.frame_processors.{frame_processor}')
except ImportError:
sys.exit()
return FRAME_PROCESSOR_MODULES

View File

@ -18,14 +18,14 @@ THREAD_LOCK = threading.Lock()
def pre_check() -> None:
download_directory_path = resolve_relative_path('../models')
download_directory_path = resolve_relative_path('../../models')
conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'])
def get_code_former():
global CODE_FORMER
with THREAD_LOCK:
model_path = resolve_relative_path('../models/codeformer.pth')
model_path = resolve_relative_path('../../models/codeformer.pth')
if CODE_FORMER is None:
model = torch.load(model_path)['params_ema']
CODE_FORMER = ARCH_REGISTRY.get('CodeFormer')(

View File

@ -5,7 +5,7 @@ import insightface
import threading
import roop.globals
from roop.analyser import get_one_face, get_many_faces
from roop.face_analyser import get_one_face, get_many_faces
from roop.utilities import conditional_download, resolve_relative_path
FACE_SWAPPER = None
@ -13,7 +13,7 @@ THREAD_LOCK = threading.Lock()
def pre_check() -> None:
download_directory_path = resolve_relative_path('../models')
download_directory_path = resolve_relative_path('../../models')
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])
@ -22,7 +22,7 @@ def get_face_swapper() -> None:
with THREAD_LOCK:
if FACE_SWAPPER is None:
model_path = resolve_relative_path('../models/inswapper_128.onnx')
model_path = resolve_relative_path('../../models/inswapper_128.onnx')
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FACE_SWAPPER

View File

@ -6,9 +6,9 @@ import cv2
from PIL import Image, ImageTk, ImageOps
import roop.globals
from roop.analyser import get_one_face
from roop.face_analyser import get_one_face
from roop.capturer import get_video_frame, get_video_frame_total
from roop.swapper import process_faces
from roop.frame_processors.core import get_frame_processor_modules
from roop.utilities import is_image, is_video, resolve_relative_path
WINDOW_HEIGHT = 700
@ -204,7 +204,10 @@ def init_preview() -> None:
def update_preview(frame_number: int = 0) -> None:
if roop.globals.source_path and roop.globals.target_path:
video_frame = process_faces(
for frame_processor in roop.globals.frame_processors:
module = get_frame_processor_modules(frame_processor)
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
video_frame = module.process_faces(
get_one_face(cv2.imread(roop.globals.source_path)),
get_video_frame(roop.globals.target_path, frame_number)
)