欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

torch.distributions.normal,torch.distributions.normal.log_prob,torch.distributions.normal.rsample

发布时间:2023/12/8 编程问答 43 豆豆
生活随笔 收集整理的这篇文章主要介绍了 torch.distributions.normal,torch.distributions.normal.log_prob,torch.distributions.normal.rsample 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

pytorch的torch.distributions中可以定义正态分布
如下:

import torch from torch.distributions import Normal mean=torch.Tensor([0,2]) normal=Normal(mean,1)

sample()

sample()就是直接在定义的正太分布(均值为mean,标准差std是1)上采样:

c=normal.sample() print("c:",c)

输出:

c: tensor([-1.3362, 3.1730])

rsample()

rsample()不是在定义的正太分布上采样,而是先对标准正太分布N(0,1)N(0,1)N(0,1)进行采样,然后输出:mean+std×采样值mean+std\times采样值mean+std×

a=normal.rsample()

输出:

a: tensor([ 0.0530, 2.8396])

log_prob(value)

log_prob(value)是计算value在定义的正态分布(mean,1)中对应的概率的对数,正太分布概率密度函数是f(x)=12πσe−(x−μ)22σ2f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}}f(x)=2πσ1e2σ2(xμ)2,对其取对数可得log(f(x))=−(x−μ)22σ2−log(σ)−log(2π)log(f(x))=-\frac{(x-\mu)^2}{2\sigma^2}-log(\sigma)-log(\sqrt{2\pi})log(f(x))=2σ2(xμ)2log(σ)log(2π)
这里我们通过对数概率还原其对应的真实概率:

print("c log_prob:",normal.log_prob(c).exp())

输出:

c log_prob: tensor([ 0.1634, 0.2005])

总结

以上是生活随笔为你收集整理的torch.distributions.normal,torch.distributions.normal.log_prob,torch.distributions.normal.rsample的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。