更新mini的讯飞对接方法
This commit is contained in:
@@ -5,8 +5,6 @@ Micropython library for the MINI_XUNFEI(ASR, LLM)
|
|||||||
=======================================================
|
=======================================================
|
||||||
@dahanzimin From the Mixly Team
|
@dahanzimin From the Mixly Team
|
||||||
|
|
||||||
语音听写(流式版) WebAPI 文档 https://www.xfyun.cn/doc/asr/voicedictation/API.html
|
|
||||||
大模型(Spark4.0 Ultra)WebAPI 文档 https://www.xfyun.cn/doc/spark/Web.html
|
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
import hmac
|
import hmac
|
||||||
@@ -55,7 +53,7 @@ class ASR_WebSocket(Ws_Param):
|
|||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self.ws = websocket.connect(self.create_url())
|
self.ws = websocket.connect(self.create_url())
|
||||||
self.ws.settimeout(2000)
|
self.ws.settimeout(1000)
|
||||||
|
|
||||||
def _frame(self, status, buf):
|
def _frame(self, status, buf):
|
||||||
return {"status": status, "format": "audio/L16;rate=8000", "audio": str(b64encode(buf), 'utf-8'), "encoding": "raw"}
|
return {"status": status, "format": "audio/L16;rate=8000", "audio": str(b64encode(buf), 'utf-8'), "encoding": "raw"}
|
||||||
@@ -65,7 +63,7 @@ class ASR_WebSocket(Ws_Param):
|
|||||||
msg = json.loads(message)
|
msg = json.loads(message)
|
||||||
code = msg["code"]
|
code = msg["code"]
|
||||||
if code != 0:
|
if code != 0:
|
||||||
raise AttributeError("On message sid:%s call error:%s code is:%s" % (msg["sid"], msg["message"], code))
|
raise AttributeError("%s Code:%s" % (msg["message"], code))
|
||||||
else:
|
else:
|
||||||
data = msg["data"]["result"]["ws"]
|
data = msg["data"]["result"]["ws"]
|
||||||
for i in data:
|
for i in data:
|
||||||
@@ -116,29 +114,40 @@ class ASR_WebSocket(Ws_Param):
|
|||||||
return msg
|
return msg
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
onboard_bot.pcm_en(False) #PCM关闭
|
onboard_bot.pcm_en(False) #PCM关闭
|
||||||
print("run:%s" % (e))
|
print("Run error: %s" % (e))
|
||||||
|
|
||||||
#大模型
|
#大模型
|
||||||
class LLM_WebSocket(Ws_Param):
|
class LLM_WebSocket(Ws_Param):
|
||||||
def __init__(self, APPID, APIKey, APISecret, answers=50, url='ws://spark-api.xf-yun.com/v4.0/chat'):
|
Model_url = {
|
||||||
super().__init__(APPID, APIKey, APISecret, url)
|
"Spark Ultra-32K": ("ws://spark-api.xf-yun.com/v4.0/chat", "4.0Ultra"),
|
||||||
|
"Spark Max-32K": ("ws://spark-api.xf-yun.com/chat/max-32k", "max-32k"),
|
||||||
|
"Spark Max": ("ws://spark-api.xf-yun.com/v3.5/chat", "generalv3.5"),
|
||||||
|
"Spark Pro-128K": (" ws://spark-api.xf-yun.com/chat/pro-128k", "pro-128k"),
|
||||||
|
"Spark Pro": ("ws://spark-api.xf-yun.com/v3.1/chat", "generalv3"),
|
||||||
|
"Spark Lite": ("ws://spark-api.xf-yun.com/v1.1/chat", "lite"),
|
||||||
|
"Spark kjwx": ("ws://spark-openapi-n.cn-huabei-1.xf-yun.com/v1.1/chat_kjwx", "kjwx"),
|
||||||
|
"Spark X1-32K": ("ws://spark-api.xf-yun.com/v1/x1", "x1"),
|
||||||
|
"Spark Customize": ("ws://sparkcube-api.xf-yun.com/v1/customize", "max"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, APPID, APIKey, APISecret, model='Spark Ultra-32K', system="你是知识渊博的助理,习惯简短表达", answers=50):
|
||||||
self.ws = None
|
self.ws = None
|
||||||
self.answers =answers
|
self.answers = answers
|
||||||
self._messages = [{
|
self._url = self.Model_url[model] if model in self.Model_url else model
|
||||||
"role": "system",
|
super().__init__(APPID, APIKey, APISecret, self._url[0])
|
||||||
"content": "你是知识渊博的助理,习惯简短表达"
|
self._function = [{}, []] #[回调函数, 功能描述]
|
||||||
}]
|
self._messages = [{"role": "system", "content": system}]
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self.ws = websocket.connect(self.create_url())
|
self.ws = websocket.connect(self.create_url())
|
||||||
self.ws.settimeout(1000)
|
self.ws.settimeout(1000)
|
||||||
|
|
||||||
def _params(self, domain):
|
def _params(self):
|
||||||
d = {
|
d = {
|
||||||
"header": {"app_id": self.APPID},
|
"header": {"app_id": self.APPID},
|
||||||
"parameter": {
|
"parameter": {
|
||||||
"chat": {
|
"chat": {
|
||||||
"domain": domain,
|
"domain": self._url[1],
|
||||||
"temperature": 0.8,
|
"temperature": 0.8,
|
||||||
"max_tokens": 2048,
|
"max_tokens": 2048,
|
||||||
"top_k": 5,
|
"top_k": 5,
|
||||||
@@ -151,8 +160,22 @@ class LLM_WebSocket(Ws_Param):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if self._function[1]:
|
||||||
|
d["payload"]["functions"] = {"text": self._function[1]}
|
||||||
self.ws.send(json.dumps(d))
|
self.ws.send(json.dumps(d))
|
||||||
|
|
||||||
|
def function_call(self, callback, name, description, *args):
|
||||||
|
"""功能回调名称, 描述, (参数名, 类型, 描述) """
|
||||||
|
properties = {"type": "object", "properties":{}, "required":[]}
|
||||||
|
for arg in args:
|
||||||
|
if len(arg) >= 3:
|
||||||
|
properties["properties"][arg[0]] = {"type": arg[1], "description": arg[2]}
|
||||||
|
if arg[0] not in properties["required"]:
|
||||||
|
properties["required"].append(arg[0])
|
||||||
|
else:
|
||||||
|
raise AttributeError('Invalid Input , format is (name, type, description)')
|
||||||
|
self._function[0][name] = callback
|
||||||
|
self._function[1].append({"name": name, "description": description, "parameters": properties})
|
||||||
|
|
||||||
def empty_history(self):
|
def empty_history(self):
|
||||||
self._messages = []
|
self._messages = []
|
||||||
@@ -163,37 +186,47 @@ class LLM_WebSocket(Ws_Param):
|
|||||||
"content": content
|
"content": content
|
||||||
})
|
})
|
||||||
|
|
||||||
def on_message(self, message):
|
def on_message(self, message, reas):
|
||||||
result = ""
|
result = ""
|
||||||
msg = json.loads(message)
|
msg = json.loads(message)
|
||||||
code = msg['header']['code']
|
code = msg['header']['code']
|
||||||
if code != 0:
|
if code != 0:
|
||||||
raise AttributeError("On message sid:%s code is:%s" % (msg["header"]["sid"], code))
|
raise AttributeError("%s Code:%s" % (msg["header"]["message"], code))
|
||||||
else:
|
else:
|
||||||
choices = msg["payload"]["choices"]
|
choices = msg["payload"]["choices"]
|
||||||
result += choices["text"][0]["content"]
|
text = choices["text"][0]
|
||||||
if choices["status"] == 2:
|
#推理
|
||||||
return result, False
|
if "reasoning_content" in text and reas:
|
||||||
|
print("reasoning: ", text["reasoning_content"])
|
||||||
|
#回调
|
||||||
|
if "function_call" in text:
|
||||||
|
if str(text['function_call']['name']) in self._function[0] and text['function_call']['arguments']:
|
||||||
|
self._function[0][text['function_call']['name']](json.loads(text['function_call']['arguments']))
|
||||||
|
#答复
|
||||||
|
if "content" in text:
|
||||||
|
result += text["content"]
|
||||||
|
if choices["status"] == 2:
|
||||||
|
return result, False
|
||||||
return result, True
|
return result, True
|
||||||
|
|
||||||
def receive_messages(self):
|
def receive_messages(self, reas):
|
||||||
msg = ""
|
msg = ""
|
||||||
while True:
|
while True:
|
||||||
t = self.on_message(self.ws.recv())
|
t = self.on_message(self.ws.recv(), reas)
|
||||||
msg += t[0]
|
msg += t[0]
|
||||||
if not t[1]:
|
if not t[1]:
|
||||||
break
|
break
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def run(self, question, domain="4.0Ultra"):
|
def run(self, question, reas=True):
|
||||||
try:
|
try:
|
||||||
self.connect()
|
self.connect()
|
||||||
self.add_history("user", question)
|
self.add_history("user", question)
|
||||||
self._params(domain)
|
self._params()
|
||||||
while self.answers < len(self._messages):
|
while self.answers < len(self._messages):
|
||||||
del self._messages[0]
|
del self._messages[0]
|
||||||
msg = self.receive_messages()
|
msg = self.receive_messages(reas)
|
||||||
self.ws.close()
|
self.ws.close()
|
||||||
return msg
|
return msg
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("run:%s" % (e))
|
print("Run error: %s" % (e))
|
||||||
|
|||||||
Reference in New Issue
Block a user