本文深入探討了pytorch多標簽圖像分類任務中常見的批次大小不一致問題。通過分析自定義模型中卷積層輸出尺寸與全連接層輸入尺寸不匹配的根本原因,詳細闡述了如何精確計算張量形變后的維度,并提供修正后的PyTorch模型代碼。教程強調了張量尺寸追蹤的重要性,以及如何正確使用view操作和nn.Linear層,以確保模型輸入輸出批次的一致性,從而解決訓練過程中ValueError報錯。
1. 引言:多標簽分類與模型架構挑戰
在圖像識別任務中,多標簽分類(multi-label classification)是一種常見的場景,即一張圖像可能同時包含多個獨立的類別標簽(例如,一張藝術品圖像可能同時被標記為“印象派”、“風景畫”和“莫奈”)。為了實現這類任務,通常會采用多頭(multi-head)模型架構,即在共享的特征提取器之后,為每個分類任務設置獨立的分類頭。
在PyTorch中構建自定義模型時,尤其是在卷積層和全連接層之間進行張量形變(flattening)時,很容易出現張量尺寸計算錯誤,導致模型輸入批次與輸出批次不一致的問題。這會直接導致訓練循環中計算損失時出現ValueError: Expected input batch_size (…) to match target batch_size (…)的錯誤。
2. 問題描述與初步嘗試
本教程將以一個具體的案例來闡述這一問題。用戶嘗試為一個Wikiart數據集構建一個多標簽分類模型,需要同時預測藝術家(artist)、風格(style)和流派(genre)三個標簽。
最初,用戶嘗試基于Hugging Face的ResNetForImageClassification修改其分類頭,以適應多標簽任務。然而,直接修改model.classifier屬性并不能讓模型在forward方法中自動包含新增的多個分類頭,torchinfo的摘要也證實了這一點,模型結構仍然是單分類輸出。
# 初始嘗試:修改預訓練模型的分類頭 (不適用多頭輸出) # model2.classifier_artist = torch.nn.Sequential(...) # model2.classifier_style = torch.nn.Sequential(...) # model2.classifier_genre = torch.nn.Sequential(...)
由于預訓練模型修改的復雜性,用戶轉向了構建一個自定義的PyTorch模型WikiartModel。該模型包含共享的卷積層用于特征提取,然后分叉出三個獨立的線性分類頭。
import torch import torch.nn as nn import torch.nn.functional as F class WikiartModel(nn.Module): def __init__(self, num_artists, num_genres, num_styles): super(WikiartModel, self).__init__() # 共享卷積層 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding =1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) # 最大池化層 # 藝術家分類分支 self.fc_artist1 = nn.Linear(256 * 16 * 16, 512) # 錯誤:輸入特征維度計算有誤 self.fc_artist2 = nn.Linear(512, num_artists) # 流派分類分支 self.fc_genre1 = nn.Linear(256 * 16 * 16, 512) # 錯誤:輸入特征維度計算有誤 self.fc_genre2 = nn.Linear(512, num_genres) # 風格分類分支 self.fc_style1 = nn.Linear(256 * 16 * 16, 512) # 錯誤:輸入特征維度計算有誤 self.fc_style2 = nn.Linear(512, num_styles) def forward(self, x): # 共享卷積層處理 x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) # 張量形變:將多維特征圖展平為一維向量 x = x.view(-1, 256 * 16 * 16) # 錯誤:展平后的維度計算有誤,且-1可能導致意外行為 # 藝術家分類分支 artists_out = F.relu(self.fc_artist1(x)) artists_out = self.fc_artist2(artists_out) # 流派分類分支 genre_out = F.relu(self.fc_genre1(x)) genre_out = self.fc_genre2(genre_out) # 風格分類分支 style_out = F.relu(self.fc_style1(x)) style_out = self.fc_style2(style_out) return artists_out, genre_out, style_out # 設置類別數量 num_artists = 129 num_genres = 11 num_styles = 27
當輸入數據批次大小為32(即輸入張量形狀為[32, 3, 224, 224])時,torchinfo顯示的模型輸出批次大小為98,而不是預期的32,這導致了訓練循環中損失計算的ValueError。
3. 根本原因分析:張量尺寸計算錯誤
問題的核心在于卷積層輸出的特征圖尺寸與全連接層nn.Linear的in_features參數不匹配,以及forward方法中x.view操作的錯誤。
讓我們逐步分析輸入張量[32, 3, 224, 224]經過卷積和池化層后的尺寸變化:
- 輸入: [Batch_Size, Channels, Height, Width] -> [32, 3, 224, 224]
- self.conv1: nn.Conv2d(3, 64, kernel_size=3, padding=1)
- 輸出尺寸公式:H_out = (H_in + 2*padding – kernel_size)/stride + 1
- 224 + 2*1 – 3 / 1 + 1 = 224
- 輸出: [32, 64, 224, 224]
- self.pool: nn.MaxPool2d(2, 2) (kernel_size=2, stride=2)
- 輸出尺寸:H_out = H_in / stride
- 224 / 2 = 112
- 輸出: [32, 64, 112, 112]
- self.conv2: nn.Conv2d(64, 128, kernel_size=3, padding=1)
- 輸出: [32, 128, 112, 112]
- self.pool: nn.MaxPool2d(2, 2)
- 輸出: [32, 128, 56, 56]
- self.conv3: nn.Conv2d(128, 256, kernel_size=3, padding=1)
- 輸出: [32, 256, 56, 56]
- self.pool: nn.MaxPool2d(2, 2)
- 最終特征圖輸出: [32, 256, 28, 28]
因此,在進入全連接層之前,特征圖的尺寸應該是 [Batch_Size, 256, 28, 28]。 當將其展平為一維向量時,除了批次大小之外的維度都應相乘:256 * 28 * 28 = 200704。
然而,原始代碼中nn.Linear的in_features參數被錯誤地設置為256 * 16 * 16,這顯然與實際的256 * 28 * 28不符。 同時,x.view(-1, 256 * 16 * 16)中的-1表示PyTorch會自動推斷該維度,但由于其后指定的維度256 * 16 * 16與實際的展平尺寸不匹配,導致PyTorch在嘗試展平時,不得不調整批次大小以滿足總元素數量,從而產生了98這個錯誤的批次大小。
4. 解決方案:精確計算與正確形變
要解決此問題,需要進行兩處關鍵修改:
- 修正nn.Linear的in_features參數: 將其更改為卷積層最終輸出特征圖的展平尺寸,即 256 * 28 * 28。
- 修正x.view操作: 確保展平操作正確,并且批次大小能夠正確傳遞。推薦使用x.view(x.size(0), -1),其中x.size(0)明確指定了當前張量的批次大小,而-1則讓PyTorch自動計算剩余維度的乘積。
以下是修正后的WikiartModel代碼:
import torch import torch.nn as nn import torch.nn.functional as F class WikiartModel(nn.Module): def __init__(self, num_artists, num_genres, num_styles): super(WikiartModel, self).__init__() # 共享卷積層 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding =1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) # 計算卷積層最終輸出的特征圖尺寸(例如,對于224x224輸入,經過三次conv+pool后為28x28) # 建議在模型初始化時或通過一個小的dummy_input計算得出 # 確保這里的尺寸與實際計算結果一致 self.feature_map_size = 28 # 經過三次池化后,224 -> 112 -> 56 -> 28 self.flattened_features = 256 * self.feature_map_size * self.feature_map_size # 256 * 28 * 28 = 200704 # 藝術家分類分支 self.fc_artist1 = nn.Linear(self.flattened_features, 512) # 修正此處輸入特征維度 self.fc_artist2 = nn.Linear(512, num_artists) # 流派分類分支 self.fc_genre1 = nn.Linear(self.flattened_features, 512) # 修正此處輸入特征維度 self.fc_genre2 = nn.Linear(512, num_genres) # 風格分類分支 self.fc_style1 = nn.Linear(self.flattened_features, 512) # 修正此處輸入特征維度 self.fc_style2 = nn.Linear(512, num_styles) def forward(self, x): # 共享卷積層處理 x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) # 張量形變:展平張量,保留批次大小 # x.size(0) 獲取當前批次大小,-1讓PyTorch自動計算剩余維度 x = x.view(x.size(0), -1) # 藝術家分類分支 artists_out = F.relu(self.fc_artist1(x)) artists_out = self.fc_artist2(artists_out) # 流派分類分支 genre_out = F.relu(self.fc_genre1(x)) genre_out = self.fc_genre2(genre_out) # 風格分類分支 style_out = F.relu(self.fc_style1(x)) style_out = self.fc_style2(style_out) return artists_out, genre_out, style_out # 設置類別數量 num_artists = 129 num_genres = 11 num_styles = 27 # 實例化模型并進行測試 (示例) model = WikiartModel(num_artists, num_genres, num_styles) dummy_input = torch.randn(32, 3, 224, 224) # 批次大小為32的模擬輸入 artist_output, genre_output, style_output = model(dummy_input) print(f"Artist Output Shape: {artist_output.shape}") # 預期: [32, 129] print(f"Genre Output Shape: {genre_output.shape}") # 預期: [32, 11] print(f"Style Output Shape: {style_output.shape}") # 預期: [32, 27] # 此時,torchinfo的輸出也將顯示正確的批次大小 # from torchinfo import summary # summary(model, input_size=(32, 3, 224, 224))
5. 注意事項與最佳實踐
- 張量尺寸追蹤的重要性: 在構建自定義神經網絡時,務必在每個層之后打?。ɑ蚴褂谜{試工具如torchinfo)張量的形狀(tensor.shape或tensor.size()),以確保數據流經網絡時尺寸符合預期。這是解決這類問題的最有效方法。
- x.view(x.size(0), -1)的優勢: 使用x.size(0)明確指定批次大小,而不是依賴-1來推斷所有維度,可以避免在其他維度計算錯誤時導致批次大小被錯誤推斷。這使得代碼更健壯,不易出錯。
- 動態計算展平尺寸: 對于更復雜的模型或可變輸入尺寸,可以在forward方法中動態計算展平尺寸。例如,在展平之前,可以使用num_features = x.numel() // x.size(0)來獲取每個樣本的特征數量,然后將其用于nn.Linear層的初始化(如果模型結構允許)。但通常,對于固定輸入尺寸的模型,預先計算好nn.Linear的in_features是更常見的做法。
- 預訓練模型的使用: 如果希望利用預訓練模型(如ResNet)的強大特征提取能力,并進行多標簽分類,正確的做法是加載預訓練模型,凍結其特征提取層,然后替換或在其之上添加自定義的多個分類頭。這通常涉及到直接修改模型的classifier或fc屬性,并確保forward方法能夠正確地將特征傳遞給這些新的分類頭。對于像Hugging Face的ResNetForImageClassification,可能需要更深入地了解其內部結構或繼承并重寫其forward方法以實現多頭輸出。
6. 總結
在PyTorch中構建自定義神經網絡時,管理張量尺寸是至關重要的一環。批次大小不一致的問題通常源于卷積層輸出與全連接層輸入之間的尺寸不匹配,以及view操作的誤用。通過精確計算卷積層輸出的特征圖尺寸,并采用x.view(x.size(0), -1)這種健壯的展平方式,可以有效解決這類問題,確保數據在網絡中順暢流動,并避免訓練過程中的ValueError。養成良好的張量尺寸追蹤習慣,將大大提高模型開發的效率和準確性。