欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 人文社科 > 生活经验 >内容正文

生活经验

MXNET源码中NDArray数据的获取和打印

发布时间:2023/11/27 生活经验 45 豆豆
生活随笔 收集整理的这篇文章主要介绍了 MXNET源码中NDArray数据的获取和打印 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

虽然本人也很想写一个系列的分析文章,奈何水平不足,零零碎碎学到一点就写一点吧

本人是想学习MXNET的源码,首先想要添加一些打印,debug一下,第一个问题是如何在C++源码中打印出NDArray结构的值,

今天尝试如下,可以打印出来,

文件 incubator-mxnet/src/c_api/c_api.cc 中,函数MXNDArraySlice修改如下:

int MXNDArraySlice(NDArrayHandle handle,mx_uint slice_begin,mx_uint slice_end,NDArrayHandle *out) {NDArray *ptr = new NDArray();API_BEGIN();std::cout << "slice_begin:" << slice_begin << std::endl;std::cout << "slice_end:" << slice_end << std::endl;*ptr = static_cast<NDArray*>(handle)->SliceWithRecord(slice_begin, slice_end);*out = ptr;float *p = (float *)ptr->data().dptr_;std::cout << "p[0] = " << p[0] << std::endl;std::cout << "p[1] = " << p[1] << std::endl;API_END_HANDLE_ERROR(delete ptr);
}

Python测试代码如下

from mxnet import autograd, nd
import mxnet
print(mxnet.__version__)x = nd.arange(2, 7).reshape((5, 1))
print(x[2:4].asnumpy())

打印结果为:

#python3 mxnet_test.py 
1.5.0
slice_begin:2
slice_end:4
p[0] = 4
p[1] = 5
[[4.][5.]]

Great,可以验证出来实际的数值就是在NDArray的data()函数的dptr_指针中,

 

____________________________________________

但是在操作时有时会无法得到预期的结果,如同文件中函数MXNDArrayGetGrad,如果按照上面的代码进行打印的话,会发现打印出的值全为0,这时需要在代码中添加一行WaitToRead,如下可正常打印

int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) {API_BEGIN();NDArray *arr = static_cast<NDArray*>(handle);NDArray ret = arr->grad();if (ret.is_none()) {*out = NULL;} else {std::cout << "ret.shape().ndim() = " << ret.shape().ndim() << std::endl;std::cout << "ret.shape()[0] = " << ret.shape()[0] << std::endl;std::cout << "ret.shape()[1] = " << ret.shape()[1] << std::endl;*out = new NDArray(ret);ret.WaitToRead();float *p_float = (float *)(ret.data().dptr_);for (int i = 0; i < ret.shape()[0] * ret.shape()[1]; i++){std::cout << "p_float[" << i << "] = " << p_float[i] << std::endl;}}API_END();
}

Python 测试代码为:

from mxnet import autograd, nd
import mxnet
print(mxnet.__version__)x = nd.arange(2, 7).reshape((5, 1))x.attach_grad()with autograd.record():y = 2 * nd.dot(x.T, x)y.backward()# assert (x.grad - 4 * x).norm().asscalar() == 0
print(x.grad)

输出为:

# python3 autograd_test.py 
1.5.0
ret.shape().ndim() = 2
ret.shape()[0] = 5
ret.shape()[1] = 1
p_float[0] = 8
p_float[1] = 12
p_float[2] = 16
p_float[3] = 20
p_float[4] = 24[[ 8.][12.][16.][20.][24.]]
<NDArray 5x1 @cpu(0)>

 

总结

以上是生活随笔为你收集整理的MXNET源码中NDArray数据的获取和打印的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。