当前位置: 首页>编程语言>正文

可使用的 ESRGAN 超分模型

Kaggle中使用

!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git
import os
from huggingface_hub import hf_hub_download
import torch
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path)  

# pth -> pth转换
map_key = [
    ["conv_body", "sub.23"],
    ["body", "sub"],
    ["rdb", "RDB"],
    ["", "model.1."],
    ["conv_first", "model.0"],
    ["conv_up1", "model.3"],
    ["conv_up2", "model.6"],
    ["conv_hr", "model.8"],
    ["conv_last", "model.10"],
    [".w", ".0.w"],
    [".b", ".0.b"]
]
state_dict = {}
for k in list(weights.keys()):
    v = weights[k]
    for m_k in map_key:
        k = k.replace(m_k[1], m_k[0]) 
    state_dict[k] = v

# 权重保存路径
model_path_pt = "./state_dict.pth"   
torch.save(state_dict, model_path_pt)

import torch, time
from PIL import Image
import numpy as np
from RealESRGAN import RealESRGAN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载模型和权重
model = RealESRGAN(device, scale=4)
model.load_weights(model_path_pt ) 

# 推理
path_to_images = "/kaggle/input/esr-low-img/1Z412103326-1-1200.jpg"
image = Image.open(path_to_image).convert('RGB')
sr_image = model.predict(image)

# 保存图片
sr_image.save(str(time.time()) + 'sr_image.png')

转换为onnx模型

!pip install onnxruntime
# to onnx 这里只是举个例子
import torch

dummy_input = torch.randn(1, 3, 224, 224).to("cuda") # 输入示例
rdb = model.model.eval()  # 原作者自定义的RealESRGAN 下的继承了 torch.nn.Module 的 model 组件才可以被转换为onnx

# onnx模型转换以及保存
onnx_model_path = "./RDB.onnx" 
torch.onnx.export(
    rdb, 
    dummy_input, 
    onnx_model_path
)


# 使用onnx模型推理
import numpy as np  
import onnxruntime as ort  
  
# 加载 ONNX 模型  
ort_session = ort.InferenceSession(onnx_model_path)  
  
# 获取输入和输出的名称  
input_name = ort_session.get_inputs()[0].name  
output_name = ort_session.get_outputs()[0].name  
  
# 准备输入数据  
input_data = np.random.rand(1, 3, 224, 224).astype(np.float32)  
  
# 运行推理  
results = ort_session.run([output_name], {input_name: input_data})  
  
# 处理推理结果  
output = results[0]  
# ...

转换为paddle模型

paddle中超分体验链接

!pip install paddlepaddle 
# to paddle
import paddle
import os
import torch
from huggingface_hub import hf_hub_download

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path)

# 键映射
map_key_pd = [
    ["trunk_conv", "model.1.sub.23"],
    ["RRDB_trunk", "model.1.sub"],
    ["conv_first", "model.0"],
    ["upconv1", "model.3"],
    ["upconv2", "model.6"],
    ["HRconv", "model.8"],
    ["conv_last", "model.10"],
    [".w", ".0.w"],
    [".b", ".0.b"]
]

state_dict = {}
for k in list(weights.keys()):
    v = weights[k].numpy()
    for m_k in map_key_pd:
        k = k.replace(m_k[1], m_k[0]) 
    state_dict[k] = paddle.to_tensor(v)


# state_dict.keys()

state_dict_pd = {}
state_dict_pd["generator"] = state_dict

# state_dict_pd["generator"].keys()
# 保存
model_path_pd = "./ERRGAN_UltralSharp_X4.pdparams"
paddle.save(state_dict_pd, model_path_pd)

https://www.xamrdz.com/lan/5kb1936383.html

相关文章: