pytorch中的鉤子(Hook)有何作用?

時間 2021-05-11 23:22:14

1樓:OLDPAN

Hook的功能一樓說的很好了。

我這裡提一下pytorch中hook的Bug。Tensor的hook沒有問題,但是Module的hook在使用的時候是有乙個小問題需要注意的,也就是,具體可以看這裡:

pytorch中autograd以及hook函式詳解 - Oldpan的個人部落格

2樓:

正好最近也在看,就做一回文件搬運工吧。

首先明確一點,有哪些hook?

我看到的有3個:

1. torch.autograd.

Variable.register_hook (Python method, in Automatic differentiation package

2. torch.nn.Module.register_backward_hook (Python method, in torch.nn)

3. torch.nn.Module.register_forward_hook

第乙個是register_hook,是針對Variable物件的,後面的兩個:register_backward_hook和register_forward_hook是針對nn.Module這個物件的。

其次,明確一下,為什麼需要用hook

打個比方,有這麼個函式, , , 你想通過梯度下降法求最小值。在PyTorch裡面很容易實現,你只需要:

import

torch

from

torch.autograd

import

Variablex=

Variable

(torch

.randn(2

,1),requires_grad

=True)y

=x+2

z=torch

.mean

(torch

.pow(y

,2))lr

=1e-3z.

backward()x

.data

-=lr*x

.grad

.data

但問題是,如果我想要求中間變數 的梯度,系統會返回錯誤。

事實上,如果你輸入:

type(y.grad)

系統會告訴你:NoneType

這個問題在PyTorch的論壇上有人提問過,開發者說是因為當初開發時設計的是,對於中間變數,一旦它們完成了自身反傳的使命,就會被釋放掉。

因此,hook就派上用場了。簡而言之,register_hook的作用是,當反傳時,除了完成原有的反傳,額外多完成一些任務。你可以定義乙個中間變數的hook,將它的grad值列印出來,當然你也可以定義乙個全域性列表,將每次的grad值新增到裡面去。

import

torch

from

torch.autograd

import

Variable

grad_list=

defprint_grad

(grad

):grad_list.(

grad)x

=Variable

(torch

.randn(2

,1),requires_grad

=True)y

=x+2

z=torch

.mean

(torch

.pow(y

,2))lr

=1e-3y.

register_hook

(print_grad)z

.backward()x

.data

-=lr*x

.grad

.data

需要注意的是,register_hook函式接收的是乙個函式,這個函式有如下的形式:

hook(grad) -> Variable or None

也就是說,這個函式是擁有改變梯度值的威力的!

至於register_forward_hook和register_backward_hook的用法和這個大同小異。只不過物件從Variable改成了你自己定義的nn.Module。

當你訓練乙個網路,想要提取中間層的引數、或者特徵圖的時候,使用hook就能派上用場了。

pytorch中,相同的batchsize,多GPU會比單GPU快多少?雙路能是單路的兩倍嗎?

張懷文 相同的batchsize,多gpu會不會比單卡快都是個問題。我遇到過多卡時候,每個卡頻率撞牆的問題,甚至比單卡慢。就是快的時候,也不會有N張卡,提速N倍的能力。多卡並行相關程式的開銷 多卡策略的開銷都不小。 勒布朗詹姆斯哈登 不會,pytorch有很多計算都是只在第一塊卡上進行的。而且如果雙...

PyTorch 中,nn 與 nn functional 有什麼區別?

老實人 上面使用者有糖吃可好 講的已經挺好了,我再插兩句 在建圖過程中,往往有兩種層,一種含引數 有Variable,如全連線層,卷積層 Batch Normlization層等 另一種不含引數 無Variable,如Pooling層,Relu層,損失函式層等。閱讀原始碼發現 nn.裡面的是繼承自n...

pytorch 中的Dataset這個類為什麼可以呼叫 getitem ?

王小山 在DataLoder的iter中,會觸發子類Dataset中的getiterm函式讀取資料,並拼接成乙個batch返回,作為模型真正的輸入 操作符過載 如果乙個類定義了名為 getitem 的方法,x為該類的乙個例項 x i 可是為x.getitem x,i Goodbye響 我覺得題主你想...