更新mini支持讯飞语音识别
This commit is contained in:
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
MINI_XUNFEI
|
||||
|
||||
Micropython library for the MINI_XUNFEI(ASR, LLM)
|
||||
=======================================================
|
||||
@dahanzimin From the Mixly Team
|
||||
|
||||
语音听写(流式版) WebAPI 文档 https://www.xfyun.cn/doc/asr/voicedictation/API.html
|
||||
大模型(Spark4.0 Ultra)WebAPI 文档 https://www.xfyun.cn/doc/spark/Web.html
|
||||
"""
|
||||
import time
|
||||
import hmac
|
||||
import json
|
||||
import hashlib
|
||||
import rtctime
|
||||
import websocket
|
||||
from mixgo_mini import onboard_bot
|
||||
from base64 import b64decode, b64encode
|
||||
from urllib import urlencode, urlparse
|
||||
|
||||
class Ws_Param:
|
||||
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.url = Spark_url
|
||||
self.urlparse = urlparse(Spark_url)
|
||||
|
||||
def create_url(self):
|
||||
date = rtctime.rfc1123_time()
|
||||
signature_origin = "host: " + self.urlparse.netloc + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.urlparse.path + " HTTP/1.1"
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest()
|
||||
signature_base64 = b64encode(signature_sha).decode('utf-8')
|
||||
authorization_origin = ('api_key="{}", algorithm="hmac-sha256", headers="host date request-line", signature="{}"'.format(self.APIKey, signature_base64))
|
||||
authorization = b64encode(authorization_origin.encode('utf-8')).decode('utf-8')
|
||||
headers = {"authorization": authorization, "date": date, "host": self.urlparse.netloc}
|
||||
return self.url + '?' + urlencode(headers)
|
||||
|
||||
#语音听写
|
||||
class ASR_WebSocket(Ws_Param):
|
||||
def __init__(self, APPID, APIKey, APISecret, url='ws://iat-api.xfyun.cn/v2/iat'):
|
||||
super().__init__(APPID, APIKey, APISecret, url)
|
||||
self.ws = None
|
||||
self.business = {
|
||||
"domain": "iat",
|
||||
"language": "zh_cn",
|
||||
"accent": "mandarin",
|
||||
"vinfo": 1,
|
||||
"vad_eos": 1000,
|
||||
"nbest": 1,
|
||||
"wbest": 1,
|
||||
}
|
||||
|
||||
def connect(self):
|
||||
self.ws = websocket.connect(self.create_url())
|
||||
self.ws.settimeout(2000)
|
||||
|
||||
def _frame(self, status, buf):
|
||||
return {"status": status, "format": "audio/L16;rate=8000", "audio": str(b64encode(buf), 'utf-8'), "encoding": "raw"}
|
||||
|
||||
def on_message(self, message):
|
||||
result = ""
|
||||
msg = json.loads(message)
|
||||
code = msg["code"]
|
||||
if code != 0:
|
||||
raise AttributeError("On message sid:%s call error:%s code is:%s" % (msg["sid"], msg["message"], code))
|
||||
else:
|
||||
data = msg["data"]["result"]["ws"]
|
||||
for i in data:
|
||||
for w in i["cw"]:
|
||||
result += w["w"]
|
||||
if msg["data"]["status"]== 2:
|
||||
return result, False
|
||||
return result, True
|
||||
|
||||
def receive_messages(self):
|
||||
msg = ""
|
||||
while True:
|
||||
t = self.on_message(self.ws.recv())
|
||||
msg += t[0]
|
||||
if not t[1]:
|
||||
break
|
||||
return msg
|
||||
|
||||
def run(self, seconds=3, ibuf=1600, timeout=2000):
|
||||
try:
|
||||
_state = 0
|
||||
self.connect()
|
||||
_star = time.ticks_ms()
|
||||
_size = int(ibuf * seconds * 10) #100ms/次
|
||||
onboard_bot.pcm_en(True) #PCM开启
|
||||
while _size > 0:
|
||||
if onboard_bot.pcm_any():
|
||||
_size -= ibuf
|
||||
_star = time.ticks_ms()
|
||||
buf = onboard_bot.pcm_read(ibuf)
|
||||
# 第一帧处理
|
||||
if _state == 0:
|
||||
d = {"common": {"app_id": self.APPID}, "business": self.business, "data": self._frame(_state, buf)}
|
||||
_state = 1
|
||||
# 中间帧处理
|
||||
else:
|
||||
d = {"data": self._frame(_state, buf)}
|
||||
self.ws.send(json.dumps(d))
|
||||
#print("------",len(buf), time.ticks_diff(time.ticks_ms(), _star))
|
||||
if time.ticks_diff(time.ticks_ms(), _star) > timeout:
|
||||
raise OSError("Timeout pcm read error")
|
||||
# 最后一帧处理
|
||||
d = {"data": self._frame(2, b'\x00')}
|
||||
self.ws.send(json.dumps(d))
|
||||
onboard_bot.pcm_en(False) #PCM关闭
|
||||
msg = self.receive_messages()
|
||||
self.ws.close()
|
||||
return msg
|
||||
except Exception as e:
|
||||
onboard_bot.pcm_en(False) #PCM关闭
|
||||
print("run:%s" % (e))
|
||||
|
||||
#大模型
|
||||
class LLM_WebSocket(Ws_Param):
|
||||
def __init__(self, APPID, APIKey, APISecret, answers=50, url='ws://spark-api.xf-yun.com/v4.0/chat'):
|
||||
super().__init__(APPID, APIKey, APISecret, url)
|
||||
self.ws = None
|
||||
self.answers =answers
|
||||
self._messages = [{
|
||||
"role": "system",
|
||||
"content": "你是知识渊博的助理,习惯简短表达"
|
||||
}]
|
||||
|
||||
def connect(self):
|
||||
self.ws = websocket.connect(self.create_url())
|
||||
self.ws.settimeout(1000)
|
||||
|
||||
def _params(self, domain):
|
||||
d = {
|
||||
"header": {"app_id": self.APPID},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 2048,
|
||||
"top_k": 5,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": self._messages
|
||||
}
|
||||
}
|
||||
}
|
||||
self.ws.send(json.dumps(d))
|
||||
|
||||
|
||||
def empty_history(self):
|
||||
self._messages = []
|
||||
|
||||
def add_history(self, role, content):
|
||||
self._messages.append({
|
||||
"role": role,
|
||||
"content": content
|
||||
})
|
||||
|
||||
def on_message(self, message):
|
||||
result = ""
|
||||
msg = json.loads(message)
|
||||
code = msg['header']['code']
|
||||
if code != 0:
|
||||
raise AttributeError("On message sid:%s code is:%s" % (msg["header"]["sid"], code))
|
||||
else:
|
||||
choices = msg["payload"]["choices"]
|
||||
result += choices["text"][0]["content"]
|
||||
if choices["status"] == 2:
|
||||
return result, False
|
||||
return result, True
|
||||
|
||||
def receive_messages(self):
|
||||
msg = ""
|
||||
while True:
|
||||
t = self.on_message(self.ws.recv())
|
||||
msg += t[0]
|
||||
if not t[1]:
|
||||
break
|
||||
return msg
|
||||
|
||||
def run(self, question, domain="4.0Ultra"):
|
||||
try:
|
||||
self.connect()
|
||||
self.add_history("user", question)
|
||||
self._params(domain)
|
||||
while self.answers < len(self._messages):
|
||||
del self._messages[0]
|
||||
msg = self.receive_messages()
|
||||
self.ws.close()
|
||||
return msg
|
||||
except Exception as e:
|
||||
print("run:%s" % (e))
|
||||
Reference in New Issue
Block a user