diff --git a/share-copilot b/share-copilot index 287c03b..39c916d 100644 Binary files a/share-copilot and b/share-copilot differ diff --git a/source/define.go b/source/define.go index 3e87c74..e1fd34a 100644 --- a/source/define.go +++ b/source/define.go @@ -21,7 +21,9 @@ type Config struct { var ( //初始化需要返回给客户端的响应体 - responseData map[string]interface{} + tokenMap = make(map[string]map[string]interface{}) + //有效的token列表 + validTokenList = make(map[string]bool) requestCountMutex sync.Mutex githubApiCount = 0 requestCount = 0 diff --git a/source/main.go b/source/main.go index 40aa77c..dcd91f6 100644 --- a/source/main.go +++ b/source/main.go @@ -15,6 +15,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" ) @@ -30,13 +31,18 @@ func main() { if err != nil { log.Fatal(err) } + // 初始化有效的token列表 + initValidTokenList() // 定义路由 domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain)) - domainDefault.GET("/copilot_internal/v2/token", getToken()) + domainDefault.GET("/copilot_internal/v2/token", getGithubToken()) + // 初始化服务器 initServer(engine) // 显示信息 showMsg() } + +// 初始化服务器 func initServer(engine *gin.Engine) { // 配置支持的应用程序协议 server := &http.Server{ @@ -62,47 +68,28 @@ func initServer(engine *gin.Engine) { } }() } -func showMsg() { - var url = "" - if configFile.Server.Port == 80 { - url = "http://" + configFile.Server.Domain - } else if configFile.Server.Port == 443 { - url = "https://" + configFile.Server.Domain - } else { - url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) - } - var jetStr = color.WhiteString("[Jetbrains]") - var vsStr = color.WhiteString("[Vscode]") - fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token")) - fmt.Println(vsStr + ": " + color.HiBlueString(url)) - fmt.Println(color.WhiteString("-----------------------------------------------------------------------")) - for { - requestCountMutex.Lock() - sCount := successCount - tCount := requestCount - gCount := githubApiCount - requestCountMutex.Unlock() - currentTime := time.Now().Format("2006-01-02 15:04:05") - if "00:00:00" == currentTime { - resetRequestCount() - } - var s2 = color.WhiteString("[Succeed]") - var s3 = color.WhiteString("[Failed]") - var s4 = color.WhiteString("[GithubApi]") - // 打印文本 - fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ", - color.HiYellowString(currentTime), - s2, color.GreenString(strconv.Itoa(sCount)), - s3, color.RedString(strconv.Itoa(tCount-sCount)), - s4, color.CyanString(strconv.Itoa(gCount))) - time.Sleep(1 * time.Second) // +// 初始化有效的token列表 +func initValidTokenList() { + //为了安全起见,应该等待请求完成并处理其响应。 + var wg sync.WaitGroup + for _, token := range configFile.CopilotConfig.Token { + wg.Add(1) + go func(token string) { + defer wg.Done() + if getGithubApi(token) { + validTokenList[token] = true + } + }(token) } + wg.Wait() } -func getToken() gin.HandlerFunc { + +// 获取Github的token +func getGithubToken() gin.HandlerFunc { return func(c *gin.Context) { // 请求计数 - incrementRequestCount() + requestCount++ // 如果配置了verification,则需要获取请求头中的Authorization令牌 if configFile.Verification != "" { token := c.GetHeader("Authorization") @@ -110,61 +97,91 @@ func getToken() gin.HandlerFunc { configCert := strings.ReplaceAll(configFile.Verification, " ", "") if tokenStr != "token"+configCert { // 拒绝不符合Verification的请求 - c.JSON(http.StatusBadRequest, gin.H{ - "message": "Bad credentials", - "documentation_url": "https://docs.github.com/rest"}) + badRequest(c) return } } - //判断时间戳key是否存在 - if _, exists := responseData["expires_at"]; exists { - // 获取当前时间的Unix时间戳 - currentTime := time.Now().Unix() - if expiresAt, ok := responseData["expires_at"].(float64); ok { - // 判断expires_at是否已经过期 - expiresAtInt64 := int64(expiresAt) - //提前一分钟请求 - if expiresAtInt64 > currentTime+60 { - //fmt.Println("\n未过期无需请求") - respProxy(c) + //从有效的token列表中随机获取一个token + token := getRandomToken(validTokenList) + //判断tokenMap 里的token是否存在 + if _, exists := tokenMap[token]; exists { + respDataMap := tokenMap[token] + //判断时间戳key是否存在 + if _, exists := respDataMap["expires_at"]; exists { + // 获取当前时间的Unix时间戳 + currentTime := time.Now().Unix() + if expiresAt, ok := respDataMap["expires_at"].(float64); ok { + // 判断expires_at是否已经过期 + expiresAtInt64 := int64(expiresAt) + //提前一分钟请求 + if expiresAtInt64 > currentTime+60 { + //fmt.Println("\n未过期无需请求") + proxyResp(c, tokenMap[token]) + } else { + //fmt.Println("\n已过期重新请求") + if getGithubApi(token) { + proxyResp(c, tokenMap[token]) + return + } else { + badRequest(c) + } + } } else { - //fmt.Println("\n已过期重新请求") - getGithubApi(c) - respProxy(c) + badRequest(c) } } else { - fmt.Println("Age is not an int") + //tokenMap里的token对应的返回体不存在expires_at + if getGithubApi(token) { + proxyResp(c, tokenMap[token]) + return + } else { + badRequest(c) + } } } else { - //向githubApi发送请求 - //fmt.Println("\n第一次请求") - getGithubApi(c) - respProxy(c) + //不存在则githubApi发送请求,并存到tokenMap + if getGithubApi(token) { + proxyResp(c, tokenMap[token]) + return + } else { + badRequest(c) + } } } } -func respProxy(c *gin.Context) { + +// 请求错误 +func badRequest(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{ + "message": "Bad credentials", + "documentation_url": "https://docs.github.com/rest"}) +} + +// 本服务器响应 +func proxyResp(c *gin.Context, respDataMap map[string]interface{}) { // 将map转换为JSON字符串 - responseJSON, err := json.Marshal(responseData) + responseJSON, err := json.Marshal(respDataMap) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"}) } // 请求成功统计 - incrementSuccessCount() + successCount++ // 将JSON字符串作为响应体返回 c.Header("Content-Type", "application/json") c.String(http.StatusOK, string(responseJSON)) } -func getGithubApi(c *gin.Context) { + +// 请求githubApi +func getGithubApi(token string) bool { githubApiCount++ // 设置请求头 headers := map[string]string{ - "Authorization": "token " + getRandomToken(configFile.CopilotConfig.Token), - "editor-version": c.GetHeader("editor-version"), + "Authorization": "token " + token, + /*"editor-version": c.GetHeader("editor-version"), "editor-plugin-version": c.GetHeader("editor-plugin-version"), "user-agent": c.GetHeader("user-agent"), "accept": c.GetHeader("accept"), - "accept-encoding": c.GetHeader("accept-encoding"), + "accept-encoding": c.GetHeader("accept-encoding"),*/ } // 发起GET请求 response, err := resty.New().R(). @@ -172,17 +189,28 @@ func getGithubApi(c *gin.Context) { Get(configFile.CopilotConfig.GithubApiUrl) if err != nil { // 处理请求错误 - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return false } - - err = json.Unmarshal(response.Body(), &responseData) - if err != nil { - // 处理JSON解析错误 - c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"}) - return + // 判断响应状态码 + if response.StatusCode() == http.StatusOK { + // 响应状态码为200 OK + respDataMap := map[string]interface{}{} + err = json.Unmarshal(response.Body(), &respDataMap) + if err != nil { + // 处理JSON解析错误 + return false + } + //token map + tokenMap[token] = respDataMap + return true + } else { + // 处理其他状态码 + delete(validTokenList, token) + return false } } + +// 初始化配置文件 func initConfig() Config { // 读取配置文件 exePath, err := os.Executable() @@ -209,26 +237,30 @@ func initConfig() Config { } return config } -func incrementRequestCount() { - requestCount++ -} -func incrementSuccessCount() { - successCount++ -} + +// 重置请求计数 func resetRequestCount() { requestCountMutex.Lock() defer requestCountMutex.Unlock() requestCount = 0 successCount = 0 } -func getRandomToken(tokens []string) string { - if len(tokens) == 0 { + +// 从map中随机获取一个key +func getRandomToken(m map[string]bool) string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + if len(keys) == 0 { return "" // 返回空字符串或处理其他错误情况 } r := rand.New(rand.NewSource(time.Now().UnixNano())) - randomIndex := r.Intn(len(tokens)) - return tokens[randomIndex] + randomIndex := r.Intn(len(keys)) + return keys[randomIndex] } + +// DomainMiddleware 域名中间件 func DomainMiddleware(domain string) gin.HandlerFunc { return func(c *gin.Context) { // 检查域名是否匹配 @@ -241,3 +273,44 @@ func DomainMiddleware(domain string) gin.HandlerFunc { } } } + +// 显示信息 +func showMsg() { + var url = "" + if configFile.Server.Port == 80 { + url = "http://" + configFile.Server.Domain + } else if configFile.Server.Port == 443 { + url = "https://" + configFile.Server.Domain + } else { + url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) + } + var jetStr = color.WhiteString("[Jetbrains]") + var vsStr = color.WhiteString("[Vscode]") + var valid = color.WhiteString("[Valid tokens]") + + fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token")) + fmt.Println(vsStr + ": " + color.HiBlueString(url)) + fmt.Println(valid + ": " + color.HiBlueString(strconv.Itoa(len(validTokenList)))) + fmt.Println(color.WhiteString("-----------------------------------------------------------------------")) + for { + requestCountMutex.Lock() + sCount := successCount + tCount := requestCount + gCount := githubApiCount + requestCountMutex.Unlock() + currentTime := time.Now().Format("2006-01-02 15:04:05") + if "00:00:00" == currentTime { + resetRequestCount() + } + var s2 = color.WhiteString("[Succeed]") + var s3 = color.WhiteString("[Failed]") + var s4 = color.WhiteString("[GithubApi]") + // 打印文本 + fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ", + color.HiYellowString(currentTime), + s2, color.GreenString(strconv.Itoa(sCount)), + s3, color.RedString(strconv.Itoa(tCount-sCount)), + s4, color.CyanString(strconv.Itoa(gCount))) + time.Sleep(1 * time.Second) // + } +} diff --git a/source/test.go b/source/test.go new file mode 100644 index 0000000..1913fb8 --- /dev/null +++ b/source/test.go @@ -0,0 +1,86 @@ +package main + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "net/http" +) + +func getGithubTest(c *gin.Context, token string) { + githubApiCount++ + data1 := ` + { + "chat_enabled": false, + "code_quote_enabled": false, + "code_quote_v2_enabled": false, + "copilotignore_enabled": false, + "expires_at": 3194360727, + "prompt_8k": true, + "public_suggestions": "disabled", + "refresh_in": 1500, + "sku": "free_educational", + "telemetry": "disabled", + "token": "tid=;exp=1694360727;sku=free_educational;st=dotcom;8kp=1:", + "tracking_id": "" + } + ` + data2 := ` + { + "chat_enabled": false, + "code_quote_enabled": false, + "code_quote_v2_enabled": false, + "copilotignore_enabled": false, + "expires_at": 2294360727, + "prompt_8k": true, + "public_suggestions": "disabled", + "refresh_in": 1500, + "sku": "free_educational", + "telemetry": "disabled", + "token": "tid=;exp=1694360727;sku=free_educational;st=dotcom;8kp=1:", + "tracking_id": "" + } + ` + data3 := ` + { + "chat_enabled": false, + "code_quote_enabled": false, + "code_quote_v2_enabled": false, + "copilotignore_enabled": false, + "expires_at": 3394360727, + "prompt_8k": true, + "public_suggestions": "disabled", + "refresh_in": 1500, + "sku": "free_educational", + "telemetry": "disabled", + "token": "tid=;exp=1694360727;sku=free_educational;st=dotcom;8kp=1:", + "tracking_id": "" + } + ` + //响应体map + var respDataMap = make(map[string]interface{}) + if token == "1" { + err := json.Unmarshal([]byte(data1), &respDataMap) + if err != nil { + // 处理JSON解析错误 + c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"}) + return + } + } else if token == "2" { + err := json.Unmarshal([]byte(data2), &respDataMap) + if err != nil { + // 处理JSON解析错误 + c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"}) + return + } + } else { + err := json.Unmarshal([]byte(data3), &respDataMap) + if err != nil { + // 处理JSON解析错误 + c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"}) + return + } + } + + //token map + tokenMap[token] = respDataMap +}