share-copilot/source/main.go
2023-09-10 23:10:55 +08:00

244 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/fatih/color"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"io"
"log"
"math/rand"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
func main() {
//初始化配置文件
configFile = initConfig()
// 创建Gin引擎
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
engine := gin.New()
// 设置信任的代理
err := engine.SetTrustedProxies([]string{"127.0.0.1"})
if err != nil {
log.Fatal(err)
}
// 定义路由
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
domainDefault.GET("/copilot_internal/v2/token", getToken())
initServer(engine)
// 显示信息
showMsg()
}
func initServer(engine *gin.Engine) {
// 配置支持的应用程序协议
server := &http.Server{
Addr: ":443",
TLSConfig: &tls.Config{
NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, // 支持的应用程序协议列表
},
Handler: engine,
}
// 启动Gin服务器并监听端口
listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
go func() {
if configFile.Server.Port != 443 {
err := engine.Run(listenAddress)
if err != nil {
log.Fatal(err)
}
} else {
err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
if err != nil {
log.Fatal(err)
}
}
}()
}
func showMsg() {
var url = ""
if configFile.Server.Port == 80 {
url = "http://" + configFile.Server.Domain
} else if configFile.Server.Port == 443 {
url = "https://" + configFile.Server.Domain
} else {
url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port)
}
var jetStr = color.WhiteString("[Jetbrains]")
var vsStr = color.WhiteString("[Vscode]")
fmt.Println(jetStr + ": " + color.HiBlueString(url+"/copilot_internal/v2/token"))
fmt.Println(vsStr + ": " + color.HiBlueString(url))
fmt.Println(color.WhiteString("-----------------------------------------------------------------------"))
for {
requestCountMutex.Lock()
sCount := successCount
tCount := requestCount
gCount := githubApiCount
requestCountMutex.Unlock()
currentTime := time.Now().Format("2006-01-02 15:04:05")
if "00:00:00" == currentTime {
resetRequestCount()
}
var s2 = color.WhiteString("[Succeed]")
var s3 = color.WhiteString("[Failed]")
var s4 = color.WhiteString("[GithubApi]")
// 打印文本
fmt.Printf("\033[G%s - %s: %s %s: %s %s: %s ",
color.HiYellowString(currentTime),
s2, color.GreenString(strconv.Itoa(sCount)),
s3, color.RedString(strconv.Itoa(tCount-sCount)),
s4, color.CyanString(strconv.Itoa(gCount)))
time.Sleep(1 * time.Second) //
}
}
func getToken() gin.HandlerFunc {
return func(c *gin.Context) {
// 请求计数
incrementRequestCount()
// 如果配置了verification则需要获取请求头中的Authorization令牌
if configFile.Verification != "" {
token := c.GetHeader("Authorization")
tokenStr := strings.ReplaceAll(token, " ", "")
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
if tokenStr != "token"+configCert {
// 拒绝不符合Verification的请求
c.JSON(http.StatusBadRequest, gin.H{
"message": "Bad credentials",
"documentation_url": "https://docs.github.com/rest"})
return
}
}
//判断时间戳key是否存在
if _, exists := responseData["expires_at"]; exists {
// 获取当前时间的Unix时间戳
currentTime := time.Now().Unix()
if expiresAt, ok := responseData["expires_at"].(float64); ok {
// 判断expires_at是否已经过期
expiresAtInt64 := int64(expiresAt)
//提前一分钟请求
if expiresAtInt64 > currentTime+60 {
//fmt.Println("\n未过期无需请求")
respProxy(c)
} else {
//fmt.Println("\n已过期重新请求")
getGithubApi(c)
respProxy(c)
}
} else {
fmt.Println("Age is not an int")
}
} else {
//向githubApi发送请求
//fmt.Println("\n第一次请求")
getGithubApi(c)
respProxy(c)
}
}
}
func respProxy(c *gin.Context) {
// 将map转换为JSON字符串
responseJSON, err := json.Marshal(responseData)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"})
}
// 请求成功统计
incrementSuccessCount()
// 将JSON字符串作为响应体返回
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(responseJSON))
}
func getGithubApi(c *gin.Context) {
githubApiCount++
// 设置请求头
headers := map[string]string{
"Authorization": "token " + getRandomToken(configFile.CopilotConfig.Token),
"editor-version": c.GetHeader("editor-version"),
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
"user-agent": c.GetHeader("user-agent"),
"accept": c.GetHeader("accept"),
"accept-encoding": c.GetHeader("accept-encoding"),
}
// 发起GET请求
response, err := resty.New().R().
SetHeaders(headers).
Get(configFile.CopilotConfig.GithubApiUrl)
if err != nil {
// 处理请求错误
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
err = json.Unmarshal(response.Body(), &responseData)
if err != nil {
// 处理JSON解析错误
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"})
return
}
}
func initConfig() Config {
// 读取配置文件
exePath, err := os.Executable()
if err != nil {
panic(err)
}
// 获取执行文件所在目录
exeDir := filepath.Dir(exePath)
configFile, err := os.Open(exeDir + "/config.json")
if err != nil {
panic("file \"./config.json\" not found")
}
defer func(configFile *os.File) {
err := configFile.Close()
if err != nil {
panic("close file \"./config.json\" err")
}
}(configFile)
decoder := json.NewDecoder(configFile)
config := Config{}
err = decoder.Decode(&config)
if err != nil {
panic("config format err")
}
return config
}
func incrementRequestCount() {
requestCount++
}
func incrementSuccessCount() {
successCount++
}
func resetRequestCount() {
requestCountMutex.Lock()
defer requestCountMutex.Unlock()
requestCount = 0
successCount = 0
}
func getRandomToken(tokens []string) string {
if len(tokens) == 0 {
return "" // 返回空字符串或处理其他错误情况
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndex := r.Intn(len(tokens))
return tokens[randomIndex]
}
func DomainMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查域名是否匹配
requestDomain := strings.Split(c.Request.Host, ":")[0]
if requestDomain == domain || requestDomain == "127.0.0.1" {
c.Next()
} else {
c.String(403, "Forbidden")
c.Abort()
}
}
}