当前位置:   article > 正文

python实现类似wandb.sweep搜参功能_和wandb类似的

和wandb类似的
  1. import argparse
  2. from argparse import Namespace
  3. from wandb_train import main
  4. from pprint import pprint
  5. import wandb
  6. import random
  7. import os
  8. import csv
  9. if __name__ == "__main__":
  10. config_dataset_name=['assist2009','assist2012','assist2017','nips_task34']
  11. config_dropout=[0.1, 0.2, 0.3,0.4]
  12. config_emb_size=[64,128,256]
  13. config_num_attn_heads=[1,2,4,8]
  14. config_n_blocks=[1,2,4,8]
  15. config_learning_rate=[0.1,0.01,0.001,0.003,0.005]
  16. file_path = './result/'+config_dataset_name[3]+'_saint_2.csv'
  17. if os.path.exists(file_path):
  18. print("文件存在")
  19. else:
  20. file_name='./result/'+config_dataset_name[3]+'_saint_2.csv'
  21. data = ['best_auc','best_acc','best_epoch', 'dataset_name', 'dropout','emb_size','learning_rate','num_attn_heads','n_blocks']
  22. with open(file_name, 'w', newline='') as file:
  23. writer = csv.writer(file)
  24. writer.writerow(data)
  25. for i in range(900):
  26. random.seed(i)
  27. drop_num=random.randint(0,len(config_dropout)-1)
  28. emb_num=random.randint(0,len(config_emb_size)-1)
  29. heads_num=random.randint(0,len(config_num_attn_heads)-1)
  30. blocks_num=random.randint(0,len(config_n_blocks)-1)
  31. learning_num=random.randint(0,len(config_learning_rate)-1)
  32. parser = argparse.ArgumentParser()
  33. parser.add_argument("--dataset_name", type=str, default=config_dataset_name[3])
  34. parser.add_argument("--model_name", type=str, default="saint")
  35. parser.add_argument("--emb_type", type=str, default="qid")
  36. parser.add_argument("--save_dir", type=str, default="saved_model")
  37. # parser.add_argument("--learning_rate", type=float, default=1e-5)
  38. parser.add_argument("--seed", type=int, default=42)
  39. parser.add_argument("--fold", type=int, default=0)
  40. parser.add_argument("--dropout", type=float, default=config_dropout[drop_num])
  41. parser.add_argument("--emb_size", type=int, default=config_emb_size[emb_num])
  42. parser.add_argument("--learning_rate", type=float, default=config_learning_rate[learning_num])
  43. parser.add_argument("--num_attn_heads", type=int, default=config_num_attn_heads[heads_num])
  44. parser.add_argument("--n_blocks", type=int, default=config_n_blocks[blocks_num])
  45. parser.add_argument("--use_wandb", type=int, default=0)
  46. parser.add_argument("--add_uuid", type=int, default=1)
  47. args = parser.parse_args()
  48. params = vars(args)
  49. main(params)
  50. # main(params)
  1. best_auc,best_acc, best_epoch = train_model(model, train_loader, valid_loader, num_epochs, opt, ckpt_path, None, None, save_model)
  2. data = [best_auc,best_acc,best_epoch, params["dataset_name"], params["dropout"],params["emb_size"],params["learning_rate"],params["num_attn_heads"],params["n_blocks"]]
  3. file_path = './result/'+params["dataset_name"]+'_saint_2.csv'
  4. with open(file_path, 'a', newline='') as file:
  5. writer = csv.writer(file)
  6. writer.writerow(data)

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号