Use face enhancer device according to execution provider

This commit is contained in:
henryruhs 2023-07-01 23:28:38 +02:00
parent b104741e68
commit b710cc8258

View File

@ -23,10 +23,18 @@ def get_face_enhancer() -> Any:
if FACE_ENHANCER is None:
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
# todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=get_device()) # type: ignore[attr-defined]
return FACE_ENHANCER
def get_device() -> str:
if 'CUDAExecutionProvider' in roop.globals.execution_providers:
return 'cuda'
if 'CoreMLExecutionProvider' in roop.globals.execution_providers:
return 'mps'
return 'cpu'
def pre_check() -> bool:
download_directory_path = resolve_relative_path('../models')
conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth'])