summaryrefslogtreecommitdiffstats
path: root/packet.c
diff options
context:
space:
mode:
authordjm@openbsd.org <djm@openbsd.org>2015-01-30 02:13:33 +0100
committerDamien Miller <djm@mindrot.org>2015-01-30 02:18:59 +0100
commit4509b5d4a4fa645a022635bfa7e86d09b285001f (patch)
treecb94ac37e4d5c59a3a5c2cde3b6c76363e7035d3 /packet.c
parentupstream commit (diff)
downloadopenssh-4509b5d4a4fa645a022635bfa7e86d09b285001f.tar.xz
openssh-4509b5d4a4fa645a022635bfa7e86d09b285001f.zip
upstream commit
avoid more fatal/exit in the packet.c paths that ssh-keyscan uses; feedback and "looks good" markus@
Diffstat (limited to 'packet.c')
-rw-r--r--packet.c220
1 files changed, 143 insertions, 77 deletions
diff --git a/packet.c b/packet.c
index eb178f149..f9ce08412 100644
--- a/packet.c
+++ b/packet.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: packet.c,v 1.204 2015/01/28 21:15:47 djm Exp $ */
+/* $OpenBSD: packet.c,v 1.205 2015/01/30 01:13:33 djm Exp $ */
/*
* Author: Tatu Ylonen <ylo@cs.hut.fi>
* Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -272,20 +272,26 @@ ssh_packet_set_connection(struct ssh *ssh, int fd_in, int fd_out)
const struct sshcipher *none = cipher_by_name("none");
int r;
- if (none == NULL)
- fatal("%s: cannot load cipher 'none'", __func__);
+ if (none == NULL) {
+ error("%s: cannot load cipher 'none'", __func__);
+ return NULL;
+ }
if (ssh == NULL)
ssh = ssh_alloc_session_state();
- if (ssh == NULL)
- fatal("%s: cound not allocate state", __func__);
+ if (ssh == NULL) {
+ error("%s: cound not allocate state", __func__);
+ return NULL;
+ }
state = ssh->state;
state->connection_in = fd_in;
state->connection_out = fd_out;
if ((r = cipher_init(&state->send_context, none,
(const u_char *)"", 0, NULL, 0, CIPHER_ENCRYPT)) != 0 ||
(r = cipher_init(&state->receive_context, none,
- (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0)
- fatal("%s: cipher_init failed: %s", __func__, ssh_err(r));
+ (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) {
+ error("%s: cipher_init failed: %s", __func__, ssh_err(r));
+ return NULL;
+ }
state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL;
deattack_init(&state->deattack);
return ssh;
@@ -893,8 +899,8 @@ ssh_packet_send1(struct ssh *ssh)
/*
* Note that the packet is now only buffered in output. It won't be
- * actually sent until packet_write_wait or packet_write_poll is
- * called.
+ * actually sent until ssh_packet_write_wait or ssh_packet_write_poll
+ * is called.
*/
r = 0;
out:
@@ -1263,8 +1269,12 @@ ssh_packet_read_seqnr(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
if (setp == NULL)
return SSH_ERR_ALLOC_FAIL;
- /* Since we are blocking, ensure that all written packets have been sent. */
- ssh_packet_write_wait(ssh);
+ /*
+ * Since we are blocking, ensure that all written packets have
+ * been sent.
+ */
+ if ((r = ssh_packet_write_wait(ssh)) != 0)
+ return r;
/* Stay in the loop until we have received a complete packet. */
for (;;) {
@@ -1351,16 +1361,22 @@ ssh_packet_read(struct ssh *ssh)
* that given, and gives a fatal error and exits if there is a mismatch.
*/
-void
-ssh_packet_read_expect(struct ssh *ssh, int expected_type)
+int
+ssh_packet_read_expect(struct ssh *ssh, u_int expected_type)
{
- int type;
+ int r;
+ u_char type;
- type = ssh_packet_read(ssh);
- if (type != expected_type)
- ssh_packet_disconnect(ssh,
+ if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0)
+ return r;
+ if (type != expected_type) {
+ if ((r = sshpkt_disconnect(ssh,
"Protocol error: expected packet type %d, got %d",
- expected_type, type);
+ expected_type, type)) != 0)
+ return r;
+ return SSH_ERR_PROTOCOL_ERROR;
+ }
+ return 0;
}
/* Checks if a full packet is available in the data received so far via
@@ -1377,6 +1393,7 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
{
struct session_state *state = ssh->state;
u_int len, padded_len;
+ const char *emsg;
const u_char *cp;
u_char *p;
u_int checksum, stored_checksum;
@@ -1389,9 +1406,12 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
return 0;
/* Get length of incoming packet. */
len = PEEK_U32(sshbuf_ptr(state->input));
- if (len < 1 + 2 + 2 || len > 256 * 1024)
- ssh_packet_disconnect(ssh, "Bad packet length %u.",
- len);
+ if (len < 1 + 2 + 2 || len > 256 * 1024) {
+ if ((r = sshpkt_disconnect(ssh, "Bad packet length %u",
+ len)) != 0)
+ return r;
+ return SSH_ERR_CONN_CORRUPT;
+ }
padded_len = (len + 8) & ~7;
/* Check if the packet has been entirely received. */
@@ -1410,19 +1430,27 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
* Ariel Futoransky(futo@core-sdi.com)
*/
if (!state->receive_context.plaintext) {
+ emsg = NULL;
switch (detect_attack(&state->deattack,
sshbuf_ptr(state->input), padded_len)) {
case DEATTACK_OK:
break;
case DEATTACK_DETECTED:
- ssh_packet_disconnect(ssh,
- "crc32 compensation attack: network attack detected"
- );
+ emsg = "crc32 compensation attack detected";
+ break;
case DEATTACK_DOS_DETECTED:
- ssh_packet_disconnect(ssh,
- "deattack denial of service detected");
+ emsg = "deattack denial of service detected";
+ break;
default:
- ssh_packet_disconnect(ssh, "deattack error");
+ emsg = "deattack error";
+ break;
+ }
+ if (emsg != NULL) {
+ error("%s", emsg);
+ if ((r = sshpkt_disconnect(ssh, "%s", emsg)) != 0 ||
+ (r = ssh_packet_write_wait(ssh)) != 0)
+ return r;
+ return SSH_ERR_CONN_CORRUPT;
}
}
@@ -1451,16 +1479,24 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
goto out;
/* Test check bytes. */
- if (len != sshbuf_len(state->incoming_packet))
- ssh_packet_disconnect(ssh,
- "packet_read_poll1: len %d != sshbuf_len %zd.",
+ if (len != sshbuf_len(state->incoming_packet)) {
+ error("%s: len %d != sshbuf_len %zd", __func__,
len, sshbuf_len(state->incoming_packet));
+ if ((r = sshpkt_disconnect(ssh, "invalid packet length")) != 0 ||
+ (r = ssh_packet_write_wait(ssh)) != 0)
+ return r;
+ return SSH_ERR_CONN_CORRUPT;
+ }
cp = sshbuf_ptr(state->incoming_packet) + len - 4;
stored_checksum = PEEK_U32(cp);
- if (checksum != stored_checksum)
- ssh_packet_disconnect(ssh,
- "Corrupted check bytes on input.");
+ if (checksum != stored_checksum) {
+ error("Corrupted check bytes on input");
+ if ((r = sshpkt_disconnect(ssh, "connection corrupted")) != 0 ||
+ (r = ssh_packet_write_wait(ssh)) != 0)
+ return r;
+ return SSH_ERR_CONN_CORRUPT;
+ }
if ((r = sshbuf_consume_end(state->incoming_packet, 4)) < 0)
goto out;
@@ -1478,9 +1514,13 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
state->p_read.bytes += padded_len + 4;
if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
goto out;
- if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX)
- ssh_packet_disconnect(ssh,
- "Invalid ssh1 packet type: %d", *typep);
+ if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) {
+ error("Invalid ssh1 packet type: %d", *typep);
+ if ((r = sshpkt_disconnect(ssh, "invalid packet type")) != 0 ||
+ (r = ssh_packet_write_wait(ssh)) != 0)
+ return r;
+ return SSH_ERR_PROTOCOL_ERROR;
+ }
r = 0;
out:
return r;
@@ -1634,7 +1674,6 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
if ((r = sshbuf_consume(state->input, mac->mac_len)) != 0)
goto out;
}
- /* XXX now it's safe to use fatal/packet_disconnect */
if (seqnr_p != NULL)
*seqnr_p = state->p_read.seqnr;
if (++state->p_read.seqnr == 0)
@@ -1648,9 +1687,13 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
/* get padlen */
padlen = sshbuf_ptr(state->incoming_packet)[4];
DBG(debug("input: padlen %d", padlen));
- if (padlen < 4)
- ssh_packet_disconnect(ssh,
- "Corrupted padlen %d on input.", padlen);
+ if (padlen < 4) {
+ if ((r = sshpkt_disconnect(ssh,
+ "Corrupted padlen %d on input.", padlen)) != 0 ||
+ (r = ssh_packet_write_wait(ssh)) != 0)
+ return r;
+ return SSH_ERR_CONN_CORRUPT;
+ }
/* skip packet size + padlen, discard padding */
if ((r = sshbuf_consume(state->incoming_packet, 4 + 1)) != 0 ||
@@ -1677,9 +1720,13 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
*/
if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
goto out;
- if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN)
- ssh_packet_disconnect(ssh,
- "Invalid ssh2 packet type: %d", *typep);
+ if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) {
+ if ((r = sshpkt_disconnect(ssh,
+ "Invalid ssh2 packet type: %d", *typep)) != 0 ||
+ (r = ssh_packet_write_wait(ssh)) != 0)
+ return r;
+ return SSH_ERR_PROTOCOL_ERROR;
+ }
if (*typep == SSH2_MSG_NEWKEYS)
r = ssh_set_newkeys(ssh, MODE_IN);
else if (*typep == SSH2_MSG_USERAUTH_SUCCESS && !state->server_side)
@@ -1816,9 +1863,8 @@ ssh_packet_remaining(struct ssh *ssh)
* message is printed immediately, but only if the client is being executed
* in verbose mode. These messages are primarily intended to ease debugging
* authentication problems. The length of the formatted message must not
- * exceed 1024 bytes. This will automatically call packet_write_wait.
+ * exceed 1024 bytes. This will automatically call ssh_packet_write_wait.
*/
-
void
ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
{
@@ -1846,7 +1892,29 @@ ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
(r = sshpkt_send(ssh)) != 0)
fatal("%s: %s", __func__, ssh_err(r));
}
- ssh_packet_write_wait(ssh);
+ if ((r = ssh_packet_write_wait(ssh)) != 0)
+ fatal("%s: %s", __func__, ssh_err(r));
+}
+
+/*
+ * Pretty-print connection-terminating errors and exit.
+ */
+void
+sshpkt_fatal(struct ssh *ssh, const char *tag, int r)
+{
+ switch (r) {
+ case SSH_ERR_CONN_CLOSED:
+ logit("Connection closed by %.200s", ssh_remote_ipaddr(ssh));
+ cleanup_exit(255);
+ case SSH_ERR_CONN_TIMEOUT:
+ logit("Connection to %.200s timed out while "
+ "waiting to write", ssh_remote_ipaddr(ssh));
+ cleanup_exit(255);
+ default:
+ fatal("%s%sConnection to %.200s: %s",
+ tag != NULL ? tag : "", tag != NULL ? ": " : "",
+ ssh_remote_ipaddr(ssh), ssh_err(r));
+ }
}
/*
@@ -1855,7 +1923,6 @@ ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
* should not contain a newline. The length of the formatted message must
* not exceed 1024 bytes.
*/
-
void
ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...)
{
@@ -1879,30 +1946,26 @@ ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...)
/* Display the error locally */
logit("Disconnecting: %.100s", buf);
- /* Send the disconnect message to the other side, and wait for it to get sent. */
- if (compat20) {
- if ((r = sshpkt_start(ssh, SSH2_MSG_DISCONNECT)) != 0 ||
- (r = sshpkt_put_u32(ssh, SSH2_DISCONNECT_PROTOCOL_ERROR)) != 0 ||
- (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
- (r = sshpkt_put_cstring(ssh, "")) != 0 ||
- (r = sshpkt_send(ssh)) != 0)
- fatal("%s: %s", __func__, ssh_err(r));
- } else {
- if ((r = sshpkt_start(ssh, SSH_MSG_DISCONNECT)) != 0 ||
- (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
- (r = sshpkt_send(ssh)) != 0)
- fatal("%s: %s", __func__, ssh_err(r));
- }
- ssh_packet_write_wait(ssh);
+ /*
+ * Send the disconnect message to the other side, and wait
+ * for it to get sent.
+ */
+ if ((r = sshpkt_disconnect(ssh, "%s", buf)) != 0)
+ sshpkt_fatal(ssh, __func__, r);
+
+ if ((r = ssh_packet_write_wait(ssh)) != 0)
+ sshpkt_fatal(ssh, __func__, r);
/* Close the connection. */
ssh_packet_close(ssh);
cleanup_exit(255);
}
-/* Checks if there is any buffered output, and tries to write some of the output. */
-
-void
+/*
+ * Checks if there is any buffered output, and tries to write some of
+ * the output.
+ */
+int
ssh_packet_write_poll(struct ssh *ssh)
{
struct session_state *state = ssh->state;
@@ -1916,33 +1979,33 @@ ssh_packet_write_poll(struct ssh *ssh)
if (len == -1) {
if (errno == EINTR || errno == EAGAIN ||
errno == EWOULDBLOCK)
- return;
- fatal("Write failed: %.100s", strerror(errno));
+ return 0;
+ return SSH_ERR_SYSTEM_ERROR;
}
if (len == 0 && !cont)
- fatal("Write connection closed");
+ return SSH_ERR_CONN_CLOSED;
if ((r = sshbuf_consume(state->output, len)) != 0)
- fatal("%s: %s", __func__, ssh_err(r));
+ return r;
}
+ return 0;
}
/*
* Calls packet_write_poll repeatedly until all pending output data has been
* written.
*/
-
-void
+int
ssh_packet_write_wait(struct ssh *ssh)
{
fd_set *setp;
- int ret, ms_remain = 0;
+ int ret, r, ms_remain = 0;
struct timeval start, timeout, *timeoutp = NULL;
struct session_state *state = ssh->state;
setp = (fd_set *)calloc(howmany(state->connection_out + 1,
NFDBITS), sizeof(fd_mask));
if (setp == NULL)
- fatal("%s: calloc failed", __func__);
+ return SSH_ERR_ALLOC_FAIL;
ssh_packet_write_poll(ssh);
while (ssh_packet_have_data_to_write(ssh)) {
memset(setp, 0, howmany(state->connection_out + 1,
@@ -1973,13 +2036,16 @@ ssh_packet_write_wait(struct ssh *ssh)
}
}
if (ret == 0) {
- logit("Connection to %.200s timed out while "
- "waiting to write", ssh_remote_ipaddr(ssh));
- cleanup_exit(255);
+ free(setp);
+ return SSH_ERR_CONN_TIMEOUT;
+ }
+ if ((r = ssh_packet_write_poll(ssh)) != 0) {
+ free(setp);
+ return r;
}
- ssh_packet_write_poll(ssh);
}
free(setp);
+ return 0;
}
/* Returns true if there is buffered data to write to the connection. */