summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordjm@openbsd.org <djm@openbsd.org>2015-01-30 01:13:33 +0000
committerDamien Miller <djm@mindrot.org>2015-01-30 12:18:59 +1100
commit4509b5d4a4fa645a022635bfa7e86d09b285001f (patch)
treecb94ac37e4d5c59a3a5c2cde3b6c76363e7035d3
parent669aee994348468af8b4b2ebd29b602cf2860b22 (diff)
upstream commit
avoid more fatal/exit in the packet.c paths that ssh-keyscan uses; feedback and "looks good" markus@
-rw-r--r--dispatch.c8
-rw-r--r--opacket.c49
-rw-r--r--opacket.h9
-rw-r--r--packet.c220
-rw-r--r--packet.h12
-rw-r--r--ssh-keyscan.c5
-rw-r--r--ssh_api.c5
-rw-r--r--ssherr.c6
-rw-r--r--ssherr.h4
9 files changed, 204 insertions, 114 deletions
diff --git a/dispatch.c b/dispatch.c
index 9ff5e3daf..900a2c10c 100644
--- a/dispatch.c
+++ b/dispatch.c
@@ -1,4 +1,4 @@
1/* $OpenBSD: dispatch.c,v 1.24 2015/01/28 22:05:31 djm Exp $ */ 1/* $OpenBSD: dispatch.c,v 1.25 2015/01/30 01:13:33 djm Exp $ */
2/* 2/*
3 * Copyright (c) 2000 Markus Friedl. All rights reserved. 3 * Copyright (c) 2000 Markus Friedl. All rights reserved.
4 * 4 *
@@ -49,9 +49,9 @@ dispatch_protocol_error(int type, u_int32_t seq, void *ctx)
49 fatal("protocol error"); 49 fatal("protocol error");
50 if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 || 50 if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
51 (r = sshpkt_put_u32(ssh, seq)) != 0 || 51 (r = sshpkt_put_u32(ssh, seq)) != 0 ||
52 (r = sshpkt_send(ssh)) != 0) 52 (r = sshpkt_send(ssh)) != 0 ||
53 fatal("%s: %s", __func__, ssh_err(r)); 53 (r = ssh_packet_write_wait(ssh)) != 0)
54 ssh_packet_write_wait(ssh); 54 sshpkt_fatal(ssh, __func__, r);
55 return 0; 55 return 0;
56} 56}
57 57
diff --git a/opacket.c b/opacket.c
index a137b5a8a..7618eae48 100644
--- a/opacket.c
+++ b/opacket.c
@@ -223,6 +223,8 @@ void
223packet_set_connection(int fd_in, int fd_out) 223packet_set_connection(int fd_in, int fd_out)
224{ 224{
225 active_state = ssh_packet_set_connection(active_state, fd_in, fd_out); 225 active_state = ssh_packet_set_connection(active_state, fd_in, fd_out);
226 if (active_state == NULL)
227 fatal("%s: ssh_packet_set_connection failed", __func__);
226} 228}
227 229
228void 230void
@@ -255,20 +257,8 @@ packet_read_seqnr(u_int32_t *seqnr)
255 u_char type; 257 u_char type;
256 int r; 258 int r;
257 259
258 if ((r = ssh_packet_read_seqnr(active_state, &type, seqnr)) != 0) { 260 if ((r = ssh_packet_read_seqnr(active_state, &type, seqnr)) != 0)
259 switch (r) { 261 sshpkt_fatal(active_state, __func__, r);
260 case SSH_ERR_CONN_CLOSED:
261 logit("Connection closed by %.200s",
262 ssh_remote_ipaddr(active_state));
263 cleanup_exit(255);
264 case SSH_ERR_CONN_TIMEOUT:
265 logit("Connection to %.200s timed out while "
266 "waiting to read", ssh_remote_ipaddr(active_state));
267 cleanup_exit(255);
268 default:
269 fatal("%s: %s", __func__, ssh_err(r));
270 }
271 }
272 return type; 262 return type;
273} 263}
274 264
@@ -279,7 +269,7 @@ packet_read_poll_seqnr(u_int32_t *seqnr)
279 int r; 269 int r;
280 270
281 if ((r = ssh_packet_read_poll_seqnr(active_state, &type, seqnr))) 271 if ((r = ssh_packet_read_poll_seqnr(active_state, &type, seqnr)))
282 fatal("%s: %s", __func__, ssh_err(r)); 272 sshpkt_fatal(active_state, __func__, r);
283 return type; 273 return type;
284} 274}
285 275
@@ -296,5 +286,32 @@ packet_process_incoming(const char *buf, u_int len)
296 int r; 286 int r;
297 287
298 if ((r = ssh_packet_process_incoming(active_state, buf, len)) != 0) 288 if ((r = ssh_packet_process_incoming(active_state, buf, len)) != 0)
299 fatal("%s: %s", __func__, ssh_err(r)); 289 sshpkt_fatal(active_state, __func__, r);
290}
291
292void
293packet_write_wait(void)
294{
295 int r;
296
297 if ((r = ssh_packet_write_wait(active_state)) != 0)
298 sshpkt_fatal(active_state, __func__, r);
299}
300
301void
302packet_write_poll(void)
303{
304 int r;
305
306 if ((r = ssh_packet_write_poll(active_state)) != 0)
307 sshpkt_fatal(active_state, __func__, r);
308}
309
310void
311packet_read_expect(int expected_type)
312{
313 int r;
314
315 if ((r = ssh_packet_read_expect(active_state, expected_type)) != 0)
316 sshpkt_fatal(active_state, __func__, r);
300} 317}
diff --git a/opacket.h b/opacket.h
index 261ed1f81..e563d8d3b 100644
--- a/opacket.h
+++ b/opacket.h
@@ -45,6 +45,9 @@ void packet_set_connection(int, int);
45int packet_read_seqnr(u_int32_t *); 45int packet_read_seqnr(u_int32_t *);
46int packet_read_poll_seqnr(u_int32_t *); 46int packet_read_poll_seqnr(u_int32_t *);
47void packet_process_incoming(const char *buf, u_int len); 47void packet_process_incoming(const char *buf, u_int len);
48void packet_write_wait(void);
49void packet_write_poll(void);
50void packet_read_expect(int expected_type);
48#define packet_set_timeout(timeout, count) \ 51#define packet_set_timeout(timeout, count) \
49 ssh_packet_set_timeout(active_state, (timeout), (count)) 52 ssh_packet_set_timeout(active_state, (timeout), (count))
50#define packet_connection_is_on_socket() \ 53#define packet_connection_is_on_socket() \
@@ -85,8 +88,6 @@ void packet_process_incoming(const char *buf, u_int len);
85 ssh_packet_send(active_state) 88 ssh_packet_send(active_state)
86#define packet_read() \ 89#define packet_read() \
87 ssh_packet_read(active_state) 90 ssh_packet_read(active_state)
88#define packet_read_expect(expected_type) \
89 ssh_packet_read_expect(active_state, (expected_type))
90#define packet_get_int64() \ 91#define packet_get_int64() \
91 ssh_packet_get_int64(active_state) 92 ssh_packet_get_int64(active_state)
92#define packet_get_bignum(value) \ 93#define packet_get_bignum(value) \
@@ -105,10 +106,6 @@ void packet_process_incoming(const char *buf, u_int len);
105 ssh_packet_send_debug(active_state, (fmt), ##args) 106 ssh_packet_send_debug(active_state, (fmt), ##args)
106#define packet_disconnect(fmt, args...) \ 107#define packet_disconnect(fmt, args...) \
107 ssh_packet_disconnect(active_state, (fmt), ##args) 108 ssh_packet_disconnect(active_state, (fmt), ##args)
108#define packet_write_poll() \
109 ssh_packet_write_poll(active_state)
110#define packet_write_wait() \
111 ssh_packet_write_wait(active_state)
112#define packet_have_data_to_write() \ 109#define packet_have_data_to_write() \
113 ssh_packet_have_data_to_write(active_state) 110 ssh_packet_have_data_to_write(active_state)
114#define packet_not_very_much_data_to_write() \ 111#define packet_not_very_much_data_to_write() \
diff --git a/packet.c b/packet.c
index eb178f149..f9ce08412 100644
--- a/packet.c
+++ b/packet.c
@@ -1,4 +1,4 @@
1/* $OpenBSD: packet.c,v 1.204 2015/01/28 21:15:47 djm Exp $ */ 1/* $OpenBSD: packet.c,v 1.205 2015/01/30 01:13:33 djm Exp $ */
2/* 2/*
3 * Author: Tatu Ylonen <ylo@cs.hut.fi> 3 * Author: Tatu Ylonen <ylo@cs.hut.fi>
4 * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland 4 * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -272,20 +272,26 @@ ssh_packet_set_connection(struct ssh *ssh, int fd_in, int fd_out)
272 const struct sshcipher *none = cipher_by_name("none"); 272 const struct sshcipher *none = cipher_by_name("none");
273 int r; 273 int r;
274 274
275 if (none == NULL) 275 if (none == NULL) {
276 fatal("%s: cannot load cipher 'none'", __func__); 276 error("%s: cannot load cipher 'none'", __func__);
277 return NULL;
278 }
277 if (ssh == NULL) 279 if (ssh == NULL)
278 ssh = ssh_alloc_session_state(); 280 ssh = ssh_alloc_session_state();
279 if (ssh == NULL) 281 if (ssh == NULL) {
280 fatal("%s: cound not allocate state", __func__); 282 error("%s: cound not allocate state", __func__);
283 return NULL;
284 }
281 state = ssh->state; 285 state = ssh->state;
282 state->connection_in = fd_in; 286 state->connection_in = fd_in;
283 state->connection_out = fd_out; 287 state->connection_out = fd_out;
284 if ((r = cipher_init(&state->send_context, none, 288 if ((r = cipher_init(&state->send_context, none,
285 (const u_char *)"", 0, NULL, 0, CIPHER_ENCRYPT)) != 0 || 289 (const u_char *)"", 0, NULL, 0, CIPHER_ENCRYPT)) != 0 ||
286 (r = cipher_init(&state->receive_context, none, 290 (r = cipher_init(&state->receive_context, none,
287 (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) 291 (const u_char *)"", 0, NULL, 0, CIPHER_DECRYPT)) != 0) {
288 fatal("%s: cipher_init failed: %s", __func__, ssh_err(r)); 292 error("%s: cipher_init failed: %s", __func__, ssh_err(r));
293 return NULL;
294 }
289 state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL; 295 state->newkeys[MODE_IN] = state->newkeys[MODE_OUT] = NULL;
290 deattack_init(&state->deattack); 296 deattack_init(&state->deattack);
291 return ssh; 297 return ssh;
@@ -893,8 +899,8 @@ ssh_packet_send1(struct ssh *ssh)
893 899
894 /* 900 /*
895 * Note that the packet is now only buffered in output. It won't be 901 * Note that the packet is now only buffered in output. It won't be
896 * actually sent until packet_write_wait or packet_write_poll is 902 * actually sent until ssh_packet_write_wait or ssh_packet_write_poll
897 * called. 903 * is called.
898 */ 904 */
899 r = 0; 905 r = 0;
900 out: 906 out:
@@ -1263,8 +1269,12 @@ ssh_packet_read_seqnr(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1263 if (setp == NULL) 1269 if (setp == NULL)
1264 return SSH_ERR_ALLOC_FAIL; 1270 return SSH_ERR_ALLOC_FAIL;
1265 1271
1266 /* Since we are blocking, ensure that all written packets have been sent. */ 1272 /*
1267 ssh_packet_write_wait(ssh); 1273 * Since we are blocking, ensure that all written packets have
1274 * been sent.
1275 */
1276 if ((r = ssh_packet_write_wait(ssh)) != 0)
1277 return r;
1268 1278
1269 /* Stay in the loop until we have received a complete packet. */ 1279 /* Stay in the loop until we have received a complete packet. */
1270 for (;;) { 1280 for (;;) {
@@ -1351,16 +1361,22 @@ ssh_packet_read(struct ssh *ssh)
1351 * that given, and gives a fatal error and exits if there is a mismatch. 1361 * that given, and gives a fatal error and exits if there is a mismatch.
1352 */ 1362 */
1353 1363
1354void 1364int
1355ssh_packet_read_expect(struct ssh *ssh, int expected_type) 1365ssh_packet_read_expect(struct ssh *ssh, u_int expected_type)
1356{ 1366{
1357 int type; 1367 int r;
1368 u_char type;
1358 1369
1359 type = ssh_packet_read(ssh); 1370 if ((r = ssh_packet_read_seqnr(ssh, &type, NULL)) != 0)
1360 if (type != expected_type) 1371 return r;
1361 ssh_packet_disconnect(ssh, 1372 if (type != expected_type) {
1373 if ((r = sshpkt_disconnect(ssh,
1362 "Protocol error: expected packet type %d, got %d", 1374 "Protocol error: expected packet type %d, got %d",
1363 expected_type, type); 1375 expected_type, type)) != 0)
1376 return r;
1377 return SSH_ERR_PROTOCOL_ERROR;
1378 }
1379 return 0;
1364} 1380}
1365 1381
1366/* Checks if a full packet is available in the data received so far via 1382/* Checks if a full packet is available in the data received so far via
@@ -1377,6 +1393,7 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
1377{ 1393{
1378 struct session_state *state = ssh->state; 1394 struct session_state *state = ssh->state;
1379 u_int len, padded_len; 1395 u_int len, padded_len;
1396 const char *emsg;
1380 const u_char *cp; 1397 const u_char *cp;
1381 u_char *p; 1398 u_char *p;
1382 u_int checksum, stored_checksum; 1399 u_int checksum, stored_checksum;
@@ -1389,9 +1406,12 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
1389 return 0; 1406 return 0;
1390 /* Get length of incoming packet. */ 1407 /* Get length of incoming packet. */
1391 len = PEEK_U32(sshbuf_ptr(state->input)); 1408 len = PEEK_U32(sshbuf_ptr(state->input));
1392 if (len < 1 + 2 + 2 || len > 256 * 1024) 1409 if (len < 1 + 2 + 2 || len > 256 * 1024) {
1393 ssh_packet_disconnect(ssh, "Bad packet length %u.", 1410 if ((r = sshpkt_disconnect(ssh, "Bad packet length %u",
1394 len); 1411 len)) != 0)
1412 return r;
1413 return SSH_ERR_CONN_CORRUPT;
1414 }
1395 padded_len = (len + 8) & ~7; 1415 padded_len = (len + 8) & ~7;
1396 1416
1397 /* Check if the packet has been entirely received. */ 1417 /* Check if the packet has been entirely received. */
@@ -1410,19 +1430,27 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
1410 * Ariel Futoransky(futo@core-sdi.com) 1430 * Ariel Futoransky(futo@core-sdi.com)
1411 */ 1431 */
1412 if (!state->receive_context.plaintext) { 1432 if (!state->receive_context.plaintext) {
1433 emsg = NULL;
1413 switch (detect_attack(&state->deattack, 1434 switch (detect_attack(&state->deattack,
1414 sshbuf_ptr(state->input), padded_len)) { 1435 sshbuf_ptr(state->input), padded_len)) {
1415 case DEATTACK_OK: 1436 case DEATTACK_OK:
1416 break; 1437 break;
1417 case DEATTACK_DETECTED: 1438 case DEATTACK_DETECTED:
1418 ssh_packet_disconnect(ssh, 1439 emsg = "crc32 compensation attack detected";
1419 "crc32 compensation attack: network attack detected" 1440 break;
1420 );
1421 case DEATTACK_DOS_DETECTED: 1441 case DEATTACK_DOS_DETECTED:
1422 ssh_packet_disconnect(ssh, 1442 emsg = "deattack denial of service detected";
1423 "deattack denial of service detected"); 1443 break;
1424 default: 1444 default:
1425 ssh_packet_disconnect(ssh, "deattack error"); 1445 emsg = "deattack error";
1446 break;
1447 }
1448 if (emsg != NULL) {
1449 error("%s", emsg);
1450 if ((r = sshpkt_disconnect(ssh, "%s", emsg)) != 0 ||
1451 (r = ssh_packet_write_wait(ssh)) != 0)
1452 return r;
1453 return SSH_ERR_CONN_CORRUPT;
1426 } 1454 }
1427 } 1455 }
1428 1456
@@ -1451,16 +1479,24 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
1451 goto out; 1479 goto out;
1452 1480
1453 /* Test check bytes. */ 1481 /* Test check bytes. */
1454 if (len != sshbuf_len(state->incoming_packet)) 1482 if (len != sshbuf_len(state->incoming_packet)) {
1455 ssh_packet_disconnect(ssh, 1483 error("%s: len %d != sshbuf_len %zd", __func__,
1456 "packet_read_poll1: len %d != sshbuf_len %zd.",
1457 len, sshbuf_len(state->incoming_packet)); 1484 len, sshbuf_len(state->incoming_packet));
1485 if ((r = sshpkt_disconnect(ssh, "invalid packet length")) != 0 ||
1486 (r = ssh_packet_write_wait(ssh)) != 0)
1487 return r;
1488 return SSH_ERR_CONN_CORRUPT;
1489 }
1458 1490
1459 cp = sshbuf_ptr(state->incoming_packet) + len - 4; 1491 cp = sshbuf_ptr(state->incoming_packet) + len - 4;
1460 stored_checksum = PEEK_U32(cp); 1492 stored_checksum = PEEK_U32(cp);
1461 if (checksum != stored_checksum) 1493 if (checksum != stored_checksum) {
1462 ssh_packet_disconnect(ssh, 1494 error("Corrupted check bytes on input");
1463 "Corrupted check bytes on input."); 1495 if ((r = sshpkt_disconnect(ssh, "connection corrupted")) != 0 ||
1496 (r = ssh_packet_write_wait(ssh)) != 0)
1497 return r;
1498 return SSH_ERR_CONN_CORRUPT;
1499 }
1464 if ((r = sshbuf_consume_end(state->incoming_packet, 4)) < 0) 1500 if ((r = sshbuf_consume_end(state->incoming_packet, 4)) < 0)
1465 goto out; 1501 goto out;
1466 1502
@@ -1478,9 +1514,13 @@ ssh_packet_read_poll1(struct ssh *ssh, u_char *typep)
1478 state->p_read.bytes += padded_len + 4; 1514 state->p_read.bytes += padded_len + 4;
1479 if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0) 1515 if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
1480 goto out; 1516 goto out;
1481 if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) 1517 if (*typep < SSH_MSG_MIN || *typep > SSH_MSG_MAX) {
1482 ssh_packet_disconnect(ssh, 1518 error("Invalid ssh1 packet type: %d", *typep);
1483 "Invalid ssh1 packet type: %d", *typep); 1519 if ((r = sshpkt_disconnect(ssh, "invalid packet type")) != 0 ||
1520 (r = ssh_packet_write_wait(ssh)) != 0)
1521 return r;
1522 return SSH_ERR_PROTOCOL_ERROR;
1523 }
1484 r = 0; 1524 r = 0;
1485 out: 1525 out:
1486 return r; 1526 return r;
@@ -1634,7 +1674,6 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1634 if ((r = sshbuf_consume(state->input, mac->mac_len)) != 0) 1674 if ((r = sshbuf_consume(state->input, mac->mac_len)) != 0)
1635 goto out; 1675 goto out;
1636 } 1676 }
1637 /* XXX now it's safe to use fatal/packet_disconnect */
1638 if (seqnr_p != NULL) 1677 if (seqnr_p != NULL)
1639 *seqnr_p = state->p_read.seqnr; 1678 *seqnr_p = state->p_read.seqnr;
1640 if (++state->p_read.seqnr == 0) 1679 if (++state->p_read.seqnr == 0)
@@ -1648,9 +1687,13 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1648 /* get padlen */ 1687 /* get padlen */
1649 padlen = sshbuf_ptr(state->incoming_packet)[4]; 1688 padlen = sshbuf_ptr(state->incoming_packet)[4];
1650 DBG(debug("input: padlen %d", padlen)); 1689 DBG(debug("input: padlen %d", padlen));
1651 if (padlen < 4) 1690 if (padlen < 4) {
1652 ssh_packet_disconnect(ssh, 1691 if ((r = sshpkt_disconnect(ssh,
1653 "Corrupted padlen %d on input.", padlen); 1692 "Corrupted padlen %d on input.", padlen)) != 0 ||
1693 (r = ssh_packet_write_wait(ssh)) != 0)
1694 return r;
1695 return SSH_ERR_CONN_CORRUPT;
1696 }
1654 1697
1655 /* skip packet size + padlen, discard padding */ 1698 /* skip packet size + padlen, discard padding */
1656 if ((r = sshbuf_consume(state->incoming_packet, 4 + 1)) != 0 || 1699 if ((r = sshbuf_consume(state->incoming_packet, 4 + 1)) != 0 ||
@@ -1677,9 +1720,13 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
1677 */ 1720 */
1678 if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0) 1721 if ((r = sshbuf_get_u8(state->incoming_packet, typep)) != 0)
1679 goto out; 1722 goto out;
1680 if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) 1723 if (*typep < SSH2_MSG_MIN || *typep >= SSH2_MSG_LOCAL_MIN) {
1681 ssh_packet_disconnect(ssh, 1724 if ((r = sshpkt_disconnect(ssh,
1682 "Invalid ssh2 packet type: %d", *typep); 1725 "Invalid ssh2 packet type: %d", *typep)) != 0 ||
1726 (r = ssh_packet_write_wait(ssh)) != 0)
1727 return r;
1728 return SSH_ERR_PROTOCOL_ERROR;
1729 }
1683 if (*typep == SSH2_MSG_NEWKEYS) 1730 if (*typep == SSH2_MSG_NEWKEYS)
1684 r = ssh_set_newkeys(ssh, MODE_IN); 1731 r = ssh_set_newkeys(ssh, MODE_IN);
1685 else if (*typep == SSH2_MSG_USERAUTH_SUCCESS && !state->server_side) 1732 else if (*typep == SSH2_MSG_USERAUTH_SUCCESS && !state->server_side)
@@ -1816,9 +1863,8 @@ ssh_packet_remaining(struct ssh *ssh)
1816 * message is printed immediately, but only if the client is being executed 1863 * message is printed immediately, but only if the client is being executed
1817 * in verbose mode. These messages are primarily intended to ease debugging 1864 * in verbose mode. These messages are primarily intended to ease debugging
1818 * authentication problems. The length of the formatted message must not 1865 * authentication problems. The length of the formatted message must not
1819 * exceed 1024 bytes. This will automatically call packet_write_wait. 1866 * exceed 1024 bytes. This will automatically call ssh_packet_write_wait.
1820 */ 1867 */
1821
1822void 1868void
1823ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...) 1869ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
1824{ 1870{
@@ -1846,7 +1892,29 @@ ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
1846 (r = sshpkt_send(ssh)) != 0) 1892 (r = sshpkt_send(ssh)) != 0)
1847 fatal("%s: %s", __func__, ssh_err(r)); 1893 fatal("%s: %s", __func__, ssh_err(r));
1848 } 1894 }
1849 ssh_packet_write_wait(ssh); 1895 if ((r = ssh_packet_write_wait(ssh)) != 0)
1896 fatal("%s: %s", __func__, ssh_err(r));
1897}
1898
1899/*
1900 * Pretty-print connection-terminating errors and exit.
1901 */
1902void
1903sshpkt_fatal(struct ssh *ssh, const char *tag, int r)
1904{
1905 switch (r) {
1906 case SSH_ERR_CONN_CLOSED:
1907 logit("Connection closed by %.200s", ssh_remote_ipaddr(ssh));
1908 cleanup_exit(255);
1909 case SSH_ERR_CONN_TIMEOUT:
1910 logit("Connection to %.200s timed out while "
1911 "waiting to write", ssh_remote_ipaddr(ssh));
1912 cleanup_exit(255);
1913 default:
1914 fatal("%s%sConnection to %.200s: %s",
1915 tag != NULL ? tag : "", tag != NULL ? ": " : "",
1916 ssh_remote_ipaddr(ssh), ssh_err(r));
1917 }
1850} 1918}
1851 1919
1852/* 1920/*
@@ -1855,7 +1923,6 @@ ssh_packet_send_debug(struct ssh *ssh, const char *fmt,...)
1855 * should not contain a newline. The length of the formatted message must 1923 * should not contain a newline. The length of the formatted message must
1856 * not exceed 1024 bytes. 1924 * not exceed 1024 bytes.
1857 */ 1925 */
1858
1859void 1926void
1860ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...) 1927ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...)
1861{ 1928{
@@ -1879,30 +1946,26 @@ ssh_packet_disconnect(struct ssh *ssh, const char *fmt,...)
1879 /* Display the error locally */ 1946 /* Display the error locally */
1880 logit("Disconnecting: %.100s", buf); 1947 logit("Disconnecting: %.100s", buf);
1881 1948
1882 /* Send the disconnect message to the other side, and wait for it to get sent. */ 1949 /*
1883 if (compat20) { 1950 * Send the disconnect message to the other side, and wait
1884 if ((r = sshpkt_start(ssh, SSH2_MSG_DISCONNECT)) != 0 || 1951 * for it to get sent.
1885 (r = sshpkt_put_u32(ssh, SSH2_DISCONNECT_PROTOCOL_ERROR)) != 0 || 1952 */
1886 (r = sshpkt_put_cstring(ssh, buf)) != 0 || 1953 if ((r = sshpkt_disconnect(ssh, "%s", buf)) != 0)
1887 (r = sshpkt_put_cstring(ssh, "")) != 0 || 1954 sshpkt_fatal(ssh, __func__, r);
1888 (r = sshpkt_send(ssh)) != 0) 1955
1889 fatal("%s: %s", __func__, ssh_err(r)); 1956 if ((r = ssh_packet_write_wait(ssh)) != 0)
1890 } else { 1957 sshpkt_fatal(ssh, __func__, r);
1891 if ((r = sshpkt_start(ssh, SSH_MSG_DISCONNECT)) != 0 ||
1892 (r = sshpkt_put_cstring(ssh, buf)) != 0 ||
1893 (r = sshpkt_send(ssh)) != 0)
1894 fatal("%s: %s", __func__, ssh_err(r));
1895 }
1896 ssh_packet_write_wait(ssh);
1897 1958
1898 /* Close the connection. */ 1959 /* Close the connection. */
1899 ssh_packet_close(ssh); 1960 ssh_packet_close(ssh);
1900 cleanup_exit(255); 1961 cleanup_exit(255);
1901} 1962}
1902 1963
1903/* Checks if there is any buffered output, and tries to write some of the output. */ 1964/*
1904 1965 * Checks if there is any buffered output, and tries to write some of
1905void 1966 * the output.
1967 */
1968int
1906ssh_packet_write_poll(struct ssh *ssh) 1969ssh_packet_write_poll(struct ssh *ssh)
1907{ 1970{
1908 struct session_state *state = ssh->state; 1971 struct session_state *state = ssh->state;
@@ -1916,33 +1979,33 @@ ssh_packet_write_poll(struct ssh *ssh)
1916 if (len == -1) { 1979 if (len == -1) {
1917 if (errno == EINTR || errno == EAGAIN || 1980 if (errno == EINTR || errno == EAGAIN ||
1918 errno == EWOULDBLOCK) 1981 errno == EWOULDBLOCK)
1919 return; 1982 return 0;
1920 fatal("Write failed: %.100s", strerror(errno)); 1983 return SSH_ERR_SYSTEM_ERROR;
1921 } 1984 }
1922 if (len == 0 && !cont) 1985 if (len == 0 && !cont)
1923 fatal("Write connection closed"); 1986 return SSH_ERR_CONN_CLOSED;
1924 if ((r = sshbuf_consume(state->output, len)) != 0) 1987 if ((r = sshbuf_consume(state->output, len)) != 0)
1925 fatal("%s: %s", __func__, ssh_err(r)); 1988 return r;
1926 } 1989 }
1990 return 0;
1927} 1991}
1928 1992
1929/* 1993/*
1930 * Calls packet_write_poll repeatedly until all pending output data has been 1994 * Calls packet_write_poll repeatedly until all pending output data has been
1931 * written. 1995 * written.
1932 */ 1996 */
1933 1997int
1934void
1935ssh_packet_write_wait(struct ssh *ssh) 1998ssh_packet_write_wait(struct ssh *ssh)
1936{ 1999{
1937 fd_set *setp; 2000 fd_set *setp;
1938 int ret, ms_remain = 0; 2001 int ret, r, ms_remain = 0;
1939 struct timeval start, timeout, *timeoutp = NULL; 2002 struct timeval start, timeout, *timeoutp = NULL;
1940 struct session_state *state = ssh->state; 2003 struct session_state *state = ssh->state;
1941 2004
1942 setp = (fd_set *)calloc(howmany(state->connection_out + 1, 2005 setp = (fd_set *)calloc(howmany(state->connection_out + 1,
1943 NFDBITS), sizeof(fd_mask)); 2006 NFDBITS), sizeof(fd_mask));
1944 if (setp == NULL) 2007 if (setp == NULL)
1945 fatal("%s: calloc failed", __func__); 2008 return SSH_ERR_ALLOC_FAIL;
1946 ssh_packet_write_poll(ssh); 2009 ssh_packet_write_poll(ssh);
1947 while (ssh_packet_have_data_to_write(ssh)) { 2010 while (ssh_packet_have_data_to_write(ssh)) {
1948 memset(setp, 0, howmany(state->connection_out + 1, 2011 memset(setp, 0, howmany(state->connection_out + 1,
@@ -1973,13 +2036,16 @@ ssh_packet_write_wait(struct ssh *ssh)
1973 } 2036 }
1974 } 2037 }
1975 if (ret == 0) { 2038 if (ret == 0) {
1976 logit("Connection to %.200s timed out while " 2039 free(setp);
1977 "waiting to write", ssh_remote_ipaddr(ssh)); 2040 return SSH_ERR_CONN_TIMEOUT;
1978 cleanup_exit(255); 2041 }
2042 if ((r = ssh_packet_write_poll(ssh)) != 0) {
2043 free(setp);
2044 return r;
1979 } 2045 }
1980 ssh_packet_write_poll(ssh);
1981 } 2046 }
1982 free(setp); 2047 free(setp);
2048 return 0;
1983} 2049}
1984 2050
1985/* Returns true if there is buffered data to write to the connection. */ 2051/* Returns true if there is buffered data to write to the connection. */
diff --git a/packet.h b/packet.h
index 8a9d0f6c6..01df9f413 100644
--- a/packet.h
+++ b/packet.h
@@ -1,4 +1,4 @@
1/* $OpenBSD: packet.h,v 1.65 2015/01/28 21:15:47 djm Exp $ */ 1/* $OpenBSD: packet.h,v 1.66 2015/01/30 01:13:33 djm Exp $ */
2 2
3/* 3/*
4 * Author: Tatu Ylonen <ylo@cs.hut.fi> 4 * Author: Tatu Ylonen <ylo@cs.hut.fi>
@@ -90,7 +90,7 @@ int ssh_packet_send2_wrapped(struct ssh *);
90int ssh_packet_send2(struct ssh *); 90int ssh_packet_send2(struct ssh *);
91 91
92int ssh_packet_read(struct ssh *); 92int ssh_packet_read(struct ssh *);
93void ssh_packet_read_expect(struct ssh *, int type); 93int ssh_packet_read_expect(struct ssh *, u_int type);
94int ssh_packet_read_poll(struct ssh *); 94int ssh_packet_read_poll(struct ssh *);
95int ssh_packet_read_poll1(struct ssh *, u_char *); 95int ssh_packet_read_poll1(struct ssh *, u_char *);
96int ssh_packet_read_poll2(struct ssh *, u_char *, u_int32_t *seqnr_p); 96int ssh_packet_read_poll2(struct ssh *, u_char *, u_int32_t *seqnr_p);
@@ -112,8 +112,8 @@ typedef void (ssh_packet_comp_free_func)(void *, void *);
112void ssh_packet_set_compress_hooks(struct ssh *, void *, 112void ssh_packet_set_compress_hooks(struct ssh *, void *,
113 ssh_packet_comp_alloc_func *, ssh_packet_comp_free_func *); 113 ssh_packet_comp_alloc_func *, ssh_packet_comp_free_func *);
114 114
115void ssh_packet_write_poll(struct ssh *); 115int ssh_packet_write_poll(struct ssh *);
116void ssh_packet_write_wait(struct ssh *); 116int ssh_packet_write_wait(struct ssh *);
117int ssh_packet_have_data_to_write(struct ssh *); 117int ssh_packet_have_data_to_write(struct ssh *);
118int ssh_packet_not_very_much_data_to_write(struct ssh *); 118int ssh_packet_not_very_much_data_to_write(struct ssh *);
119 119
@@ -148,8 +148,10 @@ void *ssh_packet_get_output(struct ssh *);
148/* new API */ 148/* new API */
149int sshpkt_start(struct ssh *ssh, u_char type); 149int sshpkt_start(struct ssh *ssh, u_char type);
150int sshpkt_send(struct ssh *ssh); 150int sshpkt_send(struct ssh *ssh);
151int sshpkt_disconnect(struct ssh *, const char *fmt, ...) __attribute__((format(printf, 2, 3))); 151int sshpkt_disconnect(struct ssh *, const char *fmt, ...)
152 __attribute__((format(printf, 2, 3)));
152int sshpkt_add_padding(struct ssh *, u_char); 153int sshpkt_add_padding(struct ssh *, u_char);
154void sshpkt_fatal(struct ssh *ssh, const char *tag, int r);
153 155
154int sshpkt_put(struct ssh *ssh, const void *v, size_t len); 156int sshpkt_put(struct ssh *ssh, const void *v, size_t len);
155int sshpkt_putb(struct ssh *ssh, const struct sshbuf *b); 157int sshpkt_putb(struct ssh *ssh, const struct sshbuf *b);
diff --git a/ssh-keyscan.c b/ssh-keyscan.c
index e59eacace..989f7ecce 100644
--- a/ssh-keyscan.c
+++ b/ssh-keyscan.c
@@ -1,4 +1,4 @@
1/* $OpenBSD: ssh-keyscan.c,v 1.97 2015/01/28 21:15:47 djm Exp $ */ 1/* $OpenBSD: ssh-keyscan.c,v 1.98 2015/01/30 01:13:33 djm Exp $ */
2/* 2/*
3 * Copyright 1995, 1996 by David Mazieres <dm@lcs.mit.edu>. 3 * Copyright 1995, 1996 by David Mazieres <dm@lcs.mit.edu>.
4 * 4 *
@@ -466,7 +466,8 @@ congreet(int s)
466 return; 466 return;
467 } 467 }
468 *cp = '\0'; 468 *cp = '\0';
469 c->c_ssh = ssh_packet_set_connection(NULL, s, s); 469 if ((c->c_ssh = ssh_packet_set_connection(NULL, s, s)) == NULL)
470 fatal("ssh_packet_set_connection failed");
470 ssh_set_app_data(c->c_ssh, c); /* back link */ 471 ssh_set_app_data(c->c_ssh, c); /* back link */
471 if (sscanf(buf, "SSH-%d.%d-%[^\n]\n", 472 if (sscanf(buf, "SSH-%d.%d-%[^\n]\n",
472 &remote_major, &remote_minor, remote_version) == 3) 473 &remote_major, &remote_minor, remote_version) == 3)
diff --git a/ssh_api.c b/ssh_api.c
index 9794e0e57..7097c063c 100644
--- a/ssh_api.c
+++ b/ssh_api.c
@@ -1,4 +1,4 @@
1/* $OpenBSD: ssh_api.c,v 1.2 2015/01/26 06:10:03 djm Exp $ */ 1/* $OpenBSD: ssh_api.c,v 1.3 2015/01/30 01:13:33 djm Exp $ */
2/* 2/*
3 * Copyright (c) 2012 Markus Friedl. All rights reserved. 3 * Copyright (c) 2012 Markus Friedl. All rights reserved.
4 * 4 *
@@ -85,7 +85,8 @@ ssh_init(struct ssh **sshp, int is_server, struct kex_params *kex_params)
85 called = 1; 85 called = 1;
86 } 86 }
87 87
88 ssh = ssh_packet_set_connection(NULL, -1, -1); 88 if ((ssh = ssh_packet_set_connection(NULL, -1, -1)) == NULL)
89 return SSH_ERR_ALLOC_FAIL;
89 if (is_server) 90 if (is_server)
90 ssh_packet_set_server(ssh); 91 ssh_packet_set_server(ssh);
91 92
diff --git a/ssherr.c b/ssherr.c
index 0b79fbb00..5c29c467c 100644
--- a/ssherr.c
+++ b/ssherr.c
@@ -1,4 +1,4 @@
1/* $OpenBSD: ssherr.c,v 1.2 2015/01/28 21:15:47 djm Exp $ */ 1/* $OpenBSD: ssherr.c,v 1.3 2015/01/30 01:13:33 djm Exp $ */
2/* 2/*
3 * Copyright (c) 2011 Damien Miller 3 * Copyright (c) 2011 Damien Miller
4 * 4 *
@@ -129,6 +129,10 @@ ssh_err(int n)
129 return "Connection closed"; 129 return "Connection closed";
130 case SSH_ERR_CONN_TIMEOUT: 130 case SSH_ERR_CONN_TIMEOUT:
131 return "Connection timed out"; 131 return "Connection timed out";
132 case SSH_ERR_CONN_CORRUPT:
133 return "Connection corrupted";
134 case SSH_ERR_PROTOCOL_ERROR:
135 return "Protocol error";
132 default: 136 default:
133 return "unknown error"; 137 return "unknown error";
134 } 138 }
diff --git a/ssherr.h b/ssherr.h
index ac5cd159a..6f771b4b7 100644
--- a/ssherr.h
+++ b/ssherr.h
@@ -1,4 +1,4 @@
1/* $OpenBSD: ssherr.h,v 1.2 2015/01/28 21:15:47 djm Exp $ */ 1/* $OpenBSD: ssherr.h,v 1.3 2015/01/30 01:13:33 djm Exp $ */
2/* 2/*
3 * Copyright (c) 2011 Damien Miller 3 * Copyright (c) 2011 Damien Miller
4 * 4 *
@@ -75,6 +75,8 @@
75#define SSH_ERR_KEY_REVOKED -51 75#define SSH_ERR_KEY_REVOKED -51
76#define SSH_ERR_CONN_CLOSED -52 76#define SSH_ERR_CONN_CLOSED -52
77#define SSH_ERR_CONN_TIMEOUT -53 77#define SSH_ERR_CONN_TIMEOUT -53
78#define SSH_ERR_CONN_CORRUPT -54
79#define SSH_ERR_PROTOCOL_ERROR -55
78 80
79/* Translate a numeric error code to a human-readable error string */ 81/* Translate a numeric error code to a human-readable error string */
80const char *ssh_err(int n); 82const char *ssh_err(int n);