go语言实现web反向代理(二)
在上一篇文章《go语言实现web反向代理(一)》中实现了基本的功能,但是有一个问题:在broker
之后的web服务器看到的客户端地址是broker
的地址,而不是真正的客户端。因此作为改进需要实现将真正的客户端地址传递给web服务器的功能。
实现方式比较简单,就是修改客户端传到broker
的请求头,在里面添加上客户端真实地址信息再传递给web服务器。
其中X-Real-IP
就是客户端的真实地址。
func handleHTTPClient(clientConn net.Conn) {
if clientConn == nil {
log.Printf("nil client")
return
}
defer clientConn.Close()
clientAddr := clientConn.RemoteAddr().String()
log.Printf("accept new connection from %v", clientAddr)
// 存储解析出的虚拟主机名
var host string
// 要解析http头,就要读出数据,而这部分数据还必需发送到虚拟主机
// 因此还需要将它们缓存起来
const maxBufferSize = 4096
buffer := bytes.NewBuffer(make([]byte, 0, maxBufferSize))
reader := bufio.NewReader(clientConn)
for {
if buffer.Len() > maxBufferSize {
log.Printf("err: invalid http header, the header is too large")
return
}
data, _, err := reader.ReadLine()
if err != nil {
if err == io.EOF {
break
}
}
line := string(data)
buffer.WriteString(string(line) + "\r\n")
if len(line) < 6 {
continue
}
prefix := strings.ToLower(line[:5])
if prefix != "host:" {
continue
}
fields := strings.Split(line, ":")
if len(fields) != 2 && len(fields) != 3 {
log.Printf("err: invalid host field in http header: %v", line)
return
}
host = fields[1]
// 顺便插入客户端的真实ip:port信息到http头里
clientIP, _, _ := net.SplitHostPort(clientAddr)
buffer.WriteString("X-Real-IP: " + clientIP + "\r\n")
break
}
if len(host) == 0 {
log.Printf("err: empty hostname")
return
}
// 找出虚拟主机地址
hostAddr, err := getVirtualHTTPHostAddr(host)
if err != nil {
log.Printf("err: unsupport host %v, err: %v", host, err)
return
}
// 连接虚拟主机
hostConn, err := net.Dial("tcp", hostAddr)
if err != nil {
log.Printf("err: connect to virtual host %v[%v] failed: %v", host, hostAddr, err)
return
}
defer hostConn.Close()
localIP, _, _ := net.SplitHostPort(hostConn.LocalAddr().String())
buffer.WriteString("X-Forwarded-For: " + localIP + "\r\n")
// client的链接可能单方面关闭
// 为了避免另一方向的数据拷贝中断
// 将他们放进两个goroutine并等待
var wg sync.WaitGroup
wg.Add(2)
// 再将余下的数据传递虚拟主机
go func() {
defer wg.Done()
defer hostConn.(*net.TCPConn).CloseWrite()
defer func() {
if c, ok := clientConn.(*net.TCPConn); ok {
c.CloseRead()
}
}()
_, err = io.Copy(hostConn, io.MultiReader(buffer, reader))
log.Printf("copy from client %v to host %v. local %v err %v\n", clientAddr, hostAddr, hostConn.LocalAddr().String(), err)
}()
// 将虚拟主机返回的数据传递给前端
go func() {
defer wg.Done()
defer hostConn.(*net.TCPConn).CloseRead()
defer func() {
switch c := clientConn.(type) {
case *net.TCPConn:
c.CloseWrite()
case *tls.Conn:
c.CloseWrite()
default:
}
}()
_, err = io.Copy(clientConn, hostConn)
log.Printf("copy from host %v to client %v. local %v err %v\n", hostAddr, clientAddr, hostConn.LocalAddr().String(), err)
}()
wg.Wait()
}
至于这段代码没有沿用之前的scanner
的方式的原因,是因为执行了scanner.Scan()
之后会根据缓冲区大小读取并缓存clientConn
里的数据。因此,如果在for
循环解析出了Host
字段之后再将余下的数据传到web服务器,就会发现中间丢失了一段数据。
func handleHTTPClient(clientConn net.Conn) {
// ...
const maxBufferSize = 4096
buffer := bytes.NewBuffer(make([]byte, 0, maxBufferSize))
reader := io.NewReader(clientConn)
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
if buffer.Len() > maxBufferSize {
log.Printf("err: invalid http header, the header is too large")
return
}
line := scanner.Text()
// ...
}
// ...
}
http部分的比较好解决,接下来就是https的部分,https的数据是加密的,因此需要解密,因此就需要配置后端所有虚拟主机的证书:
// virtualHTTPSHosts 是一个类似 map[string]VirtualHostConfig的字典
// VirtualHostConfig 是一个类似如下定义的结构体
//type VirtualHostConfig struct {
// Domain string
// Host string
// Cert string
// Key string
//}
func httpsServer() {
// 加载所有的证书
tlsCfg := &tls.Config{}
for _, cfg := range virtualHTTPSHosts {
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
if err != nil {
log.Fatal(err)
}
tlsCfg.Certificates = append(tlsCfg.Certificates, cert)
}
tlsCfg.BuildNameToCertificate()
tlsCfg.Time = time.Now
tlsCfg.Rand = rand.Reader
// 监听443端口
listener, err := net.Listen("tcp", ":443")
if err != nil {
log.Fatalf("listen failed: %v", err)
}
for {
connp, err := listener.Accept()
if err != nil {
log.Printf("accept new client failed: %v", err)
continue
}
// 开始服务每一个https连接
go func() {
client := tls.Server(connp, tlsCfg)
handleHTTPClient(client)
}()
}
}
这里broker
相当于解密了https的所有数据,其后的处理就完全可以复用处理http请求的逻辑。
最后,有时候需要强制这个网站按https访问,因此我们可以将所有http的请求全部重定向到https。
func main() {
var forceHTTPS bool
flag.BoolVar(&forceHTTPS, "forcehttps", true, "use https only")
flag.Parse()
if forceHTTPS {
go httpForceHTTPS()
} else {
go httpServer()
}
httpsServer()
}
func httpForceHTTPS() {
var err error
defer log.Printf("http server exited. err: ", err)
m := http.NewServeMux()
m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
url := "https://" + r.Host + r.URL.Path
if len(r.URL.RawQuery) > 0 {
url += "?" + r.URL.RawQuery
}
http.Redirect(w, r, url, http.StatusSeeOther)
})
err = http.ListenAndServe(":80", m)
}
func httpServer() {
listener, err := net.Listen("tcp", ":80")
if err != nil {
log.Fatalf("listen failed: %v", err)
}
for {
client, err := listener.Accept()
if err != nil {
log.Printf("accept new client failed: %v", err)
continue
}
go handleHTTPClient(client)
}
}