3.6 KiB
3.6 KiB
WideResNet 4G显存优化
本项目针对WideResNet模型在4G显存上的训练进行了全面优化,解决了准确率瓶颈问题,并提供了多种显存配置选项。
优化措施
1. 数据增强增强
- 添加了颜色抖动(ColorJitter)
- 添加了随机旋转(RandomRotation)
- 添加了随机仿射变换(RandomAffine)
- 添加了随机擦除(RandomErasing)
2. 模型结构优化
- 将激活函数从ReLU替换为GELU
- 在全连接层前添加了额外的Dropout层
- 优化了Block结构
- 添加了专门针对WideResNet的权重初始化方法
3. 训练策略改进
- 标签平滑(Label Smoothing):使用平滑因子0.1
- 余弦退火学习率调度器:替代原来的MultiStepLR
- Nesterov动量:在SGD优化器中启用nesterov=True
- 学习率预热:前5个epoch进行学习率预热
- 梯度累积:减少显存使用
- 早停机制:添加30个epoch的耐心值
配置选项
项目提供了四种预定义配置,适用于不同显存大小的GPU:
4G显存配置(默认)
config_name = "4gb"
- 模型深度: 22
- 宽度因子: 4
- 批量大小: 64
- 梯度累积步数: 2
2G显存配置
config_name = "2gb"
- 模型深度: 16
- 宽度因子: 2
- 批量大小: 32
- 梯度累积步数: 4
8G显存配置
config_name = "8gb"
- 模型深度: 28
- 宽度因子: 10
- 批量大小: 128
- 梯度累积步数: 1
高性能配置
config_name = "high_performance"
- 模型深度: 34
- 宽度因子: 10
- 批量大小: 128
- 梯度累积步数: 1
使用方法
1. 选择配置
在 main.py中修改配置名称:
config_name = "4gb" # 可以改为 "2gb", "8gb" 或 "high_performance"
2. 运行训练
python main.py
3. 测试显存使用
运行显存测试脚本,查看不同配置的显存使用情况:
python test_memory.py
文件结构
Cifar100/
├── src/cifar100/
│ ├── __init__.py
│ ├── config.py # 配置文件
│ ├── data.py # 数据加载和增强
│ ├── model.py # WideResNet模型定义
│ ├── trainer.py # 训练器
│ └── visualizer.py # 训练可视化
├── main.py # 主训练脚本
├── test_memory.py # 显存测试脚本
└── plots/ # 训练图表保存目录
预期效果
这些优化措施应该能够:
- 提高模型泛化能力:通过增强的数据增强和标签平滑
- 加速收敛:通过余弦退火调度器和Nesterov动量
- 防止过拟合:通过标签平滑、Dropout和早停机制
- 突破准确率瓶颈:综合以上所有改进,应该能显著超过52%-53%的准确率
自定义配置
如果预定义配置不满足需求,可以在 src/cifar100/config.py中添加自定义配置:
CUSTOM_CONFIG = {
"depth": 28,
"width_factor": 8,
"dropout": 0.3,
"batch_size": 96,
"learning_rate": 0.1,
"accumulation_steps": 1,
"warmup_epochs": 5,
"description": "自定义配置"
}
然后在 get_config函数中添加对应的映射。
注意事项
- 确保PyTorch版本支持CUDA
- 如果遇到显存不足,尝试减小批量大小或增加梯度累积步数
- 训练过程中会保存最佳模型和定期检查点
- 训练图表会保存在
./plots目录中
性能基准
在4G显存的GPU上,使用默认配置(WRN-22-4):
- 训练时间: 约2-3小时(200个epoch)
- 峰值显存使用: 约3.5GB
- 预期测试准确率: 70%+