MXNet中x.grad源码追溯
生活随笔
收集整理的这篇文章主要介绍了
MXNet中x.grad源码追溯
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
Python测试代码如https://zh.gluon.ai/chapter_prerequisite/autograd.html
本文追溯x.grad这一行代码的调用
grad调用的是函数MXNDArrayGetGrad,/usr/local/lib/python3.7/dist-packages/mxnet-1.5.0-py3.7.egg/mxnet/ndarray/ndarray.py
MXNDArrayGetGrad的源码依旧是在文件src/c_api/c_api.cc中,
NDArray ret = arr->grad();
ret就是获取到的梯度
这里grad的源码文件为src/ndarray/ndarray.cc,
Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);return info.out_grads[0];
这里Imperative::AGInfo::Get的源码文件为 include/mxnet/imperative.h
return dmlc::get<AGInfo>(node->info);
这里get的源码文件为3rdparty/dmlc-core/include/dmlc/any.h
return *any::TypeInfo<T>::get_ptr(&(src.data_));
这个get_ptr调用的是同文件中的如下代码:
template<typename T>
class any::TypeOnHeap {public:inline static T* get_ptr(any::Data* data) {return static_cast<T*>(data->pheap);}
回到上面的代码,那个entry_是NDArrary类的一个对象:
/*! \brief node entry for autograd */nnvm::NodeEntry entry_;
NodeEntry 源码文件为include/nnvm/node.h,
#大体来讲,梯度就是arr->entry_.node->info.data_.pheap;
总结
以上是生活随笔为你收集整理的MXNet中x.grad源码追溯的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: MXNET源码中TShape值的获取和打
- 下一篇: mxnet 中的 DepthwiseCo