Files
Cifar100/main.py
2025-10-22 11:41:55 +08:00

164 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import os
import gc
import numpy as np
import random
from src.cifar100.data import Data
from src.cifar100.model import WideResNet
from src.cifar100.trainer import Trainer
from src.cifar100.visualizer import TrainingVisualizer
from src.cifar100.config import get_config, print_config
def set_seed(seed=42):
"""设置随机种子以确保结果可重现"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# 确保CUDA操作是确定性的
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
# 设置随机种子
set_seed(42)
# 获取配置 - 默认使用4G显存配置
config_name = "4gb" # 可以改为 "2gb", "8gb" 或 "high_performance"
config = get_config(config_name)
print_config(config)
# 从配置中获取参数
batch_size = config["batch_size"]
threads = 2 # 减少线程数量以降低内存使用
epochs = 200
learning_rate = config["learning_rate"]
accumulation_steps = config["accumulation_steps"]
warmup_epochs = config["warmup_epochs"]
# 设置环境变量以优化内存使用
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# 检查是否有GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 如果使用GPU设置一些优化选项
if torch.cuda.is_available():
# 启用cudnn基准测试优化卷积操作
torch.backends.cudnn.benchmark = True
# 禁用cudnn确定性可能提高性能
torch.backends.cudnn.deterministic = False
# 加载数据
data = Data(batch_size, threads)
# 创建模型 - 使用配置中的参数
model = WideResNet(
depth=config["depth"],
width_factor=config["width_factor"],
dropout=config["dropout"],
in_channels=3,
labels=100
).to(device)
# 创建训练器 - 启用标签平滑
trainer = Trainer(model, device, use_label_smoothing=True, smoothing=0.1)
# 创建可视化器
visualizer = TrainingVisualizer(save_dir="./plots")
# 定义优化器和学习率调度器 - 使用余弦退火调度器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True)
# 使用余弦退火调度器通常比MultiStepLR效果更好
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-4)
trainer.set_optimizer(optimizer, scheduler)
# 添加学习率预热功能 - 对WideResNet特别有效
def warmup_lr(epoch, warmup_epochs=warmup_epochs, base_lr=learning_rate):
if epoch < warmup_epochs:
return base_lr * (epoch + 1) / warmup_epochs
return base_lr
# 保存初始学习率
base_lr = learning_rate
# 训练模型
best_acc = 0
patience = 30 # 早停耐心值
patience_counter = 0
for epoch in range(epochs):
print(f'Epoch: {epoch+1}/{epochs}')
# 应用学习率预热
if epoch < warmup_epochs: # 前warmup_epochs个epoch进行预热
current_lr = warmup_lr(epoch, warmup_epochs=warmup_epochs, base_lr=base_lr)
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr
# 使用梯度累积accumulation_steps表示每accumulation_steps个batch更新一次权重
train_metrics = trainer.train_epoch(data.train, accumulation_steps=accumulation_steps)
# 测试模型
test_metrics = trainer.evaluate(data.test)
# 更新可视化器数据
visualizer.update(
epoch=epoch+1,
train_loss=train_metrics["loss"],
train_acc=train_metrics["accuracy"],
test_loss=test_metrics["loss"],
test_acc=test_metrics["accuracy"]
)
print(f'Epoch: {epoch+1}/{epochs} | Train Loss: {train_metrics["loss"]:.3f} | Train Acc: {train_metrics["accuracy"]:.3f}%')
print(f'Epoch: {epoch+1}/{epochs} | Test Loss: {test_metrics["loss"]:.3f} | Test Acc: {test_metrics["accuracy"]:.3f}%')
print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
# 保存最佳模型
current_acc = test_metrics["accuracy"]
if current_acc > best_acc:
best_acc = current_acc
trainer.save_model(f'best_model.pth')
print(f'New best accuracy: {best_acc:.3f}% - Model saved: best_model.pth')
patience_counter = 0 # 重置耐心计数器
else:
patience_counter += 1
# 每10个epoch显示一次训练进度图表
if (epoch + 1) % 10 == 0:
visualizer.plot_metrics(save=True, show=False)
print("训练进度图表已更新")
# 保存模型
if (epoch + 1) % 20 == 0:
model_name = f'wrn{config["depth"]}_{config["width_factor"]}_epoch_{epoch+1}.pth'
trainer.save_model(model_name)
print(f'Model saved: {model_name}')
# 早停检查
if patience_counter >= patience:
print(f"Early stopping triggered after {patience} epochs without improvement.")
break
# 每个epoch结束后进行垃圾回收
gc.collect()
torch.cuda.empty_cache()
# 训练结束后,显示最终结果
print("\n训练完成!生成最终可视化图表...")
visualizer.plot_metrics(save=True, show=True)
visualizer.plot_combined(save=True, show=True)
visualizer.save_data()
visualizer.print_summary()
print(f"最佳测试准确率: {best_acc:.3f}%")
if __name__ == "__main__":
main()