mirror of
https://github.com/s0md3v/roop.git
synced 2025-12-06 18:08:29 +00:00
Encode and decode execution providers for easier usage
This commit is contained in:
parent
7e8080732c
commit
319738f8d7
17
roop/core.py
17
roop/core.py
@ -49,7 +49,7 @@ def parse_args() -> None:
|
||||
parser.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18)
|
||||
parser.add_argument('--max-memory', help='maximum amount of RAM in GB to be used', dest='max_memory', type=int, default=suggest_max_memory())
|
||||
parser.add_argument('--cpu-cores', help='number of CPU cores to use', dest='cpu_cores', type=int, default=suggest_cpu_cores())
|
||||
parser.add_argument('--execution-provider', help='execution provider', dest='execution_provider', default=['CPUExecutionProvider'], choices=onnxruntime.get_available_providers(), nargs='+')
|
||||
parser.add_argument('--execution-provider', help='execution provider', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
|
||||
parser.add_argument('--execution-threads', help='number of threads to be use for the GPU', dest='execution_threads', type=int, default=suggest_execution_threads())
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
@ -67,10 +67,19 @@ def parse_args() -> None:
|
||||
roop.globals.video_quality = args.video_quality
|
||||
roop.globals.max_memory = args.max_memory
|
||||
roop.globals.cpu_cores = args.cpu_cores
|
||||
roop.globals.execution_providers = args.execution_provider
|
||||
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
|
||||
roop.globals.execution_threads = args.execution_threads
|
||||
|
||||
|
||||
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
|
||||
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
|
||||
|
||||
|
||||
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
|
||||
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
|
||||
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
|
||||
|
||||
|
||||
def suggest_max_memory() -> int:
|
||||
if platform.system().lower() == 'darwin':
|
||||
return 4
|
||||
@ -83,6 +92,10 @@ def suggest_cpu_cores() -> int:
|
||||
return int(max(psutil.cpu_count() / 2, 1))
|
||||
|
||||
|
||||
def suggest_execution_providers() -> List[str]:
|
||||
return encode_execution_providers(onnxruntime.get_available_providers())
|
||||
|
||||
|
||||
def suggest_execution_threads() -> int:
|
||||
if 'DmlExecutionProvider' in roop.globals.execution_providers:
|
||||
return 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user