diff --git a/roop/core.py b/roop/core.py index d18be4a..61169fa 100755 --- a/roop/core.py +++ b/roop/core.py @@ -182,6 +182,10 @@ def start() -> None: if predict_image(roop.globals.target_path) > 0.85: destroy() roop.swapper.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path) + if roop.globals.gpu_vendor == 'nvidia' and 'face-enhancer' in roop.globals.frame_processor: + roop.enhancer.process_image(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path) + elif 'face-enhancer' in roop.globals.frame_processor: + print('face-enhancer is only supported on CUDA') if is_image(roop.globals.target_path): update_status('Swapping to image succeed!') else: @@ -201,12 +205,11 @@ def start() -> None: conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.swapper.process_video) if roop.globals.gpu_vendor == 'nvidia': torch.cuda.empty_cache() - if roop.globals.cpu_vendor == 'nvidia' and 'face-enhancer' in roop.globals.frame_processor: - update_status('enhancinging in progress...') - conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.enhancer.process_video) - else: - if 'face-enhancer' in roop.globals.frame_processor: - print('face-enhancer only surpported on CUDA') + if roop.globals.gpu_vendor == 'nvidia' and 'face-enhancer' in roop.globals.frame_processor: + update_status('enhancinging in progress...') + conditional_process_video(roop.globals.source_path, temp_frame_paths, roop.enhancer.process_video) + elif 'face-enhancer' in roop.globals.frame_processor: + print('face-enhancer is only supported on CUDA') if roop.globals.keep_fps: update_status('Detecting fps...') fps = detect_fps(roop.globals.target_path) diff --git a/roop/enhancer.py b/roop/enhancer.py index 72fd26d..1d7fdf7 100644 --- a/roop/enhancer.py +++ b/roop/enhancer.py @@ -136,6 +136,13 @@ def restore_face(face_in_tensor): return restored_face +def process_image(source_path: str, image_path: str, output_file: str) -> None: + source_face = None + image = cv2.imread(image_path) + result = process_faces(source_face, image) + cv2.imwrite(output_file, result) + + def process_frames(source_path: str, frame_paths: list[str], progress=None) -> None: source_face = None for frame_path in frame_paths: