学术, 机器学习

由于内存限制导致数据无法一次性从文件加载到PyTorch 中的简单解决方法

如果训练的数据比较大,无法一次性加载,简单的解决方法是分批从文件中加载。为了演示方便,这里的 x_train, y_train 数组都是直接定义,且都是小数组,用于模拟从文件中加载大数组。说明:PyTorch 库中可能有其他的解决方案,代码好像会有点麻烦,因此这里暂不考虑。

1. 从文件中一次性加载到DataLoader

from torch.utils.data import DataLoader, TensorDataset
import torch
x_train = torch.randn(500, 20) # 全部数据加载
y_train = torch.randn(500, 1) # 全部数据加载
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
for batch_x, batch_y in train_loader:
    print(batch_x.shape)
    print(batch_y.shape)

运行结果:

torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([20, 20])
torch.Size([20, 1])

更多阅读:批量训练中迭代次数的计算

如果 x_train 和 y_train 的数据量很多,不只是 500 个,而是 1e6 个,甚至更多,无法一次性加载训练,那么可以考虑分批从文件中加载。

2. 分批从文件中加载

尽管无法在整个数据集上进行一次性的 shuffle,但通过将数据分批从文件中加载,并对每个数据集进行独立的 shuffle 和训练,可以达到类似的效果。

代码示例:

from torch.utils.data import DataLoader, TensorDataset
import torch
for i0 in range(5):
    x_train = torch.randn(100, 20) # 小文件加载
    y_train = torch.randn(100, 1) # 小文件加载
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    for batch_x, batch_y in train_loader:
        print(batch_x.shape)
        print(batch_y.shape)
    if i0 == 0:
        print('Training model...')
    else:
        print('Continue training model...')

运行结果:

torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([4, 20])
torch.Size([4, 1])
Training model...
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([4, 20])
torch.Size([4, 1])
Continue training model...
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([4, 20])
torch.Size([4, 1])
Continue training model...
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([4, 20])
torch.Size([4, 1])
Continue training model...
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([32, 20])
torch.Size([32, 1])
torch.Size([4, 20])
torch.Size([4, 1])
Continue training model...

3. 使用额外代码打乱文件中的数据,然后再分批从文件中加载

如果每个文件数据都比较特别,分批文件加载训练时,可能会导致损失函数波动比较大,因此对文件中的数据进行打乱(shuffle)是十分必要的。数据的合并参考这篇博文:Pytorch张量数组的合并

需要注意的是:由于 x_train 和 y_train 是一一对应的关系,因此不能简单的随机排列。这里只是分别从小文件中读取数据,保存为新的小文件,没有做额外的随机处理。

代码示例:

import torch

def load_data_from_file(filename):
    x_train = torch.randn(100, 20) # 小文件加载
    y_train = torch.randn(100, 1) # 小文件加载
    return x_train, y_train

x_train_new = []  # 新的小文件
y_train_new = []  # 新的小文件
n = 5
for i0_new in range(n):
    for i0 in range(n):
        x_train, y_train = load_data_from_file(filename=str(i0))
        if i0 == 0:
            x_train_new = x_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]
            y_train_new = y_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]
        else:
            x_train_new = torch.cat((x_train_new, x_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]), dim=0)
            y_train_new = torch.cat((y_train_new, y_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]), dim=0)
    print(x_train_new.shape)
    print(y_train_new.shape)
    print('Save new file!')

运行结果:

torch.Size([100, 20])
torch.Size([100, 1])
Save new file!
torch.Size([100, 20])
torch.Size([100, 1])
Save new file!
torch.Size([100, 20])
torch.Size([100, 1])
Save new file!
torch.Size([100, 20])
torch.Size([100, 1])
Save new file!
torch.Size([100, 20])
torch.Size([100, 1])
Save new file!
506 次浏览

【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注

Captcha Code