mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
Introduce dynamic frame procesors
This commit is contained in:
parent
96a403c98d
commit
889ca47de0
48
roop/core.py
48
roop/core.py
@ -2,6 +2,9 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from roop.frame_processors.core import get_frame_processor_modules
|
||||
|
||||
# single thread doubles cuda performance - needs to be set before torch import
|
||||
if any(arg.startswith('--execution-provider') for arg in sys.argv):
|
||||
os.environ['OMP_NUM_THREADS'] = '1'
|
||||
@ -23,10 +26,10 @@ import cv2
|
||||
|
||||
import roop.globals
|
||||
import roop.ui as ui
|
||||
import roop.swapper
|
||||
import roop.enhancer
|
||||
import roop.frame_processors.face_swapper
|
||||
import roop.frame_processors.face_enhancer
|
||||
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp
|
||||
from roop.analyser import get_one_face
|
||||
from roop.face_analyser import get_one_face
|
||||
|
||||
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
||||
del torch
|
||||
@ -40,7 +43,7 @@ def parse_args() -> None:
|
||||
parser.add_argument('-f', '--face', help='use a face image', dest='source_path')
|
||||
parser.add_argument('-t', '--target', help='replace image or video with face', dest='target_path')
|
||||
parser.add_argument('-o', '--output', help='save output to this file', dest='output_path')
|
||||
parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face-swapper'], choices=['face-swapper', 'face-enhancer'], nargs='+')
|
||||
parser.add_argument('--frame-processor', help='list of frame processors to run', dest='frame_processor', default=['face_swapper'], choices=['face_swapper', 'face_enhancer'], nargs='+')
|
||||
parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps', action='store_true', default=False)
|
||||
parser.add_argument('--keep-audio', help='maintain original audio', dest='keep_audio', action='store_true', default=True)
|
||||
parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False)
|
||||
@ -70,6 +73,10 @@ def parse_args() -> None:
|
||||
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
|
||||
roop.globals.execution_threads = args.execution_threads
|
||||
|
||||
# limit face enhancer to cuda
|
||||
if 'CUDAExecutionProvider' not in roop.globals.execution_providers and 'face_enhancer' in roop.globals.frame_processors:
|
||||
roop.globals.frame_processors.remove('face_enhancer')
|
||||
|
||||
|
||||
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
|
||||
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
|
||||
@ -172,12 +179,11 @@ def start() -> None:
|
||||
if has_image_extension(roop.globals.target_path):
|
||||
if predict_image(roop.globals.target_path) > 0.85:
|
||||
destroy()
|
||||
if 'face-swapper' in roop.globals.frame_processors:
|
||||
update_status('Swapping in progress...')
|
||||
roop.swapper.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
|
||||
if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors:
|
||||
update_status('Enhancing in progress...')
|
||||
roop.enhancer.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
|
||||
for frame_processor in roop.globals.frame_processors:
|
||||
update_status(f'{frame_processor} in progress...')
|
||||
module = get_frame_processor_modules(frame_processor)
|
||||
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
|
||||
release_resources()
|
||||
if is_image(roop.globals.target_path):
|
||||
update_status('Processing to image succeed!')
|
||||
else:
|
||||
@ -192,16 +198,11 @@ def start() -> None:
|
||||
update_status('Extracting frames...')
|
||||
extract_frames(roop.globals.target_path)
|
||||
temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
|
||||
if 'face-swapper' in roop.globals.frame_processors:
|
||||
update_status('Swapping in progress...')
|
||||
conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.swapper.process_video)
|
||||
release_resources()
|
||||
# limit to one execution thread
|
||||
roop.globals.execution_threads = 1
|
||||
if 'CUDAExecutionProvider' in roop.globals.execution_providers and 'face-enhancer' in roop.globals.frame_processors:
|
||||
update_status('Enhancing in progress...')
|
||||
conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.enhancer.process_video)
|
||||
release_resources()
|
||||
for frame_processor in roop.globals.frame_processors:
|
||||
update_status(f'{frame_processor} in progress...')
|
||||
module = get_frame_processor_modules(frame_processor)
|
||||
conditional_process_video(roop.globals.source_path, temp_frame_paths, module.process_video)
|
||||
release_resources()
|
||||
if roop.globals.keep_fps:
|
||||
update_status('Detecting fps...')
|
||||
fps = detect_fps(roop.globals.target_path)
|
||||
@ -234,10 +235,9 @@ def destroy() -> None:
|
||||
def run() -> None:
|
||||
parse_args()
|
||||
pre_check()
|
||||
if 'face-swapper' in roop.globals.frame_processors:
|
||||
roop.swapper.pre_check()
|
||||
if 'face-enhancer' in roop.globals.frame_processors:
|
||||
roop.enhancer.pre_check()
|
||||
for frame_processor in roop.globals.frame_processors:
|
||||
module = get_frame_processor_modules(frame_processor)
|
||||
module.pre_check()
|
||||
limit_resources()
|
||||
if roop.globals.headless:
|
||||
start()
|
||||
|
||||
0
roop/frame_processors/__init__.py
Normal file
0
roop/frame_processors/__init__.py
Normal file
23
roop/frame_processors/core.py
Normal file
23
roop/frame_processors/core.py
Normal file
@ -0,0 +1,23 @@
|
||||
import sys
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import roop.globals
|
||||
|
||||
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
|
||||
del torch
|
||||
|
||||
FRAME_PROCESSOR_MODULES = None
|
||||
|
||||
|
||||
def get_frame_processor_modules(frame_processor: str) -> Any:
|
||||
global FRAME_PROCESSOR_MODULES
|
||||
|
||||
if not FRAME_PROCESSOR_MODULES:
|
||||
try:
|
||||
FRAME_PROCESSOR_MODULES = importlib.import_module(f'roop.frame_processors.{frame_processor}')
|
||||
except ImportError:
|
||||
sys.exit()
|
||||
return FRAME_PROCESSOR_MODULES
|
||||
@ -18,14 +18,14 @@ THREAD_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def pre_check() -> None:
|
||||
download_directory_path = resolve_relative_path('../models')
|
||||
download_directory_path = resolve_relative_path('../../models')
|
||||
conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'])
|
||||
|
||||
|
||||
def get_code_former():
|
||||
global CODE_FORMER
|
||||
with THREAD_LOCK:
|
||||
model_path = resolve_relative_path('../models/codeformer.pth')
|
||||
model_path = resolve_relative_path('../../models/codeformer.pth')
|
||||
if CODE_FORMER is None:
|
||||
model = torch.load(model_path)['params_ema']
|
||||
CODE_FORMER = ARCH_REGISTRY.get('CodeFormer')(
|
||||
@ -1,100 +1,100 @@
|
||||
from typing import Any, List
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
import insightface
|
||||
import threading
|
||||
|
||||
import roop.globals
|
||||
from roop.analyser import get_one_face, get_many_faces
|
||||
from roop.utilities import conditional_download, resolve_relative_path
|
||||
|
||||
FACE_SWAPPER = None
|
||||
THREAD_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def pre_check() -> None:
|
||||
download_directory_path = resolve_relative_path('../models')
|
||||
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])
|
||||
|
||||
|
||||
def get_face_swapper() -> None:
|
||||
global FACE_SWAPPER
|
||||
|
||||
with THREAD_LOCK:
|
||||
if FACE_SWAPPER is None:
|
||||
model_path = resolve_relative_path('../models/inswapper_128.onnx')
|
||||
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
|
||||
return FACE_SWAPPER
|
||||
|
||||
|
||||
def swap_face(source_face: Any, target_face: Any, temp_frame: Any) -> Any:
|
||||
if target_face:
|
||||
return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True)
|
||||
return temp_frame
|
||||
|
||||
|
||||
def process_faces(source_face: Any, temp_frame: Any) -> Any:
|
||||
if roop.globals.many_faces:
|
||||
many_faces = get_many_faces(temp_frame)
|
||||
if many_faces:
|
||||
for target_face in many_faces:
|
||||
temp_frame = swap_face(source_face, target_face, temp_frame)
|
||||
else:
|
||||
target_face = get_one_face(temp_frame)
|
||||
if target_face:
|
||||
temp_frame = swap_face(source_face, target_face, temp_frame)
|
||||
return temp_frame
|
||||
|
||||
|
||||
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
|
||||
source_face = get_one_face(cv2.imread(source_path))
|
||||
for temp_frame_path in temp_frame_paths:
|
||||
temp_frame = cv2.imread(temp_frame_path)
|
||||
try:
|
||||
result = process_faces(source_face, temp_frame)
|
||||
cv2.imwrite(temp_frame_path, result)
|
||||
except Exception as exception:
|
||||
print(exception)
|
||||
pass
|
||||
if progress:
|
||||
progress.update(1)
|
||||
|
||||
|
||||
def multi_process_frame(source_path: str, temp_frame_paths: List[str], progress) -> None:
|
||||
threads = []
|
||||
frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads
|
||||
remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads
|
||||
start_index = 0
|
||||
# create threads by frames
|
||||
for _ in range(roop.globals.execution_threads):
|
||||
end_index = start_index + frames_per_thread
|
||||
if remaining_frames > 0:
|
||||
end_index += 1
|
||||
remaining_frames -= 1
|
||||
thread_paths = temp_frame_paths[start_index:end_index]
|
||||
thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
start_index = end_index
|
||||
# join threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def process_image(source_path: str, target_path: str, output_path: str) -> None:
|
||||
source_face = get_one_face(cv2.imread(source_path))
|
||||
target_frame = cv2.imread(target_path)
|
||||
result = process_faces(source_face, target_frame)
|
||||
cv2.imwrite(output_path, result)
|
||||
|
||||
|
||||
def process_video(source_path: str, temp_frame_paths: List[str], mode: str) -> None:
|
||||
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
||||
total = len(temp_frame_paths)
|
||||
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
|
||||
if mode == 'multi-processing':
|
||||
progress.set_postfix({'mode': mode, 'cores': roop.globals.cpu_cores, 'memory': roop.globals.max_memory})
|
||||
process_frames(source_path, temp_frame_paths, progress)
|
||||
elif mode == 'multi-threading':
|
||||
progress.set_postfix({'mode': mode, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory})
|
||||
multi_process_frame(source_path, temp_frame_paths, progress)
|
||||
from typing import Any, List
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
import insightface
|
||||
import threading
|
||||
|
||||
import roop.globals
|
||||
from roop.face_analyser import get_one_face, get_many_faces
|
||||
from roop.utilities import conditional_download, resolve_relative_path
|
||||
|
||||
FACE_SWAPPER = None
|
||||
THREAD_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def pre_check() -> None:
|
||||
download_directory_path = resolve_relative_path('../../models')
|
||||
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])
|
||||
|
||||
|
||||
def get_face_swapper() -> None:
|
||||
global FACE_SWAPPER
|
||||
|
||||
with THREAD_LOCK:
|
||||
if FACE_SWAPPER is None:
|
||||
model_path = resolve_relative_path('../../models/inswapper_128.onnx')
|
||||
FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.execution_providers)
|
||||
return FACE_SWAPPER
|
||||
|
||||
|
||||
def swap_face(source_face: Any, target_face: Any, temp_frame: Any) -> Any:
|
||||
if target_face:
|
||||
return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True)
|
||||
return temp_frame
|
||||
|
||||
|
||||
def process_faces(source_face: Any, temp_frame: Any) -> Any:
|
||||
if roop.globals.many_faces:
|
||||
many_faces = get_many_faces(temp_frame)
|
||||
if many_faces:
|
||||
for target_face in many_faces:
|
||||
temp_frame = swap_face(source_face, target_face, temp_frame)
|
||||
else:
|
||||
target_face = get_one_face(temp_frame)
|
||||
if target_face:
|
||||
temp_frame = swap_face(source_face, target_face, temp_frame)
|
||||
return temp_frame
|
||||
|
||||
|
||||
def process_frames(source_path: str, temp_frame_paths: List[str], progress=None) -> None:
|
||||
source_face = get_one_face(cv2.imread(source_path))
|
||||
for temp_frame_path in temp_frame_paths:
|
||||
temp_frame = cv2.imread(temp_frame_path)
|
||||
try:
|
||||
result = process_faces(source_face, temp_frame)
|
||||
cv2.imwrite(temp_frame_path, result)
|
||||
except Exception as exception:
|
||||
print(exception)
|
||||
pass
|
||||
if progress:
|
||||
progress.update(1)
|
||||
|
||||
|
||||
def multi_process_frame(source_path: str, temp_frame_paths: List[str], progress) -> None:
|
||||
threads = []
|
||||
frames_per_thread = len(temp_frame_paths) // roop.globals.execution_threads
|
||||
remaining_frames = len(temp_frame_paths) % roop.globals.execution_threads
|
||||
start_index = 0
|
||||
# create threads by frames
|
||||
for _ in range(roop.globals.execution_threads):
|
||||
end_index = start_index + frames_per_thread
|
||||
if remaining_frames > 0:
|
||||
end_index += 1
|
||||
remaining_frames -= 1
|
||||
thread_paths = temp_frame_paths[start_index:end_index]
|
||||
thread = threading.Thread(target=process_frames, args=(source_path, thread_paths, progress))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
start_index = end_index
|
||||
# join threads
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def process_image(source_path: str, target_path: str, output_path: str) -> None:
|
||||
source_face = get_one_face(cv2.imread(source_path))
|
||||
target_frame = cv2.imread(target_path)
|
||||
result = process_faces(source_face, target_frame)
|
||||
cv2.imwrite(output_path, result)
|
||||
|
||||
|
||||
def process_video(source_path: str, temp_frame_paths: List[str], mode: str) -> None:
|
||||
progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]'
|
||||
total = len(temp_frame_paths)
|
||||
with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress:
|
||||
if mode == 'multi-processing':
|
||||
progress.set_postfix({'mode': mode, 'cores': roop.globals.cpu_cores, 'memory': roop.globals.max_memory})
|
||||
process_frames(source_path, temp_frame_paths, progress)
|
||||
elif mode == 'multi-threading':
|
||||
progress.set_postfix({'mode': mode, 'threads': roop.globals.execution_threads, 'memory': roop.globals.max_memory})
|
||||
multi_process_frame(source_path, temp_frame_paths, progress)
|
||||
15
roop/ui.py
15
roop/ui.py
@ -6,9 +6,9 @@ import cv2
|
||||
from PIL import Image, ImageTk, ImageOps
|
||||
|
||||
import roop.globals
|
||||
from roop.analyser import get_one_face
|
||||
from roop.face_analyser import get_one_face
|
||||
from roop.capturer import get_video_frame, get_video_frame_total
|
||||
from roop.swapper import process_faces
|
||||
from roop.frame_processors.core import get_frame_processor_modules
|
||||
from roop.utilities import is_image, is_video, resolve_relative_path
|
||||
|
||||
WINDOW_HEIGHT = 700
|
||||
@ -204,10 +204,13 @@ def init_preview() -> None:
|
||||
|
||||
def update_preview(frame_number: int = 0) -> None:
|
||||
if roop.globals.source_path and roop.globals.target_path:
|
||||
video_frame = process_faces(
|
||||
get_one_face(cv2.imread(roop.globals.source_path)),
|
||||
get_video_frame(roop.globals.target_path, frame_number)
|
||||
)
|
||||
for frame_processor in roop.globals.frame_processors:
|
||||
module = get_frame_processor_modules(frame_processor)
|
||||
module.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
|
||||
video_frame = module.process_faces(
|
||||
get_one_face(cv2.imread(roop.globals.source_path)),
|
||||
get_video_frame(roop.globals.target_path, frame_number)
|
||||
)
|
||||
image = Image.fromarray(cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB))
|
||||
image = ImageOps.contain(image, (PREVIEW_MAX_WIDTH, PREVIEW_MAX_HEIGHT), Image.LANCZOS)
|
||||
image = ImageTk.PhotoImage(image)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user