欢迎访问 生活随笔!

生活随笔

当前位置: 首页 >

11_拼接与拆分,cat,stack,split,chunk

发布时间:2024/9/27 48 豆豆
生活随笔 收集整理的这篇文章主要介绍了 11_拼接与拆分,cat,stack,split,chunk 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

1.11.拼接与拆分
1.11.1.cat
1.11.2.Stack
1.11.3.split
1.11.4.chunk

1.11.拼接与拆分

1.11.1.cat

numpy中使用concat,在pytorch中使用更加简写的 cat
完成一个拼接
两个向量维度相同,想要拼接的维度上的值可以不同,但是其它维度上的值必须相同。

举个例子:还是按照前面的,想将这两组班级的成绩合并起来
a[class 1-4, students, scores]
b[class 5-9, students, scores]

# -*- coding: UTF-8 -*-import torcha = torch.rand(4, 32, 8) b = torch.rand(5, 32, 8)print(torch.cat([a, b], dim=0).shape) """输出结果:torch.Size([9, 32, 8]) 结果就是9个班级的成绩 """

理解cat:
行拼接:[4, 4] 与 [5, 4] 以 dim=0(行)进行拼接 —> [9, 4] 9个班的成绩合起来。
列拼接:[4, 5] 与 [4, 3] 以 dim=1(列)进行拼接 —> [4, 8] 每个班合成8项成绩

理解Cat

# -*- coding: UTF-8 -*-import torcha1 = torch.rand(4, 3, 32, 32) a2 = torch.rand(5, 3, 32, 32) print(torch.cat([a1, a2], dim=0).shape) # 合并第1维 理解上相当于合并batch """ 输出结果:torch.Size([9, 3, 32, 32]) """a2 = torch.rand(4, 1, 32, 32) print(torch.cat([a1,a2],dim=1).shape) # 合并第2维 理解上相当于合并为 rgba """ 输出结果:torch.Size([4, 4, 32, 32]) """a1 = torch.rand(4, 3, 16, 32) a2 = torch.rand(4, 3, 16, 32) print(torch.cat([a1, a2], dim=3).shape) # 合并第3维 理解上相当于合并照片的上下两半 """ 输出结果:torch.Size([4, 3, 16, 64]) """a1 = torch.rand(4, 3, 32, 32) print(torch.cat([a1, a2], dim=0).shape) """ RuntimeError: Sizes of tensors must match except in dimension 0. Got 32 and 16 in dimension 2 (The offending index is 1) """

1.11.2.Stack

创造一个新的维度(代表了新的组别)
要求两个tensor的size完全相同

# -*- coding: UTF-8 -*-import torcha1 = torch.rand(4, 3, 16, 32) a2 = torch.rand(4, 3, 16, 32) print(torch.cat([a1, a2], dim=2).shape) # 合并照片的上下部分 """ 输入结果:torch.Size([4, 3, 32, 32]) """# 添加了一个维度 一个值代表上半部分,一个值代表下半部分。 这显然是没有cat合适的。 print(torch.stack([a1, a2],dim=2).shape) """ 输入结果:torch.Size([4, 3, 2, 16, 32]) """a = torch.rand(32, 8) b = torch.rand(32, 8) # 将两个班级的学生成绩合并,添加一个新的维度,这个维度的每个值代表一个班级。显然是比cat合适的。 print(torch.stack([a,b],dim=0).shape) """ 输出结果:torch.Size([2, 32, 8]) """

1.11.3.split

按长度进行拆分:单元长度/数量
长度相同给一个固定值
长度不同给一个列表

# -*- coding: UTF-8 -*-import torcha = torch.rand(32, 8) b = torch.rand(32, 8) c = torch.rand(32, 8) d = torch.rand(32, 8) e = torch.rand(32, 8) f = torch.rand(32, 8) s = torch.stack([a, b, c, d, e, f], dim=0) print(s.shape) """ 输出结果:torch.Size([6, 32, 8]) """ aa,bb = s.split(3, dim=0) # 按数量切分,可以使用一个常数 print(aa.shape, bb.shape) """ 输出结果:torch.Size([3, 32, 8]) torch.Size([3, 32, 8]) """ cc, dd, ee = s.split([3, 2, 1], dim=0) # 按单位长度切分,可以使用一个列表 print(cc.shape, dd.shape, ee.shape) """ 输出结果: torch.Size([3, 32, 8]) torch.Size([2, 32, 8]) torch.Size([1, 32, 8]) 看到结果第一列,分别是:3,2,1 """print(s)ff, gg = s.split(6, dim=0) # 只切了一半,有一半不存在,所以报错 """ ValueError: not enough values to unpack (expected 2, got 1) """

1.11.4.chunk

按照量进行拆分

# -*- coding: UTF-8 -*-import torcha = torch.rand(32, 8) b = torch.rand(32, 8) c = torch.rand(32, 8) d = torch.rand(32, 8) e = torch.rand(32, 8) f = torch.rand(32, 8) s = torch.stack([a, b, c, d, e, f], dim=0) print(s.shape) """ 输出结果:torch.Size([6, 32, 8]) """aa, bb = s.chunk(2, dim=0) print(aa.shape, bb.shape) """ 输出结果:torch.Size([3, 32, 8]) torch.Size([3, 32, 8]) """cc, dd = s.split(3, dim=0) print(cc.shape, dd.shape) """ 输出结果:torch.Size([3, 32, 8]) torch.Size([3, 32, 8]) """

注意:对于按数量切分:chunk中的参数是要切成几份;split的常量是每份有几个。

总结

以上是生活随笔为你收集整理的11_拼接与拆分,cat,stack,split,chunk的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。