diff --git a/src/copilot_api.go b/src/copilot_api.go new file mode 100644 index 0000000..9bb7d8a --- /dev/null +++ b/src/copilot_api.go @@ -0,0 +1,112 @@ +package main + +import ( + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/go-resty/resty/v2" + "math/rand" + "net/http" + "time" +) + +// 初始化有效的github token列表 +func initValidGhuTokenMap() { + for _, token := range configFile.CopilotConfig.Token { + if getGithubApi(token) { + // 有效的token + validGhuTokenMap[token] = true + } + } +} + +// 代理服务器返回copilot token +func proxyResp(c *gin.Context, respDataMap map[string]interface{}) { + // 将map转换为JSON字符串 + responseJSON, _ := json.Marshal(respDataMap) + // 请求成功统计 + successCount++ + // 将响应体作为JSON字符串返回 + c.Header("Content-Type", "application/json") + c.String(http.StatusOK, string(responseJSON)) +} + +// 获取有效的 copilot token +func getCopilotToken() gin.HandlerFunc { + return func(c *gin.Context) { + //用户请求代理服务器的计数 + requestCount++ + //随机从有效的github token列表中获取一个token + token := getRandomToken(validGhuTokenMap) + //通过token取对应的copilot token的map数据 + if respDataMap, exists := getTokenData(token); exists { + //没过期直接返回 + if !isTokenExpired(respDataMap) { + proxyResp(c, respDataMap) + return + } + } + //过期了或者没取到,重新获取 + if getGithubApi(token) { + proxyResp(c, copilotTokenMap[token]) + } else { + // 重新获取失败,返回自定义消息 400 + diyBadRequest(c, 400, "can't get copilot token") + } + } +} + +// 请求github api +func getGithubApi(token string) bool { + // githubApi请求计数 + githubApiCount++ + // 设置请求头 + headers := map[string]string{ + "Authorization": "token " + token, + } + // 发起GET请求 + response, _ := resty.New().R(). + SetHeaders(headers). + Get(configFile.CopilotConfig.GithubApiUrl) + // 判断响应状态码 + if response.StatusCode() == http.StatusOK { + // 响应状态码为200 OK + respDataMap := map[string]interface{}{} + _ = json.Unmarshal(response.Body(), &respDataMap) + copilotTokenMap[token] = respDataMap + return true + } else { + // 响应状态码不为200 map删除无效token + delete(validGhuTokenMap, token) + return false + } +} + +// 从map中获取github token对应的copilot token +func getTokenData(token string) (map[string]interface{}, bool) { + respDataMap, exists := copilotTokenMap[token] + return respDataMap, exists +} + +// 检测copilot token是否过期 +func isTokenExpired(respDataMap map[string]interface{}) bool { + if expiresAt, ok := respDataMap["expires_at"].(float64); ok { + currentTime := time.Now().Unix() + expiresAtInt64 := int64(expiresAt) + return expiresAtInt64 <= currentTime+60 + } + return true +} + +// 从map中随机获取一个github token +func getRandomToken(m map[string]bool) string { + ghuTokenArray := make([]string, 0, len(m)) + for k := range m { + ghuTokenArray = append(ghuTokenArray, k) + } + if len(ghuTokenArray) == 0 { + return "" // 没有有效的token,返回空字符串 + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + randomIndex := r.Intn(len(ghuTokenArray)) + return ghuTokenArray[randomIndex] +} diff --git a/src/getCopilotToke.go b/src/getCopilotToke.go deleted file mode 100644 index f6827cd..0000000 --- a/src/getCopilotToke.go +++ /dev/null @@ -1,142 +0,0 @@ -package main - -import ( - "encoding/json" - "errors" - "github.com/gin-gonic/gin" - "github.com/go-resty/resty/v2" - "math/rand" - "net/http" - "strings" - "sync" - "time" -) - -// 初始化有效的github token列表 -func initValidTokenList() { - //为了安全起见,应该等待请求完成并处理其响应。 - var wg sync.WaitGroup - for _, token := range configFile.CopilotConfig.Token { - wg.Add(1) - go func(token string) { - defer wg.Done() - if getGithubApi(token) { - validTokenList[token] = true - } - }(token) - } - wg.Wait() -} - -// 请求github api -func getGithubApi(token string) bool { - githubApiCount++ - // 设置请求头 - headers := map[string]string{ - "Authorization": "token " + token, - /*"editor-version": c.GetHeader("editor-version"), - "editor-plugin-version": c.GetHeader("editor-plugin-version"), - "user-agent": c.GetHeader("user-agent"), - "accept": c.GetHeader("accept"), - "accept-encoding": c.GetHeader("accept-encoding"),*/ - } - // 发起GET请求 - response, err := resty.New().R(). - SetHeaders(headers). - Get(configFile.CopilotConfig.GithubApiUrl) - if err != nil { - // 处理请求错误 - return false - } - // 判断响应状态码 - if response.StatusCode() == http.StatusOK { - // 响应状态码为200 OK - respDataMap := map[string]interface{}{} - err = json.Unmarshal(response.Body(), &respDataMap) - if err != nil { - // 处理JSON解析错误 - return false - } - //token map - tokenMap[token] = respDataMap - return true - } else { - // 处理其他状态码 - delete(validTokenList, token) - return false - } -} - -// 获取copilot token -func getGithubToken() gin.HandlerFunc { - return func(c *gin.Context) { - requestCount++ - if err := verifyRequest(c); err != nil { - badRequest(c) - return - } - token := getRandomToken(validTokenList) - if respDataMap, exists := getTokenData(token); exists { - if !isTokenExpired(respDataMap) { - proxyResp(c, respDataMap) - return - } - } - if getGithubApi(token) { - proxyResp(c, tokenMap[token]) - } else { - badRequest(c) - } - } -} - -// 验证请求代理请求token -func verifyRequest(c *gin.Context) error { - if configFile.Verification != "" { - token := c.GetHeader("Authorization") - tokenStr := strings.ReplaceAll(token, " ", "") - configCert := strings.ReplaceAll(configFile.Verification, " ", "") - if tokenStr != "token"+configCert { - return errors.New("verification failed") - } - } - return nil -} - -// 从map中获取github token对应的copilot token -func getTokenData(token string) (map[string]interface{}, bool) { - respDataMap, exists := tokenMap[token] - return respDataMap, exists -} - -// 检测copilot token是否过期 -func isTokenExpired(respDataMap map[string]interface{}) bool { - if expiresAt, ok := respDataMap["expires_at"].(float64); ok { - currentTime := time.Now().Unix() - expiresAtInt64 := int64(expiresAt) - return expiresAtInt64 <= currentTime+60 - } - return true -} - -// 重置请求计数 -func resetRequestCount() { - requestCountMutex.Lock() - defer requestCountMutex.Unlock() - requestCount = 0 - successCount = 0 -} - -// 从map中随机获取一个github token -func getRandomToken(m map[string]bool) string { - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - if len(keys) == 0 { - return "" // 返回空字符串或处理其他错误情况 - } - r := rand.New(rand.NewSource(time.Now().UnixNano())) - randomIndex := r.Intn(len(keys)) - return keys[randomIndex] -} diff --git a/src/define.go b/src/global.go similarity index 51% rename from src/define.go rename to src/global.go index e1fd34a..7c3e31f 100644 --- a/src/define.go +++ b/src/global.go @@ -1,9 +1,11 @@ package main import ( + "github.com/gin-gonic/gin" "sync" ) +// Config 配置文件结构体 type Config struct { Server struct { Domain string `json:"domain"` @@ -19,14 +21,21 @@ type Config struct { Verification string `json:"verification"` } +// 全局变量 var ( - //初始化需要返回给客户端的响应体 - tokenMap = make(map[string]map[string]interface{}) - //有效的token列表 - validTokenList = make(map[string]bool) + copilotGinEngine *gin.Engine + //有效ghu_token的map + validGhuTokenMap = make(map[string]bool) + //与有效ghu_token对于的co_token的map + copilotTokenMap = make(map[string]map[string]interface{}) + //服务器配置文件 + configFile Config + //请求计数锁 requestCountMutex sync.Mutex - githubApiCount = 0 - requestCount = 0 - successCount = 0 - configFile Config + //githubApi请求计数 + githubApiCount = 0 + //总请求计数 + requestCount = 0 + //请求成功计数 + successCount = 0 ) diff --git a/src/main.go b/src/main.go index 411290e..8942926 100644 --- a/src/main.go +++ b/src/main.go @@ -1,21 +1,13 @@ package main -func main() { - // 初始化配置文件 - configFile = initConfig() - - // 创建Gin引擎 - engine := setupGinEngine() - - // 初始化有效的token列表 - initValidTokenList() - - // 定义路由 - setupRoutes(engine) - - // 初始化并启动服务器 - initAndStartServer(engine) - - // 显示信息 - showMsg() +func init() { + loadConfig() // 1.加载服务器配置文件 + initGinEngine() // 2.初始化Gin引擎 + initValidGhuTokenMap() // 3.初始化有效Ghu_token +} + +func main() { + Routes() // 1.url路由 + StartServer() // 2.启动服务器 + showMsg() // 3.控制台信息显示 } diff --git a/src/middleware.go b/src/middleware.go new file mode 100644 index 0000000..b0da43c --- /dev/null +++ b/src/middleware.go @@ -0,0 +1,37 @@ +package main + +import ( + "github.com/gin-gonic/gin" + "strings" +) + +//1.DomainMiddleware 域名中间件 +//2.VerifyRequestMiddleware 代理服务器自定义的token验证中间件 + +// DomainMiddleware 域名中间件 +func DomainMiddleware(domain string) gin.HandlerFunc { + return func(c *gin.Context) { + // 获取请求的域名(不包括端口) + requestDomain := strings.Split(c.Request.Host, ":")[0] + // 使用子域名的情况下,检查域名是否匹配或为本地地址 + if requestDomain == domain || requestDomain == "127.0.0.1" || requestDomain == "localhost" { + c.Next() + return + } + c.AbortWithStatus(403) + } +} + +// VerifyRequestMiddleware 代理服务器自定义的token验证中间件 +func VerifyRequestMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if configFile.Verification != "" { + authHeader := c.GetHeader("Authorization") + if authHeader != "token "+configFile.Verification { + c.JSON(401, gin.H{"error": "Unauthorized"}) + return + } + } + c.Next() + } +} diff --git a/src/routers.go b/src/routers.go new file mode 100644 index 0000000..9bca1ba --- /dev/null +++ b/src/routers.go @@ -0,0 +1,7 @@ +package main + +// Routes 自定义代理服务器路由 附加中间件(域名验证和请求验证) +func Routes() { + copilotApi := copilotGinEngine.Group("/copilot_internal", DomainMiddleware(configFile.Server.Domain), VerifyRequestMiddleware()) + copilotApi.GET("/v2/token", getCopilotToken()) +} diff --git a/src/server.go b/src/server.go index 6395ab4..fb2434b 100644 --- a/src/server.go +++ b/src/server.go @@ -10,76 +10,57 @@ import ( "os" "path/filepath" "strconv" - "strings" ) -// 初始化配置文件 -func initConfig() Config { - // 读取配置文件 - exePath, err := os.Executable() - if err != nil { - panic(err) - } - // 获取执行文件所在目录 +// 加载服务器配置文件 +func loadConfig() { + // 获取可执行文件所在目录 ./config.json + exePath, _ := os.Executable() exeDir := filepath.Dir(exePath) - configFile, err := os.Open(exeDir + "/config.json") + configFileTemp, err := os.Open(exeDir + "/config.json") if err != nil { panic("file \"./config.json\" not found") } + //函数退出时关闭文件流 defer func(configFile *os.File) { - err := configFile.Close() - if err != nil { - panic("close file \"./config.json\" err") - } - }(configFile) - decoder := json.NewDecoder(configFile) - config := Config{} - err = decoder.Decode(&config) + _ = configFile.Close() + }(configFileTemp) + //解析json + decoder := json.NewDecoder(configFileTemp) + err = decoder.Decode(&configFile) if err != nil { panic("config format err") } - return config } -// 创建和配置Gin引擎 -func setupGinEngine() *gin.Engine { +// 初始化Gin引擎 +func initGinEngine() { + // 设置gin模式为发布模式 gin.SetMode(gin.ReleaseMode) + //关闭gin日志输出 自认为小项目没什么用 gin.DefaultWriter = io.Discard - engine := gin.New() - // 设置信任的代理 - if err := engine.SetTrustedProxies([]string{"127.0.0.1"}); err != nil { + // 创建Gin引擎 + copilotGinEngine = gin.New() + // 设置信任的前置代理 用nginx反代需要 不写这个编译会有个警告看着难受 + if err := copilotGinEngine.SetTrustedProxies([]string{"127.0.0.1"}); err != nil { log.Fatal(err) } - return engine } -// 定义路由和中间件 -func setupRoutes(engine *gin.Engine) { - domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain)) - domainDefault.GET("/copilot_internal/v2/token", getGithubToken()) -} - -// DomainMiddleware 域名中间件 -func DomainMiddleware(domain string) gin.HandlerFunc { - return func(c *gin.Context) { - // 检查域名是否匹配 - requestDomain := strings.Split(c.Request.Host, ":")[0] - if requestDomain == domain || requestDomain == "127.0.0.1" { - c.Next() - } else { - c.String(403, "Forbidden") - c.Abort() - } - } -} - -// 初始化和启动服务器 -func initAndStartServer(engine *gin.Engine) { +// StartServer 启动服务器并监听ip和端口 +func StartServer() { + //监听地址是host+port listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port) - server := createTLSServer(engine, listenAddress) + server := &http.Server{ + Addr: listenAddress, + TLSConfig: &tls.Config{ + NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, + }, + Handler: copilotGinEngine, + } go func() { if configFile.Server.Port != 443 { - err := engine.Run(listenAddress) + err := copilotGinEngine.Run(listenAddress) log.Fatal(err) } else { err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath) @@ -88,34 +69,10 @@ func initAndStartServer(engine *gin.Engine) { }() } -// 创建TLS服务器配置 -func createTLSServer(engine *gin.Engine, address string) *http.Server { - return &http.Server{ - Addr: address, - TLSConfig: &tls.Config{ - NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, - }, - Handler: engine, - } -} - -// 本服务器响应 -func proxyResp(c *gin.Context, respDataMap map[string]interface{}) { - // 将map转换为JSON字符串 - responseJSON, err := json.Marshal(respDataMap) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"}) - } - // 请求成功统计 - successCount++ - // 将JSON字符串作为响应体返回 - c.Header("Content-Type", "application/json") - c.String(http.StatusOK, string(responseJSON)) -} - -// 请求错误 -func badRequest(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{ - "message": "Bad credentials", - "documentation_url": "https://docs.github.com/rest"}) +// 自定义请求失败返回的状态码和错误信息 +func diyBadRequest(c *gin.Context, code int, errorMessage string) { + c.JSON(code, gin.H{ + "message": errorMessage, + "documentation_url": "https://docs.github.com/rest", + }) } diff --git a/src/showMsg.go b/src/show_msg.go similarity index 71% rename from src/showMsg.go rename to src/show_msg.go index a55a9cd..8a496cf 100644 --- a/src/showMsg.go +++ b/src/show_msg.go @@ -7,7 +7,15 @@ import ( "time" ) -// 控制台显示信息 +// 重置计数 +func resetRequestCount() { + requestCountMutex.Lock() + defer requestCountMutex.Unlock() + requestCount = 0 + successCount = 0 +} + +// 控制台显示信息 无关紧要的内容 func showMsg() { fmt.Println(color.WhiteString("-----------------------------------------------------------------------")) fmt.Println(color.HiBlueString(" _ _ _ _ \n ___| |__ __ _ _ __ ___ ___ ___ _ __ (_) | ___ | |_ \n/ __| '_ \\ / _` | '__/ _ \\_____ / __/ _ \\| '_ \\| | |/ _ \\| __|\n\\__ \\ | | | (_| | | | __/_____| (_| (_) | |_) | | | (_) | |_ \n|___/_| |_|\\__,_|_| \\___| \\___\\___/| .__/|_|_|\\___/ \\__|\n |_| \n")) @@ -21,13 +29,10 @@ func showMsg() { } else { url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) } - var jetStr = color.WhiteString("[Jetbrains]") - var vsStr = color.WhiteString("[Vscode]") - var valid = color.WhiteString("[Valid tokens]") - fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token")) - fmt.Println(vsStr + ": " + color.HiBlueString(url)) - fmt.Println(valid + ": " + color.HiBlueString(strconv.Itoa(len(validTokenList)))) - fmt.Println(color.WhiteString("-----------------------------------------------------------------------")) + jetStr, vsStr, valid := color.WhiteString("[Jetbrains]"), color.WhiteString("[Vscode]"), color.WhiteString("[Valid tokens]") + fmt.Printf("%s: %s/copilot_internal/v2/token\n%s: %s\n%s: %d\n", + jetStr, color.HiBlueString(url), vsStr, color.HiBlueString(url), valid, len(validGhuTokenMap)) + fmt.Println("-----------------------------------------------------------------------") for { requestCountMutex.Lock() sCount := successCount @@ -38,9 +43,7 @@ func showMsg() { if "00:00:00" == currentTime { resetRequestCount() } - var s2 = color.WhiteString("[Succeed]") - var s3 = color.WhiteString("[Failed]") - var s4 = color.WhiteString("[GithubApi]") + s2, s3, s4 := color.WhiteString("[Succeed]"), color.WhiteString("[Failed]"), color.WhiteString("[GithubApi]") // 打印文本 fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ", color.HiYellowString(currentTime), diff --git a/src/test.go b/src/test.go index 1913fb8..6b6d211 100644 --- a/src/test.go +++ b/src/test.go @@ -82,5 +82,5 @@ func getGithubTest(c *gin.Context, token string) { } //token map - tokenMap[token] = respDataMap + copilotTokenMap[token] = respDataMap }