share-copilot/source/main.go
ben Gutier ca203c6c63 0.2
2023-09-10 15:52:26 +08:00

281 lines
7.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/fatih/color"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"io"
"log"
"math/rand"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
var (
//初始化需要返回给客户端的响应体
responseData map[string]interface{}
requestCountMutex sync.Mutex
githubApiCount int
requestCount int
successCount int
configFile Config
)
func main() {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
//初始化配置文件
configFile = initConfig()
// 创建Gin引擎
engine := gin.New()
// 自定义错误处理程序
engine.Use(func(c *gin.Context) {
c.Next()
err := c.Errors.Last()
if err != nil {
log.Printf("Error: %v", err.Err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"})
}
})
// 设置信任的代理
err := engine.SetTrustedProxies([]string{"127.0.0.1"})
if err != nil {
log.Fatal(err)
}
// 配置支持的应用程序协议
server := &http.Server{
Addr: ":443",
TLSConfig: &tls.Config{
NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, // 支持的应用程序协议列表
},
Handler: engine,
}
// 定义路由
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
domainDefault.GET("/copilot_internal/v2/token", getToken())
// 启动Gin服务器并监听端口
listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
go func() {
if configFile.Server.Port != 443 {
err := engine.Run(listenAddress)
if err != nil {
log.Fatal(err)
}
} else {
err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
if err != nil {
log.Fatal(err)
}
}
}()
// 显示url
displayMsg()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for range ticker.C {
showRequestCount()
}
}
func getToken() gin.HandlerFunc {
return func(c *gin.Context) {
// 请求计数
incrementRequestCount()
// 如果配置了verification则需要获取请求头中的Authorization令牌
if configFile.Verification != "" {
token := c.GetHeader("Authorization")
tokenStr := strings.ReplaceAll(token, " ", "")
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
if tokenStr != "token"+configCert {
// 拒绝不符合Verification的请求
c.JSON(http.StatusBadRequest, gin.H{
"message": "Bad credentials",
"documentation_url": "https://docs.github.com/rest"})
return
}
}
//判断时间戳key是否存在
if _, exists := responseData["expires_at"]; exists {
// 获取当前时间的Unix时间戳
currentTime := time.Now().Unix()
if expiresAt, ok := responseData["expires_at"].(float64); ok {
// 判断expires_at是否已经过期
expiresAtInt64 := int64(expiresAt)
//提前一分钟请求
if expiresAtInt64 > currentTime+60 {
fmt.Println("\n未过期无需请求")
respProxy(c)
} else {
fmt.Println("\n已过期重新请求")
getGithubApi(c)
respProxy(c)
}
} else {
fmt.Println("Age is not an int")
}
} else {
//向githubApi发送请求
fmt.Println("\n第一次请求")
getGithubApi(c)
respProxy(c)
}
}
}
func respProxy(c *gin.Context) {
// 将map转换为JSON字符串
responseJSON, err := json.Marshal(responseData)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"})
}
// 请求成功统计
incrementSuccessCount()
// 将JSON字符串作为响应体返回
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(responseJSON))
}
func getGithubApi(c *gin.Context) {
githubApiCount++
// 设置请求头
headers := map[string]string{
"Authorization": "token " + getRandomToken(configFile.CopilotConfig.Token),
"editor-version": c.GetHeader("editor-version"),
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
"user-agent": c.GetHeader("user-agent"),
"accept": c.GetHeader("accept"),
"accept-encoding": c.GetHeader("accept-encoding"),
}
// 发起GET请求
response, err := resty.New().R().
SetHeaders(headers).
Get(configFile.CopilotConfig.GithubApiUrl)
if err != nil {
// 处理请求错误
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
err = json.Unmarshal(response.Body(), &responseData)
if err != nil {
// 处理JSON解析错误
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON parsing error"})
return
}
}
func initConfig() Config {
// 打开或创建 error.log 文件
logFile, err := os.OpenFile("error.log", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
log.Fatal(err)
}
defer func(logFile *os.File) {
err := logFile.Close()
if err != nil {
panic(err)
}
}(logFile)
// 设置日志输出到文件
log.SetOutput(logFile)
// 读取配置文件
exePath, err := os.Executable()
if err != nil {
panic(err)
}
// 获取可执行文件所在目录
exeDir := filepath.Dir(exePath)
configFile, err := os.Open(exeDir + "/config.json")
if err != nil {
panic("file \"./config.json\" not found")
}
defer func(configFile *os.File) {
err := configFile.Close()
if err != nil {
panic("close file \"./config.json\" err")
}
}(configFile)
decoder := json.NewDecoder(configFile)
config := Config{}
err = decoder.Decode(&config)
if err != nil {
panic("config format err")
}
return config
}
func displayMsg() {
c := color.New(color.FgGreen)
fmt.Printf("\n------------------------------------------------------------\n")
if configFile.Server.Port == 80 {
_, err := c.Println("[ApiUri] http://" + configFile.Server.Domain + "/copilot_internal/v2/token\r")
if err != nil {
return
}
} else if configFile.Server.Port == 443 {
_, err := c.Println("[ApiUrl] https://" + configFile.Server.Domain + "/copilot_internal/v2/token\r")
if err != nil {
return
}
} else {
_, err := c.Println("[ApiUrl] http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) + "/copilot_internal/v2/token\r")
if err != nil {
return
}
}
fmt.Printf("------------------------------------------------------------\n")
}
func incrementRequestCount() {
requestCount++
}
func incrementSuccessCount() {
successCount++
}
func showRequestCount() {
requestCountMutex.Lock()
count := requestCount
sCount := successCount
requestCountMutex.Unlock()
currentTime := time.Now().Format("2006-01-02 15:04:05")
if time.Now().Format("15:04:05") == "00:00:00" {
resetRequestCount()
}
// 黄色文本
timeStr := fmt.Sprintf("\x1b[33m%s\x1b[0m", currentTime)
// 蓝色文本
countStr := fmt.Sprintf("\x1b[34m%d\x1b[0m", count)
// 绿色文本
successCountStr := fmt.Sprintf("\x1b[32m%d\x1b[0m", sCount)
failureCountStr := fmt.Sprintf("\x1b[31m%d\x1b[0m", count-sCount)
fmt.Printf("\033[1G%s - Total Count: %s | Success Count %s | Fail Count %s ", timeStr, countStr, successCountStr, failureCountStr)
}
func resetRequestCount() {
requestCountMutex.Lock()
defer requestCountMutex.Unlock()
requestCount = 0
successCount = 0
}
func getRandomToken(tokens []string) string {
if len(tokens) == 0 {
return "" // 返回空字符串或处理其他错误情况
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndex := r.Intn(len(tokens))
return tokens[randomIndex]
}
func DomainMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查域名是否匹配
requestDomain := strings.Split(c.Request.Host, ":")[0]
if requestDomain == domain {
c.Next()
} else {
c.String(403, "Forbidden")
c.Abort()
}
}
}