fix(sentiment_analyzer): fix type warning from pyright

This commit is contained in:
ghmark675
2025-11-10 19:14:19 +08:00
committed by 666ghj
parent aa11c529c8
commit 92482a416e

View File

@@ -142,6 +142,7 @@ class WeiboMultilingualSentimentAnalyzer:
"""Select the best available torch device."""
if not TORCH_AVAILABLE:
return None
assert torch is not None
if torch.cuda.is_available():
return torch.device("cuda")
mps_backend = getattr(torch.backends, "mps", None)
@@ -177,6 +178,8 @@ class WeiboMultilingualSentimentAnalyzer:
try:
print("正在加载多语言情感分析模型...")
assert AutoTokenizer is not None
assert AutoModelForSequenceClassification is not None
# 使用多语言情感分析模型
model_name = "tabularisai/multilingual-sentiment-analysis"
@@ -300,7 +303,7 @@ class WeiboMultilingualSentimentAnalyzer:
error_message="输入文本为空或无效内容",
analysis_performed=False,
)
assert self.tokenizer is not None
# 分词编码
inputs = self.tokenizer(
processed_text,
@@ -314,11 +317,13 @@ class WeiboMultilingualSentimentAnalyzer:
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 预测
assert torch is not None
assert self.model is not None
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
prediction = torch.argmax(probabilities, dim=1).item()
prediction = int(torch.argmax(probabilities, dim=1).item())
# 构建结果
confidence = probabilities[0][prediction].item()