move the model_path to the method

This commit is contained in:
Moeblack 2023-06-12 22:02:19 +08:00 committed by henryruhs
parent eaf899043c
commit 36446a92ca
2 changed files with 9 additions and 10 deletions

View File

@ -16,23 +16,20 @@ if 'ROCMExecutionProvider' in roop.globals.providers:
CODE_FORMER = None
THREAD_LOCK = threading.Lock()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = os.path.join("models", "codeformer.pth")
def pre_check() -> None:
pretrain_model_url = [
'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
]
download_directory_path = resolve_relative_path('../models')
conditional_download(download_directory_path, pretrain_model_url)
download_directory_path = resolve_relative_path('../models/codeformer.pth')
conditional_download(download_directory_path, ['https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'])
def get_code_former():
global CODE_FORMER
with THREAD_LOCK:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = os.path.join("models", "codeformer.pth")
if CODE_FORMER is None:
checkpoint = torch.load(ckpt_path)["params_ema"]
model = torch.load(model_path)["params_ema"]
CODE_FORMER = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
@ -40,12 +37,13 @@ def get_code_former():
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device)
CODE_FORMER.load_state_dict(checkpoint)
CODE_FORMER.load_state_dict(model)
CODE_FORMER.eval()
return CODE_FORMER
def get_face_enhancer(FACE_ENHANCER):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if FACE_ENHANCER is None:
FACE_ENHANCER = FaceRestoreHelper(
upscale_factor = int(2),
@ -91,6 +89,7 @@ def process_faces(source_face: any, frame: any) -> any:
def normalize_face(face):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
face_in_tensor = img2tensor(face / 255.0, bgr2rgb=True, float32=True)
normalize(face_in_tensor, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
return face_in_tensor.unsqueeze(0).to(device)

View File

@ -14,7 +14,7 @@ THREAD_LOCK = threading.Lock()
def pre_check() -> None:
download_directory_path = resolve_relative_path('../models')
download_directory_path = resolve_relative_path('../models/inswapper_128.onnx')
conditional_download(download_directory_path, ['https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx'])