share-copilot/source/getCopilotToke.go
ben Gutier 7abff5fe83 优化
2023-09-11 12:24:04 +08:00

143 lines
3.4 KiB
Go

package main
import (
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"math/rand"
"net/http"
"strings"
"sync"
"time"
)
// 初始化有效的github token列表
func initValidTokenList() {
//为了安全起见,应该等待请求完成并处理其响应。
var wg sync.WaitGroup
for _, token := range configFile.CopilotConfig.Token {
wg.Add(1)
go func(token string) {
defer wg.Done()
if getGithubApi(token) {
validTokenList[token] = true
}
}(token)
}
wg.Wait()
}
// 请求github api
func getGithubApi(token string) bool {
githubApiCount++
// 设置请求头
headers := map[string]string{
"Authorization": "token " + token,
/*"editor-version": c.GetHeader("editor-version"),
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
"user-agent": c.GetHeader("user-agent"),
"accept": c.GetHeader("accept"),
"accept-encoding": c.GetHeader("accept-encoding"),*/
}
// 发起GET请求
response, err := resty.New().R().
SetHeaders(headers).
Get(configFile.CopilotConfig.GithubApiUrl)
if err != nil {
// 处理请求错误
return false
}
// 判断响应状态码
if response.StatusCode() == http.StatusOK {
// 响应状态码为200 OK
respDataMap := map[string]interface{}{}
err = json.Unmarshal(response.Body(), &respDataMap)
if err != nil {
// 处理JSON解析错误
return false
}
//token map
tokenMap[token] = respDataMap
return true
} else {
// 处理其他状态码
delete(validTokenList, token)
return false
}
}
// 获取copilot token
func getGithubToken() gin.HandlerFunc {
return func(c *gin.Context) {
requestCount++
if err := verifyRequest(c); err != nil {
badRequest(c)
return
}
token := getRandomToken(validTokenList)
if respDataMap, exists := getTokenData(token); exists {
if !isTokenExpired(respDataMap) {
proxyResp(c, respDataMap)
return
}
}
if getGithubApi(token) {
proxyResp(c, tokenMap[token])
} else {
badRequest(c)
}
}
}
// 验证请求代理请求token
func verifyRequest(c *gin.Context) error {
if configFile.Verification != "" {
token := c.GetHeader("Authorization")
tokenStr := strings.ReplaceAll(token, " ", "")
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
if tokenStr != "token"+configCert {
return errors.New("verification failed")
}
}
return nil
}
// 从map中获取github token对应的copilot token
func getTokenData(token string) (map[string]interface{}, bool) {
respDataMap, exists := tokenMap[token]
return respDataMap, exists
}
// 检测copilot token是否过期
func isTokenExpired(respDataMap map[string]interface{}) bool {
if expiresAt, ok := respDataMap["expires_at"].(float64); ok {
currentTime := time.Now().Unix()
expiresAtInt64 := int64(expiresAt)
return expiresAtInt64 <= currentTime+60
}
return true
}
// 重置请求计数
func resetRequestCount() {
requestCountMutex.Lock()
defer requestCountMutex.Unlock()
requestCount = 0
successCount = 0
}
// 从map中随机获取一个github token
func getRandomToken(m map[string]bool) string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
if len(keys) == 0 {
return "" // 返回空字符串或处理其他错误情况
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndex := r.Intn(len(keys))
return keys[randomIndex]
}