share-copilot/source/main.go
2023-09-11 01:31:07 +08:00

317 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/fatih/color"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"io"
"log"
"math/rand"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
func main() {
//初始化配置文件
configFile = initConfig()
// 创建Gin引擎
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
engine := gin.New()
// 设置信任的代理
err := engine.SetTrustedProxies([]string{"127.0.0.1"})
if err != nil {
log.Fatal(err)
}
// 初始化有效的token列表
initValidTokenList()
// 定义路由
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
domainDefault.GET("/copilot_internal/v2/token", getGithubToken())
// 初始化服务器
initServer(engine)
// 显示信息
showMsg()
}
// 初始化服务器
func initServer(engine *gin.Engine) {
// 配置支持的应用程序协议
server := &http.Server{
Addr: ":443",
TLSConfig: &tls.Config{
NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, // 支持的应用程序协议列表
},
Handler: engine,
}
// 启动Gin服务器并监听端口
listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
go func() {
if configFile.Server.Port != 443 {
err := engine.Run(listenAddress)
if err != nil {
log.Fatal(err)
}
} else {
err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
if err != nil {
log.Fatal(err)
}
}
}()
}
// 初始化有效的token列表
func initValidTokenList() {
//为了安全起见,应该等待请求完成并处理其响应。
var wg sync.WaitGroup
for _, token := range configFile.CopilotConfig.Token {
wg.Add(1)
go func(token string) {
defer wg.Done()
if getGithubApi(token) {
validTokenList[token] = true
}
}(token)
}
wg.Wait()
}
// 获取Github的token
func getGithubToken() gin.HandlerFunc {
return func(c *gin.Context) {
// 请求计数
requestCount++
// 如果配置了verification则需要获取请求头中的Authorization令牌
if configFile.Verification != "" {
token := c.GetHeader("Authorization")
tokenStr := strings.ReplaceAll(token, " ", "")
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
if tokenStr != "token"+configCert {
// 拒绝不符合Verification的请求
badRequest(c)
return
}
}
//从有效的token列表中随机获取一个token
token := getRandomToken(validTokenList)
//判断tokenMap 里的token是否存在
if _, exists := tokenMap[token]; exists {
respDataMap := tokenMap[token]
//判断时间戳key是否存在
if _, exists := respDataMap["expires_at"]; exists {
// 获取当前时间的Unix时间戳
currentTime := time.Now().Unix()
if expiresAt, ok := respDataMap["expires_at"].(float64); ok {
// 判断expires_at是否已经过期
expiresAtInt64 := int64(expiresAt)
//提前一分钟请求
if expiresAtInt64 > currentTime+60 {
//fmt.Println("\n未过期无需请求")
proxyResp(c, tokenMap[token])
} else {
//fmt.Println("\n已过期重新请求")
if getGithubApi(token) {
proxyResp(c, tokenMap[token])
return
} else {
badRequest(c)
}
}
} else {
badRequest(c)
}
} else {
//tokenMap里的token对应的返回体不存在expires_at
if getGithubApi(token) {
proxyResp(c, tokenMap[token])
return
} else {
badRequest(c)
}
}
} else {
//不存在则githubApi发送请求并存到tokenMap
if getGithubApi(token) {
proxyResp(c, tokenMap[token])
return
} else {
badRequest(c)
}
}
}
}
// 请求错误
func badRequest(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"message": "Bad credentials",
"documentation_url": "https://docs.github.com/rest"})
}
// 本服务器响应
func proxyResp(c *gin.Context, respDataMap map[string]interface{}) {
// 将map转换为JSON字符串
responseJSON, err := json.Marshal(respDataMap)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"})
}
// 请求成功统计
successCount++
// 将JSON字符串作为响应体返回
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(responseJSON))
}
// 请求githubApi
func getGithubApi(token string) bool {
githubApiCount++
// 设置请求头
headers := map[string]string{
"Authorization": "token " + token,
/*"editor-version": c.GetHeader("editor-version"),
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
"user-agent": c.GetHeader("user-agent"),
"accept": c.GetHeader("accept"),
"accept-encoding": c.GetHeader("accept-encoding"),*/
}
// 发起GET请求
response, err := resty.New().R().
SetHeaders(headers).
Get(configFile.CopilotConfig.GithubApiUrl)
if err != nil {
// 处理请求错误
return false
}
// 判断响应状态码
if response.StatusCode() == http.StatusOK {
// 响应状态码为200 OK
respDataMap := map[string]interface{}{}
err = json.Unmarshal(response.Body(), &respDataMap)
if err != nil {
// 处理JSON解析错误
return false
}
//token map
tokenMap[token] = respDataMap
return true
} else {
// 处理其他状态码
delete(validTokenList, token)
return false
}
}
// 初始化配置文件
func initConfig() Config {
// 读取配置文件
exePath, err := os.Executable()
if err != nil {
panic(err)
}
// 获取执行文件所在目录
exeDir := filepath.Dir(exePath)
configFile, err := os.Open(exeDir + "/config.json")
if err != nil {
panic("file \"./config.json\" not found")
}
defer func(configFile *os.File) {
err := configFile.Close()
if err != nil {
panic("close file \"./config.json\" err")
}
}(configFile)
decoder := json.NewDecoder(configFile)
config := Config{}
err = decoder.Decode(&config)
if err != nil {
panic("config format err")
}
return config
}
// 重置请求计数
func resetRequestCount() {
requestCountMutex.Lock()
defer requestCountMutex.Unlock()
requestCount = 0
successCount = 0
}
// 从map中随机获取一个key
func getRandomToken(m map[string]bool) string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
if len(keys) == 0 {
return "" // 返回空字符串或处理其他错误情况
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndex := r.Intn(len(keys))
return keys[randomIndex]
}
// DomainMiddleware 域名中间件
func DomainMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查域名是否匹配
requestDomain := strings.Split(c.Request.Host, ":")[0]
if requestDomain == domain || requestDomain == "127.0.0.1" {
c.Next()
} else {
c.String(403, "Forbidden")
c.Abort()
}
}
}
// 显示信息
func showMsg() {
var url = ""
if configFile.Server.Port == 80 {
url = "http://" + configFile.Server.Domain
} else if configFile.Server.Port == 443 {
url = "https://" + configFile.Server.Domain
} else {
url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port)
}
var jetStr = color.WhiteString("[Jetbrains]")
var vsStr = color.WhiteString("[Vscode]")
var valid = color.WhiteString("[Valid tokens]")
fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token"))
fmt.Println(vsStr + ": " + color.HiBlueString(url))
fmt.Println(valid + ": " + color.HiBlueString(strconv.Itoa(len(validTokenList))))
fmt.Println(color.WhiteString("-----------------------------------------------------------------------"))
for {
requestCountMutex.Lock()
sCount := successCount
tCount := requestCount
gCount := githubApiCount
requestCountMutex.Unlock()
currentTime := time.Now().Format("2006-01-02 15:04:05")
if "00:00:00" == currentTime {
resetRequestCount()
}
var s2 = color.WhiteString("[Succeed]")
var s3 = color.WhiteString("[Failed]")
var s4 = color.WhiteString("[GithubApi]")
// 打印文本
fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ",
color.HiYellowString(currentTime),
s2, color.GreenString(strconv.Itoa(sCount)),
s3, color.RedString(strconv.Itoa(tCount-sCount)),
s4, color.CyanString(strconv.Itoa(gCount)))
time.Sleep(1 * time.Second) //
}
}