diff --git a/roop/core.py b/roop/core.py index e6c740e..36f292d 100755 --- a/roop/core.py +++ b/roop/core.py @@ -18,7 +18,7 @@ import tensorflow import roop.globals import roop.metadata from roop.predictor import predict_image, predict_video -from roop.processors.frame.core import get_frame_processors_modules +from roop.processors.frame.core import get_frame_processors_modules, list_frame_processors_names from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path warnings.filterwarnings('ignore', category=FutureWarning, module='insightface') @@ -31,7 +31,7 @@ def parse_args() -> None: program.add_argument('-s', '--source', help='select an source image', dest='source_path') program.add_argument('-t', '--target', help='select an target image or video', dest='target_path') program.add_argument('-o', '--output', help='select output file or directory', dest='output_path') - program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+') + program.add_argument('--frame-processor', help='list of available frame processors', dest='frame_processor', default=['face_swapper'], choices=list_frame_processors_names(), nargs='+') program.add_argument('--keep-fps', help='keep target fps', dest='keep_fps', action='store_true') program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true') program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true') diff --git a/roop/processors/frame/__modules__/__init__.py b/roop/processors/frame/__modules__/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/roop/processors/frame/face_enhancer.py b/roop/processors/frame/__modules__/face_enhancer.py similarity index 100% rename from roop/processors/frame/face_enhancer.py rename to roop/processors/frame/__modules__/face_enhancer.py diff --git a/roop/processors/frame/face_swapper.py b/roop/processors/frame/__modules__/face_swapper.py similarity index 100% rename from roop/processors/frame/face_swapper.py rename to roop/processors/frame/__modules__/face_swapper.py diff --git a/roop/processors/frame/frame_enhancer.py b/roop/processors/frame/__modules__/frame_enhancer.py similarity index 100% rename from roop/processors/frame/frame_enhancer.py rename to roop/processors/frame/__modules__/frame_enhancer.py diff --git a/roop/processors/frame/core.py b/roop/processors/frame/core.py index 68a9e94..0cf2a3f 100644 --- a/roop/processors/frame/core.py +++ b/roop/processors/frame/core.py @@ -5,13 +5,14 @@ import psutil from concurrent.futures import ThreadPoolExecutor, as_completed from queue import Queue from types import ModuleType -from typing import Any, List, Callable +from typing import Any, List, Callable, Optional from tqdm import tqdm import roop +from roop.utilities import list_module_names FRAME_PROCESSORS_MODULES: List[ModuleType] = [] -FRAME_PROCESSORS_INTERFACE = [ +FRAME_PROCESSORS_METHODS = [ 'pre_check', 'pre_start', 'process_frame', @@ -24,8 +25,8 @@ FRAME_PROCESSORS_INTERFACE = [ def load_frame_processor_module(frame_processor: str) -> Any: try: - frame_processor_module = importlib.import_module(f'roop.processors.frame.{frame_processor}') - for method_name in FRAME_PROCESSORS_INTERFACE: + frame_processor_module = importlib.import_module(f'roop.processors.frame.__modules__.{frame_processor}') + for method_name in FRAME_PROCESSORS_METHODS: if not hasattr(frame_processor_module, method_name): raise NotImplementedError except ModuleNotFoundError: @@ -51,6 +52,10 @@ def clear_frame_processors_modules() -> None: FRAME_PROCESSORS_MODULES = [] +def list_frame_processors_names() -> Optional[List[str]]: + return list_module_names('roop/processors/frame/__modules__') + + def multi_process_frame(source_path: str, temp_frame_paths: List[str], process_frames: Callable[[str, List[str], Any], None], update: Callable[[], None]) -> None: with ThreadPoolExecutor(max_workers=roop.globals.execution_threads) as executor: futures = [] diff --git a/roop/uis/preview.py b/roop/uis/preview.py index 6029848..f42113d 100644 --- a/roop/uis/preview.py +++ b/roop/uis/preview.py @@ -9,7 +9,7 @@ from roop.core import destroy from roop.face_analyser import get_one_face from roop.face_reference import get_face_reference, set_face_reference from roop.predictor import predict_frame -from roop.processors.frame.core import get_frame_processors_modules, load_frame_processor_module +from roop.processors.frame.core import load_frame_processor_module from roop.typing import Frame from roop.uis import core as ui from roop.utilities import is_video, is_image @@ -78,7 +78,8 @@ def get_preview_frame(temp_frame: Frame) -> Frame: reference_face = get_one_face(reference_frame, roop.globals.reference_face_position) set_face_reference(reference_face) reference_face = get_face_reference() if not roop.globals.many_faces else None - for frame_processor_module in get_frame_processors_modules(roop.globals.frame_processors): + for frame_processor in roop.globals.frame_processors: + frame_processor_module = load_frame_processor_module(frame_processor) if frame_processor_module.pre_start(): temp_frame = frame_processor_module.process_frame( source_face, diff --git a/roop/uis/settings.py b/roop/uis/settings.py index 3e292cc..e4b6b53 100644 --- a/roop/uis/settings.py +++ b/roop/uis/settings.py @@ -3,7 +3,7 @@ import gradio import onnxruntime import roop.globals -from roop.processors.frame.core import clear_frame_processors_modules +from roop.processors.frame.core import list_frame_processors_names, clear_frame_processors_modules from roop.uis import core as ui NAME = 'ROOP.UIS.OUTPUT' @@ -14,7 +14,7 @@ def render() -> None: with gradio.Box(): frame_processors_checkbox_group = gradio.CheckboxGroup( label='frame_processors', - choices=['face_swapper', 'face_enhancer', 'frame_enhancer'], + choices=list_frame_processors_names(), value=roop.globals.frame_processors ) ui.register_component('frame_processors_checkbox_group', frame_processors_checkbox_group) diff --git a/roop/utilities.py b/roop/utilities.py index 6bfaa5e..5424607 100644 --- a/roop/utilities.py +++ b/roop/utilities.py @@ -151,3 +151,10 @@ def conditional_download(download_directory_path: str, urls: List[str]) -> None: def resolve_relative_path(path: str) -> str: return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) + + +def list_module_names(path: str) -> Optional[List[str]]: + if os.path.exists(path): + files = os.listdir(path) + return [Path(file).stem for file in files if not Path(file).stem.startswith('__')] + return None