diff options
author | djm@openbsd.org <djm@openbsd.org> | 2015-01-30 02:13:33 +0100 |
---|---|---|
committer | Damien Miller <djm@mindrot.org> | 2015-01-30 02:18:59 +0100 |
commit | 4509b5d4a4fa645a022635bfa7e86d09b285001f (patch) | |
tree | cb94ac37e4d5c59a3a5c2cde3b6c76363e7035d3 /packet.c | |
parent | upstream commit (diff) | |
download | openssh-4509b5d4a4fa645a022635bfa7e86d09b285001f.tar.xz openssh-4509b5d4a4fa645a022635bfa7e86d09b285001f.zip |
upstream commit
avoid more fatal/exit in the packet.c paths that
ssh-keyscan uses; feedback and "looks good" markus@
Diffstat (limited to 'packet.c')
-rw-r--r-- | packet.c | 220 |
1 files changed, 143 insertions, 77 deletions
@@ -1,4 +1,4 @@ -/* $OpenBSD: packet.c,v 1.204 2015/01/28 21:15:47 djm Exp $ */ +/* $OpenBSD: packet.c,v 1.205 2015/01/30 01:13:33 djm Exp $ */ /* * Author: Tatu Ylonen <ylo@cs.hut.fi> * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland @@ -272,20 +272,26 @@ ssh_packet_set_connection(struct ssh *ssh, int fd_in, int fd_out) const struct sshcipher *none = cipher_by_name("none"); int r; - if (none == NULL) - fatal("%s: cannot load cipher 'none'", __func__); + if (none == NULL) { + error("%s: cannot load cipher 'none'", __func__); + return NULL; + } if (ssh == NULL) ssh = ssh_alloc_session_state(); - if (ssh == NULL) - fatal("%s: cound not allocate state", __func__); + if (ssh == NULL) { + error("%s: cound not allocate state", __func__); + return NULL; + } state = ssh->state; state->connection_in = fd_in; state->connection_out = fd_out; if ((r = cipher_init(&state->send_context, none, (const u_char *)"", 0, NULL, 0, CIPHER_ENCRYPT)) != 0 || (r = cipher_init(&state->receive_context, none, - (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) - fatal("%s: cipher_init failed: %s", __func__, ssh_err(r)); + (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) { + error("%s: cipher_init failed: %s", __func__, ssh_err(r)); + return NULL; + } state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL; deattack_init(&state->deattack); return ssh; @@ -893,8 +899,8 @@ ssh_packet_send1(struct ssh *ssh) /* * Note that the packet is now only buffered in output. It won't be - * actually sent until packet_write_wait or packet_write_poll is - * called. + * actually sent until ssh_packet_write_wait or ssh_packet_write_poll + * is called. */ r = 0; out: @@ -1263,8 +1269,12 @@ ssh_packet_read_seqnr(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p) if (setp == NULL) return SSH_ERR_ALLOC_FAIL; - /* Since we are blocking, ensure that all written packets have been sent. */ - ssh_packet_write_wait(ssh); + /* + * Since we are blocking, ensure that all written packets have + * been sent. + */ + if ((r = ssh_packet_write_wait(ssh)) != 0) + return r; /* Stay in the loop until we have received a complete packet. */ for (;;) { @@ -1351,16 +1361,22 @@ ssh_packet_read(struct ssh *ssh) * that given, and gives a fatal error and exits if there is a mismatch. */ -void -ssh_packet_read_expect(struct ssh *ssh, int expected_type) +int +ssh_packet_read_expect(struct ssh *ssh, u_int expected_type) { - int type; + int r; + u_char type; - type = ssh_packet_read(ssh); - if (type != expected_type) - ssh_packet_disconnect(ssh, + if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0) + return r; + if (type != expected_type) { + if ((r = sshpkt_disconnect(ssh, "Protocol error: expected packet type %d, got %d", - expected_type, type); + expected_type, type)) != 0) + return r; + return SSH_ERR_PROTOCOL_ERROR; + } + return 0; } /* Checks if a full packet is available in the data received so far via @@ -1377,6 +1393,7 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep) { struct session_state *state = ssh->state; u_int len, padded_len; + const char *emsg; const u_char *cp; u_char *p; u_int checksum, stored_checksum; @@ -1389,9 +1406,12 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep) return 0; /* Get length of incoming packet. */ len = PEEK_U32(sshbuf_ptr(state->input)); - if (len < 1 + 2 + 2 || len > 256 * 1024) - ssh_packet_disconnect(ssh, "Bad packet length %u.", - len); + if (len < 1 + 2 + 2 || len > 256 * 1024) { + if ((r = sshpkt_disconnect(ssh, "Bad packet length %u", + len)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } padded_len = (len + 8) & ~7; /* Check if the packet has been entirely received. */ @@ -1410,19 +1430,27 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep) * Ariel Futoransky(futo@core-sdi.com) */ if (!state->receive_context.plaintext) { + emsg = NULL; switch (detect_attack(&state->deattack, sshbuf_ptr(state->input), padded_len)) { case DEATTACK_OK: break; case DEATTACK_DETECTED: - ssh_packet_disconnect(ssh, - "crc32 compensation attack: network attack detected" - ); + emsg = "crc32 compensation attack detected"; + break; case DEATTACK_DOS_DETECTED: - ssh_packet_disconnect(ssh, - "deattack denial of service detected"); + emsg = "deattack denial of service detected"; + break; default: - ssh_packet_disconnect(ssh, "deattack error"); + emsg = "deattack error"; + break; + } + if (emsg != NULL) { + error("%s", emsg); + if ((r = sshpkt_disconnect(ssh, "%s", emsg)) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; } } @@ -1451,16 +1479,24 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep) goto out; /* Test check bytes. */ - if (len != sshbuf_len(state->incoming_packet)) - ssh_packet_disconnect(ssh, - "packet_read_poll1: len %d != sshbuf_len %zd.", + if (len != sshbuf_len(state->incoming_packet)) { + error("%s: len %d != sshbuf_len %zd", __func__, len, sshbuf_len(state->incoming_packet)); + if ((r = sshpkt_disconnect(ssh, "invalid packet length")) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } cp = sshbuf_ptr(state->incoming_packet) + len - 4; stored_checksum = PEEK_U32(cp); - if (checksum != stored_checksum) - ssh_packet_disconnect(ssh, - "Corrupted check bytes on input."); + if (checksum != stored_checksum) { + error("Corrupted check bytes on input"); + if ((r = sshpkt_disconnect(ssh, "connection corrupted")) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } if ((r = sshbuf_consume_end(state->incoming_packet, 4)) < 0) goto out; @@ -1478,9 +1514,13 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep) state->p_read.bytes += padded_len + 4; if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0) goto out; - if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) - ssh_packet_disconnect(ssh, - "Invalid ssh1 packet type: %d", *typep); + if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) { + error("Invalid ssh1 packet type: %d", *typep); + if ((r = sshpkt_disconnect(ssh, "invalid packet type")) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_PROTOCOL_ERROR; + } r = 0; out: return r; @@ -1634,7 +1674,6 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p) if ((r = sshbuf_consume(state->input, mac->mac_len)) != 0) goto out; } - /* XXX now it's safe to use fatal/packet_disconnect */ if (seqnr_p != NULL) *seqnr_p = state->p_read.seqnr; if (++state->p_read.seqnr == 0) @@ -1648,9 +1687,13 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p) /* get padlen */ padlen = sshbuf_ptr(state->incoming_packet)[4]; DBG(debug("input: padlen %d", padlen)); - if (padlen < 4) - ssh_packet_disconnect(ssh, - "Corrupted padlen %d on input.", padlen); + if (padlen < 4) { + if ((r = sshpkt_disconnect(ssh, + "Corrupted padlen %d on input.", padlen)) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_CONN_CORRUPT; + } /* skip packet size + padlen, discard padding */ if ((r = sshbuf_consume(state->incoming_packet, 4 + 1)) != 0 || @@ -1677,9 +1720,13 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p) */ if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0) goto out; - if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) - ssh_packet_disconnect(ssh, - "Invalid ssh2 packet type: %d", *typep); + if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) { + if ((r = sshpkt_disconnect(ssh, + "Invalid ssh2 packet type: %d", *typep)) != 0 || + (r = ssh_packet_write_wait(ssh)) != 0) + return r; + return SSH_ERR_PROTOCOL_ERROR; + } if (*typep == SSH2_MSG_NEWKEYS) r = ssh_set_newkeys(ssh, MODE_IN); else if (*typep == SSH2_MSG_USERAUTH_SUCCESS && !state->server_side) @@ -1816,9 +1863,8 @@ ssh_packet_remaining(struct ssh *ssh) * message is printed immediately, but only if the client is being executed * in verbose mode. These messages are primarily intended to ease debugging * authentication problems. The length of the formatted message must not - * exceed 1024 bytes. This will automatically call packet_write_wait. + * exceed 1024 bytes. This will automatically call ssh_packet_write_wait. */ - void ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...) { @@ -1846,7 +1892,29 @@ ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...) (r = sshpkt_send(ssh)) != 0) fatal("%s: %s", __func__, ssh_err(r)); } - ssh_packet_write_wait(ssh); + if ((r = ssh_packet_write_wait(ssh)) != 0) + fatal("%s: %s", __func__, ssh_err(r)); +} + +/* + * Pretty-print connection-terminating errors and exit. + */ +void +sshpkt_fatal(struct ssh *ssh, const char *tag, int r) +{ + switch (r) { + case SSH_ERR_CONN_CLOSED: + logit("Connection closed by %.200s", ssh_remote_ipaddr(ssh)); + cleanup_exit(255); + case SSH_ERR_CONN_TIMEOUT: + logit("Connection to %.200s timed out while " + "waiting to write", ssh_remote_ipaddr(ssh)); + cleanup_exit(255); + default: + fatal("%s%sConnection to %.200s: %s", + tag != NULL ? tag : "", tag != NULL ? ": " : "", + ssh_remote_ipaddr(ssh), ssh_err(r)); + } } /* @@ -1855,7 +1923,6 @@ ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...) * should not contain a newline. The length of the formatted message must * not exceed 1024 bytes. */ - void ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...) { @@ -1879,30 +1946,26 @@ ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...) /* Display the error locally */ logit("Disconnecting: %.100s", buf); - /* Send the disconnect message to the other side, and wait for it to get sent. */ - if (compat20) { - if ((r = sshpkt_start(ssh, SSH2_MSG_DISCONNECT)) != 0 || - (r = sshpkt_put_u32(ssh, SSH2_DISCONNECT_PROTOCOL_ERROR)) != 0 || - (r = sshpkt_put_cstring(ssh, buf)) != 0 || - (r = sshpkt_put_cstring(ssh, "")) != 0 || - (r = sshpkt_send(ssh)) != 0) - fatal("%s: %s", __func__, ssh_err(r)); - } else { - if ((r = sshpkt_start(ssh, SSH_MSG_DISCONNECT)) != 0 || - (r = sshpkt_put_cstring(ssh, buf)) != 0 || - (r = sshpkt_send(ssh)) != 0) - fatal("%s: %s", __func__, ssh_err(r)); - } - ssh_packet_write_wait(ssh); + /* + * Send the disconnect message to the other side, and wait + * for it to get sent. + */ + if ((r = sshpkt_disconnect(ssh, "%s", buf)) != 0) + sshpkt_fatal(ssh, __func__, r); + + if ((r = ssh_packet_write_wait(ssh)) != 0) + sshpkt_fatal(ssh, __func__, r); /* Close the connection. */ ssh_packet_close(ssh); cleanup_exit(255); } -/* Checks if there is any buffered output, and tries to write some of the output. */ - -void +/* + * Checks if there is any buffered output, and tries to write some of + * the output. + */ +int ssh_packet_write_poll(struct ssh *ssh) { struct session_state *state = ssh->state; @@ -1916,33 +1979,33 @@ ssh_packet_write_poll(struct ssh *ssh) if (len == -1) { if (errno == EINTR || errno == EAGAIN || errno == EWOULDBLOCK) - return; - fatal("Write failed: %.100s", strerror(errno)); + return 0; + return SSH_ERR_SYSTEM_ERROR; } if (len == 0 && !cont) - fatal("Write connection closed"); + return SSH_ERR_CONN_CLOSED; if ((r = sshbuf_consume(state->output, len)) != 0) - fatal("%s: %s", __func__, ssh_err(r)); + return r; } + return 0; } /* * Calls packet_write_poll repeatedly until all pending output data has been * written. */ - -void +int ssh_packet_write_wait(struct ssh *ssh) { fd_set *setp; - int ret, ms_remain = 0; + int ret, r, ms_remain = 0; struct timeval start, timeout, *timeoutp = NULL; struct session_state *state = ssh->state; setp = (fd_set *)calloc(howmany(state->connection_out + 1, NFDBITS), sizeof(fd_mask)); if (setp == NULL) - fatal("%s: calloc failed", __func__); + return SSH_ERR_ALLOC_FAIL; ssh_packet_write_poll(ssh); while (ssh_packet_have_data_to_write(ssh)) { memset(setp, 0, howmany(state->connection_out + 1, @@ -1973,13 +2036,16 @@ ssh_packet_write_wait(struct ssh *ssh) } } if (ret == 0) { - logit("Connection to %.200s timed out while " - "waiting to write", ssh_remote_ipaddr(ssh)); - cleanup_exit(255); + free(setp); + return SSH_ERR_CONN_TIMEOUT; + } + if ((r = ssh_packet_write_poll(ssh)) != 0) { + free(setp); + return r; } - ssh_packet_write_poll(ssh); } free(setp); + return 0; } /* Returns true if there is buffered data to write to the connection. */ |