赞
踩
view
和 reshape
都用于调整张量的形状,它们的参数是新的形状,每个维度的大小可以指定为具体的数值或者 -1
。-1
表示这个维度的大小由张量的总元素数量自动推断。
new_shape
:这是一个 tuple 或者一个 list,定义了新的形状。每个元素代表对应维度的大小。-1
:特殊值,表示该维度的大小由其他维度自动推断。假设有一个张量 tensor
,形状为 [batch_size, seq_len, num_labels]
。
import torch
tensor = torch.randn(4, 3, 5) # 示例张量,形状为 (4, 3, 5)
要将其形状调整为 [12, 5]
,可以使用 view
或 reshape
。
# 使用 view
reshaped_tensor_view = tensor.view(-1, 5)
print("View tensor shape:", reshaped_tensor_view.shape) # 输出: torch.Size([12, 5])
# 使用 reshape
reshaped_tensor_reshape = tensor.reshape(-1, 5)
print("Reshape tensor shape:", reshaped_tensor_reshape.shape) # 输出: torch.Size([12, 5])
view
和 reshape
在具体应用中的参数解释在序列标记分类任务中,我们通常需要将 logits 和标签调整为适合计算损失的形状。
假设 logits 的形状为 [batch_size, seq_len, num_labels]
,我们希望将其调整为 [batch_size * seq_len, num_labels]
,以便与标签 [batch_size * seq_len]
对应。
以下是使用 view
和 reshape
的示例:
import torch import torch.nn as nn from transformers import BertTokenizer, BertForTokenClassification # 初始化模型和tokenizer model_name = 'bert-base-uncased' tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForTokenClassification.from_pretrained(model_name, num_labels=5) # 假设有5个分类 # 假设输入文本 text = "I love natural language processing." inputs = tokenizer(text, return_tensors="pt") # 获取模型输出 outputs = model(**inputs) seq_logits = outputs.logits # 假设标签映射 tags_to_idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-LOC': 3, 'I-LOC': 4} tags = torch.tensor([[0, 0, 0, 0, 1, 2, 3, 4]]) # 示例标签,形状为 (batch_size, seq_len) # 使用 reshape 调整形状 pred = seq_logits.reshape([-1, len(tags_to_idx)]) label = tags.reshape([-1]) ignore_index = tags_to_idx["O"] # 计算损失 criterion = nn.CrossEntropyLoss(ignore_index=ignore_index) loss = criterion(pred, label) print("Loss with reshape:", loss.item()) # 使用 view 调整形状 pred_view = seq_logits.view(-1, len(tags_to_idx)) label_view = tags.view(-1) # 计算损失 loss_view = criterion(pred_view, label_view) print("Loss with view:", loss_view.item())
seq_logits.reshape([-1, len(tags_to_idx)])
和 seq_logits.view(-1, len(tags_to_idx)])
:
-1
:表示这个维度的大小由其他维度自动推断。这里是将 [batch_size, seq_len, num_labels]
调整为 [batch_size * seq_len, num_labels]
。len(tags_to_idx)
:表示 num_labels
,即分类的数量。假设有一个四维张量,形状为 [2, 2, 3, 4]
,我们希望将其调整为 [4, 3, 4]
:
import torch
tensor = torch.randn(2, 2, 3, 4)
print("Original shape:", tensor.shape) # 输出: torch.Size([2, 2, 3, 4])
# 使用 view 调整形状
view_tensor = tensor.view(4, 3, 4)
print("View tensor shape:", view_tensor.shape) # 输出: torch.Size([4, 3, 4])
# 使用 reshape 调整形状
reshape_tensor = tensor.reshape(4, 3, 4)
print("Reshape tensor shape:", reshape_tensor.shape) # 输出: torch.Size([4, 3, 4])
import torch tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) transpose_tensor = tensor.t() # 转置张量 print("Transpose shape:", transpose_tensor.shape) # 输出: torch.Size([3, 2]) # 使用 view(会报错,因为内存不连续) try: view_tensor = transpose_tensor.view(-1) except RuntimeError as e: print("Error using view:", e) # 使用 contiguous 方法确保内存连续 contiguous_tensor = transpose_tensor.contiguous() view_tensor = contiguous_tensor.view(-1) print("Contiguous view tensor:", view_tensor) print("Contiguous view tensor shape:", view_tensor.shape) # 输出: torch.Size([6]) # 使用 reshape reshape_tensor = transpose_tensor.reshape(-1) print("Reshape tensor:", reshape_tensor) print("Reshape tensor shape:", reshape_tensor.shape) # 输出: torch.Size([6])
view
和 reshape
参数:
-1
表示该维度的大小由其他维度自动推断。view
的限制:要求输入张量是连续的。reshape
的灵活性:可以处理非连续内存的张量。通过这些详细的例子和解释,你可以更好地理解如何使用 view
和 reshape
来调整张量的形状。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。