diff options
Diffstat (limited to 'server/red-stream.cpp')
-rw-r--r-- | server/red-stream.cpp | 1219 |
1 files changed, 1219 insertions, 0 deletions
diff --git a/server/red-stream.cpp b/server/red-stream.cpp new file mode 100644 index 00000000..89222702 --- /dev/null +++ b/server/red-stream.cpp @@ -0,0 +1,1219 @@ +/* -*- Mode: C; c-basic-offset: 4; indent-tabs-mode: nil -*- */ +/* + Copyright (C) 2009, 2013 Red Hat, Inc. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, see <http://www.gnu.org/licenses/>. +*/ +#include <config.h> + +#include <cerrno> + +#include <fcntl.h> +#include <unistd.h> +#ifndef _WIN32 +#include <netdb.h> +#include <sys/socket.h> +#include <netinet/tcp.h> +#include <netinet/in.h> +#else +#include <ws2tcpip.h> +#endif + +#include <glib.h> + +#include <openssl/err.h> + +#include <common/log.h> + +#include "main-dispatcher.h" +#include "net-utils.h" +#include "red-common.h" +#include "red-stream.h" +#include "reds.h" +#include "websocket.h" + +// compatibility for *BSD systems +#if !defined(TCP_CORK) && !defined(_WIN32) +#define TCP_CORK TCP_NOPUSH +#endif + +struct AsyncRead { + void *opaque; + uint8_t *now; + uint8_t *end; + AsyncReadDone done; + AsyncReadError error; +}; + +#if HAVE_SASL +#include <sasl/sasl.h> + +struct RedSASL { + sasl_conn_t *conn; + + /* If we want to negotiate an SSF layer with client */ + int wantSSF :1; + /* If we are now running the SSF layer */ + int runSSF :1; + + /* + * Buffering encoded data to allow more clear data + * to be stuffed onto the output buffer + */ + const uint8_t *encoded; + unsigned int encodedLength; + unsigned int encodedOffset; + + SpiceBuffer inbuffer; +}; +#endif + +struct RedStreamPrivate { + SSL *ssl; + +#if HAVE_SASL + RedSASL sasl; +#endif + + AsyncRead async_read; + + RedsWebSocket *ws; + + /* life time of info: + * allocated when creating RedStream. + * deallocated when main_dispatcher handles the SPICE_CHANNEL_EVENT_DISCONNECTED + * event, either from same thread or by call back from main thread. */ + SpiceChannelEventInfo* info; + bool use_cork; + bool corked; + + ssize_t (*read)(RedStream *s, void *buf, size_t nbyte); + ssize_t (*write)(RedStream *s, const void *buf, size_t nbyte); + ssize_t (*writev)(RedStream *s, const struct iovec *iov, int iovcnt); + + RedsState *reds; + SpiceCoreInterfaceInternal *core; +}; + +#ifndef _WIN32 +/** + * Set TCP_CORK on socket + */ +/* NOTE: enabled must be int */ +static int socket_set_cork(int socket, int enabled) +{ + SPICE_VERIFY(sizeof(enabled) == sizeof(int)); + return setsockopt(socket, IPPROTO_TCP, TCP_CORK, &enabled, sizeof(enabled)); +} +#else +static inline int socket_set_cork(int socket, int enabled) +{ + return -1; +} +#endif + +static ssize_t stream_write_cb(RedStream *s, const void *buf, size_t size) +{ + return socket_write(s->socket, buf, size); +} + +static ssize_t stream_writev_cb(RedStream *s, const struct iovec *iov, int iovcnt) +{ + ssize_t ret = 0; + do { + int tosend; + ssize_t n, expected = 0; + int i; +#ifdef IOV_MAX + tosend = MIN(iovcnt, IOV_MAX); +#else + tosend = iovcnt; +#endif + for (i = 0; i < tosend; i++) { + expected += iov[i].iov_len; + } + n = socket_writev(s->socket, iov, tosend); + if (n <= expected) { + if (n > 0) + ret += n; + return ret == 0 ? n : ret; + } + ret += n; + iov += tosend; + iovcnt -= tosend; + } while(iovcnt > 0); + + return ret; +} + +static ssize_t stream_read_cb(RedStream *s, void *buf, size_t size) +{ + return socket_read(s->socket, buf, size); +} + +static ssize_t stream_ssl_error(RedStream *s, int return_code) +{ + SPICE_GNUC_UNUSED int ssl_error; + + ssl_error = SSL_get_error(s->priv->ssl, return_code); + + // OpenSSL can to return SSL_ERROR_WANT_READ if we attempt to read + // data and the socket did not receive all SSL packet. + // Under Windows errno is not set so potentially caller can detect + // the wrong error so we need to set errno. +#ifdef _WIN32 + if (ssl_error == SSL_ERROR_WANT_READ || ssl_error == SSL_ERROR_WANT_WRITE) { + errno = EAGAIN; + } else { + errno = EPIPE; + } +#endif + + // red_peer_receive is expected to receive -1 on errors while + // OpenSSL documentation just state a <0 value + return -1; +} + +static ssize_t stream_ssl_write_cb(RedStream *s, const void *buf, size_t size) +{ + int return_code; + + return_code = SSL_write(s->priv->ssl, buf, size); + + if (return_code < 0) { + return stream_ssl_error(s, return_code); + } + + return return_code; +} + +static ssize_t stream_ssl_read_cb(RedStream *s, void *buf, size_t size) +{ + int return_code; + + return_code = SSL_read(s->priv->ssl, buf, size); + + if (return_code < 0) { + return stream_ssl_error(s, return_code); + } + + return return_code; +} + +void red_stream_remove_watch(RedStream* s) +{ + red_watch_remove(s->watch); + s->watch = nullptr; +} + +#if HAVE_SASL +static ssize_t red_stream_sasl_read(RedStream *s, uint8_t *buf, size_t nbyte); +#endif + +ssize_t red_stream_read(RedStream *s, void *buf, size_t nbyte) +{ + ssize_t ret; + +#if HAVE_SASL + if (s->priv->sasl.conn && s->priv->sasl.runSSF) { + ret = red_stream_sasl_read(s, (uint8_t*) buf, nbyte); + } else +#endif + ret = s->priv->read(s, buf, nbyte); + + return ret; +} + +bool red_stream_write_all(RedStream *stream, const void *in_buf, size_t n) +{ + const uint8_t *buf = (uint8_t *)in_buf; + + while (n) { + int now = red_stream_write(stream, buf, n); + if (now <= 0) { + if (now == -1 && (errno == EINTR || errno == EAGAIN)) { + continue; + } + return false; + } + n -= now; + buf += now; + } + return true; +} + +bool red_stream_set_auto_flush(RedStream *s, bool auto_flush) +{ + if (s->priv->use_cork == !auto_flush) { + return true; + } + + s->priv->use_cork = !auto_flush; + if (s->priv->use_cork) { + if (socket_set_cork(s->socket, 1)) { + s->priv->use_cork = false; + return false; + } + s->priv->corked = true; + } else if (s->priv->corked) { + socket_set_cork(s->socket, 0); + s->priv->corked = false; + } + return true; +} + +void red_stream_flush(RedStream *s) +{ + if (s->priv->corked) { + socket_set_cork(s->socket, 0); + socket_set_cork(s->socket, 1); + } +} + +#if HAVE_SASL +static ssize_t red_stream_sasl_write(RedStream *s, const void *buf, size_t nbyte); +#endif + +ssize_t red_stream_write(RedStream *s, const void *buf, size_t nbyte) +{ + ssize_t ret; + +#if HAVE_SASL + if (s->priv->sasl.conn && s->priv->sasl.runSSF) { + ret = red_stream_sasl_write(s, buf, nbyte); + } else +#endif + ret = s->priv->write(s, buf, nbyte); + + return ret; +} + +int red_stream_get_family(const RedStream *s) +{ + spice_return_val_if_fail(s != nullptr, -1); + + if (s->socket == -1) + return -1; + + return s->priv->info->laddr_ext.ss_family; +} + +bool red_stream_is_plain_unix(const RedStream *s) +{ + spice_return_val_if_fail(s != nullptr, false); + + if (red_stream_get_family(s) != AF_UNIX) { + return false; + } + +#if HAVE_SASL + if (s->priv->sasl.conn) { + return false; + } +#endif + if (s->priv->ssl) { + return false; + } + + return true; + +} + +/** + * red_stream_set_no_delay: + * @stream: a #RedStream + * @no_delay: whether to enable TCP_NODELAY on @@stream + * + * Returns: #true if the operation succeeded, #false otherwise. + */ +bool red_stream_set_no_delay(RedStream *stream, bool no_delay) +{ + return red_socket_set_no_delay(stream->socket, no_delay); +} + +int red_stream_get_no_delay(RedStream *stream) +{ + return red_socket_get_no_delay(stream->socket); +} + +#ifndef _WIN32 +int red_stream_send_msgfd(RedStream *stream, int fd) +{ + struct msghdr msgh = { nullptr, }; + struct iovec iov; + int r; + + const size_t fd_size = 1 * sizeof(int); + struct cmsghdr *cmsg; + union { + struct cmsghdr hdr; + char data[CMSG_SPACE(fd_size)]; + } control; + + spice_return_val_if_fail(red_stream_is_plain_unix(stream), -1); + + /* set the payload */ + iov.iov_base = const_cast<char *>("@"); + iov.iov_len = 1; + msgh.msg_iovlen = 1; + msgh.msg_iov = &iov; + + if (fd != -1) { + msgh.msg_control = control.data; + msgh.msg_controllen = sizeof(control.data); + /* CMSG_SPACE() might be larger than CMSG_LEN() as it can include some + * padding. We set the whole control data to 0 to avoid valgrind warnings + */ + memset(control.data, 0, sizeof(control.data)); + + cmsg = CMSG_FIRSTHDR(&msgh); + cmsg->cmsg_len = CMSG_LEN(fd_size); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(cmsg), &fd, fd_size); + } + + do { + r = sendmsg(stream->socket, &msgh, MSG_NOSIGNAL); + } while (r < 0 && (errno == EINTR || errno == EAGAIN)); + + return r; +} +#endif + +ssize_t red_stream_writev(RedStream *s, const struct iovec *iov, int iovcnt) +{ + int i; + int n; + ssize_t ret = 0; + + if (s->priv->writev != nullptr && iovcnt > 1) { + return s->priv->writev(s, iov, iovcnt); + } + + for (i = 0; i < iovcnt; ++i) { + n = red_stream_write(s, iov[i].iov_base, iov[i].iov_len); + if (n <= 0) + return ret == 0 ? n : ret; + ret += n; + } + + return ret; +} + +void red_stream_free(RedStream *s) +{ + if (!s) { + return; + } + + red_stream_push_channel_event(s, SPICE_CHANNEL_EVENT_DISCONNECTED); + +#if HAVE_SASL + if (s->priv->sasl.conn) { + s->priv->sasl.runSSF = s->priv->sasl.wantSSF = 0; + s->priv->sasl.encodedLength = s->priv->sasl.encodedOffset = 0; + s->priv->sasl.encoded = NULL; + sasl_dispose(&s->priv->sasl.conn); + s->priv->sasl.conn = NULL; + } +#endif + + if (s->priv->ssl) { + SSL_free(s->priv->ssl); + } + + websocket_free(s->priv->ws); + + red_stream_remove_watch(s); + socket_close(s->socket); + + g_free(s); +} + +void red_stream_push_channel_event(RedStream *s, int event) +{ + RedsState *reds = s->priv->reds; + MainDispatcher *md = reds_get_main_dispatcher(reds); + md->channel_event(event, s->priv->info); +} + +static void red_stream_set_socket(RedStream *stream, int socket) +{ + stream->socket = socket; + /* deprecated fields. Filling them for backward compatibility */ + stream->priv->info->llen = sizeof(stream->priv->info->laddr); + stream->priv->info->plen = sizeof(stream->priv->info->paddr); + getsockname(stream->socket, &stream->priv->info->laddr, &stream->priv->info->llen); + getpeername(stream->socket, &stream->priv->info->paddr, &stream->priv->info->plen); + + stream->priv->info->flags |= SPICE_CHANNEL_EVENT_FLAG_ADDR_EXT; + stream->priv->info->llen_ext = sizeof(stream->priv->info->laddr_ext); + stream->priv->info->plen_ext = sizeof(stream->priv->info->paddr_ext); + getsockname(stream->socket, reinterpret_cast<struct sockaddr *>(&stream->priv->info->laddr_ext), + &stream->priv->info->llen_ext); + getpeername(stream->socket, reinterpret_cast<struct sockaddr *>(&stream->priv->info->paddr_ext), + &stream->priv->info->plen_ext); +} + + +void red_stream_set_channel(RedStream *stream, int connection_id, + int channel_type, int channel_id) +{ + stream->priv->info->connection_id = connection_id; + stream->priv->info->type = channel_type; + stream->priv->info->id = channel_id; + if (red_stream_is_ssl(stream)) { + stream->priv->info->flags |= SPICE_CHANNEL_EVENT_FLAG_TLS; + } +} + +RedStream *red_stream_new(RedsState *reds, int socket) +{ + RedStream *stream; + + stream = static_cast<RedStream *>(g_malloc0(sizeof(RedStream) + sizeof(RedStreamPrivate))); + stream->priv = reinterpret_cast<RedStreamPrivate *>(stream + 1); + stream->priv->info = g_new0(SpiceChannelEventInfo, 1); + stream->priv->reds = reds; + stream->priv->core = reds_get_core_interface(reds); + red_stream_set_socket(stream, socket); + + stream->priv->read = stream_read_cb; + stream->priv->write = stream_write_cb; + stream->priv->writev = stream_writev_cb; + + return stream; +} + +void red_stream_set_core_interface(RedStream *stream, SpiceCoreInterfaceInternal *core) +{ + red_stream_remove_watch(stream); + stream->priv->core = core; +} + +bool red_stream_is_ssl(RedStream *stream) +{ + return (stream->priv->ssl != nullptr); +} + +static void red_stream_disable_writev(RedStream *stream) +{ + stream->priv->writev = nullptr; +} + +RedStreamSslStatus red_stream_ssl_accept(RedStream *stream) +{ + int ssl_error; + int return_code; + + return_code = SSL_accept(stream->priv->ssl); + if (return_code == 1) { + return RED_STREAM_SSL_STATUS_OK; + } + +#ifndef SSL_OP_NO_RENEGOTIATION + // With OpenSSL 1.0.2 and earlier: disable client-side renegotiation + stream->priv->ssl->s3->flags |= SSL3_FLAGS_NO_RENEGOTIATE_CIPHERS; +#endif + + ssl_error = SSL_get_error(stream->priv->ssl, return_code); + if (return_code == -1 && (ssl_error == SSL_ERROR_WANT_READ || + ssl_error == SSL_ERROR_WANT_WRITE)) { + if (ssl_error == SSL_ERROR_WANT_READ) { + return RED_STREAM_SSL_STATUS_WAIT_FOR_READ; + } + return RED_STREAM_SSL_STATUS_WAIT_FOR_WRITE; + } + + red_dump_openssl_errors(); + spice_warning("SSL_accept failed, error=%d", ssl_error); + SSL_free(stream->priv->ssl); + stream->priv->ssl = nullptr; + + return RED_STREAM_SSL_STATUS_ERROR; +} + +RedStreamSslStatus red_stream_enable_ssl(RedStream *stream, SSL_CTX *ctx) +{ + BIO *sbio; + + // Handle SSL handshaking + if (!(sbio = BIO_new_socket(stream->socket, BIO_NOCLOSE))) { + spice_warning("could not allocate ssl bio socket"); + return RED_STREAM_SSL_STATUS_ERROR; + } + + stream->priv->ssl = SSL_new(ctx); + if (!stream->priv->ssl) { + spice_warning("could not allocate ssl context"); + BIO_free(sbio); + return RED_STREAM_SSL_STATUS_ERROR; + } + + SSL_set_bio(stream->priv->ssl, sbio, sbio); + + stream->priv->write = stream_ssl_write_cb; + stream->priv->read = stream_ssl_read_cb; + red_stream_disable_writev(stream); + + return red_stream_ssl_accept(stream); +} + +void red_stream_set_async_error_handler(RedStream *stream, + AsyncReadError error_handler) +{ + stream->priv->async_read.error = error_handler; +} + +static inline void async_read_clear_handlers(RedStream *stream) +{ + AsyncRead *async = &stream->priv->async_read; + red_stream_remove_watch(stream); + async->now = nullptr; + async->end = nullptr; +} + +static void async_read_handler(G_GNUC_UNUSED int fd, + G_GNUC_UNUSED int event, + RedStream *stream) +{ + AsyncRead *async = &stream->priv->async_read; + SpiceCoreInterfaceInternal *core = stream->priv->core; + + for (;;) { + int n = async->end - async->now; + + spice_assert(n > 0); + n = red_stream_read(stream, async->now, n); + if (n <= 0) { + int err = n < 0 ? errno: 0; + switch (err) { + case EAGAIN: + if (!stream->watch) { + stream->watch = core->watch_new(stream->socket, + SPICE_WATCH_EVENT_READ, + async_read_handler, stream); + } + return; + case EINTR: + break; + default: + async_read_clear_handlers(stream); + if (async->error) { + async->error(async->opaque, err); + } + return; + } + } else { + async->now += n; + if (async->now == async->end) { + async_read_clear_handlers(stream); + async->done(async->opaque); + return; + } + } + } +} + +void red_stream_async_read(RedStream *stream, + uint8_t *data, size_t size, + AsyncReadDone read_done_cb, + void *opaque) +{ + AsyncRead *async = &stream->priv->async_read; + + g_return_if_fail(async->now == nullptr && async->end == nullptr); + if (size == 0) { + read_done_cb(opaque); + return; + } + async->now = data; + async->end = async->now + size; + async->done = read_done_cb; + async->opaque = opaque; + async_read_handler(0, 0, stream); + +} + +#if HAVE_SASL +static bool red_stream_write_u8(RedStream *s, uint8_t n) +{ + return red_stream_write_all(s, &n, sizeof(uint8_t)); +} + +static bool red_stream_write_u32_le(RedStream *s, uint32_t n) +{ + n = GUINT32_TO_LE(n); + return red_stream_write_all(s, &n, sizeof(uint32_t)); +} + +static ssize_t red_stream_sasl_write(RedStream *s, const void *buf, size_t nbyte) +{ + ssize_t ret; + + if (!s->priv->sasl.encoded) { + int err; + err = sasl_encode(s->priv->sasl.conn, (char *)buf, nbyte, + (const char **)&s->priv->sasl.encoded, + &s->priv->sasl.encodedLength); + if (err != SASL_OK) { + spice_warning("sasl_encode error: %d", err); + errno = EIO; + return -1; + } + + if (s->priv->sasl.encodedLength == 0) { + return 0; + } + + if (!s->priv->sasl.encoded) { + spice_warning("sasl_encode didn't return a buffer!"); + return 0; + } + + s->priv->sasl.encodedOffset = 0; + } + + ret = s->priv->write(s, s->priv->sasl.encoded + s->priv->sasl.encodedOffset, + s->priv->sasl.encodedLength - s->priv->sasl.encodedOffset); + + if (ret <= 0) { + return ret; + } + + s->priv->sasl.encodedOffset += ret; + if (s->priv->sasl.encodedOffset == s->priv->sasl.encodedLength) { + s->priv->sasl.encoded = NULL; + s->priv->sasl.encodedOffset = s->priv->sasl.encodedLength = 0; + return nbyte; + } + + /* we didn't flush the encoded buffer */ + errno = EAGAIN; + return -1; +} + +static ssize_t red_stream_sasl_read(RedStream *s, uint8_t *buf, size_t nbyte) +{ + uint8_t encoded[4096]; + const char *decoded; + unsigned int decodedlen; + int err; + int n, offset; + + offset = spice_buffer_copy(&s->priv->sasl.inbuffer, buf, nbyte); + if (offset > 0) { + spice_buffer_remove(&s->priv->sasl.inbuffer, offset); + if (offset == nbyte) + return offset; + nbyte -= offset; + buf += offset; + } + + n = s->priv->read(s, encoded, sizeof(encoded)); + if (n <= 0) { + return offset > 0 ? offset : n; + } + + err = sasl_decode(s->priv->sasl.conn, + (char *)encoded, n, + &decoded, &decodedlen); + if (err != SASL_OK) { + spice_warning("sasl_decode error: %d", err); + errno = EIO; + return offset > 0 ? offset : -1; + } + + if (decodedlen == 0) { + errno = EAGAIN; + return offset > 0 ? offset : -1; + } + + n = MIN(nbyte, decodedlen); + memcpy(buf, decoded, n); + spice_buffer_append(&s->priv->sasl.inbuffer, decoded + n, decodedlen - n); + return offset + n; +} + +static char *addr_to_string(const char *format, + struct sockaddr_storage *sa, + socklen_t salen) +{ + char host[NI_MAXHOST]; + char serv[NI_MAXSERV]; + int err; + + // makes it work on no-glibc avoiding getnameinfo returning error + if (sa->ss_family == AF_UNIX) { + return g_strdup("localhost;"); + } + + if ((err = getnameinfo((struct sockaddr *)sa, salen, + host, sizeof(host), + serv, sizeof(serv), + NI_NUMERICHOST | NI_NUMERICSERV)) != 0) { + spice_warning("Cannot resolve address %d: %s", + err, gai_strerror(err)); + return NULL; + } + + return g_strdup_printf(format, host, serv); +} + +static char *red_stream_get_local_address(RedStream *stream) +{ + return addr_to_string("%s;%s", &stream->priv->info->laddr_ext, + stream->priv->info->llen_ext); +} + +static char *red_stream_get_remote_address(RedStream *stream) +{ + return addr_to_string("%s;%s", &stream->priv->info->paddr_ext, + stream->priv->info->plen_ext); +} + +static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF) +{ + const void *val; + int err, ssf; + + *runSSF = 0; + if (!sasl->wantSSF) { + return 1; + } + + err = sasl_getprop(sasl->conn, SASL_SSF, &val); + if (err != SASL_OK) { + return 0; + } + + ssf = *(const int *)val; + spice_debug("negotiated an SSF of %d", ssf); + if (ssf < 56) { + return 0; /* 56 is good for Kerberos */ + } + + *runSSF = 1; + + /* We have a SSF that's good enough */ + return 1; +} + +struct RedSASLAuth { + RedStream *stream; + // list of mechanisms allowed, allocated and freed by SASL + char *mechlist; + // mech received + char *mechname; + uint32_t len; + char *data; + // callback to call if success + RedSaslResult result_cb; + void *result_opaque; + // saved Async callback, we need to call if failed as + // we need to chain it in order to use a different opaque data + AsyncReadError saved_error_cb; +}; + +static void red_sasl_auth_free(RedSASLAuth *auth) +{ + g_free(auth->data); + g_free(auth->mechname); + g_free(auth->mechlist); + g_free(auth); +} + +// handle SASL termination, either success or error +// NOTE: After this function is called usually there should be a +// return or the function should exit +static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err) +{ + red_stream_set_async_error_handler(auth->stream, auth->saved_error_cb); + auth->result_cb(auth->result_opaque, err); + red_sasl_auth_free(auth); +} + +static void red_sasl_error(void *opaque, int err) +{ + RedSASLAuth *auth = (RedSASLAuth*) opaque; + red_stream_set_async_error_handler(auth->stream, auth->saved_error_cb); + if (auth->saved_error_cb) { + auth->saved_error_cb(auth->result_opaque, err); + } + red_sasl_auth_free(auth); +} + +/* + * Step Msg + * + * Input from client: + * + * u32 clientin-length + * u8-array clientin-string + * + * Output to client: + * + * u32 serverout-length + * u8-array serverout-strin + * u8 continue + */ +#define SASL_MAX_MECHNAME_LEN 100 +#define SASL_DATA_MAX_LEN (1024 * 1024) + +static void red_sasl_handle_auth_steplen(void *opaque); + +/* + * Start Msg + * + * Input from client: + * + * u32 clientin-length + * u8-array clientin-string + * + * Output to client: + * + * u32 serverout-length + * u8-array serverout-strin + * u8 continue + */ + +static void red_sasl_handle_auth_step(void *opaque) +{ + RedSASLAuth *auth = (RedSASLAuth*) opaque; + RedStream *stream = auth->stream; + const char *serverout; + unsigned int serveroutlen; + int err; + char *clientdata = NULL; + RedSASL *sasl = &stream->priv->sasl; + uint32_t datalen = auth->len; + + /* NB, distinction of NULL vs "" is *critical* in SASL */ + if (datalen) { + clientdata = auth->data; + clientdata[datalen - 1] = '\0'; /* Wire includes '\0', but make sure */ + datalen--; /* Don't count NULL byte when passing to _start() */ + } + + if (auth->mechname != NULL) { + spice_debug("Start SASL auth with mechanism %s. Data %p (%d bytes)", + auth->mechname, clientdata, datalen); + err = sasl_server_start(sasl->conn, + auth->mechname, + clientdata, + datalen, + &serverout, + &serveroutlen); + g_free(auth->mechname); + auth->mechname = NULL; + } else { + spice_debug("Step using SASL Data %p (%d bytes)", clientdata, datalen); + err = sasl_server_step(sasl->conn, + clientdata, + datalen, + &serverout, + &serveroutlen); + } + if (err != SASL_OK && + err != SASL_CONTINUE) { + spice_warning("sasl step failed %d (%s)", + err, sasl_errdetail(sasl->conn)); + return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC); + } + + if (serveroutlen > SASL_DATA_MAX_LEN) { + spice_warning("sasl step reply data too long %d", + serveroutlen); + return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC); + } + + spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout); + + if (serveroutlen) { + serveroutlen += 1; + red_stream_write_u32_le(stream, serveroutlen); + red_stream_write_all(stream, serverout, serveroutlen); + } else { + red_stream_write_u32_le(stream, 0); + } + + /* Whether auth is complete */ + red_stream_write_u8(stream, err == SASL_CONTINUE ? 0 : 1); + + if (err == SASL_CONTINUE) { + spice_debug("%s", "Authentication must continue"); + /* Wait for step length */ + red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t), + red_sasl_handle_auth_steplen, auth); + return; + } else { + int ssf; + + if (auth_sasl_check_ssf(sasl, &ssf) == 0) { + spice_warning("Authentication rejected for weak SSF"); + goto authreject; + } + + spice_debug("Authentication successful"); + red_stream_write_u32_le(stream, SPICE_LINK_ERR_OK); /* Accept auth */ + + /* + * Delay writing in SSF encoded until now + */ + sasl->runSSF = ssf; + red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */ + + return red_sasl_async_result(auth, RED_SASL_ERROR_OK); + } + +authreject: + red_stream_write_u32_le(stream, 1); /* Reject auth */ + red_stream_write_u32_le(stream, sizeof("Authentication failed")); + red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed")); + + red_sasl_async_result(auth, RED_SASL_ERROR_AUTH_FAILED); +} + +static void red_sasl_handle_auth_steplen(void *opaque) +{ + RedSASLAuth *auth = (RedSASLAuth*) opaque; + + auth->len = GUINT32_FROM_LE(auth->len); + uint32_t len = auth->len; + spice_debug("Got steplen %d", len); + if (len > SASL_DATA_MAX_LEN) { + spice_warning("Too much SASL data %d", len); + return red_sasl_async_result((RedSASLAuth*) opaque, auth->mechname ? RED_SASL_ERROR_INVALID_DATA : RED_SASL_ERROR_GENERIC); + } + + auth->data = (char*) g_realloc(auth->data, len); + red_stream_async_read(auth->stream, (uint8_t *)auth->data, len, + red_sasl_handle_auth_step, auth); +} + + + +static void red_sasl_handle_auth_mechname(void *opaque) +{ + RedSASLAuth *auth = (RedSASLAuth*) opaque; + + auth->mechname[auth->len] = '\0'; + spice_debug("Got client mechname '%s' check against '%s'", + auth->mechname, auth->mechlist); + + char quoted_mechname[SASL_MAX_MECHNAME_LEN + 4]; + sprintf(quoted_mechname, ",%s,", auth->mechname); + + if (strchr(auth->mechname, ',') || strstr(auth->mechlist, quoted_mechname) == NULL) { + return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA); + } + + spice_debug("Validated mechname '%s'", auth->mechname); + + red_stream_async_read(auth->stream, (uint8_t *)&auth->len, sizeof(uint32_t), + red_sasl_handle_auth_steplen, auth); +} + +static void red_sasl_handle_auth_mechlen(void *opaque) +{ + RedSASLAuth *auth = (RedSASLAuth*) opaque; + + auth->len = GUINT32_FROM_LE(auth->len); + uint32_t len = auth->len; + if (len < 1 || len > SASL_MAX_MECHNAME_LEN) { + spice_warning("Got bad client mechname len %d", len); + return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC); + } + + auth->mechname = (char*) g_malloc(len + 1); + + spice_debug("Wait for client mechname"); + red_stream_async_read(auth->stream, (uint8_t *)auth->mechname, len, + red_sasl_handle_auth_mechname, auth); +} + +bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *result_opaque) +{ + const char *mechlist = NULL; + sasl_security_properties_t secprops; + int err; + char *localAddr, *remoteAddr; + int mechlistlen; + RedSASL *sasl = &stream->priv->sasl; + RedSASLAuth *auth; + + if (!(localAddr = red_stream_get_local_address(stream))) { + goto error; + } + + if (!(remoteAddr = red_stream_get_remote_address(stream))) { + g_free(localAddr); + goto error; + } + + err = sasl_server_new("spice", + NULL, /* FQDN - just delegates to gethostname */ + NULL, /* User realm */ + localAddr, + remoteAddr, + NULL, /* Callbacks, not needed */ + SASL_SUCCESS_DATA, + &sasl->conn); + g_free(localAddr); + g_free(remoteAddr); + localAddr = remoteAddr = NULL; + + if (err != SASL_OK) { + spice_warning("sasl context setup failed %d (%s)", + err, sasl_errstring(err, NULL, NULL)); + sasl->conn = NULL; + goto error; + } + + /* Inform SASL that we've got an external SSF layer from TLS */ + if (stream->priv->ssl) { + sasl_ssf_t ssf; + + ssf = SSL_get_cipher_bits(stream->priv->ssl, NULL); + err = sasl_setprop(sasl->conn, SASL_SSF_EXTERNAL, &ssf); + if (err != SASL_OK) { + spice_warning("cannot set SASL external SSF %d (%s)", + err, sasl_errstring(err, NULL, NULL)); + goto error_dispose; + } + } else { + sasl->wantSSF = 1; + } + + memset(&secprops, 0, sizeof secprops); + /* Inform SASL that we've got an external SSF layer from TLS */ + if (stream->priv->ssl) { + /* If we've got TLS (or UNIX domain sock), we don't care about SSF */ + secprops.min_ssf = 0; + secprops.max_ssf = 0; + secprops.maxbufsize = 8192; + secprops.security_flags = 0; + } else { + /* Plain TCP, better get an SSF layer */ + secprops.min_ssf = 56; /* Good enough to require kerberos */ + secprops.max_ssf = 100000; /* Arbitrary big number */ + secprops.maxbufsize = 8192; + /* Forbid any anonymous or trivially crackable auth */ + secprops.security_flags = + SASL_SEC_NOANONYMOUS | SASL_SEC_NOPLAINTEXT; + } + + err = sasl_setprop(sasl->conn, SASL_SEC_PROPS, &secprops); + if (err != SASL_OK) { + spice_warning("cannot set SASL security props %d (%s)", + err, sasl_errstring(err, NULL, NULL)); + goto error_dispose; + } + + err = sasl_listmech(sasl->conn, + NULL, /* Don't need to set user */ + ",", /* Prefix */ + ",", /* Separator */ + ",", /* Suffix */ + &mechlist, + NULL, + NULL); + if (err != SASL_OK || mechlist == NULL) { + spice_warning("cannot list SASL mechanisms %d (%s)", + err, sasl_errdetail(sasl->conn)); + goto error_dispose; + } + + spice_debug("Available mechanisms for client: '%s'", mechlist); + + mechlistlen = strlen(mechlist); + if (!red_stream_write_u32_le(stream, mechlistlen) + || !red_stream_write_all(stream, mechlist, mechlistlen)) { + spice_warning("SASL mechanisms write error"); + goto error; + } + + auth = g_new0(RedSASLAuth, 1); + auth->stream = stream; + auth->result_cb = result_cb; + auth->result_opaque = result_opaque; + auth->saved_error_cb = stream->priv->async_read.error; + auth->mechlist = g_strdup(mechlist); + + spice_debug("Wait for client mechname length"); + red_stream_set_async_error_handler(stream, red_sasl_error); + red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t), + red_sasl_handle_auth_mechlen, auth); + + return true; + +error_dispose: + sasl_dispose(&sasl->conn); + sasl->conn = NULL; +error: + return false; +} +#endif + +static ssize_t stream_websocket_read(RedStream *s, void *buf, size_t size) +{ + unsigned flags; + int len; + + do { + len = websocket_read(s->priv->ws, static_cast<uint8_t *>(buf), size, &flags); + } while (len == 0 && flags != 0); + return len; +} + +static ssize_t stream_websocket_write(RedStream *s, const void *buf, size_t size) +{ + return websocket_write(s->priv->ws, buf, size, WEBSOCKET_BINARY_FINAL); +} + +static ssize_t stream_websocket_writev(RedStream *s, const struct iovec *iov, int iovcnt) +{ + return websocket_writev(s->priv->ws, iov, iovcnt, WEBSOCKET_BINARY_FINAL); +} + +/* + If we detect that a newly opened stream appears to be using + the WebSocket protocol, we will put in place cover functions + that will speak WebSocket to the client, but allow the server + to continue to use normal stream read/write/writev semantics. +*/ +bool red_stream_is_websocket(RedStream *stream, const void *buf, size_t len) +{ + if (stream->priv->ws) { + return false; + } + + stream->priv->ws = + websocket_new(buf, len, stream, reinterpret_cast<websocket_read_cb_t>(stream->priv->read), + reinterpret_cast<websocket_write_cb_t>(stream->priv->write), + reinterpret_cast<websocket_writev_cb_t>(stream->priv->writev)); + if (stream->priv->ws) { + stream->priv->read = stream_websocket_read; + stream->priv->write = stream_websocket_write; + + if (stream->priv->writev) { + stream->priv->writev = stream_websocket_writev; + } + + return true; + } + + return false; +} |