赞
踩
本文对论文:Heterogeneous Graph Attention Network for Malicious Domain Detection进行简要总结并对其进行了复现,提供复现代码。
方法分类 | 优点 | 缺点 |
---|---|---|
基于黑名单 | 实现简单 | 维护困难,且极易被绕过 |
基于域名字符特征的方法 | 有效应对Domain-Flux, Fast-Flux, Double-Flux等躲避技术 | 手工提取特征难 |
深度学习方法 | 自动提取特征 | 易被攻击者的精心设计绕过 |
基于域名关联特征的方法 | 特征难以被绕过,有效检测 | ~ |
本论文的系统整体设计框架如图所示:
本文通过将 DNS 场景建模为由域、客户端、IP 地址及其关联组成的异构信息网络。然后,结合考虑不同邻居的重要性和不同元路径的重要性以捕获多粒度的关键信息,可以有效区分恶意域和良性域。
import datetime import errno import os import pickle import random from pprint import pprint import pandas as pd import dgl from transformers import BertTokenizer, BertModel import numpy as np import torch from dgl.data.utils import _get_dgl_url, download, get_download_dir from scipy import io as sio, sparse def set_random_seed(seed=0): """Set random seed. Parameters ---------- seed : int Random seed to use """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) def mkdir_p(path, log=True): """Create a directory for the specified path. Parameters ---------- path : str Path name log : bool Whether to print result for directory creation """ try: os.makedirs(path) if log: print("Created directory {}".format(path)) except OSError as exc: if exc.errno == errno.EEXIST and os.path.isdir(path) and log: print("Directory {} already exists.".format(path)) else: raise def get_date_postfix(): """Get a date based postfix for directory name. Returns ------- post_fix : str """ dt = datetime.datetime.now() post_fix = "{}_{:02d}-{:02d}-{:02d}".format( dt.date(), dt.hour, dt.minute, dt.second ) return post_fix def setup_log_dir(args, sampling=False): """Name and create directory for logging. Parameters ---------- args : dict Configuration Returns ------- log_dir : str Path for logging directory sampling : bool Whether we are using sampling based training """ date_postfix = get_date_postfix() log_dir = os.path.join( args["log_dir"], "{}_{}".format(args["dataset"], date_postfix) ) if sampling: log_dir = log_dir + "_sampling" mkdir_p(log_dir) return log_dir # The configuration below is from the paper. default_configure = { "lr": 0.001, # Learning rate "num_heads": [8], # Number of attention heads for node-level attention "hidden_units": 8, "dropout": 0.6, "weight_decay": 0.001, "num_epochs": 80, "patience": 100, } sampling_configure = {"batch_size": 50} def setup(args): args.update(default_configure) set_random_seed(args["seed"]) args["dataset"] = "ACMRaw" if args["hetero"] else "ACM" args["device"] = "cpu"#"cuda:0" if torch.cuda.is_available() else "cpu" args["log_dir"] = setup_log_dir(args) return args def setup_for_sampling(args): args.update(default_configure) args.update(sampling_configure) set_random_seed() args["device"] = "cuda:0" if torch.cuda.is_available() else "cpu" args["log_dir"] = setup_log_dir(args, sampling=True) return args def get_binary_mask(total_size, indices):#设置指定位置为1 mask = torch.zeros(total_size) mask[indices] = 1 return mask.byte() def load_acm(remove_self_loop): url = "dataset/ACM3025.pkl" data_path = get_download_dir() + "/ACM3025.pkl" download(_get_dgl_url(url), path=data_path) with open(data_path, "rb") as f: data = pickle.load(f) labels, features = (#在这里,data 是一个字典,其中包含两个键值对,一个是 "label",另一个是 "feature"。代码的目的是将这两个稀疏矩阵转换为 PyTorch 张量。 torch.from_numpy(data["label"].todense()).long(), torch.from_numpy(data["feature"].todense()).float(), ) #获取label向量 num_classes = labels.shape[1] labels = labels.nonzero()[:, 1] if remove_self_loop: num_nodes = data["label"].shape[0] data["PAP"] = sparse.csr_matrix(data["PAP"] - np.eye(num_nodes))#单位矩阵 data["PLP"] = sparse.csr_matrix(data["PLP"] - np.eye(num_nodes)) # Adjacency matrices for meta path based neighbors # (Mufei): I verified both of them are binary adjacency matrices with self loops author_g = dgl.from_scipy(data["PAP"])#基于元路径的邻接矩阵 subject_g = dgl.from_scipy(data["PLP"]) gs = [author_g, subject_g] train_idx = torch.from_numpy(data["train_idx"]).long().squeeze(0)#将 "train_idx" 数据从NumPy数组转换为PyTorch长整数张量 (long()),然后使用 squeeze(0) 方法去掉可能存在的多余的维度,"train_idx" 包含用于训练的示例的索引 val_idx = torch.from_numpy(data["val_idx"]).long().squeeze(0) test_idx = torch.from_numpy(data["test_idx"]).long().squeeze(0) num_nodes = author_g.num_nodes() train_mask = get_binary_mask(num_nodes, train_idx)#行向量 val_mask = get_binary_mask(num_nodes, val_idx) test_mask = get_binary_mask(num_nodes, test_idx) print("dataset loaded") pprint( { "dataset": "ACM", "train": train_mask.sum().item() / num_nodes, "val": val_mask.sum().item() / num_nodes, "test": test_mask.sum().item() / num_nodes, } ) return ( gs, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, val_mask, test_mask, ) def get_keys_by_value(d, value): return [key for key, val in d.items() if val == value] def load_acm_raw(remove_self_loop): assert not remove_self_loop domain_to_numeric = {} #读取数据集 domain_df = pd.read_csv('/HAN_pytorch_bert/data/domain.csv')#数据集文件地址 domains=domain_df['domain'] labels=domain_df['label'].tolist() labels = [x - 1 for x in labels] edge_df = pd.DataFrame(columns=['src_node', 'dst_node', 'edge_type']) #读取节点数字映射文件,因为dgl库只能处理数字类型 file1="/HAN_pytorch_bert/data/domain_to_numeric.pkl" file2="/HAN_pytorch_bert/data/client_to_numeric.pkl" file3="/HAN_pytorch_bert/data/ip_to_numeric.pkl" # 使用pickl读取字典 with open(file1, 'rb') as file: domain_to_numeric = pickle.load(file) with open(file2, 'rb') as file: client_to_numeric = pickle.load(file) with open(file3, 'rb') as file: ip_to_numeric = pickle.load(file) #读取边文件 edge_df=pd.read_csv("HAN_pytorch_bert/data/edges.csv",header=None) query_edge0,resolve_edge0,cname_edge0=[],[],[] query_edge1,resolve_edge1,cname_edge1=[],[],[] query_value='query' resolve_value='resolve' cname_value='CNAME' for index,row in edge_df.iterrows(): #print(row) if row[2]==query_value: query_edge0.append(client_to_numeric[row[0]]) #print(client_to_numeric[row[0]]) query_edge1.append(domain_to_numeric[row[1]]) if row[2]==resolve_value: resolve_edge0.append(domain_to_numeric[row[0]]) resolve_edge1.append(ip_to_numeric[row[1]]) if row[2]==cname_value: cname_edge0.append(domain_to_numeric[row[0]]) cname_edge1.append(domain_to_numeric[row[1]]) G = dgl.heterograph({ ('client', 'query', 'domain'): (query_edge0,query_edge1), # 添加 'query' 类型的边 ('domain', 'queried', 'client'): (query_edge1,query_edge0), # 添加 'queried' 类型的边 ('domain', 'resolve', 'ip'): (resolve_edge0,resolve_edge1), # 添加 'resolve' 类型的边 ('ip', 'resolved', 'domain'): (resolve_edge1,resolve_edge0), # 添加 'resolved' 类型的边 ('domain', 'cname', 'domain'): (cname_edge0,cname_edge1), # 添加 'cname' 类型的边 ('domain', 'cname', 'domain'): (cname_edge1,cname_edge0), }) print(G) # 图修剪 # 1. 删除超过35%的客户端查询的域名 client_num=G.number_of_nodes('client') # 找到要删除的域名 threshold_percentage = 0.35 threshold_count = threshold_percentage * client_num # 统计每个域名被查询的次数 domain_query_counts = G.in_degrees(etype='query') domain_query_counts = domain_query_counts.numpy() # 找到要删除的域名 domains_to_remove = [node_id for node_id, count in enumerate(domain_query_counts) if count > threshold_count] new_labels = [] for i, element in enumerate(labels): if i not in domains_to_remove: new_labels.append(element) domain_to_numeric = {key: value for key, value in domain_to_numeric.items() if value not in domains_to_remove} # 删除指定的域名节点及相关边 G.remove_nodes(domains_to_remove, ntype='domain') # 指定节点类型为'domain' num_specified_nodes = G.number_of_nodes('domain') # 2. 删除查询域名少于4个或大于100个的客户端 # 统计每个客户端查询的域名数量 client_query_counts = G.out_degrees(etype='query') client_query_counts=client_query_counts.numpy() # 找到要删除的客户端 clients_to_remove = [node_id for node_id, count in enumerate(client_query_counts) if count < 4 or count > 100] # 删除指定的客户端节点及相关边 G.remove_nodes(clients_to_remove, ntype='client') # 指定节点类型为'client' # 4. 删除被一个域名解析到的IP地址 domain_resolved_counts = G.in_degrees(etype='resolve') domain_resolved_counts=domain_resolved_counts.numpy() # 找到要删除的IP地址 ips_to_remove = [node_id for node_id,count in enumerate(domain_resolved_counts) if count == 1] # 删除指定的IP地址节点及相关边 G.remove_nodes(ips_to_remove, ntype='ip') # 指定节点类型为'ip' # 定义BERT模型和分词器 tokenizer = BertTokenizer.from_pretrained("/home/bert-domain")#Bert模型地址 bert_model = BertModel.from_pretrained("/home/bert-domain") domain_names=domain_df.unique() # 为每个域名生成BERT嵌入 domain_embeddings = {} h = torch.zeros(len(new_labels),768) i=0 for key, node_id in domain_to_numeric.items(): # 分词并添加特殊标记 tokens = tokenizer.tokenize(key) inputs = tokenizer.encode(" ".join(tokens), return_tensors="pt") # 获取BERT模型的输出 with torch.no_grad(): outputs = bert_model(inputs) # 获取域名节点的初始嵌入表示 initial_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() values = torch.tensor(initial_embedding) h[i, :] = values # 将ID和嵌入添加到字典中 domain_embeddings[node_id] = initial_embedding i=i+1 features = torch.FloatTensor(h) labels = torch.LongTensor(new_labels) num_classes = 2 import numpy as np # 假设 N 为节点总数 N = labels.shape[0] # 计算各子集数量 train_count = int(0.75 * N) val_count = int(0.25 * N) # 生成随机索引 all_indices = np.arange(N) np.random.shuffle(all_indices) # 划分数据集 train_idx = all_indices[:train_count] val_idx = all_indices[train_count:] num_nodes = N train_mask = get_binary_mask(num_nodes, train_idx) val_mask = get_binary_mask(num_nodes, val_idx) return ( G, features, labels, num_classes, train_idx, val_idx ) def load_data(dataset, remove_self_loop=False): if dataset == "ACM": return load_acm(remove_self_loop) elif dataset == "ACMRaw": return load_acm_raw(remove_self_loop) else: return NotImplementedError("Unsupported dataset {}".format(dataset)) class EarlyStopping(object): def __init__(self, patience=10): dt = datetime.datetime.now() self.filename = "early_stop_{}_{:02d}-{:02d}-{:02d}.pth".format( dt.date(), dt.hour, dt.minute, dt.second ) self.patience = patience self.counter = 0 self.best_acc = None self.best_loss = None self.early_stop = False def step(self, loss, acc, model): if self.best_loss is None: self.best_acc = acc self.best_loss = loss self.save_checkpoint(model) elif (loss > self.best_loss) and (acc < self.best_acc): self.counter += 1 print( f"EarlyStopping counter: {self.counter} out of {self.patience}" ) if self.counter >= self.patience: self.early_stop = True else: if (loss <= self.best_loss) and (acc >= self.best_acc): self.save_checkpoint(model) self.best_loss = np.min((loss, self.best_loss)) self.best_acc = np.max((acc, self.best_acc)) self.counter = 0 return self.early_stop #两个方法分别用于保存和加载深度学习模型的检查点。检查点通常包含了模型的参数和训练过程中的状态信息,可以用于恢复模型的训练或用于后续的推理任务。 def save_checkpoint(self, model): """Saves model when validation loss decreases.""" torch.save(model.state_dict(), self.filename) def load_checkpoint(self, model): """Load the latest checkpoint.""" model.load_state_dict(torch.load(self.filename))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。