赞
踩
在基于预训练大模型引发自然语言处理革命的今天,代码智能技术也在迅速跟进发展。
那么,代码智能主要在做一些什么样的事情呢?可能很多同学会有比较科幻的想法,比如程序员要失业了之类的。
但是,其实很多工作并没有那么神秘,非常基础。那么我们用代码智能要解决什么问题呢?
看了之后是不是觉得更玄幻了?这么困难的问题怎么搞得定?
诚实地讲,这其中的每个子问题都很困难,就算是人类学习起来也很困难。
不过,正像是人类也是一步一步学会的一样,机器也在不断地进步。我们需要的不一定是万能的机器神,也是和我们一样普通的机器人,它们有很大的局限,但是它们可以帮助我们减轻不少工作量。
而且,最后一节我们将揭晓,处理这么多如此复杂问题的方法,却非常简单,一把梭哈,我们只用一个模型就能搞定。
下面我们就详细看一看这些问题的细节。
万地高楼平地起,代码智能任务首先从克隆检测做起。
所谓克隆检测,就是寻找写法和功能上相似的代码。
不要小看代码重复,它会显著地降低代码智能训练的有效性。
我们看下图,训练集中有重复,测试集中有重复,它们的交集中仍然有重复,在论文《The Adverse Effects of Code Duplication in Machine Learning Models of Code》中有详细的分析。
以下的例子来自BigCloneBench数据集. 论文地址在:https://arxiv.org/pdf/2002.08653.pdf
下面我们举几个例子来看什么算相似:
代码1:
private StringBuffer encoder(String arg) {
if (arg == null) {
arg = "";
}
MessageDigest md5 = null;
try {
md5 = MessageDigest.getInstance("MD5");
md5.update(arg.getBytes(SysConstant.charset));
} catch (Exception e) {
e.printStackTrace();
}
return toHex(md5.digest());
}
代码2:
public String kodetu(String testusoila) { MessageDigest md = null; try { md = MessageDigest.getInstance("SHA"); md.update(testusoila.getBytes("UTF-8")); } catch (NoSuchAlgorithmException e) { new MezuLeiho("Ez da zifraketa algoritmoa aurkitu", "Ados", "Zifraketa Arazoa", JOptionPane.ERROR_MESSAGE); e.printStackTrace(); } catch (UnsupportedEncodingException e) { new MezuLeiho("Errorea kodetzerakoan", "Ados", "Kodeketa Errorea", JOptionPane.ERROR_MESSAGE); e.printStackTrace(); } byte raw[] = md.digest(); String hash = (new BASE64Encoder()).encode(raw); return hash; }
代码2的字符串是用巴斯克语写的。它们用的算法也有区别,判空和异常处理也有不同,但是我们认为它们是很类似的,属于克隆识别认为相同或高度相似的。
我们再看一对例子:
代码1:
public static void test(String args[]) { int trace; int bytes_read = 0; int last_contentLenght = 0; try { BufferedReader reader; URL url; url = new URL(args[0]); URLConnection istream = url.openConnection(); last_contentLenght = istream.getContentLength(); reader = new BufferedReader(new InputStreamReader(istream.getInputStream())); System.out.println(url.toString()); String line; trace = t2pNewTrace(); while ((line = reader.readLine()) != null) { bytes_read = bytes_read + line.length() + 1; t2pProcessLine(trace, line); } t2pHandleEventPairs(trace); t2pSort(trace, 0); t2pExportTrace(trace, new String("pngtest2.png"), 1000, 700, (float) 0, (float) 33); t2pExportTrace(trace, new String("pngtest3.png"), 1000, 700, (float) 2.3, (float) 2.44); System.out.println("Press any key to contiune read from stream !!!"); System.out.println(t2pGetProcessName(trace, 0)); System.in.read(); istream = url.openConnection(); if (last_contentLenght != istream.getContentLength()) { istream = url.openConnection(); istream.setRequestProperty("Range", "bytes=" + Integer.toString(bytes_read) + "-"); System.out.println(Integer.toString(istream.getContentLength())); reader = new BufferedReader(new InputStreamReader(istream.getInputStream())); while ((line = reader.readLine()) != null) { System.out.println(line); t2pProcessLine(trace, line); } } else System.out.println("File not changed !"); t2pDeleteTrace(trace); } catch (MalformedURLException e) { System.out.println("MalformedURLException !!!"); } catch (IOException e) { System.out.println("File not found " + args[0]); } ; }
代码2:
private static String loadUrlToString(String a_url) throws IOException {
URL l_url1 = new URL(a_url);
BufferedReader br = new BufferedReader(new InputStreamReader(l_url1.openStream()));
String l_content = "";
String l_ligne = null;
l_content = br.readLine();
while ((l_ligne = br.readLine()) != null) {
l_content += AA.SL + l_ligne;
}
return l_content;
}
这个虽然没有涉及小语种,但是明显代码长度差异巨大。不过,我们仍然认为它们是相似的。
我们看一对不相似的吧:
代码1:
private void setNodekeyInJsonResponse(String service) throws Exception {
String filename = this.baseDirectory + service + ".json";
Scanner s = new Scanner(new File(filename));
PrintWriter fw = new PrintWriter(new File(filename + ".new"));
while (s.hasNextLine()) {
fw.println(s.nextLine().replaceAll("NODEKEY", this.key));
}
s.close();
fw.close();
(new File(filename + ".new")).renameTo(new File(filename));
}
代码2:
public void transform(String style, String spec, OutputStream out) throws IOException {
URL url = new URL(rootURL, spec);
InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));
transform(style, in, out);
in.close();
}
不相似的就不解释了。
BigCloneBench数据集,就是提供了两段代码,以及它们是否相似的人工打标的结果。
数据分为train.txt, valid.txt, test.txt三个集合,它们的格式都是同样的:
idx1 idx2 0/1
其中idx1和idx2是两段代码在data.jsonl中的索引值,最后一个是它们是否相似的人工打标的值。
代码都保存在data.jsonl中,格式为:
{"func":"代码","idx":"idx值"}
我们以训练集train.txt为例,其前两行是这样的:
13988825 8660836 0
80378 18548122 1
13988825在data.jsonl中对应的结构是这样的:
{"func": " private void setNodekeyInJsonResponse(String service) throws Exception {\n String filename = this.baseDirectory + service + \".json\";\n Scanner s = new Scanner(new File(filename));\n PrintWriter fw = new PrintWriter(new File(filename + \".new\"));\n while (s.hasNextLine()) {\n fw.println(s.nextLine().replaceAll(\"NODEKEY\", this.key));\n }\n s.close();\n fw.close();\n (new File(filename + \".new\")).renameTo(new File(filename));\n }\n", "idx": "13988825"}
8660836对应的是:
{"func": " public void transform(String style, String spec, OutputStream out) throws IOException {\n URL url = new URL(rootURL, spec);\n InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n transform(style, in, out);\n in.close();\n }\n", "idx": "8660836"}
而它们的结果是不相似。大家看到,这个例子就是刚才上面我们写的第三个例子。
这个我们使用北大李戈李师团队的POJ-104数据集。
这个数据集需要到https://drive.google.com/uc?id=0B2i-vWnOu7MxVlJwQXN6eVNONUU去下载。
每个代码段用一个index来描述,然后code字段是完整的代码。我们来看个例子:
{ "label":"1", "index":"0", "code":" int f(int a,int x) { int count=1,i; for(i=x;i<a;i++) if(a%i==0) count+=f(a/i,i); if(i==a) return count; else return 0; } void main() { int n,a; scanf(\"%d\",&n); for(;n>0;n--) { scanf(\"%d\",&a); if(a==1||a==2) printf(\"1\ \"); else printf(\"%d\ \",f(a,2)); } } " }
然后,这个任务的目的就是求出针对某一段代码最相似的代码段。以取top 2为例:输出的样例如下:
{"index": "0", "answers": ["3", "2"]}
{"index": "1", "answers": ["0", "4"]}
{"index": "2", "answers": ["0", "1"]}
{"index": "4", "answers": ["1", "5"]}
{"index": "3", "answers": ["4", "2"]}
{"index": "5", "answers": ["4", "3"]}
也就是说,针对于代码index 0, 最相似的代码段是 index 3和2.
index 3是这样的:
void qut(int a,int b); //???? int num=0; //????????? int main() { int i,n,g[1000]; //????????? cin>>n; for(i=0;i<n;i++) //?????? cin>>g[i]; for(i=0;i<n;i++) { qut(g[i],1); //???? cout<<num<<endl; num=0; } return 0; } void qut(int a,int b) { int i; if (a>=b) { num++; if (b==1) b++; for (i=b;i<=a;i++) { if (a%i==0) { qut(a/i,i); //??a%i==0,?? } } } }
缺陷检测的数据集非常简单粗暴,就是一段打标的代码,标识是不是有漏洞。
我们看个有漏洞的例子:
{ "project":"FFmpeg", "commit_id":"aba232cfa9b193604ed98f3fa505378d006b1b3b", "target":1, "func":"static int r3d_read_rdvo(AVFormatContext *s, Atom *atom) { R3DContext *r3d = s->priv_data; AVStream *st = s->streams[0]; int i; r3d->video_offsets_count = (atom->size - 8) / 4; r3d->video_offsets = av_malloc(atom->size); if (!r3d->video_offsets) return AVERROR(ENOMEM); for (i = 0; i < r3d->video_offsets_count; i++) { r3d->video_offsets[i] = avio_rb32(s->pb); if (!r3d->video_offsets[i]) { r3d->video_offsets_count = i; break; } av_dlog(s, \"video offset %d: %#x\ \", i, r3d->video_offsets[i]); } if (st->r_frame_rate.num) st->duration = av_rescale_q(r3d->video_offsets_count, (AVRational){st->r_frame_rate.den, st->r_frame_rate.num}, st->time_base); av_dlog(s, \"duration %\"PRId64\"\ \", st->duration); return 0; } ", "idx":5 }
信息就这么多,至于哪行是什么问题,训练集中没有。
当然,数据集里大部分还是没有漏洞的,比如第一条:
{"project": "FFmpeg", "commit_id": "973b1a6b9070e2bf17d17568cbaf4043ce931f51", "target": 0, "func": "static av_cold int vdadec_init(AVCodecContext *avctx)\n\n{\n\n VDADecoderContext *ctx = avctx->priv_data;\n\n struct vda_context *vda_ctx = &ctx->vda_ctx;\n\n OSStatus status;\n\n int ret;\n\n\n\n ctx->h264_initialized = 0;\n\n\n\n /* init pix_fmts of codec */\n\n if (!ff_h264_vda_decoder.pix_fmts) {\n\n if (kCFCoreFoundationVersionNumber < kCFCoreFoundationVersionNumber10_7)\n\n ff_h264_vda_decoder.pix_fmts = vda_pixfmts_prior_10_7;\n\n else\n\n ff_h264_vda_decoder.pix_fmts = vda_pixfmts;\n\n }\n\n\n\n /* init vda */\n\n memset(vda_ctx, 0, sizeof(struct vda_context));\n\n vda_ctx->width = avctx->width;\n\n vda_ctx->height = avctx->height;\n\n vda_ctx->format = 'avc1';\n\n vda_ctx->use_sync_decoding = 1;\n\n vda_ctx->use_ref_buffer = 1;\n\n ctx->pix_fmt = avctx->get_format(avctx, avctx->codec->pix_fmts);\n\n switch (ctx->pix_fmt) {\n\n case AV_PIX_FMT_UYVY422:\n\n vda_ctx->cv_pix_fmt_type = '2vuy';\n\n break;\n\n case AV_PIX_FMT_YUYV422:\n\n vda_ctx->cv_pix_fmt_type = 'yuvs';\n\n break;\n\n case AV_PIX_FMT_NV12:\n\n vda_ctx->cv_pix_fmt_type = '420v';\n\n break;\n\n case AV_PIX_FMT_YUV420P:\n\n vda_ctx->cv_pix_fmt_type = 'y420';\n\n break;\n\n default:\n\n av_log(avctx, AV_LOG_ERROR, \"Unsupported pixel format: %d\\n\", avctx->pix_fmt);\n\n goto failed;\n\n }\n\n status = ff_vda_create_decoder(vda_ctx,\n\n avctx->extradata, avctx->extradata_size);\n\n if (status != kVDADecoderNoErr) {\n\n av_log(avctx, AV_LOG_ERROR,\n\n \"Failed to init VDA decoder: %d.\\n\", status);\n\n goto failed;\n\n }\n\n avctx->hwaccel_context = vda_ctx;\n\n\n\n /* changes callback functions */\n\n avctx->get_format = get_format;\n\n avctx->get_buffer2 = get_buffer2;\n\n#if FF_API_GET_BUFFER\n\n // force the old get_buffer to be empty\n\n avctx->get_buffer = NULL;\n\n#endif\n\n\n\n /* init H.264 decoder */\n\n ret = ff_h264_decoder.init(avctx);\n\n if (ret < 0) {\n\n av_log(avctx, AV_LOG_ERROR, \"Failed to open H.264 decoder.\\n\");\n\n goto failed;\n\n }\n\n ctx->h264_initialized = 1;\n\n\n\n return 0;\n\n\n\nfailed:\n\n vdadec_close(avctx);\n\n return -1;\n\n}\n", "idx": 0}
推理搞起来也是十分省事了,就是对应每个index给个0或1的结果:
0 0
1 1
2 1
3 0
4 0
有了识别代码漏洞的,更进一步就是学习自动修复代码的了。
代码自动修复的题目也很简单,一段是有bug的代码,另一段是修复之后的代码。
我们来看一个例子:
有bug的代码是这样的:
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( VAR_1 . length ) - 1 ) ] . getTime ( ) ) ; }
修复之后是这样子的:
public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( type ) - 1 ) ] . getTime ( ) ) ; }
也真难为算法了,人看起来都有点费事。
比如实现C#语言和Java语言的互译。我们只要有一系列代码的C#写法和Java写法,就可以进行学习进行互译。
我们来看一对例子。
先看C#代码:
public virtual ListSpeechSynthesisTasksResponse ListSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request){
var options = new InvokeOptions();
options.RequestMarshaller = ListSpeechSynthesisTasksRequestMarshaller.Instance;
options.ResponseUnmarshaller = ListSpeechSynthesisTasksResponseUnmarshaller.Instance;
return Invoke<ListSpeechSynthesisTasksResponse>(request, options);
}
对应的Java
public ListSpeechSynthesisTasksResult listSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request) {
request = beforeClientExecution(request);
return executeListSpeechSynthesisTasks(request);
}
在训练素材中,有代码和注释,这个任务的目的为新代码写注释。评价指标是对于生成的注释的语言准确度。
这个我们使用CodeSearchNet数据集。
这个数据集中的每条记录的格式如下:
我们来看个例子:
{"repo": "ciena-blueplanet/bunsen-core", "path": "src/reducer.js", "func_name": "", "original_string": "function
(state, action) {\n return _.defaults({\n isValidating: action.isValidating,\n lastAction: IS_VALIDA
TING\n }, state)\n }", "language": "javascript", "code": "function (state, action) {\n return _.defaults({
\n isValidating: action.isValidating,\n lastAction: IS_VALIDATING\n }, state)\n }", "code_tokens":
["function", "(", "state", ",", "action", ")", "{", "return", "_", ".", "defaults", "(", "{", "isValidating", ":"
, "action", ".", "isValidating", ",", "lastAction", ":", "IS_VALIDATING", "}", ",", "state", ")", "}"], "docstrin
g": "Update is validating result\n@param {State} state - state to update\n@param {Action} action - action\n@retur
ns {State} - updated state", "docstring_tokens": ["Update", "is", "validating", "result"], "sha": "993c67e314e2b7
5003a1ff4c2f0cb667715562b2", "url": "https://github.com/ciena-blueplanet/bunsen-core/blob/993c67e314e2b75003a1ff4
c2f0cb667715562b2/src/reducer.js#L394-L399", "partition": "train"}
对于生成的自然语言,我们采用《ORANGE: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation 》论文的方法进行评分。
我们仍然使用上一节的CodeSearchNet数据集。
这个搜索的结果类似于下面这样:
{"url": "url0", "answers": [10,11,12,13,14]}
{"url": "url1", "answers": [10,12,11,13,14]}
{"url": "url2", "answers": [13,11,12,10,14]}
{"url": "url3", "answers": [10,14,12,13,11]}
{"url": "url4", "answers": [10,11,12,13,14]}
配上UI,大致实现的效果是这样的:
或者是这样:
这是终极任务,就是根据一段文本描述硬生生地生成一段代码出来。
格式非常简单,就一段代码和一段文本。
我们来看个训练样本的例子:
{"code": "void function ( Binder arg0 ) { EventBus loc0 = new EventBus ( ) ; AmbariEventPublisher loc1 = new AmbariEventPublisher ( ) ; repla
ceEventBus ( AmbariEventPublisher . class , loc1 , loc0 ) ; arg0 . bind ( AmbariEventPublisher . class ) . toInstance ( loc1 ) ; }", "nl": "force the eventb us from ambarievent publisher to be serialand synchronous . concode_field_sep PlaceHolder placeHolder concode_field_sep void registerAlertListeners concode_elem_sep EventBus synchronizeAlertEventPublisher concode_elem_sep void replaceEventBus concode_elem_sep void registerAmbariListeners"}
这NL部分有点乱啊,没办法,为了增加数据量,没有那么多人手打精确的标。
我们再看一个:
{"code": "byte [ ] function ( Class < ? > arg0 , Configuration arg1 ) { return AuthenticationTokenSerializer . serialize ( org . apache . acc
umulo . core . client . mapreduce . lib . impl . ConfiguratorBase . getAuthenticationToken ( arg0 , arg1 ) ) ; }", "nl": "do n't use this . n
o , really , do n't use this . you already have an authenticationtoken with org.apache.accumulo.core.client.mapreduce.lib.impl.configuratorba
se #getauthenticationtoken class , configuration . you do n't need to construct it yourself . gets the password from the configuration . warn
ing : the password is stored in the configuration and shared with all mapreduce tasks ; it is base64 encoded to provide a charset safe conver
sion to a string , and is not intended to be secure . concode_field_sep PlaceHolder placeHolder concode_field_sep String getPrincipal concode
_elem_sep void setLogLevel concode_elem_sep Level getLogLevel concode_elem_sep Boolean isConnectorInfoSet concode_elem_sep String getTokenCla
ss concode_elem_sep void setZooKeeperInstance concode_elem_sep void setMockInstance concode_elem_sep Instance getInstance concode_elem_sep St
ring enumToConfKey concode_elem_sep void setConnectorInfo"}
是不是质量也没好到哪儿去?这就是CONCODE数据集的样子。
402年前,当努尔哈赤面临明朝多路大军的围困的时候,采取了“凭你几路来,我只一路去”的战术赢得了萨尔浒之战的立国之战。
我们同样学习古人的智慧,任你数据集千变万化,我们的工具就只用一个 - 大规模预训练模型。
下面是预训练模型的简要发展史:
以开头我们展示的微软的codebert模型为例,我们要处理上面最复杂的代码生成任务,只要一条命令就可以搞定:
python -m torch.distributed.launch --nproc_per_node=$PER_NODE_GPU run.py \ --data_dir=$DATADIR \ --langs=$LANG \ --output_dir=$OUTPUTDIR \ --pretrain_dir=$PRETRAINDIR \ --log_file=$LOGFILE \ --model_type=gpt2 \ --block_size=512 \ --do_train \ --node_index 0 \ --gpu_per_node $PER_NODE_GPU \ --learning_rate=5e-5 \ --weight_decay=0.01 \ --evaluate_during_training \ --per_gpu_train_batch_size=6 \ --per_gpu_eval_batch_size=12 \ --gradient_accumulation_steps=2 \ --num_train_epochs=30 \ --logging_steps=100 \ --save_steps=5000 \ --overwrite_output_dir \ --seed=42
如果使用两张2 NVIDIA P100 GPU卡的话,22小时左右就可以训练完。
推理呢,也是一条语句就搞定:
python -u run.py \
--data_dir=$DATADIR \
--langs=$LANG \
--output_dir=$OUTPUTDIR \
--pretrain_dir=$PRETRAINDIR \
--log_file=$LOGFILE \
--model_type=gpt2 \
--block_size=512 \
--do_infer \
--logging_steps=100 \
--seed=42
只用一张P100卡,大约40分钟就可以搞定。
有了上面的基础,我们就可以去打比赛啦。上面介绍的数据集,全都是比赛的赛题:
上面提到的数据集,可以在https://github.com/microsoft/CodeXGLUE下载到。
欢迎来到代码智能的世界!
放翁云:纸上得来终觉浅,绝知此事要躬行。
下面我们就落地下,将代码智能模型的训练和推理跑起来~~~
pip install transformers --user
pip install torch torchvision torchtext torchaudio --user
git clone https://github.com/microsoft/CodeXGLUE
到Code-Code/Clone-detection-BigCloneBench/code目录下,运行:
python run.py --output_dir=./saved_models --model_type=roberta --config_name=microsoft/codebert-base --model_name_or_path=microsoft/codebert-base --tokenizer_name=roberta-base --do_train --train_data_file=../dataset/train.txt --eval_data_file=../dataset/valid.txt --test_data_file=../dataset/test.txt --epoch 2 --block_size 400 --train_batch_size 16 --eval_batch_size 32 --learning_rate 5e-5 --max_grad_norm 1.0 --evaluate_during_training --seed 123456 2>&1| tee train.log
然后训练就运行起来了:
07/05/2021 16:29:24 - INFO - __main__ - ***** Running training *****
07/05/2021 16:29:24 - INFO - __main__ - Num examples = 90102
07/05/2021 16:29:24 - INFO - __main__ - Num Epochs = 2
07/05/2021 16:29:24 - INFO - __main__ - Instantaneous batch size per GPU = 8
07/05/2021 16:29:24 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 16
07/05/2021 16:29:24 - INFO - __main__ - Gradient Accumulation steps = 1
07/05/2021 16:29:24 - INFO - __main__ - Total optimization steps = 11264
在两张V100卡大约需要训练40分钟左右。
训练之后是验证,会将目前最好的结果保存到checkpoint中以备推理时使用
07/05/2021 17:10:04 - INFO - __main__ - ***** Running evaluation ***** 40950/41541 [00:10<00:00, 2785.61it/s]
07/05/2021 17:10:04 - INFO - __main__ - Num examples = 41541
07/05/2021 17:10:04 - INFO - __main__ - Batch size = 32
07/05/2021 17:16:05 - INFO - __main__ - ***** Eval results *****
07/05/2021 17:16:05 - INFO - __main__ - eval_f1 = 0.9531
07/05/2021 17:16:05 - INFO - __main__ - eval_precision = 0.9579
07/05/2021 17:16:05 - INFO - __main__ - eval_recall = 0.9484
07/05/2021 17:16:05 - INFO - __main__ - eval_threshold = 0.97
07/05/2021 17:16:06 - INFO - __main__ - ********************
07/05/2021 17:16:06 - INFO - __main__ - Best f1:0.9531
07/05/2021 17:16:06 - INFO - __main__ - ********************
07/05/2021 17:16:08 - INFO - __main__ - Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
一次训练两轮,第二轮效果提升到0.97多:
07/05/2021 17:56:43 - INFO - __main__ - ***** Running evaluation ***** 40950/41541 [00:12<00:00, 3535.62it/s] 07/05/2021 17:56:43 - INFO - __main__ - Num examples = 41541 07/05/2021 17:56:43 - INFO - __main__ - Batch size = 32 [W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool) [W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool) [W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool) [W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool) 07/05/2021 18:02:44 - INFO - __main__ - ***** Eval results ***** 07/05/2021 18:02:44 - INFO - __main__ - eval_f1 = 0.9701 07/05/2021 18:02:44 - INFO - __main__ - eval_precision = 0.9772 07/05/2021 18:02:44 - INFO - __main__ - eval_recall = 0.9633 07/05/2021 18:02:44 - INFO - __main__ - eval_threshold = 0.97 07/05/2021 18:02:45 - INFO - __main__ - ******************** 07/05/2021 18:02:45 - INFO - __main__ - Best f1:0.9701 07/05/2021 18:02:45 - INFO - __main__ - ******************** 07/05/2021 18:02:47 - INFO - __main__ - Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
然后我们用训好的模型进行推理吧:
python run.py \ --output_dir=./saved_models \ --model_type=roberta \ --config_name=microsoft/codebert-base \ --model_name_or_path=microsoft/codebert-base \ --tokenizer_name=roberta-base \ --do_eval \ --do_test \ --train_data_file=../dataset/train.txt \ --eval_data_file=../dataset/valid.txt \ --test_data_file=../dataset/test.txt \ --epoch 2 \ --block_size 400 \ --train_batch_size 16 \ --eval_batch_size 32 \ --learning_rate 5e-5 \ --max_grad_norm 1.0 \ --evaluate_during_training \ --seed 123456 2>&1| tee test.log
最后我们运行evaluator.py来查看测试结果:
python ../evaluator/evaluator.py -a ../dataset/test.txt -p saved_models/predictions.txt
输出如下:
{'Recall': 0.9677421599288263, 'Prediction': 0.9557057904236594, 'F1': 0.9616080550111168}
准确率0.956, 召回率0.968,还不错~
跟CodeXGLUE的排行榜比一比:
跟榜上的CodeBert的结果基本一致
要提升性能,我们可以用GraphCodeBert来替换CodeBert.
我们先下载GraphCodeBert的代码:
git clone https://github.com/microsoft/CodeBERT
然后转到GraphCodeBERT/clonedetection目录,解压dataset.zip:
unzip dataset.zip
然后就可以像训练codebert一样训练graphcodebert了:
mkdir saved_models python run.py \ --output_dir=saved_models \ --config_name=microsoft/graphcodebert-base \ --model_name_or_path=microsoft/graphcodebert-base \ --tokenizer_name=microsoft/graphcodebert-base \ --do_train \ --train_data_file=dataset/train.txt \ --eval_data_file=dataset/valid.txt \ --test_data_file=dataset/test.txt \ --epoch 1 \ --code_length 512 \ --data_flow_length 128 \ --train_batch_size 16 \ --eval_batch_size 32 \ --learning_rate 2e-5 \ --max_grad_norm 1.0 \ --evaluate_during_training \ --seed 123456 2>&1| tee saved_models/train.log
上面的参数是按4个V100 GPU来调的,如果只有两块V100,可以将–code_length改成256.
CodeBert 40分钟左右一轮,GraphCodeBert大约需要6个半小时一轮。
然后我们进行推理:
python run.py --output_dir=saved_models --config_name=microsoft/graphcodebert-base --model_name_or_path=microsoft/graphcodebert-base --tokenizer_name=microsoft/graphcodebert-base --do_eval --do_test --train_data_file=dataset/train.txt --eval_data_file=dataset/valid.txt --test_data_file=dataset/test.txt --epoch 1 --code_length 256 --data_flow_length 128 --train_batch_size 16 --eval_batch_size 32 --learning_rate 2e-5 --max_grad_norm 1.0 --evaluate_during_training --seed 123456 2>&1| tee saved_models/test.log
最后我们解读一下结果吧:
python evaluator/evaluator.py -a dataset/test.txt -p saved_models/predictions.txt 2>&1| tee saved_models/score.log
结果如下:
{'Recall': 0.9589415798936043, 'Prediction': 0.962620653900429, 'F1': 0.9607703728051462}
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。