Files
Cifar100/test_memory.py
2025-10-22 11:41:55 +08:00

167 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
显存使用测试脚本
用于测试不同WideResNet配置的显存使用情况
"""
import torch
import torch.nn as nn
from src.cifar100.model import WideResNet
from src.cifar100.config import get_config, print_config
def test_memory_usage(config_name):
"""
测试指定配置的显存使用情况
Args:
config_name: 配置名称
"""
# 获取配置
config = get_config(config_name)
print_config(config)
# 检查是否有GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
if not torch.cuda.is_available():
print("未检测到CUDA设备无法测试显存使用")
return
# 清空显存
torch.cuda.empty_cache()
# 记录初始显存
initial_memory = torch.cuda.memory_allocated() / 1024**2 # MB
print(f"初始显存使用: {initial_memory:.2f} MB")
# 创建模型
model = WideResNet(
depth=config["depth"],
width_factor=config["width_factor"],
dropout=config["dropout"],
in_channels=3,
labels=100
).to(device)
# 记录模型加载后的显存
model_memory = torch.cuda.memory_allocated() / 1024**2 # MB
print(f"模型加载后显存使用: {model_memory:.2f} MB")
print(f"模型参数占用显存: {model_memory - initial_memory:.2f} MB")
# 计算模型参数数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数数量: {total_params:,}")
print(f"可训练参数数量: {trainable_params:,}")
# 测试前向传播的显存使用
batch_size = config["batch_size"]
print(f"\n测试前向传播 (批量大小: {batch_size})...")
# 创建随机输入
inputs = torch.randn(batch_size, 3, 32, 32).to(device)
# 记录前向传播前的显存
before_forward = torch.cuda.memory_allocated() / 1024**2 # MB
# 前向传播
with torch.no_grad():
outputs = model(inputs)
# 记录前向传播后的显存
after_forward = torch.cuda.memory_allocated() / 1024**2 # MB
print(f"前向传播前显存使用: {before_forward:.2f} MB")
print(f"前向传播后显存使用: {after_forward:.2f} MB")
print(f"前向传播增加显存: {after_forward - before_forward:.2f} MB")
# 测试训练时的显存使用
print(f"\n测试训练过程 (梯度累积步数: {config['accumulation_steps']})...")
# 创建标签
targets = torch.randint(0, 100, (batch_size,)).to(device)
# 创建优化器和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
# 记录训练前的显存
before_training = torch.cuda.memory_allocated() / 1024**2 # MB
# 训练步骤
model.train()
optimizer.zero_grad()
for step in range(config["accumulation_steps"]):
outputs = model(inputs)
loss = criterion(outputs, targets)
loss = loss / config["accumulation_steps"]
loss.backward()
optimizer.step()
# 记录训练后的显存
after_training = torch.cuda.memory_allocated() / 1024**2 # MB
print(f"训练前显存使用: {before_training:.2f} MB")
print(f"训练后显存使用: {after_training:.2f} MB")
print(f"训练增加显存: {after_training - before_training:.2f} MB")
# 记录峰值显存
peak_memory = torch.cuda.max_memory_allocated() / 1024**2 # MB
print(f"\n峰值显存使用: {peak_memory:.2f} MB")
# 清理
del model, inputs, targets, outputs, loss
torch.cuda.empty_cache()
return peak_memory
def main():
print("=" * 60)
print("WideResNet 显存使用测试")
print("=" * 60)
# 测试所有配置
configs = ["2gb", "4gb", "8gb", "high_performance"]
results = {}
for config_name in configs:
print(f"\n{'='*60}")
print(f"测试配置: {config_name}")
print(f"{'='*60}")
try:
peak_memory = test_memory_usage(config_name)
results[config_name] = peak_memory
except Exception as e:
print(f"测试配置 {config_name} 时出错: {str(e)}")
results[config_name] = None
# 打印汇总结果
print(f"\n{'='*60}")
print("测试结果汇总")
print(f"{'='*60}")
for config_name, peak_memory in results.items():
if peak_memory is not None:
print(f"{config_name:>15}: {peak_memory:>8.2f} MB")
else:
print(f"{config_name:>15}: {'测试失败':>8}")
# 给出建议
print(f"\n{'='*60}")
print("建议")
print(f"{'='*60}")
if "4gb" in results and results["4gb"] is not None:
if results["4gb"] < 4000:
print("4GB显存配置适合您的GPU")
else:
print("4GB显存配置可能超出您的GPU显存限制建议使用2GB配置")
if "2gb" in results and results["2gb"] is not None:
if results["2gb"] < 2000:
print("2GB显存配置适合您的GPU")
else:
print("2GB显存配置可能超出您的GPU显存限制")
if __name__ == "__main__":
main()