一个小坑,多输入的pytorch模型在导出pt文件供libtorch调用时候,python下模型的forward方法不能使用tuple的形式传入inputs。
pytorch下
pytorch下多输入比较方便,修改模型的forward方法就可以。
数据输入
在处理图片时经常需要多输入,比如分类时输入额外特征、检测时输入定界框,而pytorch已经实现了基本的Dataset类,实现多输入使用的就是派生一个自定义的Dataset然后实现数据读取以及__getitem__、__len__方法供Dataloader调用。
下面是自己用到的代码,自定义还是比较简单。
import os
import numpy as np
from torch.utils.data import DataLoader,Dataset
from torchvision.datasets.folder import default_loader
class CustomDataset(Dataset):
def __init__(self,
img_path,
txt_path,
loader = default_loader,
img_transform=None,
):
with open(txt_path, 'r') as f:
lines = f.readlines()
self.imgs = [
os.path.join(img_path, i.split(',')[0].partition('\\')[2]) for i in lines
]
self.label_list = [i.split('\\')[1] for i in lines]
self.feature_list = np.array([list(map(float,[i.split(',')[1],i.split(',')[2],
i.split(',')[3],i.split(',')[4],
i.split(',')[5],i.split(',')[6]]))
for i in lines])
self.feature_list[:,2:] = self.feature_list[:,2:] / 28
self.img_transform = img_transform
self.loader = loader
self.labels = list(set(self.label_list))
self.labels.sort()
self.class_to_idx = dict(zip(self.labels ,range(len(self.labels))))
self.label_list = [self.class_to_idx[c] for c in self.label_list]
def __getitem__(self, index):
img_path = self.imgs[index]
extra_feature = self.feature_list[index]
label = self.label_list[index]
img = self.loader(img_path)
if self.img_transform is not None:
img = self.img_transform(img)
return img, extra_feature, label
def __len__(self):
return len(self.label_list)
libtorch下
libtorch下模型的forward方法的输入是一个向量,如果模型的forward方法每个参数对应一个输入的话,在对应位置输入就没问题,但是如果python模型使用了tuple来传inputs,那可能遇到下面几种参数不匹配的错误,在libtorch中调用模型的forward时无法将输入的std::vector<torch::IValue>转换为(Tensor, Tensor)。
Expected value of type (Tensor, Tensor) for argument 'argument_1' in position 0, but instead got value of type Tensor. Declaration: forward((Tensor, Tensor) argument_1) -> Tensor
或者
Expected at most 1 argument(s) for operator 'forward', but received 2 argument(s). Declaration: forward((Tensor, Tensor) argument_1) -> Tensor
另外还有在torch.jit.trace阶段可能碰到的错误,可以多包裹一层tuple解决,例如((sample_input_1, sample_input_2),)。
TypeError: forward() takes 2 positional arguments but 3 were given