From: AceVest Date: Wed, 27 Mar 2019 16:25:15 +0000 (+0800) Subject: add port forward X-Git-Url: http://zhaoyanbai.com/repos/icons/static/gitweb.css?a=commitdiff_plain;h=359a5ab10cd8c6d4cb6f6b5bffbec41ad661a99a;p=acecode.git add port forward --- diff --git a/tools/hack/port_forward.go b/tools/hack/port_forward.go new file mode 100644 index 0000000..5deb1e4 --- /dev/null +++ b/tools/hack/port_forward.go @@ -0,0 +1,116 @@ +/* + * ------------------------------------------------------------------------ + * File Name: port_forward.go + * Author: Zhao Yanbai + * 2019-03-27 22:53:06 星期三 CST + * Description: none + * ------------------------------------------------------------------------ + */ +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net" + "os" + // "sync" +) + +var logDebug *log.Logger +var logInfo *log.Logger +var logImpt *log.Logger +var logErr *log.Logger + +func CreateLogger(path string) error { + logfile, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + fmt.Printf("failed open log file: %s\n", path) + return err + } + + logDebug = log.New(logfile, "[DEBUG]", log.Ldate|log.Lmicroseconds|log.Lshortfile) + logInfo = log.New(logfile, "[INFO]", log.Ldate|log.Lmicroseconds|log.Lshortfile) + logImpt = log.New(logfile, "[IMPT]", log.Ldate|log.Lmicroseconds|log.Lshortfile) + logErr = log.New(logfile, "[ERROR]", log.Ldate|log.Lmicroseconds|log.Lshortfile) + + logDebug.SetOutput(io.Writer(logfile)) + //logInfo.SetOutput(io.Writer(logfile)) + logInfo.SetOutput(io.MultiWriter(os.Stdout, logfile)) + logImpt.SetOutput(io.MultiWriter(os.Stdout, logfile)) + logErr.SetOutput(io.MultiWriter(os.Stderr, logfile)) + + return nil +} + +func main() { + + var local string + var remote string + var logPath string + flag.StringVar(&local, "local", "localhost:80", "local addr ") + flag.StringVar(&remote, "remote", "localhost:8080", "remote addr ") + flag.StringVar(&logPath, "log", "./port_forward.log", "log file path") + flag.Parse() + + err := CreateLogger(logPath) + if err != nil { + log.Fatal("failed create logger") + } + + defer logInfo.Println("Program Exited...") + + logImpt.Printf("forward from %s to %s\n", local, remote) + + listener, err := net.Listen("tcp", local) + if err != nil { + logErr.Println(err) + return + } + + for { + conn, err := listener.Accept() + if err != nil { + logErr.Println(err) + continue + } + + logInfo.Printf("new connection from %s\n", conn.RemoteAddr()) + + go handleNewConn(conn, remote) + } +} + +func handleNewConn(cli_conn net.Conn, remote string) { + defer cli_conn.Close() + + svr_conn, err := net.Dial("tcp", remote) + if err != nil { + logErr.Println(err) + return + } + + defer svr_conn.Close() + + //var wg sync.WaitGroup + //wg.Add(2) + end := make(chan bool) + + go func() { + io.Copy(cli_conn, svr_conn) + //wg.Done() + end <- true + }() + + go func() { + io.Copy(svr_conn, cli_conn) + //wg.Done() + end <- true + }() + + <-end + //wg.Wait() + + logInfo.Printf("connection from %s closed\n", cli_conn.RemoteAddr()) +}