删除无用代码,优化逻辑

This commit is contained in:
luyoyu 2023-09-20 23:05:07 +08:00
parent 4fe39fff57
commit 78206866d4
9 changed files with 234 additions and 259 deletions

112
src/copilot_api.go Normal file
View File

@ -0,0 +1,112 @@
package main
import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"math/rand"
"net/http"
"time"
)
// 初始化有效的github token列表
func initValidGhuTokenMap() {
for _, token := range configFile.CopilotConfig.Token {
if getGithubApi(token) {
// 有效的token
validGhuTokenMap[token] = true
}
}
}
// 代理服务器返回copilot token
func proxyResp(c *gin.Context, respDataMap map[string]interface{}) {
// 将map转换为JSON字符串
responseJSON, _ := json.Marshal(respDataMap)
// 请求成功统计
successCount++
// 将响应体作为JSON字符串返回
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(responseJSON))
}
// 获取有效的 copilot token
func getCopilotToken() gin.HandlerFunc {
return func(c *gin.Context) {
//用户请求代理服务器的计数
requestCount++
//随机从有效的github token列表中获取一个token
token := getRandomToken(validGhuTokenMap)
//通过token取对应的copilot token的map数据
if respDataMap, exists := getTokenData(token); exists {
//没过期直接返回
if !isTokenExpired(respDataMap) {
proxyResp(c, respDataMap)
return
}
}
//过期了或者没取到,重新获取
if getGithubApi(token) {
proxyResp(c, copilotTokenMap[token])
} else {
// 重新获取失败,返回自定义消息 400
diyBadRequest(c, 400, "can't get copilot token")
}
}
}
// 请求github api
func getGithubApi(token string) bool {
// githubApi请求计数
githubApiCount++
// 设置请求头
headers := map[string]string{
"Authorization": "token " + token,
}
// 发起GET请求
response, _ := resty.New().R().
SetHeaders(headers).
Get(configFile.CopilotConfig.GithubApiUrl)
// 判断响应状态码
if response.StatusCode() == http.StatusOK {
// 响应状态码为200 OK
respDataMap := map[string]interface{}{}
_ = json.Unmarshal(response.Body(), &respDataMap)
copilotTokenMap[token] = respDataMap
return true
} else {
// 响应状态码不为200 map删除无效token
delete(validGhuTokenMap, token)
return false
}
}
// 从map中获取github token对应的copilot token
func getTokenData(token string) (map[string]interface{}, bool) {
respDataMap, exists := copilotTokenMap[token]
return respDataMap, exists
}
// 检测copilot token是否过期
func isTokenExpired(respDataMap map[string]interface{}) bool {
if expiresAt, ok := respDataMap["expires_at"].(float64); ok {
currentTime := time.Now().Unix()
expiresAtInt64 := int64(expiresAt)
return expiresAtInt64 <= currentTime+60
}
return true
}
// 从map中随机获取一个github token
func getRandomToken(m map[string]bool) string {
ghuTokenArray := make([]string, 0, len(m))
for k := range m {
ghuTokenArray = append(ghuTokenArray, k)
}
if len(ghuTokenArray) == 0 {
return "" // 没有有效的token返回空字符串
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndex := r.Intn(len(ghuTokenArray))
return ghuTokenArray[randomIndex]
}

View File

@ -1,142 +0,0 @@
package main
import (
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"math/rand"
"net/http"
"strings"
"sync"
"time"
)
// 初始化有效的github token列表
func initValidTokenList() {
//为了安全起见,应该等待请求完成并处理其响应。
var wg sync.WaitGroup
for _, token := range configFile.CopilotConfig.Token {
wg.Add(1)
go func(token string) {
defer wg.Done()
if getGithubApi(token) {
validTokenList[token] = true
}
}(token)
}
wg.Wait()
}
// 请求github api
func getGithubApi(token string) bool {
githubApiCount++
// 设置请求头
headers := map[string]string{
"Authorization": "token " + token,
/*"editor-version": c.GetHeader("editor-version"),
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
"user-agent": c.GetHeader("user-agent"),
"accept": c.GetHeader("accept"),
"accept-encoding": c.GetHeader("accept-encoding"),*/
}
// 发起GET请求
response, err := resty.New().R().
SetHeaders(headers).
Get(configFile.CopilotConfig.GithubApiUrl)
if err != nil {
// 处理请求错误
return false
}
// 判断响应状态码
if response.StatusCode() == http.StatusOK {
// 响应状态码为200 OK
respDataMap := map[string]interface{}{}
err = json.Unmarshal(response.Body(), &respDataMap)
if err != nil {
// 处理JSON解析错误
return false
}
//token map
tokenMap[token] = respDataMap
return true
} else {
// 处理其他状态码
delete(validTokenList, token)
return false
}
}
// 获取copilot token
func getGithubToken() gin.HandlerFunc {
return func(c *gin.Context) {
requestCount++
if err := verifyRequest(c); err != nil {
badRequest(c)
return
}
token := getRandomToken(validTokenList)
if respDataMap, exists := getTokenData(token); exists {
if !isTokenExpired(respDataMap) {
proxyResp(c, respDataMap)
return
}
}
if getGithubApi(token) {
proxyResp(c, tokenMap[token])
} else {
badRequest(c)
}
}
}
// 验证请求代理请求token
func verifyRequest(c *gin.Context) error {
if configFile.Verification != "" {
token := c.GetHeader("Authorization")
tokenStr := strings.ReplaceAll(token, " ", "")
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
if tokenStr != "token"+configCert {
return errors.New("verification failed")
}
}
return nil
}
// 从map中获取github token对应的copilot token
func getTokenData(token string) (map[string]interface{}, bool) {
respDataMap, exists := tokenMap[token]
return respDataMap, exists
}
// 检测copilot token是否过期
func isTokenExpired(respDataMap map[string]interface{}) bool {
if expiresAt, ok := respDataMap["expires_at"].(float64); ok {
currentTime := time.Now().Unix()
expiresAtInt64 := int64(expiresAt)
return expiresAtInt64 <= currentTime+60
}
return true
}
// 重置请求计数
func resetRequestCount() {
requestCountMutex.Lock()
defer requestCountMutex.Unlock()
requestCount = 0
successCount = 0
}
// 从map中随机获取一个github token
func getRandomToken(m map[string]bool) string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
if len(keys) == 0 {
return "" // 返回空字符串或处理其他错误情况
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndex := r.Intn(len(keys))
return keys[randomIndex]
}

View File

@ -1,9 +1,11 @@
package main package main
import ( import (
"github.com/gin-gonic/gin"
"sync" "sync"
) )
// Config 配置文件结构体
type Config struct { type Config struct {
Server struct { Server struct {
Domain string `json:"domain"` Domain string `json:"domain"`
@ -19,14 +21,21 @@ type Config struct {
Verification string `json:"verification"` Verification string `json:"verification"`
} }
// 全局变量
var ( var (
//初始化需要返回给客户端的响应体 copilotGinEngine *gin.Engine
tokenMap = make(map[string]map[string]interface{}) //有效ghu_token的map
//有效的token列表 validGhuTokenMap = make(map[string]bool)
validTokenList = make(map[string]bool) //与有效ghu_token对于的co_token的map
copilotTokenMap = make(map[string]map[string]interface{})
//服务器配置文件
configFile Config
//请求计数锁
requestCountMutex sync.Mutex requestCountMutex sync.Mutex
githubApiCount = 0 //githubApi请求计数
requestCount = 0 githubApiCount = 0
successCount = 0 //总请求计数
configFile Config requestCount = 0
//请求成功计数
successCount = 0
) )

View File

@ -1,21 +1,13 @@
package main package main
func main() { func init() {
// 初始化配置文件 loadConfig() // 1.加载服务器配置文件
configFile = initConfig() initGinEngine() // 2.初始化Gin引擎
initValidGhuTokenMap() // 3.初始化有效Ghu_token
// 创建Gin引擎 }
engine := setupGinEngine()
func main() {
// 初始化有效的token列表 Routes() // 1.url路由
initValidTokenList() StartServer() // 2.启动服务器
showMsg() // 3.控制台信息显示
// 定义路由
setupRoutes(engine)
// 初始化并启动服务器
initAndStartServer(engine)
// 显示信息
showMsg()
} }

37
src/middleware.go Normal file
View File

@ -0,0 +1,37 @@
package main
import (
"github.com/gin-gonic/gin"
"strings"
)
//1.DomainMiddleware 域名中间件
//2.VerifyRequestMiddleware 代理服务器自定义的token验证中间件
// DomainMiddleware 域名中间件
func DomainMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 获取请求的域名(不包括端口)
requestDomain := strings.Split(c.Request.Host, ":")[0]
// 使用子域名的情况下,检查域名是否匹配或为本地地址
if requestDomain == domain || requestDomain == "127.0.0.1" || requestDomain == "localhost" {
c.Next()
return
}
c.AbortWithStatus(403)
}
}
// VerifyRequestMiddleware 代理服务器自定义的token验证中间件
func VerifyRequestMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if configFile.Verification != "" {
authHeader := c.GetHeader("Authorization")
if authHeader != "token "+configFile.Verification {
c.JSON(401, gin.H{"error": "Unauthorized"})
return
}
}
c.Next()
}
}

7
src/routers.go Normal file
View File

@ -0,0 +1,7 @@
package main
// Routes 自定义代理服务器路由 附加中间件(域名验证和请求验证)
func Routes() {
copilotApi := copilotGinEngine.Group("/copilot_internal", DomainMiddleware(configFile.Server.Domain), VerifyRequestMiddleware())
copilotApi.GET("/v2/token", getCopilotToken())
}

View File

@ -10,76 +10,57 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings"
) )
// 初始化配置文件 // 加载服务器配置文件
func initConfig() Config { func loadConfig() {
// 读取配置文件 // 获取可执行文件所在目录 ./config.json
exePath, err := os.Executable() exePath, _ := os.Executable()
if err != nil {
panic(err)
}
// 获取执行文件所在目录
exeDir := filepath.Dir(exePath) exeDir := filepath.Dir(exePath)
configFile, err := os.Open(exeDir + "/config.json") configFileTemp, err := os.Open(exeDir + "/config.json")
if err != nil { if err != nil {
panic("file \"./config.json\" not found") panic("file \"./config.json\" not found")
} }
//函数退出时关闭文件流
defer func(configFile *os.File) { defer func(configFile *os.File) {
err := configFile.Close() _ = configFile.Close()
if err != nil { }(configFileTemp)
panic("close file \"./config.json\" err") //解析json
} decoder := json.NewDecoder(configFileTemp)
}(configFile) err = decoder.Decode(&configFile)
decoder := json.NewDecoder(configFile)
config := Config{}
err = decoder.Decode(&config)
if err != nil { if err != nil {
panic("config format err") panic("config format err")
} }
return config
} }
// 创建和配置Gin引擎 // 初始化Gin引擎
func setupGinEngine() *gin.Engine { func initGinEngine() {
// 设置gin模式为发布模式
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
//关闭gin日志输出 自认为小项目没什么用
gin.DefaultWriter = io.Discard gin.DefaultWriter = io.Discard
engine := gin.New() // 创建Gin引擎
// 设置信任的代理 copilotGinEngine = gin.New()
if err := engine.SetTrustedProxies([]string{"127.0.0.1"}); err != nil { // 设置信任的前置代理 用nginx反代需要 不写这个编译会有个警告看着难受
if err := copilotGinEngine.SetTrustedProxies([]string{"127.0.0.1"}); err != nil {
log.Fatal(err) log.Fatal(err)
} }
return engine
} }
// 定义路由和中间件 // StartServer 启动服务器并监听ip和端口
func setupRoutes(engine *gin.Engine) { func StartServer() {
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain)) //监听地址是host+port
domainDefault.GET("/copilot_internal/v2/token", getGithubToken())
}
// DomainMiddleware 域名中间件
func DomainMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查域名是否匹配
requestDomain := strings.Split(c.Request.Host, ":")[0]
if requestDomain == domain || requestDomain == "127.0.0.1" {
c.Next()
} else {
c.String(403, "Forbidden")
c.Abort()
}
}
}
// 初始化和启动服务器
func initAndStartServer(engine *gin.Engine) {
listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port) listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
server := createTLSServer(engine, listenAddress) server := &http.Server{
Addr: listenAddress,
TLSConfig: &tls.Config{
NextProtos: []string{"http/1.1", "http/1.2", "http/2"},
},
Handler: copilotGinEngine,
}
go func() { go func() {
if configFile.Server.Port != 443 { if configFile.Server.Port != 443 {
err := engine.Run(listenAddress) err := copilotGinEngine.Run(listenAddress)
log.Fatal(err) log.Fatal(err)
} else { } else {
err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath) err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
@ -88,34 +69,10 @@ func initAndStartServer(engine *gin.Engine) {
}() }()
} }
// 创建TLS服务器配置 // 自定义请求失败返回的状态码和错误信息
func createTLSServer(engine *gin.Engine, address string) *http.Server { func diyBadRequest(c *gin.Context, code int, errorMessage string) {
return &http.Server{ c.JSON(code, gin.H{
Addr: address, "message": errorMessage,
TLSConfig: &tls.Config{ "documentation_url": "https://docs.github.com/rest",
NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, })
},
Handler: engine,
}
}
// 本服务器响应
func proxyResp(c *gin.Context, respDataMap map[string]interface{}) {
// 将map转换为JSON字符串
responseJSON, err := json.Marshal(respDataMap)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"})
}
// 请求成功统计
successCount++
// 将JSON字符串作为响应体返回
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(responseJSON))
}
// 请求错误
func badRequest(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"message": "Bad credentials",
"documentation_url": "https://docs.github.com/rest"})
} }

View File

@ -7,7 +7,15 @@ import (
"time" "time"
) )
// 控制台显示信息 // 重置计数
func resetRequestCount() {
requestCountMutex.Lock()
defer requestCountMutex.Unlock()
requestCount = 0
successCount = 0
}
// 控制台显示信息 无关紧要的内容
func showMsg() { func showMsg() {
fmt.Println(color.WhiteString("-----------------------------------------------------------------------")) fmt.Println(color.WhiteString("-----------------------------------------------------------------------"))
fmt.Println(color.HiBlueString(" _ _ _ _ \n ___| |__ __ _ _ __ ___ ___ ___ _ __ (_) | ___ | |_ \n/ __| '_ \\ / _` | '__/ _ \\_____ / __/ _ \\| '_ \\| | |/ _ \\| __|\n\\__ \\ | | | (_| | | | __/_____| (_| (_) | |_) | | | (_) | |_ \n|___/_| |_|\\__,_|_| \\___| \\___\\___/| .__/|_|_|\\___/ \\__|\n |_| \n")) fmt.Println(color.HiBlueString(" _ _ _ _ \n ___| |__ __ _ _ __ ___ ___ ___ _ __ (_) | ___ | |_ \n/ __| '_ \\ / _` | '__/ _ \\_____ / __/ _ \\| '_ \\| | |/ _ \\| __|\n\\__ \\ | | | (_| | | | __/_____| (_| (_) | |_) | | | (_) | |_ \n|___/_| |_|\\__,_|_| \\___| \\___\\___/| .__/|_|_|\\___/ \\__|\n |_| \n"))
@ -21,13 +29,10 @@ func showMsg() {
} else { } else {
url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port)
} }
var jetStr = color.WhiteString("[Jetbrains]") jetStr, vsStr, valid := color.WhiteString("[Jetbrains]"), color.WhiteString("[Vscode]"), color.WhiteString("[Valid tokens]")
var vsStr = color.WhiteString("[Vscode]") fmt.Printf("%s: %s/copilot_internal/v2/token\n%s: %s\n%s: %d\n",
var valid = color.WhiteString("[Valid tokens]") jetStr, color.HiBlueString(url), vsStr, color.HiBlueString(url), valid, len(validGhuTokenMap))
fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token")) fmt.Println("-----------------------------------------------------------------------")
fmt.Println(vsStr + ": " + color.HiBlueString(url))
fmt.Println(valid + ": " + color.HiBlueString(strconv.Itoa(len(validTokenList))))
fmt.Println(color.WhiteString("-----------------------------------------------------------------------"))
for { for {
requestCountMutex.Lock() requestCountMutex.Lock()
sCount := successCount sCount := successCount
@ -38,9 +43,7 @@ func showMsg() {
if "00:00:00" == currentTime { if "00:00:00" == currentTime {
resetRequestCount() resetRequestCount()
} }
var s2 = color.WhiteString("[Succeed]") s2, s3, s4 := color.WhiteString("[Succeed]"), color.WhiteString("[Failed]"), color.WhiteString("[GithubApi]")
var s3 = color.WhiteString("[Failed]")
var s4 = color.WhiteString("[GithubApi]")
// 打印文本 // 打印文本
fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ", fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ",
color.HiYellowString(currentTime), color.HiYellowString(currentTime),

View File

@ -82,5 +82,5 @@ func getGithubTest(c *gin.Context, token string) {
} }
//token map //token map
tokenMap[token] = respDataMap copilotTokenMap[token] = respDataMap
} }