学术, 机器学习

PyTorch张量数组的合并

这是之前的一篇:PyTorch基础(Tensor数据类型)

在机器学习中,可能会涉及到多个数据合并后再进行训练或者预测。本篇主要讲 PyTorch 张量数组的合并,合并的方法和NumPy会有些相似,但书写有所不同。PyTorch 用的是 torch.cat,NumPy 中用的是 np.concatenate (NumPy 中可能还有 numpy.vstack、numpy.hstack、np.row_stack、np.column_stack、numpy.stack、np.append 等方法,这里不再给出例子)。

PyTorch 数组合并的代码例子:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/41194
"""

import torch

# 定义两个张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
print(tensor1)
print(tensor2)
print()

# 第一维度的数据合并(需要其他的维度保持一致)
result1 = torch.cat((tensor1, tensor2), dim=0)
print(result1)
print(result1.shape)
print()

# 第二维度的数据合并(需要其他的维度保持一致)
result2 = torch.cat((tensor1, tensor2), dim=1)
print(result2)
print(result2.shape)
print()



# 定义多个张量
tensor1 = torch.randn(2, 10)
tensor2 = torch.randn(2, 20)
tensor3 = torch.randn(2, 30)
tensor4 = torch.randn(2, 50)

# 将这些张量放在一个列表中
tensors = [tensor1, tensor2, tensor3, tensor4]

# 第二维度的数据合并(确保所有张量的第一维度相同)
result3 = torch.cat(tensors, dim=1)
print(result3.shape)

运行结果:

tensor([[ 0.6541, -1.1516, -0.8723],
        [ 0.0564, -0.7304, -0.2876]])
tensor([[ 0.5717, -0.3831, -0.8642],
        [-0.5465, -1.6053, -0.8270]])

tensor([[ 0.6541, -1.1516, -0.8723],
        [ 0.0564, -0.7304, -0.2876],
        [ 0.5717, -0.3831, -0.8642],
        [-0.5465, -1.6053, -0.8270]])
torch.Size([4, 3])

tensor([[ 0.6541, -1.1516, -0.8723,  0.5717, -0.3831, -0.8642],        
        [ 0.0564, -0.7304, -0.2876, -0.5465, -1.6053, -0.8270]])       
torch.Size([2, 6])

torch.Size([2, 110])

NumPy 数组合并的代码例子:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/41194
"""

import numpy as np

a = np.random.rand(2, 3)
b = np.random.rand(2, 3)

# 第一维度的数据合并(需要其他的维度保持一致)
concatenated_1 = np.concatenate((a, b), axis=0)
print(concatenated_1.shape)

# 第二维度的数据合并(需要其他的维度保持一致)
concatenated_2 = np.concatenate((a, b), axis=1)
print(concatenated_2.shape)

运行结果:

(4, 3)
(2, 6)
406 次浏览

【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注

Captcha Code