fix some bug

This commit is contained in:
luyoyu 2023-09-21 05:25:28 +08:00
parent fa6e9323cb
commit 6a244d0a0d

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,7 +12,6 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strconv" "strconv"
) )
@ -84,67 +82,33 @@ func diyBadRequest(c *gin.Context, code int, errorMessage string) {
// registerProxyRoute 反向代理路由 // registerProxyRoute 反向代理路由
func registerProxyRoute(r *gin.Engine, routePath, targetURL string) { func registerProxyRoute(r *gin.Engine, routePath, targetURL string) {
target, err := url.Parse(targetURL) target, _ := url.Parse(targetURL)
if err != nil { isModMsgEnabled := configFile.IsModMsg && targetURL == completionUrl
panic(err) director := func(req *http.Request) {
req.URL.Host = target.Host
req.Host = target.Host
req.URL.Scheme = "https"
} }
proxy := &httputil.ReverseProxy{ proxy := &httputil.ReverseProxy{
Director: func(req *http.Request) { Director: director,
req.URL.Host = target.Host ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {},
req.URL.Scheme = "http"
req.Host = target.Host
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
},
} }
r.Any(routePath, func(c *gin.Context) { r.Any(routePath, func(c *gin.Context) {
proxy.Director = func(req *http.Request) { if isModMsgEnabled {
req.URL.Host = target.Host modProxyResp(c)
req.URL.Scheme = "https" } else {
req.Host = target.Host proxy.ServeHTTP(c.Writer, c.Request)
} }
if configFile.IsModMsg {
modProxyResp(c, proxy)
}
proxy.ServeHTTP(c.Writer, c.Request)
c.Writer.Header().Set("X-Forwarded-Proto", "http")
}) })
} }
// 修改代码提示内容 func modProxyResp(c *gin.Context) {
func modProxyResp(c *gin.Context, proxy *httputil.ReverseProxy) { c.Header("Content-Type", "text/plain;charset=utf-8")
// 在代理响应之前修改响应 c.Header("X-Forwarded-Proto", "http")
proxy.ModifyResponse = func(response *http.Response) error { runes := []rune(configFile.DiyMsg)
responseBuffer := new(bytes.Buffer) for _, r := range runes {
_, readErr := responseBuffer.ReadFrom(response.Body) modStr := fmt.Sprintf("data: {\"id\":\"cmpl-7xy1GLgssjHEubVrPyt534VRYVF0t\",\"model\":\"cushman-ml\",\"created\":1694526422,\"choices\":[{\"text\":\"%c\",\"index\":0,\"finish_reason\":null,\"logprobs\":null}]}\n", r)
if readErr != nil { c.String(200, modStr)
return readErr
}
responseData := responseBuffer.Bytes()
// 定义正则表达式模式
pattern := `(.*?)data:`
// 编译正则表达式
regex := regexp.MustCompile(pattern)
// 查找匹配项
match := regex.FindStringSubmatch(string(responseData))
var replacedData = ""
if len(match) >= 2 {
replacedData = match[1]
} else {
fmt.Println("No match found.")
}
// 将字符串转换为[]rune类型的字符数组
runes := []rune(configFile.DiyMsg)
newStr := ""
// 遍历字符数组,对每个字符进行处理
for _, r := range runes {
// 在字符前后添加内容,生成新的字符串
newStr += fmt.Sprintf("data: {\"id\":\"cmpl-7xy1GLgssjHEubVrPyt534VRYVF0t\",\"model\":\"cushman-ml\",\"created\":1694526422,\"choices\":[{\"text\":\"%c\",\"index\":0,\"finish_reason\":null,\"logprobs\":null}]}\n", r)
}
newStr = replacedData + newStr + "data: [DONE]\n"
fmt.Println(newStr)
c.Data(response.StatusCode, response.Header.Get("Content-Type"), []byte(newStr))
return nil
} }
c.String(200, "data: [DONE]\n")
} }