Introduce dynamic frame procesors

This commit is contained in:
henryruhs 2023-06-13 14:52:21 +02:00
parent 96a403c98d
commit 889ca47de0
7 changed files with 158 additions and 132 deletions

View File

@ -2,6 +2,9 @@
import os
import sys
from roop.frame_processors.core import get_frame_processor_modules
# single thread doubles cuda performance - needs to be set before torch import
if any(arg.startswith('--execution-provider') for arg in sys.argv):
os.environ['OMP_NUM_THREADS'] = '1'
@ -23,10 +26,10 @@ import cv2
import roop.globals
import roop.ui as ui
import roop.swapper
import roop.enhancer
import roop.frame_processors.face_swapper
import roop.frame_processors.face_enhancer
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp
from roop.analyser import get_one_face
from roop.face_analyser import get_one_face
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
del torch
@ -40,7 +43,7 @@ def parse_args() -> None:
parser.add_argument('-f', '--face', help='use a face image', dest='source_path')
parser.add_argument('-t', '--target', help='replace image or video with face', dest='target_path')
parser.add_argument('-o', '--output', help='save output to this file', dest='output_path')
parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face-swapper'], choices=['face-swapper', 'face-enhancer'], nargs='+')
parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face_swapper'], choices=['face_swapper', 'face_enhancer'], nargs='+')
parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps', action='store_true', default=False)
parser.add_argument('--keep-audio', help='maintain original audio', dest='keep_audio', action='store_true', default=True)
parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False)
@ -70,6 +73,10 @@ def parse_args() -> None:
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
roop.globals.execution_threads = args.execution_threads
# limit face enhancer to cuda
if 'CUDAExecutionProvider' not in roop.globals.execution_providers and 'face_enhancer' in roop.globals.frame_processors:
roop.globals.frame_processors.remove('face_enhancer')
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
@ -172,12 +179,11 @@ def start() -> None:
if has_image_extension(roop.globals.target_path):
if predict_image(roop.globals.target_path) > 0.85:
destroy()
if 'face-swapper' in roop.globals.frame_processors:
update_status('Swapping in progress...')
roop.swapper.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors:
update_status('Enhancing in progress...')
roop.enhancer.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
for frame_processor in roop.globals.frame_processors:
update_status(f'{frame_processor} in progress...')
module = get_frame_processor_modules(frame_processor)
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
release_resources()
if is_image(roop.globals.target_path):
update_status('Processing to image succeed!')
else:
@ -192,16 +198,11 @@ def start() -> None:
update_status('Extracting frames...')
extract_frames(roop.globals.target_path)
temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
if 'face-swapper' in roop.globals.frame_processors:
update_status('Swapping in progress...')
conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.swapper.process_video)
release_resources()
# limit to one execution thread
roop.globals.execution_threads = 1
if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors:
update_status('Enhancing in progress...')
conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.enhancer.process_video)
release_resources()
for frame_processor in roop.globals.frame_processors:
update_status(f'{frame_processor} in progress...')
module = get_frame_processor_modules(frame_processor)
conditional_process_video(roop.globals.source_path, temp_frame_paths, module.process_video)
release_resources()
if roop.globals.keep_fps:
update_status('Detecting fps...')
fps = detect_fps(roop.globals.target_path)
@ -234,10 +235,9 @@ def destroy() -> None:
def run() -> None:
parse_args()
pre_check()
if 'face-swapper' in roop.globals.frame_processors:
roop.swapper.pre_check()
if 'face-enhancer' in roop.globals.frame_processors:
roop.enhancer.pre_check()
for frame_processor in roop.globals.frame_processors:
module = get_frame_processor_modules(frame_processor)
module.pre_check()
limit_resources()
if roop.globals.headless:
start()

View File

View File

@ -0,0 +1,23 @@
import sys
import importlib
from typing import Any
import torch
import roop.globals
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
del torch
FRAME_PROCESSOR_MODULES = None
def get_frame_processor_modules(frame_processor: str) -> Any:
global FRAME_PROCESSOR_MODULES
if not FRAME_PROCESSOR_MODULES:
try:
FRAME_PROCESSOR_MODULES = importlib.import_module(f'roop.frame_processors.{frame_processor}')
except ImportError:
sys.exit()
return FRAME_PROCESSOR_MODULES

View File

@ -18,14 +18,14 @@ THREAD_LOCK = threading.Lock()
def pre_check() -> None:
download_directory_path = resolve_relative_path('../models')
download_directory_path = resolve_relative_path('../../models')
conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'])
def get_code_former():
global CODE_FORMER
with THREAD_LOCK:
model_path = resolve_relative_path('../models/codeformer.pth')
model_path = resolve_relative_path('../../models/codeformer.pth')
if CODE_FORMER is None:
model = torch.load(model_path)['params_ema']
CODE_FORMER = ARCH_REGISTRY.get('CodeFormer')(

View File

@ -1,100 +1,100 @@
from typing import Any, List
from tqdm import tqdm
import cv2
import insightface
import threading
import roop.globals
from roop.analyser import get_one_face, get_many_faces
from roop.utilities import conditional_download, resolve_relative_path
FACE_SWAPPER = None
THREAD_LOCK = threading.Lock()
def pre_check() -> None:
download_directory_path = resolve_relative_path('../models')
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])
def get_face_swapper() -> None:
global FACE_SWAPPER
with THREAD_LOCK:
if FACE_SWAPPER is None:
model_path = resolve_relative_path('../models/inswapper_128.onnx')
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FACE_SWAPPER
def swap_face(source_face: Any, target_face: Any, temp_frame: Any) -> Any:
if target_face:
return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True)
return temp_frame
def process_faces(source_face: Any, temp_frame: Any) -> Any:
if roop.globals.many_faces:
many_faces = get_many_faces(temp_frame)
if many_faces:
for target_face in many_faces:
temp_frame = swap_face(source_face, target_face, temp_frame)
else:
target_face = get_one_face(temp_frame)
if target_face:
temp_frame = swap_face(source_face, target_face, temp_frame)
return temp_frame
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
source_face = get_one_face(cv2.imread(source_path))
for temp_frame_path in temp_frame_paths:
temp_frame = cv2.imread(temp_frame_path)
try:
result = process_faces(source_face, temp_frame)
cv2.imwrite(temp_frame_path, result)
except Exception as exception:
print(exception)
pass
if progress:
progress.update(1)
def multi_process_frame(source_path: str, temp_frame_paths: List[str], progress) -> None:
threads = []
frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads
remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads
start_index = 0
# create threads by frames
for _ in range(roop.globals.execution_threads):
end_index = start_index + frames_per_thread
if remaining_frames > 0:
end_index += 1
remaining_frames -= 1
thread_paths = temp_frame_paths[start_index:end_index]
thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress))
threads.append(thread)
thread.start()
start_index = end_index
# join threads
for thread in threads:
thread.join()
def process_image(source_path: str, target_path: str, output_path: str) -> None:
source_face = get_one_face(cv2.imread(source_path))
target_frame = cv2.imread(target_path)
result = process_faces(source_face, target_frame)
cv2.imwrite(output_path, result)
def process_video(source_path: str, temp_frame_paths: List[str], mode: str) -> None:
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
total = len(temp_frame_paths)
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
if mode == 'multi-processing':
progress.set_postfix({'mode': mode, 'cores': roop.globals.cpu_cores, 'memory': roop.globals.max_memory})
process_frames(source_path, temp_frame_paths, progress)
elif mode == 'multi-threading':
progress.set_postfix({'mode': mode, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory})
multi_process_frame(source_path, temp_frame_paths, progress)
from typing import Any, List
from tqdm import tqdm
import cv2
import insightface
import threading
import roop.globals
from roop.face_analyser import get_one_face, get_many_faces
from roop.utilities import conditional_download, resolve_relative_path
FACE_SWAPPER = None
THREAD_LOCK = threading.Lock()
def pre_check() -> None:
download_directory_path = resolve_relative_path('../../models')
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])
def get_face_swapper() -> None:
global FACE_SWAPPER
with THREAD_LOCK:
if FACE_SWAPPER is None:
model_path = resolve_relative_path('../../models/inswapper_128.onnx')
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
return FACE_SWAPPER
def swap_face(source_face: Any, target_face: Any, temp_frame: Any) -> Any:
if target_face:
return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True)
return temp_frame
def process_faces(source_face: Any, temp_frame: Any) -> Any:
if roop.globals.many_faces:
many_faces = get_many_faces(temp_frame)
if many_faces:
for target_face in many_faces:
temp_frame = swap_face(source_face, target_face, temp_frame)
else:
target_face = get_one_face(temp_frame)
if target_face:
temp_frame = swap_face(source_face, target_face, temp_frame)
return temp_frame
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
source_face = get_one_face(cv2.imread(source_path))
for temp_frame_path in temp_frame_paths:
temp_frame = cv2.imread(temp_frame_path)
try:
result = process_faces(source_face, temp_frame)
cv2.imwrite(temp_frame_path, result)
except Exception as exception:
print(exception)
pass
if progress:
progress.update(1)
def multi_process_frame(source_path: str, temp_frame_paths: List[str], progress) -> None:
threads = []
frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads
remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads
start_index = 0
# create threads by frames
for _ in range(roop.globals.execution_threads):
end_index = start_index + frames_per_thread
if remaining_frames > 0:
end_index += 1
remaining_frames -= 1
thread_paths = temp_frame_paths[start_index:end_index]
thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress))
threads.append(thread)
thread.start()
start_index = end_index
# join threads
for thread in threads:
thread.join()
def process_image(source_path: str, target_path: str, output_path: str) -> None:
source_face = get_one_face(cv2.imread(source_path))
target_frame = cv2.imread(target_path)
result = process_faces(source_face, target_frame)
cv2.imwrite(output_path, result)
def process_video(source_path: str, temp_frame_paths: List[str], mode: str) -> None:
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
total = len(temp_frame_paths)
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
if mode == 'multi-processing':
progress.set_postfix({'mode': mode, 'cores': roop.globals.cpu_cores, 'memory': roop.globals.max_memory})
process_frames(source_path, temp_frame_paths, progress)
elif mode == 'multi-threading':
progress.set_postfix({'mode': mode, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory})
multi_process_frame(source_path, temp_frame_paths, progress)

View File

@ -6,9 +6,9 @@ import cv2
from PIL import Image, ImageTk, ImageOps
import roop.globals
from roop.analyser import get_one_face
from roop.face_analyser import get_one_face
from roop.capturer import get_video_frame, get_video_frame_total
from roop.swapper import process_faces
from roop.frame_processors.core import get_frame_processor_modules
from roop.utilities import is_image, is_video, resolve_relative_path
WINDOW_HEIGHT = 700
@ -204,10 +204,13 @@ def init_preview() -> None:
def update_preview(frame_number: int = 0) -> None:
if roop.globals.source_path and roop.globals.target_path:
video_frame = process_faces(
get_one_face(cv2.imread(roop.globals.source_path)),
get_video_frame(roop.globals.target_path, frame_number)
)
for frame_processor in roop.globals.frame_processors:
module = get_frame_processor_modules(frame_processor)
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
video_frame = module.process_faces(
get_one_face(cv2.imread(roop.globals.source_path)),
get_video_frame(roop.globals.target_path, frame_number)
)
image = Image.fromarray(cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB))
image = ImageOps.contain(image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS)
image = ImageTk.PhotoImage(image)