赞
踩
背景: 分别在不同的平台上运行pytorch算子,将算子的描述、延迟、输出结果的统计信息等记录到MongoDB中,方便后续的对比和分析
备注:
docker pull mongo
docker run -d --name mongodb -p 27017:27017 -v ~/mongo-data:/data/db mongo
docker run --name redis -p 1968:6379 -d redis:5.0.14
安装pymongo
pip3 install pymongo
pip3 install redis -i https://pypi.tuna.tsinghua.edu.cn/simple
demo
def get_database(): from pymongo import MongoClient import redis client = MongoClient('mongodb://localhost:27017/') db = client['torch_ops_benchmark'] r = redis.Redis(host='localhost', port=1968, db=0) return (db,r) def is_record_exsist(dbs,index,device_type): db,r=dbs cache_key = f'{index}_{device_type}' user_data = r.get(cache_key) return user_data is not None collection = db['reports'] query = {"index":f"{index}", f"{device_type}": {"$exists": True}} t0=time.time() result = collection.count_documents(query)>0 t1=time.time() print(t1-t0) return result def save_record(dbs,index,device_type,model_name,op_type,inputs_desc,error,avg_ms_latency,data_desc): db,r=dbs cache_key = f'{index}_{device_type}' collection = db['reports'] query = {"index": f"{index}"} new_data={"$set":{ "index": f"{index}", "model_name" :model_name, "op_type": op_type, "inputs_desc":inputs_desc, f"{device_type}": { "error":error, "avg_ms_latency":avg_ms_latency, "data_desc":data_desc } }} result = collection.update_one(query, new_data, upsert=True) if result.matched_count > 0: pass elif result.upserted_id is not None: pass else: print("Nothing was inserted or updated, something went wrong.") assert(0) r.set(cache_key,"", ex=100000) if __name__ == "__main__": model_name="Llama2_7b_hf_b2_s256" op_type="aten._softmax_backward_data.default" inputs_desc_str="[Tensor(shape:(2,32,256,256)-torch.float32),Tensor(shape:(2,32,256,256)-torch.float32),int(-1),dtype(torch.float16)]#{}" data_desc="65536,0.99951,0.00000" db = get_database() avg_ms_latency=100 for op_index in range(0,80000): index=f"{op_index:010d}" device_type="RTX3060" #跳过已存在的记录 if not is_record_exsist(db,index,device_type): #先设置成false,防止算子执行过程中进程退出 error=-1 save_record(db,index,device_type,model_name,op_type,inputs_desc_str,error,avg_ms_latency,"") error=0 #更新为实际的数据 save_record(db,index,device_type,model_name,op_type,inputs_desc_str,error,avg_ms_latency,data_desc) device_type="RTX3090" if not is_record_exsist(db,index,device_type): #先设置成false,防止算子执行过程中进程退出 error=-1 save_record(db,index,device_type,model_name,op_type,inputs_desc_str,error,avg_ms_latency,"") error=0 #更新为实际的数据 save_record(db,index,device_type,model_name,op_type,inputs_desc_str,error,avg_ms_latency,data_desc) print("[{:08d}/{:08d}]\n".format(op_index,80000),end="")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。