share-copilot/source/server.go
ben Gutier 7abff5fe83 优化
2023-09-11 12:24:04 +08:00

122 lines
2.9 KiB
Go

package main
import (
"crypto/tls"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
)
// 初始化配置文件
func initConfig() Config {
// 读取配置文件
exePath, err := os.Executable()
if err != nil {
panic(err)
}
// 获取执行文件所在目录
exeDir := filepath.Dir(exePath)
configFile, err := os.Open(exeDir + "/config.json")
if err != nil {
panic("file \"./config.json\" not found")
}
defer func(configFile *os.File) {
err := configFile.Close()
if err != nil {
panic("close file \"./config.json\" err")
}
}(configFile)
decoder := json.NewDecoder(configFile)
config := Config{}
err = decoder.Decode(&config)
if err != nil {
panic("config format err")
}
return config
}
// 创建和配置Gin引擎
func setupGinEngine() *gin.Engine {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
engine := gin.New()
// 设置信任的代理
if err := engine.SetTrustedProxies([]string{"127.0.0.1"}); err != nil {
log.Fatal(err)
}
return engine
}
// 定义路由和中间件
func setupRoutes(engine *gin.Engine) {
domainDefault := engine.Group("/", DomainMiddleware(configFile.Server.Domain))
domainDefault.GET("/copilot_internal/v2/token", getGithubToken())
}
// DomainMiddleware 域名中间件
func DomainMiddleware(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查域名是否匹配
requestDomain := strings.Split(c.Request.Host, ":")[0]
if requestDomain == domain || requestDomain == "127.0.0.1" {
c.Next()
} else {
c.String(403, "Forbidden")
c.Abort()
}
}
}
// 初始化和启动服务器
func initAndStartServer(engine *gin.Engine) {
listenAddress := configFile.Server.Host + ":" + strconv.Itoa(configFile.Server.Port)
server := createTLSServer(engine, listenAddress)
go func() {
if configFile.Server.Port != 443 {
err := engine.Run(listenAddress)
log.Fatal(err)
} else {
err := server.ListenAndServeTLS(configFile.Server.CertPath, configFile.Server.KeyPath)
log.Fatal(err)
}
}()
}
// 创建TLS服务器配置
func createTLSServer(engine *gin.Engine, address string) *http.Server {
return &http.Server{
Addr: address,
TLSConfig: &tls.Config{
NextProtos: []string{"http/1.1", "http/1.2", "http/2"},
},
Handler: engine,
}
}
// 本服务器响应
func proxyResp(c *gin.Context, respDataMap map[string]interface{}) {
// 将map转换为JSON字符串
responseJSON, err := json.Marshal(respDataMap)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "JSON marshaling error"})
}
// 请求成功统计
successCount++
// 将JSON字符串作为响应体返回
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, string(responseJSON))
}
// 请求错误
func badRequest(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"message": "Bad credentials",
"documentation_url": "https://docs.github.com/rest"})
}