Encode and decode execution providers for easier usage

This commit is contained in:
henryruhs 2023-06-13 12:09:10 +02:00
parent 7e8080732c
commit 319738f8d7

View File

@ -49,7 +49,7 @@ def parse_args() -> None:
parser.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18)
parser.add_argument('--max-memory', help='maximum amount of RAM in GB to be used', dest='max_memory', type=int, default=suggest_max_memory())
parser.add_argument('--cpu-cores', help='number of CPU cores to use', dest='cpu_cores', type=int, default=suggest_cpu_cores())
parser.add_argument('--execution-provider', help='execution provider', dest='execution_provider', default=['CPUExecutionProvider'], choices=onnxruntime.get_available_providers(), nargs='+')
parser.add_argument('--execution-provider', help='execution provider', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
parser.add_argument('--execution-threads', help='number of threads to be use for the GPU', dest='execution_threads', type=int, default=suggest_execution_threads())
args = parser.parse_known_args()[0]
@ -67,10 +67,19 @@ def parse_args() -> None:
roop.globals.video_quality = args.video_quality
roop.globals.max_memory = args.max_memory
roop.globals.cpu_cores = args.cpu_cores
roop.globals.execution_providers = args.execution_provider
roop.globals.execution_providers = decode_execution_providers(args.execution_provider)
roop.globals.execution_threads = args.execution_threads
def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
def decode_execution_providers(execution_providers: List[str]) -> List[str]:
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
def suggest_max_memory() -> int:
if platform.system().lower() == 'darwin':
return 4
@ -83,6 +92,10 @@ def suggest_cpu_cores() -> int:
return int(max(psutil.cpu_count() / 2, 1))
def suggest_execution_providers() -> List[str]:
return encode_execution_providers(onnxruntime.get_available_providers())
def suggest_execution_threads() -> int:
if 'DmlExecutionProvider' in roop.globals.execution_providers:
return 1