diff --git a/examples/dns.c b/examples/dns.c index e6dee66..d5ee585 100644 --- a/examples/dns.c +++ b/examples/dns.c @@ -1,3 +1,6 @@ +#include +#include +#include #include "../src/hammer.h" #include "dns_common.h" #include "dns.h" @@ -41,15 +44,16 @@ struct dns_qname get_qname(const HParsedToken *t) { const HParsedToken *labels = t->seq->elements[0]; struct dns_qname ret = { .qlen = labels->seq->used, - .labels = h_arena_malloc(t->seq->arena, sizeof(ret.labels)*labels->seq->used) + .labels = h_arena_malloc(t->seq->arena, sizeof(*ret.labels)*labels->seq->used) }; // i is which label we're on for (size_t i=0; iseq->used; ++i) { ret.labels[i].len = labels->seq->elements[i]->seq->used; - ret.labels[i].label = h_arena_malloc(t->seq->arena, sizeof(uint8_t)*ret.labels[i].len); + ret.labels[i].label = h_arena_malloc(t->seq->arena, ret.labels[i].len + 1); // j is which char of the label we're on for (size_t j=0; jseq->elements[i]->seq->elements[j]->uint; + ret.labels[i].label[ret.labels[i].len] = 0; } return ret; } @@ -279,10 +283,11 @@ void set_rr(struct dns_rr rr, HCountedArray *rdata) { } const HParsedToken* pack_dns_struct(const HParseResult *p) { - HParsedToken *ret = h_arena_malloc(p->arena, sizeof(HParsedToken*)); + h_pprint(stdout, p->ast, 0, 2); + HParsedToken *ret = h_arena_malloc(p->arena, sizeof(HParsedToken)); ret->token_type = TT_USER; - dns_message_t *msg = h_arena_malloc(p->arena, sizeof(dns_message_t*)); + dns_message_t *msg = h_arena_malloc(p->arena, sizeof(dns_message_t)); HParsedToken *hdr = p->ast->seq->elements[0]; struct dns_header header = { @@ -388,7 +393,7 @@ const HParser* init_parser() { h_int_range(h_uint16(), 255, 255), NULL); - const HParser *dns_question = h_sequence(h_sequence(h_many1(h_length_value(h_uint8(), + const HParser *dns_question = h_sequence(h_sequence(h_many1(h_length_value(h_int_range(h_uint8(), 1, 255), h_uint8())), h_ch('\x00'), NULL), // QNAME @@ -405,16 +410,128 @@ const HParser* init_parser() { NULL); - dns_message = (HParser*)h_attr_bool(h_sequence(dns_header, - h_many(dns_question), - h_many(dns_rr), - h_end_p(), - NULL), - validate_dns); + dns_message = (HParser*)h_action(h_attr_bool(h_sequence(dns_header, + h_many(dns_question), + h_many(dns_rr), + h_end_p(), + NULL), + validate_dns), + pack_dns_struct); return dns_message; } +int start_listening() { + // return: fd + int sock; + struct sockaddr_in addr; + + sock = socket(PF_INET, SOCK_DGRAM, 0); + if (sock < 0) + err(1, "Failed to open listning socket"); + addr.sin_family = AF_INET; + addr.sin_port = htons(53); + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + int optval = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); + if (bind(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) + err(1, "Bind failed"); + return sock; +} + +const int TYPE_MAX = 16; +typedef const char* cstr; +const char* TYPE_STR[17] = { + "nil", "A", "NS", "MD", + "MF", "CNAME", "SOA", "MB", + "MG", "MR", "NULL", "WKS", + "PTR", "HINFO", "MINFO", "MX", + "TXT" +}; + +const int CLASS_MAX = 4; +const char* CLASS_STR[5] = { + "nil", "IN", "CS", "CH", "HS" +}; + + +void format_qname(struct dns_qname *name, uint8_t **dest) { + uint8_t *rp = *dest; + for (size_t j = 0; j < name->qlen; j++) { + *rp++ = name->labels[j].len; + for (size_t k = 0; k < name->labels[j].len; k++) + *rp++ = name->labels[j].label[k]; + } + *rp++ = 0; + *dest = rp; +} + + int main(int argc, char** argv) { + const HParser *parser = init_parser(); + + + // set up a listening socket... + int sock = start_listening(); + + uint8_t packet[8192]; // static buffer for simplicity + ssize_t packet_size; + struct sockaddr_in remote; + socklen_t remote_len; + + + while (1) { + remote_len = sizeof(remote); + packet_size = recvfrom(sock, packet, sizeof(packet), 0, (struct sockaddr*)&remote, &remote_len); + // dump the packet... + for (int i = 0; i < packet_size; i++) + printf(".%02hhx", packet[i]); + + printf("\n"); + + HParseResult *content = h_parse(parser, packet, packet_size); + if (!content) { + printf("Invalid packet; ignoring\n"); + continue; + } + dns_message_t *message = content->ast->user; + (void)message; + for (size_t i = 0; i < message->header.question_count; i++) { + struct dns_question *question = &message->questions[i]; + printf("Recieved %s %s request for ", CLASS_STR[question->qclass], TYPE_STR[question->qtype]); + for (size_t j = 0; j < question->qname.qlen; j++) + printf("%s.", question->qname.labels[j].label); + printf("\n"); + + } + printf("%p\n", content); + + + // Traditional response for this time of year... + uint8_t response_buf[4096]; + uint8_t *rp = response_buf; + // write out header... + *rp++ = message->header.id >> 8; + *rp++ = message->header.id & 0xff; + *rp++ = 0x80 | (message->header.opcode << 3) | message->header.rd; + *rp++ = 0x3; // change to 0 for no error... + *rp++ = 0; *rp++ = 1; // QDCOUNT + *rp++ = 0; *rp++ = 0; // ANCOUNT + *rp++ = 0; *rp++ = 0; // NSCOUNT + *rp++ = 0; *rp++ = 0; // ARCOUNT + // encode the first question... + { + struct dns_question *question = &message->questions[0]; + format_qname(&question->qname, &rp); + *rp++ = (question->qtype >> 8) & 0xff; + *rp++ = (question->qtype ) & 0xff; + *rp++ = (question->qclass >> 8) & 0xff; + *rp++ = (question->qclass ) & 0xff; + } + // send response. + sendto(sock, response_buf, (rp - response_buf), 0, (struct sockaddr*)&remote, remote_len); + } return 0; } + + diff --git a/src/pprint.c b/src/pprint.c index aa3a914..8dc5852 100644 --- a/src/pprint.c +++ b/src/pprint.c @@ -63,7 +63,11 @@ void h_pprint(FILE* stream, const HParsedToken* tok, int indent, int delta) { h_pprint(stream, tok->seq->elements[i], indent + delta, delta); } fprintf(stream, "%*s]\n", indent, ""); - } // TODO: implement this + } + break; + case TT_USER: + fprintf(stream, "%*sUSER\n", indent, ""); + break; default: g_assert_not_reached(); }