读研接触深度学习也有一年了,在这一年期间复现别人的论文基本都是有现成的代码的,对代码的改动也是在AI的辅助下完成。直到最近有一个实验没有提供代码,才开始自己手搓,在复现实验的过程中可以说是毫无思路犹如无头苍蝇,代码不知道该从何写起、又该写些什么。这时候才惊觉自己的编程水平由于过于依赖AI,已经大幅退步了。于是痛定思痛,从现在开始边写边学。
遭遇的第一个难题是数据集的加载步骤,在此之前对于这部分的代码我仅仅是扫一眼完事,所以到自己写的时候脑子里一片空白。在查阅资料、询问ChatGpt、翻阅官网文档后,我大致理清了加载数据集的步骤,并归纳如下:
- 编写Dataset类
- 创建Dataloader
编写Dataset类
在PyTorch中,Dataset
类是一个抽象类,它为加载和处理数据提供了一个统一的接口。当你需要在PyTorch中使用自己的数据时,通常需要继承并实现Dataset
类来创建一个自定义的数据集对象。这种方式让数据加载更加灵活和模块化,特别是在进行机器学习和深度学习训练时。
通过Dataset类可以实现更加轻松方便地管理数据集,编写Dataset类主要需要实现以下三个方法:
__init__(self,)
:构造函数,用于初始化数据集对象。在这里,你可以加载数据文件、初始化转换等。
__len__(self)
: 必须返回整个数据集的大小。
__getitem__(self, index)
: 必须返回与给定索引index
对应的数据项。这里可以包括数据读取、预处理和返回数据项。
接下来就用我正在写的代码当做例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
| import torch from torch.utils.data import Dataset import os import numpy as np import time from utils import *
class ExampleDataset(Dataset): def __init__(self, root_dir, mode): super(ExampleDataset, self).__init__() assert mode in ["train", "validation", "test"], \ "Argument --mode could only be ['train', 'validation', 'test']" self.mode = mode self.root_dir = root_dir if mode in ['train', 'validation', 'test']: self.IDs = np.load(os.path.join(self.root_dir, 'ID_gt.npy')) if mode in ['train', 'validation']: self.other_gt = np.load(os.path.join(self.root_dir, 'other_gt.npy')) self.text_files = np.sort(os.listdir(os.path.join(self.root_dir, 'text')))
def __len__(self): return len(self.IDs)
def __iter__(self): return iter(self.IDs)
def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist()
logger = get_logger("dataset", "stream") start_time = time.time() try: text = np.load(os.path.join(self.root_dir, 'text', self.text_files[idx])) except FileNotFoundError: logger.error(f"文件未找到: {self.text_files[idx]}") return None except Exception as e: logger.exception(f"加载数据时发生错误: {str(e)}") return None end_time = time.time() logger.info(f"加载样本 {idx} 耗时 {end_time - start_time:.4f} 秒") session = { 'ID': self.IDs[idx] } if self.mode != 'test': session['text'] = text
return session
|
__iter__
方法定义允许一个类的实例表现为一个可迭代对象。其效果就是在迭代上下文中,实例返回的值是self.IDs
。
Dataset
整体就是如此,主要的逻辑都需要在__init__
和__getitem__
两个方法中根据实际情况去实现。这段代码涉及到的logger
的配置就留到以后再去写吧。
创建Dataloader
在PyTorch中,Dataset
通常与DataLoader
一起使用,后者可以提供批处理、打乱数据、多进程加载等功能,使得数据加载更加高效和方便。
同样直接给出代码比较直观:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| def get_dataloader(data_config): logger = get_logger('dataset', 'stream') try: dataset_mode = data_config['dataset_mode'] dataset_dir_key = f'{dataset_mode}_dataset_dir' if dataset_dir_key not in data_config: error_msg = f"数据集目录未指定: {dataset_dir_key} 需要在配置中。" logger.error(error_msg) raise ValueError(error_msg) dataset_dir = data_config[f'{dataset_mode}_dataset_dir'] dataset = ExampleDataset(dataset_dir, dataset_mode) dataloader_config = data_config['dataloader_config'] dataloader = DataLoader(dataset, **dataloader_config) return dataloader except KeyError as e: logger.exception("配置缺失/错误") return None
|
Dataloader
的使用比较简单,创建好Dataset
实例后,再配置好需要的参数即可。 dataloader_config
常用的参数有:
batch_size
:每一批数据的尺寸。
shuffle
:是否对数据进行洗牌打乱顺序。
sampler
:指定采样方法。
完整的参数列表可以参考官方文档。
Dataloader
是一个可迭代的对象,因此在训练循环中可以直接使用:
1 2 3
| for data, labels in dataloader: pass
|
使用Pytorch去加载数据集的流程大致如上,行笔至此,如之后发现还有需要完善的地方就留待之后再写吧。