mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
Use face enhancer device according to execution provider
This commit is contained in:
parent
b104741e68
commit
b710cc8258
@ -23,10 +23,18 @@ def get_face_enhancer() -> Any:
|
||||
if FACE_ENHANCER is None:
|
||||
model_path = resolve_relative_path('../models/GFPGANv1.4.pth')
|
||||
# todo: set models path https://github.com/TencentARC/GFPGAN/issues/399
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1) # type: ignore[attr-defined]
|
||||
FACE_ENHANCER = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=get_device()) # type: ignore[attr-defined]
|
||||
return FACE_ENHANCER
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
if 'CUDAExecutionProvider' in roop.globals.execution_providers:
|
||||
return 'cuda'
|
||||
if 'CoreMLExecutionProvider' in roop.globals.execution_providers:
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def pre_check() -> bool:
|
||||
download_directory_path = resolve_relative_path('../models')
|
||||
conditional_download(download_directory_path, ['https://huggingface.co/henryruhs/roop/resolve/main/GFPGANv1.4.pth'])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user