summaryrefslogtreecommitdiff
path: root/ssh-agent.c
diff options
context:
space:
mode:
authordjm@openbsd.org <djm@openbsd.org>2017-07-19 01:15:02 +0000
committerDamien Miller <djm@mindrot.org>2017-07-21 14:17:33 +1000
commitfd0e8fa5f89d21290b1fb5f9d110ca4f113d81d9 (patch)
treea9b803cc12096cf74eabe13ff7dab974ad3bd09c /ssh-agent.c
parentb1e72df2b813ecc15bd0152167bf4af5f91c36d3 (diff)
upstream commit
switch from select() to poll() for the ssh-agent mainloop; ok markus Upstream-ID: 4a94888ee67b3fd948fd10693973beb12f802448
Diffstat (limited to 'ssh-agent.c')
-rw-r--r--ssh-agent.c312
1 files changed, 185 insertions, 127 deletions
diff --git a/ssh-agent.c b/ssh-agent.c
index eb8c2043d..d858c2470 100644
--- a/ssh-agent.c
+++ b/ssh-agent.c
@@ -1,4 +1,4 @@
1/* $OpenBSD: ssh-agent.c,v 1.222 2017/07/01 13:50:45 djm Exp $ */ 1/* $OpenBSD: ssh-agent.c,v 1.223 2017/07/19 01:15:02 djm Exp $ */
2/* 2/*
3 * Author: Tatu Ylonen <ylo@cs.hut.fi> 3 * Author: Tatu Ylonen <ylo@cs.hut.fi>
4 * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland 4 * Copyright (c) 1995 Tatu Ylonen <ylo@cs.hut.fi>, Espoo, Finland
@@ -60,6 +60,9 @@
60#ifdef HAVE_PATHS_H 60#ifdef HAVE_PATHS_H
61# include <paths.h> 61# include <paths.h>
62#endif 62#endif
63#ifdef HAVE_POLL_H
64# include <poll.h>
65#endif
63#include <signal.h> 66#include <signal.h>
64#include <stdarg.h> 67#include <stdarg.h>
65#include <stdio.h> 68#include <stdio.h>
@@ -91,6 +94,9 @@
91# define DEFAULT_PKCS11_WHITELIST "/usr/lib*/*,/usr/local/lib*/*" 94# define DEFAULT_PKCS11_WHITELIST "/usr/lib*/*,/usr/local/lib*/*"
92#endif 95#endif
93 96
97/* Maximum accepted message length */
98#define AGENT_MAX_LEN (256*1024)
99
94typedef enum { 100typedef enum {
95 AUTH_UNUSED, 101 AUTH_UNUSED,
96 AUTH_SOCKET, 102 AUTH_SOCKET,
@@ -634,30 +640,46 @@ send:
634 640
635/* dispatch incoming messages */ 641/* dispatch incoming messages */
636 642
637static void 643static int
638process_message(SocketEntry *e) 644process_message(u_int socknum)
639{ 645{
640 u_int msg_len; 646 u_int msg_len;
641 u_char type; 647 u_char type;
642 const u_char *cp; 648 const u_char *cp;
643 int r; 649 int r;
650 SocketEntry *e;
651
652 if (socknum >= sockets_alloc) {
653 fatal("%s: socket number %u >= allocated %u",
654 __func__, socknum, sockets_alloc);
655 }
656 e = &sockets[socknum];
644 657
645 if (sshbuf_len(e->input) < 5) 658 if (sshbuf_len(e->input) < 5)
646 return; /* Incomplete message. */ 659 return 0; /* Incomplete message header. */
647 cp = sshbuf_ptr(e->input); 660 cp = sshbuf_ptr(e->input);
648 msg_len = PEEK_U32(cp); 661 msg_len = PEEK_U32(cp);
649 if (msg_len > 256 * 1024) { 662 if (msg_len > AGENT_MAX_LEN) {
650 close_socket(e); 663 debug("%s: socket %u (fd=%d) message too long %u > %u",
651 return; 664 __func__, socknum, e->fd, msg_len, AGENT_MAX_LEN);
665 return -1;
652 } 666 }
653 if (sshbuf_len(e->input) < msg_len + 4) 667 if (sshbuf_len(e->input) < msg_len + 4)
654 return; 668 return 0; /* Incomplete message body. */
655 669
656 /* move the current input to e->request */ 670 /* move the current input to e->request */
657 sshbuf_reset(e->request); 671 sshbuf_reset(e->request);
658 if ((r = sshbuf_get_stringb(e->input, e->request)) != 0 || 672 if ((r = sshbuf_get_stringb(e->input, e->request)) != 0 ||
659 (r = sshbuf_get_u8(e->request, &type)) != 0) 673 (r = sshbuf_get_u8(e->request, &type)) != 0) {
674 if (r == SSH_ERR_MESSAGE_INCOMPLETE ||
675 r == SSH_ERR_STRING_TOO_LARGE) {
676 debug("%s: buffer error: %s", __func__, ssh_err(r));
677 return -1;
678 }
660 fatal("%s: buffer error: %s", __func__, ssh_err(r)); 679 fatal("%s: buffer error: %s", __func__, ssh_err(r));
680 }
681
682 debug("%s: socket %u (fd=%d) type %d", __func__, socknum, e->fd, type);
661 683
662 /* check wheter agent is locked */ 684 /* check wheter agent is locked */
663 if (locked && type != SSH_AGENTC_UNLOCK) { 685 if (locked && type != SSH_AGENTC_UNLOCK) {
@@ -671,10 +693,9 @@ process_message(SocketEntry *e)
671 /* send a fail message for all other request types */ 693 /* send a fail message for all other request types */
672 send_status(e, 0); 694 send_status(e, 0);
673 } 695 }
674 return; 696 return 0;
675 } 697 }
676 698
677 debug("type %d", type);
678 switch (type) { 699 switch (type) {
679 case SSH_AGENTC_LOCK: 700 case SSH_AGENTC_LOCK:
680 case SSH_AGENTC_UNLOCK: 701 case SSH_AGENTC_UNLOCK:
@@ -716,6 +737,7 @@ process_message(SocketEntry *e)
716 send_status(e, 0); 737 send_status(e, 0);
717 break; 738 break;
718 } 739 }
740 return 0;
719} 741}
720 742
721static void 743static void
@@ -757,19 +779,141 @@ new_socket(sock_type type, int fd)
757} 779}
758 780
759static int 781static int
760prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp, 782handle_socket_read(u_int socknum)
761 struct timeval **tvpp) 783{
784 struct sockaddr_un sunaddr;
785 socklen_t slen;
786 uid_t euid;
787 gid_t egid;
788 int fd;
789
790 slen = sizeof(sunaddr);
791 fd = accept(sockets[socknum].fd, (struct sockaddr *)&sunaddr, &slen);
792 if (fd < 0) {
793 error("accept from AUTH_SOCKET: %s", strerror(errno));
794 return -1;
795 }
796 if (getpeereid(fd, &euid, &egid) < 0) {
797 error("getpeereid %d failed: %s", fd, strerror(errno));
798 close(fd);
799 return -1;
800 }
801 if ((euid != 0) && (getuid() != euid)) {
802 error("uid mismatch: peer euid %u != uid %u",
803 (u_int) euid, (u_int) getuid());
804 close(fd);
805 return -1;
806 }
807 new_socket(AUTH_CONNECTION, fd);
808 return 0;
809}
810
811static int
812handle_conn_read(u_int socknum)
813{
814 char buf[1024];
815 ssize_t len;
816 int r;
817
818 if ((len = read(sockets[socknum].fd, buf, sizeof(buf))) <= 0) {
819 if (len == -1) {
820 if (errno == EAGAIN || errno == EINTR)
821 return 0;
822 error("%s: read error on socket %u (fd %d): %s",
823 __func__, socknum, sockets[socknum].fd,
824 strerror(errno));
825 }
826 return -1;
827 }
828 if ((r = sshbuf_put(sockets[socknum].input, buf, len)) != 0)
829 fatal("%s: buffer error: %s", __func__, ssh_err(r));
830 explicit_bzero(buf, sizeof(buf));
831 process_message(socknum);
832 return 0;
833}
834
835static int
836handle_conn_write(u_int socknum)
837{
838 ssize_t len;
839 int r;
840
841 if (sshbuf_len(sockets[socknum].output) == 0)
842 return 0; /* shouldn't happen */
843 if ((len = write(sockets[socknum].fd,
844 sshbuf_ptr(sockets[socknum].output),
845 sshbuf_len(sockets[socknum].output))) <= 0) {
846 if (len == -1) {
847 if (errno == EAGAIN || errno == EINTR)
848 return 0;
849 error("%s: read error on socket %u (fd %d): %s",
850 __func__, socknum, sockets[socknum].fd,
851 strerror(errno));
852 }
853 return -1;
854 }
855 if ((r = sshbuf_consume(sockets[socknum].output, len)) != 0)
856 fatal("%s: buffer error: %s", __func__, ssh_err(r));
857 return 0;
858}
859
860static void
861after_poll(struct pollfd *pfd, size_t npfd)
762{ 862{
763 u_int i, sz; 863 size_t i;
764 int n = 0; 864 u_int socknum;
765 static struct timeval tv; 865
866 for (i = 0; i < npfd; i++) {
867 if (pfd[i].revents == 0)
868 continue;
869 /* Find sockets entry */
870 for (socknum = 0; socknum < sockets_alloc; socknum++) {
871 if (sockets[socknum].type != AUTH_SOCKET &&
872 sockets[socknum].type != AUTH_CONNECTION)
873 continue;
874 if (pfd[i].fd == sockets[socknum].fd)
875 break;
876 }
877 if (socknum >= sockets_alloc) {
878 error("%s: no socket for fd %d", __func__, pfd[i].fd);
879 continue;
880 }
881 /* Process events */
882 switch (sockets[socknum].type) {
883 case AUTH_SOCKET:
884 if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
885 handle_socket_read(socknum) != 0)
886 close_socket(&sockets[socknum]);
887 break;
888 case AUTH_CONNECTION:
889 if ((pfd[i].revents & (POLLIN|POLLERR)) != 0 &&
890 handle_conn_read(socknum) != 0) {
891 close_socket(&sockets[socknum]);
892 break;
893 }
894 if ((pfd[i].revents & (POLLOUT|POLLHUP)) != 0 &&
895 handle_conn_write(socknum) != 0)
896 close_socket(&sockets[socknum]);
897 break;
898 default:
899 break;
900 }
901 }
902}
903
904static int
905prepare_poll(struct pollfd **pfdp, size_t *npfdp, int *timeoutp)
906{
907 struct pollfd *pfd = *pfdp;
908 size_t i, j, npfd = 0;
766 time_t deadline; 909 time_t deadline;
767 910
911 /* Count active sockets */
768 for (i = 0; i < sockets_alloc; i++) { 912 for (i = 0; i < sockets_alloc; i++) {
769 switch (sockets[i].type) { 913 switch (sockets[i].type) {
770 case AUTH_SOCKET: 914 case AUTH_SOCKET:
771 case AUTH_CONNECTION: 915 case AUTH_CONNECTION:
772 n = MAXIMUM(n, sockets[i].fd); 916 npfd++;
773 break; 917 break;
774 case AUTH_UNUSED: 918 case AUTH_UNUSED:
775 break; 919 break;
@@ -778,28 +922,23 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
778 break; 922 break;
779 } 923 }
780 } 924 }
925 if (npfd != *npfdp &&
926 (pfd = recallocarray(pfd, *npfdp, npfd, sizeof(*pfd))) == NULL)
927 fatal("%s: recallocarray failed", __func__);
928 *pfdp = pfd;
929 *npfdp = npfd;
781 930
782 sz = howmany(n+1, NFDBITS) * sizeof(fd_mask); 931 for (i = j = 0; i < sockets_alloc; i++) {
783 if (*fdrp == NULL || sz > *nallocp) {
784 free(*fdrp);
785 free(*fdwp);
786 *fdrp = xmalloc(sz);
787 *fdwp = xmalloc(sz);
788 *nallocp = sz;
789 }
790 if (n < *fdl)
791 debug("XXX shrink: %d < %d", n, *fdl);
792 *fdl = n;
793 memset(*fdrp, 0, sz);
794 memset(*fdwp, 0, sz);
795
796 for (i = 0; i < sockets_alloc; i++) {
797 switch (sockets[i].type) { 932 switch (sockets[i].type) {
798 case AUTH_SOCKET: 933 case AUTH_SOCKET:
799 case AUTH_CONNECTION: 934 case AUTH_CONNECTION:
800 FD_SET(sockets[i].fd, *fdrp); 935 pfd[j].fd = sockets[i].fd;
936 pfd[j].revents = 0;
937 /* XXX backoff when input buffer full */
938 pfd[j].events = POLLIN;
801 if (sshbuf_len(sockets[i].output) > 0) 939 if (sshbuf_len(sockets[i].output) > 0)
802 FD_SET(sockets[i].fd, *fdwp); 940 pfd[j].events |= POLLOUT;
941 j++;
803 break; 942 break;
804 default: 943 default:
805 break; 944 break;
@@ -810,99 +949,17 @@ prepare_select(fd_set **fdrp, fd_set **fdwp, int *fdl, u_int *nallocp,
810 deadline = (deadline == 0) ? parent_alive_interval : 949 deadline = (deadline == 0) ? parent_alive_interval :
811 MINIMUM(deadline, parent_alive_interval); 950 MINIMUM(deadline, parent_alive_interval);
812 if (deadline == 0) { 951 if (deadline == 0) {
813 *tvpp = NULL; 952 *timeoutp = INFTIM;
814 } else { 953 } else {
815 tv.tv_sec = deadline; 954 if (deadline > INT_MAX / 1000)
816 tv.tv_usec = 0; 955 *timeoutp = INT_MAX / 1000;
817 *tvpp = &tv; 956 else
957 *timeoutp = deadline * 1000;
818 } 958 }
819 return (1); 959 return (1);
820} 960}
821 961
822static void 962static void
823after_select(fd_set *readset, fd_set *writeset)
824{
825 struct sockaddr_un sunaddr;
826 socklen_t slen;
827 char buf[1024];
828 int len, sock, r;
829 u_int i, orig_alloc;
830 uid_t euid;
831 gid_t egid;
832
833 for (i = 0, orig_alloc = sockets_alloc; i < orig_alloc; i++)
834 switch (sockets[i].type) {
835 case AUTH_UNUSED:
836 break;
837 case AUTH_SOCKET:
838 if (FD_ISSET(sockets[i].fd, readset)) {
839 slen = sizeof(sunaddr);
840 sock = accept(sockets[i].fd,
841 (struct sockaddr *)&sunaddr, &slen);
842 if (sock < 0) {
843 error("accept from AUTH_SOCKET: %s",
844 strerror(errno));
845 break;
846 }
847 if (getpeereid(sock, &euid, &egid) < 0) {
848 error("getpeereid %d failed: %s",
849 sock, strerror(errno));
850 close(sock);
851 break;
852 }
853 if ((euid != 0) && (getuid() != euid)) {
854 error("uid mismatch: "
855 "peer euid %u != uid %u",
856 (u_int) euid, (u_int) getuid());
857 close(sock);
858 break;
859 }
860 new_socket(AUTH_CONNECTION, sock);
861 }
862 break;
863 case AUTH_CONNECTION:
864 if (sshbuf_len(sockets[i].output) > 0 &&
865 FD_ISSET(sockets[i].fd, writeset)) {
866 len = write(sockets[i].fd,
867 sshbuf_ptr(sockets[i].output),
868 sshbuf_len(sockets[i].output));
869 if (len == -1 && (errno == EAGAIN ||
870 errno == EWOULDBLOCK ||
871 errno == EINTR))
872 continue;
873 if (len <= 0) {
874 close_socket(&sockets[i]);
875 break;
876 }
877 if ((r = sshbuf_consume(sockets[i].output,
878 len)) != 0)
879 fatal("%s: buffer error: %s",
880 __func__, ssh_err(r));
881 }
882 if (FD_ISSET(sockets[i].fd, readset)) {
883 len = read(sockets[i].fd, buf, sizeof(buf));
884 if (len == -1 && (errno == EAGAIN ||
885 errno == EWOULDBLOCK ||
886 errno == EINTR))
887 continue;
888 if (len <= 0) {
889 close_socket(&sockets[i]);
890 break;
891 }
892 if ((r = sshbuf_put(sockets[i].input,
893 buf, len)) != 0)
894 fatal("%s: buffer error: %s",
895 __func__, ssh_err(r));
896 explicit_bzero(buf, sizeof(buf));
897 process_message(&sockets[i]);
898 }
899 break;
900 default:
901 fatal("Unknown type %d", sockets[i].type);
902 }
903}
904
905static void
906cleanup_socket(void) 963cleanup_socket(void)
907{ 964{
908 if (cleanup_pid != 0 && getpid() != cleanup_pid) 965 if (cleanup_pid != 0 && getpid() != cleanup_pid)
@@ -963,7 +1020,6 @@ main(int ac, char **av)
963 int sock, fd, ch, result, saved_errno; 1020 int sock, fd, ch, result, saved_errno;
964 u_int nalloc; 1021 u_int nalloc;
965 char *shell, *format, *pidstr, *agentsocket = NULL; 1022 char *shell, *format, *pidstr, *agentsocket = NULL;
966 fd_set *readsetp = NULL, *writesetp = NULL;
967#ifdef HAVE_SETRLIMIT 1023#ifdef HAVE_SETRLIMIT
968 struct rlimit rlim; 1024 struct rlimit rlim;
969#endif 1025#endif
@@ -971,9 +1027,11 @@ main(int ac, char **av)
971 extern char *optarg; 1027 extern char *optarg;
972 pid_t pid; 1028 pid_t pid;
973 char pidstrbuf[1 + 3 * sizeof pid]; 1029 char pidstrbuf[1 + 3 * sizeof pid];
974 struct timeval *tvp = NULL;
975 size_t len; 1030 size_t len;
976 mode_t prev_mask; 1031 mode_t prev_mask;
1032 int timeout = INFTIM;
1033 struct pollfd *pfd = NULL;
1034 size_t npfd = 0;
977 1035
978 ssh_malloc_init(); /* must be called before any mallocs */ 1036 ssh_malloc_init(); /* must be called before any mallocs */
979 /* Ensure that fds 0, 1 and 2 are open or directed to /dev/null */ 1037 /* Ensure that fds 0, 1 and 2 are open or directed to /dev/null */
@@ -1201,8 +1259,8 @@ skip:
1201 platform_pledge_agent(); 1259 platform_pledge_agent();
1202 1260
1203 while (1) { 1261 while (1) {
1204 prepare_select(&readsetp, &writesetp, &max_fd, &nalloc, &tvp); 1262 prepare_poll(&pfd, &npfd, &timeout);
1205 result = select(max_fd + 1, readsetp, writesetp, NULL, tvp); 1263 result = poll(pfd, npfd, timeout);
1206 saved_errno = errno; 1264 saved_errno = errno;
1207 if (parent_alive_interval != 0) 1265 if (parent_alive_interval != 0)
1208 check_parent_exists(); 1266 check_parent_exists();
@@ -1210,9 +1268,9 @@ skip:
1210 if (result < 0) { 1268 if (result < 0) {
1211 if (saved_errno == EINTR) 1269 if (saved_errno == EINTR)
1212 continue; 1270 continue;
1213 fatal("select: %s", strerror(saved_errno)); 1271 fatal("poll: %s", strerror(saved_errno));
1214 } else if (result > 0) 1272 } else if (result > 0)
1215 after_select(readsetp, writesetp); 1273 after_poll(pfd, npfd);
1216 } 1274 }
1217 /* NOTREACHED */ 1275 /* NOTREACHED */
1218} 1276}