go语言实现web反向代理(三)
在上一篇文章《go语言实现web反向代理(二)》完成了通过broker
向web服务器传递客户端真实IP的功能,但这个功能其实只实现了一半。
因为一条tcp连接上不一定只传递一次http请求,如果浏览器对连接复用,传递多次请求,那么按上文所述的逻辑,就会出现在同一个tcp连接上,只有第一次的http请求里带有客户端真实IP,后续所有请求都没有这个信息的情况。
因此本文着手解决这个问题。
要解决这个问题首先要解决如何在一个tcp连接中将一个或多个http请求分隔出来? 本文不再用手动硬解的方式,而是利用go语言自带的现成的库来解决。
以下示例代码为清晰表达,省去所有枝叶代码。
func handleHTTPClient(clientConn net.Conn) {
clientAddr := clientConn.RemoteAddr().String()
reader := bufio.NewReader(clientConn)
for {
// 从tcp流中解析出一个http.Request
req, err := http.ReadRequest(reader)
// 解析失败则退出
if err != nil {
if err != io.EOF {
log.Printf("read request, err: %v", err)
return
}
break
}
// 获得http头中的host信息
// 注意Host信息必需通过这种方式获取
// req.Header.Get("Host")是获取不到的
// host := req.Host
// 在请求头中插入客户端的真实ip信息到http头里
clientIP, _, _ := net.SplitHostPort(clientAddr)
req.Header.Add("X-Real-IP", clientIP)
// 添加其它信息
// ...
// 将新的http.Request写入buffer
buffer := bytes.NewBuffer(make([]byte, 0, 4096))
req.Write(buffer)
// 创建一个broker到web服务器的tcp连接
hostConn, _ := net.Dial("tcp", "web服务器的地址")
// 将新的http.Request数据传递给后端的web服务器
io.Copy(hostConn, buffer)
}
}
更详细的逻辑如下
func handleHTTPClient(clientConn net.Conn) {
if clientConn == nil {
log.Printf("nil client")
return
}
var err error
defer clientConn.Close()
clientAddr := clientConn.RemoteAddr().String()
log.Printf("accept new connection from %v", clientAddr)
// client的链接可能单方面关闭
// 为了避免另一方向的数据拷贝中断
// 将他们放进两个goroutine并等待
var wg sync.WaitGroup
wg.Add(2)
var chanConn = make(chan net.Conn, 1)
// 将虚拟主机返回的数据传递给前端
go func() {
defer wg.Done()
hostConn := <-chanConn
if hostConn == nil {
return
}
defer hostConn.(*net.TCPConn).CloseRead()
defer func() {
switch c := clientConn.(type) {
case *net.TCPConn:
c.CloseWrite()
case *tls.Conn:
c.CloseWrite()
default:
}
}()
_, err = io.Copy(clientConn, hostConn)
hostAddr := hostConn.LocalAddr().String()
log.Printf("copy from host %v to client %v. local %v err %v\n", hostAddr, clientAddr, hostConn.LocalAddr().String(), err)
}()
// 再将余下的数据传递虚拟主机
go func() {
defer wg.Done()
// 通知从web服务器拷贝数据到客户端的协程退出
defer close(chanConn)
var hostConn net.Conn
defer func() {
if hostConn != nil {
hostConn.(*net.TCPConn).CloseWrite()
} else {
log.Printf("host conn is nil")
}
}()
defer func() {
if c, ok := clientConn.(*net.TCPConn); ok {
c.CloseRead()
}
}()
// 存储解析出的虚拟主机名
var host string
reader := bufio.NewReader(clientConn)
for {
req, err := http.ReadRequest(reader)
if err != nil {
if err != io.EOF {
log.Printf("read request, err: %v", err)
return
}
break
}
const maxBufferSize = 4096
buffer := bytes.NewBuffer(make([]byte, 0, maxBufferSize))
if len(host) == 0 {
host = req.Host
log.Printf("host: %v", host)
// 找出虚拟主机地址
var hostAddr string
if _, ok := clientConn.(*net.TCPConn); ok {
hostAddr, err = getVirtualHTTPHostAddr(host)
} else {
hostAddr, err = getVirtualHTTPSHostAddr(host)
}
if err != nil {
log.Printf("err: unsupport host %v, err: %v", host, err)
return
}
// 连接虚拟主机
hostConn, err = net.Dial("tcp", hostAddr)
if err != nil {
log.Printf("err: connect to virtual host %v[%v] failed: %v", host, hostAddr, err)
return
}
chanConn <- hostConn
}
// 插入客户端的真实ip信息到http头里
clientIP, _, _ := net.SplitHostPort(clientAddr)
req.Header.Add("X-Real-IP", clientIP)
// 插入broker信息
localIP, _, _ := net.SplitHostPort(hostConn.LocalAddr().String())
prevXForwardFor := req.Header.Get("X-Forwarded-For")
if len(prevXForwardFor) != 0 {
prevXForwardFor = "," + prevXForwardFor
}
req.Header.Add("X-Forwarded-For", localIP+prevXForwardFor)
req.Write(buffer)
_, err = io.Copy(hostConn, buffer)
log.Printf("copy from client %v to host %v. local %v err %v\n", clientAddr, "", hostConn.LocalAddr().String(), err)
}
}()
wg.Wait()
}
除此之外如果想在这个反向代理上跑websocket
协议会发现还是不行,这这是因为websocket
第一个http.Request
之后基本就是二进制数据了,后续不太会有正常的http.Request
头了,所以如果websocket
请求走到这个broker
上,除第一个请求头外,第二次尝试解析请求头的时候就会失败退出,因此针对这种情况可以简单实现如下:
for {
// ...
req.Header.Add("X-Forwarded-For", localIP+prevXForwardFor)
req.Write(buffer)
_, err = io.Copy(hostConn, buffer)
log.Printf("copy from client %v to server %v. local %v err %v\n", clientAddr, hostConn.RemoteAddr().String(), hostConn.LocalAddr().String(), err)
// 如果是websocket就不再进行后续解析,直接拷贝数据
if req.Header.Get("Upgrade") == "websocket" {
log.Printf("begin to copy websocket data from client %v to server %v.", clientAddr, hostConn.RemoteAddr().String())
_, err = io.Copy(hostConn, reader)
log.Printf("copy websocket from client %v to server %v. local %v err %v\n", clientAddr, hostConn.RemoteAddr().String(), hostConn.LocalAddr().String(), err)
break
}
}
完整源代码请参考https://github.com/acevest/webbroker