您好,欢迎来到步遥情感网。
搜索
您的当前位置:首页WGAN-gp模型——pytorch实现

WGAN-gp模型——pytorch实现

来源:步遥情感网

论文传送门:

WGAN存在的问题:在WGAN中,为使得判别器D(x)满足Lipschitz连续条件,从而对网络参数进行了[-c,c]的区间,使得网络参数分布极端,参数均接近于-c或c。

WGAN-gp的目的:解决WGAN参数分布极端的问题。 

WGAN-gp的方法:在判别器D的loss中增加梯度惩罚项,代替WGAN中对判别器D的参数区间,同样能保证D(x)满足Lipschitz连续条件。(证明过程见论文补充材料)

红框部分:与WGAN不同之处,即判别器D的loss增加梯度惩罚项和优化器选择Adam

梯度惩罚项的计算实现见代码70-87行,判别器D的损失函数修改见代码156行。

import os
import torch
from torch.utils.data import DataLoader

import torch.nn as nn

from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm import tqdm


class Discriminator(nn.Module):  # 定义判别器(WS-divergence)
    def __init__(self, img_shape=(1, 28, 28)):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
        self.img_shape = img_shape  # 图片形状

        self.linear1 = nn.Linear(self.img_shape[0] * self.img_shape[1] * self.img_shape[2], 512)  # linear映射
        self.linear2 = nn.Linear(512, 256)  # linear映射
        self.linear3 = nn.Linear(256, 1)  # linear映射
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)  # leakyrelu激活函数

    def forward(self, x):  # 前传函数

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- obuygou.com 版权所有 赣ICP备2024042798号-5

违法及侵权请联系:TEL:199 18 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务