pytorch中hook有三种函数,分别是register_hook, register_backward_hook, register_forward_hook,第一个是针对Variable,后面两个是针对modules的
- register_hook函数 针对中间层的Variable的梯度进行处理,比如修改和打印
1 | # 打印中间层Variable的梯度 |
1 | # 修改中间层Variable的梯度 |
- register_backward_hook函数 该函数是注册在module上的,而不是在Variable上,同时该module必须是一个function,而不是有container的函数,里面不能包含多个module。
具体的形式是function(module, grad_in, grad_out),该函数可以返回一个新的grad_in用于替代原始的grad_in。而不是直接修改grad_in
1 | def bh(m,gi,go): |
- register_forward_hook 函数
该函数是先进行正常的Forward方法, 然后对于Forward以后的结果,进行自定义的处理。 注意该函数与register_backward_hook不同,他不能改变output。而register_backward_hook是可以用一个新的grad_input来替代grad。
PS: 注意在dataParalle中的load函数的话,一定要在model parallel之后,进行load,因为通过parallel以后,函数的keys会在原来的key的基础上加module.