diff --git a/source/.gitignore b/source/.gitignore index f790e77..1117f4d 100644 --- a/source/.gitignore +++ b/source/.gitignore @@ -5,3 +5,4 @@ /config.json /*.exe /share-copilot +/*.bat diff --git a/source/build.bat b/source/build.bat deleted file mode 100644 index d78103a..0000000 --- a/source/build.bat +++ /dev/null @@ -1,10 +0,0 @@ -@echo off -SET CGO_ENABLED=0 -SET GOOS=linux -SET GOARCH=amd64 -go build -@echo off -SET CGO_ENABLED=0 -SET GOOS=windows -SET GOARCH=amd64 -go build \ No newline at end of file diff --git a/source/error.log b/source/error.log deleted file mode 100644 index e69de29..0000000 diff --git a/source/getCopilotToke.go b/source/getCopilotToke.go new file mode 100644 index 0000000..f6827cd --- /dev/null +++ b/source/getCopilotToke.go @@ -0,0 +1,142 @@ +package main + +import ( + "encoding/json" + "errors" + "github.com/gin-gonic/gin" + "github.com/go-resty/resty/v2" + "math/rand" + "net/http" + "strings" + "sync" + "time" +) + +// 初始化有效的github token列表 +func initValidTokenList() { + //为了安全起见,应该等待请求完成并处理其响应。 + var wg sync.WaitGroup + for _, token := range configFile.CopilotConfig.Token { + wg.Add(1) + go func(token string) { + defer wg.Done() + if getGithubApi(token) { + validTokenList[token] = true + } + }(token) + } + wg.Wait() +} + +// 请求github api +func getGithubApi(token string) bool { + githubApiCount++ + // 设置请求头 + headers := map[string]string{ + "Authorization": "token " + token, + /*"editor-version": c.GetHeader("editor-version"), + "editor-plugin-version": c.GetHeader("editor-plugin-version"), + "user-agent": c.GetHeader("user-agent"), + "accept": c.GetHeader("accept"), + "accept-encoding": c.GetHeader("accept-encoding"),*/ + } + // 发起GET请求 + response, err := resty.New().R(). + SetHeaders(headers). + Get(configFile.CopilotConfig.GithubApiUrl) + if err != nil { + // 处理请求错误 + return false + } + // 判断响应状态码 + if response.StatusCode() == http.StatusOK { + // 响应状态码为200 OK + respDataMap := map[string]interface{}{} + err = json.Unmarshal(response.Body(), &respDataMap) + if err != nil { + // 处理JSON解析错误 + return false + } + //token map + tokenMap[token] = respDataMap + return true + } else { + // 处理其他状态码 + delete(validTokenList, token) + return false + } +} + +// 获取copilot token +func getGithubToken() gin.HandlerFunc { + return func(c *gin.Context) { + requestCount++ + if err := verifyRequest(c); err != nil { + badRequest(c) + return + } + token := getRandomToken(validTokenList) + if respDataMap, exists := getTokenData(token); exists { + if !isTokenExpired(respDataMap) { + proxyResp(c, respDataMap) + return + } + } + if getGithubApi(token) { + proxyResp(c, tokenMap[token]) + } else { + badRequest(c) + } + } +} + +// 验证请求代理请求token +func verifyRequest(c *gin.Context) error { + if configFile.Verification != "" { + token := c.GetHeader("Authorization") + tokenStr := strings.ReplaceAll(token, " ", "") + configCert := strings.ReplaceAll(configFile.Verification, " ", "") + if tokenStr != "token"+configCert { + return errors.New("verification failed") + } + } + return nil +} + +// 从map中获取github token对应的copilot token +func getTokenData(token string) (map[string]interface{}, bool) { + respDataMap, exists := tokenMap[token] + return respDataMap, exists +} + +// 检测copilot token是否过期 +func isTokenExpired(respDataMap map[string]interface{}) bool { + if expiresAt, ok := respDataMap["expires_at"].(float64); ok { + currentTime := time.Now().Unix() + expiresAtInt64 := int64(expiresAt) + return expiresAtInt64 <= currentTime+60 + } + return true +} + +// 重置请求计数 +func resetRequestCount() { + requestCountMutex.Lock() + defer requestCountMutex.Unlock() + requestCount = 0 + successCount = 0 +} + +// 从map中随机获取一个github token +func getRandomToken(m map[string]bool) string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + if len(keys) == 0 { + return "" // 返回空字符串或处理其他错误情况 + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + randomIndex := r.Intn(len(keys)) + return keys[randomIndex] +} diff --git a/source/main.go b/source/main.go index dcd91f6..411290e 100644 --- a/source/main.go +++ b/source/main.go @@ -1,316 +1,21 @@ package main -import ( - "crypto/tls" - "encoding/json" - "fmt" - "github.com/fatih/color" - "github.com/gin-gonic/gin" - "github.com/go-resty/resty/v2" - "io" - "log" - "math/rand" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" -) - func main() { - //初始化配置文件 + // 初始化配置文件 configFile = initConfig() + // 创建Gin引擎 - gin.SetMode(gin.ReleaseMode) - gin.DefaultWriter = io.Discard - engine := gin.New() - // 设置信任的代理 - err := engine.SetTrustedProxies([]string{"127.0.0.1"}) - if err != nil { - log.Fatal(err) - } + engine := setupGinEngine() + // 初始化有效的token列表 initValidTokenList() + // 定义路由 - domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain)) - domainDefault.GET("/copilot_internal/v2/token", getGithubToken()) - // 初始化服务器 - initServer(engine) + setupRoutes(engine) + + // 初始化并启动服务器 + initAndStartServer(engine) + // 显示信息 showMsg() } - -// 初始化服务器 -func initServer(engine *gin.Engine) { - // 配置支持的应用程序协议 - server := &http.Server{ - Addr: ":443", - TLSConfig: &tls.Config{ - NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, // 支持的应用程序协议列表 - }, - Handler: engine, - } - // 启动Gin服务器并监听端口 - listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port) - go func() { - if configFile.Server.Port != 443 { - err := engine.Run(listenAddress) - if err != nil { - log.Fatal(err) - } - } else { - err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath) - if err != nil { - log.Fatal(err) - } - } - }() -} - -// 初始化有效的token列表 -func initValidTokenList() { - //为了安全起见,应该等待请求完成并处理其响应。 - var wg sync.WaitGroup - for _, token := range configFile.CopilotConfig.Token { - wg.Add(1) - go func(token string) { - defer wg.Done() - if getGithubApi(token) { - validTokenList[token] = true - } - }(token) - } - wg.Wait() -} - -// 获取Github的token -func getGithubToken() gin.HandlerFunc { - return func(c *gin.Context) { - // 请求计数 - requestCount++ - // 如果配置了verification,则需要获取请求头中的Authorization令牌 - if configFile.Verification != "" { - token := c.GetHeader("Authorization") - tokenStr := strings.ReplaceAll(token, " ", "") - configCert := strings.ReplaceAll(configFile.Verification, " ", "") - if tokenStr != "token"+configCert { - // 拒绝不符合Verification的请求 - badRequest(c) - return - } - } - //从有效的token列表中随机获取一个token - token := getRandomToken(validTokenList) - //判断tokenMap 里的token是否存在 - if _, exists := tokenMap[token]; exists { - respDataMap := tokenMap[token] - //判断时间戳key是否存在 - if _, exists := respDataMap["expires_at"]; exists { - // 获取当前时间的Unix时间戳 - currentTime := time.Now().Unix() - if expiresAt, ok := respDataMap["expires_at"].(float64); ok { - // 判断expires_at是否已经过期 - expiresAtInt64 := int64(expiresAt) - //提前一分钟请求 - if expiresAtInt64 > currentTime+60 { - //fmt.Println("\n未过期无需请求") - proxyResp(c, tokenMap[token]) - } else { - //fmt.Println("\n已过期重新请求") - if getGithubApi(token) { - proxyResp(c, tokenMap[token]) - return - } else { - badRequest(c) - } - } - } else { - badRequest(c) - } - } else { - //tokenMap里的token对应的返回体不存在expires_at - if getGithubApi(token) { - proxyResp(c, tokenMap[token]) - return - } else { - badRequest(c) - } - } - } else { - //不存在则githubApi发送请求,并存到tokenMap - if getGithubApi(token) { - proxyResp(c, tokenMap[token]) - return - } else { - badRequest(c) - } - } - } -} - -// 请求错误 -func badRequest(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{ - "message": "Bad credentials", - "documentation_url": "https://docs.github.com/rest"}) -} - -// 本服务器响应 -func proxyResp(c *gin.Context, respDataMap map[string]interface{}) { - // 将map转换为JSON字符串 - responseJSON, err := json.Marshal(respDataMap) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"}) - } - // 请求成功统计 - successCount++ - // 将JSON字符串作为响应体返回 - c.Header("Content-Type", "application/json") - c.String(http.StatusOK, string(responseJSON)) -} - -// 请求githubApi -func getGithubApi(token string) bool { - githubApiCount++ - // 设置请求头 - headers := map[string]string{ - "Authorization": "token " + token, - /*"editor-version": c.GetHeader("editor-version"), - "editor-plugin-version": c.GetHeader("editor-plugin-version"), - "user-agent": c.GetHeader("user-agent"), - "accept": c.GetHeader("accept"), - "accept-encoding": c.GetHeader("accept-encoding"),*/ - } - // 发起GET请求 - response, err := resty.New().R(). - SetHeaders(headers). - Get(configFile.CopilotConfig.GithubApiUrl) - if err != nil { - // 处理请求错误 - return false - } - // 判断响应状态码 - if response.StatusCode() == http.StatusOK { - // 响应状态码为200 OK - respDataMap := map[string]interface{}{} - err = json.Unmarshal(response.Body(), &respDataMap) - if err != nil { - // 处理JSON解析错误 - return false - } - //token map - tokenMap[token] = respDataMap - return true - } else { - // 处理其他状态码 - delete(validTokenList, token) - return false - } -} - -// 初始化配置文件 -func initConfig() Config { - // 读取配置文件 - exePath, err := os.Executable() - if err != nil { - panic(err) - } - // 获取执行文件所在目录 - exeDir := filepath.Dir(exePath) - configFile, err := os.Open(exeDir + "/config.json") - if err != nil { - panic("file \"./config.json\" not found") - } - defer func(configFile *os.File) { - err := configFile.Close() - if err != nil { - panic("close file \"./config.json\" err") - } - }(configFile) - decoder := json.NewDecoder(configFile) - config := Config{} - err = decoder.Decode(&config) - if err != nil { - panic("config format err") - } - return config -} - -// 重置请求计数 -func resetRequestCount() { - requestCountMutex.Lock() - defer requestCountMutex.Unlock() - requestCount = 0 - successCount = 0 -} - -// 从map中随机获取一个key -func getRandomToken(m map[string]bool) string { - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - if len(keys) == 0 { - return "" // 返回空字符串或处理其他错误情况 - } - r := rand.New(rand.NewSource(time.Now().UnixNano())) - randomIndex := r.Intn(len(keys)) - return keys[randomIndex] -} - -// DomainMiddleware 域名中间件 -func DomainMiddleware(domain string) gin.HandlerFunc { - return func(c *gin.Context) { - // 检查域名是否匹配 - requestDomain := strings.Split(c.Request.Host, ":")[0] - if requestDomain == domain || requestDomain == "127.0.0.1" { - c.Next() - } else { - c.String(403, "Forbidden") - c.Abort() - } - } -} - -// 显示信息 -func showMsg() { - var url = "" - if configFile.Server.Port == 80 { - url = "http://" + configFile.Server.Domain - } else if configFile.Server.Port == 443 { - url = "https://" + configFile.Server.Domain - } else { - url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) - } - var jetStr = color.WhiteString("[Jetbrains]") - var vsStr = color.WhiteString("[Vscode]") - var valid = color.WhiteString("[Valid tokens]") - - fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token")) - fmt.Println(vsStr + ": " + color.HiBlueString(url)) - fmt.Println(valid + ": " + color.HiBlueString(strconv.Itoa(len(validTokenList)))) - fmt.Println(color.WhiteString("-----------------------------------------------------------------------")) - for { - requestCountMutex.Lock() - sCount := successCount - tCount := requestCount - gCount := githubApiCount - requestCountMutex.Unlock() - currentTime := time.Now().Format("2006-01-02 15:04:05") - if "00:00:00" == currentTime { - resetRequestCount() - } - var s2 = color.WhiteString("[Succeed]") - var s3 = color.WhiteString("[Failed]") - var s4 = color.WhiteString("[GithubApi]") - // 打印文本 - fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ", - color.HiYellowString(currentTime), - s2, color.GreenString(strconv.Itoa(sCount)), - s3, color.RedString(strconv.Itoa(tCount-sCount)), - s4, color.CyanString(strconv.Itoa(gCount))) - time.Sleep(1 * time.Second) // - } -} diff --git a/source/server.go b/source/server.go new file mode 100644 index 0000000..6395ab4 --- /dev/null +++ b/source/server.go @@ -0,0 +1,121 @@ +package main + +import ( + "crypto/tls" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "log" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" +) + +// 初始化配置文件 +func initConfig() Config { + // 读取配置文件 + exePath, err := os.Executable() + if err != nil { + panic(err) + } + // 获取执行文件所在目录 + exeDir := filepath.Dir(exePath) + configFile, err := os.Open(exeDir + "/config.json") + if err != nil { + panic("file \"./config.json\" not found") + } + defer func(configFile *os.File) { + err := configFile.Close() + if err != nil { + panic("close file \"./config.json\" err") + } + }(configFile) + decoder := json.NewDecoder(configFile) + config := Config{} + err = decoder.Decode(&config) + if err != nil { + panic("config format err") + } + return config +} + +// 创建和配置Gin引擎 +func setupGinEngine() *gin.Engine { + gin.SetMode(gin.ReleaseMode) + gin.DefaultWriter = io.Discard + engine := gin.New() + // 设置信任的代理 + if err := engine.SetTrustedProxies([]string{"127.0.0.1"}); err != nil { + log.Fatal(err) + } + return engine +} + +// 定义路由和中间件 +func setupRoutes(engine *gin.Engine) { + domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain)) + domainDefault.GET("/copilot_internal/v2/token", getGithubToken()) +} + +// DomainMiddleware 域名中间件 +func DomainMiddleware(domain string) gin.HandlerFunc { + return func(c *gin.Context) { + // 检查域名是否匹配 + requestDomain := strings.Split(c.Request.Host, ":")[0] + if requestDomain == domain || requestDomain == "127.0.0.1" { + c.Next() + } else { + c.String(403, "Forbidden") + c.Abort() + } + } +} + +// 初始化和启动服务器 +func initAndStartServer(engine *gin.Engine) { + listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port) + server := createTLSServer(engine, listenAddress) + go func() { + if configFile.Server.Port != 443 { + err := engine.Run(listenAddress) + log.Fatal(err) + } else { + err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath) + log.Fatal(err) + } + }() +} + +// 创建TLS服务器配置 +func createTLSServer(engine *gin.Engine, address string) *http.Server { + return &http.Server{ + Addr: address, + TLSConfig: &tls.Config{ + NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, + }, + Handler: engine, + } +} + +// 本服务器响应 +func proxyResp(c *gin.Context, respDataMap map[string]interface{}) { + // 将map转换为JSON字符串 + responseJSON, err := json.Marshal(respDataMap) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"}) + } + // 请求成功统计 + successCount++ + // 将JSON字符串作为响应体返回 + c.Header("Content-Type", "application/json") + c.String(http.StatusOK, string(responseJSON)) +} + +// 请求错误 +func badRequest(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{ + "message": "Bad credentials", + "documentation_url": "https://docs.github.com/rest"}) +} diff --git a/source/showMsg.go b/source/showMsg.go new file mode 100644 index 0000000..b56a42f --- /dev/null +++ b/source/showMsg.go @@ -0,0 +1,48 @@ +package main + +import ( + "fmt" + "github.com/fatih/color" + "strconv" + "time" +) + +// 控制台显示信息 +func showMsg() { + var url = "" + if configFile.Server.Port == 80 { + url = "http://" + configFile.Server.Domain + } else if configFile.Server.Port == 443 { + url = "https://" + configFile.Server.Domain + } else { + url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) + } + var jetStr = color.WhiteString("[Jetbrains]") + var vsStr = color.WhiteString("[Vscode]") + var valid = color.WhiteString("[Valid tokens]") + fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token")) + fmt.Println(vsStr + ": " + color.HiBlueString(url)) + fmt.Println(valid + ": " + color.HiBlueString(strconv.Itoa(len(validTokenList)))) + fmt.Println(color.WhiteString("-----------------------------------------------------------------------")) + for { + requestCountMutex.Lock() + sCount := successCount + tCount := requestCount + gCount := githubApiCount + requestCountMutex.Unlock() + currentTime := time.Now().Format("2006-01-02 15:04:05") + if "00:00:00" == currentTime { + resetRequestCount() + } + var s2 = color.WhiteString("[Succeed]") + var s3 = color.WhiteString("[Failed]") + var s4 = color.WhiteString("[GithubApi]") + // 打印文本 + fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ", + color.HiYellowString(currentTime), + s2, color.GreenString(strconv.Itoa(sCount)), + s3, color.RedString(strconv.Itoa(tCount-sCount)), + s4, color.CyanString(strconv.Itoa(gCount))) + time.Sleep(1 * time.Second) // + } +}