from langchain.llm import LLM
from typing import Optional, List, Mapping, Any
class CustomLLM(LLM):
n: int = 10
@property
def _llm_type(self) -> str:
return "custom_simple"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager=None,
**kwargs: Any,
) -> str:
if stop:
for word in stop:
prompt = prompt.replace(word, "")
return prompt[:self.n]
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"n": self.n}
import json
import ssl
import time
import hashlib
import hmac
import base64
from urllib.parse import urlencode, urlparse
from datetime import datetime
from wsgiref.handlers import format_date_time
from threading import Thread
from typing import Optional, List, Mapping, Any
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import websocket
class SparkLLM(LLM):
appid: Optional[str] = None
api_secret: Optional[str] = None
api_key: Optional[str] = None
model: Optional[str] = None
answer: str = ""
@property
def _llm_type(self) -> str:
return "SparkLLM"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
self.answer = ""
modal_dict = {
"general": "ws://spark-api.xf-yun.com/v1.1/chat",
"generalv2": "ws://spark-api.xf-yun.com/v2.1/chat",
"generalv3": "ws://spark-api.xf-yun.com/v3.1/chat"
}
domain = self.model or "generalv3"
spark_url = modal_dict.get(domain, modal_dict["generalv3"])
try:
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, spark_url)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(
wsUrl,
on_message=self.on_message,
on_error=self.on_error,
on_close=self.on_close,
on_open=lambda ws: self.run(ws, prompt, domain)
)
ws.appid = self.appid
ws.question = prompt
ws.domain = domain
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
except Exception as e:
return f"Error occurred during connection: {str(e)}"
return self.answer
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'Request error: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
self.answer += content
if status == 2:
ws.close()
def on_error(self, ws, error):
print(f"WebSocket Error: {error}")
def on_close(self, ws, one, two):
pass
def run(self, ws, question, domain):
data = json.dumps(self.gen_params(ws.appid, domain, question))
ws.send(data)
def gen_params(self, appid, domain, question):
return {
"header": {"app_id": appid, "uid": "user_123"},
"parameter": {
"chat": {
"domain": domain,
"temperature": 0.5,
"max_tokens": 2048
}
},
"payload": {"message": {"text": question}}
}
import os
xh_app_id = os.getenv("IFLYTEK_APP_ID")
xh_api_secret = os.getenv("IFLYTEK_API_SECRET")
xh_api_key = os.getenv("IFLYTEK_API_KEY")
if not all([xh_app_id, xh_api_secret, xh_api_key]):
raise ValueError("Missing environment variables")
modal = "generalv3"
llm = SparkLLM(appid=xh_app_id, api_secret=xh_api_secret, api_key=xh_api_key, model=modal)
response = llm.invoke("你好,请介绍一下 LangChain")
print(response)