- 机器学习系统:设计和实现
- 麦络 董豪编著
- 306字
- 2024-12-27 20:30:18
2.2.5 训练及保存模型
MindSpore提供了回调(Callback)机制,可以在训练过程中执行自定义逻辑。代码2.6使用框架提供的ModelCheckpoint函数,ModelCheckpoint函数可以保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。
代码2.6 定义模型保存
![](https://epubservercos.yuewen.com/2564F9/31398141107520606/epubprivate/OEBPS/Images/Figure-P31_10673.jpg?sign=1739503284-TyzsbIGWzmhrEzZjyXqg8HESiIEDqHxs-0-7426fa6a4eaf951e8af89f8e0f8c1f26)
通过MindSpore提供的model.train接口可以方便地进行网络的训练,同时使用Loss-Monitor可以监控训练过程中损失(loss)值的变化,如代码2.7所示。
代码2.7 定义模型训练
![](https://epubservercos.yuewen.com/2564F9/31398141107520606/epubprivate/OEBPS/Images/Figure-P31_10674.jpg?sign=1739503284-SeMFMaqNBoxkDk4dMfHmjiBsVRsStVdA-0-67b90f15b8dc796031b00b937b16ed87)
其中,dataset_sink_mode用于控制数据是否下沉,数据下沉是指数据通过通道直接传送到设备(Device)上,可以加快训练速度,dataset_sink_mode为真(True),表示数据下沉,否则为非下沉。
有了数据集、模型、损失函数、优化器后就可以进行训练了。代码2.8把train_epoch设置为1,对数据集进行1次迭代训练。在train_net方法中,加载了之前下载的训练数据集,mnist_path是MNIST数据集路径。
代码2.8 训练模型
![](https://epubservercos.yuewen.com/2564F9/31398141107520606/epubprivate/OEBPS/Images/Figure-P31_10675.jpg?sign=1739503284-xV8lcG4ISYyo6wZ0xIRiGWl19g9daOhN-0-4367be63c47de45e91df7a8894d66224)