From 7abcc8b89b5e1ee82b64d5fba0ff363c41a95e8b Mon Sep 17 00:00:00 2001 From: Jagrut <202702863+jagrut-thakare@users.noreply.github.com> Date: Mon, 14 Jul 2025 13:05:34 +0530 Subject: [PATCH] v3 - Multifaceswap Support --- .gitignore | 1 + README.md | 8 +++++ roop/core.py | 4 +-- roop/globals.py | 2 +- roop/processors/frame/face_swapper.py | 50 +++++++++++++++++++-------- roop/utilities.py | 17 +++++---- 6 files changed, 59 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index e25e7ce..ce97100 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ models temp __pycache__ +*.jpg \ No newline at end of file diff --git a/README.md b/README.md index aa4f66c..918d16b 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,14 @@ Start the program with arguments: ``` python run.py [options] +``` +## For multiface swap + +``` +python run.py --many-faces --source PATH1 PATH2 --target TARGET_PATH --output ./result.jpg +``` + +``` -h, --help show this help message and exit -s SOURCE_PATH, --source SOURCE_PATH select an source image diff --git a/roop/core.py b/roop/core.py index 7e5a46f..9fe0df7 100755 --- a/roop/core.py +++ b/roop/core.py @@ -29,7 +29,7 @@ warnings.filterwarnings('ignore', category=UserWarning, module='torchvision') def parse_args() -> None: signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100)) - program.add_argument('-s', '--source', help='select an source image', dest='source_path') + program.add_argument('-s', '--source', help='select source image(s)', dest='source_path', nargs='+') program.add_argument('-t', '--target', help='select an target image or video', dest='target_path') program.add_argument('-o', '--output', help='select output file or directory', dest='output_path') program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+') @@ -51,7 +51,7 @@ def parse_args() -> None: args = program.parse_args() - roop.globals.source_path = args.source_path + roop.globals.source_path = args.source_path # Now a list of paths roop.globals.target_path = args.target_path roop.globals.output_path = normalize_output_path(roop.globals.source_path, roop.globals.target_path, args.output_path) roop.globals.headless = roop.globals.source_path is not None and roop.globals.target_path is not None and roop.globals.output_path is not None diff --git a/roop/globals.py b/roop/globals.py index 3eca8d0..894bb3c 100644 --- a/roop/globals.py +++ b/roop/globals.py @@ -1,6 +1,6 @@ from typing import List, Optional -source_path: Optional[str] = None +source_path: Optional[List[str]] = None target_path: Optional[str] = None output_path: Optional[str] = None headless: Optional[bool] = None diff --git a/roop/processors/frame/face_swapper.py b/roop/processors/frame/face_swapper.py index da68956..52db31a 100644 --- a/roop/processors/frame/face_swapper.py +++ b/roop/processors/frame/face_swapper.py @@ -39,12 +39,26 @@ def pre_check() -> bool: def pre_start() -> bool: + source_paths = roop.globals.source_path + if isinstance(source_paths, list): + images = [cv2.imread(p) for p in source_paths] + # process each image as needed + else: + image = cv2.imread(source_paths) + # process single image as needed + if not is_image(roop.globals.source_path): update_status('Select an image for source path.', NAME) return False - elif not get_one_face(cv2.imread(roop.globals.source_path)): - update_status('No face in source path detected.', NAME) - return False + if isinstance(roop.globals.source_path, list): + for p in roop.globals.source_path: + if not get_one_face(cv2.imread(p)): + update_status(f'No face detected in source path: {p}', NAME) + return False + else: + if not get_one_face(cv2.imread(roop.globals.source_path)): + update_status('No face in source path detected.', NAME) + return False if not is_image(roop.globals.target_path) and not is_video(roop.globals.target_path): update_status('Select an image or video for target path.', NAME) return False @@ -60,41 +74,49 @@ def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: return get_face_swapper().get(temp_frame, target_face, source_face, paste_back=True) -def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame: +def process_frame(source_faces: List[Face], reference_face: Face, temp_frame: Frame) -> Frame: if roop.globals.many_faces: many_faces = get_many_faces(temp_frame) if many_faces: - for target_face in many_faces: + for i, target_face in enumerate(many_faces): + # Use corresponding source face or fallback to first + source_face = source_faces[i] if i < len(source_faces) else source_faces[0] temp_frame = swap_face(source_face, target_face, temp_frame) else: target_face = find_similar_face(temp_frame, reference_face) if target_face: - temp_frame = swap_face(source_face, target_face, temp_frame) + temp_frame = swap_face(source_faces[0], target_face, temp_frame) return temp_frame -def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None: - source_face = get_one_face(cv2.imread(source_path)) +def process_frames(source_paths: List[str], temp_frame_paths: List[str], update: Callable[[], None]) -> None: + source_faces = [get_one_face(cv2.imread(path)) for path in source_paths] reference_face = None if roop.globals.many_faces else get_face_reference() for temp_frame_path in temp_frame_paths: temp_frame = cv2.imread(temp_frame_path) - result = process_frame(source_face, reference_face, temp_frame) + if temp_frame is None: + update_status(f'Could not load frame: {temp_frame_path}', NAME) + continue + result = process_frame(source_faces, reference_face, temp_frame) cv2.imwrite(temp_frame_path, result) if update: update() -def process_image(source_path: str, target_path: str, output_path: str) -> None: - source_face = get_one_face(cv2.imread(source_path)) +def process_image(source_paths: List[str], target_path: str, output_path: str) -> None: + source_faces = [get_one_face(cv2.imread(path)) for path in source_paths] target_frame = cv2.imread(target_path) + if target_frame is None: + update_status(f'Could not load target image: {target_path}', NAME) + return reference_face = None if roop.globals.many_faces else get_one_face(target_frame, roop.globals.reference_face_position) - result = process_frame(source_face, reference_face, target_frame) + result = process_frame(source_faces, reference_face, target_frame) cv2.imwrite(output_path, result) -def process_video(source_path: str, temp_frame_paths: List[str]) -> None: +def process_video(source_paths: List[str], temp_frame_paths: List[str]) -> None: if not roop.globals.many_faces and not get_face_reference(): reference_frame = cv2.imread(temp_frame_paths[roop.globals.reference_frame_number]) reference_face = get_one_face(reference_frame, roop.globals.reference_face_position) set_face_reference(reference_face) - roop.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames) + roop.processors.frame.core.process_video(source_paths, temp_frame_paths, process_frames) diff --git a/roop/utilities.py b/roop/utilities.py index 31ba7f2..33fc203 100644 --- a/roop/utilities.py +++ b/roop/utilities.py @@ -84,12 +84,15 @@ def get_temp_output_path(target_path: str) -> str: return os.path.join(temp_directory_path, TEMP_VIDEO_FILE) -def normalize_output_path(source_path: str, target_path: str, output_path: str) -> Optional[str]: - if source_path and target_path and output_path: +def normalize_output_path(source_path, target_path, output_path): + # If source_path is a list, use the first item for naming + if isinstance(source_path, list): + source_name, _ = os.path.splitext(os.path.basename(source_path[0])) + else: source_name, _ = os.path.splitext(os.path.basename(source_path)) - target_name, target_extension = os.path.splitext(os.path.basename(target_path)) - if os.path.isdir(output_path): - return os.path.join(output_path, source_name + '-' + target_name + target_extension) + target_name, _ = os.path.splitext(os.path.basename(target_path)) + if output_path is None: + return f'{source_name}_to_{target_name}.jpg' return output_path @@ -119,7 +122,9 @@ def has_image_extension(image_path: str) -> bool: return image_path.lower().endswith(('png', 'jpg', 'jpeg', 'webp')) -def is_image(image_path: str) -> bool: +def is_image(image_path): + if isinstance(image_path, list): + return all(is_image(p) for p in image_path) if image_path and os.path.isfile(image_path): mimetype, _ = mimetypes.guess_type(image_path) return bool(mimetype and mimetype.startswith('image/'))