赞
踩
- import argparse
- from argparse import Namespace
- from wandb_train import main
- from pprint import pprint
- import wandb
- import random
- import os
- import csv
-
-
- if __name__ == "__main__":
- config_dataset_name=['assist2009','assist2012','assist2017','nips_task34']
- config_dropout=[0.1, 0.2, 0.3,0.4]
- config_emb_size=[64,128,256]
- config_num_attn_heads=[1,2,4,8]
- config_n_blocks=[1,2,4,8]
- config_learning_rate=[0.1,0.01,0.001,0.003,0.005]
-
- file_path = './result/'+config_dataset_name[3]+'_saint_2.csv'
- if os.path.exists(file_path):
- print("文件存在")
- else:
- file_name='./result/'+config_dataset_name[3]+'_saint_2.csv'
- data = ['best_auc','best_acc','best_epoch', 'dataset_name', 'dropout','emb_size','learning_rate','num_attn_heads','n_blocks']
- with open(file_name, 'w', newline='') as file:
- writer = csv.writer(file)
- writer.writerow(data)
-
- for i in range(900):
- random.seed(i)
- drop_num=random.randint(0,len(config_dropout)-1)
- emb_num=random.randint(0,len(config_emb_size)-1)
- heads_num=random.randint(0,len(config_num_attn_heads)-1)
- blocks_num=random.randint(0,len(config_n_blocks)-1)
- learning_num=random.randint(0,len(config_learning_rate)-1)
-
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--dataset_name", type=str, default=config_dataset_name[3])
- parser.add_argument("--model_name", type=str, default="saint")
- parser.add_argument("--emb_type", type=str, default="qid")
- parser.add_argument("--save_dir", type=str, default="saved_model")
- # parser.add_argument("--learning_rate", type=float, default=1e-5)
- parser.add_argument("--seed", type=int, default=42)
- parser.add_argument("--fold", type=int, default=0)
-
- parser.add_argument("--dropout", type=float, default=config_dropout[drop_num])
- parser.add_argument("--emb_size", type=int, default=config_emb_size[emb_num])
- parser.add_argument("--learning_rate", type=float, default=config_learning_rate[learning_num])
- parser.add_argument("--num_attn_heads", type=int, default=config_num_attn_heads[heads_num])
- parser.add_argument("--n_blocks", type=int, default=config_n_blocks[blocks_num])
- parser.add_argument("--use_wandb", type=int, default=0)
- parser.add_argument("--add_uuid", type=int, default=1)
-
-
- args = parser.parse_args()
- params = vars(args)
-
- main(params)
- # main(params)
- best_auc,best_acc, best_epoch = train_model(model, train_loader, valid_loader, num_epochs, opt, ckpt_path, None, None, save_model)
-
-
- data = [best_auc,best_acc,best_epoch, params["dataset_name"], params["dropout"],params["emb_size"],params["learning_rate"],params["num_attn_heads"],params["n_blocks"]]
- file_path = './result/'+params["dataset_name"]+'_saint_2.csv'
- with open(file_path, 'a', newline='') as file:
- writer = csv.writer(file)
- writer.writerow(data)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。