share-copilot/source/main.go
2023-09-09 11:49:30 +08:00

251 lines
6.9 KiB
Go

package main
import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/fatih/color"
"github.com/gin-gonic/gin"
"github.com/go-resty/resty/v2"
"io"
"log"
"math/rand"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
)
type Config struct {
Server struct {
Domain string `json:"domain"`
Host string `json:"host"`
Port int `json:"port"`
CertPath string `json:"certPath"`
KeyPath string `json:"keyPath"`
} `json:"server"`
CopilotConfig struct {
GithubApiUrl string `json:"github_api_url"`
Token []string `json:"token"`
} `json:"copilot_config"`
Verification string `json:"verification"`
}
var (
requestCountMutex sync.Mutex
requestCount int
successCount int
configFile = Config{}
)
func main() {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
//初始化配置文件
configFile = initConfig()
// 创建Gin引擎
engine := gin.New()
// 自定义错误处理程序
engine.Use(func(c *gin.Context) {
c.Next()
err := c.Errors.Last()
if err != nil {
log.Printf("Error: %v", err.Err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"})
}
})
// 设置信任的代理
err := engine.SetTrustedProxies([]string{"127.0.0.1"})
if err != nil {
log.Fatal(err)
}
// 配置支持的应用程序协议
server := &http.Server{
Addr: ":443",
TLSConfig: &tls.Config{
NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, // 支持的应用程序协议列表
},
Handler: engine,
}
// 定义路由
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
domainDefault.GET("/copilot_internal/v2/token", getToken())
// 启动Gin服务器并监听端口
listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
go func() {
if configFile.Server.Port != 443 {
err := engine.Run(listenAddress)
if err != nil {
log.Fatal(err)
}
} else {
err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
if err != nil {
log.Fatal(err)
}
}
}()
// 显示url
displayMsg()
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for range ticker.C {
showRequestCount()
}
}
func getToken() gin.HandlerFunc {
return func(c *gin.Context) {
// 请求计数
incrementRequestCount()
// 获取请求头中的Authorization令牌
if configFile.Verification != "" {
token := c.GetHeader("Authorization")
tokenStr := strings.ReplaceAll(token, " ", "")
configCert := strings.ReplaceAll(configFile.Verification, " ", "")
if tokenStr != "token"+configCert {
// 拒绝不符合Verification的请求
c.JSON(http.StatusBadRequest, gin.H{
"message": "Bad credentials",
"documentation_url": "https://docs.github.com/rest"})
return
}
}
// 设置请求头
headers := map[string]string{
"Authorization": "token " + getRandomToken(configFile.CopilotConfig.Token),
"editor-version": c.GetHeader("editor-version"),
"editor-plugin-version": c.GetHeader("editor-plugin-version"),
"user-agent": c.GetHeader("user-agent"),
"accept": c.GetHeader("accept"),
"accept-encoding": c.GetHeader("accept-encoding"),
}
// 发起GET请求
response, err := resty.New().R().
SetHeaders(headers).
Get(configFile.CopilotConfig.GithubApiUrl)
if err != nil {
// 处理请求错误
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 检查响应状态码
if response.StatusCode() != http.StatusOK {
// 处理非200响应
c.JSON(response.StatusCode(), gin.H{"error": "Request failed"})
return
}
// 请求成功统计
incrementSuccessCount()
// 从响应中获取内容并发送回客户端
responseBody := response.String()
// 设置响应头 重要
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, responseBody)
}
}
func initConfig() Config {
// 打开或创建 error.log 文件
logFile, err := os.OpenFile("error.log", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
log.Fatal(err)
}
defer logFile.Close()
// 设置日志输出到文件
log.SetOutput(logFile)
// 读取配置文件
configFile, err := os.Open("./config.json")
if err != nil {
panic("file \"./config.json\" not found")
}
// 获取可执行文件所在目录
exeDir := filepath.Dir(exePath)
configFile, err := os.Open(exeDir + "/config.json")
if err != nil {
panic("file \"./config.json\" not found")
}
defer func(configFile *os.File) {
err := configFile.Close()
if err != nil {
panic("close file \"./config.json\" err")
}
}(configFile)
decoder := json.NewDecoder(configFile)
config := Config{}
err = decoder.Decode(&config)
if err != nil {
panic("config format err")
}
return config
}
func displayMsg() {
c := color.New(color.FgGreen)
fmt.Printf("\n------------------------------------------------------------\n")
if configFile.Server.Port == 80 {
_, err := c.Println("[ApiUri] http://" + configFile.Server.Domain + "/copilot_internal/v2/token\r")
if err != nil {
return
}
} else if configFile.Server.Port == 443 {
_, err := c.Println("[ApiUrl] https://" + configFile.Server.Domain + "/copilot_internal/v2/token\r")
if err != nil {
return
}
} else {
_, err := c.Println("[ApiUrl] http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) + "/copilot_internal/v2/token\r")
if err != nil {
return
}
}
fmt.Printf("------------------------------------------------------------\n")
}
func incrementRequestCount() {
requestCount++
}
func incrementSuccessCount() {
successCount++
}
func showRequestCount() {
requestCountMutex.Lock()
count := requestCount
sCount := successCount
requestCountMutex.Unlock()
currentTime := time.Now().Format("2006-01-02 15:04:05")
if time.Now().Format("15:04:05") == "00:00:00" {
resetRequestCount()
}
// 黄色文本
timeStr := fmt.Sprintf("\x1b[33m%s\x1b[0m", currentTime)
// 蓝色文本
countStr := fmt.Sprintf("\x1b[34m%d\x1b[0m", count)
// 绿色文本
successCountStr := fmt.Sprintf("\x1b[32m%d\x1b[0m", sCount)
// 红色文本
failureCountStr := fmt.Sprintf("\x1b[31m%d\x1b[0m", count-sCount)
fmt.Printf("\033[1G%s - Total Count: %s | Success Count %s | Fail Count %s ", timeStr, countStr, successCountStr, failureCountStr)
}
func resetRequestCount() {
requestCountMutex.Lock()
defer requestCountMutex.Unlock()
requestCount = 0
successCount = 0
}
func getRandomToken(tokens []string) string {
if len(tokens) == 0 {
return "" // 返回空字符串或处理其他错误情况
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randomIndex := r.Intn(len(tokens))
return tokens[randomIndex]
}
func DomainMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查域名是否匹配
requestDomain := strings.Split(c.Request.Host, ":")[0]
if requestDomain == domain {
c.Next()
} else {
c.String(403, "Forbidden")
c.Abort()
}
}
}