From 6e0b8f546bdeae3fbf724949bf4c78e06ddf8198 Mon Sep 17 00:00:00 2001 From: Merith-TK Date: Tue, 4 Feb 2025 17:05:54 +0000 Subject: [PATCH] overhaul code --- main.go | 152 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 91 insertions(+), 61 deletions(-) diff --git a/main.go b/main.go index 0411269..575146b 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,16 @@ package main import ( + "context" + "flag" "io" - "io/ioutil" "log" "net" "os" + "os/signal" "strings" + "sync" + "syscall" "github.com/yosuke-furukawa/json5/encoding/json5" ) @@ -26,7 +30,7 @@ func handleTCPConnection(src net.Conn, targetAddr string) { dst, err := net.Dial("tcp", targetAddr) if err != nil { - log.Printf("Unable to connect to target: %v\n", err) + log.Printf("[ERROR] Unable to connect to target: %v\n", err) return } defer dst.Close() @@ -46,122 +50,148 @@ func handleTCPConnection(src net.Conn, targetAddr string) { <-done } -func startTCPProxy(localAddr, targetAddr string) { +func startTCPProxy(ctx context.Context, wg *sync.WaitGroup, localAddr, targetAddr string) { + defer wg.Done() + listener, err := net.Listen("tcp", localAddr) if err != nil { - log.Fatalf("Unable to listen on %s: %v\n", localAddr, err) + log.Fatalf("[ERROR] Unable to listen on %s: %v\n", localAddr, err) } defer listener.Close() - log.Printf("Listening on %s (TCP), forwarding to %s\n", localAddr, targetAddr) + log.Printf("[INFO] Listening on %s (TCP), forwarding to %s\n", localAddr, targetAddr) for { - conn, err := listener.Accept() - if err != nil { - log.Printf("Failed to accept connection: %v\n", err) - continue + select { + case <-ctx.Done(): + log.Printf("[INFO] Shutting down TCP proxy on %s\n", localAddr) + return + default: + conn, err := listener.Accept() + if err != nil { + log.Printf("[ERROR] Failed to accept connection: %v\n", err) + continue + } + go handleTCPConnection(conn, targetAddr) } - - go handleTCPConnection(conn, targetAddr) } } -func startUDPProxy(localAddr, targetAddr string) { +func startUDPProxy(ctx context.Context, wg *sync.WaitGroup, localAddr, targetAddr string) { + defer wg.Done() + localConn, err := net.ListenPacket("udp", localAddr) if err != nil { - log.Fatalf("Unable to listen on %s: %v\n", localAddr, err) + log.Fatalf("[ERROR] Unable to listen on %s: %v\n", localAddr, err) } defer localConn.Close() remoteAddr, err := net.ResolveUDPAddr("udp", targetAddr) if err != nil { - log.Fatalf("Unable to resolve target address: %v\n", err) + log.Fatalf("[ERROR] Unable to resolve target address: %v\n", err) } buf := make([]byte, 4096) - log.Printf("Listening on %s (UDP), forwarding to %s\n", localAddr, targetAddr) + log.Printf("[INFO] Listening on %s (UDP), forwarding to %s\n", localAddr, targetAddr) for { - n, addr, err := localConn.ReadFrom(buf) - if err != nil { - log.Printf("Failed to read from connection: %v\n", err) - continue + select { + case <-ctx.Done(): + log.Printf("[INFO] Shutting down UDP proxy on %s\n", localAddr) + return + default: + n, addr, err := localConn.ReadFrom(buf) + if err != nil { + log.Printf("[ERROR] Failed to read from connection: %v\n", err) + continue + } + + go func(data []byte, addr net.Addr) { + remoteConn, err := net.DialUDP("udp", nil, remoteAddr) + if err != nil { + log.Printf("[ERROR] Unable to connect to target: %v\n", err) + return + } + defer remoteConn.Close() + + _, err = remoteConn.Write(data) + if err != nil { + log.Printf("[ERROR] Failed to write to target: %v\n", err) + return + } + + n, _, err := remoteConn.ReadFrom(data) + if err != nil { + log.Printf("[ERROR] Failed to read from target: %v\n", err) + return + } + + _, err = localConn.WriteTo(data[:n], addr) + if err != nil { + log.Printf("[ERROR] Failed to write back to source: %v\n", err) + } + }(buf[:n], addr) } - - go func(data []byte, addr net.Addr) { - remoteConn, err := net.DialUDP("udp", nil, remoteAddr) - if err != nil { - log.Printf("Unable to connect to target: %v\n", err) - return - } - defer remoteConn.Close() - - _, err = remoteConn.Write(data) - if err != nil { - log.Printf("Failed to write to target: %v\n", err) - return - } - - n, _, err := remoteConn.ReadFrom(data) - if err != nil { - log.Printf("Failed to read from target: %v\n", err) - return - } - - _, err = localConn.WriteTo(data[:n], addr) - if err != nil { - log.Printf("Failed to write back to source: %v\n", err) - } - }(buf[:n], addr) } } func main() { - if len(os.Args) < 2 { - log.Fatalf("Usage: %s ", os.Args[0]) + flag.Parse() + configPath := os.Getenv("GOPROXY_CONFIG") + if configPath == "" || flag.Arg(0) == "" { + configPath = "goproxy.json" // Default path for Docker } - configFile := os.Args[1] - - configData, err := ioutil.ReadFile(configFile) + configData, err := os.ReadFile(configPath) if err != nil { - log.Fatalf("Failed to read config file: %v\n", err) + log.Fatalf("[ERROR] Failed to read config file (%s): %v\n", configPath, err) } var config ProxyConfig if err := json5.Unmarshal(configData, &config); err != nil { - log.Fatalf("Failed to parse config file: %v\n", err) + log.Fatalf("[ERROR] Failed to parse config file: %v\n", err) } + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + for _, proxy := range config.Proxy { - // Default type to "both" if not provided if proxy.Type == "" { proxy.Type = "both" } - // Default local to the port of remote if not provided if proxy.Local == "" { parts := strings.Split(proxy.Remote, ":") if len(parts) == 2 { proxy.Local = ":" + parts[1] } else { - log.Fatalf("Invalid remote address format: %s\n", proxy.Remote) + log.Fatalf("[ERROR] Invalid remote address format: %s\n", proxy.Remote) } } + wg.Add(1) switch proxy.Type { case "tcp": - go startTCPProxy(proxy.Local, proxy.Remote) + go startTCPProxy(ctx, &wg, proxy.Local, proxy.Remote) case "udp": - go startUDPProxy(proxy.Local, proxy.Remote) + go startUDPProxy(ctx, &wg, proxy.Local, proxy.Remote) case "both": - go startTCPProxy(proxy.Local, proxy.Remote) - go startUDPProxy(proxy.Local, proxy.Remote) + go startTCPProxy(ctx, &wg, proxy.Local, proxy.Remote) + go startUDPProxy(ctx, &wg, proxy.Local, proxy.Remote) default: - log.Printf("Unknown proxy type: %s\n", proxy.Type) + log.Printf("[WARNING] Unknown proxy type: %s\n", proxy.Type) + wg.Done() } } - select {} + // Handle termination signals + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + <-sigChan + log.Println("[INFO] Shutting down proxy server...") + cancel() + wg.Wait() + log.Println("[INFO] Proxy server stopped.") }