fix(sentiment_analyzer): fix type warning from pyright
This commit is contained in:
@@ -142,6 +142,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
"""Select the best available torch device."""
|
||||
if not TORCH_AVAILABLE:
|
||||
return None
|
||||
assert torch is not None
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
mps_backend = getattr(torch.backends, "mps", None)
|
||||
@@ -177,6 +178,8 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
|
||||
try:
|
||||
print("正在加载多语言情感分析模型...")
|
||||
assert AutoTokenizer is not None
|
||||
assert AutoModelForSequenceClassification is not None
|
||||
|
||||
# 使用多语言情感分析模型
|
||||
model_name = "tabularisai/multilingual-sentiment-analysis"
|
||||
@@ -300,7 +303,7 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
error_message="输入文本为空或无效内容",
|
||||
analysis_performed=False,
|
||||
)
|
||||
|
||||
assert self.tokenizer is not None
|
||||
# 分词编码
|
||||
inputs = self.tokenizer(
|
||||
processed_text,
|
||||
@@ -314,11 +317,13 @@ class WeiboMultilingualSentimentAnalyzer:
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# 预测
|
||||
assert torch is not None
|
||||
assert self.model is not None
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
probabilities = torch.softmax(logits, dim=1)
|
||||
prediction = torch.argmax(probabilities, dim=1).item()
|
||||
prediction = int(torch.argmax(probabilities, dim=1).item())
|
||||
|
||||
# 构建结果
|
||||
confidence = probabilities[0][prediction].item()
|
||||
|
||||
Reference in New Issue
Block a user