GAN初步-生成1010格式规律的向量

构建和训练GAN的推荐步骤:

(1)从真实数据集预 览数据;

(2)测试鉴别器至少具备从随机噪声中区分 真实数据的能力;

(3)测试未经训练的生成器能否创 建正确格式的数据;

(4)可视化观察损失值,了解训 练进展。

#真实的数据源
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
import random
import numpy
def synthetic_data():
    real_data = torch.FloatTensor([
        random.uniform(0.8,1.0),
        random.uniform(0.0,0.1),
        random.uniform(0.8,0.9),
        random.uniform(0.0,0.1)
    ])
    return real_data
#Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        
        self.model = nn.Sequential(
        nn.Linear(1,3),
        nn.Sigmoid(),
        nn.Linear(3,4),
        nn.Sigmoid())
        
#         self.loss_function = nn.MSELoss()
        
        self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)
        self.counter = 0
        self.progress = []
        pass
    def forward(self,inputs):
        return self.model(inputs)
    
    def train(self,D,inputs,targets):
        g_outputs = self.forward(inputs)
        
        d_output = D.forward(g_outputs)
        
        loss = D.loss_function(d_output,targets)
        
        self.counter += 1
        
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
#             print('countetr = ',self.counter)
            pass
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        
        pass
    def plot_progress(self):
        df = pandas.DataFrame(self.progress,columns=['loss'])
        df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',
                grid=True,yticks=(0,0.25,0.5))
        pass
    
#descriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        
        self.model = nn.Sequential(
        nn.Linear(4,3),
        nn.Sigmoid(),
        nn.Linear(3,1),
        nn.Sigmoid())
        
        self.loss_function = nn.MSELoss()
        
        self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)
        self.counter = 0
        self.progress = []
        pass
    def forward(self,inputs):
        return self.model(inputs)
    
    def train(self,inputs,targets):
        outputs = self.forward(inputs)
        
        loss = self.loss_function(outputs,targets)
        
        self.counter += 1
        
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
            print('countetr = ',self.counter)
            pass
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        
        pass
    def plot_progress(self):
        df = pandas.DataFrame(self.progress,columns=['loss'])
        df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',
                grid=True,yticks=(0,0.25,0.5))
        pass

# 记录训练过程
image_list=[]

for i in range(10000):
    D.train(synthetic_data(),torch.FloatTensor([1.0]))
    
    D.train(G.forward(torch.FloatTensor([0.5])).detach(),torch.FloatTensor([0.0]))
    G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
    if i%1000 == 0:
        image_list.append(G.forward(torch.FloatTensor([0.5])))
    
#     G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.0]))

    pass
image_list_ = []
for i in  range(len(image_list)):
    image_list_.append(image_list[i].detach().numpy())
    
plt.imshow(numpy.array(image_list_).T,interpolation='none',cmap='Blues')

原文地址:http://www.cnblogs.com/afengblog/p/16794588.html

1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长! 2. 分享目的仅供大家学习和交流,请务用于商业用途! 3. 如果你也有好源码或者教程,可以到用户中心发布,分享有积分奖励和额外收入! 4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解! 5. 如有链接无法下载、失效或广告,请联系管理员处理! 6. 本站资源售价只是赞助,收取费用仅维持本站的日常运营所需! 7. 如遇到加密压缩包,默认解压密码为"gltf",如遇到无法解压的请联系管理员! 8. 因为资源和程序源码均为可复制品,所以不支持任何理由的退款兑现,请斟酌后支付下载 声明:如果标题没有注明"已测试"或者"测试可用"等字样的资源源码均未经过站长测试.特别注意没有标注的源码不保证任何可用性