[svn-upgrade] Integrating new upstream version, iodine (0.5.1)
[debian/iodine.git] / src / iodined.c
1 /*
2  * Copyright (c) 2006-2009 Bjorn Andersson <flex@kryo.se>, Erik Ekman <yarrick@kryo.se>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  *
8  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15  */
16
17 #include <stdio.h>
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <signal.h>
22 #include <unistd.h>
23 #include <sys/types.h>
24 #include <sys/param.h>
25 #include <sys/time.h>
26 #include <fcntl.h>
27 #include <time.h>
28 #include <zlib.h>
29
30 #ifdef WINDOWS32
31 #include "windows.h"
32 #include <winsock2.h>
33 #else
34 #include <arpa/nameser.h>
35 #ifdef DARWIN
36 #include <arpa/nameser8_compat.h>
37 #endif
38 #define _XPG4_2
39 #include <sys/socket.h>
40 #include <err.h>
41 #include <arpa/inet.h>
42 #include <netinet/in.h>
43 #include <netinet/in_systm.h>
44 #include <netinet/ip.h>
45 #include <grp.h>
46 #include <sys/uio.h>
47 #include <pwd.h>
48 #include <netdb.h>
49 #include <syslog.h>
50 #endif
51
52 #include "common.h"
53 #include "dns.h"
54 #include "encoding.h"
55 #include "base32.h"
56 #include "base64.h"
57 #include "user.h"
58 #include "login.h"
59 #include "tun.h"
60 #include "fw_query.h"
61 #include "version.h"
62
63 #ifdef WINDOWS32
64 WORD req_version = MAKEWORD(2, 2);
65 WSADATA wsa_data;
66 #endif
67
68 static int running = 1;
69 static char *topdomain;
70 static char password[33];
71 static struct encoder *b32;
72 static int created_users;
73
74 static int check_ip;
75 static int my_mtu;
76 static in_addr_t my_ip;
77 static int netmask;
78
79 static in_addr_t ns_ip;
80
81 static int bind_port;
82 static int debug;
83
84 #if !defined(BSD) && !defined(__GLIBC__)
85 static char *__progname;
86 #endif
87
88 static int read_dns(int, struct query *);
89 static void write_dns(int, struct query *, char *, int);
90
91 static void
92 sigint(int sig) 
93 {
94         running = 0;
95 }
96
97 #ifdef WINDOWS32
98 #define LOG_EMERG       0
99 #define LOG_ALERT       1
100 #define LOG_CRIT        2
101 #define LOG_ERR         3
102 #define LOG_WARNING     4
103 #define LOG_NOTICE      5
104 #define LOG_INFO        6
105 #define LOG_DEBUG       7
106 static void
107 syslog(int a, const char *str, ...)
108 {
109         /* TODO: implement (add to event log), move to common.c */
110         ;
111 }
112 #endif
113
114 static int
115 check_user_and_ip(int userid, struct query *q)
116 {
117         struct sockaddr_in *tempin;
118
119         if (userid < 0 || userid >= created_users ) {
120                 return 1; 
121         }
122         if (!users[userid].active) {
123                 return 1;
124         }
125
126         /* return early if IP checking is disabled */
127         if (!check_ip) {
128                 return 0;
129         }
130
131         tempin = (struct sockaddr_in *) &(q->from);
132         return memcmp(&(users[userid].host), &(tempin->sin_addr), sizeof(struct in_addr));
133 }
134
135 static int
136 tunnel_tun(int tun_fd, int dns_fd)
137 {
138         unsigned long outlen;
139         struct ip *header;
140         char out[64*1024];
141         char in[64*1024];
142         int userid;
143         int read;
144
145         if ((read = read_tun(tun_fd, in, sizeof(in))) <= 0)
146                 return 0;
147         
148         /* find target ip in packet, in is padded with 4 bytes TUN header */
149         header = (struct ip*) (in + 4);
150         userid = find_user_by_ip(header->ip_dst.s_addr);
151         if (userid < 0)
152                 return 0;
153
154         outlen = sizeof(out);
155         compress2((uint8_t*)out, &outlen, (uint8_t*)in, read, 9);
156
157         /* if another packet is queued, throw away this one. TODO build queue */
158         if (users[userid].outpacket.len == 0) {
159                 memcpy(users[userid].outpacket.data, out, outlen);
160                 users[userid].outpacket.len = outlen;
161                 users[userid].outpacket.offset = 0;
162                 users[userid].outpacket.sentlen = 0;
163                 users[userid].outpacket.seqno = (++users[userid].outpacket.seqno & 7);
164                 users[userid].outpacket.fragment = 0;
165                 return outlen;
166         } else {
167                 return 0;
168         }
169 }
170
171 typedef enum {
172         VERSION_ACK,
173         VERSION_NACK,
174         VERSION_FULL
175 } version_ack_t;
176
177 static void
178 send_version_response(int fd, version_ack_t ack, uint32_t payload, int userid, struct query *q)
179 {
180         char out[9];
181         
182         switch (ack) {
183         case VERSION_ACK:
184                 strncpy(out, "VACK", sizeof(out));
185                 break;
186         case VERSION_NACK:
187                 strncpy(out, "VNAK", sizeof(out));
188                 break;
189         case VERSION_FULL:
190                 strncpy(out, "VFUL", sizeof(out));
191                 break;
192         }
193         
194         out[4] = ((payload >> 24) & 0xff);
195         out[5] = ((payload >> 16) & 0xff);
196         out[6] = ((payload >> 8) & 0xff);
197         out[7] = ((payload) & 0xff);
198         out[8] = userid & 0xff;
199
200         write_dns(fd, q, out, sizeof(out));
201 }
202
203 static void
204 send_chunk(int dns_fd, int userid) {
205         char pkt[4096];
206         int datalen;
207         int last;
208
209         datalen = MIN(users[userid].fragsize, users[userid].outpacket.len - users[userid].outpacket.offset);
210
211         if (datalen && users[userid].outpacket.sentlen > 0 && 
212                         (
213                         users[userid].outpacket.seqno != users[userid].out_acked_seqno ||
214                         users[userid].outpacket.fragment != users[userid].out_acked_fragment
215                         )
216                 ) {
217
218                 /* Still waiting on latest ack, send nothing */
219                 datalen = 0;
220                 last = 0;
221                 /* TODO : count down and discard packet if no acks arrive within X queries */
222         } else {
223                 memcpy(&pkt[2], &users[userid].outpacket.data[users[userid].outpacket.offset], datalen);
224                 users[userid].outpacket.sentlen = datalen;
225                 last = (users[userid].outpacket.len == users[userid].outpacket.offset + users[userid].outpacket.sentlen);
226
227                 /* Increase fragment# when sending data with offset */
228                 if (users[userid].outpacket.offset && datalen)
229                         users[userid].outpacket.fragment++;
230         }
231
232         /* Build downstream data header (see doc/proto_xxxxxxxx.txt) */
233
234         /* First byte is 1 bit compression flag, 3 bits upstream seqno, 4 bits upstream fragment */
235         pkt[0] = (1<<7) | ((users[userid].inpacket.seqno & 7) << 4) | (users[userid].inpacket.fragment & 15);
236         /* Second byte is 3 bits downstream seqno, 4 bits downstream fragment, 1 bit last flag */
237         pkt[1] = ((users[userid].outpacket.seqno & 7) << 5) | 
238                 ((users[userid].outpacket.fragment & 15) << 1) | (last & 1);
239
240         if (debug >= 1) {
241                 fprintf(stderr, "OUT  pkt seq# %d, frag %d (last=%d), offset %d, fragsize %d, total %d, to user %d\n",
242                         users[userid].outpacket.seqno & 7, users[userid].outpacket.fragment & 15, 
243                         last, users[userid].outpacket.offset, datalen, users[userid].outpacket.len, userid);
244         }
245         write_dns(dns_fd, &users[userid].q, pkt, datalen + 2);
246         users[userid].q.id = 0;
247
248         if (users[userid].outpacket.len > 0 && 
249                 users[userid].outpacket.len == users[userid].outpacket.sentlen) {
250
251                 /* Whole packet was sent in one chunk, dont wait for ack */
252                 users[userid].outpacket.len = 0;
253                 users[userid].outpacket.offset = 0;
254                 users[userid].outpacket.sentlen = 0;
255         }
256 }
257
258 static void
259 update_downstream_seqno(int dns_fd, int userid, int down_seq, int down_frag)
260 {
261         /* If we just read a new packet from tun we have not sent a fragment of, just send it */
262         if (users[userid].outpacket.len > 0 && users[userid].outpacket.sentlen == 0) {
263                 send_chunk(dns_fd, userid);
264                 return;
265         }
266
267         /* otherwise, check if we received ack on a fragment and can send next */
268         if (users[userid].outpacket.len > 0 &&
269                 users[userid].outpacket.seqno == down_seq && users[userid].outpacket.fragment == down_frag) {
270
271                 if (down_seq != users[userid].out_acked_seqno || down_frag != users[userid].out_acked_fragment) {
272                         /* Received ACK on downstream fragment */
273                         users[userid].outpacket.offset += users[userid].outpacket.sentlen;
274                         users[userid].outpacket.sentlen = 0;
275
276                         /* Is packet done? */
277                         if (users[userid].outpacket.offset == users[userid].outpacket.len) {
278                                 users[userid].outpacket.len = 0;
279                                 users[userid].outpacket.offset = 0;
280                                 users[userid].outpacket.sentlen = 0;
281                         }
282
283                         users[userid].out_acked_seqno = down_seq;
284                         users[userid].out_acked_fragment = down_frag;
285
286                         /* Send reply if waiting */
287                         if (users[userid].outpacket.len > 0) {
288                                 send_chunk(dns_fd, userid);
289                         }
290                 }
291         }
292 }
293
294 static void
295 handle_null_request(int tun_fd, int dns_fd, struct query *q, int domain_len)
296 {
297         struct in_addr tempip;
298         struct ip *hdr;
299         unsigned long outlen;
300         char in[512];
301         char logindata[16];
302         char out[64*1024];
303         char unpacked[64*1024];
304         char *tmp[2];
305         int userid;
306         int touser;
307         int version;
308         int code;
309         int read;
310
311         userid = -1;
312
313         memcpy(in, q->name, MIN(domain_len, sizeof(in)));
314
315         if(in[0] == 'V' || in[0] == 'v') {
316                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
317                 /* Version greeting, compare and send ack/nak */
318                 if (read > 4) { 
319                         /* Received V + 32bits version */
320                         version = (((unpacked[0] & 0xff) << 24) |
321                                            ((unpacked[1] & 0xff) << 16) |
322                                            ((unpacked[2] & 0xff) << 8) |
323                                            ((unpacked[3] & 0xff)));
324                 }
325
326                 if (version == VERSION) {
327                         userid = find_available_user();
328                         if (userid >= 0) {
329                                 struct sockaddr_in *tempin;
330
331                                 users[userid].seed = rand();
332                                 /* Store remote IP number */
333                                 tempin = (struct sockaddr_in *) &(q->from);
334                                 memcpy(&(users[userid].host), &(tempin->sin_addr), sizeof(struct in_addr));
335                                 
336                                 memcpy(&(users[userid].q), q, sizeof(struct query));
337                                 users[userid].encoder = get_base32_encoder();
338                                 send_version_response(dns_fd, VERSION_ACK, users[userid].seed, userid, q);
339                                 syslog(LOG_INFO, "accepted version for user #%d from %s",
340                                         userid, inet_ntoa(tempin->sin_addr));
341                                 users[userid].q.id = 0;
342                         } else {
343                                 /* No space for another user */
344                                 send_version_response(dns_fd, VERSION_FULL, created_users, 0, q);
345                                 syslog(LOG_INFO, "dropped user from %s, server full", 
346                                         inet_ntoa(((struct sockaddr_in *) &q->from)->sin_addr));
347                         }
348                 } else {
349                         send_version_response(dns_fd, VERSION_NACK, VERSION, 0, q);
350                         syslog(LOG_INFO, "dropped user from %s, sent bad version %08X", 
351                                 inet_ntoa(((struct sockaddr_in *) &q->from)->sin_addr), version);
352                 }
353                 return;
354         } else if(in[0] == 'L' || in[0] == 'l') {
355                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
356                 /* Login phase, handle auth */
357                 userid = unpacked[0];
358
359                 if (check_user_and_ip(userid, q) != 0) {
360                         write_dns(dns_fd, q, "BADIP", 5);
361                         syslog(LOG_WARNING, "dropped login request from user #%d from unexpected source %s",
362                                 userid, inet_ntoa(((struct sockaddr_in *) &q->from)->sin_addr));
363                         return;
364                 } else {
365                         users[userid].last_pkt = time(NULL);
366                         login_calculate(logindata, 16, password, users[userid].seed);
367
368                         if (read >= 18 && (memcmp(logindata, unpacked+1, 16) == 0)) {
369                                 /* Login ok, send ip/mtu/netmask info */
370
371                                 tempip.s_addr = my_ip;
372                                 tmp[0] = strdup(inet_ntoa(tempip));
373                                 tempip.s_addr = users[userid].tun_ip;
374                                 tmp[1] = strdup(inet_ntoa(tempip));
375
376                                 read = snprintf(out, sizeof(out), "%s-%s-%d-%d", 
377                                                 tmp[0], tmp[1], my_mtu, netmask);
378
379                                 write_dns(dns_fd, q, out, read);
380                                 q->id = 0;
381                                 syslog(LOG_NOTICE, "accepted password from user #%d, given IP %s", userid, tmp[1]);
382
383                                 free(tmp[1]);
384                                 free(tmp[0]);
385                         } else {
386                                 write_dns(dns_fd, q, "LNAK", 4);
387                                 syslog(LOG_WARNING, "rejected login request from user #%d from %s, bad password",
388                                         userid, inet_ntoa(((struct sockaddr_in *) &q->from)->sin_addr));
389                         }
390                 }
391                 return;
392         } else if(in[0] == 'Z' || in[0] == 'z') {
393                 /* Check for case conservation and chars not allowed according to RFC */
394
395                 /* Reply with received hostname as data */
396                 write_dns(dns_fd, q, in, domain_len);
397                 return;
398         } else if(in[0] == 'S' || in[0] == 's') {
399                 int codec;
400                 struct encoder *enc;
401                 if (domain_len != 4) { /* len = 4, example: "S15." */
402                         write_dns(dns_fd, q, "BADLEN", 6);
403                         return;
404                 }
405
406                 userid = b32_8to5(in[1]);
407                 
408                 if (check_user_and_ip(userid, q) != 0) {
409                         write_dns(dns_fd, q, "BADIP", 5);
410                         return; /* illegal id */
411                 }
412                 
413                 codec = b32_8to5(in[2]);
414
415                 switch (codec) {
416                 case 5: /* 5 bits per byte = base32 */
417                         enc = get_base32_encoder();
418                         user_switch_codec(userid, enc);
419                         write_dns(dns_fd, q, enc->name, strlen(enc->name));
420                         break;
421                 case 6: /* 6 bits per byte = base64 */
422                         enc = get_base64_encoder();
423                         user_switch_codec(userid, enc);
424                         write_dns(dns_fd, q, enc->name, strlen(enc->name));
425                         break;
426                 default:
427                         write_dns(dns_fd, q, "BADCODEC", 8);
428                         break;
429                 }
430                 return;
431         } else if(in[0] == 'R' || in[0] == 'r') {
432                 int req_frag_size;
433
434                 /* Downstream fragsize probe packet */
435                 userid = (b32_8to5(in[1]) >> 1) & 15;
436                 if (check_user_and_ip(userid, q) != 0) {
437                         write_dns(dns_fd, q, "BADIP", 5);
438                         return; /* illegal id */
439                 }
440                                 
441                 req_frag_size = ((b32_8to5(in[1]) & 1) << 10) | ((b32_8to5(in[2]) & 31) << 5) | (b32_8to5(in[3]) & 31);
442                 if (req_frag_size < 2 || req_frag_size > 2047) {        
443                         write_dns(dns_fd, q, "BADFRAG", 7);
444                 } else {
445                         char buf[2048];
446
447                         memset(buf, 0, sizeof(buf));
448                         buf[0] = (req_frag_size >> 8) & 0xff;
449                         buf[1] = req_frag_size & 0xff;
450                         write_dns(dns_fd, q, buf, req_frag_size);
451                 }
452                 return;
453         } else if(in[0] == 'N' || in[0] == 'n') {
454                 int max_frag_size;
455
456                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
457                 /* Downstream fragsize packet */
458                 userid = unpacked[0];
459                 if (check_user_and_ip(userid, q) != 0) {
460                         write_dns(dns_fd, q, "BADIP", 5);
461                         return; /* illegal id */
462                 }
463                                 
464                 max_frag_size = ((unpacked[1] & 0xff) << 8) | (unpacked[2] & 0xff);
465                 if (max_frag_size < 2) {        
466                         write_dns(dns_fd, q, "BADFRAG", 7);
467                 } else {
468                         users[userid].fragsize = max_frag_size;
469                         write_dns(dns_fd, q, &unpacked[1], 2);
470                 }
471                 return;
472         } else if(in[0] == 'P' || in[0] == 'p') {
473                 int dn_seq;
474                 int dn_frag;
475                 
476                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
477                 /* Ping packet, store userid */
478                 userid = unpacked[0];
479                 if (check_user_and_ip(userid, q) != 0) {
480                         write_dns(dns_fd, q, "BADIP", 5);
481                         return; /* illegal id */
482                 }
483                                 
484                 if (debug >= 1) {
485                         fprintf(stderr, "PING pkt from user %d\n", userid);
486                 }
487
488                 if (users[userid].q.id != 0) {
489                         /* Send reply on earlier query before overwriting */
490                         send_chunk(dns_fd, userid);
491                 }
492
493                 dn_seq = unpacked[1] >> 4;
494                 dn_frag = unpacked[1] & 15;
495                 memcpy(&(users[userid].q), q, sizeof(struct query));
496                 users[userid].last_pkt = time(NULL);
497
498                 /* Update seqno and maybe send immediate response packet */
499                 update_downstream_seqno(dns_fd, userid, dn_seq, dn_frag);
500         } else if((in[0] >= '0' && in[0] <= '9')
501                         || (in[0] >= 'a' && in[0] <= 'f')
502                         || (in[0] >= 'A' && in[0] <= 'F')) {
503                 if ((in[0] >= '0' && in[0] <= '9'))
504                         code = in[0] - '0';
505                 if ((in[0] >= 'a' && in[0] <= 'f'))
506                         code = in[0] - 'a' + 10;
507                 if ((in[0] >= 'A' && in[0] <= 'F'))
508                         code = in[0] - 'A' + 10;
509
510                 userid = code;
511                 /* Check user and sending ip number */
512                 if (check_user_and_ip(userid, q) != 0) {
513                         write_dns(dns_fd, q, "BADIP", 5);
514                 } else {
515                         /* Decode data header */
516                         int up_seq = (b32_8to5(in[1]) >> 2) & 7;
517                         int up_frag = ((b32_8to5(in[1]) & 3) << 2) | ((b32_8to5(in[2]) >> 3) & 3);
518                         int dn_seq = (b32_8to5(in[2]) & 7);
519                         int dn_frag = b32_8to5(in[3]) >> 1;
520                         int lastfrag = b32_8to5(in[3]) & 1;
521
522                         if (users[userid].q.id != 0) {
523                                 /* Send reply on earlier query before overwriting */
524                                 send_chunk(dns_fd, userid);
525                         }
526
527                         /* Update query and time info for user */
528                         users[userid].last_pkt = time(NULL);
529                         memcpy(&(users[userid].q), q, sizeof(struct query));
530
531                         if (up_seq == users[userid].inpacket.seqno && 
532                                 up_frag <= users[userid].inpacket.fragment) {
533                                 /* Got repeated old packet, skip it */
534                                 if (debug >= 1) {
535                                         fprintf(stderr, "IN   pkt seq# %d, frag %d, dropped duplicate\n",
536                                                 up_seq, up_frag);
537                                 }
538                                 /* Update seqno and maybe send immediate response packet */
539                                 update_downstream_seqno(dns_fd, userid, dn_seq, dn_frag);
540                                 return;
541                         }
542                         if (up_seq != users[userid].inpacket.seqno) {
543                                 /* New packet has arrived */
544                                 users[userid].inpacket.seqno = up_seq;
545                                 users[userid].inpacket.len = 0;
546                                 users[userid].inpacket.offset = 0;
547                         }
548                         users[userid].inpacket.fragment = up_frag;
549
550                         /* decode with this users encoding */
551                         read = unpack_data(unpacked, sizeof(unpacked), &(in[4]), domain_len - 4, 
552                                            users[userid].encoder);
553
554                         /* copy to packet buffer, update length */
555                         memcpy(users[userid].inpacket.data + users[userid].inpacket.offset, unpacked, read);
556                         users[userid].inpacket.len += read;
557                         users[userid].inpacket.offset += read;
558
559                         if (debug >= 1) {
560                                 fprintf(stderr, "IN   pkt seq# %d, frag %d (last=%d), fragsize %d, total %d, from user %d\n",
561                                         up_seq, up_frag, lastfrag, read, users[userid].inpacket.len, userid);
562                         }
563
564                         if (lastfrag & 1) { /* packet is complete */
565                                 int ret;
566                                 outlen = sizeof(out);
567                                 ret = uncompress((uint8_t*)out, &outlen, 
568                                            (uint8_t*)users[userid].inpacket.data, users[userid].inpacket.len);
569
570                                 if (ret == Z_OK) {
571                                         hdr = (struct ip*) (out + 4);
572                                         touser = find_user_by_ip(hdr->ip_dst.s_addr);
573
574                                         if (touser == -1) {
575                                                 /* send the uncompressed packet to tun device */
576                                                 write_tun(tun_fd, out, outlen);
577                                         } else {
578                                                 /* send the compressed packet to other client
579                                                  * if another packet is queued, throw away this one. TODO build queue */
580                                                 if (users[touser].outpacket.len == 0) {
581                                                         memcpy(users[touser].outpacket.data, users[userid].inpacket.data, users[userid].inpacket.len);
582                                                         users[touser].outpacket.len = users[userid].inpacket.len;
583                                                 }
584                                         }
585                                 } else {
586                                         fprintf(stderr, "Discarded data, uncompress() result: %d\n", ret);
587                                 }
588                                 users[userid].inpacket.len = users[userid].inpacket.offset = 0;
589                         }
590                         /* Update seqno and maybe send immediate response packet */
591                         update_downstream_seqno(dns_fd, userid, dn_seq, dn_frag);
592                 }
593         }
594 }
595
596 static void
597 handle_ns_request(int dns_fd, struct query *q)
598 {
599         char buf[64*1024];
600         int len;
601
602         if (ns_ip != INADDR_ANY) {
603                 memcpy(&q->destination.s_addr, &ns_ip, sizeof(in_addr_t));
604         }
605
606         len = dns_encode_ns_response(buf, sizeof(buf), q, topdomain);
607         
608         if (debug >= 2) {
609                 struct sockaddr_in *tempin;
610                 tempin = (struct sockaddr_in *) &(q->from);
611                 fprintf(stderr, "TX: client %s, type %d, name %s, %d bytes NS reply\n", 
612                         inet_ntoa(tempin->sin_addr), q->type, q->name, len);
613         }
614         if (sendto(dns_fd, buf, len, 0, (struct sockaddr*)&q->from, q->fromlen) <= 0) {
615                 warn("ns reply send error");
616         }
617 }
618
619 static void
620 forward_query(int bind_fd, struct query *q)
621 {
622         char buf[64*1024];
623         int len;
624         struct fw_query fwq;
625         struct sockaddr_in *myaddr;
626         in_addr_t newaddr;
627
628         len = dns_encode(buf, sizeof(buf), q, QR_QUERY, q->name, strlen(q->name));
629
630         /* Store sockaddr for q->id */
631         memcpy(&(fwq.addr), &(q->from), q->fromlen);
632         fwq.addrlen = q->fromlen;
633         fwq.id = q->id;
634         fw_query_put(&fwq);
635
636         newaddr = inet_addr("127.0.0.1");
637         myaddr = (struct sockaddr_in *) &(q->from);
638         memcpy(&(myaddr->sin_addr), &newaddr, sizeof(in_addr_t));
639         myaddr->sin_port = htons(bind_port);
640         
641         if (debug >= 2) {
642                 fprintf(stderr, "TX: NS reply \n");
643         }
644
645         if (sendto(bind_fd, buf, len, 0, (struct sockaddr*)&q->from, q->fromlen) <= 0) {
646                 warn("forward query error");
647         }
648 }
649   
650 static int
651 tunnel_bind(int bind_fd, int dns_fd)
652 {
653         char packet[64*1024];
654         struct sockaddr_in from;
655         socklen_t fromlen;
656         struct fw_query *query;
657         unsigned short id;
658         int r;
659
660         fromlen = sizeof(struct sockaddr);
661         r = recvfrom(bind_fd, packet, sizeof(packet), 0, 
662                 (struct sockaddr*)&from, &fromlen);
663
664         if (r <= 0)
665                 return 0;
666
667         id = dns_get_id(packet, r);
668         
669         if (debug >= 2) {
670                 fprintf(stderr, "RX: Got response on query %u from DNS\n", (id & 0xFFFF));
671         }
672
673         /* Get sockaddr from id */
674         fw_query_get(id, &query);
675         if (!query && debug >= 2) {
676                 fprintf(stderr, "Lost sender of id %u, dropping reply\n", (id & 0xFFFF));
677                 return 0;
678         }
679
680         if (debug >= 2) {
681                 struct sockaddr_in *in;
682                 in = (struct sockaddr_in *) &(query->addr);
683                 fprintf(stderr, "TX: client %s id %u, %d bytes\n",
684                         inet_ntoa(in->sin_addr), (id & 0xffff), r);
685         }
686         
687         if (sendto(dns_fd, packet, r, 0, (const struct sockaddr *) &(query->addr), 
688                 query->addrlen) <= 0) {
689                 warn("forward reply error");
690         }
691
692         return 0;
693 }
694
695 static int
696 tunnel_dns(int tun_fd, int dns_fd, int bind_fd)
697 {
698         struct query q;
699         int read;
700         char *domain;
701         int domain_len;
702         int inside_topdomain;
703
704         if ((read = read_dns(dns_fd, &q)) <= 0)
705                 return 0;
706
707         if (debug >= 2) {
708                 struct sockaddr_in *tempin;
709                 tempin = (struct sockaddr_in *) &(q.from);
710                 fprintf(stderr, "RX: client %s, type %d, name %s\n", 
711                         inet_ntoa(tempin->sin_addr), q.type, q.name);
712         }
713         
714         domain = strstr(q.name, topdomain);
715         inside_topdomain = 0;
716         if (domain) {
717                 domain_len = (int) (domain - q.name); 
718                 if (domain_len + strlen(topdomain) == strlen(q.name)) {
719                         inside_topdomain = 1;
720                 }
721         }
722         
723         if (inside_topdomain) {
724                 /* This is a query we can handle */
725                 switch (q.type) {
726                 case T_NULL:
727                         handle_null_request(tun_fd, dns_fd, &q, domain_len);
728                         break;
729                 case T_NS:
730                         handle_ns_request(dns_fd, &q);
731                         break;
732                 default:
733                         break;
734                 }
735         } else {
736                 /* Forward query to other port ? */
737                 if (bind_fd) {
738                         forward_query(bind_fd, &q);
739                 }
740         }
741         return 0;
742 }
743
744 static int
745 tunnel(int tun_fd, int dns_fd, int bind_fd)
746 {
747         struct timeval tv;
748         fd_set fds;
749         int i;
750
751         while (running) {
752                 int maxfd;
753                 if (users_waiting_on_reply()) {
754                         tv.tv_sec = 0;
755                         tv.tv_usec = 15000;
756                 } else {
757                         tv.tv_sec = 1;
758                         tv.tv_usec = 0;
759                 }
760
761                 FD_ZERO(&fds);
762
763                 FD_SET(dns_fd, &fds);
764                 maxfd = dns_fd;
765
766                 if (bind_fd) {
767                         /* wait for replies from real DNS */
768                         FD_SET(bind_fd, &fds);
769                         maxfd = MAX(bind_fd, maxfd);
770                 }
771
772                 /* TODO : use some kind of packet queue */
773                 if(!all_users_waiting_to_send()) {
774                         FD_SET(tun_fd, &fds);
775                         maxfd = MAX(tun_fd, maxfd);
776                 }
777
778                 i = select(maxfd + 1, &fds, NULL, NULL, &tv);
779                 
780                 if(i < 0) {
781                         if (running) 
782                                 warn("select");
783                         return 1;
784                 }
785
786                 if (i==0) {     
787                         int j;
788                         for (j = 0; j < USERS; j++) {
789                                 if (users[j].q.id != 0) {
790                                         send_chunk(dns_fd, j);
791                                 }
792                         }
793                 } else {
794                         if(FD_ISSET(tun_fd, &fds)) {
795                                 tunnel_tun(tun_fd, dns_fd);
796                                 continue;
797                         }
798                         if(FD_ISSET(dns_fd, &fds)) {
799                                 tunnel_dns(tun_fd, dns_fd, bind_fd);
800                                 continue;
801                         } 
802                         if(FD_ISSET(bind_fd, &fds)) {
803                                 tunnel_bind(bind_fd, dns_fd);
804                                 continue;
805                         }
806                 }
807         }
808
809         return 0;
810 }
811
812 static int
813 read_dns(int fd, struct query *q)
814 {
815         struct sockaddr_in from;
816         socklen_t addrlen;
817         char packet[64*1024];
818         int r;
819 #ifndef WINDOWS32
820         char address[96];
821         struct msghdr msg;
822         struct iovec iov;
823         struct cmsghdr *cmsg;
824
825         addrlen = sizeof(struct sockaddr);
826         iov.iov_base = packet;
827         iov.iov_len = sizeof(packet);
828
829         msg.msg_name = (caddr_t) &from;
830         msg.msg_namelen = (unsigned) addrlen;
831         msg.msg_iov = &iov;
832         msg.msg_iovlen = 1;
833         msg.msg_control = address;
834         msg.msg_controllen = sizeof(address);
835         msg.msg_flags = 0;
836         
837         r = recvmsg(fd, &msg, 0);
838 #else
839         addrlen = sizeof(struct sockaddr);
840         r = recvfrom(fd, packet, sizeof(packet), 0, (struct sockaddr*)&from, &addrlen);
841 #endif /* !WINDOWS32 */
842
843         if (r > 0) {
844                 dns_decode(NULL, 0, q, QR_QUERY, packet, r);
845                 memcpy((struct sockaddr*)&q->from, (struct sockaddr*)&from, addrlen);
846                 q->fromlen = addrlen;
847                 
848 #ifndef WINDOWS32
849                 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; 
850                         cmsg = CMSG_NXTHDR(&msg, cmsg)) { 
851                         
852                         if (cmsg->cmsg_level == IPPROTO_IP && 
853                                 cmsg->cmsg_type == DSTADDR_SOCKOPT) { 
854                                 
855                                 q->destination = *dstaddr(cmsg); 
856                                 break;
857                         } 
858                 }
859 #endif
860
861                 return strlen(q->name);
862         } else if (r < 0) { 
863                 /* Error */
864                 warn("read dns");
865         }
866
867         return 0;
868 }
869
870 static void
871 write_dns(int fd, struct query *q, char *data, int datalen)
872 {
873         char buf[64*1024];
874         int len;
875
876         len = dns_encode(buf, sizeof(buf), q, QR_ANSWER, data, datalen);
877         
878         if (debug >= 2) {
879                 struct sockaddr_in *tempin;
880                 tempin = (struct sockaddr_in *) &(q->from);
881                 fprintf(stderr, "TX: client %s, type %d, name %s, %d bytes data\n", 
882                         inet_ntoa(tempin->sin_addr), q->type, q->name, datalen);
883         }
884
885         sendto(fd, buf, len, 0, (struct sockaddr*)&q->from, q->fromlen);
886 }
887
888 static void
889 usage() {
890         extern char *__progname;
891
892         fprintf(stderr, "Usage: %s [-v] [-h] [-c] [-s] [-f] [-D] [-u user] "
893                 "[-t chrootdir] [-d device] [-m mtu] "
894                 "[-l ip address to listen on] [-p port] [-n external ip] [-b dnsport] [-P password]"
895                 " tunnel_ip[/netmask] topdomain\n", __progname);
896         exit(2);
897 }
898
899 static void
900 help() {
901         extern char *__progname;
902
903         fprintf(stderr, "iodine IP over DNS tunneling server\n");
904         fprintf(stderr, "Usage: %s [-v] [-h] [-c] [-s] [-f] [-D] [-u user] "
905                 "[-t chrootdir] [-d device] [-m mtu] "
906                 "[-l ip address to listen on] [-p port] [-n external ip] [-b dnsport] [-P password]"
907                 " tunnel_ip[/netmask] topdomain\n", __progname);
908         fprintf(stderr, "  -v to print version info and exit\n");
909         fprintf(stderr, "  -h to print this help and exit\n");
910         fprintf(stderr, "  -c to disable check of client IP/port on each request\n");
911         fprintf(stderr, "  -s to skip creating and configuring the tun device, "
912                 "which then has to be created manually\n");
913         fprintf(stderr, "  -f to keep running in foreground\n");
914         fprintf(stderr, "  -D to increase debug level\n");
915         fprintf(stderr, "  -u name to drop privileges and run as user 'name'\n");
916         fprintf(stderr, "  -t dir to chroot to directory dir\n");
917         fprintf(stderr, "  -d device to set tunnel device name\n");
918         fprintf(stderr, "  -m mtu to set tunnel device mtu\n");
919         fprintf(stderr, "  -l ip address to listen on for incoming dns traffic "
920                 "(default 0.0.0.0)\n");
921         fprintf(stderr, "  -p port to listen on for incoming dns traffic (default 53)\n");
922         fprintf(stderr, "  -n ip to respond with to NS queries\n");
923         fprintf(stderr, "  -b port to forward normal DNS queries to (on localhost)\n");
924         fprintf(stderr, "  -P password used for authentication (max 32 chars will be used)\n");
925         fprintf(stderr, "tunnel_ip is the IP number of the local tunnel interface.\n");
926         fprintf(stderr, "   /netmask sets the size of the tunnel network.\n");
927         fprintf(stderr, "topdomain is the FQDN that is delegated to this server.\n");
928         exit(0);
929 }
930
931 static void
932 version() {
933         printf("iodine IP over DNS tunneling server\n");
934         printf("version: 0.5.1 from 2009-03-21\n");
935         exit(0);
936 }
937
938 int
939 main(int argc, char **argv)
940 {
941         extern char *__progname;
942         in_addr_t listen_ip;
943 #ifndef WINDOWS32
944         struct passwd *pw;
945 #endif
946         int foreground;
947         char *username;
948         char *newroot;
949         char *device;
950         int dnsd_fd;
951         int tun_fd;
952
953         /* settings for forwarding normal DNS to 
954          * local real DNS server */
955         int bind_fd;
956         int bind_enable;
957         
958         int choice;
959         int port;
960         int mtu;
961         int skipipconfig;
962         char *netsize;
963
964         username = NULL;
965         newroot = NULL;
966         device = NULL;
967         foreground = 0;
968         bind_enable = 0;
969         bind_fd = 0;
970         mtu = 1200;
971         listen_ip = INADDR_ANY;
972         port = 53;
973         ns_ip = INADDR_ANY;
974         check_ip = 1;
975         skipipconfig = 0;
976         debug = 0;
977         netmask = 27;
978
979         b32 = get_base32_encoder();
980         
981 #ifdef WINDOWS32
982         WSAStartup(req_version, &wsa_data);
983 #endif
984
985 #if !defined(BSD) && !defined(__GLIBC__)
986         __progname = strrchr(argv[0], '/');
987         if (__progname == NULL)
988                 __progname = argv[0];
989         else
990                 __progname++;
991 #endif
992
993         memset(password, 0, sizeof(password));
994         srand(time(NULL));
995         fw_query_init();
996         
997         while ((choice = getopt(argc, argv, "vcsfhDu:t:d:m:l:p:n:b:P:")) != -1) {
998                 switch(choice) {
999                 case 'v':
1000                         version();
1001                         break;
1002                 case 'c':
1003                         check_ip = 0;
1004                         break;
1005                 case 's':
1006                         skipipconfig = 1;
1007                         break;
1008                 case 'f':
1009                         foreground = 1;
1010                         break;
1011                 case 'h':
1012                         help();
1013                         break;
1014                 case 'D':
1015                         debug++;
1016                         break;
1017                 case 'u':
1018                         username = optarg;
1019                         break;
1020                 case 't':
1021                         newroot = optarg;
1022                         break;
1023                 case 'd':
1024                         device = optarg;
1025                         break;
1026                 case 'm':
1027                         mtu = atoi(optarg);
1028                         break;
1029                 case 'l':
1030                         listen_ip = inet_addr(optarg);
1031                         break;
1032                 case 'p':
1033                         port = atoi(optarg);
1034                         break;
1035                 case 'n':
1036                         ns_ip = inet_addr(optarg);
1037                         break;
1038                 case 'b':
1039                         bind_enable = 1;
1040                         bind_port = atoi(optarg);
1041                         break;
1042                 case 'P':
1043                         strncpy(password, optarg, sizeof(password));
1044                         password[sizeof(password)-1] = 0;
1045                         
1046                         /* XXX: find better way of cleaning up ps(1) */
1047                         memset(optarg, 0, strlen(optarg)); 
1048                         break;
1049                 default:
1050                         usage();
1051                         break;
1052                 }
1053         }
1054
1055         argc -= optind;
1056         argv += optind;
1057
1058         check_superuser(usage);
1059
1060         if (argc != 2) 
1061                 usage();
1062         
1063         netsize = strchr(argv[0], '/');
1064         if (netsize) {
1065                 *netsize = 0;
1066                 netsize++;
1067                 netmask = atoi(netsize);
1068         }
1069
1070         my_ip = inet_addr(argv[0]);
1071         
1072         if (my_ip == INADDR_NONE) {
1073                 warnx("Bad IP address to use inside tunnel.\n");
1074                 usage();
1075         }
1076
1077         topdomain = strdup(argv[1]);
1078         if(strlen(topdomain) <= 128) {
1079                 if(check_topdomain(topdomain)) {
1080                         warnx("Topdomain contains invalid characters.\n");
1081                         usage();
1082                 }
1083         } else {
1084                 warnx("Use a topdomain max 128 chars long.\n");
1085                 usage();
1086         }
1087
1088         if (username != NULL) {
1089 #ifndef WINDOWS32
1090                 if ((pw = getpwnam(username)) == NULL) {
1091                         warnx("User %s does not exist!\n", username);
1092                         usage();
1093                 }
1094 #endif
1095         }
1096
1097         if (mtu <= 0) {
1098                 warnx("Bad MTU given.\n");
1099                 usage();
1100         }
1101         
1102         if(port < 1 || port > 65535) {
1103                 warnx("Bad port number given.\n");
1104                 usage();
1105         }
1106         
1107         if(bind_enable) {
1108                 if (bind_port < 1 || bind_port > 65535 || bind_port == port) {
1109                         warnx("Bad DNS server port number given.\n");
1110                         usage();
1111                         /* NOTREACHED */
1112                 }
1113                 fprintf(stderr, "Requests for domains outside of %s will be forwarded to port %d\n",
1114                         topdomain, bind_port);
1115         }
1116         
1117         if (port != 53) {
1118                 fprintf(stderr, "ALERT! Other dns servers expect you to run on port 53.\n");
1119                 fprintf(stderr, "You must manually forward port 53 to port %d for things to work.\n", port);
1120         }
1121
1122         if (debug) {
1123                 fprintf(stderr, "Debug level %d enabled, will stay in foreground.\n", debug);
1124                 fprintf(stderr, "Add more -D switches to set higher debug level.\n");
1125                 foreground = 1;
1126         }
1127
1128         if (listen_ip == INADDR_NONE) {
1129                 warnx("Bad IP address to listen on.\n");
1130                 usage();
1131         }
1132         
1133         if (ns_ip == INADDR_NONE) {
1134                 warnx("Bad IP address to return as nameserver.\n");
1135                 usage();
1136         }
1137         if (netmask > 30 || netmask < 8) {
1138                 warnx("Bad netmask (%d bits). Use 8-30 bits.\n", netmask);
1139                 usage();
1140         }
1141         
1142         if (strlen(password) == 0)
1143                 read_password(password, sizeof(password));
1144
1145         if ((tun_fd = open_tun(device)) == -1)
1146                 goto cleanup0;
1147         if (!skipipconfig)
1148                 if (tun_setip(argv[0], netmask) != 0 || tun_setmtu(mtu) != 0)
1149                         goto cleanup1;
1150         if ((dnsd_fd = open_dns(port, listen_ip)) == -1) 
1151                 goto cleanup2;
1152         if (bind_enable)
1153                 if ((bind_fd = open_dns(0, INADDR_ANY)) == -1)
1154                         goto cleanup3;
1155
1156         my_mtu = mtu;
1157
1158         created_users = init_users(my_ip, netmask);
1159         
1160         if (created_users < USERS) {
1161                 fprintf(stderr, "Limiting to %d simultaneous users because of netmask /%d\n",
1162                         created_users, netmask);
1163         }
1164         fprintf(stderr, "Listening to dns for domain %s\n", topdomain);
1165
1166         if (foreground == 0) 
1167                 do_detach();
1168         
1169         if (newroot != NULL)
1170                 do_chroot(newroot);
1171
1172         signal(SIGINT, sigint);
1173         if (username != NULL) {
1174 #ifndef WINDOWS32
1175                 gid_t gids[1];
1176                 gids[0] = pw->pw_gid;
1177                 if (setgroups(1, gids) < 0 || setgid(pw->pw_gid) < 0 || setuid(pw->pw_uid) < 0) {
1178                         warnx("Could not switch to user %s!\n", username);
1179                         usage();
1180                 }
1181 #endif
1182         }
1183
1184 #ifndef WINDOWS32
1185         openlog(__progname, LOG_NOWAIT, LOG_DAEMON);
1186 #endif
1187         syslog(LOG_INFO, "started, listening on port %d", port);
1188         
1189         tunnel(tun_fd, dnsd_fd, bind_fd);
1190
1191         syslog(LOG_INFO, "stopping");
1192 cleanup3:
1193         close_dns(bind_fd);
1194 cleanup2:
1195         close_dns(dnsd_fd);
1196 cleanup1:
1197         close_tun(tun_fd);      
1198 cleanup0:
1199
1200         return 0;
1201 }