mirror of
https://gitee.com/chuangxxt/share-copilot
synced 2025-04-16 10:13:26 +00:00
271 lines
7.0 KiB
Go
271 lines
7.0 KiB
Go
package main
|
||
|
||
import (
|
||
"crypto/tls"
|
||
"encoding/json"
|
||
"fmt"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/go-resty/resty/v2"
|
||
"github.com/nsf/termbox-go"
|
||
"io"
|
||
"log"
|
||
"math/rand"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
func main() {
|
||
//初始化配置文件
|
||
configFile = initConfig()
|
||
// 创建Gin引擎
|
||
gin.SetMode(gin.ReleaseMode)
|
||
gin.DefaultWriter = io.Discard
|
||
engine := gin.New()
|
||
// 设置信任的代理
|
||
err := engine.SetTrustedProxies([]string{"127.0.0.1"})
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
// 定义路由
|
||
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
|
||
domainDefault.GET("/copilot_internal/v2/token", getToken())
|
||
initServer(engine)
|
||
// 显示信息
|
||
showMsg()
|
||
}
|
||
func initServer(engine *gin.Engine) {
|
||
// 配置支持的应用程序协议
|
||
server := &http.Server{
|
||
Addr: ":443",
|
||
TLSConfig: &tls.Config{
|
||
NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, // 支持的应用程序协议列表
|
||
},
|
||
Handler: engine,
|
||
}
|
||
// 启动Gin服务器并监听端口
|
||
listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
|
||
go func() {
|
||
if configFile.Server.Port != 443 {
|
||
err := engine.Run(listenAddress)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
} else {
|
||
err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
func showMsg() {
|
||
exitChan := make(chan bool)
|
||
dataGetter := func() []InfoItem {
|
||
var url = ""
|
||
if configFile.Server.Port == 80 {
|
||
url = "http://" + configFile.Server.Domain
|
||
} else if configFile.Server.Port == 443 {
|
||
url = "https://" + configFile.Server.Domain
|
||
} else {
|
||
url = "http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port)
|
||
}
|
||
requestCountMutex.Lock()
|
||
sCount := successCount
|
||
gCount := githubApiCount
|
||
requestCountMutex.Unlock()
|
||
currentTime := time.Now().Format("2006-01-02 15:04:05")
|
||
if "00:00:00" == currentTime {
|
||
resetRequestCount()
|
||
}
|
||
return []InfoItem{
|
||
{
|
||
Title: "[Jetbrains]",
|
||
Value: url + "/copilot_internal/v2/token",
|
||
TitleColor: termbox.ColorGreen,
|
||
ValueColor: termbox.ColorYellow,
|
||
},
|
||
{
|
||
Title: "[Vscode]",
|
||
Value: url,
|
||
TitleColor: termbox.ColorGreen,
|
||
ValueColor: termbox.ColorYellow,
|
||
},
|
||
{
|
||
Title: "[User Request]:",
|
||
Value: strconv.Itoa(sCount),
|
||
TitleColor: termbox.ColorGreen,
|
||
ValueColor: termbox.ColorYellow,
|
||
},
|
||
{
|
||
Title: "[GithubApi request]:",
|
||
Value: strconv.Itoa(gCount),
|
||
TitleColor: termbox.ColorGreen,
|
||
ValueColor: termbox.ColorYellow,
|
||
},
|
||
{
|
||
Title: "Time:",
|
||
Value: currentTime,
|
||
TitleColor: termbox.ColorGreen,
|
||
ValueColor: termbox.ColorYellow,
|
||
},
|
||
}
|
||
}
|
||
exitChan = make(chan bool)
|
||
// 调用显示函数,传递数据获取器函数
|
||
go DisplayInfo(dataGetter, exitChan)
|
||
for {
|
||
ev := termbox.PollEvent()
|
||
if ev.Type == termbox.EventKey && ev.Key == termbox.KeyEsc {
|
||
exitChan <- true // 发送退出信号
|
||
break
|
||
}
|
||
}
|
||
|
||
}
|
||
func getToken() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 请求计数
|
||
incrementRequestCount()
|
||
// 如果配置了verification,则需要获取请求头中的Authorization令牌
|
||
if configFile.Verification != "" {
|
||
token := c.GetHeader("Authorization")
|
||
tokenStr := strings.ReplaceAll(token, " ", "")
|
||
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
|
||
if tokenStr != "token"+configCert {
|
||
// 拒绝不符合Verification的请求
|
||
c.JSON(http.StatusBadRequest, gin.H{
|
||
"message": "Bad credentials",
|
||
"documentation_url": "https://docs.github.com/rest"})
|
||
return
|
||
}
|
||
}
|
||
//判断时间戳key是否存在
|
||
if _, exists := responseData["expires_at"]; exists {
|
||
// 获取当前时间的Unix时间戳
|
||
currentTime := time.Now().Unix()
|
||
if expiresAt, ok := responseData["expires_at"].(float64); ok {
|
||
// 判断expires_at是否已经过期
|
||
expiresAtInt64 := int64(expiresAt)
|
||
//提前一分钟请求
|
||
if expiresAtInt64 > currentTime+60 {
|
||
//fmt.Println("\n未过期无需请求")
|
||
respProxy(c)
|
||
} else {
|
||
//fmt.Println("\n已过期重新请求")
|
||
getGithubApi(c)
|
||
respProxy(c)
|
||
}
|
||
} else {
|
||
fmt.Println("Age is not an int")
|
||
}
|
||
} else {
|
||
//向githubApi发送请求
|
||
//fmt.Println("\n第一次请求")
|
||
getGithubApi(c)
|
||
respProxy(c)
|
||
}
|
||
}
|
||
}
|
||
func respProxy(c *gin.Context) {
|
||
// 将map转换为JSON字符串
|
||
responseJSON, err := json.Marshal(responseData)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"})
|
||
}
|
||
// 请求成功统计
|
||
incrementSuccessCount()
|
||
// 将JSON字符串作为响应体返回
|
||
c.Header("Content-Type", "application/json")
|
||
c.String(http.StatusOK, string(responseJSON))
|
||
}
|
||
func getGithubApi(c *gin.Context) {
|
||
githubApiCount++
|
||
// 设置请求头
|
||
headers := map[string]string{
|
||
"Authorization": "token " + getRandomToken(configFile.CopilotConfig.Token),
|
||
"editor-version": c.GetHeader("editor-version"),
|
||
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
|
||
"user-agent": c.GetHeader("user-agent"),
|
||
"accept": c.GetHeader("accept"),
|
||
"accept-encoding": c.GetHeader("accept-encoding"),
|
||
}
|
||
// 发起GET请求
|
||
response, err := resty.New().R().
|
||
SetHeaders(headers).
|
||
Get(configFile.CopilotConfig.GithubApiUrl)
|
||
if err != nil {
|
||
// 处理请求错误
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
err = json.Unmarshal(response.Body(), &responseData)
|
||
if err != nil {
|
||
// 处理JSON解析错误
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"})
|
||
return
|
||
}
|
||
}
|
||
func initConfig() Config {
|
||
// 读取配置文件
|
||
exePath, err := os.Executable()
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
// 获取执行文件所在目录
|
||
exeDir := filepath.Dir(exePath)
|
||
configFile, err := os.Open(exeDir + "/config.json")
|
||
if err != nil {
|
||
panic("file \"./config.json\" not found")
|
||
}
|
||
defer func(configFile *os.File) {
|
||
err := configFile.Close()
|
||
if err != nil {
|
||
panic("close file \"./config.json\" err")
|
||
}
|
||
}(configFile)
|
||
decoder := json.NewDecoder(configFile)
|
||
config := Config{}
|
||
err = decoder.Decode(&config)
|
||
if err != nil {
|
||
panic("config format err")
|
||
}
|
||
return config
|
||
}
|
||
func incrementRequestCount() {
|
||
requestCount++
|
||
}
|
||
func incrementSuccessCount() {
|
||
successCount++
|
||
}
|
||
func resetRequestCount() {
|
||
requestCountMutex.Lock()
|
||
defer requestCountMutex.Unlock()
|
||
requestCount = 0
|
||
successCount = 0
|
||
}
|
||
func getRandomToken(tokens []string) string {
|
||
if len(tokens) == 0 {
|
||
return "" // 返回空字符串或处理其他错误情况
|
||
}
|
||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||
randomIndex := r.Intn(len(tokens))
|
||
return tokens[randomIndex]
|
||
}
|
||
func DomainMiddleware(domain string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 检查域名是否匹配
|
||
requestDomain := strings.Split(c.Request.Host, ":")[0]
|
||
if requestDomain == domain || requestDomain == "127.0.0.1" {
|
||
c.Next()
|
||
} else {
|
||
c.String(403, "Forbidden")
|
||
c.Abort()
|
||
}
|
||
}
|
||
}
|