(刘二大人)PyTorch深度学习实践-多分类问题(Minist)

2023/11/29 10:19:48

1.首先解决加载数据集缓慢以及不成功问题

去Minist官网下载四个数据集,放到你的项目文件中,最好放在MINIST/raw文件夹中,切忌不要随便解压,这里我的路径为E:\learn_pytorch\LE\MNIST\raw

 

然后去你的pytorch环境中的lib库中找到site-packages中的torchvision包,修改minist.py的文件下载路径,我这里是Anaconda的虚拟环境

 

 直接将你的文件的加载路径放上去,然后把上面的去掉

 

 我们再次进行数据下载会发现已成功

2.完整代码实现

import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

import torch.nn.functional as F
import torch.optim as optim

#准备数据集
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3801,))])#这里第一个是均值,第二个是标准差
train_datasets = datasets.MNIST(root='E:\learn_pytorch\LE',train=True,transform=trans,download=True)
test_datasets = datasets.MNIST(root='E:\learn_pytorch\LE',train=False,transform=trans,download=True)

#进行数据集的加载
batch_size = 64
train_loader = DataLoader(dataset=train_datasets,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset=test_datasets,batch_size=batch_size,shuffle=False)

#进行模型的构建
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(784,512)
        self.linear2 = torch.nn.Linear(512,256)
        self.linear3 = torch.nn.Linear(256,128)
        self.linear4 = torch.nn.Linear(128,64)
        self.linear5 = torch.nn.Linear(64,10)

    def forward(self,x):
        x = x.view(-1,784)#在这里先对x进行操作,我们将其转换为张量
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        return self.linear5(x)

#进行实例化
huihui = Model()

#定义损失函数和优化器
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(huihui.parameters(),lr=0.01,momentum=0.5)

#我们将一轮epoch单独拿出来作为一个函数
def train(epoch):
    running_loss = 0.0
    for batch_id,data in enumerate(train_loader):
        inputs,targets = data
        optimizer.zero_grad()

        # Forward
        outputs = huihui(inputs)
        loss = loss_fn(outputs,targets)
        loss.backward()
        optimizer.step()
        #标签从0开始
        running_loss+=loss.item()
        if batch_id%300 == 299:
            print('[%d,%5d]  loss:%.3f' %(epoch+1,batch_id+1,running_loss/300))
            running_loss = 0.0


#定义测试集
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images,labels = data
            outputs = huihui(images)
            #这个torch.max函数可以返回最大值和最大值的下标,那个predict取的是最大值下标,只需要拿它和标签对比即可
            _,predict = torch.max(outputs.data,dim=1)
            total += labels.size(0)#总共有多少个标签样本,Nx1
            correct+=(predict==labels).sum().item()#将我们预测的最有可能的下标与真实标签对比,最后将这个标量取出来
        print('Accuracy on test set: %d %%' % (100*correct/total))


#进行训练和测试
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

3.结果展示(由于损失了很多空间特征,这种方法精度就是在97%了)

D:\Anaconda3\envs\pytorch\python.exe E:/learn_pytorch/LE/Minist.py
[1,  300]  loss:2.259
[1,  600]  loss:1.288
[1,  900]  loss:0.505
Accuracy on test set: 88 %
[2,  300]  loss:0.359
[2,  600]  loss:0.312
[2,  900]  loss:0.266
Accuracy on test set: 92 %
[3,  300]  loss:0.216
[3,  600]  loss:0.193
[3,  900]  loss:0.173
Accuracy on test set: 95 %
[4,  300]  loss:0.153
[4,  600]  loss:0.136
[4,  900]  loss:0.128
Accuracy on test set: 96 %
[5,  300]  loss:0.114
[5,  600]  loss:0.106
[5,  900]  loss:0.099
Accuracy on test set: 96 %
[6,  300]  loss:0.089
[6,  600]  loss:0.086
[6,  900]  loss:0.078
Accuracy on test set: 97 %
[7,  300]  loss:0.070
[7,  600]  loss:0.069
[7,  900]  loss:0.069
Accuracy on test set: 97 %
[8,  300]  loss:0.055
[8,  600]  loss:0.056
[8,  900]  loss:0.055
Accuracy on test set: 97 %
[9,  300]  loss:0.044
[9,  600]  loss:0.046
[9,  900]  loss:0.046
Accuracy on test set: 97 %
[10,  300]  loss:0.040
[10,  600]  loss:0.040
[10,  900]  loss:0.036
Accuracy on test set: 97 %

Process finished with exit code 0


http://www.jnnr.cn/a/89267.html

相关文章

ICCV 2021 | Y-Net:轨迹-场景信息的真正融合

今天没有多余的解释,直接开始吧~ 1. Y-Net网络结构 Y-Net的网络结构长什么样子呢?Y-Net的网络结构就长下图这样子。看上去我好像在自言自语,其实你仔细揣摩就会发现,我真的是在自言自语。可以看到说,Y-Net网络输入的是…

(附源码)计算机毕业设计SSM冷链物流管理系统

(附源码)计算机毕业设计SSM冷链物流管理系统 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术&…

一维前缀和与一维差分

目录 一、前言 二、一维前缀和 1、构造数组 2、上题例 3、C代码 4、python代码 三、一维差分 1、构造数组 2、上题例 3、python代码 4、C代码 一、前言 对于学计算机的同学来说,学习算法是一件非常重要的事情,废话不多讲,我们来讲…

【FPGA】SCCB通信协议

文章目录一. 什么是SCCB协议?二. SCCB时序分析1. 起始信号2. 停止信号3. 数据传输3.1 三相写传输3.2 两相写传输:读数据第一阶段3.3 两相读传输:读数据第二阶段一. 什么是SCCB协议? SCCB (Serial Camera Control Bus)…

3.3.3JavaScript网页编程——WebAPI(JS之BOM含正则)

目录BOMwindow对象定时器-延时函数setTimeoutJS执行机制(执行栈、任务队列)面试要问location对象location.href (获取完整url或者赋值)location.search (获取?后面的)location.hash(获取#号后面的)location.reloadnavigator对象(检测浏览器移…

基于matlab的史密斯圆图演示仿真图

目录 1.算法概述 2.部分程序 3.算法部分仿真结果图 4.完整程序获取 1.算法概述 史密斯图表(Smith chart,又称史密斯圆图)是在反射系散平面上标绘有归一化输入阻抗(或导纳)等值圆族的计算图。是一款用于电机与电子工程学的图表&#xff0c…

注意!2022年下半年pets5正在报考中

2022年下半年全国外语水平考试(WSK-PETS5)网上报名时间为10月24日13时-10月26日16时,知识人网小编特别提醒报名者注意报名截止时间是16点(下午4点),万勿错过! 国家公派留学人员全国外语水平考试…

[C++] 初接触 泛型编程—— C++ 模板分析

泛型编程 C中引入了重载的概念,使得可以编写多个函数名相同但参数、返回值不同的函数,例如: 相同的函数名可以传入不同的参宿,进而调用不同的函数 但,即使有了重载,相同功能的函数 还要分别对不同的类型进…

关于UI测试的相关及技巧

一、关于UI测试 1、UI走查顺序 1.1、有空白页的页面优先测试(走查)空白页 1.2、按页面跳转流程把主线任务走一遍。 1.3、测试(走查)主线任务之外的页面。 1.4、对于复用以前组件的控件,主要看和以前是否一致&…

ClickHouse快速入门

ClickHouse可运行于任何x86 64位的Linux, FreeBSD, 或 Mac OS X ,及 AArch64, 或PowerPC64LE CPU架构。下列步骤将在Linux上安装和运行ClickHouse。 1. 启动Clickhouse 1)下载Clickhouse到本地的最简单方法是运行如下命令。如果操作系统支持,将会下载一…

大华股份流程IT总监金利红:数字化转型如何做到“五全一持续”?

企业数字化是为了业务成功,解决业务痛点,提高运作效率,控制风险。 通过数字化转型在企业经营管理上的应用,一是让业务管理有抓手,帮助公司从经营数据层层钻取到业务细节,进而清晰是什么原因导效问题的发生…

TEEOS的实例-在线支付系统

到了我最喜欢的环节了,我其实学习的过程中,对这些应用场景概念是十分的感兴趣的。 下面一起看看老师这本书的最后的一个部分应用篇—在线支付系统 内容来自《手机安全和可信应用开发指南》 1、简介 基于安全考虑,支付系统的最终结算由支付…

[实践篇]13.13 再来梳理一下HAB的设计原理

【QNX Hypervisor 2.2用户手册】目录(完结) 一,什么是HAB? HAB全称为Hypervisor Abstraction,即硬件抽象,主要用于提供对HOST OS硬件资源的访问。其核心架构包括API层,Core层和HAB-HYP插件层,如下: 通常情况下,OS通过驱动程序可以直接访问物理硬件,在HAB中的实现则…

vxe-table 实现嵌套子表格

目录一、功能说明1.图1:表格主界面2.图2:新增父表格3.图3:新增子表格4.图4:子表格二、代码实现一、功能说明 1.图1:表格主界面 说明: 1)新增:点击触发函数 addParent 添加父表格一行…

Creo/Proe草图无法标注尺寸怎么办?

之前一直用的solidworks,今天遇到一个及其尴尬的问题,就是跟着ICE的creo视频,但是没有办法自己标注尺寸 最后找了好久才发现,原来creo左键单击是选择,但是要出来尺寸需要点击中键!!&#xff01…

【Verilog】valid-ready双向握手机制 ——很绕但是很有意思

题干 描述 实现串行输入数据累加输出,输入端输入8bit数据,每当模块接收到4个输入数据后,输出端输出4个接收到数据的累加结果。输入端和输出端与上下游的交互采用valid-ready双向握手机制。要求上下游均能满速传输时,数据传输无气泡…

行走在加密世界 你需要了解这6个加密投资思维模型

加密世界沉浮5年,我总结了6个非常重要的思维模型,排名不分先后: 1)概率---随机性奖励 你是否认为这波机会是一旦错过,就会后悔一辈子?或许你也相信一种论调叫做,一个人的成功,只需…

Docker容器-----Consul(注册中心)部署

前言 Consul是HashiCorp公司推出的开源工具,用于实现分布式系统的服务发现与配置 与Docker等轻量级容器可无缝配合 一、Docker consul(注册中心) 1、什么是consul Consul是HashiCorp公司推出的开源工具,consul包含很多组件&am…

NFT标准:带有 EIP-3754 的普通NFT

NFT标准:带有 EIP-3754 的普通NFT NFT标准ERC721有点臃肿。这可能会导致一些公司只部分遵循ERC721来实现某个目标。例如NFT的订阅模型。或者有些公司可能只想实现一个没有URI的代币。EIP-3754被赋予生命来创建一个原子NFT标准,我们可以在其上构建抽象层。…

JMeter性能测试之运行内存设置

在进行大数据、高并发压测的过程性,有时会遇上JMeter卡死现象,使得测试无法进行,查看日志显示:java.lang.OutOfMemoryError: Java heap space 原因:运行jmeter机器的内存,占用较高,超过了jmete…
最新文章