[svn-upgrade] Integrating new upstream version, iodine (0.5.0)
[debian/iodine.git] / src / iodined.c
1 /*
2  * Copyright (c) 2006-2009 Bjorn Andersson <flex@kryo.se>, Erik Ekman <yarrick@kryo.se>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  *
8  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15  */
16
17 #include <stdio.h>
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <signal.h>
22 #include <unistd.h>
23 #include <sys/types.h>
24 #include <sys/param.h>
25 #include <sys/time.h>
26 #define _XPG4_2
27 #include <sys/socket.h>
28 #include <sys/uio.h>
29 #include <fcntl.h>
30 #include <err.h>
31 #include <grp.h>
32 #include <time.h>
33 #include <pwd.h>
34 #include <arpa/inet.h>
35 #include <netinet/in.h>
36 #include <netinet/in_systm.h>
37 #include <netinet/ip.h>
38 #include <zlib.h>
39 #include <arpa/nameser.h>
40 #ifdef DARWIN
41 #include <arpa/nameser8_compat.h>
42 #endif
43
44 #include "common.h"
45 #include "dns.h"
46 #include "encoding.h"
47 #include "base32.h"
48 #include "base64.h"
49 #include "user.h"
50 #include "login.h"
51 #include "tun.h"
52 #include "fw_query.h"
53 #include "version.h"
54
55 static int running = 1;
56 static char *topdomain;
57 static char password[33];
58 static struct encoder *b32;
59 static int created_users;
60
61 static int check_ip;
62 static int my_mtu;
63 static in_addr_t my_ip;
64 static int netmask;
65
66 static in_addr_t ns_ip;
67
68 static int bind_port;
69 static int debug;
70
71 #if !defined(BSD) && !defined(__GLIBC__)
72 static char *__progname;
73 #endif
74
75 static int read_dns(int, struct query *);
76 static void write_dns(int, struct query *, char *, int);
77
78 static void
79 sigint(int sig) 
80 {
81         running = 0;
82 }
83
84 static int
85 check_user_and_ip(int userid, struct query *q)
86 {
87         struct sockaddr_in *tempin;
88
89         if (userid < 0 || userid >= created_users ) {
90                 return 1; 
91         }
92         if (!users[userid].active) {
93                 return 1;
94         }
95
96         /* return early if IP checking is disabled */
97         if (!check_ip) {
98                 return 0;
99         }
100
101         tempin = (struct sockaddr_in *) &(q->from);
102         return memcmp(&(users[userid].host), &(tempin->sin_addr), sizeof(struct in_addr));
103 }
104
105 static int
106 tunnel_tun(int tun_fd, int dns_fd)
107 {
108         unsigned long outlen;
109         struct ip *header;
110         char out[64*1024];
111         char in[64*1024];
112         int userid;
113         int read;
114
115         if ((read = read_tun(tun_fd, in, sizeof(in))) <= 0)
116                 return 0;
117         
118         /* find target ip in packet, in is padded with 4 bytes TUN header */
119         header = (struct ip*) (in + 4);
120         userid = find_user_by_ip(header->ip_dst.s_addr);
121         if (userid < 0)
122                 return 0;
123
124         outlen = sizeof(out);
125         compress2((uint8_t*)out, &outlen, (uint8_t*)in, read, 9);
126
127         /* if another packet is queued, throw away this one. TODO build queue */
128         if (users[userid].outpacket.len == 0) {
129                 memcpy(users[userid].outpacket.data, out, outlen);
130                 users[userid].outpacket.len = outlen;
131                 users[userid].outpacket.offset = 0;
132                 users[userid].outpacket.sentlen = 0;
133                 users[userid].outpacket.seqno = (++users[userid].outpacket.seqno & 7);
134                 users[userid].outpacket.fragment = 0;
135                 return outlen;
136         } else {
137                 return 0;
138         }
139 }
140
141 typedef enum {
142         VERSION_ACK,
143         VERSION_NACK,
144         VERSION_FULL
145 } version_ack_t;
146
147 static void
148 send_version_response(int fd, version_ack_t ack, uint32_t payload, int userid, struct query *q)
149 {
150         char out[9];
151         
152         switch (ack) {
153         case VERSION_ACK:
154                 strncpy(out, "VACK", sizeof(out));
155                 break;
156         case VERSION_NACK:
157                 strncpy(out, "VNAK", sizeof(out));
158                 break;
159         case VERSION_FULL:
160                 strncpy(out, "VFUL", sizeof(out));
161                 break;
162         }
163         
164         out[4] = ((payload >> 24) & 0xff);
165         out[5] = ((payload >> 16) & 0xff);
166         out[6] = ((payload >> 8) & 0xff);
167         out[7] = ((payload) & 0xff);
168         out[8] = userid & 0xff;
169
170         write_dns(fd, q, out, sizeof(out));
171 }
172
173 static void
174 send_chunk(int dns_fd, int userid) {
175         char pkt[4096];
176         int datalen;
177         int last;
178
179         datalen = MIN(users[userid].fragsize, users[userid].outpacket.len - users[userid].outpacket.offset);
180
181         if (datalen && users[userid].outpacket.sentlen > 0 && 
182                         (
183                         users[userid].outpacket.seqno != users[userid].out_acked_seqno ||
184                         users[userid].outpacket.fragment != users[userid].out_acked_fragment
185                         )
186                 ) {
187
188                 /* Still waiting on latest ack, send nothing */
189                 datalen = 0;
190                 last = 0;
191                 /* TODO : count down and discard packet if no acks arrive within X queries */
192         } else {
193                 memcpy(&pkt[2], &users[userid].outpacket.data[users[userid].outpacket.offset], datalen);
194                 users[userid].outpacket.sentlen = datalen;
195                 last = (users[userid].outpacket.len == users[userid].outpacket.offset + users[userid].outpacket.sentlen);
196
197                 /* Increase fragment# when sending data with offset */
198                 if (users[userid].outpacket.offset && datalen)
199                         users[userid].outpacket.fragment++;
200         }
201
202         /* Build downstream data header (see doc/proto_xxxxxxxx.txt) */
203
204         /* First byte is 1 bit compression flag, 3 bits upstream seqno, 4 bits upstream fragment */
205         pkt[0] = (1<<7) | ((users[userid].inpacket.seqno & 7) << 4) | (users[userid].inpacket.fragment & 15);
206         /* Second byte is 3 bits downstream seqno, 4 bits downstream fragment, 1 bit last flag */
207         pkt[1] = ((users[userid].outpacket.seqno & 7) << 5) | 
208                 ((users[userid].outpacket.fragment & 15) << 1) | (last & 1);
209
210         if (debug >= 1) {
211                 printf("OUT  pkt seq# %d, frag %d (last=%d), offset %d, fragsize %d, total %d, to user %d\n",
212                         users[userid].outpacket.seqno & 7, users[userid].outpacket.fragment & 15, 
213                         last, users[userid].outpacket.offset, datalen, users[userid].outpacket.len, userid);
214         }
215         write_dns(dns_fd, &users[userid].q, pkt, datalen + 2);
216         users[userid].q.id = 0;
217
218         if (users[userid].outpacket.len > 0 && 
219                 users[userid].outpacket.len == users[userid].outpacket.sentlen) {
220
221                 /* Whole packet was sent in one chunk, dont wait for ack */
222                 users[userid].outpacket.len = 0;
223                 users[userid].outpacket.offset = 0;
224                 users[userid].outpacket.sentlen = 0;
225         }
226 }
227
228 static void
229 update_downstream_seqno(int dns_fd, int userid, int down_seq, int down_frag)
230 {
231         /* If we just read a new packet from tun we have not sent a fragment of, just send it */
232         if (users[userid].outpacket.len > 0 && users[userid].outpacket.sentlen == 0) {
233                 send_chunk(dns_fd, userid);
234                 return;
235         }
236
237         /* otherwise, check if we received ack on a fragment and can send next */
238         if (users[userid].outpacket.len > 0 &&
239                 users[userid].outpacket.seqno == down_seq && users[userid].outpacket.fragment == down_frag) {
240
241                 if (down_seq != users[userid].out_acked_seqno || down_frag != users[userid].out_acked_fragment) {
242                         /* Received ACK on downstream fragment */
243                         users[userid].outpacket.offset += users[userid].outpacket.sentlen;
244                         users[userid].outpacket.sentlen = 0;
245
246                         /* Is packet done? */
247                         if (users[userid].outpacket.offset == users[userid].outpacket.len) {
248                                 users[userid].outpacket.len = 0;
249                                 users[userid].outpacket.offset = 0;
250                                 users[userid].outpacket.sentlen = 0;
251                         }
252
253                         users[userid].out_acked_seqno = down_seq;
254                         users[userid].out_acked_fragment = down_frag;
255
256                         /* Send reply if waiting */
257                         if (users[userid].outpacket.len > 0) {
258                                 send_chunk(dns_fd, userid);
259                         }
260                 }
261         }
262 }
263
264 static void
265 handle_null_request(int tun_fd, int dns_fd, struct query *q, int domain_len)
266 {
267         struct in_addr tempip;
268         struct ip *hdr;
269         unsigned long outlen;
270         char in[512];
271         char logindata[16];
272         char out[64*1024];
273         char unpacked[64*1024];
274         char *tmp[2];
275         int userid;
276         int touser;
277         int version;
278         int code;
279         int read;
280
281         userid = -1;
282
283         memcpy(in, q->name, MIN(domain_len, sizeof(in)));
284
285         if(in[0] == 'V' || in[0] == 'v') {
286                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
287                 /* Version greeting, compare and send ack/nak */
288                 if (read > 4) { 
289                         /* Received V + 32bits version */
290                         version = (((unpacked[0] & 0xff) << 24) |
291                                            ((unpacked[1] & 0xff) << 16) |
292                                            ((unpacked[2] & 0xff) << 8) |
293                                            ((unpacked[3] & 0xff)));
294                 }
295
296                 if (version == VERSION) {
297                         userid = find_available_user();
298                         if (userid >= 0) {
299                                 struct sockaddr_in *tempin;
300
301                                 users[userid].seed = rand();
302                                 /* Store remote IP number */
303                                 tempin = (struct sockaddr_in *) &(q->from);
304                                 memcpy(&(users[userid].host), &(tempin->sin_addr), sizeof(struct in_addr));
305                                 
306                                 memcpy(&(users[userid].q), q, sizeof(struct query));
307                                 users[userid].encoder = get_base32_encoder();
308                                 send_version_response(dns_fd, VERSION_ACK, users[userid].seed, userid, q);
309                                 users[userid].q.id = 0;
310                         } else {
311                                 /* No space for another user */
312                                 send_version_response(dns_fd, VERSION_FULL, created_users, 0, q);
313                         }
314                 } else {
315                         send_version_response(dns_fd, VERSION_NACK, VERSION, 0, q);
316                 }
317                 return;
318         } else if(in[0] == 'L' || in[0] == 'l') {
319                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
320                 /* Login phase, handle auth */
321                 userid = unpacked[0];
322
323                 if (check_user_and_ip(userid, q) != 0) {
324                         write_dns(dns_fd, q, "BADIP", 5);
325                         return;
326                 } else {
327                         users[userid].last_pkt = time(NULL);
328                         login_calculate(logindata, 16, password, users[userid].seed);
329
330                         if (read >= 18 && (memcmp(logindata, unpacked+1, 16) == 0)) {
331                                 /* Login ok, send ip/mtu/netmask info */
332
333                                 tempip.s_addr = my_ip;
334                                 tmp[0] = strdup(inet_ntoa(tempip));
335                                 tempip.s_addr = users[userid].tun_ip;
336                                 tmp[1] = strdup(inet_ntoa(tempip));
337
338                                 read = snprintf(out, sizeof(out), "%s-%s-%d-%d", 
339                                                 tmp[0], tmp[1], my_mtu, netmask);
340
341                                 write_dns(dns_fd, q, out, read);
342                                 q->id = 0;
343
344                                 free(tmp[1]);
345                                 free(tmp[0]);
346                         } else {
347                                 write_dns(dns_fd, q, "LNAK", 4);
348                         }
349                 }
350                 return;
351         } else if(in[0] == 'Z' || in[0] == 'z') {
352                 /* Check for case conservation and chars not allowed according to RFC */
353
354                 /* Reply with received hostname as data */
355                 write_dns(dns_fd, q, in, domain_len);
356                 return;
357         } else if(in[0] == 'S' || in[0] == 's') {
358                 int codec;
359                 struct encoder *enc;
360                 if (domain_len != 4) { /* len = 4, example: "S15." */
361                         write_dns(dns_fd, q, "BADLEN", 6);
362                         return;
363                 }
364
365                 userid = b32_8to5(in[1]);
366                 
367                 if (check_user_and_ip(userid, q) != 0) {
368                         write_dns(dns_fd, q, "BADIP", 5);
369                         return; /* illegal id */
370                 }
371                 
372                 codec = b32_8to5(in[2]);
373
374                 switch (codec) {
375                 case 5: /* 5 bits per byte = base32 */
376                         enc = get_base32_encoder();
377                         user_switch_codec(userid, enc);
378                         write_dns(dns_fd, q, enc->name, strlen(enc->name));
379                         break;
380                 case 6: /* 6 bits per byte = base64 */
381                         enc = get_base64_encoder();
382                         user_switch_codec(userid, enc);
383                         write_dns(dns_fd, q, enc->name, strlen(enc->name));
384                         break;
385                 default:
386                         write_dns(dns_fd, q, "BADCODEC", 8);
387                         break;
388                 }
389                 return;
390         } else if(in[0] == 'R' || in[0] == 'r') {
391                 int req_frag_size;
392
393                 /* Downstream fragsize probe packet */
394                 userid = (b32_8to5(in[1]) >> 1) & 15;
395                 if (check_user_and_ip(userid, q) != 0) {
396                         write_dns(dns_fd, q, "BADIP", 5);
397                         return; /* illegal id */
398                 }
399                                 
400                 req_frag_size = ((b32_8to5(in[1]) & 1) << 10) | ((b32_8to5(in[2]) & 31) << 5) | (b32_8to5(in[3]) & 31);
401                 if (req_frag_size < 2 || req_frag_size > 2047) {        
402                         write_dns(dns_fd, q, "BADFRAG", 7);
403                 } else {
404                         char buf[2048];
405
406                         memset(buf, 0, sizeof(buf));
407                         buf[0] = (req_frag_size >> 8) & 0xff;
408                         buf[1] = req_frag_size & 0xff;
409                         write_dns(dns_fd, q, buf, req_frag_size);
410                 }
411                 return;
412         } else if(in[0] == 'N' || in[0] == 'n') {
413                 int max_frag_size;
414
415                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
416                 /* Downstream fragsize packet */
417                 userid = unpacked[0];
418                 if (check_user_and_ip(userid, q) != 0) {
419                         write_dns(dns_fd, q, "BADIP", 5);
420                         return; /* illegal id */
421                 }
422                                 
423                 max_frag_size = ((unpacked[1] & 0xff) << 8) | (unpacked[2] & 0xff);
424                 if (max_frag_size < 2) {        
425                         write_dns(dns_fd, q, "BADFRAG", 7);
426                 } else {
427                         users[userid].fragsize = max_frag_size;
428                         write_dns(dns_fd, q, &unpacked[1], 2);
429                 }
430                 return;
431         } else if(in[0] == 'P' || in[0] == 'p') {
432                 int dn_seq;
433                 int dn_frag;
434                 
435                 read = unpack_data(unpacked, sizeof(unpacked), &(in[1]), domain_len - 1, b32);
436                 /* Ping packet, store userid */
437                 userid = unpacked[0];
438                 if (check_user_and_ip(userid, q) != 0) {
439                         write_dns(dns_fd, q, "BADIP", 5);
440                         return; /* illegal id */
441                 }
442                                 
443                 if (debug >= 1) {
444                         printf("PING pkt from user %d\n", userid);
445                 }
446
447                 if (users[userid].q.id != 0) {
448                         /* Send reply on earlier query before overwriting */
449                         send_chunk(dns_fd, userid);
450                 }
451
452                 dn_seq = unpacked[1] >> 4;
453                 dn_frag = unpacked[1] & 15;
454                 memcpy(&(users[userid].q), q, sizeof(struct query));
455                 users[userid].last_pkt = time(NULL);
456
457                 /* Update seqno and maybe send immediate response packet */
458                 update_downstream_seqno(dns_fd, userid, dn_seq, dn_frag);
459         } else if((in[0] >= '0' && in[0] <= '9')
460                         || (in[0] >= 'a' && in[0] <= 'f')
461                         || (in[0] >= 'A' && in[0] <= 'F')) {
462                 if ((in[0] >= '0' && in[0] <= '9'))
463                         code = in[0] - '0';
464                 if ((in[0] >= 'a' && in[0] <= 'f'))
465                         code = in[0] - 'a' + 10;
466                 if ((in[0] >= 'A' && in[0] <= 'F'))
467                         code = in[0] - 'A' + 10;
468
469                 userid = code;
470                 /* Check user and sending ip number */
471                 if (check_user_and_ip(userid, q) != 0) {
472                         write_dns(dns_fd, q, "BADIP", 5);
473                 } else {
474                         /* Decode data header */
475                         int up_seq = (b32_8to5(in[1]) >> 2) & 7;
476                         int up_frag = ((b32_8to5(in[1]) & 3) << 2) | ((b32_8to5(in[2]) >> 3) & 3);
477                         int dn_seq = (b32_8to5(in[2]) & 7);
478                         int dn_frag = b32_8to5(in[3]) >> 1;
479                         int lastfrag = b32_8to5(in[3]) & 1;
480
481                         if (users[userid].q.id != 0) {
482                                 /* Send reply on earlier query before overwriting */
483                                 send_chunk(dns_fd, userid);
484                         }
485
486                         /* Update query and time info for user */
487                         users[userid].last_pkt = time(NULL);
488                         memcpy(&(users[userid].q), q, sizeof(struct query));
489
490                         if (up_seq == users[userid].inpacket.seqno && 
491                                 up_frag <= users[userid].inpacket.fragment) {
492                                 /* Got repeated old packet, skip it */
493                                 if (debug >= 1) {
494                                         printf("IN   pkt seq# %d, frag %d, dropped duplicate\n",
495                                                 up_seq, up_frag);
496                                 }
497                                 /* Update seqno and maybe send immediate response packet */
498                                 update_downstream_seqno(dns_fd, userid, dn_seq, dn_frag);
499                                 return;
500                         }
501                         if (up_seq != users[userid].inpacket.seqno) {
502                                 /* New packet has arrived */
503                                 users[userid].inpacket.seqno = up_seq;
504                                 users[userid].inpacket.len = 0;
505                                 users[userid].inpacket.offset = 0;
506                         }
507                         users[userid].inpacket.fragment = up_frag;
508
509                         /* decode with this users encoding */
510                         read = unpack_data(unpacked, sizeof(unpacked), &(in[4]), domain_len - 4, 
511                                            users[userid].encoder);
512
513                         /* copy to packet buffer, update length */
514                         memcpy(users[userid].inpacket.data + users[userid].inpacket.offset, unpacked, read);
515                         users[userid].inpacket.len += read;
516                         users[userid].inpacket.offset += read;
517
518                         if (debug >= 1) {
519                                 printf("IN   pkt seq# %d, frag %d (last=%d), fragsize %d, total %d, from user %d\n",
520                                         up_seq, up_frag, lastfrag, read, users[userid].inpacket.len, userid);
521                         }
522
523                         if (lastfrag & 1) { /* packet is complete */
524                                 int ret;
525                                 outlen = sizeof(out);
526                                 ret = uncompress((uint8_t*)out, &outlen, 
527                                            (uint8_t*)users[userid].inpacket.data, users[userid].inpacket.len);
528
529                                 if (ret == Z_OK) {
530                                         hdr = (struct ip*) (out + 4);
531                                         touser = find_user_by_ip(hdr->ip_dst.s_addr);
532
533                                         if (touser == -1) {
534                                                 /* send the uncompressed packet to tun device */
535                                                 write_tun(tun_fd, out, outlen);
536                                         } else {
537                                                 /* send the compressed packet to other client
538                                                  * if another packet is queued, throw away this one. TODO build queue */
539                                                 if (users[touser].outpacket.len == 0) {
540                                                         memcpy(users[touser].outpacket.data, users[userid].inpacket.data, users[userid].inpacket.len);
541                                                         users[touser].outpacket.len = users[userid].inpacket.len;
542                                                 }
543                                         }
544                                 } else {
545                                         printf("Discarded data, uncompress() result: %d\n", ret);
546                                 }
547                                 users[userid].inpacket.len = users[userid].inpacket.offset = 0;
548                         }
549                         /* Update seqno and maybe send immediate response packet */
550                         update_downstream_seqno(dns_fd, userid, dn_seq, dn_frag);
551                 }
552         }
553 }
554
555 static void
556 handle_ns_request(int dns_fd, struct query *q)
557 {
558         char buf[64*1024];
559         int len;
560
561         if (ns_ip != INADDR_ANY) {
562                 memcpy(&q->destination.s_addr, &ns_ip, sizeof(in_addr_t));
563         }
564
565         len = dns_encode_ns_response(buf, sizeof(buf), q, topdomain);
566         
567         if (debug >= 2) {
568                 struct sockaddr_in *tempin;
569                 tempin = (struct sockaddr_in *) &(q->from);
570                 printf("TX: client %s, type %d, name %s, %d bytes NS reply\n", 
571                         inet_ntoa(tempin->sin_addr), q->type, q->name, len);
572         }
573         if (sendto(dns_fd, buf, len, 0, (struct sockaddr*)&q->from, q->fromlen) <= 0) {
574                 warn("ns reply send error");
575         }
576 }
577
578 static void
579 forward_query(int bind_fd, struct query *q)
580 {
581         char buf[64*1024];
582         int len;
583         struct fw_query fwq;
584         struct sockaddr_in *myaddr;
585         in_addr_t newaddr;
586
587         len = dns_encode(buf, sizeof(buf), q, QR_QUERY, q->name, strlen(q->name));
588
589         /* Store sockaddr for q->id */
590         memcpy(&(fwq.addr), &(q->from), q->fromlen);
591         fwq.addrlen = q->fromlen;
592         fwq.id = q->id;
593         fw_query_put(&fwq);
594
595         newaddr = inet_addr("127.0.0.1");
596         myaddr = (struct sockaddr_in *) &(q->from);
597         memcpy(&(myaddr->sin_addr), &newaddr, sizeof(in_addr_t));
598         myaddr->sin_port = htons(bind_port);
599         
600         if (debug >= 2) {
601                 printf("TX: NS reply \n");
602         }
603
604         if (sendto(bind_fd, buf, len, 0, (struct sockaddr*)&q->from, q->fromlen) <= 0) {
605                 warn("forward query error");
606         }
607 }
608   
609 static int
610 tunnel_bind(int bind_fd, int dns_fd)
611 {
612         char packet[64*1024];
613         struct sockaddr_in from;
614         socklen_t fromlen;
615         struct fw_query *query;
616         short id;
617         int r;
618
619         fromlen = sizeof(struct sockaddr);
620         r = recvfrom(bind_fd, packet, sizeof(packet), 0, 
621                 (struct sockaddr*)&from, &fromlen);
622
623         if (r <= 0)
624                 return 0;
625
626         id = dns_get_id(packet, r);
627         
628         if (debug >= 2) {
629                 printf("RX: Got response on query %u from DNS\n", (id & 0xFFFF));
630         }
631
632         /* Get sockaddr from id */
633         fw_query_get(id, &query);
634         if (!query && debug >= 2) {
635                 printf("Lost sender of id %u, dropping reply\n", (id & 0xFFFF));
636                 return 0;
637         }
638
639         if (debug >= 2) {
640                 struct sockaddr_in *in;
641                 in = (struct sockaddr_in *) &(query->addr);
642                 printf("TX: client %s id %u, %d bytes\n",
643                         inet_ntoa(in->sin_addr), (id & 0xffff), r);
644         }
645         
646         if (sendto(dns_fd, packet, r, 0, (const struct sockaddr *) &(query->addr), 
647                 query->addrlen) <= 0) {
648                 warn("forward reply error");
649         }
650
651         return 0;
652 }
653
654 static int
655 tunnel_dns(int tun_fd, int dns_fd, int bind_fd)
656 {
657         struct query q;
658         int read;
659         char *domain;
660         int domain_len;
661         int inside_topdomain;
662
663         if ((read = read_dns(dns_fd, &q)) <= 0)
664                 return 0;
665
666         if (debug >= 2) {
667                 struct sockaddr_in *tempin;
668                 tempin = (struct sockaddr_in *) &(q.from);
669                 printf("RX: client %s, type %d, name %s\n", 
670                         inet_ntoa(tempin->sin_addr), q.type, q.name);
671         }
672         
673         domain = strstr(q.name, topdomain);
674         inside_topdomain = 0;
675         if (domain) {
676                 domain_len = (int) (domain - q.name); 
677                 if (domain_len + strlen(topdomain) == strlen(q.name)) {
678                         inside_topdomain = 1;
679                 }
680         }
681         
682         if (inside_topdomain) {
683                 /* This is a query we can handle */
684                 switch (q.type) {
685                 case T_NULL:
686                         handle_null_request(tun_fd, dns_fd, &q, domain_len);
687                         break;
688                 case T_NS:
689                         handle_ns_request(dns_fd, &q);
690                         break;
691                 default:
692                         break;
693                 }
694         } else {
695                 /* Forward query to other port ? */
696                 if (bind_fd) {
697                         forward_query(bind_fd, &q);
698                 }
699         }
700         return 0;
701 }
702
703 static int
704 tunnel(int tun_fd, int dns_fd, int bind_fd)
705 {
706         struct timeval tv;
707         fd_set fds;
708         int i;
709
710         while (running) {
711                 int maxfd;
712                 if (users_waiting_on_reply()) {
713                         tv.tv_sec = 0;
714                         tv.tv_usec = 15000;
715                 } else {
716                         tv.tv_sec = 1;
717                         tv.tv_usec = 0;
718                 }
719
720                 FD_ZERO(&fds);
721
722                 FD_SET(dns_fd, &fds);
723                 maxfd = dns_fd;
724
725                 if (bind_fd) {
726                         /* wait for replies from real DNS */
727                         FD_SET(bind_fd, &fds);
728                         maxfd = MAX(bind_fd, maxfd);
729                 }
730
731                 /* TODO : use some kind of packet queue */
732                 if(!all_users_waiting_to_send()) {
733                         FD_SET(tun_fd, &fds);
734                         maxfd = MAX(tun_fd, maxfd);
735                 }
736
737                 i = select(maxfd + 1, &fds, NULL, NULL, &tv);
738                 
739                 if(i < 0) {
740                         if (running) 
741                                 warn("select");
742                         return 1;
743                 }
744
745                 if (i==0) {     
746                         int j;
747                         for (j = 0; j < USERS; j++) {
748                                 if (users[j].q.id != 0) {
749                                         send_chunk(dns_fd, j);
750                                 }
751                         }
752                 } else {
753                         if(FD_ISSET(tun_fd, &fds)) {
754                                 tunnel_tun(tun_fd, dns_fd);
755                                 continue;
756                         }
757                         if(FD_ISSET(dns_fd, &fds)) {
758                                 tunnel_dns(tun_fd, dns_fd, bind_fd);
759                                 continue;
760                         } 
761                         if(FD_ISSET(bind_fd, &fds)) {
762                                 tunnel_bind(bind_fd, dns_fd);
763                                 continue;
764                         }
765                 }
766         }
767
768         return 0;
769 }
770
771 static int
772 read_dns(int fd, struct query *q)
773 {
774         struct sockaddr_in from;
775         socklen_t addrlen;
776         char packet[64*1024];
777         char address[96];
778         struct msghdr msg;
779         struct iovec iov;
780         struct cmsghdr *cmsg;
781         int r;
782
783         addrlen = sizeof(struct sockaddr);
784         iov.iov_base = packet;
785         iov.iov_len = sizeof(packet);
786
787         msg.msg_name = (caddr_t) &from;
788         msg.msg_namelen = (unsigned) addrlen;
789         msg.msg_iov = &iov;
790         msg.msg_iovlen = 1;
791         msg.msg_control = address;
792         msg.msg_controllen = sizeof(address);
793         msg.msg_flags = 0;
794         
795         r = recvmsg(fd, &msg, 0);
796
797         if (r > 0) {
798                 dns_decode(NULL, 0, q, QR_QUERY, packet, r);
799                 memcpy((struct sockaddr*)&q->from, (struct sockaddr*)&from, addrlen);
800                 q->fromlen = addrlen;
801                 
802                 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; 
803                         cmsg = CMSG_NXTHDR(&msg, cmsg)) { 
804                         
805                         if (cmsg->cmsg_level == IPPROTO_IP && 
806                                 cmsg->cmsg_type == DSTADDR_SOCKOPT) { 
807                                 
808                                 q->destination = *dstaddr(cmsg); 
809                                 break;
810                         } 
811                 }
812
813                 return strlen(q->name);
814         } else if (r < 0) { 
815                 /* Error */
816                 warn("read dns");
817         }
818
819         return 0;
820 }
821
822 static void
823 write_dns(int fd, struct query *q, char *data, int datalen)
824 {
825         char buf[64*1024];
826         int len;
827
828         len = dns_encode(buf, sizeof(buf), q, QR_ANSWER, data, datalen);
829         
830         if (debug >= 2) {
831                 struct sockaddr_in *tempin;
832                 tempin = (struct sockaddr_in *) &(q->from);
833                 printf("TX: client %s, type %d, name %s, %d bytes data\n", 
834                         inet_ntoa(tempin->sin_addr), q->type, q->name, datalen);
835         }
836
837         sendto(fd, buf, len, 0, (struct sockaddr*)&q->from, q->fromlen);
838 }
839
840 static void
841 usage() {
842         extern char *__progname;
843
844         printf("Usage: %s [-v] [-h] [-c] [-s] [-f] [-D] [-u user] "
845                 "[-t chrootdir] [-d device] [-m mtu] "
846                 "[-l ip address to listen on] [-p port] [-n external ip] [-b dnsport] [-P password]"
847                 " tunnel_ip[/netmask] topdomain\n", __progname);
848         exit(2);
849 }
850
851 static void
852 help() {
853         extern char *__progname;
854
855         printf("iodine IP over DNS tunneling server\n");
856         printf("Usage: %s [-v] [-h] [-c] [-s] [-f] [-D] [-u user] "
857                 "[-t chrootdir] [-d device] [-m mtu] "
858                 "[-l ip address to listen on] [-p port] [-n external ip] [-b dnsport] [-P password]"
859                 " tunnel_ip[/netmask] topdomain\n", __progname);
860         printf("  -v to print version info and exit\n");
861         printf("  -h to print this help and exit\n");
862         printf("  -c to disable check of client IP/port on each request\n");
863         printf("  -s to skip creating and configuring the tun device, "
864                 "which then has to be created manually\n");
865         printf("  -f to keep running in foreground\n");
866         printf("  -D to increase debug level\n");
867         printf("  -u name to drop privileges and run as user 'name'\n");
868         printf("  -t dir to chroot to directory dir\n");
869         printf("  -d device to set tunnel device name\n");
870         printf("  -m mtu to set tunnel device mtu\n");
871         printf("  -l ip address to listen on for incoming dns traffic "
872                 "(default 0.0.0.0)\n");
873         printf("  -p port to listen on for incoming dns traffic (default 53)\n");
874         printf("  -n ip to respond with to NS queries\n");
875         printf("  -b port to forward normal DNS queries to (on localhost)\n");
876         printf("  -P password used for authentication (max 32 chars will be used)\n");
877         printf("tunnel_ip is the IP number of the local tunnel interface.\n");
878         printf("   /netmask sets the size of the tunnel network.\n");
879         printf("topdomain is the FQDN that is delegated to this server.\n");
880         exit(0);
881 }
882
883 static void
884 version() {
885         printf("iodine IP over DNS tunneling server\n");
886         printf("version: 0.5.0 from 2009-01-23\n");
887         exit(0);
888 }
889
890 int
891 main(int argc, char **argv)
892 {
893         in_addr_t listen_ip;
894         struct passwd *pw;
895         int foreground;
896         char *username;
897         char *newroot;
898         char *device;
899         int dnsd_fd;
900         int tun_fd;
901
902         /* settings for forwarding normal DNS to 
903          * local real DNS server */
904         int bind_fd;
905         int bind_enable;
906         
907         int choice;
908         int port;
909         int mtu;
910         int skipipconfig;
911         char *netsize;
912
913         username = NULL;
914         newroot = NULL;
915         device = NULL;
916         foreground = 0;
917         bind_enable = 0;
918         bind_fd = 0;
919         mtu = 1024;
920         listen_ip = INADDR_ANY;
921         port = 53;
922         ns_ip = INADDR_ANY;
923         check_ip = 1;
924         skipipconfig = 0;
925         debug = 0;
926         netmask = 27;
927
928         b32 = get_base32_encoder();
929
930 #if !defined(BSD) && !defined(__GLIBC__)
931         __progname = strrchr(argv[0], '/');
932         if (__progname == NULL)
933                 __progname = argv[0];
934         else
935                 __progname++;
936 #endif
937
938         memset(password, 0, sizeof(password));
939         srand(time(NULL));
940         fw_query_init();
941         
942         while ((choice = getopt(argc, argv, "vcsfhDu:t:d:m:l:p:n:b:P:")) != -1) {
943                 switch(choice) {
944                 case 'v':
945                         version();
946                         break;
947                 case 'c':
948                         check_ip = 0;
949                         break;
950                 case 's':
951                         skipipconfig = 1;
952                         break;
953                 case 'f':
954                         foreground = 1;
955                         break;
956                 case 'h':
957                         help();
958                         break;
959                 case 'D':
960                         debug++;
961                         break;
962                 case 'u':
963                         username = optarg;
964                         break;
965                 case 't':
966                         newroot = optarg;
967                         break;
968                 case 'd':
969                         device = optarg;
970                         break;
971                 case 'm':
972                         mtu = atoi(optarg);
973                         break;
974                 case 'l':
975                         listen_ip = inet_addr(optarg);
976                         break;
977                 case 'p':
978                         port = atoi(optarg);
979                         break;
980                 case 'n':
981                         ns_ip = inet_addr(optarg);
982                         break;
983                 case 'b':
984                         bind_enable = 1;
985                         bind_port = atoi(optarg);
986                         break;
987                 case 'P':
988                         strncpy(password, optarg, sizeof(password));
989                         password[sizeof(password)-1] = 0;
990                         
991                         /* XXX: find better way of cleaning up ps(1) */
992                         memset(optarg, 0, strlen(optarg)); 
993                         break;
994                 default:
995                         usage();
996                         break;
997                 }
998         }
999
1000         argc -= optind;
1001         argv += optind;
1002
1003         if (geteuid() != 0) {
1004                 warnx("Run as root and you'll be happy.\n");
1005                 usage();
1006         }
1007
1008         if (argc != 2) 
1009                 usage();
1010         
1011         netsize = strchr(argv[0], '/');
1012         if (netsize) {
1013                 *netsize = 0;
1014                 netsize++;
1015                 netmask = atoi(netsize);
1016         }
1017
1018         my_ip = inet_addr(argv[0]);
1019         
1020         if (my_ip == INADDR_NONE) {
1021                 warnx("Bad IP address to use inside tunnel.\n");
1022                 usage();
1023         }
1024
1025         topdomain = strdup(argv[1]);
1026         if(strlen(topdomain) <= 128) {
1027                 if(check_topdomain(topdomain)) {
1028                         warnx("Topdomain contains invalid characters.\n");
1029                         usage();
1030                 }
1031         } else {
1032                 warnx("Use a topdomain max 128 chars long.\n");
1033                 usage();
1034         }
1035
1036         if (username != NULL) {
1037                 if ((pw = getpwnam(username)) == NULL) {
1038                         warnx("User %s does not exist!\n", username);
1039                         usage();
1040                 }
1041         }
1042
1043         if (mtu <= 0) {
1044                 warnx("Bad MTU given.\n");
1045                 usage();
1046         }
1047         
1048         if(port < 1 || port > 65535) {
1049                 warnx("Bad port number given.\n");
1050                 usage();
1051         }
1052         
1053         if(bind_enable) {
1054                 if (bind_port < 1 || bind_port > 65535 || bind_port == port) {
1055                         warnx("Bad DNS server port number given.\n");
1056                         usage();
1057                         /* NOTREACHED */
1058                 }
1059                 printf("Requests for domains outside of %s will be forwarded to port %d\n",
1060                         topdomain, bind_port);
1061         }
1062         
1063         if (port != 53) {
1064                 printf("ALERT! Other dns servers expect you to run on port 53.\n");
1065                 printf("You must manually forward port 53 to port %d for things to work.\n", port);
1066         }
1067
1068         if (debug) {
1069                 printf("Debug level %d enabled, will stay in foreground.\n", debug);
1070                 printf("Add more -D switches to set higher debug level.\n");
1071                 foreground = 1;
1072         }
1073
1074         if (listen_ip == INADDR_NONE) {
1075                 warnx("Bad IP address to listen on.\n");
1076                 usage();
1077         }
1078         
1079         if (ns_ip == INADDR_NONE) {
1080                 warnx("Bad IP address to return as nameserver.\n");
1081                 usage();
1082         }
1083         if (netmask > 30 || netmask < 8) {
1084                 warnx("Bad netmask (%d bits). Use 8-30 bits.\n", netmask);
1085                 usage();
1086         }
1087         
1088         if (strlen(password) == 0)
1089                 read_password(password, sizeof(password));
1090
1091         if ((tun_fd = open_tun(device)) == -1)
1092                 goto cleanup0;
1093         if (!skipipconfig)
1094                 if (tun_setip(argv[0], netmask) != 0 || tun_setmtu(mtu) != 0)
1095                         goto cleanup1;
1096         if ((dnsd_fd = open_dns(port, listen_ip)) == -1) 
1097                 goto cleanup2;
1098         if (bind_enable)
1099                 if ((bind_fd = open_dns(0, INADDR_ANY)) == -1)
1100                         goto cleanup3;
1101
1102         my_mtu = mtu;
1103
1104         created_users = init_users(my_ip, netmask);
1105         
1106         if (created_users < USERS) {
1107                 printf("Limiting to %d simultaneous users because of netmask /%d\n",
1108                         created_users, netmask);
1109         }
1110         printf("Listening to dns for domain %s\n", topdomain);
1111
1112         if (foreground == 0) 
1113                 do_detach();
1114         
1115         if (newroot != NULL)
1116                 do_chroot(newroot);
1117
1118         signal(SIGINT, sigint);
1119         if (username != NULL) {
1120                 gid_t gids[1];
1121                 gids[0] = pw->pw_gid;
1122                 if (setgroups(1, gids) < 0 || setgid(pw->pw_gid) < 0 || setuid(pw->pw_uid) < 0) {
1123                         warnx("Could not switch to user %s!\n", username);
1124                         usage();
1125                 }
1126         }
1127         
1128         tunnel(tun_fd, dnsd_fd, bind_fd);
1129
1130 cleanup3:
1131         close_dns(bind_fd);
1132 cleanup2:
1133         close_dns(dnsd_fd);
1134 cleanup1:
1135         close_tun(tun_fd);      
1136 cleanup0:
1137
1138         return 0;
1139 }