#include <stdio.h> #include <string.h> #include <core.h> #include "pxe.h" /* DNS CLASS values we care about */ #define CLASS_IN 1 /* DNS TYPE values we care about */ #define TYPE_A 1 #define TYPE_CNAME 5 /* * The DNS header structure */ struct dnshdr { uint16_t id; uint16_t flags; /* number of entries in the question section */ uint16_t qdcount; /* number of resource records in the answer section */ uint16_t ancount; /* number of name server resource records in the authority records section*/ uint16_t nscount; /* number of resource records in the additional records section */ uint16_t arcount; } __attribute__ ((packed)); /* * The DNS query structure */ struct dnsquery { uint16_t qtype; uint16_t qclass; } __attribute__ ((packed)); /* * The DNS Resource recodes structure */ struct dnsrr { uint16_t type; uint16_t class; uint32_t ttl; uint16_t rdlength; /* The lenght of this rr data */ char rdata[]; } __attribute__ ((packed)); #define DNS_PORT htons(53) /* Default DNS port */ #define DNS_MAX_SERVERS 4 /* Max no of DNS servers */ uint32_t dns_server[DNS_MAX_SERVERS] = {0, }; /* * Turn a string in _src_ into a DNS "label set" in _dst_; returns the * number of dots encountered. On return, *dst is updated. */ int dns_mangle(char **dst, const char *p) { char *q = *dst; char *count_ptr; char c; int dots = 0; count_ptr = q; *q++ = 0; while (1) { c = *p++; if (c == 0 || c == ':' || c == '/') break; if (c == '.') { dots++; count_ptr = q; *q++ = 0; continue; } *count_ptr += 1; *q++ = c; } if (*count_ptr) *q++ = 0; /* update the strings */ *dst = q; return dots; } /* * Compare two sets of DNS labels, in _s1_ and _s2_; the one in _s2_ * is allowed pointers relative to a packet in buf. * */ static bool dns_compare(const void *s1, const void *s2, const void *buf) { const uint8_t *q = s1; const uint8_t *p = s2; unsigned int c0, c1; while (1) { c0 = p[0]; if (c0 >= 0xc0) { /* Follow pointer */ c1 = p[1]; p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1; } else if (c0) { c0++; /* Include the length byte */ if (memcmp(q, p, c0)) return false; q += c0; p += c0; } else { return *q == 0; } } } /* * Copy a DNS label into a buffer, considering the possibility that we might * have to follow pointers relative to "buf". * Returns a pointer to the first free byte *after* the terminal null. */ static void *dns_copylabel(void *dst, const void *src, const void *buf) { uint8_t *q = dst; const uint8_t *p = src; unsigned int c0, c1; while (1) { c0 = p[0]; if (c0 >= 0xc0) { /* Follow pointer */ c1 = p[1]; p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1; } else if (c0) { c0++; /* Include the length byte */ memcpy(q, p, c0); p += c0; q += c0; } else { *q++ = 0; return q; } } } /* * Skip past a DNS label set in DS:SI */ static char *dns_skiplabel(char *label) { uint8_t c; while (1) { c = *label++; if (c >= 0xc0) return ++label; /* pointer is two bytes */ if (c == 0) return label; label += c; } } extern const uint8_t TimeoutTable[]; extern uint16_t get_port(void); extern void free_port(uint16_t port); /* * parse the ip_str and return the ip address with *res. * return true if the whole string was consumed and the result * was valid. * */ static bool parse_dotquad(const char *ip_str, uint32_t *res) { const char *p = ip_str; uint8_t part = 0; uint32_t ip = 0; int i; for (i = 0; i < 4; i++) { while (is_digit(*p)) { part = part * 10 + *p - '0'; p++; } if (i != 3 && *p != '.') return false; ip = (ip << 8) | part; part = 0; p++; } p--; *res = htonl(ip); return *p == '\0'; } /* * Actual resolver function * Points to a null-terminated or :-terminated string in _name_ * and returns the ip addr in _ip_ if it exists and can be found. * If _ip_ = 0 on exit, the lookup failed. _name_ will be updated * * XXX: probably need some caching here. */ __export uint32_t dns_resolv(const char *name) { static char __lowmem DNSSendBuf[PKTBUF_SIZE]; static char __lowmem DNSRecvBuf[PKTBUF_SIZE]; char *p; int err; int dots; int same; int rd_len; int ques, reps; /* number of questions and replies */ uint8_t timeout; const uint8_t *timeout_ptr = TimeoutTable; uint32_t oldtime; uint32_t srv; uint32_t *srv_ptr; struct dnshdr *hd1 = (struct dnshdr *)DNSSendBuf; struct dnshdr *hd2 = (struct dnshdr *)DNSRecvBuf; struct dnsquery *query; struct dnsrr *rr; static __lowmem struct s_PXENV_UDP_WRITE udp_write; static __lowmem struct s_PXENV_UDP_READ udp_read; uint16_t local_port; uint32_t result = 0; /* * Return failure on an empty input... this can happen during * some types of URL parsing, and this is the easiest place to * check for it. */ if (!name || !*name) return 0; /* If it is a valid dot quad, just return that value */ if (parse_dotquad(name, &result)) return result; /* Make sure we have at least one valid DNS server */ if (!dns_server[0]) return 0; /* Get a local port number */ local_port = get_port(); /* First, fill the DNS header struct */ hd1->id++; /* New query ID */ hd1->flags = htons(0x0100); /* Recursion requested */ hd1->qdcount = htons(1); /* One question */ hd1->ancount = 0; /* No answers */ hd1->nscount = 0; /* No NS */ hd1->arcount = 0; /* No AR */ p = DNSSendBuf + sizeof(struct dnshdr); dots = dns_mangle(&p, name); /* store the CNAME */ if (!dots) { p--; /* Remove final null */ /* Uncompressed DNS label set so it ends in null */ p = stpcpy(p, LocalDomain); } /* Fill the DNS query packet */ query = (struct dnsquery *)p; query->qtype = htons(TYPE_A); query->qclass = htons(CLASS_IN); p += sizeof(struct dnsquery); /* Now send it to name server */ timeout_ptr = TimeoutTable; timeout = *timeout_ptr++; srv_ptr = dns_server; while (timeout) { srv = *srv_ptr++; if (!srv) { srv_ptr = dns_server; srv = *srv_ptr++; } udp_write.status = 0; udp_write.ip = srv; udp_write.gw = gateway(srv); udp_write.src_port = local_port; udp_write.dst_port = DNS_PORT; udp_write.buffer_size = p - DNSSendBuf; udp_write.buffer = FAR_PTR(DNSSendBuf); err = pxe_call(PXENV_UDP_WRITE, &udp_write); if (err || udp_write.status) continue; oldtime = jiffies(); do { if (jiffies() - oldtime >= timeout) goto again; udp_read.status = 0; udp_read.src_ip = srv; udp_read.dest_ip = IPInfo.myip; udp_read.s_port = DNS_PORT; udp_read.d_port = local_port; udp_read.buffer_size = PKTBUF_SIZE; udp_read.buffer = FAR_PTR(DNSRecvBuf); err = pxe_call(PXENV_UDP_READ, &udp_read); } while (err || udp_read.status || hd2->id != hd1->id); if ((hd2->flags ^ 0x80) & htons(0xf80f)) goto badness; ques = htons(hd2->qdcount); /* Questions */ reps = htons(hd2->ancount); /* Replies */ p = DNSRecvBuf + sizeof(struct dnshdr); while (ques--) { p = dns_skiplabel(p); /* Skip name */ p += 4; /* Skip question trailer */ } /* Parse the replies */ while (reps--) { same = dns_compare(DNSSendBuf + sizeof(struct dnshdr), p, DNSRecvBuf); p = dns_skiplabel(p); rr = (struct dnsrr *)p; rd_len = ntohs(rr->rdlength); if (same && ntohs(rr->class) == CLASS_IN) { switch (ntohs(rr->type)) { case TYPE_A: if (rd_len == 4) { result = *(uint32_t *)rr->rdata; goto done; } break; case TYPE_CNAME: dns_copylabel(DNSSendBuf + sizeof(struct dnshdr), rr->rdata, DNSRecvBuf); /* * We should probably rescan the packet from the top * here, and technically we might have to send a whole * new request here... */ break; default: break; } } /* not the one we want, try next */ p += sizeof(struct dnsrr) + rd_len; } badness: /* * ; We got back no data from this server. ; Unfortunately, for a recursive, non-authoritative ; query there is no such thing as an NXDOMAIN reply, ; which technically means we can't draw any ; conclusions. However, in practice that means the ; domain doesn't exist. If this turns out to be a ; problem, we may want to add code to go through all ; the servers before giving up. ; If the DNS server wasn't capable of recursion, and ; isn't capable of giving us an authoritative reply ; (i.e. neither AA or RA set), then at least try a ; different setver... */ if (hd2->flags == htons(0x480)) continue; break; /* failed */ again: continue; } done: free_port(local_port); /* Return port number to the free pool */ return result; }