zyx 2 жил өмнө
parent
commit
ba7ba6880b

+ 3 - 0
src/main/java/com/sxtvs/ChatgptApplication.java

@@ -3,6 +3,9 @@ package com.sxtvs;
 import org.springframework.boot.SpringApplication;
 import org.springframework.boot.autoconfigure.SpringBootApplication;
 
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+
 @SpringBootApplication
 public class ChatgptApplication {
 

+ 25 - 13
src/main/java/com/sxtvs/api/chatgpt/controller/ChatGptController.java

@@ -28,6 +28,8 @@ import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 
 import java.nio.charset.StandardCharsets;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 import java.util.stream.Collectors;
 
 @RestController
@@ -41,27 +43,37 @@ public class ChatGptController {
 
     private final CloseableHttpClient client = HttpClients.createDefault();
 
-    @SneakyThrows
+    private final Lock lock = new ReentrantLock();
+
     @RequestMapping("completions")
+    @SneakyThrows
     public CompletionsResponseDto completions(@RequestBody CompletionsRequestDto dto) {
+
+
         var paramsDto = new CompletionsParamsDto(dto.getPrompt()
                 .stream()
                 .map(CompletionsRequestDto.Prompt::getText)
                 .collect(Collectors.joining("\n")));
         var params = objectMapper.writeValueAsString(paramsDto);
         logger.info("request", params, "dto", objectMapper.writeValueAsString(dto));
-
-        var httpPost = new HttpPost("https://api.openai.com/v1/completions");
-        httpPost.setHeader("Authorization", "Bearer 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");
-        httpPost.setEntity(new StringEntity(params, ContentType.APPLICATION_JSON));
-        var response = client.execute(httpPost);
-        @Cleanup
-        var content = response.getEntity().getContent();
-        var result = IOUtils.toString(content, StandardCharsets.UTF_8);
-        logger.info("response", result);
-        if (response.getStatusLine().getStatusCode() >= HttpStatus.SC_REDIRECTION) {
-            logger.error("error", result);
-            throw new BizException(1000, "服务器走丢了 请重试");
+        lock.lock();
+        String result;
+        try {
+            var httpPost = new HttpPost("https://api.openai.com/v1/completions");
+            httpPost.setHeader("Authorization", "Bearer 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");
+            httpPost.setEntity(new StringEntity(params, ContentType.APPLICATION_JSON));
+            @Cleanup
+            var response = client.execute(httpPost);
+            @Cleanup
+            var content = response.getEntity().getContent();
+            result = IOUtils.toString(content, StandardCharsets.UTF_8);
+            logger.info("response", result);
+            if (response.getStatusLine().getStatusCode() >= HttpStatus.SC_REDIRECTION) {
+                logger.error("error", result);
+                throw new BizException(1000, "服务器走丢了 请重试");
+            }
+        } finally {
+            lock.unlock();
         }
         var gptResponse = objectMapper.readValue(result, GptResponse.class);