164 lines
5.8 KiB
Python
164 lines
5.8 KiB
Python
import torch
|
||
import os
|
||
import gc
|
||
import numpy as np
|
||
import random
|
||
from src.cifar100.data import Data
|
||
from src.cifar100.model import WideResNet
|
||
from src.cifar100.trainer import Trainer
|
||
from src.cifar100.visualizer import TrainingVisualizer
|
||
from src.cifar100.config import get_config, print_config
|
||
|
||
|
||
def set_seed(seed=42):
|
||
"""设置随机种子以确保结果可重现"""
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
# 确保CUDA操作是确定性的
|
||
torch.backends.cudnn.deterministic = True
|
||
torch.backends.cudnn.benchmark = False
|
||
|
||
|
||
def main():
|
||
# 设置随机种子
|
||
set_seed(42)
|
||
|
||
# 获取配置 - 默认使用4G显存配置
|
||
config_name = "4gb" # 可以改为 "2gb", "8gb" 或 "high_performance"
|
||
config = get_config(config_name)
|
||
print_config(config)
|
||
|
||
# 从配置中获取参数
|
||
batch_size = config["batch_size"]
|
||
threads = 2 # 减少线程数量以降低内存使用
|
||
epochs = 200
|
||
learning_rate = config["learning_rate"]
|
||
accumulation_steps = config["accumulation_steps"]
|
||
warmup_epochs = config["warmup_epochs"]
|
||
|
||
# 设置环境变量以优化内存使用
|
||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
|
||
|
||
# 检查是否有GPU
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"Using device: {device}")
|
||
|
||
# 如果使用GPU,设置一些优化选项
|
||
if torch.cuda.is_available():
|
||
# 启用cudnn基准测试,优化卷积操作
|
||
torch.backends.cudnn.benchmark = True
|
||
# 禁用cudnn确定性,可能提高性能
|
||
torch.backends.cudnn.deterministic = False
|
||
|
||
# 加载数据
|
||
data = Data(batch_size, threads)
|
||
|
||
# 创建模型 - 使用配置中的参数
|
||
model = WideResNet(
|
||
depth=config["depth"],
|
||
width_factor=config["width_factor"],
|
||
dropout=config["dropout"],
|
||
in_channels=3,
|
||
labels=100
|
||
).to(device)
|
||
|
||
# 创建训练器 - 启用标签平滑
|
||
trainer = Trainer(model, device, use_label_smoothing=True, smoothing=0.1)
|
||
|
||
# 创建可视化器
|
||
visualizer = TrainingVisualizer(save_dir="./plots")
|
||
|
||
# 定义优化器和学习率调度器 - 使用余弦退火调度器
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True)
|
||
# 使用余弦退火调度器,通常比MultiStepLR效果更好
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-4)
|
||
trainer.set_optimizer(optimizer, scheduler)
|
||
|
||
# 添加学习率预热功能 - 对WideResNet特别有效
|
||
def warmup_lr(epoch, warmup_epochs=warmup_epochs, base_lr=learning_rate):
|
||
if epoch < warmup_epochs:
|
||
return base_lr * (epoch + 1) / warmup_epochs
|
||
return base_lr
|
||
|
||
# 保存初始学习率
|
||
base_lr = learning_rate
|
||
|
||
# 训练模型
|
||
best_acc = 0
|
||
patience = 30 # 早停耐心值
|
||
patience_counter = 0
|
||
|
||
for epoch in range(epochs):
|
||
print(f'Epoch: {epoch+1}/{epochs}')
|
||
|
||
# 应用学习率预热
|
||
if epoch < warmup_epochs: # 前warmup_epochs个epoch进行预热
|
||
current_lr = warmup_lr(epoch, warmup_epochs=warmup_epochs, base_lr=base_lr)
|
||
for param_group in optimizer.param_groups:
|
||
param_group['lr'] = current_lr
|
||
|
||
# 使用梯度累积,accumulation_steps表示每accumulation_steps个batch更新一次权重
|
||
train_metrics = trainer.train_epoch(data.train, accumulation_steps=accumulation_steps)
|
||
|
||
# 测试模型
|
||
test_metrics = trainer.evaluate(data.test)
|
||
|
||
# 更新可视化器数据
|
||
visualizer.update(
|
||
epoch=epoch+1,
|
||
train_loss=train_metrics["loss"],
|
||
train_acc=train_metrics["accuracy"],
|
||
test_loss=test_metrics["loss"],
|
||
test_acc=test_metrics["accuracy"]
|
||
)
|
||
|
||
print(f'Epoch: {epoch+1}/{epochs} | Train Loss: {train_metrics["loss"]:.3f} | Train Acc: {train_metrics["accuracy"]:.3f}%')
|
||
print(f'Epoch: {epoch+1}/{epochs} | Test Loss: {test_metrics["loss"]:.3f} | Test Acc: {test_metrics["accuracy"]:.3f}%')
|
||
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
|
||
|
||
# 保存最佳模型
|
||
current_acc = test_metrics["accuracy"]
|
||
if current_acc > best_acc:
|
||
best_acc = current_acc
|
||
trainer.save_model(f'best_model.pth')
|
||
print(f'New best accuracy: {best_acc:.3f}% - Model saved: best_model.pth')
|
||
patience_counter = 0 # 重置耐心计数器
|
||
else:
|
||
patience_counter += 1
|
||
|
||
# 每10个epoch显示一次训练进度图表
|
||
if (epoch + 1) % 10 == 0:
|
||
visualizer.plot_metrics(save=True, show=False)
|
||
print("训练进度图表已更新")
|
||
|
||
# 保存模型
|
||
if (epoch + 1) % 20 == 0:
|
||
model_name = f'wrn{config["depth"]}_{config["width_factor"]}_epoch_{epoch+1}.pth'
|
||
trainer.save_model(model_name)
|
||
print(f'Model saved: {model_name}')
|
||
|
||
# 早停检查
|
||
if patience_counter >= patience:
|
||
print(f"Early stopping triggered after {patience} epochs without improvement.")
|
||
break
|
||
|
||
# 每个epoch结束后进行垃圾回收
|
||
gc.collect()
|
||
torch.cuda.empty_cache()
|
||
|
||
# 训练结束后,显示最终结果
|
||
print("\n训练完成!生成最终可视化图表...")
|
||
visualizer.plot_metrics(save=True, show=True)
|
||
visualizer.plot_combined(save=True, show=True)
|
||
visualizer.save_data()
|
||
visualizer.print_summary()
|
||
|
||
print(f"最佳测试准确率: {best_acc:.3f}%")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|