# A simple reverse proxy that forwards connections to specified backends # based on the first received Host: header. # # by J.D. Zamfirescu import select import socket import getopt, sys import re import asyncore import os from collections import deque testers = [] class AbstractHandler: def doConnect(self, sock): pass def doClose(self, sock): pass def doAccept(self, sock): pass def doRead(self, sock): pass def doWrite(self, sock): pass def writable(self, sock): return True def readable(self, sock): return True class Socket(asyncore.dispatcher): def __init__(self, parent, sock=None): asyncore.dispatcher.__init__(self, sock) self.parent = parent if sock == None: self.create_socket(socket.AF_INET, socket.SOCK_STREAM) def handle_connect(self): self.parent.doConnect(self) def handle_close(self): self.parent.doClose(self) def handle_accept(self): self.parent.doAccept(self) def handle_read(self): self.parent.doRead(self) def handle_write(self): self.parent.doWrite(self) def writable(self): return self.parent.writable(self) def readable(self): return self.parent.readable(self) class ProxyObject(AbstractHandler): "A class to handle connecting to real server." def __init__(self, client, data, host, remote_info): self.reqhost = host self.client = client self.client.parent = self self.data = data self.host = None self.remote_info = remote_info for t in testers: if t[0].match(host): self.host = t[1:3] break self.toClient = deque() self.toServer = deque() if self.host is None: self.toClient.append("HTTP/1.0 404 NOT FOUND\r\nConnection: close\r\n\r\nUnknown host: "+host+".\r\n") self.clientReading = True self.serverReading = False else: self.server = Socket(self) self.server.connect(self.host) self.toServer.append(data) if showClient: print "CLIENT "+str(remote_info)+" OPEN to: "+host+", with:\n=====\n"+data+"\n=====\n" self.clientReading = True self.serverReading = True self.serverWriting = False def doConnect(self, sock): self.serverWriting = True def doRead(self, sock): q = None if sock is self.client: q = self.toServer def p(x): if showClient: print "CLIENT "+str(self.remote_info)+" SEND to: "+str(self.host)+", with:\n=====\n"+x+"\n=====\n" elif sock is self.server: q = self.toClient def p(x): if showServer: print "SERVER "+str(self.remote_info)+" SEND from: "+str(self.host)+", with:\n=====\n"+x+"\n=====\n" else: self.doClose(sock) if q is not None: try: data = sock.recv(8192) p(data) if len(data) == 0: sock.close() else: q.append(data) except socket.error, err: if err[0] == 61: # connection refused self.toClient.append("HTTP/1.0 500 ERROR\r\nConnection: close\r\n\r\nCouldn't forward your request for host: "+self.reqhost+".\r\n") sock.close() self.doClose(sock) def doClose(self, sock): if sock is self.client: self.clientReading = False if len(self.toServer) == 0: if "server" in self.__dict__: self.server.close() elif sock is self.server: self.serverReading = False if len(self.toClient) == 0: self.client.close() else: sock.close() def doWrite(self, sock): q = None pred = None if sock is self.client: q = self.toClient pred = self.serverReading elif sock is self.server: q = self.toServer pred = self.clientReading else: self.doClose(sock) if q is not None: data = q.popleft() count = sock.send(data) who = "server" if sock is self.client: who = "client" if (count < len(data)): q.appendleft(data[count:]) else: if len(q) == 0 and not pred: sock.close() def readable(self, sock): if sock is self.client: return self.clientReading elif sock is self.server: return self.serverReading else: return False def writable(self, sock): if sock is self.client: return len(self.toClient) > 0 elif sock is self.server: return len(self.toServer) > 0 else: return False class HostNegotiator(AbstractHandler): "A class to handle host negotiation." def __init__(self, sock, remote_info): self.sock = Socket(self, sock) self.data = "" self.reading = True self.response = "" self.remote_info = remote_info def handleNoHost(self): self.response = "HTTP/1.0 404 NOT FOUND\r\nConnection: close\r\n\r\nYour request didn't include a Host: header.\r\n" def addForwardedForHeader(self, data, pos): return data[:pos]+"\r\nX-Forwarded-For: "+self.remote_info[0]+data[pos:] def findHost(self): hostIndex = self.data.find("\r\nHost: ") if hostIndex < 0: self.handleNoHost() else: hostEnd = self.data.find("\r\n", hostIndex+2) hostLine = self.data[hostIndex:hostEnd] value = hostLine[len("\r\nHost:"):].lstrip().rstrip() ProxyObject(self.sock, self.addForwardedForHeader(self.data, hostEnd), value, self.remote_info) self.reading = False def doClose(self, sock): sock.close() def doRead(self, sock): data = sock.recv(8192); if len(data) == 0: self.doClose(sock) else: self.data += data if self.data.find("\r\n\r\n") >= 0: self.findHost() def doWrite(self, sock): count = sock.send(self.response) self.response = self.response[count:] if len(self.response) == 0: self.reading = False sock.close() def readable(self, sock): if sock is self.sock: return self.reading else: return False def writable(self, sock): return len(self.response) > 0 class Acceptor(AbstractHandler): "A class to accept incoming requests." def __init__(self, port): self.sock = Socket(self) self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.bind(("0.0.0.0", port)) self.sock.listen(1) def doAccept(self, sock): if watcher is not None: watcher.check() c = sock.accept() HostNegotiator(c[0], c[1]) def parseArgs(args): global testers testers = [] if len(args) == 0 or len(args) % 2 != 0: usage() sys.exit(1) pairs = [(args[2*i], args[2*i+1]) for i in range(len(args)/2)] for (r, i) in pairs: testers.append((re.compile(r), i.split(":")[0], int(i.split(":")[1]))) def usage(): print "usage:", sys.argv[0], "[-p port] [-f configFile] [pair1 pair2 ...]" print " pairN: " print " hostRegex - a regex to match the Host: header's value" print " address:port - the server to forward to" class FileWatcher: def __init__(self, path, cb): self.path = path self.lastupdate = os.path.getmtime(self.path) self.cb = cb self.cb(file(self.path).read()) def check(self): if os.path.getmtime(self.path) > self.lastupdate: self.cb(file(self.path).read()) self.lastupdate = os.path.getmtime(self.path) watcher = None showClient = False showServer = False def main(): try: opts, args = getopt.getopt(sys.argv[1:], "hp:f:cs", ["help", "port=", "config=", "showClient", "showServer"]) except getopt.GetoptError, err: print str(err) usage() sys.exit(1) port = 80 configFile = None for o, a in opts: if o in ("-p", "--port"): port = int(a) elif o in ("-h", "--help"): usage() sys.exit() elif o in ("-f", "--config"): configFile = a elif o in ("-c", "--showClient"): global showClient showClient = True elif o in ("-s", "--showServer"): global showServer showServer = True else: usage() sys.exit() if configFile is not None: global watcher watcher = FileWatcher(configFile, lambda d: parseArgs(d.split())) else: parseArgs(args) server = Acceptor(port) print "listening on", port try: asyncore.loop() except KeyboardInterrupt: print "finishing" server.sock.close() if __name__ == "__main__": main()