package main

import (
	"crypto/tls"
	"encoding/json"
	"fmt"
	"github.com/fatih/color"
	"github.com/gin-gonic/gin"
	"github.com/go-resty/resty/v2"
	"io"
	"log"
	"math/rand"
	"net/http"
	"os"
	"strconv"
	"strings"
	"sync"
	"time"
)
type Config struct {
	Server struct {
		Domain   string `json:"domain"`
		Host     string `json:"host"`
		Port     int    `json:"port"`
		CertPath string `json:"certPath"`
		KeyPath  string `json:"keyPath"`
	} `json:"server"`
	CopilotConfig struct {
		GithubApiUrl string   `json:"github_api_url"`
		Token        []string `json:"token"`
	} `json:"copilot_config"`
	Verification string `json:"verification"`
}
var (
	requestCountMutex sync.Mutex
	requestCount      int
	successCount      int
	configFile        = Config{}
)
func main() {
	gin.SetMode(gin.ReleaseMode)
	gin.DefaultWriter = io.Discard
	//初始化配置文件
	configFile = initConfig()
	// 创建Gin引擎
	engine := gin.New()
	// 自定义错误处理程序
	engine.Use(func(c *gin.Context) {
		c.Next()
		err := c.Errors.Last()
		if err != nil {
			log.Printf("Error: %v", err.Err)
			c.JSON(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"})
		}
	})
	// 设置信任的代理
	err := engine.SetTrustedProxies([]string{"127.0.0.1"})
	if err != nil {
		log.Fatal(err)
	}
	// 配置支持的应用程序协议
	server := &http.Server{
		Addr: ":443",
		TLSConfig: &tls.Config{
			NextProtos: []string{"http/1.1", "http/1.2", "http/2"}, // 支持的应用程序协议列表
		},
		Handler: engine,
	}
	// 定义路由
	domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
	domainDefault.GET("/copilot_internal/v2/token", getToken())
	// 启动Gin服务器并监听端口
	listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
	go func() {
		if configFile.Server.Port != 443 {
			err := engine.Run(listenAddress)
			if err != nil {
				log.Fatal(err)
			}
		} else {
			err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
			if err != nil {
				log.Fatal(err)
			}
		}
	}()
	// 显示url
	displayMsg()
	ticker := time.NewTicker(1 * time.Second)
	defer ticker.Stop()
	for range ticker.C {
		showRequestCount()
	}
}
func getToken() gin.HandlerFunc {
	return func(c *gin.Context) {
		// 请求计数
		incrementRequestCount()
		// 获取请求头中的Authorization令牌
		if configFile.Verification != "" {
			token := c.GetHeader("Authorization")
			tokenStr := strings.ReplaceAll(token, " ", "")
			configCert := strings.ReplaceAll(configFile.Verification, " ", "")
			if tokenStr != "token"+configCert {
				// 拒绝不符合Verification的请求
				c.JSON(http.StatusBadRequest, gin.H{
					"message":           "Bad credentials",
					"documentation_url": "https://docs.github.com/rest"})
				return
			}
		}
		// 设置请求头
		headers := map[string]string{
			"Authorization":         "token " + getRandomToken(configFile.CopilotConfig.Token),
			"editor-version":        c.GetHeader("editor-version"),
			"editor-plugin-version": c.GetHeader("editor-plugin-version"),
			"user-agent":            c.GetHeader("user-agent"),
			"accept":                c.GetHeader("accept"),
			"accept-encoding":       c.GetHeader("accept-encoding"),
		}
		// 发起GET请求
		response, err := resty.New().R().
			SetHeaders(headers).
			Get(configFile.CopilotConfig.GithubApiUrl)
		if err != nil {
			// 处理请求错误
			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
			return
		}
		// 检查响应状态码
		if response.StatusCode() != http.StatusOK {
			// 处理非200响应
			c.JSON(response.StatusCode(), gin.H{"error": "Request failed"})
			return
		}
		// 请求成功统计
		incrementSuccessCount()
		// 从响应中获取内容并发送回客户端
		responseBody := response.String()
		// 设置响应头 重要
		c.Header("Content-Type", "application/json")
		c.String(http.StatusOK, responseBody)
	}
}
func initConfig() Config {
	// 打开或创建 error.log 文件
	logFile, err := os.OpenFile("error.log", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
	if err != nil {
		log.Fatal(err)
	}
	defer logFile.Close()
	// 设置日志输出到文件
	log.SetOutput(logFile)
	// 读取配置文件
	configFile, err := os.Open("./config.json")
	if err != nil {
		panic("file \"./config.json\" not found")
	}
	// 获取可执行文件所在目录
	exeDir := filepath.Dir(exePath)
	configFile, err := os.Open(exeDir + "/config.json")
	if err != nil {
		panic("file \"./config.json\" not found")
	}
	defer func(configFile *os.File) {
		err := configFile.Close()
		if err != nil {
			panic("close file \"./config.json\" err")
		}
	}(configFile)
	decoder := json.NewDecoder(configFile)
	config := Config{}
	err = decoder.Decode(&config)
	if err != nil {
		panic("config format err")
	}
	return config
}
func displayMsg() {
	c := color.New(color.FgGreen)
	fmt.Printf("\n------------------------------------------------------------\n")
	if configFile.Server.Port == 80 {
		_, err := c.Println("[ApiUri] http://" + configFile.Server.Domain + "/copilot_internal/v2/token\r")
		if err != nil {
			return
		}
	} else if configFile.Server.Port == 443 {
		_, err := c.Println("[ApiUrl] https://" + configFile.Server.Domain + "/copilot_internal/v2/token\r")
		if err != nil {
			return
		}
	} else {
		_, err := c.Println("[ApiUrl] http://" + configFile.Server.Domain + ":" + strconv.Itoa(configFile.Server.Port) + "/copilot_internal/v2/token\r")
		if err != nil {
			return
		}
	}
	fmt.Printf("------------------------------------------------------------\n")
}
func incrementRequestCount() {
	requestCount++
}
func incrementSuccessCount() {
	successCount++
}
func showRequestCount() {
	requestCountMutex.Lock()
	count := requestCount
	sCount := successCount
	requestCountMutex.Unlock()
	currentTime := time.Now().Format("2006-01-02 15:04:05")
	if time.Now().Format("15:04:05") == "00:00:00" {
		resetRequestCount()
	}
	// 黄色文本
	timeStr := fmt.Sprintf("\x1b[33m%s\x1b[0m", currentTime)
	// 蓝色文本
	countStr := fmt.Sprintf("\x1b[34m%d\x1b[0m", count)
	// 绿色文本
	successCountStr := fmt.Sprintf("\x1b[32m%d\x1b[0m", sCount)
	// 红色文本
	failureCountStr := fmt.Sprintf("\x1b[31m%d\x1b[0m", count-sCount)
	fmt.Printf("\033[1G%s - Total Count: %s | Success Count %s | Fail Count %s   ", timeStr, countStr, successCountStr, failureCountStr)
}
func resetRequestCount() {
	requestCountMutex.Lock()
	defer requestCountMutex.Unlock()
	requestCount = 0
	successCount = 0
}
func getRandomToken(tokens []string) string {
	if len(tokens) == 0 {
		return "" // 返回空字符串或处理其他错误情况
	}
	r := rand.New(rand.NewSource(time.Now().UnixNano()))
	randomIndex := r.Intn(len(tokens))
	return tokens[randomIndex]
}
func DomainMiddleware(domain string) gin.HandlerFunc {
	return func(c *gin.Context) {
		// 检查域名是否匹配
		requestDomain := strings.Split(c.Request.Host, ":")[0]
		if requestDomain == domain {
			c.Next()
		} else {
			c.String(403, "Forbidden")
			c.Abort()
		}
	}
}