torch.distributions.normal,torch.distributions.normal.log_prob,torch.distributions.normal.rsample
pytorch的torch.distributions中可以定义正态分布
如下:
sample()
sample()就是直接在定义的正太分布(均值为mean,标准差std是1)上采样:
c=normal.sample() print("c:",c)输出:
c: tensor([-1.3362, 3.1730])rsample()
rsample()不是在定义的正太分布上采样,而是先对标准正太分布N(0,1)N(0,1)N(0,1)进行采样,然后输出:mean+std×采样值mean+std\times采样值mean+std×采样值
a=normal.rsample()输出:
a: tensor([ 0.0530, 2.8396])log_prob(value)
log_prob(value)是计算value在定义的正态分布(mean,1)中对应的概率的对数,正太分布概率密度函数是f(x)=12πσe−(x−μ)22σ2f(x)=\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}}f(x)=2πσ1e−2σ2(x−μ)2,对其取对数可得log(f(x))=−(x−μ)22σ2−log(σ)−log(2π)log(f(x))=-\frac{(x-\mu)^2}{2\sigma^2}-log(\sigma)-log(\sqrt{2\pi})log(f(x))=−2σ2(x−μ)2−log(σ)−log(2π)
这里我们通过对数概率还原其对应的真实概率:
输出:
c log_prob: tensor([ 0.1634, 0.2005])总结
以上是生活随笔为你收集整理的torch.distributions.normal,torch.distributions.normal.log_prob,torch.distributions.normal.rsample的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: 顾客点餐系统-----后端代码编写(基于
- 下一篇: JAVA面向对象(2)