一个小坑,多输入的pytorch模型在导出pt文件供libtorch调用时候,python下模型的forward方法不能使用tuple的形式传入inputs。

pytorch下

pytorch下多输入比较方便,修改模型的forward方法就可以。

数据输入

在处理图片时经常需要多输入,比如分类时输入额外特征、检测时输入定界框,而pytorch已经实现了基本的Dataset类,实现多输入使用的就是派生一个自定义的Dataset然后实现数据读取以及__getitem__、__len__方法供Dataloader调用。

下面是自己用到的代码,自定义还是比较简单。


import os
import numpy as np
from torch.utils.data import DataLoader,Dataset
from torchvision.datasets.folder import default_loader


class CustomDataset(Dataset):
    def __init__(self,
                 img_path,
                 txt_path,
                 loader = default_loader,
                 img_transform=None,
                 ):
        with open(txt_path, 'r') as f:
            lines = f.readlines()
            self.imgs = [
                os.path.join(img_path, i.split(',')[0].partition('\\')[2]) for i in lines
            ]
            self.label_list = [i.split('\\')[1] for i in lines]
            self.feature_list = np.array([list(map(float,[i.split(',')[1],i.split(',')[2],
                                            i.split(',')[3],i.split(',')[4],
                                            i.split(',')[5],i.split(',')[6]]))
                                 for i in lines])
        self.feature_list[:,2:] = self.feature_list[:,2:] / 28
        self.img_transform = img_transform
        self.loader = loader
        self.labels = list(set(self.label_list))
        self.labels.sort()
        self.class_to_idx = dict(zip(self.labels ,range(len(self.labels))))
        self.label_list = [self.class_to_idx[c] for c in self.label_list]

    def __getitem__(self, index):
        img_path = self.imgs[index]
        extra_feature = self.feature_list[index]
        label = self.label_list[index]
        img = self.loader(img_path)
        if self.img_transform is not None:
            img = self.img_transform(img)
        return img, extra_feature, label

    def __len__(self):
        return len(self.label_list)

libtorch下

libtorch下模型的forward方法的输入是一个向量,如果模型的forward方法每个参数对应一个输入的话,在对应位置输入就没问题,但是如果python模型使用了tuple来传inputs,那可能遇到下面几种参数不匹配的错误,在libtorch中调用模型的forward时无法将输入的std::vector<torch::IValue>转换为(Tensor, Tensor)。

Expected value of type (Tensor, Tensor) for argument 'argument_1' in position 0, but instead got value of type Tensor. Declaration: forward((Tensor, Tensor) argument_1) -> Tensor

或者

Expected at most 1 argument(s) for operator 'forward', but received 2 argument(s). Declaration: forward((Tensor, Tensor) argument_1) -> Tensor 

另外还有在torch.jit.trace阶段可能碰到的错误,可以多包裹一层tuple解决,例如((sample_input_1, sample_input_2),)。

TypeError: forward() takes 2 positional arguments but 3 were given