diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ea90310..824f986 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,7 +26,6 @@ jobs: with: python-version: 3.9 - run: pip install -r requirements.txt gdown - - run: gdown 13QpWFWJ37EB-nHrEOY64CEtQWY-tz7DZ - run: python run.py -f=.github/examples/face.jpg -t=.github/examples/target.mp4 -o=.github/examples/output.mp4 - run: ffmpeg -i .github/examples/snapshot.mp4 -i .github/examples/output.mp4 -filter_complex psnr -f null - diff --git a/.gitignore b/.gitignore index 3e5e7f2..e25e7ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea +models temp __pycache__ diff --git a/roop/core.py b/roop/core.py index c204d2b..a43b061 100755 --- a/roop/core.py +++ b/roop/core.py @@ -126,9 +126,6 @@ def pre_check() -> None: quit('Python version is not supported - please upgrade to 3.9 or higher.') if not shutil.which('ffmpeg'): quit('ffmpeg is not installed!') - model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../inswapper_128.onnx') - if not os.path.isfile(model_path): - quit('File "inswapper_128.onnx" does not exist!') if roop.globals.gpu_vendor == 'apple': if 'CoreMLExecutionProvider' not in roop.globals.providers: quit('You are using --gpu=apple flag but CoreML is not available or properly installed on your system.') @@ -248,6 +245,8 @@ def destroy() -> None: def run() -> None: parse_args() pre_check() + if 'face-swapper' in roop.globals.frame_processors: + roop.swapper.pre_check() limit_resources() if roop.globals.headless: start() diff --git a/roop/swapper.py b/roop/swapper.py index 05943fe..374a65c 100644 --- a/roop/swapper.py +++ b/roop/swapper.py @@ -7,16 +7,22 @@ import insightface import threading import roop.globals from roop.analyser import get_one_face, get_many_faces +from roop.utilities import conditional_download FACE_SWAPPER = None THREAD_LOCK = threading.Lock() +def pre_check() -> None: + download_directory_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../models') + conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx']) + + def get_face_swapper() -> None: global FACE_SWAPPER with THREAD_LOCK: if FACE_SWAPPER is None: - model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../inswapper_128.onnx') + model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../models/inswapper_128.onnx') FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.providers) return FACE_SWAPPER @@ -84,7 +90,8 @@ def process_image(source_path: str, target_path: str, output_path: str) -> None: def process_video(source_path: str, temp_frame_paths: List[str], mode: str) -> None: progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' - with tqdm(total=len(temp_frame_paths), desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: + total = len(temp_frame_paths) + with tqdm(total=total, desc='Processing', unit='frame', dynamic_ncols=True, bar_format=progress_bar_format) as progress: if mode == 'cpu': progress.set_postfix({'mode': mode, 'cores': roop.globals.cpu_cores, 'memory': roop.globals.max_memory}) process_frames(source_path, temp_frame_paths, progress) diff --git a/roop/utilities.py b/roop/utilities.py index 2fcf0fe..7ff612a 100644 --- a/roop/utilities.py +++ b/roop/utilities.py @@ -2,10 +2,12 @@ import glob import os import shutil import subprocess +import urllib from pathlib import Path from typing import List import cv2 from PIL import Image +from tqdm import tqdm import roop.globals @@ -112,3 +114,15 @@ def is_video(video_path: str) -> bool: except Exception: pass return False + + +def conditional_download(download_directory_path: str, urls: List[str]): + if not os.path.exists(download_directory_path): + os.makedirs(download_directory_path) + for url in urls: + download_file_path = os.path.join(download_directory_path, os.path.basename(url)) + if not os.path.exists(download_file_path): + request = urllib.request.urlopen(url) + total = int(request.headers.get('Content-Length', 0)) + with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress: + urllib.request.urlretrieve(url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size))