ZYB ARTICLES REPOS

go语言实现web反向代理(二)

在上一篇文章《go语言实现web反向代理(一)》中实现了基本的功能,但是有一个问题:在broker之后的web服务器看到的客户端地址是broker的地址,而不是真正的客户端。因此作为改进需要实现将真正的客户端地址传递给web服务器的功能。

实现方式比较简单,就是修改客户端传到broker的请求头,在里面添加上客户端真实地址信息再传递给web服务器。

其中X-Real-IP就是客户端的真实地址。

func handleHTTPClient(clientConn net.Conn) {
	if clientConn == nil {
		log.Printf("nil client")
		return
	}

	defer clientConn.Close()

	clientAddr := clientConn.RemoteAddr().String()
	log.Printf("accept new connection from %v", clientAddr)

	// 存储解析出的虚拟主机名
	var host string

	// 要解析http头,就要读出数据,而这部分数据还必需发送到虚拟主机
	// 因此还需要将它们缓存起来
	const maxBufferSize = 4096
	buffer := bytes.NewBuffer(make([]byte, 0, maxBufferSize))

	reader := bufio.NewReader(clientConn)

	for {
		if buffer.Len() > maxBufferSize {
			log.Printf("err: invalid http header, the header is too large")
			return
		}

		data, _, err := reader.ReadLine()
		if err != nil {
			if err == io.EOF {
				break
			}
		}
		line := string(data)
		buffer.WriteString(string(line) + "\r\n")
		if len(line) < 6 {
			continue
		}

		prefix := strings.ToLower(line[:5])
		if prefix != "host:" {
			continue
		}

		fields := strings.Split(line, ":")
		if len(fields) != 2 && len(fields) != 3 {
			log.Printf("err: invalid host field in http header: %v", line)
			return
		}

		host = fields[1]

		// 顺便插入客户端的真实ip:port信息到http头里
		clientIP, _, _ := net.SplitHostPort(clientAddr)
		buffer.WriteString("X-Real-IP: " + clientIP + "\r\n")

		break
	}

	if len(host) == 0 {
		log.Printf("err: empty hostname")
		return
	}

	// 找出虚拟主机地址
	hostAddr, err := getVirtualHTTPHostAddr(host)
	if err != nil {
		log.Printf("err: unsupport host %v, err: %v", host, err)
		return
	}

	// 连接虚拟主机
	hostConn, err := net.Dial("tcp", hostAddr)
	if err != nil {
		log.Printf("err: connect to virtual host %v[%v] failed: %v", host, hostAddr, err)
		return
	}

	defer hostConn.Close()

	localIP, _, _ := net.SplitHostPort(hostConn.LocalAddr().String())
	buffer.WriteString("X-Forwarded-For: " + localIP + "\r\n")

	// client的链接可能单方面关闭
	// 为了避免另一方向的数据拷贝中断
	// 将他们放进两个goroutine并等待
	var wg sync.WaitGroup
	wg.Add(2)

	// 再将余下的数据传递虚拟主机
	go func() {
		defer wg.Done()
		defer hostConn.(*net.TCPConn).CloseWrite()
		defer func() {
			if c, ok := clientConn.(*net.TCPConn); ok {
				c.CloseRead()
			}
		}()

		_, err = io.Copy(hostConn, io.MultiReader(buffer, reader))
		log.Printf("copy from client %v to host %v. local %v err %v\n", clientAddr, hostAddr, hostConn.LocalAddr().String(), err)
	}()

	// 将虚拟主机返回的数据传递给前端
	go func() {
		defer wg.Done()
		defer hostConn.(*net.TCPConn).CloseRead()
		defer func() {
			switch c := clientConn.(type) {
			case *net.TCPConn:
				c.CloseWrite()
			case *tls.Conn:
				c.CloseWrite()
			default:
			}
		}()

		_, err = io.Copy(clientConn, hostConn)
		log.Printf("copy from host %v to client %v. local %v err %v\n", hostAddr, clientAddr, hostConn.LocalAddr().String(), err)
	}()

	wg.Wait()
}

至于这段代码没有沿用之前的scanner的方式的原因,是因为执行了scanner.Scan()之后会根据缓冲区大小读取并缓存clientConn里的数据。因此,如果在for循环解析出了Host字段之后再将余下的数据传到web服务器,就会发现中间丢失了一段数据。

func handleHTTPClient(clientConn net.Conn) {
	// ...

	const maxBufferSize = 4096
	buffer := bytes.NewBuffer(make([]byte, 0, maxBufferSize))

	reader := io.NewReader(clientConn)
	scanner := bufio.NewScanner(reader)

	for scanner.Scan() {
		if buffer.Len() > maxBufferSize {
			log.Printf("err: invalid http header, the header is too large")
			return
		}

		line := scanner.Text()
		// ...
	}

	// ...
}

http部分的比较好解决,接下来就是https的部分,https的数据是加密的,因此需要解密,因此就需要配置后端所有虚拟主机的证书:

// virtualHTTPSHosts 是一个类似 map[string]VirtualHostConfig的字典
// VirtualHostConfig 是一个类似如下定义的结构体
//type VirtualHostConfig struct {
//	Domain string
//	Host   string
//	Cert   string
//	Key    string
//}
func httpsServer() {
	// 加载所有的证书
	tlsCfg := &tls.Config{}
	for _, cfg := range virtualHTTPSHosts {
		cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
		if err != nil {
			log.Fatal(err)
		}
		tlsCfg.Certificates = append(tlsCfg.Certificates, cert)
	}
	tlsCfg.BuildNameToCertificate()
	tlsCfg.Time = time.Now
	tlsCfg.Rand = rand.Reader

	// 监听443端口
	listener, err := net.Listen("tcp", ":443")
	if err != nil {
		log.Fatalf("listen failed: %v", err)
	}

	for {
		connp, err := listener.Accept()
		if err != nil {
			log.Printf("accept new client failed: %v", err)
			continue
		}

		// 开始服务每一个https连接
		go func() {
			client := tls.Server(connp, tlsCfg)
			handleHTTPClient(client)
		}()
	}
}

这里broker相当于解密了https的所有数据,其后的处理就完全可以复用处理http请求的逻辑。

最后,有时候需要强制这个网站按https访问,因此我们可以将所有http的请求全部重定向到https。

func main() {
	var forceHTTPS bool
	flag.BoolVar(&forceHTTPS, "forcehttps", true, "use https only")
	flag.Parse()

	if forceHTTPS {
		go httpForceHTTPS()
	} else {
		go httpServer()
	}
	httpsServer()
}

func httpForceHTTPS() {
	var err error
	defer log.Printf("http server exited. err: ", err)

	m := http.NewServeMux()
	m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
		url := "https://" + r.Host + r.URL.Path

		if len(r.URL.RawQuery) > 0 {
			url += "?" + r.URL.RawQuery
		}

		http.Redirect(w, r, url, http.StatusSeeOther)
	})

	err = http.ListenAndServe(":80", m)
}

func httpServer() {
	listener, err := net.Listen("tcp", ":80")
	if err != nil {
		log.Fatalf("listen failed: %v", err)
	}

	for {
		client, err := listener.Accept()
		if err != nil {
			log.Printf("accept new client failed: %v", err)
			continue
		}

		go handleHTTPClient(client)
	}
}