Source code for vlkit.datasets.filelist

import torch
import os.path as osp
from torchvision.datasets.folder import pil_loader
import warnings


[docs]class FileListDataset(torch.utils.data.Dataset): def __init__(self, filelist, root_dir=None, transform=None): assert osp.isfile(filelist) self.root_dir = root_dir self.transform = transform with open(filelist, 'r') as f: self.items = [i.strip().split(" ") for i in f.readlines()] if root_dir is not None: self.items = [(osp.join(root_dir, i[0]), i[1]) for i in self.items] else: self.items = [tuple(i) for i in self.items] def __getitem__(self, index): if len(self.items[index]) >= 2: path, label = self.items[index] label = int(label) else: path, = self.items[index] label = -1 warnings.warn('Use default label = -1, check your list !') assert osp.isfile(path), path img = pil_loader(path) if self.transform is not None: img = self.transform(img) return dict( img=img, label=label, path=path ) def __len__(self): return len(self.items) @property def num_classes(self): return len(set((i[1] for i in self.items)))