Fixing Streamlit bugs.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Streamlit Web界面
|
||||
为Deep Search Agent提供友好的Web界面
|
||||
为Insight Agent提供友好的Web界面
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -19,54 +19,25 @@ from config import DEEPSEEK_API_KEY, KIMI_API_KEY, DB_HOST, DB_USER, DB_PASSWORD
|
||||
def main():
|
||||
"""主函数"""
|
||||
st.set_page_config(
|
||||
page_title="Deep Search Agent",
|
||||
page_icon="🔍",
|
||||
page_title="Insight Agent",
|
||||
page_icon="",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
st.title("Insight Engine Agent")
|
||||
st.markdown("基于DeepSeek的本地舆情数据库深度分析AI代理")
|
||||
|
||||
# 侧边栏配置
|
||||
with st.sidebar:
|
||||
st.header("配置")
|
||||
|
||||
# 模型选择
|
||||
llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai", "kimi"])
|
||||
|
||||
# 高级配置
|
||||
st.subheader("高级配置")
|
||||
max_reflections = st.slider("反思次数", 1, 5, 2)
|
||||
|
||||
# 根据选择的模型动态调整默认值
|
||||
if llm_provider == "kimi":
|
||||
default_content_length = 500000 # Kimi支持长文本,使用更大的默认值
|
||||
max_limit = 1000000 # 提高上限
|
||||
st.info("💡 Kimi模型支持超长文本处理,建议使用更大的内容长度以充分利用其能力")
|
||||
else:
|
||||
default_content_length = 200000
|
||||
max_limit = 500000
|
||||
|
||||
max_content_length = st.number_input("最大内容长度", 10000, max_limit, default_content_length)
|
||||
|
||||
# 初始化所有可能的变量
|
||||
openai_key = ""
|
||||
kimi_key = ""
|
||||
|
||||
if llm_provider == "deepseek":
|
||||
model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"])
|
||||
elif llm_provider == "openai":
|
||||
model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"])
|
||||
openai_key = st.text_input("OpenAI API Key", type="password",
|
||||
value="")
|
||||
else: # kimi
|
||||
model_name = st.selectbox("Kimi模型", ["kimi-k2-0711-preview"])
|
||||
kimi_key = st.text_input("Kimi API Key", type="password",
|
||||
value="")
|
||||
|
||||
|
||||
st.title("Insight Agent")
|
||||
st.markdown("私有舆情数据库深度分析AI代理")
|
||||
|
||||
# ----- 配置被硬编码 -----
|
||||
# 强制使用 Kimi
|
||||
llm_provider = "kimi"
|
||||
model_name = "kimi-k2-0711-preview"
|
||||
# 默认高级配置
|
||||
max_reflections = 2
|
||||
max_content_length = 500000 # Kimi支持长文本
|
||||
|
||||
# 主界面
|
||||
col1, col2 = st.columns([2, 1])
|
||||
|
||||
|
||||
with col1:
|
||||
st.header("研究查询")
|
||||
query = st.text_area(
|
||||
@@ -74,21 +45,7 @@ def main():
|
||||
placeholder="例如:2025年人工智能发展趋势",
|
||||
height=100
|
||||
)
|
||||
|
||||
# 预设查询示例
|
||||
st.subheader("示例查询")
|
||||
example_queries = [
|
||||
"2025年人工智能发展趋势",
|
||||
"深度学习在医疗领域的应用",
|
||||
"区块链技术的最新发展",
|
||||
"可持续能源技术趋势",
|
||||
"量子计算的发展现状"
|
||||
]
|
||||
|
||||
selected_example = st.selectbox("选择示例查询", ["自定义"] + example_queries)
|
||||
if selected_example != "自定义":
|
||||
query = selected_example
|
||||
|
||||
|
||||
with col2:
|
||||
st.header("状态信息")
|
||||
if 'agent' in st.session_state and hasattr(st.session_state.agent, 'state'):
|
||||
@@ -98,41 +55,36 @@ def main():
|
||||
st.progress(progress['progress_percentage'] / 100)
|
||||
else:
|
||||
st.info("尚未开始研究")
|
||||
|
||||
|
||||
# 执行按钮
|
||||
col1, col2, col3 = st.columns([1, 1, 1])
|
||||
with col2:
|
||||
col1_btn, col2_btn, col3_btn = st.columns([1, 1, 1])
|
||||
with col2_btn:
|
||||
start_research = st.button("开始研究", type="primary", use_container_width=True)
|
||||
|
||||
|
||||
# 验证配置
|
||||
if start_research:
|
||||
if not query.strip():
|
||||
st.error("请输入研究查询")
|
||||
return
|
||||
|
||||
if llm_provider == "openai" and not openai_key:
|
||||
st.error("请提供OpenAI API Key")
|
||||
|
||||
# 由于强制使用Kimi,只检查KIMI_API_KEY
|
||||
if not KIMI_API_KEY:
|
||||
st.error("请在您的配置文件(config.py)中设置KIMI_API_KEY")
|
||||
return
|
||||
|
||||
if llm_provider == "kimi" and not kimi_key and not KIMI_API_KEY:
|
||||
st.error("请提供Kimi API Key或在配置文件中设置KIMI_API_KEY")
|
||||
return
|
||||
|
||||
|
||||
# 自动使用配置文件中的API密钥和数据库配置
|
||||
deepseek_key = DEEPSEEK_API_KEY
|
||||
kimi_key_final = kimi_key if kimi_key else KIMI_API_KEY
|
||||
db_host = DB_HOST
|
||||
db_user = DB_USER
|
||||
db_password = DB_PASSWORD
|
||||
db_name = DB_NAME
|
||||
db_port = DB_PORT
|
||||
db_charset = DB_CHARSET
|
||||
|
||||
|
||||
# 创建配置
|
||||
config = Config(
|
||||
deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None,
|
||||
openai_api_key=openai_key if llm_provider == "openai" else None,
|
||||
kimi_api_key=kimi_key_final if llm_provider == "kimi" else None,
|
||||
deepseek_api_key=None,
|
||||
openai_api_key=None,
|
||||
kimi_api_key=KIMI_API_KEY, # 强制使用配置文件中的Kimi Key
|
||||
db_host=db_host,
|
||||
db_user=db_user,
|
||||
db_password=db_password,
|
||||
@@ -140,14 +92,14 @@ def main():
|
||||
db_port=db_port,
|
||||
db_charset=db_charset,
|
||||
default_llm_provider=llm_provider,
|
||||
deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat",
|
||||
openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini",
|
||||
kimi_model=model_name if llm_provider == "kimi" else "kimi-k2-0711-preview",
|
||||
deepseek_model="deepseek-chat", # 保留默认值以兼容
|
||||
openai_model="gpt-4o-mini", # 保留默认值以兼容
|
||||
kimi_model=model_name,
|
||||
max_reflections=max_reflections,
|
||||
max_content_length=max_content_length,
|
||||
output_dir="insight_engine_streamlit_reports"
|
||||
)
|
||||
|
||||
|
||||
# 执行研究
|
||||
execute_research(query, config)
|
||||
|
||||
@@ -158,51 +110,51 @@ def execute_research(query: str, config: Config):
|
||||
# 创建进度条
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
|
||||
# 初始化Agent
|
||||
status_text.text("正在初始化Agent...")
|
||||
agent = DeepSearchAgent(config)
|
||||
st.session_state.agent = agent
|
||||
|
||||
|
||||
progress_bar.progress(10)
|
||||
|
||||
|
||||
# 生成报告结构
|
||||
status_text.text("正在生成报告结构...")
|
||||
agent._generate_report_structure(query)
|
||||
progress_bar.progress(20)
|
||||
|
||||
|
||||
# 处理段落
|
||||
total_paragraphs = len(agent.state.paragraphs)
|
||||
for i in range(total_paragraphs):
|
||||
status_text.text(f"正在处理段落 {i+1}/{total_paragraphs}: {agent.state.paragraphs[i].title}")
|
||||
|
||||
status_text.text(f"正在处理段落 {i + 1}/{total_paragraphs}: {agent.state.paragraphs[i].title}")
|
||||
|
||||
# 初始搜索和总结
|
||||
agent._initial_search_and_summary(i)
|
||||
progress_value = 20 + (i + 0.5) / total_paragraphs * 60
|
||||
progress_bar.progress(int(progress_value))
|
||||
|
||||
|
||||
# 反思循环
|
||||
agent._reflection_loop(i)
|
||||
agent.state.paragraphs[i].research.mark_completed()
|
||||
|
||||
|
||||
progress_value = 20 + (i + 1) / total_paragraphs * 60
|
||||
progress_bar.progress(int(progress_value))
|
||||
|
||||
|
||||
# 生成最终报告
|
||||
status_text.text("正在生成最终报告...")
|
||||
final_report = agent._generate_final_report()
|
||||
progress_bar.progress(90)
|
||||
|
||||
|
||||
# 保存报告
|
||||
status_text.text("正在保存报告...")
|
||||
agent._save_report(final_report)
|
||||
progress_bar.progress(100)
|
||||
|
||||
|
||||
status_text.text("研究完成!")
|
||||
|
||||
|
||||
# 显示结果
|
||||
display_results(agent, final_report)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"研究过程中发生错误: {str(e)}")
|
||||
|
||||
@@ -210,61 +162,41 @@ def execute_research(query: str, config: Config):
|
||||
def display_results(agent: DeepSearchAgent, final_report: str):
|
||||
"""显示研究结果"""
|
||||
st.header("研究结果")
|
||||
|
||||
# 结果标签页
|
||||
tab1, tab2, tab3 = st.tabs(["最终报告", "详细信息", "下载"])
|
||||
|
||||
|
||||
# 结果标签页(已移除下载选项)
|
||||
tab1, tab2 = st.tabs(["最终报告", "详细信息"])
|
||||
|
||||
with tab1:
|
||||
st.markdown(final_report)
|
||||
|
||||
|
||||
with tab2:
|
||||
# 段落详情
|
||||
st.subheader("段落详情")
|
||||
for i, paragraph in enumerate(agent.state.paragraphs):
|
||||
with st.expander(f"段落 {i+1}: {paragraph.title}"):
|
||||
with st.expander(f"段落 {i + 1}: {paragraph.title}"):
|
||||
st.write("**预期内容:**", paragraph.content)
|
||||
st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..."
|
||||
if len(paragraph.research.latest_summary) > 300
|
||||
else paragraph.research.latest_summary)
|
||||
st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..."
|
||||
if len(paragraph.research.latest_summary) > 300
|
||||
else paragraph.research.latest_summary)
|
||||
st.write("**搜索次数:**", paragraph.research.get_search_count())
|
||||
st.write("**反思次数:**", paragraph.research.reflection_iteration)
|
||||
|
||||
|
||||
# 搜索历史
|
||||
st.subheader("搜索历史")
|
||||
all_searches = []
|
||||
for paragraph in agent.state.paragraphs:
|
||||
all_searches.extend(paragraph.research.search_history)
|
||||
|
||||
|
||||
if all_searches:
|
||||
for i, search in enumerate(all_searches):
|
||||
with st.expander(f"搜索 {i+1}: {search.query}"):
|
||||
with st.expander(f"搜索 {i + 1}: {search.query}"):
|
||||
st.write("**URL:**", search.url)
|
||||
st.write("**标题:**", search.title)
|
||||
st.write("**内容预览:**", search.content[:200] + "..." if len(search.content) > 200 else search.content)
|
||||
st.write("**内容预览:**",
|
||||
search.content[:200] + "..." if len(search.content) > 200 else search.content)
|
||||
if search.score:
|
||||
st.write("**相关度评分:**", search.score)
|
||||
|
||||
with tab3:
|
||||
# 下载选项
|
||||
st.subheader("下载报告")
|
||||
|
||||
# Markdown下载
|
||||
st.download_button(
|
||||
label="下载Markdown报告",
|
||||
data=final_report,
|
||||
file_name=f"deep_search_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md",
|
||||
mime="text/markdown"
|
||||
)
|
||||
|
||||
# JSON状态下载
|
||||
state_json = agent.state.to_json()
|
||||
st.download_button(
|
||||
label="下载状态文件",
|
||||
data=state_json,
|
||||
file_name=f"deep_search_state_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
||||
mime="application/json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Streamlit Web界面
|
||||
为Deep Search Agent提供友好的Web界面
|
||||
为Media Agent提供友好的Web界面
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -19,38 +19,25 @@ from config import DEEPSEEK_API_KEY, BOCHA_Web_Search_API_KEY, GEMINI_API_KEY
|
||||
def main():
|
||||
"""主函数"""
|
||||
st.set_page_config(
|
||||
page_title="Deep Search Agent",
|
||||
page_icon="🔍",
|
||||
page_title="Media Agent",
|
||||
page_icon="",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
st.title("Deep Search Agent")
|
||||
st.markdown("基于DeepSeek的无框架深度搜索AI代理")
|
||||
|
||||
# 侧边栏配置
|
||||
with st.sidebar:
|
||||
st.header("配置")
|
||||
|
||||
# 高级配置
|
||||
st.subheader("高级配置")
|
||||
max_reflections = st.slider("反思次数", 1, 5, 2)
|
||||
max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000)
|
||||
|
||||
# 模型选择
|
||||
llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai", "gemini"])
|
||||
|
||||
openai_key = "" # 初始化变量
|
||||
if llm_provider == "deepseek":
|
||||
model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"])
|
||||
elif llm_provider == "openai":
|
||||
model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"])
|
||||
openai_key = st.text_input("OpenAI API Key", type="password", value="")
|
||||
else: # gemini
|
||||
model_name = st.selectbox("Gemini模型", ["gemini-2.5-pro"])
|
||||
|
||||
|
||||
st.title("Media Agent")
|
||||
st.markdown("具备强大多模态能力的AI代理")
|
||||
|
||||
# ----- 配置被硬编码 -----
|
||||
# 强制使用 Gemini
|
||||
llm_provider = "gemini"
|
||||
model_name = "gemini-2.5-pro"
|
||||
# 默认高级配置
|
||||
max_reflections = 2
|
||||
max_content_length = 20000
|
||||
|
||||
# 主界面
|
||||
col1, col2 = st.columns([2, 1])
|
||||
|
||||
|
||||
with col1:
|
||||
st.header("研究查询")
|
||||
query = st.text_area(
|
||||
@@ -58,21 +45,7 @@ def main():
|
||||
placeholder="例如:2025年人工智能发展趋势",
|
||||
height=100
|
||||
)
|
||||
|
||||
# 预设查询示例
|
||||
st.subheader("示例查询")
|
||||
example_queries = [
|
||||
"2025年人工智能发展趋势",
|
||||
"深度学习在医疗领域的应用",
|
||||
"区块链技术的最新发展",
|
||||
"可持续能源技术趋势",
|
||||
"量子计算的发展现状"
|
||||
]
|
||||
|
||||
selected_example = st.selectbox("选择示例查询", ["自定义"] + example_queries)
|
||||
if selected_example != "自定义":
|
||||
query = selected_example
|
||||
|
||||
|
||||
with col2:
|
||||
st.header("状态信息")
|
||||
if 'agent' in st.session_state and hasattr(st.session_state.agent, 'state'):
|
||||
@@ -82,42 +55,45 @@ def main():
|
||||
st.progress(progress['progress_percentage'] / 100)
|
||||
else:
|
||||
st.info("尚未开始研究")
|
||||
|
||||
|
||||
# 执行按钮
|
||||
col1, col2, col3 = st.columns([1, 1, 1])
|
||||
with col2:
|
||||
col1_btn, col2_btn, col3_btn = st.columns([1, 1, 1])
|
||||
with col2_btn:
|
||||
start_research = st.button("开始研究", type="primary", use_container_width=True)
|
||||
|
||||
|
||||
# 验证配置
|
||||
if start_research:
|
||||
if not query.strip():
|
||||
st.error("请输入研究查询")
|
||||
return
|
||||
|
||||
if llm_provider == "openai" and not openai_key:
|
||||
st.error("请提供OpenAI API Key")
|
||||
|
||||
# 由于强制使用Gemini,检查相关的API密钥
|
||||
if not GEMINI_API_KEY:
|
||||
st.error("请在您的配置文件(config.py)中设置GEMINI_API_KEY")
|
||||
return
|
||||
|
||||
if not BOCHA_Web_Search_API_KEY:
|
||||
st.error("请在您的配置文件(config.py)中设置BOCHA_Web_Search_API_KEY")
|
||||
return
|
||||
|
||||
# 自动使用配置文件中的API密钥
|
||||
deepseek_key = DEEPSEEK_API_KEY
|
||||
gemini_key = GEMINI_API_KEY # 使用config.py中的Gemini API密钥
|
||||
gemini_key = GEMINI_API_KEY
|
||||
bocha_key = BOCHA_Web_Search_API_KEY
|
||||
|
||||
|
||||
# 创建配置
|
||||
config = Config(
|
||||
deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None,
|
||||
openai_api_key=openai_key if llm_provider == "openai" else None,
|
||||
gemini_api_key=gemini_key if llm_provider == "gemini" else None,
|
||||
deepseek_api_key=None,
|
||||
openai_api_key=None,
|
||||
gemini_api_key=gemini_key,
|
||||
bocha_api_key=bocha_key,
|
||||
default_llm_provider=llm_provider,
|
||||
deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat",
|
||||
openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini",
|
||||
gemini_model=model_name if llm_provider == "gemini" else "gemini-2.5-pro",
|
||||
deepseek_model="deepseek-chat", # 保留默认值以兼容
|
||||
openai_model="gpt-4o-mini", # 保留默认值以兼容
|
||||
gemini_model=model_name,
|
||||
max_reflections=max_reflections,
|
||||
max_content_length=max_content_length,
|
||||
output_dir="media_engine_streamlit_reports"
|
||||
)
|
||||
|
||||
|
||||
# 执行研究
|
||||
execute_research(query, config)
|
||||
|
||||
@@ -128,51 +104,51 @@ def execute_research(query: str, config: Config):
|
||||
# 创建进度条
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
|
||||
# 初始化Agent
|
||||
status_text.text("正在初始化Agent...")
|
||||
agent = DeepSearchAgent(config)
|
||||
st.session_state.agent = agent
|
||||
|
||||
|
||||
progress_bar.progress(10)
|
||||
|
||||
|
||||
# 生成报告结构
|
||||
status_text.text("正在生成报告结构...")
|
||||
agent._generate_report_structure(query)
|
||||
progress_bar.progress(20)
|
||||
|
||||
|
||||
# 处理段落
|
||||
total_paragraphs = len(agent.state.paragraphs)
|
||||
for i in range(total_paragraphs):
|
||||
status_text.text(f"正在处理段落 {i+1}/{total_paragraphs}: {agent.state.paragraphs[i].title}")
|
||||
|
||||
status_text.text(f"正在处理段落 {i + 1}/{total_paragraphs}: {agent.state.paragraphs[i].title}")
|
||||
|
||||
# 初始搜索和总结
|
||||
agent._initial_search_and_summary(i)
|
||||
progress_value = 20 + (i + 0.5) / total_paragraphs * 60
|
||||
progress_bar.progress(int(progress_value))
|
||||
|
||||
|
||||
# 反思循环
|
||||
agent._reflection_loop(i)
|
||||
agent.state.paragraphs[i].research.mark_completed()
|
||||
|
||||
|
||||
progress_value = 20 + (i + 1) / total_paragraphs * 60
|
||||
progress_bar.progress(int(progress_value))
|
||||
|
||||
|
||||
# 生成最终报告
|
||||
status_text.text("正在生成最终报告...")
|
||||
final_report = agent._generate_final_report()
|
||||
progress_bar.progress(90)
|
||||
|
||||
|
||||
# 保存报告
|
||||
status_text.text("正在保存报告...")
|
||||
agent._save_report(final_report)
|
||||
progress_bar.progress(100)
|
||||
|
||||
|
||||
status_text.text("研究完成!")
|
||||
|
||||
|
||||
# 显示结果
|
||||
display_results(agent, final_report)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"研究过程中发生错误: {str(e)}")
|
||||
|
||||
@@ -180,61 +156,41 @@ def execute_research(query: str, config: Config):
|
||||
def display_results(agent: DeepSearchAgent, final_report: str):
|
||||
"""显示研究结果"""
|
||||
st.header("研究结果")
|
||||
|
||||
# 结果标签页
|
||||
tab1, tab2, tab3 = st.tabs(["最终报告", "详细信息", "下载"])
|
||||
|
||||
|
||||
# 结果标签页(已移除下载选项)
|
||||
tab1, tab2 = st.tabs(["最终报告", "详细信息"])
|
||||
|
||||
with tab1:
|
||||
st.markdown(final_report)
|
||||
|
||||
|
||||
with tab2:
|
||||
# 段落详情
|
||||
st.subheader("段落详情")
|
||||
for i, paragraph in enumerate(agent.state.paragraphs):
|
||||
with st.expander(f"段落 {i+1}: {paragraph.title}"):
|
||||
with st.expander(f"段落 {i + 1}: {paragraph.title}"):
|
||||
st.write("**预期内容:**", paragraph.content)
|
||||
st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..."
|
||||
if len(paragraph.research.latest_summary) > 300
|
||||
else paragraph.research.latest_summary)
|
||||
st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..."
|
||||
if len(paragraph.research.latest_summary) > 300
|
||||
else paragraph.research.latest_summary)
|
||||
st.write("**搜索次数:**", paragraph.research.get_search_count())
|
||||
st.write("**反思次数:**", paragraph.research.reflection_iteration)
|
||||
|
||||
|
||||
# 搜索历史
|
||||
st.subheader("搜索历史")
|
||||
all_searches = []
|
||||
for paragraph in agent.state.paragraphs:
|
||||
all_searches.extend(paragraph.research.search_history)
|
||||
|
||||
|
||||
if all_searches:
|
||||
for i, search in enumerate(all_searches):
|
||||
with st.expander(f"搜索 {i+1}: {search.query}"):
|
||||
with st.expander(f"搜索 {i + 1}: {search.query}"):
|
||||
st.write("**URL:**", search.url)
|
||||
st.write("**标题:**", search.title)
|
||||
st.write("**内容预览:**", search.content[:200] + "..." if len(search.content) > 200 else search.content)
|
||||
st.write("**内容预览:**",
|
||||
search.content[:200] + "..." if len(search.content) > 200 else search.content)
|
||||
if search.score:
|
||||
st.write("**相关度评分:**", search.score)
|
||||
|
||||
with tab3:
|
||||
# 下载选项
|
||||
st.subheader("下载报告")
|
||||
|
||||
# Markdown下载
|
||||
st.download_button(
|
||||
label="下载Markdown报告",
|
||||
data=final_report,
|
||||
file_name=f"deep_search_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md",
|
||||
mime="text/markdown"
|
||||
)
|
||||
|
||||
# JSON状态下载
|
||||
state_json = agent.state.to_json()
|
||||
st.download_button(
|
||||
label="下载状态文件",
|
||||
data=state_json,
|
||||
file_name=f"deep_search_state_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
||||
mime="application/json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Streamlit Web界面
|
||||
为Deep Search Agent提供友好的Web界面
|
||||
为Query Agent提供友好的Web界面
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -19,36 +19,25 @@ from config import DEEPSEEK_API_KEY, TAVILY_API_KEY
|
||||
def main():
|
||||
"""主函数"""
|
||||
st.set_page_config(
|
||||
page_title="Deep Search Agent",
|
||||
page_icon="🔍",
|
||||
page_title="Query Agent",
|
||||
page_icon="",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
st.title("Deep Search Agent")
|
||||
st.markdown("基于DeepSeek的无框架深度搜索AI代理")
|
||||
|
||||
# 侧边栏配置
|
||||
with st.sidebar:
|
||||
st.header("配置")
|
||||
|
||||
# 高级配置
|
||||
st.subheader("高级配置")
|
||||
max_reflections = st.slider("反思次数", 1, 5, 2)
|
||||
max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000)
|
||||
|
||||
# 模型选择
|
||||
llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"])
|
||||
|
||||
if llm_provider == "deepseek":
|
||||
model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"])
|
||||
else:
|
||||
model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"])
|
||||
openai_key = st.text_input("OpenAI API Key", type="password",
|
||||
value="")
|
||||
|
||||
|
||||
st.title("Query Agent")
|
||||
st.markdown("具备强大网页搜索能力的AI代理")
|
||||
|
||||
# ----- 配置被硬编码 -----
|
||||
# 强制使用 DeepSeek
|
||||
llm_provider = "deepseek"
|
||||
model_name = "deepseek-chat"
|
||||
# 默认高级配置
|
||||
max_reflections = 2
|
||||
max_content_length = 20000
|
||||
|
||||
# 主界面
|
||||
col1, col2 = st.columns([2, 1])
|
||||
|
||||
|
||||
with col1:
|
||||
st.header("研究查询")
|
||||
query = st.text_area(
|
||||
@@ -56,21 +45,7 @@ def main():
|
||||
placeholder="例如:2025年人工智能发展趋势",
|
||||
height=100
|
||||
)
|
||||
|
||||
# 预设查询示例
|
||||
st.subheader("示例查询")
|
||||
example_queries = [
|
||||
"2025年人工智能发展趋势",
|
||||
"深度学习在医疗领域的应用",
|
||||
"区块链技术的最新发展",
|
||||
"可持续能源技术趋势",
|
||||
"量子计算的发展现状"
|
||||
]
|
||||
|
||||
selected_example = st.selectbox("选择示例查询", ["自定义"] + example_queries)
|
||||
if selected_example != "自定义":
|
||||
query = selected_example
|
||||
|
||||
|
||||
with col2:
|
||||
st.header("状态信息")
|
||||
if 'agent' in st.session_state and hasattr(st.session_state.agent, 'state'):
|
||||
@@ -80,39 +55,43 @@ def main():
|
||||
st.progress(progress['progress_percentage'] / 100)
|
||||
else:
|
||||
st.info("尚未开始研究")
|
||||
|
||||
|
||||
# 执行按钮
|
||||
col1, col2, col3 = st.columns([1, 1, 1])
|
||||
with col2:
|
||||
col1_btn, col2_btn, col3_btn = st.columns([1, 1, 1])
|
||||
with col2_btn:
|
||||
start_research = st.button("开始研究", type="primary", use_container_width=True)
|
||||
|
||||
|
||||
# 验证配置
|
||||
if start_research:
|
||||
if not query.strip():
|
||||
st.error("请输入研究查询")
|
||||
return
|
||||
|
||||
if llm_provider == "openai" and not openai_key:
|
||||
st.error("请提供OpenAI API Key")
|
||||
|
||||
# 由于强制使用DeepSeek,检查相关的API密钥
|
||||
if not DEEPSEEK_API_KEY:
|
||||
st.error("请在您的配置文件(config.py)中设置DEEPSEEK_API_KEY")
|
||||
return
|
||||
|
||||
if not TAVILY_API_KEY:
|
||||
st.error("请在您的配置文件(config.py)中设置TAVILY_API_KEY")
|
||||
return
|
||||
|
||||
# 自动使用配置文件中的API密钥
|
||||
deepseek_key = DEEPSEEK_API_KEY
|
||||
tavily_key = TAVILY_API_KEY
|
||||
|
||||
|
||||
# 创建配置
|
||||
config = Config(
|
||||
deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None,
|
||||
openai_api_key=openai_key if llm_provider == "openai" else None,
|
||||
deepseek_api_key=deepseek_key,
|
||||
openai_api_key=None,
|
||||
tavily_api_key=tavily_key,
|
||||
default_llm_provider=llm_provider,
|
||||
deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat",
|
||||
openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini",
|
||||
deepseek_model=model_name,
|
||||
openai_model="gpt-4o-mini", # 保留默认值以兼容
|
||||
max_reflections=max_reflections,
|
||||
max_content_length=max_content_length,
|
||||
output_dir="query_engine_streamlit_reports"
|
||||
)
|
||||
|
||||
|
||||
# 执行研究
|
||||
execute_research(query, config)
|
||||
|
||||
@@ -123,51 +102,51 @@ def execute_research(query: str, config: Config):
|
||||
# 创建进度条
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
|
||||
# 初始化Agent
|
||||
status_text.text("正在初始化Agent...")
|
||||
agent = DeepSearchAgent(config)
|
||||
st.session_state.agent = agent
|
||||
|
||||
|
||||
progress_bar.progress(10)
|
||||
|
||||
|
||||
# 生成报告结构
|
||||
status_text.text("正在生成报告结构...")
|
||||
agent._generate_report_structure(query)
|
||||
progress_bar.progress(20)
|
||||
|
||||
|
||||
# 处理段落
|
||||
total_paragraphs = len(agent.state.paragraphs)
|
||||
for i in range(total_paragraphs):
|
||||
status_text.text(f"正在处理段落 {i+1}/{total_paragraphs}: {agent.state.paragraphs[i].title}")
|
||||
|
||||
status_text.text(f"正在处理段落 {i + 1}/{total_paragraphs}: {agent.state.paragraphs[i].title}")
|
||||
|
||||
# 初始搜索和总结
|
||||
agent._initial_search_and_summary(i)
|
||||
progress_value = 20 + (i + 0.5) / total_paragraphs * 60
|
||||
progress_bar.progress(int(progress_value))
|
||||
|
||||
|
||||
# 反思循环
|
||||
agent._reflection_loop(i)
|
||||
agent.state.paragraphs[i].research.mark_completed()
|
||||
|
||||
|
||||
progress_value = 20 + (i + 1) / total_paragraphs * 60
|
||||
progress_bar.progress(int(progress_value))
|
||||
|
||||
|
||||
# 生成最终报告
|
||||
status_text.text("正在生成最终报告...")
|
||||
final_report = agent._generate_final_report()
|
||||
progress_bar.progress(90)
|
||||
|
||||
|
||||
# 保存报告
|
||||
status_text.text("正在保存报告...")
|
||||
agent._save_report(final_report)
|
||||
progress_bar.progress(100)
|
||||
|
||||
|
||||
status_text.text("研究完成!")
|
||||
|
||||
|
||||
# 显示结果
|
||||
display_results(agent, final_report)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"研究过程中发生错误: {str(e)}")
|
||||
|
||||
@@ -175,61 +154,41 @@ def execute_research(query: str, config: Config):
|
||||
def display_results(agent: DeepSearchAgent, final_report: str):
|
||||
"""显示研究结果"""
|
||||
st.header("研究结果")
|
||||
|
||||
# 结果标签页
|
||||
tab1, tab2, tab3 = st.tabs(["最终报告", "详细信息", "下载"])
|
||||
|
||||
|
||||
# 结果标签页(已移除下载选项)
|
||||
tab1, tab2 = st.tabs(["最终报告", "详细信息"])
|
||||
|
||||
with tab1:
|
||||
st.markdown(final_report)
|
||||
|
||||
|
||||
with tab2:
|
||||
# 段落详情
|
||||
st.subheader("段落详情")
|
||||
for i, paragraph in enumerate(agent.state.paragraphs):
|
||||
with st.expander(f"段落 {i+1}: {paragraph.title}"):
|
||||
with st.expander(f"段落 {i + 1}: {paragraph.title}"):
|
||||
st.write("**预期内容:**", paragraph.content)
|
||||
st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..."
|
||||
if len(paragraph.research.latest_summary) > 300
|
||||
else paragraph.research.latest_summary)
|
||||
st.write("**最终内容:**", paragraph.research.latest_summary[:300] + "..."
|
||||
if len(paragraph.research.latest_summary) > 300
|
||||
else paragraph.research.latest_summary)
|
||||
st.write("**搜索次数:**", paragraph.research.get_search_count())
|
||||
st.write("**反思次数:**", paragraph.research.reflection_iteration)
|
||||
|
||||
|
||||
# 搜索历史
|
||||
st.subheader("搜索历史")
|
||||
all_searches = []
|
||||
for paragraph in agent.state.paragraphs:
|
||||
all_searches.extend(paragraph.research.search_history)
|
||||
|
||||
|
||||
if all_searches:
|
||||
for i, search in enumerate(all_searches):
|
||||
with st.expander(f"搜索 {i+1}: {search.query}"):
|
||||
with st.expander(f"搜索 {i + 1}: {search.query}"):
|
||||
st.write("**URL:**", search.url)
|
||||
st.write("**标题:**", search.title)
|
||||
st.write("**内容预览:**", search.content[:200] + "..." if len(search.content) > 200 else search.content)
|
||||
st.write("**内容预览:**",
|
||||
search.content[:200] + "..." if len(search.content) > 200 else search.content)
|
||||
if search.score:
|
||||
st.write("**相关度评分:**", search.score)
|
||||
|
||||
with tab3:
|
||||
# 下载选项
|
||||
st.subheader("下载报告")
|
||||
|
||||
# Markdown下载
|
||||
st.download_button(
|
||||
label="下载Markdown报告",
|
||||
data=final_report,
|
||||
file_name=f"deep_search_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md",
|
||||
mime="text/markdown"
|
||||
)
|
||||
|
||||
# JSON状态下载
|
||||
state_json = agent.state.to_json()
|
||||
st.download_button(
|
||||
label="下载状态文件",
|
||||
data=state_json,
|
||||
file_name=f"deep_search_state_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
|
||||
mime="application/json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
Reference in New Issue
Block a user