您現在的位置是:首頁 > 音樂首頁音樂

FastAI實踐:貓狗品種分類

由 鴨鴿營—浩 發表于 音樂2021-06-26
簡介fit_one_cycle(4, slice(lr)) # 進行訓練,4為epoch數,slice(lr)是學習率範圍 modelPath= learn

犬類包含貓嗎

1、資料集介紹  The Oxford-IIIT Pet Dataset是一個寵物影象資料集,包含37種寵物,其中有犬類25類,貓類12類,每種寵物200張左右寵物圖片,並同時包含寵物輪廓標註資訊。  地址:http://www。robots。ox。ac。uk/~vgg/data/pets/2、FastAI使用神經網路訓練的步驟: 1。建立神經網路模型,預先訓練好的比較好,表達能力更強 2。查詢合適的學習率進行學習 3。進行學習(該步驟其實為遷移學習:使用pretrained模型加上自己的分類,然後進行學習) 4。保留當前狀態 5。unfreeze:解凍模型所有權重,準備 6。查詢合適的學習率。 7。再次進行學習,(進行fine tune:微調,在原有模型上加上自己的最後fully connection layer,然後進行所有權重訓練。)3、完整程式碼# import * 對於程式設計人員來說會有點難以接受,主要目的是方便實驗from fastai import *from fastai。vision import *# 不怕不會,怕不知道怎麼學,在進行一切之前,首先要學會如下:#doc(get_transforms()) # 獲得該方法的文件,學會使用該方法# 使用上面的方法結合訪問官方文件介紹網站來進行學習,官方文件url文中最開始有提供。#FastAi程式碼要在if __name__ == ‘__main__’:下面執行if __name__ == ‘__main__’: # 1。示例中的資料為PETS,是一個用來進行貓狗品種分類的資料集。 path = untar_data(URLs。PETS) # 該方法首先下載資料集到目錄中,再進行解壓,返回解壓路徑 print(path) # 列印path,其為下載的目錄 print(path。ls()) # 看看解壓路徑中的檔案 paths_img = path / ‘images’ # 該路徑為影象檔案路徑,也就是真正的建模資料集 images = get_image_files(paths_img) # 該方法為獲取路徑中的所有檔案 print(images[:5]) # 2。載入路徑中的資料集 pat = r‘/([^/]+)_\d+。jpg$’ # 用於匹配圖片名字的表示式 tfms = get_transforms() # 這個用來進行資料增強 data = ImageDataBunch。from_name_re(paths_img, images, pat=pat, ds_tfms=tfms, size=224,bs=16) # 該方法使用資料載入物件ImageDataBunch的from_name_re(使用正則匹配檔名稱的方法),進行資料讀取。 data。show_batch(rows=3, figsize=(6, 8)) # 看看資料的樣子 print(data。c) # 列印資料集中的分類數量 data。normalize(imagenet_stats) # normalize影象,消除因為啟用函式而可能發生的梯度爆炸和梯度消失情況,加快梯度下降,很有用,詳細內容請看文章開始提供的pytorch教學連結裡面關於BN(batch_normal)的影片 learn = create_cnn(data, models。densenet161, metrics=accuracy) # 該方法建立了一個resnet34網路結構的CNN網路,在訓練中列印metrics對應的方法來展示精度,注意:metrics不影響訓練精度。 learn。lr_find() # 查詢合適的學習率 # 學習率應該選擇loss向下降程度最大點所對應的學習率 lr = 1e-3 learn。fit_one_cycle(4, slice(lr)) # 進行訓練,4為epoch數,slice(lr)是學習率範圍 modelPath= learn。save(‘stage_1。h5’, return_path=True) # 儲存第一階段遷移學習的引數,return_path=True是列印儲存目錄 print(modelPath) learn。unfreeze() # 該方法為解凍所有權重 learn。lr_find() #learn。recorder。plot() learn。fit_one_cycle(15, slice(lr)) # 開始fine tune,該lr為上面一步影象中下降最快的點對應的lr learn。save(‘stage_2。h5’) # 儲存第二階段fine tune的引數 # 4。檢視分類錯的資料都是哪些 interp = ClassificationInterpretation。from_learner(learn) losses, idxs = interp。top_losses() # 查詢分類錯誤的影象的索引 interp。plot_top_losses(9, figsize=(15, 11)) # 打印出分類錯誤的影象 # 5。分類一個試試 img = open_image(images[0]) img。show() label, idx, probability = learn。predict(img) print(‘預測的分類為{}, 機率為{:。4f}’。format(label, probability[idx]))參考文章: