""" MINI_XUNFEI Micropython library for the MIC_ADC XUNFEI(ASR, LLM) ======================================================= @dahanzimin From the Mixly Team """ import time import hmac import json import hashlib import rtctime import websocket import adc_mic from base64 import b64decode, b64encode from urllib import urlencode, urlparse class Ws_Param: def __init__(self, APPID, APIKey, APISecret, Spark_url): self.APPID = APPID self.APIKey = APIKey self.APISecret = APISecret self.url = Spark_url self.urlparse = urlparse(Spark_url) def create_url(self): date = rtctime.rfc1123_time() signature_origin = "host: " + self.urlparse.netloc + "\n" signature_origin += "date: " + date + "\n" signature_origin += "GET " + self.urlparse.path + " HTTP/1.1" signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest() signature_base64 = b64encode(signature_sha).decode('utf-8') authorization_origin = ('api_key="{}", algorithm="hmac-sha256", headers="host date request-line", signature="{}"'.format(self.APIKey, signature_base64)) authorization = b64encode(authorization_origin.encode('utf-8')).decode('utf-8') headers = {"authorization": authorization, "date": date, "host": self.urlparse.netloc} return self.url + '?' + urlencode(headers) #语音听写 class ASR_WebSocket(Ws_Param): def __init__(self, adcpin, APPID, APIKey, APISecret, url='ws://iat-api.xfyun.cn/v2/iat', sample_rate=8000): super().__init__(APPID, APIKey, APISecret, url) self.ws = None self.rate = sample_rate self.mic = adc_mic.ADCMic(adcpin, sample_rate=sample_rate) self.business = { "domain": "iat", "language": "zh_cn", "accent": "mandarin", "vad_eos": 1000, "nbest": 1, "wbest": 1, } def connect(self): self.ws = websocket.connect(self.create_url()) self.ws.settimeout(1000) def _frame(self, status, buf): if status == 0: return {"common": {"app_id": self.APPID}, "business": self.business, "data": {"status": status, "format": f"audio/L16;rate={self.rate}", "audio": str(b64encode(buf), 'utf-8'), "encoding": "raw"}} else: return {"data": {"status": status, "format": f"audio/L16;rate={self.rate}", "audio": str(b64encode(buf), 'utf-8'), "encoding": "raw"}} def on_message(self, message): result = "" msg = json.loads(message) code = msg["code"] if code != 0: raise AttributeError("%s Code:%s" % (msg["message"], code)) else: data = msg["data"]["result"]["ws"] for i in data: for w in i["cw"]: result += w["w"] if msg["data"]["status"]== 2: return result, False return result, True def receive_messages(self): msg = "" while True: t = self.on_message(self.ws.recv()) msg += t[0] if not t[1]: break return msg def run(self, seconds=3, pace=True): try: self.connect() _state = 0 ibuf = int(self.rate * 0.2) _buf = bytearray(ibuf) _size = int(ibuf * seconds * 10) #100ms/次 self.mic.start() if pace: print('[',end ="") while _size > 0: if self.mic.read_into(_buf): _size -= ibuf if pace: print('=',end ="") # 第一帧处理 if _state == 0: d = self._frame(_state, _buf) _state = 1 # 中间帧处理 else: d = self._frame(_state, _buf) self.ws.send(json.dumps(d)) # 最后一帧处理 self.mic.stop() d = self._frame(2, b'\x00') self.ws.send(json.dumps(d)) if pace: print(']') msg = self.receive_messages() return msg except Exception as e: onboard_bot.pcm_en(False) #PCM关闭 if "403 Forbidden" in str(e): raise OSError("Access denied, Please try updating clock time") else: print("Run error: %s" % (e)) finally: self.mic.stop() self.ws.close() #中英识别大模型 class IAT_WebSocket(ASR_WebSocket): def __init__(self, adcpin, APPID, APIKey, APISecret, url='ws://iat.xf-yun.com/v1', sample_rate=8000, accent="mandarin", res_id=None): super().__init__(adcpin, APPID, APIKey, APISecret, url, sample_rate) self.res_id = res_id self.business = { "domain": "slm", "language": "zh_cn", "accent": accent, "result": { "encoding": "utf8", "compress": "raw", "format": "plain" } } def _frame(self, status, buf): if status == 0: return {"header": {"status": status, "app_id": self.APPID, "res_id": self.res_id}, "parameter": {"iat": self.business}, "payload": {"audio": { "audio": str(b64encode(buf), 'utf-8'), "sample_rate": self.rate, "encoding": "raw"}}} else: return {"header": {"status": status, "app_id": self.APPID, "res_id": self.res_id}, "payload": {"audio": { "audio": str(b64encode(buf), 'utf-8'), "sample_rate": self.rate, "encoding": "raw"}}} def on_message(self, message): result = "" msg = json.loads(message) code = msg['header']["code"] if code != 0: raise AttributeError("%s Code:%s" % (msg['header']["message"], code)) else: if "payload" in msg: text = msg["payload"]["result"]["text"] data = json.loads(b64decode(text).decode())['ws'] for i in data: for w in i["cw"]: result += w["w"] if msg["header"]["status"]== 2: return result, False return result, True #大模型 class LLM_WebSocket(Ws_Param): Model_url = { "Spark Ultra-32K": ("ws://spark-api.xf-yun.com/v4.0/chat", "4.0Ultra"), "Spark Max-32K": ("ws://spark-api.xf-yun.com/chat/max-32k", "max-32k"), "Spark Max": ("ws://spark-api.xf-yun.com/v3.5/chat", "generalv3.5"), "Spark Pro-128K": (" ws://spark-api.xf-yun.com/chat/pro-128k", "pro-128k"), "Spark Pro": ("ws://spark-api.xf-yun.com/v3.1/chat", "generalv3"), "Spark Lite": ("ws://spark-api.xf-yun.com/v1.1/chat", "lite"), "Spark kjwx": ("ws://spark-openapi-n.cn-huabei-1.xf-yun.com/v1.1/chat_kjwx", "kjwx"), "Spark X1-32K": ("ws://spark-api.xf-yun.com/v1/x1", "x1"), "Spark Customize": ("ws://sparkcube-api.xf-yun.com/v1/customize", "max"), } def __init__(self, APPID, APIKey, APISecret, model='Spark Ultra-32K', system="你是知识渊博的助理,习惯简短表达", answers=50): self.ws = None self.answers = answers self._url = self.Model_url[model] if model in self.Model_url else model super().__init__(APPID, APIKey, APISecret, self._url[0]) self._function = [{}, []] #[回调函数, 功能描述] self._messages = [{"role": "system", "content": system}] def connect(self): self.ws = websocket.connect(self.create_url()) self.ws.settimeout(1000) def _params(self): d = { "header": {"app_id": self.APPID}, "parameter": { "chat": { "domain": self._url[1], "random_threshold": 0.5, "max_tokens": 2048, "auditing": "default" } }, "payload": { "message": { "text": self._messages } } } if self._function[1]: d["payload"]["functions"] = {"text": self._function[1]} self.ws.send(json.dumps(d)) def function_call(self, callback, name, description, params): """功能回调名称, 描述, ((参数名, 类型, 描述), ...)""" properties = {"type": "object", "properties":{}, "required":[]} for arg in params: if len(arg) >= 3: properties["properties"][arg[0]] = {"type": arg[1], "description": arg[2]} if arg[0] not in properties["required"]: properties["required"].append(arg[0]) else: raise AttributeError('Invalid Input , format is (name, type, description)') self._function[0][name] = callback self._function[1].append({"name": name, "description": description, "parameters": properties}) def empty_history(self): self._messages = [] def add_history(self, role, content): self._messages.append({ "role": role, "content": content }) def on_message(self, message, reas): result = "" msg = json.loads(message) code = msg['header']['code'] if code != 0: raise AttributeError("%s Code:%s" % (msg["header"]["message"], code)) else: choices = msg["payload"]["choices"] text = choices["text"][0] #推理 if "reasoning_content" in text and reas: print("reasoning: ", text["reasoning_content"]) #回调 if "tool_calls" in text: function = text['tool_calls'][0]['function'] if str(function['name']) in self._function[0] and function['arguments']: self._function[0][function['name']](json.loads(function['arguments'])) if "function_call" in text: if str(text['function_call']['name']) in self._function[0] and text['function_call']['arguments']: self._function[0][text['function_call']['name']](json.loads(text['function_call']['arguments'])) #答复 if "content" in text: result += text["content"] if choices["status"] == 2: return result, False return result, True def receive_messages(self, reas): msg = "" while True: t = self.on_message(self.ws.recv(), reas) msg += t[0] if not t[1]: break return msg def run(self, question, reas=True): try: self.connect() self.add_history("user", question) self._params() while self.answers < len(self._messages): del self._messages[0] msg = self.receive_messages(reas) return msg except Exception as e: if "403 Forbidden" in str(e): raise OSError("Access denied, Please try updating clock time") else: print("Run error: %s" % (e)) finally: self.ws.close()