feat: 增强堆叠模型管道,添加配置管理、日志记录和性能监控
重构主要管道代码,实现更完善的架构设计: - 添加统一配置管理系统,支持数据、模型、训练和可视化参数的集中管理 - 实现完整的日志记录系统,支持文件和控制台输出,提供结构化的运行日志 - 集成性能监控工具,包括执行时间、内存使用情况的实时跟踪和分析 - 新增模型检查点功能,支持模型和评估结果的自动保存与加载 - 添加数据验证器,确保数据完整性和预测结果的正确性 - 实现模型工厂模式,便于扩展和管理不同类型的机器学习模型 - 优化异常处理机制,提供专门的异常类型分类 - 更新依赖项,添加psutil库用于系统资源监控 输出文件管理改进: - 将评估结果和训练后的模型保存到outputs目录 - 添加日志文件到.gitignore,确保版本控制清洁 - 提供完整的模型评估指标对比分析
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -8,6 +8,9 @@ uv.lock
|
||||
.py[ocd]
|
||||
**/__pycache__/
|
||||
|
||||
# log file
|
||||
*.log
|
||||
|
||||
# python distribution
|
||||
dist/
|
||||
|
||||
|
||||
27
outputs/evaluation_results.json
Normal file
27
outputs/evaluation_results.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"Stacking Model": {
|
||||
"rmse": 64.64642644074704,
|
||||
"mae": 47.963500739758054,
|
||||
"r2": 0.8930611738230659
|
||||
},
|
||||
"Ridge": {
|
||||
"rmse": 20.258368087201287,
|
||||
"mae": 16.172561165969487,
|
||||
"r2": 0.9894984045767364
|
||||
},
|
||||
"XGBoost": {
|
||||
"rmse": 77.03564948928194,
|
||||
"mae": 58.84620600162227,
|
||||
"r2": 0.848144731434714
|
||||
},
|
||||
"LightGBM": {
|
||||
"rmse": 72.76199680324312,
|
||||
"mae": 57.38448079041826,
|
||||
"r2": 0.8645261150841663
|
||||
},
|
||||
"MLP": {
|
||||
"rmse": 23.392914827573286,
|
||||
"mae": 18.809170451107256,
|
||||
"r2": 0.9859971948234603
|
||||
}
|
||||
}
|
||||
BIN
outputs/stacking_model.pkl
Normal file
BIN
outputs/stacking_model.pkl
Normal file
Binary file not shown.
@@ -16,6 +16,7 @@ dependencies = [
|
||||
"xgboost>=3.1.2",
|
||||
"matplotlib>=3.8.0",
|
||||
"seaborn>=0.13.2",
|
||||
"psutil>=7.1.3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
from .data import DataGenerator
|
||||
from .models import StackingModel, BaseModels
|
||||
from .utils import Preprocessor, ModelEvaluator
|
||||
from .models import StackingModel, BaseModels, ModelFactory
|
||||
from .utils import Checkpoint, Preprocessor, logger, ModelEvaluator, PerformanceMonitor, Validator
|
||||
from .visualization import ModelPlotter
|
||||
from . import config, exceptions
|
||||
|
||||
__all__ = [
|
||||
"DataGenerator",
|
||||
"StackingModel",
|
||||
"BaseModels",
|
||||
"Checkpoint",
|
||||
"ModelFactory",
|
||||
"Preprocessor",
|
||||
"logger",
|
||||
"PerformanceMonitor",
|
||||
"Validator",
|
||||
"ModelEvaluator",
|
||||
"ModelPlotter",
|
||||
"config",
|
||||
"exceptions",
|
||||
]
|
||||
|
||||
55
src/stacking_model/config.py
Normal file
55
src/stacking_model/config.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataConfig:
|
||||
"""数据配置"""
|
||||
n_samples: int = 2000
|
||||
n_features: int = 20
|
||||
n_informative: int = 15
|
||||
test_size: float = 0.2
|
||||
val_size: float = 0.2
|
||||
random_state: int = 42
|
||||
noise: float = 20.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""模型配置"""
|
||||
base_models: List[str] = field(default_factory=lambda: ["ridge", "lasso", "elastic_net", "svr", "knn"])
|
||||
meta_learner: str = "linear_regression"
|
||||
random_state: int = 42
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""训练配置"""
|
||||
epochs: int = 10
|
||||
batch_size: int = 32
|
||||
learning_rate: float = 0.01
|
||||
early_stopping: bool = True
|
||||
patience: int = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualizationConfig:
|
||||
"""可视化配置"""
|
||||
figsize: Tuple[int, int] = (12, 8)
|
||||
style: str = "seaborn-v0_8-darkgrid"
|
||||
dpi: int = 300
|
||||
save_plots: bool = True
|
||||
output_dir: str = "outputs"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""主配置"""
|
||||
data: DataConfig = field(default_factory=DataConfig)
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
training: TrainingConfig = field(default_factory=TrainingConfig)
|
||||
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
default_config = Config()
|
||||
26
src/stacking_model/exceptions.py
Normal file
26
src/stacking_model/exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""自定义异常类"""
|
||||
|
||||
|
||||
class StackingModelException(Exception):
|
||||
"""基础异常"""
|
||||
pass
|
||||
|
||||
|
||||
class DataException(StackingModelException):
|
||||
"""数据相关异常"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelException(StackingModelException):
|
||||
"""模型相关异常"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigException(StackingModelException):
|
||||
"""配置相关异常"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationException(StackingModelException):
|
||||
"""验证异常"""
|
||||
pass
|
||||
@@ -1,178 +1,219 @@
|
||||
# import os
|
||||
from pathlib import Path
|
||||
|
||||
from stacking_model.config import default_config
|
||||
from stacking_model.data import DataGenerator
|
||||
from stacking_model.models import StackingModel
|
||||
from stacking_model.utils import Preprocessor, ModelEvaluator
|
||||
from stacking_model.visualization import ModelPlotter
|
||||
from stacking_model.utils.logger import logger
|
||||
from stacking_model.utils.checkpoint import Checkpoint
|
||||
from stacking_model.utils.validator import Validator
|
||||
from stacking_model.utils.metrics import PerformanceMonitor, Timer, MemoryMonitor
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
@PerformanceMonitor.measure_time_and_memory
|
||||
def main():
|
||||
"""主程序入口"""
|
||||
print("=" * 60)
|
||||
print("Stacking Model Pipeline")
|
||||
print("=" * 60)
|
||||
|
||||
output_dir = Path("outputs")
|
||||
|
||||
# 初始化
|
||||
logger.section("Stacking Model Pipeline")
|
||||
output_dir = Path(default_config.visualization.output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
checkpoint = Checkpoint(str(output_dir))
|
||||
|
||||
# 1. 生成数据
|
||||
print("\nStep 1: Generating Data")
|
||||
print("-" * 60)
|
||||
|
||||
generator = DataGenerator()
|
||||
X, y = generator.generate_data(n_samples=2000, n_features=20)
|
||||
print(f"Data shape: X={X.shape}, y={y.shape}")
|
||||
print(f"Features: {', '.join(X.columns[:5])}...")
|
||||
print(f"Target stats - Mean: {y.mean():.2f}, Std: {y.std():.2f}")
|
||||
|
||||
# 2. 预处理
|
||||
print("\n\nStep 2: Preprocessing Data")
|
||||
print("-" * 60)
|
||||
|
||||
preprocessor = Preprocessor(test_size=0.2, random_state=42)
|
||||
X_train, X_test, y_train, y_test = preprocessor.split_and_scale(X, y)
|
||||
|
||||
# 进一步分割训练集为训练和验证
|
||||
from sklearn.model_selection import train_test_split
|
||||
X_train_split, X_val, y_train_split, y_val = train_test_split(
|
||||
X_train, y_train, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
print(f"Train set: {X_train_split.shape}")
|
||||
print(f"Val set: {X_val.shape}")
|
||||
print(f"Test set: {X_test.shape}")
|
||||
|
||||
# 3. 训练 Stacking 模型
|
||||
print("\n\nStep 3: Training Stacking Model")
|
||||
print("-" * 60)
|
||||
|
||||
stacking_model = StackingModel(random_state=42)
|
||||
stacking_model.fit(X_train_split, y_train_split, X_val, y_val)
|
||||
|
||||
# 4. 评估
|
||||
print("\n\nStep 4: Model Evaluation")
|
||||
print("-" * 60)
|
||||
|
||||
evaluator = ModelEvaluator()
|
||||
|
||||
# Stacking 模型评估
|
||||
y_pred_stacking = stacking_model.predict(X_test)
|
||||
stacking_results = evaluator.evaluate(
|
||||
y_test, y_pred_stacking, "Stacking Model"
|
||||
)
|
||||
|
||||
# 基础模型评估
|
||||
print("\n\nBase Models Evaluation on Test Set:")
|
||||
print("-" * 60)
|
||||
|
||||
base_models = stacking_model.get_base_models()
|
||||
base_results = {}
|
||||
base_predictions = {}
|
||||
|
||||
for name, model in base_models.items():
|
||||
y_pred_base = model.predict(X_test)
|
||||
base_results[name] = evaluator.evaluate(y_test, y_pred_base, name)
|
||||
base_predictions[name] = y_pred_base
|
||||
|
||||
# 5. 对比结果
|
||||
print("\n\nStep 5: Final Comparison")
|
||||
all_results = {
|
||||
"Stacking Model": stacking_results,
|
||||
**base_results,
|
||||
}
|
||||
evaluator.print_comparison(all_results)
|
||||
|
||||
# 6. 可视化
|
||||
print("\n\nStep 6: Visualization")
|
||||
print("-" * 60)
|
||||
|
||||
plotter = ModelPlotter(figsize=(14, 8))
|
||||
|
||||
# 6.1 预测值对比
|
||||
print("Generating prediction comparison plots...")
|
||||
all_predictions = {
|
||||
"Stacking Model": y_pred_stacking,
|
||||
**base_predictions,
|
||||
}
|
||||
plotter.plot_predictions_comparison(
|
||||
y_test.values,
|
||||
all_predictions,
|
||||
save_path=str(output_dir / "predictions_comparison.png")
|
||||
)
|
||||
|
||||
# 6.2 残差分析
|
||||
print("Generating residuals plots...")
|
||||
plotter.plot_residuals(
|
||||
y_test.values,
|
||||
all_predictions,
|
||||
save_path=str(output_dir / "residuals_plot.png")
|
||||
)
|
||||
|
||||
# 6.3 残差分布
|
||||
print("Generating residuals distribution plots...")
|
||||
plotter.plot_residuals_distribution(
|
||||
y_test.values,
|
||||
all_predictions,
|
||||
save_path=str(output_dir / "residuals_distribution.png")
|
||||
)
|
||||
|
||||
# 6.4 指标对比柱状图
|
||||
print("Generating metrics comparison bar chart...")
|
||||
plotter.plot_metrics_comparison(
|
||||
all_results,
|
||||
metrics=["r2", "rmse", "mae"],
|
||||
save_path=str(output_dir / "metrics_comparison.png")
|
||||
)
|
||||
|
||||
# 6.5 模型性能热力图
|
||||
print("Generating model comparison heatmap...")
|
||||
plotter.plot_model_comparison_heatmap(
|
||||
all_results,
|
||||
metrics=["r2", "rmse", "mae"],
|
||||
save_path=str(output_dir / "model_heatmap.png")
|
||||
)
|
||||
|
||||
# 6.6 特征重要性
|
||||
print("Generating feature importance plots...")
|
||||
for name, model in base_models.items():
|
||||
if hasattr(model, 'feature_importances_'):
|
||||
plotter.plot_feature_importance(
|
||||
feature_names=X_test.columns.tolist(),
|
||||
importances=model.feature_importances_,
|
||||
model_name=name,
|
||||
top_n=15,
|
||||
save_path=str(output_dir / f"feature_importance_{name.lower().replace(' ', '_')}.png")
|
||||
try:
|
||||
# 1. 生成数据
|
||||
logger.subsection("Step 1: Generating Data")
|
||||
with Timer("Data generation"):
|
||||
generator = DataGenerator()
|
||||
X, y = generator.generate_data(
|
||||
n_samples=default_config.data.n_samples,
|
||||
n_features=default_config.data.n_features,
|
||||
)
|
||||
|
||||
# 7. 优势分析
|
||||
print("\n" + "=" * 60)
|
||||
print("Analysis")
|
||||
print("=" * 60)
|
||||
|
||||
stacking_r2 = stacking_results["r2"]
|
||||
max_base_r2 = max(v["r2"] for v in base_results.values())
|
||||
improvement = ((stacking_r2 - max_base_r2) / abs(max_base_r2)) * 100
|
||||
|
||||
print(f"\nStacking R² vs Best Base Model:")
|
||||
print(f" Stacking R²: {stacking_r2:.4f}")
|
||||
print(f" Best Base Model R²: {max_base_r2:.4f}")
|
||||
print(f" Improvement: {improvement:+.2f}%")
|
||||
|
||||
stacking_rmse = stacking_results["rmse"]
|
||||
min_base_rmse = min(v["rmse"] for v in base_results.values())
|
||||
rmse_improvement = ((min_base_rmse - stacking_rmse) / min_base_rmse) * 100
|
||||
|
||||
print(f"\nStacking RMSE vs Best Base Model:")
|
||||
print(f" Stacking RMSE: {stacking_rmse:.4f}")
|
||||
print(f" Best Base RMSE: {min_base_rmse:.4f}")
|
||||
print(f" Improvement: {rmse_improvement:+.2f}%")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Pipeline Complete")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
Validator.validate_data(X, y)
|
||||
logger.info(f"Data shape: X={X.shape}, y={y.shape}")
|
||||
logger.info(f"Features: {', '.join(X.columns[:5].tolist())}...")
|
||||
logger.info(f"Target stats - Mean: {y.mean():.2f}, Std: {y.std():.2f}")
|
||||
|
||||
# 2. 预处理
|
||||
logger.subsection("Step 2: Preprocessing Data")
|
||||
with Timer("Data preprocessing") as timer:
|
||||
preprocessor = Preprocessor(
|
||||
test_size=default_config.data.test_size,
|
||||
random_state=default_config.data.random_state
|
||||
)
|
||||
X_train, X_test, y_train, y_test = preprocessor.split_and_scale(X, y)
|
||||
X_train_split, X_val, y_train_split, y_val = train_test_split(
|
||||
X_train, y_train,
|
||||
test_size=default_config.data.val_size,
|
||||
random_state=default_config.data.random_state
|
||||
)
|
||||
|
||||
logger.info(f"Train set: {X_train_split.shape}")
|
||||
logger.info(f"Val set: {X_val.shape}")
|
||||
logger.info(f"Test set: {X_test.shape}")
|
||||
logger.info(f"Preprocessing took {timer.elapsed:.2f}s")
|
||||
|
||||
# 3. 训练模型
|
||||
logger.subsection("Step 3: Training Stacking Model")
|
||||
with MemoryMonitor("Model training") as mem_monitor:
|
||||
with Timer("Stacking model training"):
|
||||
stacking_model = StackingModel(
|
||||
random_state=default_config.model.random_state
|
||||
)
|
||||
stacking_model.fit(X_train_split, y_train_split, X_val, y_val)
|
||||
checkpoint.save_model(stacking_model, "stacking_model")
|
||||
|
||||
logger.success("Stacking model trained")
|
||||
logger.info(f"Memory used: {mem_monitor.current:.2f}MB")
|
||||
|
||||
# 4. 评估
|
||||
logger.subsection("Step 4: Model Evaluation")
|
||||
with Timer("Model evaluation"):
|
||||
evaluator = ModelEvaluator()
|
||||
|
||||
# Stacking 模型评估
|
||||
y_pred_stacking = stacking_model.predict(X_test)
|
||||
stacking_results = evaluator.evaluate(y_test, y_pred_stacking, "Stacking Model")
|
||||
Validator.validate_predictions(y_test, y_pred_stacking)
|
||||
|
||||
# 基础模型评估
|
||||
logger.subsection("Step 4.2: Base Models Evaluation")
|
||||
base_models = stacking_model.get_base_models()
|
||||
base_results = {}
|
||||
base_predictions = {}
|
||||
|
||||
for name, model in base_models.items():
|
||||
y_pred = model.predict(X_test)
|
||||
base_results[name] = evaluator.evaluate(y_test, y_pred, name)
|
||||
base_predictions[name] = y_pred
|
||||
|
||||
# 合并所有结果
|
||||
all_results = {"Stacking Model": stacking_results, **base_results}
|
||||
evaluator.print_comparison(all_results)
|
||||
|
||||
# 5. 可视化
|
||||
logger.subsection("Step 5: Visualization")
|
||||
plotter = ModelPlotter(
|
||||
figsize=default_config.visualization.figsize,
|
||||
style=default_config.visualization.style
|
||||
)
|
||||
|
||||
all_predictions = {"Stacking Model": y_pred_stacking, **base_predictions}
|
||||
|
||||
# 5.1 预测值对比
|
||||
with Timer("Generating prediction comparison plots"):
|
||||
logger.info("Generating prediction comparison plots...")
|
||||
plotter.plot_predictions_comparison(
|
||||
y_test.values,
|
||||
all_predictions,
|
||||
save_path=str(output_dir / "predictions_comparison.png")
|
||||
)
|
||||
logger.success("Prediction comparison saved")
|
||||
|
||||
# 5.2 残差分析
|
||||
with Timer("Generating residuals plots"):
|
||||
logger.info("Generating residuals plots...")
|
||||
plotter.plot_residuals(
|
||||
y_test.values,
|
||||
all_predictions,
|
||||
save_path=str(output_dir / "residuals_plot.png")
|
||||
)
|
||||
logger.success("Residuals plot saved")
|
||||
|
||||
# 5.3 残差分布
|
||||
with Timer("Generating residuals distribution plots"):
|
||||
logger.info("Generating residuals distribution plots...")
|
||||
plotter.plot_residuals_distribution(
|
||||
y_test.values,
|
||||
all_predictions,
|
||||
save_path=str(output_dir / "residuals_distribution.png")
|
||||
)
|
||||
logger.success("Residuals distribution saved")
|
||||
|
||||
# 5.4 指标对比
|
||||
with Timer("Generating metrics comparison bar chart"):
|
||||
logger.info("Generating metrics comparison bar chart...")
|
||||
plotter.plot_metrics_comparison(
|
||||
all_results,
|
||||
metrics=["r2", "rmse", "mae"],
|
||||
save_path=str(output_dir / "metrics_comparison.png")
|
||||
)
|
||||
logger.success("Metrics comparison saved")
|
||||
|
||||
# 5.5 热力图
|
||||
with Timer("Generating model comparison heatmap"):
|
||||
logger.info("Generating model comparison heatmap...")
|
||||
plotter.plot_model_comparison_heatmap(
|
||||
all_results,
|
||||
metrics=["r2", "rmse", "mae"],
|
||||
save_path=str(output_dir / "model_heatmap.png")
|
||||
)
|
||||
logger.success("Model heatmap saved")
|
||||
|
||||
# 5.6 特征重要性
|
||||
logger.subsection("Step 5.6: Feature Importance")
|
||||
feature_count = 0
|
||||
with MemoryMonitor("Feature importance generation"):
|
||||
for name, model in base_models.items():
|
||||
if hasattr(model, 'feature_importances_'):
|
||||
with Timer(f"Generating feature importance for {name}"):
|
||||
plotter.plot_feature_importance(
|
||||
feature_names=X_test.columns.tolist(),
|
||||
importances=model.feature_importances_,
|
||||
model_name=name,
|
||||
top_n=15,
|
||||
save_path=str(output_dir / f"feature_importance_{name.lower().replace(' ', '_')}.png")
|
||||
)
|
||||
feature_count += 1
|
||||
logger.success(f"Feature importance plots saved ({feature_count} models)")
|
||||
|
||||
# 6. 保存结果
|
||||
logger.subsection("Step 6: Saving Results")
|
||||
with Timer("Saving results"):
|
||||
checkpoint.save_results(all_results, "evaluation_results")
|
||||
logger.success("Results saved to checkpoint")
|
||||
|
||||
# 7. 性能分析
|
||||
logger.section("Analysis")
|
||||
stacking_r2 = stacking_results["r2"]
|
||||
max_base_r2 = max(v["r2"] for v in base_results.values())
|
||||
improvement_r2 = ((stacking_r2 - max_base_r2) / abs(max_base_r2)) * 100
|
||||
|
||||
stacking_rmse = stacking_results["rmse"]
|
||||
min_base_rmse = min(v["rmse"] for v in base_results.values())
|
||||
improvement_rmse = ((min_base_rmse - stacking_rmse) / min_base_rmse) * 100
|
||||
|
||||
logger.info("")
|
||||
logger.info("Stacking R2 vs Best Base Model:")
|
||||
logger.info(f" Stacking R2: {stacking_r2:.4f}")
|
||||
logger.info(f" Best Base Model R2: {max_base_r2:.4f}")
|
||||
logger.info(f" Improvement: {improvement_r2:+.2f}%")
|
||||
|
||||
logger.info("")
|
||||
logger.info("Stacking RMSE vs Best Base Model:")
|
||||
logger.info(f" Stacking RMSE: {stacking_rmse:.4f}")
|
||||
logger.info(f" Best Base RMSE: {min_base_rmse:.4f}")
|
||||
logger.info(f" Improvement: {improvement_rmse:+.2f}%")
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"All outputs saved to: {output_dir.absolute()}")
|
||||
|
||||
logger.section("Pipeline Complete!")
|
||||
|
||||
except Exception as e:
|
||||
logger.failure(str(e))
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 初始化日志文件
|
||||
log_dir = Path("logs")
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
logger.setup_file_handler(str(log_dir / "pipeline.log"))
|
||||
|
||||
main()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .base_models import BaseModels
|
||||
from .factory import ModelFactory
|
||||
from .stacking import StackingModel
|
||||
|
||||
__all__ = ["BaseModels", "StackingModel"]
|
||||
__all__ = ["BaseModels", "ModelFactory", "StackingModel"]
|
||||
|
||||
50
src/stacking_model/models/factory.py
Normal file
50
src/stacking_model/models/factory.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Dict, Type
|
||||
from sklearn.linear_model import Ridge, Lasso, ElasticNet
|
||||
from sklearn.svm import SVR
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""模型工厂"""
|
||||
|
||||
_models: Dict[str, Type] = {
|
||||
'ridge': Ridge,
|
||||
'lasso': Lasso,
|
||||
'elastic_net': ElasticNet,
|
||||
'svr': SVR,
|
||||
'knn': KNeighborsRegressor,
|
||||
'random_forest': RandomForestRegressor,
|
||||
'gradient_boosting': GradientBoostingRegressor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_model(cls, model_name: str, **kwargs):
|
||||
"""
|
||||
创建模型实例
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
**kwargs: 模型参数
|
||||
|
||||
Returns:
|
||||
模型实例
|
||||
"""
|
||||
if model_name not in cls._models:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_name}. "
|
||||
f"Available models: {list(cls._models.keys())}"
|
||||
)
|
||||
|
||||
model_class = cls._models[model_name]
|
||||
return model_class(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def register_model(cls, name: str, model_class: Type) -> None:
|
||||
"""注册自定义模型"""
|
||||
cls._models[name] = model_class
|
||||
|
||||
@classmethod
|
||||
def list_models(cls) -> list:
|
||||
"""列出所有可用模型"""
|
||||
return list(cls._models.keys())
|
||||
@@ -1,4 +1,8 @@
|
||||
from .preprocessing import Preprocessor
|
||||
from .checkpoint import Checkpoint
|
||||
from .evaluation import ModelEvaluator
|
||||
from .logger import logger
|
||||
from .metrics import PerformanceMonitor, MemoryMonitor, Timer
|
||||
from .validator import Validator
|
||||
from .preprocessing import Preprocessor
|
||||
|
||||
__all__ = ["Preprocessor", "ModelEvaluator"]
|
||||
__all__ = ["Checkpoint", "ModelEvaluator", "logger", "PerformanceMonitor", "MemoryMonitor", "Timer", "Preprocessor", "Validator"]
|
||||
|
||||
79
src/stacking_model/utils/checkpoint.py
Normal file
79
src/stacking_model/utils/checkpoint.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import pickle
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
import numpy as np
|
||||
from stacking_model.utils.logger import logger
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
"""模型检查点管理"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str = "checkpoints"):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
self.checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
def save_model(self, model: Any, model_name: str) -> str:
|
||||
"""保存模型"""
|
||||
try:
|
||||
path = self.checkpoint_dir / f"{model_name}.pkl"
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(model, f)
|
||||
logger.info(f"✓ Saved model: {path}")
|
||||
return str(path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save model: {e}")
|
||||
raise
|
||||
|
||||
def load_model(self, model_name: str) -> Any:
|
||||
"""加载模型"""
|
||||
try:
|
||||
path = self.checkpoint_dir / f"{model_name}.pkl"
|
||||
with open(path, 'rb') as f:
|
||||
model = pickle.load(f)
|
||||
logger.info(f"✓ Loaded model: {path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
|
||||
def save_results(self, results: Dict, filename: str) -> str:
|
||||
"""保存结果"""
|
||||
try:
|
||||
path = self.checkpoint_dir / f"{filename}.json"
|
||||
|
||||
# 转换 numpy 类型为 Python 原生类型
|
||||
results_serializable = self._convert_to_serializable(results)
|
||||
|
||||
with open(path, 'w') as f:
|
||||
json.dump(results_serializable, f, indent=4)
|
||||
logger.info(f"✓ Saved results: {path}")
|
||||
return str(path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save results: {e}")
|
||||
raise
|
||||
|
||||
def load_results(self, filename: str) -> Dict:
|
||||
"""加载结果"""
|
||||
try:
|
||||
path = self.checkpoint_dir / f"{filename}.json"
|
||||
with open(path, 'r') as f:
|
||||
results = json.load(f)
|
||||
logger.info(f"✓ Loaded results: {path}")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load results: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_serializable(obj: Any) -> Any:
|
||||
"""递归转换不可序列化的对象"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: Checkpoint._convert_to_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [Checkpoint._convert_to_serializable(item) for item in obj]
|
||||
elif isinstance(obj, (np.integer, np.floating)):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return obj
|
||||
86
src/stacking_model/utils/logger.py
Normal file
86
src/stacking_model/utils/logger.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Logger:
|
||||
"""统一日志管理"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Logger, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.logger = logging.getLogger("StackingModel")
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
self.logger.handlers.clear()
|
||||
|
||||
# 简单的控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_format = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)-8s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
console_handler.setFormatter(console_format)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
def setup_file_handler(self, log_file: str) -> None:
|
||||
"""添加文件处理器"""
|
||||
try:
|
||||
Path(log_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_format = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)-8s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_format)
|
||||
self.logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to setup file handler: {e}")
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
self.logger.warning(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self.logger.error(message)
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
self.logger.debug(message)
|
||||
|
||||
def section(self, title: str) -> None:
|
||||
"""打印分隔标题"""
|
||||
self.info("")
|
||||
self.info("=" * 70)
|
||||
self.info(f" {title}")
|
||||
self.info("=" * 70)
|
||||
|
||||
def subsection(self, title: str) -> None:
|
||||
"""打印子标题"""
|
||||
self.info("")
|
||||
self.info(title)
|
||||
self.info("-" * 70)
|
||||
|
||||
def success(self, message: str) -> None:
|
||||
"""打印成功信息"""
|
||||
self.info(f"[SUCCESS] {message}")
|
||||
|
||||
def failure(self, message: str) -> None:
|
||||
"""打印失败信息"""
|
||||
self.error(f"[FAIL] {message}")
|
||||
|
||||
|
||||
# 单例实例
|
||||
logger = Logger()
|
||||
335
src/stacking_model/utils/metrics.py
Normal file
335
src/stacking_model/utils/metrics.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import functools
|
||||
import time
|
||||
import tracemalloc
|
||||
from typing import Any, Callable, Optional
|
||||
from stacking_model.utils.logger import logger
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""性能监控装饰器"""
|
||||
|
||||
@staticmethod
|
||||
def measure_time(func: Callable) -> Callable:
|
||||
"""
|
||||
测量函数执行时间的装饰器
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
|
||||
Returns:
|
||||
包装后的函数
|
||||
|
||||
Example:
|
||||
@PerformanceMonitor.measure_time
|
||||
def my_function():
|
||||
pass
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
# 格式化时间显示
|
||||
if elapsed_time < 1:
|
||||
time_str = f"{elapsed_time*1000:.2f}ms"
|
||||
elif elapsed_time < 60:
|
||||
time_str = f"{elapsed_time:.2f}s"
|
||||
else:
|
||||
minutes = int(elapsed_time // 60)
|
||||
seconds = elapsed_time % 60
|
||||
time_str = f"{minutes}m {seconds:.2f}s"
|
||||
|
||||
logger.info(f"[TIME] {func.__name__}() executed in {time_str}")
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def memory_usage(func: Callable) -> Callable:
|
||||
"""
|
||||
测量函数内存使用的装饰器
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
|
||||
Returns:
|
||||
包装后的函数
|
||||
|
||||
Example:
|
||||
@PerformanceMonitor.memory_usage
|
||||
def my_function():
|
||||
pass
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
# 启动内存追踪
|
||||
tracemalloc.start()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
# 获取内存统计
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
# 格式化内存显示
|
||||
current_mb = current / 1024 / 1024
|
||||
peak_mb = peak / 1024 / 1024
|
||||
|
||||
logger.info(
|
||||
f"[MEMORY] {func.__name__}() - "
|
||||
f"Current: {current_mb:.2f}MB, Peak: {peak_mb:.2f}MB"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def measure_time_and_memory(func: Callable) -> Callable:
|
||||
"""
|
||||
同时测量函数执行时间和内存使用的装饰器
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
|
||||
Returns:
|
||||
包装后的函数
|
||||
|
||||
Example:
|
||||
@PerformanceMonitor.measure_time_and_memory
|
||||
def my_function():
|
||||
pass
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
start_time = time.perf_counter()
|
||||
tracemalloc.start()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
current_mb = current / 1024 / 1024
|
||||
peak_mb = peak / 1024 / 1024
|
||||
|
||||
# 格式化时间显示
|
||||
if elapsed_time < 1:
|
||||
time_str = f"{elapsed_time*1000:.2f}ms"
|
||||
elif elapsed_time < 60:
|
||||
time_str = f"{elapsed_time:.2f}s"
|
||||
else:
|
||||
minutes = int(elapsed_time // 60)
|
||||
seconds = elapsed_time % 60
|
||||
time_str = f"{minutes}m {seconds:.2f}s"
|
||||
|
||||
logger.info(
|
||||
f"[PERF] {func.__name__}() - "
|
||||
f"Time: {time_str}, Memory: {current_mb:.2f}MB (Peak: {peak_mb:.2f}MB)"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def profile(func: Callable, verbose: bool = False) -> Callable:
|
||||
"""
|
||||
详细的性能分析装饰器
|
||||
|
||||
Args:
|
||||
func: 被装饰的函数
|
||||
verbose: 是否打印详细信息
|
||||
|
||||
Returns:
|
||||
包装后的函数
|
||||
|
||||
Example:
|
||||
@PerformanceMonitor.profile(verbose=True)
|
||||
def my_function():
|
||||
pass
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
start_time = time.perf_counter()
|
||||
tracemalloc.start()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
end_time = time.perf_counter()
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
current_mb = current / 1024 / 1024
|
||||
peak_mb = peak / 1024 / 1024
|
||||
|
||||
# 格式化时间显示
|
||||
if elapsed_time < 1:
|
||||
time_str = f"{elapsed_time*1000:.2f}ms"
|
||||
elif elapsed_time < 60:
|
||||
time_str = f"{elapsed_time:.2f}s"
|
||||
else:
|
||||
minutes = int(elapsed_time // 60)
|
||||
seconds = elapsed_time % 60
|
||||
time_str = f"{minutes}m {seconds:.2f}s"
|
||||
|
||||
if verbose:
|
||||
logger.subsection(f"Performance Profile: {func.__name__}()")
|
||||
logger.info(f" Execution Time: {time_str}")
|
||||
logger.info(f" Current Memory: {current_mb:.2f}MB")
|
||||
logger.info(f" Peak Memory: {peak_mb:.2f}MB")
|
||||
|
||||
# 添加参数信息
|
||||
if args or kwargs:
|
||||
logger.info(f" Arguments:")
|
||||
if args:
|
||||
logger.info(f" - args: {len(args)} positional arguments")
|
||||
if kwargs:
|
||||
logger.info(f" - kwargs: {', '.join(kwargs.keys())}")
|
||||
else:
|
||||
logger.info(
|
||||
f"[PROFILE] {func.__name__}() - "
|
||||
f"Time: {time_str}, Memory: {current_mb:.2f}MB"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Timer:
|
||||
"""上下文管理器式的计时器"""
|
||||
|
||||
def __init__(self, name: str = "Operation"):
|
||||
"""
|
||||
初始化计时器
|
||||
|
||||
Args:
|
||||
name: 操作名称
|
||||
|
||||
Example:
|
||||
with Timer("Data loading"):
|
||||
# 执行操作
|
||||
pass
|
||||
"""
|
||||
self.name = name
|
||||
self.start_time: Optional[float] = None
|
||||
self.end_time: Optional[float] = None
|
||||
|
||||
def __enter__(self):
|
||||
"""进入上下文"""
|
||||
self.start_time = time.perf_counter()
|
||||
logger.info(f"[START] {self.name}...")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""退出上下文"""
|
||||
self.end_time = time.perf_counter()
|
||||
|
||||
# 类型检查:确保 start_time 不为 None
|
||||
if self.start_time is None:
|
||||
logger.warning(f"Timer for {self.name} was not properly started")
|
||||
return False
|
||||
|
||||
elapsed_time = self.end_time - self.start_time
|
||||
|
||||
# 格式化时间显示
|
||||
if elapsed_time < 1:
|
||||
time_str = f"{elapsed_time*1000:.2f}ms"
|
||||
elif elapsed_time < 60:
|
||||
time_str = f"{elapsed_time:.2f}s"
|
||||
else:
|
||||
minutes = int(elapsed_time // 60)
|
||||
seconds = elapsed_time % 60
|
||||
time_str = f"{minutes}m {seconds:.2f}s"
|
||||
|
||||
if exc_type is not None:
|
||||
logger.failure(f"{self.name} failed after {time_str}")
|
||||
else:
|
||||
logger.success(f"{self.name} completed in {time_str}")
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def elapsed(self) -> float:
|
||||
"""获取已用时间(秒)"""
|
||||
if self.start_time is None:
|
||||
return 0.0
|
||||
if self.end_time is None:
|
||||
return time.perf_counter() - self.start_time
|
||||
return self.end_time - self.start_time
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""内存监控上下文管理器"""
|
||||
|
||||
def __init__(self, name: str = "Operation"):
|
||||
"""
|
||||
初始化内存监控器
|
||||
|
||||
Args:
|
||||
name: 操作名称
|
||||
|
||||
Example:
|
||||
with MemoryMonitor("Data processing"):
|
||||
# 执行操作
|
||||
pass
|
||||
"""
|
||||
self.name = name
|
||||
self.start_memory: Optional[int] = None
|
||||
self.peak_memory: Optional[int] = None
|
||||
|
||||
def __enter__(self):
|
||||
"""进入上下文"""
|
||||
tracemalloc.start()
|
||||
self.start_memory = tracemalloc.get_traced_memory()[0]
|
||||
logger.info(f"[MEMORY] Starting memory monitoring for {self.name}...")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""退出上下文"""
|
||||
current, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
current_mb = current / 1024 / 1024
|
||||
peak_mb = peak / 1024 / 1024
|
||||
|
||||
# 类型检查:确保 start_memory 不为 None
|
||||
if self.start_memory is None:
|
||||
logger.warning(f"Memory monitor for {self.name} was not properly started")
|
||||
return False
|
||||
|
||||
start_mb = self.start_memory / 1024 / 1024
|
||||
delta_mb = (current - self.start_memory) / 1024 / 1024
|
||||
|
||||
if exc_type is not None:
|
||||
logger.failure(
|
||||
f"{self.name} failed - Current: {current_mb:.2f}MB, Peak: {peak_mb:.2f}MB"
|
||||
)
|
||||
else:
|
||||
logger.success(
|
||||
f"{self.name} - Start: {start_mb:.2f}MB, Current: {current_mb:.2f}MB, "
|
||||
f"Peak: {peak_mb:.2f}MB, Delta: {delta_mb:+.2f}MB"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def current(self) -> float:
|
||||
"""获取当前内存使用(MB)"""
|
||||
current, _ = tracemalloc.get_traced_memory()
|
||||
return current / 1024 / 1024
|
||||
|
||||
@property
|
||||
def peak(self) -> float:
|
||||
"""获取峰值内存使用(MB)"""
|
||||
_, peak = tracemalloc.get_traced_memory()
|
||||
return peak / 1024 / 1024
|
||||
61
src/stacking_model/utils/validator.py
Normal file
61
src/stacking_model/utils/validator.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from stacking_model.exceptions import ValidationException
|
||||
from stacking_model.utils.logger import logger
|
||||
|
||||
|
||||
class Validator:
|
||||
"""数据和参数验证"""
|
||||
|
||||
@staticmethod
|
||||
def validate_data(X: pd.DataFrame, y: pd.Series) -> None:
|
||||
"""验证数据"""
|
||||
if X is None or y is None:
|
||||
raise ValidationException("Data cannot be None")
|
||||
|
||||
if len(X) != len(y):
|
||||
raise ValidationException(
|
||||
f"X and y must have same length. Got {len(X)} and {len(y)}"
|
||||
)
|
||||
|
||||
if len(X) < 10:
|
||||
raise ValidationException(
|
||||
f"Data too small. Need at least 10 samples, got {len(X)}"
|
||||
)
|
||||
|
||||
if X.isnull().any().any():
|
||||
logger.warning("X contains NaN values")
|
||||
|
||||
if y.isnull().any():
|
||||
logger.warning("y contains NaN values")
|
||||
|
||||
logger.info(f"✓ Data validation passed: {X.shape[0]} samples, {X.shape[1]} features")
|
||||
|
||||
@staticmethod
|
||||
def validate_train_test_split(test_size: float) -> None:
|
||||
"""验证测试集比例"""
|
||||
if not 0 < test_size < 1:
|
||||
raise ValidationException(
|
||||
f"test_size must be between 0 and 1. Got {test_size}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_model_params(model_name: str, params: dict) -> None:
|
||||
"""验证模型参数"""
|
||||
if not isinstance(params, dict):
|
||||
raise ValidationException("Model params must be a dictionary")
|
||||
|
||||
logger.info(f"✓ Model params validation passed: {model_name}")
|
||||
|
||||
@staticmethod
|
||||
def validate_predictions(y_true, y_pred) -> None:
|
||||
"""验证预测结果"""
|
||||
if len(y_true) != len(y_pred):
|
||||
raise ValidationException(
|
||||
f"Prediction length mismatch: {len(y_true)} vs {len(y_pred)}"
|
||||
)
|
||||
|
||||
if np.isnan(y_pred).any():
|
||||
raise ValidationException("Predictions contain NaN values")
|
||||
|
||||
logger.info(f"✓ Predictions validation passed")
|
||||
Reference in New Issue
Block a user