mirror of
https://gitee.com/chuangxxt/share-copilot
synced 2025-04-16 10:33:25 +00:00
修复多个token请求混乱问题
This commit is contained in:
parent
6ac0e083e6
commit
8cb096239b
BIN
share-copilot
BIN
share-copilot
Binary file not shown.
@ -21,7 +21,9 @@ type Config struct {
|
||||
|
||||
var (
|
||||
//初始化需要返回给客户端的响应体
|
||||
responseData map[string]interface{}
|
||||
tokenMap = make(map[string]map[string]interface{})
|
||||
//有效的token列表
|
||||
validTokenList = make(map[string]bool)
|
||||
requestCountMutex sync.Mutex
|
||||
githubApiCount = 0
|
||||
requestCount = 0
|
||||
|
241
source/main.go
241
source/main.go
@ -15,6 +15,7 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -30,13 +31,18 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// 初始化有效的token列表
|
||||
initValidTokenList()
|
||||
// 定义路由
|
||||
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
|
||||
domainDefault.GET("/copilot_internal/v2/token", getToken())
|
||||
domainDefault.GET("/copilot_internal/v2/token", getGithubToken())
|
||||
// 初始化服务器
|
||||
initServer(engine)
|
||||
// 显示信息
|
||||
showMsg()
|
||||
}
|
||||
|
||||
// 初始化服务器
|
||||
func initServer(engine *gin.Engine) {
|
||||
// 配置支持的应用程序协议
|
||||
server := &http.Server{
|
||||
@ -62,47 +68,28 @@ func initServer(engine *gin.Engine) {
|
||||
}
|
||||
}()
|
||||
}
|
||||
func showMsg() {
|
||||
var url = ""
|
||||
if configFile.Server.Port == 80 {
|
||||
url = "http://" + configFile.Server.Domain
|
||||
} else if configFile.Server.Port == 443 {
|
||||
url = "https://" + configFile.Server.Domain
|
||||
} else {
|
||||
url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port)
|
||||
}
|
||||
var jetStr = color.WhiteString("[Jetbrains]")
|
||||
var vsStr = color.WhiteString("[Vscode]")
|
||||
|
||||
fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token"))
|
||||
fmt.Println(vsStr + ": " + color.HiBlueString(url))
|
||||
fmt.Println(color.WhiteString("-----------------------------------------------------------------------"))
|
||||
for {
|
||||
requestCountMutex.Lock()
|
||||
sCount := successCount
|
||||
tCount := requestCount
|
||||
gCount := githubApiCount
|
||||
requestCountMutex.Unlock()
|
||||
currentTime := time.Now().Format("2006-01-02 15:04:05")
|
||||
if "00:00:00" == currentTime {
|
||||
resetRequestCount()
|
||||
}
|
||||
var s2 = color.WhiteString("[Succeed]")
|
||||
var s3 = color.WhiteString("[Failed]")
|
||||
var s4 = color.WhiteString("[GithubApi]")
|
||||
// 打印文本
|
||||
fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ",
|
||||
color.HiYellowString(currentTime),
|
||||
s2, color.GreenString(strconv.Itoa(sCount)),
|
||||
s3, color.RedString(strconv.Itoa(tCount-sCount)),
|
||||
s4, color.CyanString(strconv.Itoa(gCount)))
|
||||
time.Sleep(1 * time.Second) //
|
||||
// 初始化有效的token列表
|
||||
func initValidTokenList() {
|
||||
//为了安全起见,应该等待请求完成并处理其响应。
|
||||
var wg sync.WaitGroup
|
||||
for _, token := range configFile.CopilotConfig.Token {
|
||||
wg.Add(1)
|
||||
go func(token string) {
|
||||
defer wg.Done()
|
||||
if getGithubApi(token) {
|
||||
validTokenList[token] = true
|
||||
}
|
||||
}(token)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
func getToken() gin.HandlerFunc {
|
||||
|
||||
// 获取Github的token
|
||||
func getGithubToken() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 请求计数
|
||||
incrementRequestCount()
|
||||
requestCount++
|
||||
// 如果配置了verification,则需要获取请求头中的Authorization令牌
|
||||
if configFile.Verification != "" {
|
||||
token := c.GetHeader("Authorization")
|
||||
@ -110,61 +97,91 @@ func getToken() gin.HandlerFunc {
|
||||
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
|
||||
if tokenStr != "token"+configCert {
|
||||
// 拒绝不符合Verification的请求
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"message": "Bad credentials",
|
||||
"documentation_url": "https://docs.github.com/rest"})
|
||||
badRequest(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
//判断时间戳key是否存在
|
||||
if _, exists := responseData["expires_at"]; exists {
|
||||
// 获取当前时间的Unix时间戳
|
||||
currentTime := time.Now().Unix()
|
||||
if expiresAt, ok := responseData["expires_at"].(float64); ok {
|
||||
// 判断expires_at是否已经过期
|
||||
expiresAtInt64 := int64(expiresAt)
|
||||
//提前一分钟请求
|
||||
if expiresAtInt64 > currentTime+60 {
|
||||
//fmt.Println("\n未过期无需请求")
|
||||
respProxy(c)
|
||||
//从有效的token列表中随机获取一个token
|
||||
token := getRandomToken(validTokenList)
|
||||
//判断tokenMap 里的token是否存在
|
||||
if _, exists := tokenMap[token]; exists {
|
||||
respDataMap := tokenMap[token]
|
||||
//判断时间戳key是否存在
|
||||
if _, exists := respDataMap["expires_at"]; exists {
|
||||
// 获取当前时间的Unix时间戳
|
||||
currentTime := time.Now().Unix()
|
||||
if expiresAt, ok := respDataMap["expires_at"].(float64); ok {
|
||||
// 判断expires_at是否已经过期
|
||||
expiresAtInt64 := int64(expiresAt)
|
||||
//提前一分钟请求
|
||||
if expiresAtInt64 > currentTime+60 {
|
||||
//fmt.Println("\n未过期无需请求")
|
||||
proxyResp(c, tokenMap[token])
|
||||
} else {
|
||||
//fmt.Println("\n已过期重新请求")
|
||||
if getGithubApi(token) {
|
||||
proxyResp(c, tokenMap[token])
|
||||
return
|
||||
} else {
|
||||
badRequest(c)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//fmt.Println("\n已过期重新请求")
|
||||
getGithubApi(c)
|
||||
respProxy(c)
|
||||
badRequest(c)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Age is not an int")
|
||||
//tokenMap里的token对应的返回体不存在expires_at
|
||||
if getGithubApi(token) {
|
||||
proxyResp(c, tokenMap[token])
|
||||
return
|
||||
} else {
|
||||
badRequest(c)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//向githubApi发送请求
|
||||
//fmt.Println("\n第一次请求")
|
||||
getGithubApi(c)
|
||||
respProxy(c)
|
||||
//不存在则githubApi发送请求,并存到tokenMap
|
||||
if getGithubApi(token) {
|
||||
proxyResp(c, tokenMap[token])
|
||||
return
|
||||
} else {
|
||||
badRequest(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
func respProxy(c *gin.Context) {
|
||||
|
||||
// 请求错误
|
||||
func badRequest(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"message": "Bad credentials",
|
||||
"documentation_url": "https://docs.github.com/rest"})
|
||||
}
|
||||
|
||||
// 本服务器响应
|
||||
func proxyResp(c *gin.Context, respDataMap map[string]interface{}) {
|
||||
// 将map转换为JSON字符串
|
||||
responseJSON, err := json.Marshal(responseData)
|
||||
responseJSON, err := json.Marshal(respDataMap)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"})
|
||||
}
|
||||
// 请求成功统计
|
||||
incrementSuccessCount()
|
||||
successCount++
|
||||
// 将JSON字符串作为响应体返回
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.String(http.StatusOK, string(responseJSON))
|
||||
}
|
||||
func getGithubApi(c *gin.Context) {
|
||||
|
||||
// 请求githubApi
|
||||
func getGithubApi(token string) bool {
|
||||
githubApiCount++
|
||||
// 设置请求头
|
||||
headers := map[string]string{
|
||||
"Authorization": "token " + getRandomToken(configFile.CopilotConfig.Token),
|
||||
"editor-version": c.GetHeader("editor-version"),
|
||||
"Authorization": "token " + token,
|
||||
/*"editor-version": c.GetHeader("editor-version"),
|
||||
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
|
||||
"user-agent": c.GetHeader("user-agent"),
|
||||
"accept": c.GetHeader("accept"),
|
||||
"accept-encoding": c.GetHeader("accept-encoding"),
|
||||
"accept-encoding": c.GetHeader("accept-encoding"),*/
|
||||
}
|
||||
// 发起GET请求
|
||||
response, err := resty.New().R().
|
||||
@ -172,17 +189,28 @@ func getGithubApi(c *gin.Context) {
|
||||
Get(configFile.CopilotConfig.GithubApiUrl)
|
||||
if err != nil {
|
||||
// 处理请求错误
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
err = json.Unmarshal(response.Body(), &responseData)
|
||||
if err != nil {
|
||||
// 处理JSON解析错误
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"})
|
||||
return
|
||||
// 判断响应状态码
|
||||
if response.StatusCode() == http.StatusOK {
|
||||
// 响应状态码为200 OK
|
||||
respDataMap := map[string]interface{}{}
|
||||
err = json.Unmarshal(response.Body(), &respDataMap)
|
||||
if err != nil {
|
||||
// 处理JSON解析错误
|
||||
return false
|
||||
}
|
||||
//token map
|
||||
tokenMap[token] = respDataMap
|
||||
return true
|
||||
} else {
|
||||
// 处理其他状态码
|
||||
delete(validTokenList, token)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化配置文件
|
||||
func initConfig() Config {
|
||||
// 读取配置文件
|
||||
exePath, err := os.Executable()
|
||||
@ -209,26 +237,30 @@ func initConfig() Config {
|
||||
}
|
||||
return config
|
||||
}
|
||||
func incrementRequestCount() {
|
||||
requestCount++
|
||||
}
|
||||
func incrementSuccessCount() {
|
||||
successCount++
|
||||
}
|
||||
|
||||
// 重置请求计数
|
||||
func resetRequestCount() {
|
||||
requestCountMutex.Lock()
|
||||
defer requestCountMutex.Unlock()
|
||||
requestCount = 0
|
||||
successCount = 0
|
||||
}
|
||||
func getRandomToken(tokens []string) string {
|
||||
if len(tokens) == 0 {
|
||||
|
||||
// 从map中随机获取一个key
|
||||
func getRandomToken(m map[string]bool) string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return "" // 返回空字符串或处理其他错误情况
|
||||
}
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randomIndex := r.Intn(len(tokens))
|
||||
return tokens[randomIndex]
|
||||
randomIndex := r.Intn(len(keys))
|
||||
return keys[randomIndex]
|
||||
}
|
||||
|
||||
// DomainMiddleware 域名中间件
|
||||
func DomainMiddleware(domain string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查域名是否匹配
|
||||
@ -241,3 +273,44 @@ func DomainMiddleware(domain string) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 显示信息
|
||||
func showMsg() {
|
||||
var url = ""
|
||||
if configFile.Server.Port == 80 {
|
||||
url = "http://" + configFile.Server.Domain
|
||||
} else if configFile.Server.Port == 443 {
|
||||
url = "https://" + configFile.Server.Domain
|
||||
} else {
|
||||
url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port)
|
||||
}
|
||||
var jetStr = color.WhiteString("[Jetbrains]")
|
||||
var vsStr = color.WhiteString("[Vscode]")
|
||||
var valid = color.WhiteString("[Valid tokens]")
|
||||
|
||||
fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token"))
|
||||
fmt.Println(vsStr + ": " + color.HiBlueString(url))
|
||||
fmt.Println(valid + ": " + color.HiBlueString(strconv.Itoa(len(validTokenList))))
|
||||
fmt.Println(color.WhiteString("-----------------------------------------------------------------------"))
|
||||
for {
|
||||
requestCountMutex.Lock()
|
||||
sCount := successCount
|
||||
tCount := requestCount
|
||||
gCount := githubApiCount
|
||||
requestCountMutex.Unlock()
|
||||
currentTime := time.Now().Format("2006-01-02 15:04:05")
|
||||
if "00:00:00" == currentTime {
|
||||
resetRequestCount()
|
||||
}
|
||||
var s2 = color.WhiteString("[Succeed]")
|
||||
var s3 = color.WhiteString("[Failed]")
|
||||
var s4 = color.WhiteString("[GithubApi]")
|
||||
// 打印文本
|
||||
fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ",
|
||||
color.HiYellowString(currentTime),
|
||||
s2, color.GreenString(strconv.Itoa(sCount)),
|
||||
s3, color.RedString(strconv.Itoa(tCount-sCount)),
|
||||
s4, color.CyanString(strconv.Itoa(gCount)))
|
||||
time.Sleep(1 * time.Second) //
|
||||
}
|
||||
}
|
||||
|
86
source/test.go
Normal file
86
source/test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func getGithubTest(c *gin.Context, token string) {
|
||||
githubApiCount++
|
||||
data1 := `
|
||||
{
|
||||
"chat_enabled": false,
|
||||
"code_quote_enabled": false,
|
||||
"code_quote_v2_enabled": false,
|
||||
"copilotignore_enabled": false,
|
||||
"expires_at": 3194360727,
|
||||
"prompt_8k": true,
|
||||
"public_suggestions": "disabled",
|
||||
"refresh_in": 1500,
|
||||
"sku": "free_educational",
|
||||
"telemetry": "disabled",
|
||||
"token": "tid=;exp=1694360727;sku=free_educational;st=dotcom;8kp=1:",
|
||||
"tracking_id": ""
|
||||
}
|
||||
`
|
||||
data2 := `
|
||||
{
|
||||
"chat_enabled": false,
|
||||
"code_quote_enabled": false,
|
||||
"code_quote_v2_enabled": false,
|
||||
"copilotignore_enabled": false,
|
||||
"expires_at": 2294360727,
|
||||
"prompt_8k": true,
|
||||
"public_suggestions": "disabled",
|
||||
"refresh_in": 1500,
|
||||
"sku": "free_educational",
|
||||
"telemetry": "disabled",
|
||||
"token": "tid=;exp=1694360727;sku=free_educational;st=dotcom;8kp=1:",
|
||||
"tracking_id": ""
|
||||
}
|
||||
`
|
||||
data3 := `
|
||||
{
|
||||
"chat_enabled": false,
|
||||
"code_quote_enabled": false,
|
||||
"code_quote_v2_enabled": false,
|
||||
"copilotignore_enabled": false,
|
||||
"expires_at": 3394360727,
|
||||
"prompt_8k": true,
|
||||
"public_suggestions": "disabled",
|
||||
"refresh_in": 1500,
|
||||
"sku": "free_educational",
|
||||
"telemetry": "disabled",
|
||||
"token": "tid=;exp=1694360727;sku=free_educational;st=dotcom;8kp=1:",
|
||||
"tracking_id": ""
|
||||
}
|
||||
`
|
||||
//响应体map
|
||||
var respDataMap = make(map[string]interface{})
|
||||
if token == "1" {
|
||||
err := json.Unmarshal([]byte(data1), &respDataMap)
|
||||
if err != nil {
|
||||
// 处理JSON解析错误
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"})
|
||||
return
|
||||
}
|
||||
} else if token == "2" {
|
||||
err := json.Unmarshal([]byte(data2), &respDataMap)
|
||||
if err != nil {
|
||||
// 处理JSON解析错误
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
err := json.Unmarshal([]byte(data3), &respDataMap)
|
||||
if err != nil {
|
||||
// 处理JSON解析错误
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
//token map
|
||||
tokenMap[token] = respDataMap
|
||||
}
|
Loading…
Reference in New Issue
Block a user