/*
 * The contents of this file are subject to the AOLserver Public License
 * Version 1.1 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 * http://aolserver.com/.
 *
 * Software distributed under the License is distributed on an "AS IS"
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
 * the License for the specific language governing rights and limitations
 * under the License.
 *
 * The Original Code is AOLserver Code and related documentation
 * distributed by AOL.
 * 
 * The Initial Developer of the Original Code is America Online,
 * Inc. Portions created by AOL are Copyright (C) 1999 America Online,
 * Inc. All Rights Reserved.
 *
 * Alternatively, the contents of this file may be used under the terms
 * of the GNU General Public License (the "GPL"), in which case the
 * provisions of GPL are applicable instead of those above.  If you wish
 * to allow use of your version of this file only under the terms of the
 * GPL and not to allow others to use your version of this file under the
 * License, indicate your decision by deleting the provisions above and
 * replace them with the notice and other provisions required by the GPL.
 * If you do not delete the provisions above, a recipient may use your
 * version of this file under either the License or the GPL.
 */

/* 
 * nsvhr.c --
 *
 *	Virtual Hosting Redirector.
 *
 */

static const char *RCSID = "@(#) $Header: /cvsroot/aolserver/nsvhr/nsvhr.c,v 1.2 2002/11/23 06:59:53 dossy Exp $, compiled: " __DATE__ " " __TIME__;

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/uio.h>
#include <sys/un.h>
#include <assert.h>
#include <unistd.h>

#include "ns.h"
 
#define BUFSIZE 	512
#define MAPS 		"maps"
#define CONFIG_METHOD 	"Method"
#define CONFIG_ERRORURL "ErrorUrl"
#define CONFIG_BUSYURL  "BusyUrl"
#define CONFIG_TIMEOUT	"Timeout"
#define DEFAULT_TIMEOUT 30
#define HTTP_CONF	"http"
#define TCP_CONF	"tcp"
#define UNIX_CONF	"unix"
#define HTTP_EOL	"\r\n"

#define MODULES		"modules"
#define NSSOCK_DRIVER_NAME 	"nssock"
#define NSUNIX_DRIVER_NAME 	"nsunix"

#ifdef HAVE_CMMSG
/*
 * This wraps a cmsghdr and provides space for a file descriptor.
 * It's the modern way of passing file descriptors.  NOTE: this assumes
 * that member cmsg is aligned with the beginning of struct my_cmsghdr.
 */

typedef struct {
    struct cmsghdr cmsg;
    int            sock;
} my_cmsghdr;
#endif

typedef enum {
    TCP,        /* TCP/IP */
    UDS         /* Unix Domain Socket */
} Protocol;

/*
 * This is the location structure of the remote server. 
 */

typedef struct {
    Protocol protocol;		/* the comm protocol		*/
    union {
        struct {
            char *hostname;	/* remote host name 		*/
            int   port;		/* remote port			*/
        } tcp;
        struct {
            char *filename;	/* unix domain socket file name */
        } uds;
    } u;
} Location;


static int LocationSplit(char *url, Location *loc);
static int VHRProc(void *context, Ns_Conn *conn);
static int TCPProxy(Ns_Conn  *conn, Location *loc);
static void SockWrite(int sock, char *string);
static int TimedSockDump(int sock, Ns_Conn *conn, int timeout);
static int UDSProxy(Ns_Conn *conn, Location *loc);

static Tcl_HashTable  map;
static char          *errorUrl, *busyUrl;
static int            gTimeout;

int Ns_ModuleVersion = 1;


/*
 *----------------------------------------------------------------------
 *
 * Ns_ModuleInit --
 *
 *	Nsvhr module init routine.
 *
 * Results:
 *	NS_OK if initialized ok, NS_ERROR otherwise.
 *
 * Side effects:
 *	Registers a callback for everything. Muhaha.
 *
 *----------------------------------------------------------------------
 */

NS_EXPORT int
Ns_ModuleInit(char *server, char *module)
{
    char          *path, *maps;
    Ns_Set        *set;
    int            i;
    char          *host, *url;

    path = Ns_ConfigGetPath(server, module, NULL);
    maps = Ns_ConfigGetPath(server, module, MAPS, NULL);
    if (path == NULL) {
        Ns_Log(Warning, "nsvhr: no config path [ns/server/%s/module/%s]",
            server, module);
        return NS_OK;
    }
    set = Ns_ConfigGetSection(maps);
    if (set == NULL) {
        Ns_Log(Warning, "nsvhr: no config path [ns/server/%s/module/%s/%s]",
            server, module, MAPS);
        return NS_OK;
    }
    Tcl_InitHashTable(&map, TCL_STRING_KEYS);

    for (i=0; i < Ns_SetSize(set); i++) {
        Tcl_HashEntry *hePtr;
        int            new;

        host = Ns_SetKey(set, i);
        url = Ns_SetValue(set, i);
        hePtr = Tcl_CreateHashEntry(&map, host, &new);
        if (new && hePtr != NULL) {
            Location *loc = ns_malloc(sizeof(Location));

            bzero((void *) loc, sizeof(Location));
            if (LocationSplit(url, loc) != NS_OK) {
                return NS_ERROR;
            }

            switch (loc->protocol) {
            case TCP:
                Ns_Log(Notice, "nsvhr: redirecting: host: %s -> tcp://%s:%d",
                    host, loc->u.tcp.hostname, loc->u.tcp.port);
                break;
            case UDS:
                Ns_Log(Notice, "nsvhr: redirecting: host: %s -> unix:%s",
                    host, loc->u.uds.filename);
                break;
            }

            Tcl_SetHashValue(hePtr, loc);
        }
    }

    set = Ns_ConfigGetSection(path);
    if (set == NULL) {
        Ns_RegisterRequest(server, "HEAD", "/*", VHRProc, NULL, NULL,
            NS_OP_NODELETE);
        Ns_RegisterRequest(server, "GET", "/*", VHRProc, NULL, NULL,
            NS_OP_NODELETE);
        Ns_RegisterRequest(server, "POST", "/*", VHRProc, NULL, NULL,
            NS_OP_NODELETE);
    }
    for (i=0; i < Ns_SetSize(set); i++) {
        char *key, *value;

        key = Ns_SetKey(set, i);
        value = Ns_SetValue(set, i);
        if (!strcasecmp(key, CONFIG_METHOD)) {
            Ns_RegisterRequest(server, value, "/*", VHRProc, NULL, NULL,
                NS_OP_NODELETE);
        }
    }
    errorUrl = Ns_ConfigGetValue(path, CONFIG_ERRORURL);
    busyUrl = Ns_ConfigGetValue(path, CONFIG_BUSYURL);
    if (Ns_ConfigGetInt(path, CONFIG_TIMEOUT, &gTimeout) != NS_TRUE) {
        gTimeout = DEFAULT_TIMEOUT;
    }

    return NS_OK;
}


/*
 *----------------------------------------------------------------------
 *
 * LocationSplit --
 *
 *	Take a location and break it up into parts (a location looks like
 *      "llamas.office.mit.edu:8000")
 *
 * Results:
 *	A pointer to a Location structure or NULL if there is a parsing 
 *	error.
 *
 * Side effects:
 *	None.
 *
 *----------------------------------------------------------------------
 */

static int
LocationSplit(char *url, Location *loc)
{
    char       *temp;
    char       *protocol; 
    char       *hostname; 
    char       *port; 
    char       *path; 
    char       *tail;
    Ns_DString ds;
    int        status = NS_ERROR;

    assert(url != NULL);
    assert(loc != NULL);
    
    Ns_DStringInit(&ds);

    temp = ns_strdup(url);
    if (Ns_ParseUrl(temp, &protocol, &hostname, &port, &path, &tail) !=
            NS_OK || protocol == NULL || hostname == NULL) {
        Ns_Log(Error, "nsvhr: cannot parse '%s'", url);
        goto done;
    }

    /*
     * http or tcp is TCP.  unix is UDS.
     */
    if (STREQ(protocol, HTTP_CONF)) {
        loc->protocol = TCP;
    } else if (STREQ(protocol, TCP_CONF)) {
        loc->protocol = TCP;
    } else if (STREQ(protocol, UNIX_CONF)) {
        loc->protocol = UDS;
    } else {
        Ns_Log(Error, "nsvhr: protocol '%s'", protocol);
    }

    switch (loc->protocol) {
    case TCP:
        loc->u.tcp.hostname = ns_strdup(hostname);
        if (port == NULL) {
            loc->u.tcp.port = 80;
        } else {
            loc->u.tcp.port = atoi(port);
        }
        break;
    case UDS:
        /*
         * In this case, hostname is really the unix domain socket
         * file name
         */

        Ns_DStringVarAppend(&ds, MODULES, "/", NSUNIX_DRIVER_NAME,
                "/", hostname, NULL);
        loc->u.uds.filename = ns_strdup(Ns_DStringValue(&ds));
        break;
    }

    status = NS_OK;

done:
    Ns_DStringFree(&ds);
    ns_free(temp);

    return status;
}


/*
 *----------------------------------------------------------------------
 *
 * VHRProc --
 *
 *	The callback for all requests.
 *
 * Results:
 *	NS_OK if ok, NS_ERROR otherwise.
 *
 * Side effects:
 *	Something will be returned and the connection will be closed.
 *
 *----------------------------------------------------------------------
 */

static int
VHRProc(void *context, Ns_Conn *conn)
{
    Tcl_HashEntry *hePtr = NULL;
    char          *host;
    Ns_Set        *headers;
    Location      *locPtr;

    assert(conn != NULL && context == NULL);

    headers = Ns_ConnHeaders(conn);
    host = Ns_SetIGet(headers, "Host");

    if (host != NULL) {
        Ns_StrToLower(host);
        hePtr = Tcl_FindHashEntry(&map, host);
    }

    if (hePtr == NULL) {
        if (errorUrl != NULL) {
            Ns_ConnReturnRedirect(conn, errorUrl);
        } else {
            Ns_ConnReturnNotFound(conn);
        }
        return NS_OK;
    }

    /*
     * OK, we have a host to go to!
     */

    locPtr = Tcl_GetHashValue(hePtr);
    assert(locPtr != NULL);

    switch (locPtr->protocol) {
    case TCP:
        if (TCPProxy(conn, locPtr) == NS_ERROR) {
            if (errorUrl != NULL) {
                Ns_ConnReturnRedirect(conn, errorUrl);
            } else {
                Ns_ConnReturnNotFound(conn);
            }
        }
        break;
    case UDS:
        if (UDSProxy(conn, locPtr) == NS_ERROR) {
            if (errorUrl != NULL) {
                Ns_ConnReturnRedirect(conn, errorUrl);
            } else {
                Ns_ConnReturnNotFound(conn);
            }
        }
        break;
    }

    return NS_OK;
}


/*
 *----------------------------------------------------------------------
 *
 * TCPProxy --
 *
 *	Go forth and proxy the existing connection over a TCP socket
 *
 * Results:
 *	NS_OK if ok, NS_ERROR otherwise.
 *
 * Side effects:
 *	None.
 *
 *----------------------------------------------------------------------
 */

static int
TCPProxy(Ns_Conn  *conn, Location *loc)
{
    int         sock;
    Ns_DString  request;
    Ns_Set     *headersPtr;

    sock = Ns_SockConnect(loc->u.tcp.hostname, loc->u.tcp.port);
    if (sock == INVALID_SOCKET) {
        return NS_ERROR;
    }

    Ns_DStringInit(&request);
    Ns_DStringVarAppend(&request, conn->request->line, HTTP_EOL, NULL);

    headersPtr = Ns_ConnHeaders(conn);

    /*
     * Pass-through headers to the other connection.  Replace
     * "Connection:" header with "Connection: close" to satisfy the
     * HTTP 1.1 RFC.
     */
    if (headersPtr != NULL) {
        int i;
        Ns_SetUpdate(headersPtr, "Connection", "close");
        for (i=0; i < Ns_SetSize(headersPtr); i++) {
            Ns_DStringVarAppend(&request, Ns_SetKey(headersPtr, i),
                ": ", Ns_SetValue(headersPtr, i), HTTP_EOL, NULL);
        }
        Ns_DStringAppend(&request, HTTP_EOL);
    }
    
#ifdef DEBUG
    Ns_Log(Notice, "nsvhr: request is %s",request.string);
#endif

    if (conn->contentLength > 0) {
        Ns_ConnCopyToDString(conn, (unsigned) conn->contentLength, &request);
    }
    SockWrite(sock, request.string);
    if (TimedSockDump(sock, conn, gTimeout) == NS_ERROR) {
        if (busyUrl == NULL) {
            Ns_ConnReturnNotice(conn, 408, "408 Request Timeout",
                "The server has timed out while attempting to "
                "fulfill your request.");
        } else {
            Ns_ReturnRedirect(conn, busyUrl);
        }
    } else {
        Ns_ConnClose(conn);
    }
    close(sock);
    Ns_DStringFree(&request);
    return NS_OK;
}


/*
 *----------------------------------------------------------------------
 *
 * SockWrite -
 *
 *	Write all of a string to a socket.
 *
 * Results:
 *	None.
 *
 * Side effects:
 *	Will fail if strlen(string) == 0.
 *
 *----------------------------------------------------------------------
 */

static void
SockWrite(int sock, char *string)
{
    int wrote, towrite;

    assert(string != NULL);
    assert(sock != INVALID_SOCKET);

    towrite = strlen(string);
    assert(towrite > 0);

    do {
        wrote = write(sock, string, (size_t) towrite);
        string += wrote;
        towrite -= wrote;
    } while (wrote > 0);
}



/*
 *----------------------------------------------------------------------
 *
 * TimedSockDump --
 *
 *	Copy bytes from socket to conn, giving up after timeout seconds.
 *      A timeout of 0 is infinite (careful!)
 *
 * Results:
 *	NS_OK if ok, NS_ERROR if error/timeout.
 *
 * Side effects:
 *	Writing to conn, reading from socket. Close sock after this.
 *
 *----------------------------------------------------------------------
 */

static int
TimedSockDump(int sock, Ns_Conn *conn, int timeout)
{
    fd_set readset;
    int    sel;
    char buffer[BUFSIZE];
    struct timeval to, *toPtr;

    assert(conn != NULL && sock != INVALID_SOCKET);

    FD_ZERO(&readset);
    FD_SET(sock, &readset);

    to.tv_sec = timeout;
    to.tv_usec = 0;
    if (timeout == 0) {
        toPtr = NULL;
    } else {
        toPtr = &to;
    }
    Ns_SockSetBlocking(sock);
    sel = select(sock + 1, &readset, NULL, NULL, toPtr);
    while (sel > 0) {
        int bytes;

        bytes = read(sock, buffer, BUFSIZE);
        if (bytes > 0) {
            if (Ns_WriteConn(conn, buffer, bytes) != NS_OK) {
                break;
            }
        } else {
            if (bytes < 0) {
                Ns_Log(Warning,
                    "nsvhr: read error while redirecting to host %s: %s",
                    Ns_SetIGet(Ns_ConnHeaders(conn), "Host"),
                    strerror(errno));
            }
            break;
        }

        FD_SET(sock, &readset);
        sel = select(sock + 1, &readset, NULL, NULL, toPtr);
    }
    if (sel == 0) {
        /* Timeout */
        Ns_Log(Warning, "nsvhr: timeout while redirecting to host %s",
            Ns_SetIGet(Ns_ConnHeaders(conn), "Host"));
        return NS_ERROR;
    }

    return NS_OK;
}



/*
 *----------------------------------------------------------------------
 *
 *  UDSProxy --
 *
 *	UDSProxy copies the initial data of this connection and passes
 *	the file descriptor of the socket of this connection and its
 *	initial request data in a message over a unix domain socket.
 *
 * Results:
 *	NS_OK if ok, NS_ERROR otherwise.
 *
 * Side effects:
 *	Close this connection.
 *
 *----------------------------------------------------------------------
 */

static int
UDSProxy(Ns_Conn *conn, Location *loc)
{
    int         	sock;
    int			dup_sock;
    Ns_DString  	request;
    Ns_Set     	       *headersPtr;
    struct sockaddr_un  addr;
    int   	       	unix_sock;
    int   	       	addr_len;
    struct msghdr      	msg;
    struct iovec       	iov[1];
#ifdef HAVE_CMMSG
    my_cmsghdr          ancillary;
#endif
    int 		retcode = NS_ERROR;

    Ns_DStringInit(&request);

    /*
     * Create a unix domain socket to pass the connection socket file
     * descriptor and the request data.
     */

    unix_sock = socket(AF_UNIX, SOCK_STREAM, 0);
    if (unix_sock < 0) {
        Ns_Log(Error, "nsvhr: could not create unix domain socket: %s",
            strerror(errno));
        goto done;
    }

    bzero((char *) &addr, sizeof(addr));
    strcpy(addr.sun_path, loc->u.uds.filename);
    addr.sun_family = AF_UNIX;
    addr_len = sizeof(addr.sun_family) + strlen(addr.sun_path);
    if (connect(unix_sock, (struct sockaddr *) &addr, (socklen_t) addr_len) < 0) {
        Ns_Log(Error, "nsvhr: could not connect to unix:%s: %s",
            loc->u.uds.filename, strerror(errno));
        goto done;
    }

    /*
     * Collect the request headers from this connection into a string.
     */

    Ns_DStringVarAppend(&request, conn->request->line, HTTP_EOL, NULL);

    headersPtr = Ns_ConnHeaders(conn);

    /*
     * Pass-through headers to the other connection.  Replace
     * "Connection:" header with "Connection: close" to satisfy the
     * HTTP 1.1 RFC.
     */
    if (headersPtr != NULL) {
        int i;
        Ns_SetUpdate(headersPtr, "Connection", "close");
        for (i=0; i < Ns_SetSize(headersPtr); i++) {
            Ns_DStringVarAppend(&request, Ns_SetKey(headersPtr, i),
                ": ", Ns_SetValue(headersPtr, i), HTTP_EOL, NULL);
        }
        Ns_DStringAppend(&request, HTTP_EOL);
    }

    /*
     * Copy the rest of a POST/PUT through the proxy.
     */

    if (conn->contentLength > 0) {
        Ns_ConnCopyToDString(conn, (unsigned) conn->contentLength, &request);
    }

    /*
     * Obtain the file descriptor of the socket of this connection.
     * This file descriptor is dup'ed before it is closed so this
     * connection can be closed and still have a file descriptor of 
     * socket.
     */

    sock = Ns_ConnSock(conn);
    assert (sock > 0);
    dup_sock = dup(sock);
    assert(dup_sock > 0);

    /*
     * Build a message to send over unix domain socket.
     */

    iov[0].iov_base = Ns_DStringValue(&request);
    iov[0].iov_len = Ns_DStringLength(&request);

    msg.msg_iov = iov;
    msg.msg_iovlen = 1;
    msg.msg_name = NULL;
    msg.msg_namelen = 0;
#ifdef HAVE_CMMSG
    msg.msg_control = &ancillary;
    msg.msg_controllen = sizeof(ancillary);
    ancillary.cmsg.cmsg_len = sizeof(ancillary);
    ancillary.cmsg.cmsg_level = SOL_SOCKET;
    ancillary.cmsg.cmsg_type = SCM_RIGHTS;
    ancillary.sock = dup_sock;
#else
    msg.msg_accrights = (caddr_t) &dup_sock;
    msg.msg_accrightslen = sizeof(dup_sock);
#endif

    if (sendmsg(unix_sock, &msg, 0) < 0) {
        Ns_Log(Error, "nsvhr: sendmsg() failed: %s", strerror(errno));
        goto done;
    }

    close(unix_sock);
    close(dup_sock);
    close(sock);
    Ns_ConnClose(conn);
    retcode = NS_OK;

done:
    Ns_DStringFree(&request);

    return retcode;
}

