diff --git a/src/server.go b/src/server.go index 98b0fd8..52eccca 100644 --- a/src/server.go +++ b/src/server.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "crypto/tls" "encoding/json" "fmt" @@ -13,7 +12,6 @@ import ( "net/url" "os" "path/filepath" - "regexp" "strconv" ) @@ -84,67 +82,33 @@ func diyBadRequest(c *gin.Context, code int, errorMessage string) { // registerProxyRoute 反向代理路由 func registerProxyRoute(r *gin.Engine, routePath, targetURL string) { - target, err := url.Parse(targetURL) - if err != nil { - panic(err) + target, _ := url.Parse(targetURL) + isModMsgEnabled := configFile.IsModMsg && targetURL == completionUrl + director := func(req *http.Request) { + req.URL.Host = target.Host + req.Host = target.Host + req.URL.Scheme = "https" } proxy := &httputil.ReverseProxy{ - Director: func(req *http.Request) { - req.URL.Host = target.Host - req.URL.Scheme = "http" - req.Host = target.Host - }, - ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { - }, + Director: director, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {}, } r.Any(routePath, func(c *gin.Context) { - proxy.Director = func(req *http.Request) { - req.URL.Host = target.Host - req.URL.Scheme = "https" - req.Host = target.Host + if isModMsgEnabled { + modProxyResp(c) + } else { + proxy.ServeHTTP(c.Writer, c.Request) } - if configFile.IsModMsg { - modProxyResp(c, proxy) - } - proxy.ServeHTTP(c.Writer, c.Request) - - c.Writer.Header().Set("X-Forwarded-Proto", "http") }) } -// 修改代码提示内容 -func modProxyResp(c *gin.Context, proxy *httputil.ReverseProxy) { - // 在代理响应之前修改响应 - proxy.ModifyResponse = func(response *http.Response) error { - responseBuffer := new(bytes.Buffer) - _, readErr := responseBuffer.ReadFrom(response.Body) - if readErr != nil { - return readErr - } - responseData := responseBuffer.Bytes() - // 定义正则表达式模式 - pattern := `(.*?)data:` - // 编译正则表达式 - regex := regexp.MustCompile(pattern) - // 查找匹配项 - match := regex.FindStringSubmatch(string(responseData)) - var replacedData = "" - if len(match) >= 2 { - replacedData = match[1] - } else { - fmt.Println("No match found.") - } - // 将字符串转换为[]rune类型的字符数组 - runes := []rune(configFile.DiyMsg) - newStr := "" - // 遍历字符数组,对每个字符进行处理 - for _, r := range runes { - // 在字符前后添加内容,生成新的字符串 - newStr += fmt.Sprintf("data: {\"id\":\"cmpl-7xy1GLgssjHEubVrPyt534VRYVF0t\",\"model\":\"cushman-ml\",\"created\":1694526422,\"choices\":[{\"text\":\"%c\",\"index\":0,\"finish_reason\":null,\"logprobs\":null}]}\n", r) - } - newStr = replacedData + newStr + "data: [DONE]\n" - fmt.Println(newStr) - c.Data(response.StatusCode, response.Header.Get("Content-Type"), []byte(newStr)) - return nil +func modProxyResp(c *gin.Context) { + c.Header("Content-Type", "text/plain;charset=utf-8") + c.Header("X-Forwarded-Proto", "http") + runes := []rune(configFile.DiyMsg) + for _, r := range runes { + modStr := fmt.Sprintf("data: {\"id\":\"cmpl-7xy1GLgssjHEubVrPyt534VRYVF0t\",\"model\":\"cushman-ml\",\"created\":1694526422,\"choices\":[{\"text\":\"%c\",\"index\":0,\"finish_reason\":null,\"logprobs\":null}]}\n", r) + c.String(200, modStr) } + c.String(200, "data: [DONE]\n") }