Kaggle中使用
!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git
import os
from huggingface_hub import hf_hub_download
import torch
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path)
# pth -> pth转换
map_key = [
["conv_body", "sub.23"],
["body", "sub"],
["rdb", "RDB"],
["", "model.1."],
["conv_first", "model.0"],
["conv_up1", "model.3"],
["conv_up2", "model.6"],
["conv_hr", "model.8"],
["conv_last", "model.10"],
[".w", ".0.w"],
[".b", ".0.b"]
]
state_dict = {}
for k in list(weights.keys()):
v = weights[k]
for m_k in map_key:
k = k.replace(m_k[1], m_k[0])
state_dict[k] = v
# 权重保存路径
model_path_pt = "./state_dict.pth"
torch.save(state_dict, model_path_pt)
import torch, time
from PIL import Image
import numpy as np
from RealESRGAN import RealESRGAN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型和权重
model = RealESRGAN(device, scale=4)
model.load_weights(model_path_pt )
# 推理
path_to_images = "/kaggle/input/esr-low-img/1Z412103326-1-1200.jpg"
image = Image.open(path_to_image).convert('RGB')
sr_image = model.predict(image)
# 保存图片
sr_image.save(str(time.time()) + 'sr_image.png')
转换为onnx模型
!pip install onnxruntime
# to onnx 这里只是举个例子
import torch
dummy_input = torch.randn(1, 3, 224, 224).to("cuda") # 输入示例
rdb = model.model.eval() # 原作者自定义的RealESRGAN 下的继承了 torch.nn.Module 的 model 组件才可以被转换为onnx
# onnx模型转换以及保存
onnx_model_path = "./RDB.onnx"
torch.onnx.export(
rdb,
dummy_input,
onnx_model_path
)
# 使用onnx模型推理
import numpy as np
import onnxruntime as ort
# 加载 ONNX 模型
ort_session = ort.InferenceSession(onnx_model_path)
# 获取输入和输出的名称
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
# 准备输入数据
input_data = np.random.rand(1, 3, 224, 224).astype(np.float32)
# 运行推理
results = ort_session.run([output_name], {input_name: input_data})
# 处理推理结果
output = results[0]
# ...
转换为paddle模型
paddle中超分体验链接
!pip install paddlepaddle
# to paddle
import paddle
import os
import torch
from huggingface_hub import hf_hub_download
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path)
# 键映射
map_key_pd = [
["trunk_conv", "model.1.sub.23"],
["RRDB_trunk", "model.1.sub"],
["conv_first", "model.0"],
["upconv1", "model.3"],
["upconv2", "model.6"],
["HRconv", "model.8"],
["conv_last", "model.10"],
[".w", ".0.w"],
[".b", ".0.b"]
]
state_dict = {}
for k in list(weights.keys()):
v = weights[k].numpy()
for m_k in map_key_pd:
k = k.replace(m_k[1], m_k[0])
state_dict[k] = paddle.to_tensor(v)
# state_dict.keys()
state_dict_pd = {}
state_dict_pd["generator"] = state_dict
# state_dict_pd["generator"].keys()
# 保存
model_path_pd = "./ERRGAN_UltralSharp_X4.pdparams"
paddle.save(state_dict_pd, model_path_pd)