11.15-16
11.15-16
終於找出困擾我一天的bug了,哭哭哦
陣列與集合陣列!
執行程式碼時發現:
preds = np.argmax(probs, axis=1)
probs好像是一個二維陣列
竟然返回的是列最大值的索引!
我非常驚訝,再三確認了argmax的函式。
沒錯啊 argmax:返回陣列最大值的索引,如果陣列是二維陣列的話:axis=0返回的是每列最大的索引,axis=1返回的是每行最大的索引。
錯誤的原因在於probs不是陣列!是一個集合,裡面有一個陣列!
陣列與集合陣列的區別在於:
1.陣列(np)可以呼叫shape函式,而集合陣列不能呼叫
2.使用print方法時,返回的東西是不同的。如圖:
可以看到,使用print函式時,如果是集合陣列,他會比陣列多[array(…)]這種東西!
使用argmax時自然也有不同,如果對陣列使用argmax,結果是一個躺平的陣列。如果時對集合的陣列使用argmax,結果是一個二維陣列!
發現問題後(將集合陣列命名為probs),我第一反應是在probs套一層np,將probs變為三維陣列,再用三維陣列套上一個axis=2來尋找最大值下標,但是問題來了,argmax(三維陣列)的結果是兩位陣列。
和上圖對比,結果是我們想要的結果,但是多了一層[ ]。對於三位陣列的argmax,貼一個還不錯的解釋:最後我把問題分享給了室友,室友一語中的:在probs後面加一個[0]就好了,意思就是取陣列集合中的第一個元素,該元素型別為陣列。淦!
這樣總算大功告成了。
我也想了一下,這個probs函式是呼叫nlp模型的predict函式得到的結果。我呼叫的是我們大佬的函式,為什麼他的lstm的predict就能執行得好好的,我的bert就返回的是一個list呢?可能是因為bert的引數更復雜吧。
查閱了資料,這個是tf.karas.Model庫(也是大佬程式碼的函式原型)對predict返回值的定義:
可能是因為bert的predict用的是transfrom包的。
天哪,水了這麼久,程式碼還沒跑完,先寫這麼多吧。