added optional target face

This commit is contained in:
HuJohner 2023-07-07 11:53:04 +01:00
parent 3d02b26766
commit 3fec88237a
5 changed files with 60 additions and 7 deletions

View File

@ -16,6 +16,7 @@ import argparse
import torch
import onnxruntime
import tensorflow
import cv2
import roop.globals
import roop.metadata
@ -23,6 +24,7 @@ import roop.ui as ui
from roop.predicter import predict_image, predict_video
from roop.processors.frame.core import get_frame_processors_modules
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
from roop.face_analyser import get_one_face
if 'ROCMExecutionProvider' in roop.globals.execution_providers:
del torch
@ -136,10 +138,24 @@ def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
ui.update_status(message)
def process_target_face() -> None:
if roop.globals.target_face_path is None or roop.globals.many_faces:
roop.globals.target_face = None
else:
update_status('Target face selected. Processing...')
roop.globals.target_face = get_one_face(cv2.imread(roop.globals.target_face_path))
if roop.globals.target_face is None:
update_status('Failed to find face. Falling back to no target face')
else:
update_status('Successfully found target face')
def start() -> None:
for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
if not frame_processor.pre_start():
return
# process target face
process_target_face()
# process image to image
if has_image_extension(roop.globals.target_path):
if predict_image(roop.globals.target_path):

View File

@ -1,6 +1,7 @@
import threading
from typing import Any
import insightface
import numpy
import roop.globals
from roop.typing import Frame
@ -25,6 +26,15 @@ def get_one_face(frame: Frame) -> Any:
return min(face, key=lambda x: x.bbox[0])
except ValueError:
return None
def get_target_face(frame: Frame) -> Any:
face = get_face_analyser().get(frame)
try:
target_embedding = roop.globals.target_face.embedding
return max(face, key=lambda x: numpy.dot(target_embedding, x.embedding) / (numpy.linalg.norm(target_embedding) * numpy.linalg.norm(x.embedding)))
except ValueError:
return None
def get_many_faces(frame: Frame) -> Any:

View File

@ -1,6 +1,8 @@
from typing import List
source_path = None
target_face_path = None
target_face = None
target_path = None
output_path = None
frame_processors: List[str] = []

View File

@ -6,7 +6,7 @@ import threading
import roop.globals
import roop.processors.frame.core
from roop.core import update_status
from roop.face_analyser import get_one_face, get_many_faces
from roop.face_analyser import get_one_face, get_target_face, get_many_faces
from roop.typing import Face, Frame
from roop.utilities import conditional_download, resolve_relative_path, is_image, is_video
@ -61,7 +61,10 @@ def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
for target_face in many_faces:
temp_frame = swap_face(source_face, target_face, temp_frame)
else:
target_face = get_one_face(temp_frame)
if roop.globals.target_face is None:
target_face = get_one_face(temp_frame)
else:
target_face = get_target_face(temp_frame)
if target_face:
temp_frame = swap_face(source_face, target_face, temp_frame)
return temp_frame

View File

@ -28,6 +28,7 @@ RECENT_DIRECTORY_OUTPUT = None
preview_label = None
preview_slider = None
source_label = None
target_face_label = None
target_label = None
status_label = None
@ -42,7 +43,7 @@ def init(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.CTk:
global source_label, target_label, status_label
global source_label, target_face_label, target_label, status_label
ctk.deactivate_automatic_dpi_awareness()
ctk.set_appearance_mode('system')
@ -55,16 +56,22 @@ def create_root(start: Callable[[], None], destroy: Callable[[], None]) -> ctk.C
root.protocol('WM_DELETE_WINDOW', lambda: destroy())
source_label = ctk.CTkLabel(root, text=None)
source_label.place(relx=0.1, rely=0.1, relwidth=0.3, relheight=0.25)
source_label.place(relx=0.1, rely=0.1, relwidth=0.23, relheight=0.25)
target_face_label = ctk.CTkLabel(root, text=None)
target_face_label.place(relx=0.38, rely=0.1, relwidth=0.23, relheight=0.25)
target_label = ctk.CTkLabel(root, text=None)
target_label.place(relx=0.6, rely=0.1, relwidth=0.3, relheight=0.25)
target_label.place(relx=0.66, rely=0.1, relwidth=0.23, relheight=0.25)
source_button = ctk.CTkButton(root, text='Select a face', cursor='hand2', command=lambda: select_source_path())
source_button.place(relx=0.1, rely=0.4, relwidth=0.3, relheight=0.1)
source_button.place(relx=0.1, rely=0.4, relwidth=0.23, relheight=0.1)
target_face_button = ctk.CTkButton(root, text='Select a target face', cursor='hand2', command=lambda: select_target_face_path())
target_face_button.place(relx=0.38, rely=0.4, relwidth=0.23, relheight=0.1)
target_button = ctk.CTkButton(root, text='Select a target', cursor='hand2', command=lambda: select_target_path())
target_button.place(relx=0.6, rely=0.4, relwidth=0.3, relheight=0.1)
target_button.place(relx=0.66, rely=0.4, relwidth=0.23, relheight=0.1)
keep_fps_value = ctk.BooleanVar(value=roop.globals.keep_fps)
keep_fps_checkbox = ctk.CTkSwitch(root, text='Keep fps', variable=keep_fps_value, cursor='hand2', command=lambda: setattr(roop.globals, 'keep_fps', not roop.globals.keep_fps))
@ -140,6 +147,21 @@ def select_source_path() -> None:
source_label.configure(image=None)
def select_target_face_path() -> None:
global RECENT_DIRECTORY_SOURCE
PREVIEW.withdraw()
target_face_path = ctk.filedialog.askopenfilename(title='select a target image', initialdir=RECENT_DIRECTORY_SOURCE)
if is_image(target_face_path):
roop.globals.target_face_path = target_face_path
RECENT_DIRECTORY_SOURCE = os.path.dirname(target_face_path)
image = render_image_preview(target_face_path, (200, 200))
target_face_label.configure(image=image)
else:
roop.globals.target_face_path = None
target_face_label.configure(image=None)
def select_target_path() -> None:
global RECENT_DIRECTORY_TARGET