ZYB ARTICLES REPOS

go语言实现web反向代理(一)

背景

现在我手里有很多个域名,我想把这些域名各自映射到不同的web服务程序,但是只有一个云主机,也就是说只有一个80端口,不同的web服务程序又不能共享80端口,将各个web服务程序放在一起又太乱。因此决定写一个反向代理broker,将80端口收到的所有http请求转发到对应的虚拟web服务程序。以实现一个主机支撑多个域名访问不同的web服务程序。

对于https443端口也同理。

HTTP 的部分

首先监听80端口,对来自客户端的每个连接请求开一个协程进行处理。

func httpHandler() {
	listener, err := net.Listen("tcp", ":80")
	if err != nil {
		log.Fatalf("listen failed: %v", err)
	}

	for {
		client, err := listener.Accept()
		if err != nil {
			log.Printf("accept new client failed: %v", err)
			continue
		}

		go handleClient(client)
	}
}

对于每个连接请求的第一次HTTP请求的时候,我们需要分析HTTP头里的Host字段,通过这个字段我们才能知道客户端具体需要访问的是哪个虚拟主机。

在解析的时候需要对缓冲区进行限制,以避免有人恶意提交非法数据

func handleHTTPClient(clientConn net.Conn) {
	if clientConn == nil {
		log.Printf("nil client")
		return
	}

	defer clientConn.Close()

	clientAddr := clientConn.RemoteAddr().String()
	log.Printf("accept new connection from %v", clientAddr)

	// 存储解析出的虚拟主机名
	var host string

	// 要解析http头,就要读出数据,而这部分数据还必需发送到虚拟主机
	// 因此还需要将它们缓存起来
	const maxBufferSize = 4096
	buffer := bytes.NewBuffer(make([]byte, 0, maxBufferSize))

	// 开始解析Host
	scanner := bufio.NewScanner(clientConn)

	// 限制扫描的时候分配的缓冲区大小
	// 如果不限制,则会使用scanner默认的缓冲区大小:MaxScanTokenSize = 64 * 1024
	scannerBuffer := make([]byte, 256)
	scanner.Buffer(scannerBuffer, len(scannerBuffer))

	for scanner.Scan() {
		if buffer.Len() > maxBufferSize {
			log.Printf("err: invalid http header, the header is too large")
			return
		}

		line := scanner.Text()
		buffer.WriteString(line + "\r\n")

		if len(line) < 6 {
			continue
		}

		prefix := strings.ToLower(line[:5])
		if prefix != "host:" {
			continue
		}

		fields := strings.Split(line, ":")
		if len(fields) != 2 && len(fields) != 3 {
			log.Printf("err: invalid host field in http header: %v", line)
			return
		}

		host = fields[1]

		break
	}

	if len(host) == 0 {
		log.Printf("err: empty hostname")
		return
    }
    
    // ....
}

在获得请求头里的host之后,我们需要查询配置,它对应的后端服务器的地址应该是多少 以下只是一个简要的示例代码,实际中可以通过读配置文件来解决

var virtualHTTPHostAddr = map[string]string{}
func init() {
    virtualHTTPHostAddr["abc.com"] = "1.2.3.4:7777"
    virtualHTTPHostAddr["xyz.com"] = "1.2.3.4:8888"
}
func getVirtualHTTPHostAddr(host string) (string, error) {
	host = strings.TrimSpace(host)
	v, ok := virtualHTTPHostAddr[host]
	if !ok {
		return "", fmt.Errorf("can not find %v", host)
	}

	return v, nil
}

有了域名到虚拟主机的映射配置后,我们可以继续编写handleHTTPClient的逻辑,接下来需要完成

建立一个到虚拟主机的连接hostConn,然后在hostConnbroker已经接受的客户端连接clientConn之间转发数据

func handleHTTPClient(clientConn net.Conn) {
	// 省略部分代码见上文
	// ...

	// 找出虚拟主机地址
	hostAddr, err := getVirtualHTTPHostAddr(host)
	if err != nil {
		log.Printf("err: unsupport host %v, err: %v", host, err)
		return
	}

	// 连接虚拟主机
	hostConn, err := net.Dial("tcp", hostAddr)
	if err != nil {
		log.Printf("err: connect to virtual host %v[%v] failed: %v", host, hostAddr, err)
		return
	}

	defer hostConn.Close()

	// client的链接可能单方面关闭
	// 为了避免另一方向的数据拷贝中断
	// 将他们放进两个goroutine并等待
	var wg sync.WaitGroup
	wg.Add(2)

	// 再将余下的数据传递虚拟主机
	go func() {
		defer wg.Done()
		defer hostConn.(*net.TCPConn).CloseWrite()
		defer clientConn.(*net.TCPConn).CloseRead()
		_, err = io.Copy(hostConn, io.MultiReader(buffer, clientConn))
		log.Printf("copy from client %v to host %v. err %v", clientAddr, hostAddr, err)
	}()

	// 将虚拟主机返回的数据传递给前端
	go func() {
		defer wg.Done()
		defer hostConn.(*net.TCPConn).CloseRead()
		defer clientConn.(*net.TCPConn).CloseWrite()
		_, err = io.Copy(clientConn, hostConn)
		log.Printf("copy from host %v to client %v. err %v", hostAddr, clientAddr, err)
	}()

	wg.Wait()
}

HTTPS部分

现在苹果已经限制纯http访问了,因此还有必要支持一下https

但是https的数据是加密的,不像http那样将虚拟主机记在明文的http头的Host字段里

如果https协议需要解密才能得到虚拟主机的话,那就是一个死锁的问题:服务器需要解密才能拿到虚拟主机信息,客户端需要服务器返回才能解密数据,为此TLS协议(ver:1.2)在客户端发往服务器的第一个握手包里带了SNI (Server Name Indication)信息。

所以httpshttp最大的不同就是需要解析SNI,其它逻辑大致相同

同样第一步先监听443端口,接受并处理每一新来自客户端的连接

func httpsHandler() {
	listener, err := net.Listen("tcp", ":443")
	if err != nil {
		log.Fatalf("listen failed: %v", err)
	}

	for {
		client, err := listener.Accept()
		if err != nil {
			log.Printf("accept new client failed: %v", err)
			continue
		}

		go handleHTTPSClient(client)
	}
}

handleHTTPSClienthandleHTTPClient别无二致,只是getSNI部分不同。


func handleHTTPSClient(clientConn net.Conn) {
	if clientConn == nil {
		log.Printf("nil client")
		return
	}

	defer clientConn.Close()

	clientAddr := clientConn.RemoteAddr().String()
	log.Printf("accept new https connection from %v", clientAddr)

	host, reader, err := getSNI(clientConn)
	if err != nil {
		log.Printf("err: parse SNI failed: %v", err)
		return
	}

	// 连接虚拟主机
	hostAddr, err := getVirtualHTTPSHostAddr(host)
	if err != nil {
		log.Printf("err: unsupport host %v, err: %v", host, err)
		return
	}

	hostConn, err := net.Dial("tcp", hostAddr)
	if err != nil {
		log.Printf("err: connect to virtual host %v[%v] failed: %v", host, hostAddr, err)
		return
	}

	defer hostConn.Close()

	var wg sync.WaitGroup
	wg.Add(2)

	go func() {
		defer wg.Done()
		defer hostConn.(*net.TCPConn).CloseWrite()
		defer clientConn.(*net.TCPConn).CloseRead()
		_, err = io.Copy(hostConn, reader)
		log.Printf("copy from client %v to host %v. err %v", clientAddr, hostAddr, err)
	}()

	go func() {
		defer wg.Done()
		defer hostConn.(*net.TCPConn).CloseRead()
		defer clientConn.(*net.TCPConn).CloseWrite()
		_, err = io.Copy(clientConn, hostConn)
		log.Printf("copy from host %v to client %v. err %v", hostAddr, clientAddr, err)
	}()

	wg.Wait()
}

getVirtualHTTPSHostAddr的逻辑与getVirtualHTTPHostAddr几乎一致,略。

getSNI就是获取第一个握手包里的Server Name Indication

func getSNI(reader io.Reader) (string, io.Reader, error) {
	var err error

	buffer := new(bytes.Buffer)
	r := io.TeeReader(reader, buffer)

	var hello *tls.ClientHelloInfo
	err = tls.Server(readOnlyConn{reader: r}, &tls.Config{
		GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
			hello = new(tls.ClientHelloInfo)
			*hello = *argHello
			return nil, nil
		},
	}).Handshake()

	if hello == nil {
		return "", nil, err
	}
	serverName := hello.ServerName

	return serverName, io.MultiReader(buffer, reader), nil
}

为此还实现了一个只读的net.Conn的接口

type readOnlyConn struct {
	reader io.Reader
}

func (c readOnlyConn) Read(p []byte) (int, error)         { return c.reader.Read(p) }
func (c readOnlyConn) Write(p []byte) (int, error)        { return 0, io.ErrClosedPipe }
func (c readOnlyConn) Close() error                       { return nil }
func (c readOnlyConn) LocalAddr() net.Addr                { return nil }
func (c readOnlyConn) RemoteAddr() net.Addr               { return nil }
func (c readOnlyConn) SetDeadline(t time.Time) error      { return nil }
func (c readOnlyConn) SetReadDeadline(t time.Time) error  { return nil }
func (c readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }

解析SNI信息更详细的原理请参考这篇文章 《Writing an SNI Proxy in 115 Lines of Go

最后

httpsHandlerhttpHandler放进main函数就可以正常服务了

func main() {
	go httpsHandler()
	httpHandler()
}