Files
Cifar100/README_4GB.md
2025-10-22 11:41:55 +08:00

3.6 KiB
Raw Blame History

WideResNet 4G显存优化

本项目针对WideResNet模型在4G显存上的训练进行了全面优化解决了准确率瓶颈问题并提供了多种显存配置选项。

优化措施

1. 数据增强增强

  • 添加了颜色抖动ColorJitter
  • 添加了随机旋转RandomRotation
  • 添加了随机仿射变换RandomAffine
  • 添加了随机擦除RandomErasing

2. 模型结构优化

  • 将激活函数从ReLU替换为GELU
  • 在全连接层前添加了额外的Dropout层
  • 优化了Block结构
  • 添加了专门针对WideResNet的权重初始化方法

3. 训练策略改进

  • 标签平滑Label Smoothing使用平滑因子0.1
  • 余弦退火学习率调度器替代原来的MultiStepLR
  • Nesterov动量在SGD优化器中启用nesterov=True
  • 学习率预热前5个epoch进行学习率预热
  • 梯度累积:减少显存使用
  • 早停机制添加30个epoch的耐心值

配置选项

项目提供了四种预定义配置适用于不同显存大小的GPU

4G显存配置默认

config_name = "4gb"
  • 模型深度: 22
  • 宽度因子: 4
  • 批量大小: 64
  • 梯度累积步数: 2

2G显存配置

config_name = "2gb"
  • 模型深度: 16
  • 宽度因子: 2
  • 批量大小: 32
  • 梯度累积步数: 4

8G显存配置

config_name = "8gb"
  • 模型深度: 28
  • 宽度因子: 10
  • 批量大小: 128
  • 梯度累积步数: 1

高性能配置

config_name = "high_performance"
  • 模型深度: 34
  • 宽度因子: 10
  • 批量大小: 128
  • 梯度累积步数: 1

使用方法

1. 选择配置

main.py中修改配置名称:

config_name = "4gb"  # 可以改为 "2gb", "8gb" 或 "high_performance"

2. 运行训练

python main.py

3. 测试显存使用

运行显存测试脚本,查看不同配置的显存使用情况:

python test_memory.py

文件结构

Cifar100/
├── src/cifar100/
│   ├── __init__.py
│   ├── config.py         # 配置文件
│   ├── data.py           # 数据加载和增强
│   ├── model.py          # WideResNet模型定义
│   ├── trainer.py        # 训练器
│   └── visualizer.py     # 训练可视化
├── main.py               # 主训练脚本
├── test_memory.py        # 显存测试脚本
└── plots/                # 训练图表保存目录

预期效果

这些优化措施应该能够:

  1. 提高模型泛化能力:通过增强的数据增强和标签平滑
  2. 加速收敛通过余弦退火调度器和Nesterov动量
  3. 防止过拟合通过标签平滑、Dropout和早停机制
  4. 突破准确率瓶颈综合以上所有改进应该能显著超过52%-53%的准确率

自定义配置

如果预定义配置不满足需求,可以在 src/cifar100/config.py中添加自定义配置:

CUSTOM_CONFIG = {
    "depth": 28,
    "width_factor": 8,
    "dropout": 0.3,
    "batch_size": 96,
    "learning_rate": 0.1,
    "accumulation_steps": 1,
    "warmup_epochs": 5,
    "description": "自定义配置"
}

然后在 get_config函数中添加对应的映射。

注意事项

  1. 确保PyTorch版本支持CUDA
  2. 如果遇到显存不足,尝试减小批量大小或增加梯度累积步数
  3. 训练过程中会保存最佳模型和定期检查点
  4. 训练图表会保存在 ./plots目录中

性能基准

在4G显存的GPU上使用默认配置WRN-22-4

  • 训练时间: 约2-3小时200个epoch
  • 峰值显存使用: 约3.5GB
  • 预期测试准确率: 70%+