go语言实现web反向代理(一)
背景
现在我手里有很多个域名,我想把这些域名各自映射到不同的web
服务程序,但是只有一个云主机,也就是说只有一个80
端口,不同的web
服务程序又不能共享80
端口,将各个web
服务程序放在一起又太乱。因此决定写一个反向代理broker
,将80
端口收到的所有http
请求转发到对应的虚拟web
服务程序。以实现一个主机支撑多个域名访问不同的web
服务程序。
对于https
的443
端口也同理。
HTTP 的部分
首先监听80端口,对来自客户端的每个连接请求开一个协程进行处理。
func httpHandler() {
listener, err := net.Listen("tcp", ":80")
if err != nil {
log.Fatalf("listen failed: %v", err)
}
for {
client, err := listener.Accept()
if err != nil {
log.Printf("accept new client failed: %v", err)
continue
}
go handleClient(client)
}
}
对于每个连接请求的第一次HTTP
请求的时候,我们需要分析HTTP
头里的Host
字段,通过这个字段我们才能知道客户端具体需要访问的是哪个虚拟主机。
在解析的时候需要对缓冲区进行限制,以避免有人恶意提交非法数据
func handleHTTPClient(clientConn net.Conn) {
if clientConn == nil {
log.Printf("nil client")
return
}
defer clientConn.Close()
clientAddr := clientConn.RemoteAddr().String()
log.Printf("accept new connection from %v", clientAddr)
// 存储解析出的虚拟主机名
var host string
// 要解析http头,就要读出数据,而这部分数据还必需发送到虚拟主机
// 因此还需要将它们缓存起来
const maxBufferSize = 4096
buffer := bytes.NewBuffer(make([]byte, 0, maxBufferSize))
// 开始解析Host
scanner := bufio.NewScanner(clientConn)
// 限制扫描的时候分配的缓冲区大小
// 如果不限制,则会使用scanner默认的缓冲区大小:MaxScanTokenSize = 64 * 1024
scannerBuffer := make([]byte, 256)
scanner.Buffer(scannerBuffer, len(scannerBuffer))
for scanner.Scan() {
if buffer.Len() > maxBufferSize {
log.Printf("err: invalid http header, the header is too large")
return
}
line := scanner.Text()
buffer.WriteString(line + "\r\n")
if len(line) < 6 {
continue
}
prefix := strings.ToLower(line[:5])
if prefix != "host:" {
continue
}
fields := strings.Split(line, ":")
if len(fields) != 2 && len(fields) != 3 {
log.Printf("err: invalid host field in http header: %v", line)
return
}
host = fields[1]
break
}
if len(host) == 0 {
log.Printf("err: empty hostname")
return
}
// ....
}
在获得请求头里的host之后,我们需要查询配置,它对应的后端服务器的地址应该是多少 以下只是一个简要的示例代码,实际中可以通过读配置文件来解决
var virtualHTTPHostAddr = map[string]string{}
func init() {
virtualHTTPHostAddr["abc.com"] = "1.2.3.4:7777"
virtualHTTPHostAddr["xyz.com"] = "1.2.3.4:8888"
}
func getVirtualHTTPHostAddr(host string) (string, error) {
host = strings.TrimSpace(host)
v, ok := virtualHTTPHostAddr[host]
if !ok {
return "", fmt.Errorf("can not find %v", host)
}
return v, nil
}
有了域名到虚拟主机的映射配置后,我们可以继续编写handleHTTPClient
的逻辑,接下来需要完成
建立一个到虚拟主机的连接hostConn
,然后在hostConn
和broker
已经接受的客户端连接clientConn
之间转发数据
func handleHTTPClient(clientConn net.Conn) {
// 省略部分代码见上文
// ...
// 找出虚拟主机地址
hostAddr, err := getVirtualHTTPHostAddr(host)
if err != nil {
log.Printf("err: unsupport host %v, err: %v", host, err)
return
}
// 连接虚拟主机
hostConn, err := net.Dial("tcp", hostAddr)
if err != nil {
log.Printf("err: connect to virtual host %v[%v] failed: %v", host, hostAddr, err)
return
}
defer hostConn.Close()
// client的链接可能单方面关闭
// 为了避免另一方向的数据拷贝中断
// 将他们放进两个goroutine并等待
var wg sync.WaitGroup
wg.Add(2)
// 再将余下的数据传递虚拟主机
go func() {
defer wg.Done()
defer hostConn.(*net.TCPConn).CloseWrite()
defer clientConn.(*net.TCPConn).CloseRead()
_, err = io.Copy(hostConn, io.MultiReader(buffer, clientConn))
log.Printf("copy from client %v to host %v. err %v", clientAddr, hostAddr, err)
}()
// 将虚拟主机返回的数据传递给前端
go func() {
defer wg.Done()
defer hostConn.(*net.TCPConn).CloseRead()
defer clientConn.(*net.TCPConn).CloseWrite()
_, err = io.Copy(clientConn, hostConn)
log.Printf("copy from host %v to client %v. err %v", hostAddr, clientAddr, err)
}()
wg.Wait()
}
HTTPS部分
现在苹果已经限制纯http
访问了,因此还有必要支持一下https
。
但是https
的数据是加密的,不像http
那样将虚拟主机记在明文的http
头的Host
字段里
如果https
协议需要解密才能得到虚拟主机的话,那就是一个死锁的问题:服务器需要解密才能拿到虚拟主机信息,客户端需要服务器返回才能解密数据,为此TLS
协议(ver:1.2
)在客户端发往服务器的第一个握手包里带了SNI
(Server Name Indication
)信息。
所以https
与http
最大的不同就是需要解析SNI
,其它逻辑大致相同
同样第一步先监听443
端口,接受并处理每一新来自客户端的连接
func httpsHandler() {
listener, err := net.Listen("tcp", ":443")
if err != nil {
log.Fatalf("listen failed: %v", err)
}
for {
client, err := listener.Accept()
if err != nil {
log.Printf("accept new client failed: %v", err)
continue
}
go handleHTTPSClient(client)
}
}
handleHTTPSClient
与handleHTTPClient
别无二致,只是getSNI
部分不同。
func handleHTTPSClient(clientConn net.Conn) {
if clientConn == nil {
log.Printf("nil client")
return
}
defer clientConn.Close()
clientAddr := clientConn.RemoteAddr().String()
log.Printf("accept new https connection from %v", clientAddr)
host, reader, err := getSNI(clientConn)
if err != nil {
log.Printf("err: parse SNI failed: %v", err)
return
}
// 连接虚拟主机
hostAddr, err := getVirtualHTTPSHostAddr(host)
if err != nil {
log.Printf("err: unsupport host %v, err: %v", host, err)
return
}
hostConn, err := net.Dial("tcp", hostAddr)
if err != nil {
log.Printf("err: connect to virtual host %v[%v] failed: %v", host, hostAddr, err)
return
}
defer hostConn.Close()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
defer hostConn.(*net.TCPConn).CloseWrite()
defer clientConn.(*net.TCPConn).CloseRead()
_, err = io.Copy(hostConn, reader)
log.Printf("copy from client %v to host %v. err %v", clientAddr, hostAddr, err)
}()
go func() {
defer wg.Done()
defer hostConn.(*net.TCPConn).CloseRead()
defer clientConn.(*net.TCPConn).CloseWrite()
_, err = io.Copy(clientConn, hostConn)
log.Printf("copy from host %v to client %v. err %v", hostAddr, clientAddr, err)
}()
wg.Wait()
}
getVirtualHTTPSHostAddr
的逻辑与getVirtualHTTPHostAddr
几乎一致,略。
getSNI
就是获取第一个握手包里的Server Name Indication
func getSNI(reader io.Reader) (string, io.Reader, error) {
var err error
buffer := new(bytes.Buffer)
r := io.TeeReader(reader, buffer)
var hello *tls.ClientHelloInfo
err = tls.Server(readOnlyConn{reader: r}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
if hello == nil {
return "", nil, err
}
serverName := hello.ServerName
return serverName, io.MultiReader(buffer, reader), nil
}
为此还实现了一个只读的net.Conn
的接口
type readOnlyConn struct {
reader io.Reader
}
func (c readOnlyConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (c readOnlyConn) Close() error { return nil }
func (c readOnlyConn) LocalAddr() net.Addr { return nil }
func (c readOnlyConn) RemoteAddr() net.Addr { return nil }
func (c readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (c readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (c readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
解析SNI
信息更详细的原理请参考这篇文章 《Writing an SNI Proxy in 115 Lines of Go》
最后
将httpsHandler
和httpHandler
放进main
函数就可以正常服务了
func main() {
go httpsHandler()
httpHandler()
}