这是之前的一篇:PyTorch基础(Tensor数据类型)。
在机器学习中,可能会涉及到多个数据合并后再进行训练或者预测。本篇主要讲 PyTorch 张量数组的合并,合并的方法和NumPy会有些相似,但书写有所不同。PyTorch 用的是 torch.cat,NumPy 中用的是 np.concatenate (NumPy 中可能还有 numpy.vstack、numpy.hstack、np.row_stack、np.column_stack、numpy.stack、np.append 等方法,这里不再给出例子)。
PyTorch 数组合并的代码例子:
"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/41194
"""
import torch
# 定义两个张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
print(tensor1)
print(tensor2)
print()
# 第一维度的数据合并(需要其他的维度保持一致)
result1 = torch.cat((tensor1, tensor2), dim=0)
print(result1)
print(result1.shape)
print()
# 第二维度的数据合并(需要其他的维度保持一致)
result2 = torch.cat((tensor1, tensor2), dim=1)
print(result2)
print(result2.shape)
print()
# 定义多个张量
tensor1 = torch.randn(2, 10)
tensor2 = torch.randn(2, 20)
tensor3 = torch.randn(2, 30)
tensor4 = torch.randn(2, 50)
# 将这些张量放在一个列表中
tensors = [tensor1, tensor2, tensor3, tensor4]
# 第二维度的数据合并(确保所有张量的第一维度相同)
result3 = torch.cat(tensors, dim=1)
print(result3.shape)
运行结果:
tensor([[ 0.6541, -1.1516, -0.8723],
[ 0.0564, -0.7304, -0.2876]])
tensor([[ 0.5717, -0.3831, -0.8642],
[-0.5465, -1.6053, -0.8270]])
tensor([[ 0.6541, -1.1516, -0.8723],
[ 0.0564, -0.7304, -0.2876],
[ 0.5717, -0.3831, -0.8642],
[-0.5465, -1.6053, -0.8270]])
torch.Size([4, 3])
tensor([[ 0.6541, -1.1516, -0.8723, 0.5717, -0.3831, -0.8642],
[ 0.0564, -0.7304, -0.2876, -0.5465, -1.6053, -0.8270]])
torch.Size([2, 6])
torch.Size([2, 110])
NumPy 数组合并的代码例子:
"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/41194
"""
import numpy as np
a = np.random.rand(2, 3)
b = np.random.rand(2, 3)
# 第一维度的数据合并(需要其他的维度保持一致)
concatenated_1 = np.concatenate((a, b), axis=0)
print(concatenated_1.shape)
# 第二维度的数据合并(需要其他的维度保持一致)
concatenated_2 = np.concatenate((a, b), axis=1)
print(concatenated_2.shape)
运行结果:
(4, 3)
(2, 6)
【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com】