aiphoto/u_2_net/my_u2net_test.py

107 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch.autograd import Variable
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from u_2_net.data_loader import RescaleT
from u_2_net.data_loader import ToTensorLab
from u_2_net.model import U2NET # full size version 173.6 MB
from PIL import Image
import os
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3
if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
sample = transform({
'imidx': np.array([0]),
'image': image,
'label': label
})
return sample
def pre_net():
# 采用n2net 模型数据
model_name = 'u2net'
path = os.path.dirname(__file__)
print(path)
model_dir = path+'/saved_models/'+ model_name + '/' + model_name + '.pth'
print(model_dir)
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
# 指定cpu
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
if torch.cuda.is_available():
net.cuda()
net.eval()
return net
def pre_test_data(img):
torch.cuda.empty_cache()
sample = preprocess(img)
inputs_test = sample['image'].unsqueeze(0)
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
return inputs_test
def get_im(pred):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
return im
def test_seg_trimap(org,alpha,alpha_resize):
# 将原始图片转换成 Alpha图
# org原始图片
# org_trimap:
# resize_trimap: 调整尺寸的trimap
image = Image.open(org)
print(image)
img = np.array(image)
net = pre_net()
inputs_test = pre_test_data(img)
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
# normalization
pred = d1[:, 0, :, :]
pred = normPRED(pred)
# 将数据转换成图片
im = get_im(pred)
im.save(alpha)
sp = image.size
# 根据原始图片调整尺寸
imo = im.resize((sp[0], sp[1]), resample=Image.BILINEAR)
imo.save(alpha_resize)
# if __name__ == "__main__":
# test_seg_trimap("..\\img\\meinv.jpg","..\\img\\trimap\\meinv_alpha.png","..\\img\\trimap\\meinv_alpha_resize.png")
# #pil_wait_blue()