summaryrefslogtreecommitdiff
path: root/testing/hstox/driver.c
diff options
context:
space:
mode:
Diffstat (limited to 'testing/hstox/driver.c')
-rw-r--r--testing/hstox/driver.c244
1 files changed, 244 insertions, 0 deletions
diff --git a/testing/hstox/driver.c b/testing/hstox/driver.c
new file mode 100644
index 00000000..81773190
--- /dev/null
+++ b/testing/hstox/driver.c
@@ -0,0 +1,244 @@
1#include <errno.h>
2#include <fcntl.h>
3#include <netdb.h>
4#include <signal.h>
5#include <signal.h>
6#include <stdarg.h>
7#include <stdio.h>
8#include <string.h>
9#include <sys/socket.h>
10#include <sys/types.h>
11#include <unistd.h>
12
13#include "driver.h"
14#include "errors.h"
15#include "methods.h"
16#include "util.h"
17
18#include <sodium.h>
19
20static void handle_interrupt(int signum)
21{
22 printf("Caught signal %d; exiting cleanly.\n", signum);
23 exit(0);
24}
25
26static int protocol_error(msgpack_packer *pk, char const *fmt, ...)
27{
28 msgpack_pack_array(pk, 4); // 4 elements in the array
29 msgpack_pack_uint8(pk, 1); // 1. type = response
30 // 2. We don't know the msgid, because the packet we received is not a valid
31 // msgpack-rpc packet.
32 msgpack_pack_uint64(pk, 0);
33
34 // 3. Error message.
35 va_list ap;
36 va_start(ap, fmt);
37 int res = msgpack_pack_vstringf(pk, fmt, ap);
38 va_end(ap);
39
40 // 4. No success result.
41 msgpack_pack_array(pk, 0);
42
43 return res;
44}
45
46static bool type_check(msgpack_packer *pk, msgpack_object req, int index,
47 msgpack_object_type type)
48{
49 if (req.via.array.ptr[index].type != type) {
50 protocol_error(pk, "element %d should be %s, but is %s", index, type_name(type),
51 type_name(req.via.array.ptr[index].type));
52 return false;
53 }
54
55 return true;
56}
57
58static int write_sample_input(msgpack_object req)
59{
60 static unsigned int n;
61
62 char filename[256];
63 msgpack_object_str name = req.via.array.ptr[2].via.str;
64 snprintf(filename, sizeof filename - name.size, "test-inputs/%04u-", n++);
65
66 assert(sizeof filename - strlen(filename) > name.size + 4);
67 memcpy(filename + strlen(filename) + name.size, ".mp", 4);
68 memcpy(filename + strlen(filename), name.ptr, name.size);
69
70 int fd = open(filename, O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
71
72 if (fd < 0)
73 // If we can't open the sample file, we just don't write it.
74 {
75 return E_OK;
76 }
77
78 check_return(E_WRITE, ftruncate(fd, 0));
79
80 msgpack_sbuffer sbuf __attribute__((__cleanup__(msgpack_sbuffer_destroy)));
81 msgpack_sbuffer_init(&sbuf);
82
83 msgpack_packer pk;
84 msgpack_packer_init(&pk, &sbuf, msgpack_sbuffer_write);
85
86 msgpack_pack_object(&pk, req);
87
88 check_return(E_WRITE, write(fd, sbuf.data, sbuf.size));
89
90 return E_OK;
91}
92
93static int handle_request(struct settings cfg, int write_fd, msgpack_object req)
94{
95 msgpack_sbuffer sbuf __attribute__((__cleanup__(msgpack_sbuffer_destroy))); /* buffer */
96 msgpack_sbuffer_init(&sbuf); /* initialize buffer */
97
98 msgpack_packer pk; /* packer */
99 msgpack_packer_init(&pk, &sbuf, msgpack_sbuffer_write); /* initialize packer */
100
101 if (req.type != MSGPACK_OBJECT_ARRAY) {
102 protocol_error(&pk, "expected array, but got %s", type_name(req.type));
103 } else if (req.via.array.size != 4) {
104 protocol_error(&pk, "array length should be 4, but is %d", req.via.array.size);
105 } else if (type_check(&pk, req, 0, MSGPACK_OBJECT_POSITIVE_INTEGER) &&
106 type_check(&pk, req, 1, MSGPACK_OBJECT_POSITIVE_INTEGER) &&
107 type_check(&pk, req, 2, MSGPACK_OBJECT_STR) &&
108 type_check(&pk, req, 3, MSGPACK_OBJECT_ARRAY)) {
109 if (cfg.collect_samples) {
110 propagate(write_sample_input(req));
111 }
112
113 uint64_t msgid = req.via.array.ptr[1].via.u64;
114 msgpack_object_str name = req.via.array.ptr[2].via.str;
115 msgpack_object_array args = req.via.array.ptr[3].via.array;
116
117 msgpack_pack_array(&pk, 4); // 4 elements in the array
118 msgpack_pack_uint8(&pk, 1); // 1. type = response
119 msgpack_pack_uint64(&pk, msgid); // 2. msgid
120
121 if (name.size == (sizeof "rpc.capabilities") - 1 &&
122 memcmp(name.ptr, "rpc.capabilities", name.size) == 0) {
123 // 3. Error.
124 msgpack_pack_string(&pk, "Capabilities negiotiation not implemented");
125 // 4. No result.
126 msgpack_pack_nil(&pk);
127 } else {
128 // if error is null, this writes 3. no error, and 4. result
129 char const *error =
130 call_method(name, args, &pk);
131
132 if (error) {
133 if (cfg.debug) {
134 printf("Error '%s' in request: ", error);
135 msgpack_object_print(stdout, req);
136 printf("\n");
137 }
138
139 msgpack_pack_string(&pk, error);
140 msgpack_pack_array(&pk, 0);
141 }
142 }
143 }
144
145 check_return(E_WRITE, write(write_fd, sbuf.data, sbuf.size));
146
147 return E_OK;
148}
149
150int communicate(struct settings cfg, int read_fd, int write_fd)
151{
152 msgpack_unpacker unp __attribute__((__cleanup__(msgpack_unpacker_destroy)));
153 msgpack_unpacker_init(&unp, 128);
154
155 while (true) {
156 char buf[64];
157 int size = check_return(E_READ, read(read_fd, buf, sizeof buf));
158
159 if (size == 0) {
160 break;
161 }
162
163 if (msgpack_unpacker_buffer_capacity(&unp) < size &&
164 !msgpack_unpacker_reserve_buffer(&unp, size)) {
165 return E_NOMEM;
166 }
167
168 memcpy(msgpack_unpacker_buffer(&unp), buf, size);
169 msgpack_unpacker_buffer_consumed(&unp, size);
170
171 msgpack_unpacked req __attribute__((__cleanup__(msgpack_unpacked_destroy)));
172 msgpack_unpacked_init(&req);
173
174 switch (msgpack_unpacker_next(&unp, &req)) {
175 case MSGPACK_UNPACK_SUCCESS:
176 propagate(handle_request(cfg, write_fd, req.data));
177 break;
178
179 case MSGPACK_UNPACK_EXTRA_BYTES:
180 printf("EXTRA_BYTES\n");
181 break;
182
183 case MSGPACK_UNPACK_CONTINUE:
184 break;
185
186 case MSGPACK_UNPACK_PARSE_ERROR:
187 return E_PARSE;
188
189 case MSGPACK_UNPACK_NOMEM_ERROR:
190 return E_NOMEM;
191 }
192 }
193
194 return E_OK;
195}
196
197static int closep(int *fd)
198{
199 return close(*fd);
200}
201
202static int run_tests(struct settings cfg, int port)
203{
204 int listen_fd __attribute__((__cleanup__(closep))) = 0;
205 listen_fd = check_return(E_SOCKET, socket(AF_INET, SOCK_STREAM, 0));
206 check_return(E_SOCKET, setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &(int) {
207 1
208 }, sizeof(int)));
209
210 struct sockaddr_in servaddr;
211 servaddr.sin_family = AF_INET;
212 servaddr.sin_addr.s_addr = htons(INADDR_ANY);
213 servaddr.sin_port = htons(port);
214
215 check_return(E_BIND, bind(listen_fd, (struct sockaddr *)&servaddr, sizeof servaddr));
216 check_return(E_LISTEN, listen(listen_fd, 10));
217
218 while (true) {
219 int comm_fd __attribute__((__cleanup__(closep))) = 0;
220 comm_fd = check_return(E_ACCEPT, accept(listen_fd, NULL, NULL));
221 propagate(communicate(cfg, comm_fd, comm_fd));
222 }
223
224 return E_OK;
225}
226
227uint32_t network_main(struct settings cfg, uint16_t port, unsigned int timeout)
228{
229 signal(SIGALRM, handle_interrupt);
230 signal(SIGINT, handle_interrupt);
231 check_return(E_SODIUM, sodium_init());
232
233 // Kill the process after `timeout` seconds so we don't get lingering
234 // processes bound to the test port when something goes wrong with a test run.
235 alarm(timeout);
236
237 int result = run_tests(cfg, port);
238
239 if (result == E_OK) {
240 return E_OK;
241 }
242
243 return result | (errno << 8);
244}