zyx 2 年 前
コミット
4458daaa57

+ 28 - 14
src/main/java/com/sxtvs/api/chatgpt/service/ChatGptService.java

@@ -39,6 +39,7 @@ public class ChatGptService {
     private final CloseableHttpClient client = HttpClients.createDefault();
     private final CloseableHttpClient client = HttpClients.createDefault();
 
 
     private final ArrayBlockingQueue<String> queue = new ArrayBlockingQueue<>(1);
     private final ArrayBlockingQueue<String> queue = new ArrayBlockingQueue<>(1);
+
     {
     {
         try {
         try {
 //            queue.put("Bearer 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");
 //            queue.put("Bearer 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");
@@ -48,35 +49,48 @@ public class ChatGptService {
         }
         }
     }
     }
 
 
-
     @SneakyThrows
     @SneakyThrows
-    public CompletionsResponseDto completions(@RequestBody CompletionsRequestDto dto) {
-        var paramsDto = new CompletionsParamsDto(dto.getPrompt()
-                .stream()
-                .map(CompletionsRequestDto.Prompt::getText)
-                .collect(Collectors.joining("\n")));
-        var params = objectMapper.writeValueAsString(paramsDto);
-        logger.info("request", params, "dto", objectMapper.writeValueAsString(dto));
-        var token = queue.poll(1, TimeUnit.DAYS);
-        String result;
-        try {
+    private String request(String params) {
+        while (true) {
+//            var token = queue.poll(1, TimeUnit.DAYS);
+            String result;
+//            try {
             var httpPost = new HttpPost("https://api.openai.com/v1/completions");
             var httpPost = new HttpPost("https://api.openai.com/v1/completions");
-            httpPost.setHeader("Authorization", token);
+            httpPost.setHeader("Authorization", "Bearer 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");
             httpPost.setEntity(new StringEntity(params, ContentType.APPLICATION_JSON));
             httpPost.setEntity(new StringEntity(params, ContentType.APPLICATION_JSON));
             @Cleanup
             @Cleanup
             var response = client.execute(httpPost);
             var response = client.execute(httpPost);
             @Cleanup
             @Cleanup
             var content = response.getEntity().getContent();
             var content = response.getEntity().getContent();
             result = IOUtils.toString(content, StandardCharsets.UTF_8);
             result = IOUtils.toString(content, StandardCharsets.UTF_8);
+            if (result.contains("The server had an error while processing your request")) {
+                continue;
+            }
             logger.info("response", result);
             logger.info("response", result);
             if (response.getStatusLine().getStatusCode() >= HttpStatus.SC_REDIRECTION) {
             if (response.getStatusLine().getStatusCode() >= HttpStatus.SC_REDIRECTION) {
                 logger.error("error", result);
                 logger.error("error", result);
                 throw new BizException(1000, "服务器走丢了 请重试");
                 throw new BizException(1000, "服务器走丢了 请重试");
             }
             }
-        } finally {
-            queue.put(token);
+            return result;
+//            } finally {
+//                queue.put(token);
+//            }
         }
         }
 
 
+    }
+
+
+    @SneakyThrows
+    public CompletionsResponseDto completions(@RequestBody CompletionsRequestDto dto) {
+        var paramsDto = new CompletionsParamsDto(dto.getPrompt()
+                .stream()
+                .map(CompletionsRequestDto.Prompt::getText)
+                .collect(Collectors.joining("\n")));
+        var params = objectMapper.writeValueAsString(paramsDto);
+        logger.info("request", params, "dto", objectMapper.writeValueAsString(dto));
+
+        var result = request(params);
+
         var gptResponse = objectMapper.readValue(result, GptResponse.class);
         var gptResponse = objectMapper.readValue(result, GptResponse.class);
 
 
         var completionsResponseDto = new CompletionsResponseDto();
         var completionsResponseDto = new CompletionsResponseDto();

+ 1 - 8
src/test/java/com/sxtvs/ChatgptApplicationTests.java

@@ -1,19 +1,12 @@
 package com.sxtvs;
 package com.sxtvs;
 
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
-import org.springframework.boot.test.context.SpringBootTest;
-
-import java.util.concurrent.ArrayBlockingQueue;
-import java.util.concurrent.TimeUnit;
 
 
 
 
 class ChatgptApplicationTests {
 class ChatgptApplicationTests {
 
 
-//    @Test
+    @Test
     void contextLoads() throws InterruptedException {
     void contextLoads() throws InterruptedException {
-//        ArrayBlockingQueue<String> queue = new ArrayBlockingQueue<>(1);
-//        var poll = queue.poll(1, TimeUnit.DAYS);
-//        System.out.println(poll);
     }
     }
 
 
 }
 }