@Service
public class AiServiceImpl implements AiService {
private final List<RoleContent> historyList = new ArrayList<>();
private StringBuilder totalAnswer = new StringBuilder();
private static final Gson gson = new Gson();
@Value("${ai.hostUrl}")
private String hostUrl;
@Value("${ai.domain}")
private String domain;
@Value("${ai.appid}")
private String appid;
@Value("${ai.apiSecret}")
private String apiSecret;
@Value("${ai.apiKey}")
private String apiKey;
@Override
public String getAnswer(String question) {
if (question == null || question.trim().isEmpty()) {
return "请输入有效问题";
}
try {
String authUrl = getAuthUrl(hostUrl, apiKey, apiSecret);
OkHttpClient client = new OkHttpClient.Builder().build();
String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://");
Request request = new Request.Builder().url(url).build();
CountDownLatch latch = new CountDownLatch(1);
WebSocket webSocket = client.newWebSocket(request, new AiWebSocketListener(latch, question));
boolean completed = latch.await(30, TimeUnit.SECONDS);
if (!completed) {
webSocket.close(1000, "请求超时");
return "请求超时,请重试";
}
log.info("AI 回答:{}", totalAnswer);
return totalAnswer.toString();
} catch (Exception e) {
e.printStackTrace();
return "处理请求时出错:" + e.getMessage();
}
}
class AiWebSocketListener extends WebSocketListener {
private final CountDownLatch latch;
private final String question;
private final AtomicBoolean isResponseComplete = new AtomicBoolean(false);
public AiWebSocketListener(CountDownLatch latch, String question) {
this.latch = latch;
this.question = question;
}
@Override
public void onOpen(WebSocket webSocket, Response response) {
totalAnswer.setLength(0);
super.onOpen(webSocket, response);
sendRequest(webSocket, question);
}
@Override
public void onMessage(WebSocket webSocket, String text) {
try {
JsonParse myJsonParse = gson.fromJson(text, JsonParse.class);
if (myJsonParse.header.code != 0) {
System.err.println("发生错误,错误码为:" + myJsonParse.header.code);
webSocket.close(1000, "Error code: " + myJsonParse.header.code);
latch.countDown();
return;
}
List<Text> textList = myJsonParse.payload.choices.text;
for (Text temp : textList) {
totalAnswer.append(temp.content);
}
if (myJsonParse.header.status == 2) {
addToHistory(question, totalAnswer.toString());
isResponseComplete.set(true);
latch.countDown();
}
} catch (Exception e) {
e.printStackTrace();
latch.countDown();
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
super.onFailure(webSocket, t, response);
System.err.println("WebSocket 连接失败");
if (response != null) {
System.err.println("错误码:" + response.code());
}
t.printStackTrace();
latch.countDown();
}
@Override
public void onClosed(WebSocket webSocket, int code, String reason) {
super.onClosed(webSocket, code, reason);
if (!isResponseComplete.get()) {
latch.countDown();
}
}
private void sendRequest(WebSocket webSocket, String question) {
try {
JSONObject requestJson = new JSONObject();
JSONObject header = new JSONObject();
header.put("app_id", appid);
header.put("uid", UUID.randomUUID().toString().substring(0, 10));
JSONObject parameter = new JSONObject();
JSONObject chat = new JSONObject();
chat.put("domain", domain);
chat.put("temperature", 0.5);
chat.put("max_tokens", 4096);
parameter.put("chat", chat);
JSONObject payload = new JSONObject();
JSONObject message = new JSONObject();
JSONArray text = new JSONArray();
RoleContent roleSystem = new RoleContent();
roleSystem.setRole("system");
roleSystem.setContent("你是一个智能客服,接下来你要用客服的语气和我对话");
text.add(JSON.toJSON(roleSystem));
for (RoleContent history : historyList) {
text.add(JSON.toJSON(history));
}
RoleContent userQuestion = new RoleContent();
userQuestion.setRole("user");
userQuestion.setContent(question);
text.add(JSON.toJSON(userQuestion));
message.put("text", text);
payload.put("message", message);
requestJson.put("header", header);
requestJson.put("parameter", parameter);
requestJson.put("payload", payload);
webSocket.send(requestJson.toString());
} catch (Exception e) {
e.printStackTrace();
latch.countDown();
}
}
}
private void addToHistory(String question, String answer) {
if (historyList.size() >= 10) {
historyList.remove(0);
}
RoleContent userContent = new RoleContent();
userContent.setRole("user");
userContent.setContent(question);
historyList.add(userContent);
RoleContent aiContent = new RoleContent();
aiContent.setRole("assistant");
aiContent.setContent(answer);
historyList.add(aiContent);
}
private String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
URL url = new URL(hostUrl);
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
String preStr = "host: " + url.getHost() + "\n"
+ "date: " + date + "\n"
+ "GET " + url.getPath() + " HTTP/1.1";
Mac mac = Mac.getInstance("HmacSHA256");
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "HmacSHA256");
mac.init(spec);
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
String sha = Base64.getEncoder().encodeToString(hexDigits);
String authorization = String.format(
"api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"",
apiKey, "hmac-sha256", "host date request-line", sha
);
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath()))
.newBuilder()
.addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8)))
.addQueryParameter("date", date)
.addQueryParameter("host", url.getHost())
.build();
return httpUrl.toString();
}
static class JsonParse {
Header header;
Payload payload;
}
static class Header {
int code;
int status;
String sid;
}
static class Payload {
Choices choices;
}
static class Choices {
List<Text> text;
}
static class Text {
String role;
String content;
}
static class RoleContent {
String role;
String content;
public String getRole() { return role; }
public void setRole(String role) { this.role = role; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
}
}