Different types of base models adapted for each agent.
This commit is contained in:
@@ -7,9 +7,9 @@ import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
|
||||
from .llms import DeepSeekLLM, OpenAILLM, BaseLLM
|
||||
from .llms import DeepSeekLLM, OpenAILLM, KimiLLM, BaseLLM
|
||||
from .nodes import (
|
||||
ReportStructureNode,
|
||||
FirstSearchNode,
|
||||
@@ -19,7 +19,7 @@ from .nodes import (
|
||||
ReportFormattingNode
|
||||
)
|
||||
from .state import State
|
||||
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer
|
||||
from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer
|
||||
from .utils import Config, load_config, format_search_results_for_prompt
|
||||
|
||||
|
||||
@@ -50,6 +50,9 @@ class DeepSearchAgent:
|
||||
# 初始化搜索工具集
|
||||
self.search_agency = MediaCrawlerDB()
|
||||
|
||||
# 初始化情感分析器
|
||||
self.sentiment_analyzer = multilingual_sentiment_analyzer
|
||||
|
||||
# 初始化节点
|
||||
self._initialize_nodes()
|
||||
|
||||
@@ -62,6 +65,7 @@ class DeepSearchAgent:
|
||||
print(f"Deep Search Agent 已初始化")
|
||||
print(f"使用LLM: {self.llm_client.get_model_info()}")
|
||||
print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
|
||||
print(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
|
||||
|
||||
def _initialize_llm(self) -> BaseLLM:
|
||||
"""初始化LLM客户端"""
|
||||
@@ -75,6 +79,11 @@ class DeepSearchAgent:
|
||||
api_key=self.config.openai_api_key,
|
||||
model_name=self.config.openai_model
|
||||
)
|
||||
elif self.config.default_llm_provider == "kimi":
|
||||
return KimiLLM(
|
||||
api_key=self.config.kimi_api_key,
|
||||
model_name=self.config.kimi_model
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
|
||||
|
||||
@@ -113,7 +122,7 @@ class DeepSearchAgent:
|
||||
|
||||
def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse:
|
||||
"""
|
||||
执行指定的数据库查询工具(集成关键词优化中间件)
|
||||
执行指定的数据库查询工具(集成关键词优化中间件和情感分析)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称,可选值:
|
||||
@@ -122,11 +131,13 @@ class DeepSearchAgent:
|
||||
- "search_topic_by_date": 按日期搜索话题
|
||||
- "get_comments_for_topic": 获取话题评论
|
||||
- "search_topic_on_platform": 平台定向搜索
|
||||
- "analyze_sentiment": 对查询结果进行情感分析
|
||||
query: 搜索关键词/话题
|
||||
**kwargs: 额外参数(如start_date, end_date, platform, limit等)
|
||||
**kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等)
|
||||
enable_sentiment: 是否自动对搜索结果进行情感分析(默认True)
|
||||
|
||||
Returns:
|
||||
DBResponse对象
|
||||
DBResponse对象(可能包含情感分析结果)
|
||||
"""
|
||||
print(f" → 执行数据库查询工具: {tool_name}")
|
||||
|
||||
@@ -134,7 +145,36 @@ class DeepSearchAgent:
|
||||
if tool_name == "search_hot_content":
|
||||
time_period = kwargs.get("time_period", "week")
|
||||
limit = kwargs.get("limit", 100)
|
||||
return self.search_agency.search_hot_content(time_period=time_period, limit=limit)
|
||||
response = self.search_agency.search_hot_content(time_period=time_period, limit=limit)
|
||||
|
||||
# 检查是否需要进行情感分析
|
||||
enable_sentiment = kwargs.get("enable_sentiment", True)
|
||||
if enable_sentiment and response.results and len(response.results) > 0:
|
||||
print(f" 🎭 开始对热点内容进行情感分析...")
|
||||
sentiment_analysis = self._perform_sentiment_analysis(response.results)
|
||||
if sentiment_analysis:
|
||||
# 将情感分析结果添加到响应的parameters中
|
||||
response.parameters["sentiment_analysis"] = sentiment_analysis
|
||||
print(f" ✅ 情感分析完成")
|
||||
|
||||
return response
|
||||
|
||||
# 独立情感分析工具
|
||||
if tool_name == "analyze_sentiment":
|
||||
texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query
|
||||
sentiment_result = self.analyze_sentiment_only(texts)
|
||||
|
||||
# 构建DBResponse格式的响应
|
||||
return DBResponse(
|
||||
tool_name="analyze_sentiment",
|
||||
parameters={
|
||||
"texts": texts if isinstance(texts, list) else [texts],
|
||||
**kwargs
|
||||
},
|
||||
results=[], # 情感分析不返回搜索结果
|
||||
results_count=0,
|
||||
metadata=sentiment_result
|
||||
)
|
||||
|
||||
# 对于需要搜索词的工具,使用关键词优化中间件
|
||||
optimized_response = keyword_optimizer.optimize_keywords(
|
||||
@@ -154,31 +194,35 @@ class DeepSearchAgent:
|
||||
|
||||
try:
|
||||
if tool_name == "search_topic_globally":
|
||||
limit_per_table = kwargs.get("limit_per_table", 100)
|
||||
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
|
||||
limit_per_table = self.config.default_search_topic_globally_limit_per_table
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table)
|
||||
elif tool_name == "search_topic_by_date":
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit_per_table = kwargs.get("limit_per_table", 100)
|
||||
# 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
|
||||
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
|
||||
if not start_date or not end_date:
|
||||
raise ValueError("search_topic_by_date工具需要start_date和end_date参数")
|
||||
response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table)
|
||||
elif tool_name == "get_comments_for_topic":
|
||||
limit = kwargs.get("limit", 500) // len(optimized_response.optimized_keywords)
|
||||
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
|
||||
limit = self.config.default_get_comments_for_topic_limit // len(optimized_response.optimized_keywords)
|
||||
limit = max(limit, 50)
|
||||
response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit)
|
||||
elif tool_name == "search_topic_on_platform":
|
||||
platform = kwargs.get("platform")
|
||||
start_date = kwargs.get("start_date")
|
||||
end_date = kwargs.get("end_date")
|
||||
limit = kwargs.get("limit", 200) // len(optimized_response.optimized_keywords)
|
||||
# 使用配置文件中的默认值,按关键词数量分配,但保证最小值
|
||||
limit = self.config.default_search_topic_on_platform_limit // len(optimized_response.optimized_keywords)
|
||||
limit = max(limit, 30)
|
||||
if not platform:
|
||||
raise ValueError("search_topic_on_platform工具需要platform参数")
|
||||
response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit)
|
||||
else:
|
||||
print(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=100)
|
||||
response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.default_search_topic_globally_limit_per_table)
|
||||
|
||||
# 收集结果
|
||||
if response.results:
|
||||
@@ -209,6 +253,16 @@ class DeepSearchAgent:
|
||||
results_count=len(unique_results)
|
||||
)
|
||||
|
||||
# 检查是否需要进行情感分析
|
||||
enable_sentiment = kwargs.get("enable_sentiment", True)
|
||||
if enable_sentiment and unique_results and len(unique_results) > 0:
|
||||
print(f" 🎭 开始对搜索结果进行情感分析...")
|
||||
sentiment_analysis = self._perform_sentiment_analysis(unique_results)
|
||||
if sentiment_analysis:
|
||||
# 将情感分析结果添加到响应的parameters中
|
||||
integrated_response.parameters["sentiment_analysis"] = sentiment_analysis
|
||||
print(f" ✅ 情感分析完成")
|
||||
|
||||
return integrated_response
|
||||
|
||||
def _deduplicate_results(self, results: List) -> List:
|
||||
@@ -227,6 +281,99 @@ class DeepSearchAgent:
|
||||
|
||||
return unique_results
|
||||
|
||||
def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
对搜索结果执行情感分析
|
||||
|
||||
Args:
|
||||
results: 搜索结果列表
|
||||
|
||||
Returns:
|
||||
情感分析结果字典,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 初始化情感分析器(如果尚未初始化)
|
||||
if not self.sentiment_analyzer.is_initialized:
|
||||
print(" 初始化情感分析模型...")
|
||||
if not self.sentiment_analyzer.initialize():
|
||||
print(" ❌ 情感分析模型初始化失败")
|
||||
return None
|
||||
|
||||
# 将查询结果转换为字典格式
|
||||
results_dict = []
|
||||
for result in results:
|
||||
result_dict = {
|
||||
"content": result.title_or_content,
|
||||
"platform": result.platform,
|
||||
"author": result.author_nickname,
|
||||
"url": result.url,
|
||||
"publish_time": str(result.publish_time) if result.publish_time else None
|
||||
}
|
||||
results_dict.append(result_dict)
|
||||
|
||||
# 执行情感分析
|
||||
sentiment_analysis = self.sentiment_analyzer.analyze_query_results(
|
||||
query_results=results_dict,
|
||||
text_field="content",
|
||||
min_confidence=0.5
|
||||
)
|
||||
|
||||
return sentiment_analysis.get("sentiment_analysis")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]:
|
||||
"""
|
||||
独立的情感分析工具
|
||||
|
||||
Args:
|
||||
texts: 单个文本或文本列表
|
||||
|
||||
Returns:
|
||||
情感分析结果
|
||||
"""
|
||||
print(f" → 执行独立情感分析")
|
||||
|
||||
try:
|
||||
# 初始化情感分析器(如果尚未初始化)
|
||||
if not self.sentiment_analyzer.is_initialized:
|
||||
print(" 初始化情感分析模型...")
|
||||
if not self.sentiment_analyzer.initialize():
|
||||
return {
|
||||
"success": False,
|
||||
"error": "情感分析模型初始化失败",
|
||||
"results": []
|
||||
}
|
||||
|
||||
# 执行分析
|
||||
if isinstance(texts, str):
|
||||
result = self.sentiment_analyzer.analyze_single_text(texts)
|
||||
return {
|
||||
"success": True,
|
||||
"total_analyzed": 1,
|
||||
"results": [result.__dict__]
|
||||
}
|
||||
else:
|
||||
batch_result = self.sentiment_analyzer.analyze_batch(texts, show_progress=True)
|
||||
return {
|
||||
"success": True,
|
||||
"total_analyzed": batch_result.total_processed,
|
||||
"success_count": batch_result.success_count,
|
||||
"failed_count": batch_result.failed_count,
|
||||
"average_confidence": batch_result.average_confidence,
|
||||
"results": [result.__dict__ for result in batch_result.results]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 情感分析过程中发生错误: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"results": []
|
||||
}
|
||||
|
||||
def research(self, query: str, save_report: bool = True) -> str:
|
||||
"""
|
||||
执行深度研究
|
||||
@@ -356,17 +503,23 @@ class DeepSearchAgent:
|
||||
print(f" ⚠️ search_topic_on_platform工具缺少平台参数,改用全局搜索")
|
||||
search_tool = "search_topic_globally"
|
||||
|
||||
# 处理限制参数
|
||||
# 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
|
||||
if search_tool == "search_hot_content":
|
||||
time_period = search_output.get("time_period", "week")
|
||||
limit = search_output.get("limit", 100)
|
||||
limit = self.config.default_search_hot_content_limit
|
||||
search_kwargs["time_period"] = time_period
|
||||
search_kwargs["limit"] = limit
|
||||
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
|
||||
limit_per_table = search_output.get("limit_per_table", 100)
|
||||
if search_tool == "search_topic_globally":
|
||||
limit_per_table = self.config.default_search_topic_globally_limit_per_table
|
||||
else: # search_topic_by_date
|
||||
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
|
||||
search_kwargs["limit_per_table"] = limit_per_table
|
||||
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
|
||||
limit = search_output.get("limit", 200)
|
||||
if search_tool == "get_comments_for_topic":
|
||||
limit = self.config.default_get_comments_for_topic_limit
|
||||
else: # search_topic_on_platform
|
||||
limit = self.config.default_search_topic_on_platform_limit
|
||||
search_kwargs["limit"] = limit
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
@@ -374,8 +527,11 @@ class DeepSearchAgent:
|
||||
# 转换为兼容格式
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前100个作为上限
|
||||
max_results = min(len(search_response.results), 100)
|
||||
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
|
||||
if self.config.max_search_results_for_llm > 0:
|
||||
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
|
||||
else:
|
||||
max_results = len(search_response.results) # 不限制,传递所有结果
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title_or_content,
|
||||
@@ -479,14 +635,23 @@ class DeepSearchAgent:
|
||||
# 处理限制参数
|
||||
if search_tool == "search_hot_content":
|
||||
time_period = reflection_output.get("time_period", "week")
|
||||
limit = reflection_output.get("limit", 10)
|
||||
# 使用配置文件中的默认值,不允许agent控制limit参数
|
||||
limit = self.config.default_search_hot_content_limit
|
||||
search_kwargs["time_period"] = time_period
|
||||
search_kwargs["limit"] = limit
|
||||
elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
|
||||
limit_per_table = reflection_output.get("limit_per_table", 5)
|
||||
# 使用配置文件中的默认值,不允许agent控制limit_per_table参数
|
||||
if search_tool == "search_topic_globally":
|
||||
limit_per_table = self.config.default_search_topic_globally_limit_per_table
|
||||
else: # search_topic_by_date
|
||||
limit_per_table = self.config.default_search_topic_by_date_limit_per_table
|
||||
search_kwargs["limit_per_table"] = limit_per_table
|
||||
elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
|
||||
limit = reflection_output.get("limit", 20)
|
||||
# 使用配置文件中的默认值,不允许agent控制limit参数
|
||||
if search_tool == "get_comments_for_topic":
|
||||
limit = self.config.default_get_comments_for_topic_limit
|
||||
else: # search_topic_on_platform
|
||||
limit = self.config.default_search_topic_on_platform_limit
|
||||
search_kwargs["limit"] = limit
|
||||
|
||||
search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)
|
||||
@@ -494,8 +659,11 @@ class DeepSearchAgent:
|
||||
# 转换为兼容格式
|
||||
search_results = []
|
||||
if search_response and search_response.results:
|
||||
# 每种搜索工具都有其特定的结果数量,这里取前100个作为上限
|
||||
max_results = min(len(search_response.results), 100)
|
||||
# 使用配置文件控制传递给LLM的结果数量,0表示不限制
|
||||
if self.config.max_search_results_for_llm > 0:
|
||||
max_results = min(len(search_response.results), self.config.max_search_results_for_llm)
|
||||
else:
|
||||
max_results = len(search_response.results) # 不限制,传递所有结果
|
||||
for result in search_response.results[:max_results]:
|
||||
search_results.append({
|
||||
'title': result.title_or_content,
|
||||
|
||||
@@ -6,5 +6,6 @@ LLM调用模块
|
||||
from .base import BaseLLM
|
||||
from .deepseek import DeepSeekLLM
|
||||
from .openai_llm import OpenAILLM
|
||||
from .kimi import KimiLLM
|
||||
|
||||
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"]
|
||||
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM", "KimiLLM"]
|
||||
|
||||
144
InsightEngine/llms/kimi.py
Normal file
144
InsightEngine/llms/kimi.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Kimi LLM实现
|
||||
使用Moonshot AI的Kimi API进行文本生成
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from openai import OpenAI
|
||||
# 假设 .base 模块和 BaseLLM 类已存在
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class KimiLLM(BaseLLM):
|
||||
"""Kimi LLM实现类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
|
||||
"""
|
||||
初始化Kimi客户端
|
||||
|
||||
Args:
|
||||
api_key: Kimi API密钥,如果不提供则从环境变量读取
|
||||
model_name: 模型名称,默认使用kimi-k2-0711-preview
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("KIMI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("Kimi API Key未找到!请设置KIMI_API_KEY环境变量或在初始化时提供")
|
||||
|
||||
super().__init__(api_key, model_name)
|
||||
|
||||
# 初始化OpenAI客户端,使用Kimi的endpoint
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url="https://api.moonshot.cn/v1"
|
||||
)
|
||||
|
||||
self.default_model = model_name or self.get_default_model()
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""获取默认模型名称"""
|
||||
return "kimi-k2-0711-preview"
|
||||
|
||||
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
调用Kimi API生成回复
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数,如temperature、max_tokens等
|
||||
|
||||
Returns:
|
||||
Kimi生成的回复文本
|
||||
"""
|
||||
try:
|
||||
# 构建消息
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 智能计算max_tokens - 根据输入长度自动调整输出长度
|
||||
input_length = len(system_prompt) + len(user_prompt)
|
||||
if input_length > 100000: # 超长文本
|
||||
default_max_tokens = 81920
|
||||
elif input_length > 50000: # 超长文本
|
||||
default_max_tokens = 40960
|
||||
elif input_length > 20000: # 长文本
|
||||
default_max_tokens = 16384
|
||||
elif input_length > 5000: # 中等文本
|
||||
default_max_tokens = 8192
|
||||
else: # 短文本
|
||||
default_max_tokens = 4096
|
||||
|
||||
# 设置默认参数,针对长文本处理优化
|
||||
params = {
|
||||
"model": self.default_model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 0.6), # Kimi建议使用0.6
|
||||
"max_tokens": kwargs.get("max_tokens", default_max_tokens), # 智能调整token限制
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# 添加其他可选参数
|
||||
if "top_p" in kwargs:
|
||||
params["top_p"] = kwargs["top_p"]
|
||||
if "presence_penalty" in kwargs:
|
||||
params["presence_penalty"] = kwargs["presence_penalty"]
|
||||
if "frequency_penalty" in kwargs:
|
||||
params["frequency_penalty"] = kwargs["frequency_penalty"]
|
||||
if "stop" in kwargs:
|
||||
params["stop"] = kwargs["stop"]
|
||||
|
||||
# 输出调试信息(仅在使用Kimi时)
|
||||
print(f"[Kimi] 输入长度: {input_length}, 使用max_tokens: {params['max_tokens']}")
|
||||
|
||||
# 调用API
|
||||
response = self.client.chat.completions.create(**params)
|
||||
|
||||
# 提取回复内容
|
||||
if response.choices and response.choices[0].message:
|
||||
content = response.choices[0].message.content
|
||||
return self.validate_response(content)
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"Kimi API调用错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"provider": "Kimi",
|
||||
"model": self.default_model,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"max_context_length": "长文本支持(200K+ tokens)"
|
||||
}
|
||||
|
||||
# ==================== 代码修改部分 ====================
|
||||
def invoke_long_context(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
专门用于长文本处理的调用方法 (作为invoke的兼容接口)。
|
||||
此方法通过设置推荐的默认参数,然后调用通用的invoke方法来处理请求。
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
Kimi生成的回复文本
|
||||
"""
|
||||
# 为长文本场景,设置一个慷慨的默认 max_tokens,仅当用户未指定时生效。
|
||||
# 您原有的16384是一个非常合理的值。
|
||||
kwargs.setdefault("max_tokens", 16384)
|
||||
|
||||
# 直接调用核心的invoke方法,将所有参数(包括预设的默认值)传递给它。
|
||||
return self.invoke(system_prompt, user_prompt, **kwargs)
|
||||
@@ -39,8 +39,8 @@ output_schema_first_search = {
|
||||
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,search_topic_by_date和search_topic_on_platform工具可能需要"},
|
||||
"platform": {"type": "string", "description": "平台名称,search_topic_on_platform工具必需,可选值:bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba"},
|
||||
"time_period": {"type": "string", "description": "时间周期,search_hot_content工具可选,可选值:24h, week, year"},
|
||||
"limit": {"type": "integer", "description": "结果数量限制,各工具可选参数"},
|
||||
"limit_per_table": {"type": "integer", "description": "每表结果数量限制,search_topic_globally和search_topic_by_date工具可选"}
|
||||
"enable_sentiment": {"type": "boolean", "description": "是否启用自动情感分析,默认为true,适用于除analyze_sentiment外的所有搜索工具"},
|
||||
"texts": {"type": "array", "items": {"type": "string"}, "description": "文本列表,仅用于analyze_sentiment工具"}
|
||||
},
|
||||
"required": ["search_query", "search_tool", "reasoning"]
|
||||
}
|
||||
@@ -88,8 +88,8 @@ output_schema_reflection = {
|
||||
"end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,search_topic_by_date和search_topic_on_platform工具可能需要"},
|
||||
"platform": {"type": "string", "description": "平台名称,search_topic_on_platform工具必需,可选值:bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba"},
|
||||
"time_period": {"type": "string", "description": "时间周期,search_hot_content工具可选,可选值:24h, week, year"},
|
||||
"limit": {"type": "integer", "description": "结果数量限制,各工具可选参数"},
|
||||
"limit_per_table": {"type": "integer", "description": "每表结果数量限制,search_topic_globally和search_topic_by_date工具可选"}
|
||||
"enable_sentiment": {"type": "boolean", "description": "是否启用自动情感分析,默认为true,适用于除analyze_sentiment外的所有搜索工具"},
|
||||
"texts": {"type": "array", "items": {"type": "string"}, "description": "文本列表,仅用于analyze_sentiment工具"}
|
||||
},
|
||||
"required": ["search_query", "search_tool", "reasoning"]
|
||||
}
|
||||
@@ -155,34 +155,40 @@ SYSTEM_PROMPT_FIRST_SEARCH = f"""
|
||||
{json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)}
|
||||
</INPUT JSON SCHEMA>
|
||||
|
||||
你可以使用以下5种专业的本地舆情数据库查询工具来挖掘真实的民意和公众观点:
|
||||
你可以使用以下6种专业的本地舆情数据库查询工具来挖掘真实的民意和公众观点:
|
||||
|
||||
1. **search_hot_content** - 查找热点内容工具
|
||||
- 适用于:挖掘当前最受关注的舆情事件和话题
|
||||
- 特点:基于真实的点赞、评论、分享数据发现热门话题
|
||||
- 参数:time_period ('24h', 'week', 'year'),limit(数量限制)
|
||||
- 特点:基于真实的点赞、评论、分享数据发现热门话题,自动进行情感分析
|
||||
- 参数:time_period ('24h', 'week', 'year'),limit(数量限制),enable_sentiment(是否启用情感分析,默认True)
|
||||
|
||||
2. **search_topic_globally** - 全局话题搜索工具
|
||||
- 适用于:全面了解公众对特定话题的讨论和观点
|
||||
- 特点:覆盖B站、微博、抖音、快手、小红书、知乎、贴吧等主流平台的真实用户声音
|
||||
- 参数:limit_per_table(每个表的结果数量限制)
|
||||
- 特点:覆盖B站、微博、抖音、快手、小红书、知乎、贴吧等主流平台的真实用户声音,自动进行情感分析
|
||||
- 参数:limit_per_table(每个表的结果数量限制),enable_sentiment(是否启用情感分析,默认True)
|
||||
|
||||
3. **search_topic_by_date** - 按日期搜索话题工具
|
||||
- 适用于:追踪舆情事件的时间线发展和公众情绪变化
|
||||
- 特点:精确的时间范围控制,适合分析舆情演变过程
|
||||
- 特点:精确的时间范围控制,适合分析舆情演变过程,自动进行情感分析
|
||||
- 特殊要求:需要提供start_date和end_date参数,格式为'YYYY-MM-DD'
|
||||
- 参数:limit_per_table(每个表的结果数量限制)
|
||||
- 参数:limit_per_table(每个表的结果数量限制),enable_sentiment(是否启用情感分析,默认True)
|
||||
|
||||
4. **get_comments_for_topic** - 获取话题评论工具
|
||||
- 适用于:深度挖掘网民的真实态度、情感和观点
|
||||
- 特点:直接获取用户评论,了解民意走向和情感倾向
|
||||
- 参数:limit(评论总数量限制)
|
||||
- 特点:直接获取用户评论,了解民意走向和情感倾向,自动进行情感分析
|
||||
- 参数:limit(评论总数量限制),enable_sentiment(是否启用情感分析,默认True)
|
||||
|
||||
5. **search_topic_on_platform** - 平台定向搜索工具
|
||||
- 适用于:分析特定社交平台用户群体的观点特征
|
||||
- 特点:针对不同平台用户群体的观点差异进行精准分析
|
||||
- 特点:针对不同平台用户群体的观点差异进行精准分析,自动进行情感分析
|
||||
- 特殊要求:需要提供platform参数,可选start_date和end_date
|
||||
- 参数:platform(必须),start_date, end_date(可选),limit(数量限制)
|
||||
- 参数:platform(必须),start_date, end_date(可选),limit(数量限制),enable_sentiment(是否启用情感分析,默认True)
|
||||
|
||||
6. **analyze_sentiment** - 多语言情感分析工具
|
||||
- 适用于:对文本内容进行专门的情感倾向分析
|
||||
- 特点:支持中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言的情感分析,输出5级情感等级(非常负面、负面、中性、正面、非常正面)
|
||||
- 参数:texts(文本或文本列表),query也可用作单个文本输入
|
||||
- 用途:当搜索结果的情感倾向不明确或需要专门的情感分析时使用
|
||||
|
||||
**你的核心使命:挖掘真实的民意和人情味**
|
||||
|
||||
@@ -195,11 +201,16 @@ SYSTEM_PROMPT_FIRST_SEARCH = f"""
|
||||
- **贴近生活语言**:用简单、直接、口语化的词汇
|
||||
- **包含情感词汇**:网民常用的褒贬词、情绪词
|
||||
- **考虑话题热词**:相关的网络流行语、缩写、昵称
|
||||
4. **参数优化配置**:
|
||||
4. **情感分析策略选择**:
|
||||
- **自动情感分析**:默认启用(enable_sentiment: true),适用于搜索工具,能自动分析搜索结果的情感倾向
|
||||
- **专门情感分析**:当需要对特定文本进行详细情感分析时,使用analyze_sentiment工具
|
||||
- **关闭情感分析**:在某些特殊情况下(如纯事实性内容),可设置enable_sentiment: false
|
||||
5. **参数优化配置**:
|
||||
- search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD)
|
||||
- search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一)
|
||||
- 其他工具:合理配置limit参数以获取足够的样本(建议:search_hot_content limit>=100,search_topic_globally limit_per_table>=50,search_topic_by_date limit_per_table>=100,get_comments_for_topic limit>=500,search_topic_on_platform limit>=200)
|
||||
5. **阐述选择理由**:说明为什么这样的查询能够获得最真实的民意反馈
|
||||
- analyze_sentiment: 使用texts参数提供文本列表,或使用search_query作为单个文本
|
||||
- 系统自动配置数据量参数,无需手动设置limit或limit_per_table参数
|
||||
6. **阐述选择理由**:说明为什么这样的查询和情感分析策略能够获得最真实的民意反馈
|
||||
|
||||
**搜索词设计核心原则**:
|
||||
- **想象网友怎么说**:如果你是个普通网友,你会怎么讨论这个话题?
|
||||
@@ -251,7 +262,12 @@ SYSTEM_PROMPT_FIRST_SUMMARY = f"""
|
||||
2. **展现多元观点**:呈现不同平台、不同群体的观点差异和讨论重点
|
||||
3. **数据支撑分析**:用具体的点赞数、评论数、转发数等数据说明舆情热度
|
||||
4. **情感色彩描述**:准确描述公众的情感倾向(愤怒、支持、担忧、期待等)
|
||||
5. **避免套话官话**:使用贴近民众的语言,避免过度官方化的表述
|
||||
5. **智能运用情感分析**:
|
||||
- **整合情感数据**:如果搜索结果包含自动情感分析,要充分利用情感分布数据(如"正面情感占60%,负面情感占25%")
|
||||
- **情感趋势描述**:描述主要情感倾向和情感分布特征
|
||||
- **高置信度引用**:优先引用高置信度的情感分析结果
|
||||
- **情感细节分析**:结合具体的情感标签(非常正面、正面、中性、负面、非常负面)进行深度分析
|
||||
6. **避免套话官话**:使用贴近民众的语言,避免过度官方化的表述
|
||||
|
||||
撰写风格:
|
||||
- 语言生动,有感染力
|
||||
@@ -277,13 +293,14 @@ SYSTEM_PROMPT_REFLECTION = f"""
|
||||
{json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)}
|
||||
</INPUT JSON SCHEMA>
|
||||
|
||||
你可以使用以下5种专业的本地舆情数据库查询工具来深度挖掘民意:
|
||||
你可以使用以下6种专业的本地舆情数据库查询工具来深度挖掘民意:
|
||||
|
||||
1. **search_hot_content** - 查找热点内容工具
|
||||
2. **search_topic_globally** - 全局话题搜索工具
|
||||
3. **search_topic_by_date** - 按日期搜索话题工具
|
||||
4. **get_comments_for_topic** - 获取话题评论工具
|
||||
5. **search_topic_on_platform** - 平台定向搜索工具
|
||||
1. **search_hot_content** - 查找热点内容工具(自动情感分析)
|
||||
2. **search_topic_globally** - 全局话题搜索工具(自动情感分析)
|
||||
3. **search_topic_by_date** - 按日期搜索话题工具(自动情感分析)
|
||||
4. **get_comments_for_topic** - 获取话题评论工具(自动情感分析)
|
||||
5. **search_topic_on_platform** - 平台定向搜索工具(自动情感分析)
|
||||
6. **analyze_sentiment** - 多语言情感分析工具(专门的情感分析)
|
||||
|
||||
**反思的核心目标:让报告更有人情味和真实感**
|
||||
|
||||
@@ -311,7 +328,7 @@ SYSTEM_PROMPT_REFLECTION = f"""
|
||||
4. **参数配置要求**:
|
||||
- search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD)
|
||||
- search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一)
|
||||
- 其他工具:合理配置参数以获取多样化的民意样本(建议:search_hot_content limit>=100,search_topic_globally limit_per_table>=50,search_topic_by_date limit_per_table>=100,get_comments_for_topic limit>=500,search_topic_on_platform limit>=200)
|
||||
- 系统自动配置数据量参数,无需手动设置limit或limit_per_table参数
|
||||
|
||||
5. **阐述补充理由**:明确说明为什么需要这些额外的民意数据
|
||||
|
||||
@@ -357,9 +374,13 @@ SYSTEM_PROMPT_REFLECTION_SUMMARY = f"""
|
||||
优化策略:
|
||||
1. **融入新的民意数据**:将补充搜索到的真实用户声音整合到段落中
|
||||
2. **丰富情感表达**:增加具体的情感描述和社会情绪分析
|
||||
3. **补充遗漏观点**:添加之前缺失的不同群体、平台的观点
|
||||
4. **强化数据支撑**:用具体数字和案例让分析更有说服力
|
||||
5. **优化语言表达**:让文字更生动、更贴近民众,减少官方套话
|
||||
3. **深化情感分析**:
|
||||
- **整合情感变化**:如果有新的情感分析数据,对比前后情感变化趋势
|
||||
- **细化情感层次**:区分不同群体、平台的情感差异
|
||||
- **量化情感描述**:用具体的情感分布数据支撑分析(如"新增数据显示负面情感比例上升至40%")
|
||||
4. **补充遗漏观点**:添加之前缺失的不同群体、平台的观点
|
||||
5. **强化数据支撑**:用具体数字和案例让分析更有说服力
|
||||
6. **优化语言表达**:让文字更生动、更贴近民众,减少官方套话
|
||||
|
||||
注意事项:
|
||||
- 保留段落的核心观点和重要信息
|
||||
|
||||
@@ -14,6 +14,13 @@ from .keyword_optimizer import (
|
||||
KeywordOptimizationResponse,
|
||||
keyword_optimizer
|
||||
)
|
||||
from .sentiment_analyzer import (
|
||||
WeiboMultilingualSentimentAnalyzer,
|
||||
SentimentResult,
|
||||
BatchSentimentResult,
|
||||
multilingual_sentiment_analyzer,
|
||||
analyze_sentiment
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MediaCrawlerDB",
|
||||
@@ -22,5 +29,10 @@ __all__ = [
|
||||
"print_response_summary",
|
||||
"KeywordOptimizer",
|
||||
"KeywordOptimizationResponse",
|
||||
"keyword_optimizer"
|
||||
"keyword_optimizer",
|
||||
"WeiboMultilingualSentimentAnalyzer",
|
||||
"SentimentResult",
|
||||
"BatchSentimentResult",
|
||||
"multilingual_sentiment_analyzer",
|
||||
"analyze_sentiment"
|
||||
]
|
||||
|
||||
@@ -228,7 +228,7 @@ class KeywordOptimizer:
|
||||
|
||||
# 清理和验证关键词
|
||||
cleaned_keywords = []
|
||||
for keyword in keywords[:20]: # 最多5个
|
||||
for keyword in keywords[:20]: # 最多20个
|
||||
keyword = keyword.strip().strip('"\'""''')
|
||||
if keyword and len(keyword) <= 20: # 合理长度
|
||||
cleaned_keywords.append(keyword)
|
||||
|
||||
445
InsightEngine/tools/sentiment_analyzer.py
Normal file
445
InsightEngine/tools/sentiment_analyzer.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
多语言情感分析工具
|
||||
基于WeiboMultilingualSentiment模型为InsightEngine提供情感分析功能
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
|
||||
# 添加项目根目录到路径,以便导入WeiboMultilingualSentiment
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
weibo_sentiment_path = os.path.join(project_root, "SentimentAnalysisModel", "WeiboMultilingualSentiment")
|
||||
sys.path.append(weibo_sentiment_path)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentimentResult:
|
||||
"""情感分析结果数据类"""
|
||||
text: str
|
||||
sentiment_label: str
|
||||
confidence: float
|
||||
probability_distribution: Dict[str, float]
|
||||
success: bool = True
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSentimentResult:
|
||||
"""批量情感分析结果数据类"""
|
||||
results: List[SentimentResult]
|
||||
total_processed: int
|
||||
success_count: int
|
||||
failed_count: int
|
||||
average_confidence: float
|
||||
|
||||
|
||||
class WeiboMultilingualSentimentAnalyzer:
|
||||
"""
|
||||
多语言情感分析器
|
||||
封装WeiboMultilingualSentiment模型,为AI Agent提供情感分析功能
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化情感分析器"""
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.device = None
|
||||
self.is_initialized = False
|
||||
|
||||
# 情感标签映射(5级分类)
|
||||
self.sentiment_map = {
|
||||
0: "非常负面",
|
||||
1: "负面",
|
||||
2: "中性",
|
||||
3: "正面",
|
||||
4: "非常正面"
|
||||
}
|
||||
|
||||
print("WeiboMultilingualSentimentAnalyzer 已创建,调用 initialize() 来加载模型")
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""
|
||||
初始化模型和分词器
|
||||
|
||||
Returns:
|
||||
是否初始化成功
|
||||
"""
|
||||
if self.is_initialized:
|
||||
print("模型已经初始化,无需重复加载")
|
||||
return True
|
||||
|
||||
try:
|
||||
print("正在加载多语言情感分析模型...")
|
||||
|
||||
# 使用多语言情感分析模型
|
||||
model_name = "tabularisai/multilingual-sentiment-analysis"
|
||||
local_model_path = os.path.join(weibo_sentiment_path, "model")
|
||||
|
||||
# 检查本地是否已有模型
|
||||
if os.path.exists(local_model_path):
|
||||
print("从本地加载模型...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(local_model_path)
|
||||
else:
|
||||
print("首次使用,正在下载模型到本地...")
|
||||
# 下载并保存到本地
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
|
||||
# 保存到本地
|
||||
os.makedirs(local_model_path, exist_ok=True)
|
||||
self.tokenizer.save_pretrained(local_model_path)
|
||||
self.model.save_pretrained(local_model_path)
|
||||
print(f"模型已保存到: {local_model_path}")
|
||||
|
||||
# 设置设备
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.is_initialized = True
|
||||
|
||||
print(f"模型加载成功! 使用设备: {self.device}")
|
||||
print("支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言")
|
||||
print("情感等级: 非常负面、负面、中性、正面、非常正面")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"模型加载失败: {e}")
|
||||
print("请检查网络连接或模型文件")
|
||||
self.is_initialized = False
|
||||
return False
|
||||
|
||||
def _preprocess_text(self, text: str) -> str:
|
||||
"""
|
||||
文本预处理
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
处理后的文本
|
||||
"""
|
||||
# 基本文本清理
|
||||
if not text or not text.strip():
|
||||
return ""
|
||||
|
||||
# 去除多余空格
|
||||
text = re.sub(r'\s+', ' ', text.strip())
|
||||
|
||||
return text
|
||||
|
||||
def analyze_single_text(self, text: str) -> SentimentResult:
|
||||
"""
|
||||
对单个文本进行情感分析
|
||||
|
||||
Args:
|
||||
text: 要分析的文本
|
||||
|
||||
Returns:
|
||||
SentimentResult对象
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
sentiment_label="未初始化",
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="模型未初始化,请先调用 initialize() 方法"
|
||||
)
|
||||
|
||||
try:
|
||||
# 预处理文本
|
||||
processed_text = self._preprocess_text(text)
|
||||
|
||||
if not processed_text:
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
sentiment_label="输入错误",
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="输入文本为空或无效"
|
||||
)
|
||||
|
||||
# 分词编码
|
||||
inputs = self.tokenizer(
|
||||
processed_text,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
# 转移到设备
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
probabilities = torch.softmax(logits, dim=1)
|
||||
prediction = torch.argmax(probabilities, dim=1).item()
|
||||
|
||||
# 构建结果
|
||||
confidence = probabilities[0][prediction].item()
|
||||
label = self.sentiment_map[prediction]
|
||||
|
||||
# 构建概率分布字典
|
||||
prob_dist = {}
|
||||
for i, (label_name, prob) in enumerate(zip(self.sentiment_map.values(), probabilities[0])):
|
||||
prob_dist[label_name] = prob.item()
|
||||
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
sentiment_label=label,
|
||||
confidence=confidence,
|
||||
probability_distribution=prob_dist,
|
||||
success=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SentimentResult(
|
||||
text=text,
|
||||
sentiment_label="分析失败",
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message=f"预测时发生错误: {str(e)}"
|
||||
)
|
||||
|
||||
def analyze_batch(self, texts: List[str], show_progress: bool = True) -> BatchSentimentResult:
|
||||
"""
|
||||
批量情感分析
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
show_progress: 是否显示进度
|
||||
|
||||
Returns:
|
||||
BatchSentimentResult对象
|
||||
"""
|
||||
if not texts:
|
||||
return BatchSentimentResult(
|
||||
results=[],
|
||||
total_processed=0,
|
||||
success_count=0,
|
||||
failed_count=0,
|
||||
average_confidence=0.0
|
||||
)
|
||||
|
||||
results = []
|
||||
success_count = 0
|
||||
total_confidence = 0.0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
if show_progress and len(texts) > 1:
|
||||
print(f"处理进度: {i+1}/{len(texts)}")
|
||||
|
||||
result = self.analyze_single_text(text)
|
||||
results.append(result)
|
||||
|
||||
if result.success:
|
||||
success_count += 1
|
||||
total_confidence += result.confidence
|
||||
|
||||
average_confidence = total_confidence / success_count if success_count > 0 else 0.0
|
||||
failed_count = len(texts) - success_count
|
||||
|
||||
return BatchSentimentResult(
|
||||
results=results,
|
||||
total_processed=len(texts),
|
||||
success_count=success_count,
|
||||
failed_count=failed_count,
|
||||
average_confidence=average_confidence
|
||||
)
|
||||
|
||||
def analyze_query_results(self, query_results: List[Dict[str, Any]],
|
||||
text_field: str = "content",
|
||||
min_confidence: float = 0.5) -> Dict[str, Any]:
|
||||
"""
|
||||
对查询结果进行情感分析
|
||||
专门用于分析从MediaCrawlerDB返回的查询结果
|
||||
|
||||
Args:
|
||||
query_results: 查询结果列表,每个元素包含文本内容
|
||||
text_field: 文本内容字段名,默认为"content"
|
||||
min_confidence: 最小置信度阈值
|
||||
|
||||
Returns:
|
||||
包含情感分析结果的字典
|
||||
"""
|
||||
if not query_results:
|
||||
return {
|
||||
"sentiment_analysis": {
|
||||
"total_analyzed": 0,
|
||||
"sentiment_distribution": {},
|
||||
"high_confidence_results": [],
|
||||
"summary": "没有内容需要分析"
|
||||
}
|
||||
}
|
||||
|
||||
# 提取文本内容
|
||||
texts_to_analyze = []
|
||||
original_data = []
|
||||
|
||||
for item in query_results:
|
||||
# 尝试多个可能的文本字段
|
||||
text_content = ""
|
||||
for field in [text_field, "title_or_content", "content", "title", "text"]:
|
||||
if field in item and item[field]:
|
||||
text_content = str(item[field])
|
||||
break
|
||||
|
||||
if text_content.strip():
|
||||
texts_to_analyze.append(text_content)
|
||||
original_data.append(item)
|
||||
|
||||
if not texts_to_analyze:
|
||||
return {
|
||||
"sentiment_analysis": {
|
||||
"total_analyzed": 0,
|
||||
"sentiment_distribution": {},
|
||||
"high_confidence_results": [],
|
||||
"summary": "查询结果中没有找到可分析的文本内容"
|
||||
}
|
||||
}
|
||||
|
||||
# 执行批量情感分析
|
||||
print(f"正在对{len(texts_to_analyze)}条内容进行情感分析...")
|
||||
batch_result = self.analyze_batch(texts_to_analyze, show_progress=True)
|
||||
|
||||
# 统计情感分布
|
||||
sentiment_distribution = {}
|
||||
high_confidence_results = []
|
||||
|
||||
for result, original_item in zip(batch_result.results, original_data):
|
||||
if result.success:
|
||||
# 统计情感分布
|
||||
sentiment = result.sentiment_label
|
||||
if sentiment not in sentiment_distribution:
|
||||
sentiment_distribution[sentiment] = 0
|
||||
sentiment_distribution[sentiment] += 1
|
||||
|
||||
# 收集高置信度结果
|
||||
if result.confidence >= min_confidence:
|
||||
high_confidence_results.append({
|
||||
"original_data": original_item,
|
||||
"sentiment": result.sentiment_label,
|
||||
"confidence": result.confidence,
|
||||
"text_preview": result.text[:100] + "..." if len(result.text) > 100 else result.text
|
||||
})
|
||||
|
||||
# 生成情感分析摘要
|
||||
total_analyzed = batch_result.success_count
|
||||
if total_analyzed > 0:
|
||||
dominant_sentiment = max(sentiment_distribution.items(), key=lambda x: x[1])
|
||||
sentiment_summary = f"共分析{total_analyzed}条内容,主要情感倾向为'{dominant_sentiment[0]}'({dominant_sentiment[1]}条,占{dominant_sentiment[1]/total_analyzed*100:.1f}%)"
|
||||
else:
|
||||
sentiment_summary = "情感分析失败"
|
||||
|
||||
return {
|
||||
"sentiment_analysis": {
|
||||
"total_analyzed": total_analyzed,
|
||||
"success_rate": f"{batch_result.success_count}/{batch_result.total_processed}",
|
||||
"average_confidence": round(batch_result.average_confidence, 4),
|
||||
"sentiment_distribution": sentiment_distribution,
|
||||
"high_confidence_results": high_confidence_results, # 返回所有高置信度结果,不做限制
|
||||
"summary": sentiment_summary
|
||||
}
|
||||
}
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"model_name": "tabularisai/multilingual-sentiment-analysis",
|
||||
"supported_languages": [
|
||||
"中文", "英文", "西班牙文", "阿拉伯文", "日文", "韩文",
|
||||
"德文", "法文", "意大利文", "葡萄牙文", "俄文", "荷兰文",
|
||||
"波兰文", "土耳其文", "丹麦文", "希腊文", "芬兰文",
|
||||
"瑞典文", "挪威文", "匈牙利文", "捷克文", "保加利亚文"
|
||||
],
|
||||
"sentiment_levels": list(self.sentiment_map.values()),
|
||||
"is_initialized": self.is_initialized,
|
||||
"device": str(self.device) if self.device else "未设置"
|
||||
}
|
||||
|
||||
|
||||
# 创建全局实例(延迟初始化)
|
||||
multilingual_sentiment_analyzer = WeiboMultilingualSentimentAnalyzer()
|
||||
|
||||
|
||||
def analyze_sentiment(text_or_texts: Union[str, List[str]],
|
||||
initialize_if_needed: bool = True) -> Union[SentimentResult, BatchSentimentResult]:
|
||||
"""
|
||||
便捷的情感分析函数
|
||||
|
||||
Args:
|
||||
text_or_texts: 单个文本或文本列表
|
||||
initialize_if_needed: 如果模型未初始化,是否自动初始化
|
||||
|
||||
Returns:
|
||||
SentimentResult或BatchSentimentResult
|
||||
"""
|
||||
if initialize_if_needed and not multilingual_sentiment_analyzer.is_initialized:
|
||||
if not multilingual_sentiment_analyzer.initialize():
|
||||
# 如果初始化失败,返回失败结果
|
||||
if isinstance(text_or_texts, str):
|
||||
return SentimentResult(
|
||||
text=text_or_texts,
|
||||
sentiment_label="初始化失败",
|
||||
confidence=0.0,
|
||||
probability_distribution={},
|
||||
success=False,
|
||||
error_message="模型初始化失败"
|
||||
)
|
||||
else:
|
||||
return BatchSentimentResult(
|
||||
results=[],
|
||||
total_processed=0,
|
||||
success_count=0,
|
||||
failed_count=len(text_or_texts),
|
||||
average_confidence=0.0
|
||||
)
|
||||
|
||||
if isinstance(text_or_texts, str):
|
||||
return multilingual_sentiment_analyzer.analyze_single_text(text_or_texts)
|
||||
else:
|
||||
return multilingual_sentiment_analyzer.analyze_batch(text_or_texts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
analyzer = WeiboMultilingualSentimentAnalyzer()
|
||||
|
||||
if analyzer.initialize():
|
||||
# 测试单个文本
|
||||
result = analyzer.analyze_single_text("今天天气真好,心情特别棒!")
|
||||
print(f"单个文本分析: {result.sentiment_label} (置信度: {result.confidence:.4f})")
|
||||
|
||||
# 测试批量文本
|
||||
test_texts = [
|
||||
"这家餐厅的菜味道非常棒!",
|
||||
"服务态度太差了,很失望",
|
||||
"I absolutely love this product!",
|
||||
"The customer service was disappointing."
|
||||
]
|
||||
|
||||
batch_result = analyzer.analyze_batch(test_texts)
|
||||
print(f"\n批量分析: 成功 {batch_result.success_count}/{batch_result.total_processed}")
|
||||
|
||||
for result in batch_result.results:
|
||||
print(f"'{result.text[:30]}...' -> {result.sentiment_label} ({result.confidence:.4f})")
|
||||
else:
|
||||
print("模型初始化失败,无法进行测试")
|
||||
@@ -14,6 +14,7 @@ class Config:
|
||||
# API密钥
|
||||
deepseek_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
kimi_api_key: Optional[str] = None
|
||||
|
||||
# 数据库配置
|
||||
db_host: Optional[str] = None
|
||||
@@ -24,13 +25,14 @@ class Config:
|
||||
db_charset: str = "utf8mb4"
|
||||
|
||||
# 模型配置
|
||||
default_llm_provider: str = "deepseek" # deepseek 或 openai
|
||||
default_llm_provider: str = "deepseek" # deepseek、openai 或 kimi
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
openai_model: str = "gpt-4o-mini"
|
||||
kimi_model: str = "kimi-k2-0711-preview"
|
||||
|
||||
# 搜索配置
|
||||
search_timeout: int = 240
|
||||
max_content_length: int = 100000
|
||||
max_content_length: int = 500000 # 提高5倍以充分利用Kimi的长文本能力
|
||||
|
||||
# 数据库查询限制
|
||||
default_search_hot_content_limit: int = 100
|
||||
@@ -43,6 +45,10 @@ class Config:
|
||||
max_reflections: int = 3
|
||||
max_paragraphs: int = 6
|
||||
|
||||
# 结果处理限制
|
||||
max_search_results_for_llm: int = 0 # 0表示不限制,传递所有搜索结果给LLM
|
||||
max_high_confidence_sentiment_results: int = 0 # 0表示不限制,返回所有高置信度情感分析结果
|
||||
|
||||
# 输出配置
|
||||
output_dir: str = "reports"
|
||||
save_intermediate_states: bool = True
|
||||
@@ -102,6 +108,10 @@ class Config:
|
||||
|
||||
max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
|
||||
max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5),
|
||||
|
||||
max_search_results_for_llm=getattr(config_module, "MAX_SEARCH_RESULTS_FOR_LLM", 0),
|
||||
max_high_confidence_sentiment_results=getattr(config_module, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0),
|
||||
|
||||
output_dir=getattr(config_module, "OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True)
|
||||
)
|
||||
@@ -120,6 +130,7 @@ class Config:
|
||||
return cls(
|
||||
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
|
||||
openai_api_key=config_dict.get("OPENAI_API_KEY"),
|
||||
kimi_api_key=config_dict.get("KIMI_API_KEY"),
|
||||
|
||||
db_host=config_dict.get("DB_HOST"),
|
||||
db_user=config_dict.get("DB_USER"),
|
||||
@@ -131,9 +142,10 @@ class Config:
|
||||
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
|
||||
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
|
||||
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
kimi_model=config_dict.get("KIMI_MODEL", "kimi-k2-0711-preview"),
|
||||
|
||||
search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
|
||||
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "200000")),
|
||||
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "500000")),
|
||||
|
||||
default_search_hot_content_limit=int(config_dict.get("DEFAULT_SEARCH_HOT_CONTENT_LIMIT", "100")),
|
||||
default_search_topic_globally_limit_per_table=int(config_dict.get("DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", "50")),
|
||||
@@ -143,6 +155,10 @@ class Config:
|
||||
|
||||
max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
|
||||
max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")),
|
||||
|
||||
max_search_results_for_llm=int(config_dict.get("MAX_SEARCH_RESULTS_FOR_LLM", "0")),
|
||||
max_high_confidence_sentiment_results=int(config_dict.get("MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", "0")),
|
||||
|
||||
output_dir=config_dict.get("OUTPUT_DIR", "reports"),
|
||||
save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from .llms import DeepSeekLLM, OpenAILLM, BaseLLM
|
||||
from .llms import DeepSeekLLM, OpenAILLM, GeminiLLM, BaseLLM
|
||||
from .nodes import (
|
||||
ReportStructureNode,
|
||||
FirstSearchNode,
|
||||
@@ -67,6 +67,11 @@ class DeepSearchAgent:
|
||||
api_key=self.config.openai_api_key,
|
||||
model_name=self.config.openai_model
|
||||
)
|
||||
elif self.config.default_llm_provider == "gemini":
|
||||
return GeminiLLM(
|
||||
api_key=self.config.gemini_api_key,
|
||||
model_name=self.config.gemini_model
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}")
|
||||
|
||||
|
||||
@@ -6,5 +6,6 @@ LLM调用模块
|
||||
from .base import BaseLLM
|
||||
from .deepseek import DeepSeekLLM
|
||||
from .openai_llm import OpenAILLM
|
||||
from .gemini_llm import GeminiLLM
|
||||
|
||||
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"]
|
||||
__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM", "GeminiLLM"]
|
||||
|
||||
95
MediaEngine/llms/gemini_llm.py
Normal file
95
MediaEngine/llms/gemini_llm.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Gemini LLM实现
|
||||
使用Gemini 2.5-pro中转API进行文本生成
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from openai import OpenAI
|
||||
from .base import BaseLLM
|
||||
|
||||
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""Gemini LLM实现类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None):
|
||||
"""
|
||||
初始化Gemini客户端
|
||||
|
||||
Args:
|
||||
api_key: Gemini API密钥,如果不提供则从环境变量读取
|
||||
model_name: 模型名称,默认使用gemini-2.5-pro
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("Gemini API Key未找到!请设置GEMINI_API_KEY环境变量或在初始化时提供")
|
||||
|
||||
super().__init__(api_key, model_name)
|
||||
|
||||
# 初始化OpenAI客户端,使用Gemini的中转endpoint
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url="https://www.chataiapi.com/v1"
|
||||
)
|
||||
|
||||
self.default_model = model_name or self.get_default_model()
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""获取默认模型名称"""
|
||||
return "gemini-2.5-pro"
|
||||
|
||||
def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
调用Gemini API生成回复
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户输入
|
||||
**kwargs: 其他参数,如temperature、max_tokens等
|
||||
|
||||
Returns:
|
||||
Gemini生成的回复文本
|
||||
"""
|
||||
try:
|
||||
# 构建消息
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 设置默认参数
|
||||
params = {
|
||||
"model": self.default_model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 4000),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# 调用API
|
||||
response = self.client.chat.completions.create(**params)
|
||||
|
||||
# 提取回复内容
|
||||
if response.choices and response.choices[0].message:
|
||||
content = response.choices[0].message.content
|
||||
return self.validate_response(content)
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
print(f"Gemini API调用错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前模型信息
|
||||
|
||||
Returns:
|
||||
模型信息字典
|
||||
"""
|
||||
return {
|
||||
"provider": "Gemini",
|
||||
"model": self.default_model,
|
||||
"api_base": "https://www.chataiapi.com/v1"
|
||||
}
|
||||
@@ -14,12 +14,14 @@ class Config:
|
||||
# API密钥
|
||||
deepseek_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
gemini_api_key: Optional[str] = None
|
||||
bocha_api_key: Optional[str] = None
|
||||
|
||||
# 模型配置
|
||||
default_llm_provider: str = "deepseek" # deepseek 或 openai
|
||||
default_llm_provider: str = "deepseek" # deepseek、openai 或 gemini
|
||||
deepseek_model: str = "deepseek-chat"
|
||||
openai_model: str = "gpt-4o-mini"
|
||||
gemini_model: str = "gemini-2.5-pro"
|
||||
|
||||
# 搜索配置
|
||||
search_timeout: int = 240
|
||||
@@ -44,6 +46,10 @@ class Config:
|
||||
print("错误: OpenAI API Key未设置")
|
||||
return False
|
||||
|
||||
if self.default_llm_provider == "gemini" and not self.gemini_api_key:
|
||||
print("错误: Gemini API Key未设置")
|
||||
return False
|
||||
|
||||
if not self.bocha_api_key:
|
||||
print("错误: Bocha API Key未设置")
|
||||
return False
|
||||
@@ -65,11 +71,12 @@ class Config:
|
||||
return cls(
|
||||
deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None),
|
||||
openai_api_key=getattr(config_module, "OPENAI_API_KEY", None),
|
||||
gemini_api_key=getattr(config_module, "GEMINI_API_KEY", None),
|
||||
bocha_api_key=getattr(config_module, "BOCHA_API_KEY", None),
|
||||
default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"),
|
||||
deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"),
|
||||
openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"),
|
||||
|
||||
gemini_model=getattr(config_module, "GEMINI_MODEL", "gemini-2.5-pro"),
|
||||
search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240),
|
||||
max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000),
|
||||
max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2),
|
||||
@@ -92,11 +99,12 @@ class Config:
|
||||
return cls(
|
||||
deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"),
|
||||
openai_api_key=config_dict.get("OPENAI_API_KEY"),
|
||||
gemini_api_key=config_dict.get("GEMINI_API_KEY"),
|
||||
bocha_api_key=config_dict.get("BOCHA_API_KEY"),
|
||||
default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"),
|
||||
deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"),
|
||||
openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
|
||||
gemini_model=config_dict.get("GEMINI_MODEL", "gemini-2.5-pro"),
|
||||
search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")),
|
||||
max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")),
|
||||
max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")),
|
||||
|
||||
@@ -160,6 +160,31 @@ def show_multilingual_demo(tokenizer, model, device, sentiment_map):
|
||||
print(f"处理 {text} 时出错: {e}")
|
||||
|
||||
print("\n=== 示例结束 ===")
|
||||
|
||||
'''
|
||||
正在加载多语言情感分析模型...
|
||||
从本地加载模型...
|
||||
模型加载成功! 使用设备: cuda
|
||||
|
||||
============= 多语言情感分析 =============
|
||||
支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言
|
||||
情感等级: 非常负面、负面、中性、正面、非常正面
|
||||
输入文本进行分析 (输入 'q' 退出):
|
||||
输入 'demo' 查看多语言示例
|
||||
|
||||
请输入文本: 我喜欢你
|
||||
C:\Users\67093\.conda\envs\pytorch_python11\Lib\site-packages\transformers\models\distilbert\modeling_distilbert.py:401: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:263.)
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
预测结果: 正面 (置信度: 0.5204)
|
||||
详细概率分布:
|
||||
非常负面: 0.0329
|
||||
负面: 0.0263
|
||||
中性: 0.1987
|
||||
正面: 0.5204
|
||||
非常正面: 0.2216
|
||||
|
||||
请输入文本:
|
||||
'''
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -10,10 +10,10 @@ from datetime import datetime
|
||||
import json
|
||||
|
||||
# 添加src目录到Python路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.'))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from InsightEngine import DeepSearchAgent, Config
|
||||
from config import DEEPSEEK_API_KEY, DB_HOST, DB_USER, DB_PASSWORD, DB_NAME, DB_PORT, DB_CHARSET
|
||||
from config import DEEPSEEK_API_KEY, KIMI_API_KEY, DB_HOST, DB_USER, DB_PASSWORD, DB_NAME, DB_PORT, DB_CHARSET
|
||||
|
||||
|
||||
def main():
|
||||
@@ -31,20 +31,38 @@ def main():
|
||||
with st.sidebar:
|
||||
st.header("配置")
|
||||
|
||||
# 模型选择
|
||||
llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai", "kimi"])
|
||||
|
||||
# 高级配置
|
||||
st.subheader("高级配置")
|
||||
max_reflections = st.slider("反思次数", 1, 5, 2)
|
||||
max_content_length = st.number_input("最大内容长度", 10000, 500000, 200000) # 提高10倍:1000-50000-20000 → 10000-500000-200000
|
||||
|
||||
# 模型选择
|
||||
llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"])
|
||||
# 根据选择的模型动态调整默认值
|
||||
if llm_provider == "kimi":
|
||||
default_content_length = 500000 # Kimi支持长文本,使用更大的默认值
|
||||
max_limit = 1000000 # 提高上限
|
||||
st.info("💡 Kimi模型支持超长文本处理,建议使用更大的内容长度以充分利用其能力")
|
||||
else:
|
||||
default_content_length = 200000
|
||||
max_limit = 500000
|
||||
|
||||
max_content_length = st.number_input("最大内容长度", 10000, max_limit, default_content_length)
|
||||
|
||||
# 初始化所有可能的变量
|
||||
openai_key = ""
|
||||
kimi_key = ""
|
||||
|
||||
if llm_provider == "deepseek":
|
||||
model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"])
|
||||
else:
|
||||
elif llm_provider == "openai":
|
||||
model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"])
|
||||
openai_key = st.text_input("OpenAI API Key", type="password",
|
||||
value="")
|
||||
else: # kimi
|
||||
model_name = st.selectbox("Kimi模型", ["kimi-k2-0711-preview"])
|
||||
kimi_key = st.text_input("Kimi API Key", type="password",
|
||||
value="")
|
||||
|
||||
# 主界面
|
||||
col1, col2 = st.columns([2, 1])
|
||||
@@ -96,8 +114,13 @@ def main():
|
||||
st.error("请提供OpenAI API Key")
|
||||
return
|
||||
|
||||
if llm_provider == "kimi" and not kimi_key and not KIMI_API_KEY:
|
||||
st.error("请提供Kimi API Key或在配置文件中设置KIMI_API_KEY")
|
||||
return
|
||||
|
||||
# 自动使用配置文件中的API密钥和数据库配置
|
||||
deepseek_key = DEEPSEEK_API_KEY
|
||||
kimi_key_final = kimi_key if kimi_key else KIMI_API_KEY
|
||||
db_host = DB_HOST
|
||||
db_user = DB_USER
|
||||
db_password = DB_PASSWORD
|
||||
@@ -109,6 +132,7 @@ def main():
|
||||
config = Config(
|
||||
deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None,
|
||||
openai_api_key=openai_key if llm_provider == "openai" else None,
|
||||
kimi_api_key=kimi_key_final if llm_provider == "kimi" else None,
|
||||
db_host=db_host,
|
||||
db_user=db_user,
|
||||
db_password=db_password,
|
||||
@@ -118,6 +142,7 @@ def main():
|
||||
default_llm_provider=llm_provider,
|
||||
deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat",
|
||||
openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini",
|
||||
kimi_model=model_name if llm_provider == "kimi" else "kimi-k2-0711-preview",
|
||||
max_reflections=max_reflections,
|
||||
max_content_length=max_content_length,
|
||||
output_dir="insight_engine_streamlit_reports"
|
||||
|
||||
@@ -13,7 +13,7 @@ import json
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from MediaEngine import DeepSearchAgent, Config
|
||||
from config import DEEPSEEK_API_KEY, BOCHA_Web_Search_API_KEY
|
||||
from config import DEEPSEEK_API_KEY, BOCHA_Web_Search_API_KEY, GEMINI_API_KEY
|
||||
|
||||
|
||||
def main():
|
||||
@@ -37,14 +37,16 @@ def main():
|
||||
max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000)
|
||||
|
||||
# 模型选择
|
||||
llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"])
|
||||
llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai", "gemini"])
|
||||
|
||||
openai_key = "" # 初始化变量
|
||||
if llm_provider == "deepseek":
|
||||
model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"])
|
||||
else:
|
||||
elif llm_provider == "openai":
|
||||
model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"])
|
||||
openai_key = st.text_input("OpenAI API Key", type="password",
|
||||
value="")
|
||||
openai_key = st.text_input("OpenAI API Key", type="password", value="")
|
||||
else: # gemini
|
||||
model_name = st.selectbox("Gemini模型", ["gemini-2.5-pro"])
|
||||
|
||||
# 主界面
|
||||
col1, col2 = st.columns([2, 1])
|
||||
@@ -98,16 +100,19 @@ def main():
|
||||
|
||||
# 自动使用配置文件中的API密钥
|
||||
deepseek_key = DEEPSEEK_API_KEY
|
||||
gemini_key = GEMINI_API_KEY # 使用config.py中的Gemini API密钥
|
||||
bocha_key = BOCHA_Web_Search_API_KEY
|
||||
|
||||
# 创建配置
|
||||
config = Config(
|
||||
deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None,
|
||||
openai_api_key=openai_key if llm_provider == "openai" else None,
|
||||
gemini_api_key=gemini_key if llm_provider == "gemini" else None,
|
||||
bocha_api_key=bocha_key,
|
||||
default_llm_provider=llm_provider,
|
||||
deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat",
|
||||
openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini",
|
||||
gemini_model=model_name if llm_provider == "gemini" else "gemini-2.5-pro",
|
||||
max_reflections=max_reflections,
|
||||
max_content_length=max_content_length,
|
||||
output_dir="media_engine_streamlit_reports"
|
||||
|
||||
42
config.py
42
config.py
@@ -1,25 +1,37 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
智能舆情分析平台配置文件
|
||||
存储数据库连接信息和API密钥
|
||||
Intelligence Public Opinion Analysis Platform Configuration File
|
||||
Stores database connection information and API keys
|
||||
"""
|
||||
|
||||
# MySQL数据库配置
|
||||
DB_HOST = "rm-2zeib6b13f6tt9kncoo.mysql.rds.aliyuncs.com"
|
||||
# MySQL Database Configuration
|
||||
DB_HOST = "your_database_host" # e.g., "localhost" or "127.0.0.1"
|
||||
DB_PORT = 3306
|
||||
DB_USER = "root"
|
||||
DB_PASSWORD = "mneDccc7sHHANtFk"
|
||||
DB_NAME = "media_crawler"
|
||||
DB_USER = "your_database_user"
|
||||
DB_PASSWORD = "your_database_password"
|
||||
DB_NAME = "your_database_name"
|
||||
DB_CHARSET = "utf8mb4"
|
||||
|
||||
# agent1 DeepSeek API密钥
|
||||
DEEPSEEK_API_KEY = "sk-4bbc57fadd234666a3840f1a7edc1f2e"
|
||||
# DeepSeek API Key
|
||||
# 申请地址https://www.deepseek.com/
|
||||
DEEPSEEK_API_KEY = "your_deepseek_api_key"
|
||||
|
||||
# agent2 DeepSeek API密钥
|
||||
DEEPSEEK_API_KEY_2 = "sk-b26405d2e02f475c960d21c2acce61e7"
|
||||
# Tavily Search API Key
|
||||
# 申请地址https://www.tavily.com/
|
||||
TAVILY_API_KEY = "your_tavily_api_key"
|
||||
|
||||
# Tavily搜索API密钥
|
||||
TAVILY_API_KEY = "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"
|
||||
# Kimi API Key
|
||||
# 申请地址https://www.kimi.com/
|
||||
KIMI_API_KEY = "your_kimi_api_key"
|
||||
|
||||
# 博查Web Search API密钥
|
||||
BOCHA_Web_Search_API_KEY = "sk-496b37a2a1ee4915b438dd822b03de8d"
|
||||
# Gemini API Key (via OpenAI format proxy)
|
||||
# 申请地址hapi.chataiapi.com/
|
||||
GEMINI_API_KEY = "your_gemini_api_key"
|
||||
|
||||
# Bocha Search API Key
|
||||
# 申请地址https://open.bochaai.com/
|
||||
BOCHA_Web_Search_API_KEY = "your_bocha_web_search_api_key"
|
||||
|
||||
# Guiji Flow API Key
|
||||
# 申请地址https://siliconflow.cn/
|
||||
GUIJI_QWEN3_API_KEY = "your_guiji_qwen3_api_key"
|
||||
Reference in New Issue
Block a user