赞
踩
- k = 4
- mun_validation_samples = len(data) // k
-
- np.random.shuffle(data)
-
- validation_scores = []
- for fold in range(k):
- validation_data = data[num_validation_samples*fold:num_validation_samples*(fold+1)]
- training_data = data[:num_validataion_samples*fold] +
- data[num_validation_samples* (fold+1):]
-
- model = get_model()
- model.train(training_data)
- validation_score = model.evaluate(validation_data)
- validation_score.append(validation_score)
-
- validation_score = np.average(validation_score)
-
- model = get_model() #在所有非测试数据上训练最终模型
- model.train(data)
- test_score = model.evaluate(test_data)

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。