zyx 2 years ago
parent
commit
9d316a802f

+ 3 - 63
src/main/java/com/sxtvs/api/chatgpt/controller/ChatGptController.java

@@ -1,86 +1,26 @@
 package com.sxtvs.api.chatgpt.controller;
 package com.sxtvs.api.chatgpt.controller;
 
 
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.sxtvs.api.chatgpt.dto.CompletionsParamsDto;
 import com.sxtvs.api.chatgpt.dto.CompletionsRequestDto;
 import com.sxtvs.api.chatgpt.dto.CompletionsRequestDto;
 import com.sxtvs.api.chatgpt.dto.CompletionsResponseDto;
 import com.sxtvs.api.chatgpt.dto.CompletionsResponseDto;
-import com.sxtvs.api.chatgpt.dto.GptResponse;
-import com.sxtvs.core.sls.AliyunLogger;
-import com.sxtvs.core.sls.advice.BizException;
-import lombok.Cleanup;
+import com.sxtvs.api.chatgpt.service.ChatGptService;
 import lombok.SneakyThrows;
 import lombok.SneakyThrows;
-import org.apache.commons.io.IOUtils;
-import org.apache.hc.client5.http.classic.HttpClient;
-import org.apache.hc.client5.http.fluent.Request;
-import org.apache.hc.client5.http.impl.classic.HttpClientBuilder;
-import org.apache.hc.core5.http.HttpResponse;
-import org.apache.hc.core5.http.HttpStatus;
-import org.apache.http.HttpEntity;
-import org.apache.http.client.methods.CloseableHttpResponse;
-import org.apache.http.client.methods.HttpPost;
-import org.apache.http.entity.ContentType;
-import org.apache.http.entity.StringEntity;
-import org.apache.http.impl.client.CloseableHttpClient;
-import org.apache.http.impl.client.HttpClients;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.web.bind.annotation.RequestBody;
 import org.springframework.web.bind.annotation.RequestBody;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.bind.annotation.RestController;
 
 
-import java.nio.charset.StandardCharsets;
-import java.util.concurrent.locks.Lock;
-import java.util.concurrent.locks.ReentrantLock;
-import java.util.stream.Collectors;
-
 @RestController
 @RestController
 public class ChatGptController {
 public class ChatGptController {
 
 
     @Autowired
     @Autowired
-    private ObjectMapper objectMapper;
-
-    @Autowired
-    private AliyunLogger logger;
+    private ChatGptService chatGptService;
 
 
-    private final CloseableHttpClient client = HttpClients.createDefault();
-
-    private final Lock lock = new ReentrantLock();
 
 
     @RequestMapping("completions")
     @RequestMapping("completions")
     @SneakyThrows
     @SneakyThrows
     public CompletionsResponseDto completions(@RequestBody CompletionsRequestDto dto) {
     public CompletionsResponseDto completions(@RequestBody CompletionsRequestDto dto) {
 
 
-
-        var paramsDto = new CompletionsParamsDto(dto.getPrompt()
-                .stream()
-                .map(CompletionsRequestDto.Prompt::getText)
-                .collect(Collectors.joining("\n")));
-        var params = objectMapper.writeValueAsString(paramsDto);
-        logger.info("request", params, "dto", objectMapper.writeValueAsString(dto));
-        lock.lock();
-        String result;
-        try {
-            var httpPost = new HttpPost("https://api.openai.com/v1/completions");
-            httpPost.setHeader("Authorization", "Bearer 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");
-            httpPost.setEntity(new StringEntity(params, ContentType.APPLICATION_JSON));
-            @Cleanup
-            var response = client.execute(httpPost);
-            @Cleanup
-            var content = response.getEntity().getContent();
-            result = IOUtils.toString(content, StandardCharsets.UTF_8);
-            logger.info("response", result);
-            if (response.getStatusLine().getStatusCode() >= HttpStatus.SC_REDIRECTION) {
-                logger.error("error", result);
-                throw new BizException(1000, "服务器走丢了 请重试");
-            }
-        } finally {
-            lock.unlock();
-        }
-        var gptResponse = objectMapper.readValue(result, GptResponse.class);
-
-        var completionsResponseDto = new CompletionsResponseDto();
-        var text = gptResponse.getChoices().stream().map(GptResponse.ChoicesDTO::getText).findFirst().orElse("").trim();
-        completionsResponseDto.setResult(text);
-        return completionsResponseDto;
+        return  chatGptService.completions(dto);
     }
     }
 
 
 }
 }

+ 86 - 0
src/main/java/com/sxtvs/api/chatgpt/service/ChatGptService.java

@@ -0,0 +1,86 @@
+package com.sxtvs.api.chatgpt.service;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.sxtvs.api.chatgpt.dto.*;
+import com.sxtvs.core.sls.AliyunLogger;
+import com.sxtvs.core.sls.advice.BizException;
+import lombok.Cleanup;
+import lombok.SneakyThrows;
+import org.apache.commons.io.IOUtils;
+import org.apache.hc.core5.http.HttpStatus;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ContentType;
+import org.apache.http.entity.StringEntity;
+import org.apache.http.impl.client.CloseableHttpClient;
+import org.apache.http.impl.client.HttpClients;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Service;
+import org.springframework.web.bind.annotation.RequestBody;
+
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.List;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.stream.Collectors;
+
+@Service
+public class ChatGptService {
+
+    @Autowired
+    private AliyunLogger logger;
+
+    @Autowired
+    private ObjectMapper objectMapper;
+
+    private final CloseableHttpClient client = HttpClients.createDefault();
+
+    private final SynchronousQueue<String> queue = new SynchronousQueue<>();
+
+    {
+        try {
+            queue.put("Bearer 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");
+            queue.put("Bearer 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");
+        } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+
+    @SneakyThrows
+    public CompletionsResponseDto completions(@RequestBody CompletionsRequestDto dto) {
+        var paramsDto = new CompletionsParamsDto(dto.getPrompt()
+                .stream()
+                .map(CompletionsRequestDto.Prompt::getText)
+                .collect(Collectors.joining("\n")));
+        var params = objectMapper.writeValueAsString(paramsDto);
+        logger.info("request", params, "dto", objectMapper.writeValueAsString(dto));
+        var token = queue.poll();
+        String result;
+        try {
+            var httpPost = new HttpPost("https://api.openai.com/v1/completions");
+            httpPost.setHeader("Authorization", token);
+            httpPost.setEntity(new StringEntity(params, ContentType.APPLICATION_JSON));
+            @Cleanup
+            var response = client.execute(httpPost);
+            @Cleanup
+            var content = response.getEntity().getContent();
+            result = IOUtils.toString(content, StandardCharsets.UTF_8);
+            logger.info("response", result);
+            if (response.getStatusLine().getStatusCode() >= HttpStatus.SC_REDIRECTION) {
+                logger.error("error", result);
+                throw new BizException(1000, "服务器走丢了 请重试");
+            }
+        } finally {
+            queue.put(token);
+        }
+
+        var gptResponse = objectMapper.readValue(result, GptResponse.class);
+
+        var completionsResponseDto = new CompletionsResponseDto();
+        var text = gptResponse.getChoices().stream().map(GptResponse.ChoicesDTO::getText).findFirst().orElse("").trim();
+        completionsResponseDto.setResult(text);
+        return completionsResponseDto;
+    }
+}