当前位置:   article > 正文

代码智能:问题与解法_代码智能任务

代码智能任务

代码智能:问题与解法

在基于预训练大模型引发自然语言处理革命的今天,代码智能技术也在迅速跟进发展。
那么,代码智能主要在做一些什么样的事情呢?可能很多同学会有比较科幻的想法,比如程序员要失业了之类的。
但是,其实很多工作并没有那么神秘,非常基础。那么我们用代码智能要解决什么问题呢?

  • 判断两段代码是不是实现相似的功能
  • 搜索跟当前代码段最相似的代码
  • 检测代码是否有bug
  • 自动修复代码中的bug
  • 给一段代码自动写注释
  • 根据文本推荐最相似的代码段
  • 根据文本生成代码

看了之后是不是觉得更玄幻了?这么困难的问题怎么搞得定?
诚实地讲,这其中的每个子问题都很困难,就算是人类学习起来也很困难。
不过,正像是人类也是一步一步学会的一样,机器也在不断地进步。我们需要的不一定是万能的机器神,也是和我们一样普通的机器人,它们有很大的局限,但是它们可以帮助我们减轻不少工作量。

而且,最后一节我们将揭晓,处理这么多如此复杂问题的方法,却非常简单,一把梭哈,我们只用一个模型就能搞定。

codeBert

下面我们就详细看一看这些问题的细节。

问题:克隆检测 Clone Detection

万地高楼平地起,代码智能任务首先从克隆检测做起。
所谓克隆检测,就是寻找写法和功能上相似的代码。
不要小看代码重复,它会显著地降低代码智能训练的有效性。
我们看下图,训练集中有重复,测试集中有重复,它们的交集中仍然有重复,在论文《The Adverse Effects of Code Duplication in Machine Learning Models of Code》中有详细的分析。

code duplicate

预测两段代码是否相似

以下的例子来自BigCloneBench数据集. 论文地址在:https://arxiv.org/pdf/2002.08653.pdf

下面我们举几个例子来看什么算相似:

代码1:

    private StringBuffer encoder(String arg) {
        if (arg == null) {
            arg = "";
        }
        MessageDigest md5 = null;
        try {
            md5 = MessageDigest.getInstance("MD5");
            md5.update(arg.getBytes(SysConstant.charset));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return toHex(md5.digest());
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

代码2:

    public String kodetu(String testusoila) {
        MessageDigest md = null;
        try {
            md = MessageDigest.getInstance("SHA");
            md.update(testusoila.getBytes("UTF-8"));
        } catch (NoSuchAlgorithmException e) {
            new MezuLeiho("Ez da zifraketa algoritmoa aurkitu", "Ados", "Zifraketa Arazoa", JOptionPane.ERROR_MESSAGE);
            e.printStackTrace();
        } catch (UnsupportedEncodingException e) {
            new MezuLeiho("Errorea kodetzerakoan", "Ados", "Kodeketa Errorea", JOptionPane.ERROR_MESSAGE);
            e.printStackTrace();
        }
        byte raw[] = md.digest();
        String hash = (new BASE64Encoder()).encode(raw);
        return hash;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

代码2的字符串是用巴斯克语写的。它们用的算法也有区别,判空和异常处理也有不同,但是我们认为它们是很类似的,属于克隆识别认为相同或高度相似的。

我们再看一对例子:

代码1:

    public static void test(String args[]) {
        int trace;
        int bytes_read = 0;
        int last_contentLenght = 0;
        try {
            BufferedReader reader;
            URL url;
            url = new URL(args[0]);
            URLConnection istream = url.openConnection();
            last_contentLenght = istream.getContentLength();
            reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));
            System.out.println(url.toString());
            String line;
            trace = t2pNewTrace();
            while ((line = reader.readLine()) != null) {
                bytes_read = bytes_read + line.length() + 1;
                t2pProcessLine(trace, line);
            }
            t2pHandleEventPairs(trace);
            t2pSort(trace, 0);
            t2pExportTrace(trace, new String("pngtest2.png"), 1000, 700, (float) 0, (float) 33);
            t2pExportTrace(trace, new String("pngtest3.png"), 1000, 700, (float) 2.3, (float) 2.44);
            System.out.println("Press any key to contiune read from stream !!!");
            System.out.println(t2pGetProcessName(trace, 0));
            System.in.read();
            istream = url.openConnection();
            if (last_contentLenght != istream.getContentLength()) {
                istream = url.openConnection();
                istream.setRequestProperty("Range", "bytes=" + Integer.toString(bytes_read) + "-");
                System.out.println(Integer.toString(istream.getContentLength()));
                reader = new BufferedReader(new InputStreamReader(istream.getInputStream()));
                while ((line = reader.readLine()) != null) {
                    System.out.println(line);
                    t2pProcessLine(trace, line);
                }
            } else System.out.println("File not changed !");
            t2pDeleteTrace(trace);
        } catch (MalformedURLException e) {
            System.out.println("MalformedURLException !!!");
        } catch (IOException e) {
            System.out.println("File not found " + args[0]);
        }
        ;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

代码2:

    private static String loadUrlToString(String a_url) throws IOException {
        URL l_url1 = new URL(a_url);
        BufferedReader br = new BufferedReader(new InputStreamReader(l_url1.openStream()));
        String l_content = "";
        String l_ligne = null;
        l_content = br.readLine();
        while ((l_ligne = br.readLine()) != null) {
            l_content += AA.SL + l_ligne;
        }
        return l_content;
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

这个虽然没有涉及小语种,但是明显代码长度差异巨大。不过,我们仍然认为它们是相似的。

我们看一对不相似的吧:
代码1:

    private void setNodekeyInJsonResponse(String service) throws Exception {
        String filename = this.baseDirectory + service + ".json";
        Scanner s = new Scanner(new File(filename));
        PrintWriter fw = new PrintWriter(new File(filename + ".new"));
        while (s.hasNextLine()) {
            fw.println(s.nextLine().replaceAll("NODEKEY", this.key));
        }
        s.close();
        fw.close();
        (new File(filename + ".new")).renameTo(new File(filename));
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

代码2:

    public void transform(String style, String spec, OutputStream out) throws IOException {
        URL url = new URL(rootURL, spec);
        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));
        transform(style, in, out);
        in.close();
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

不相似的就不解释了。

BigCloneBench数据集,就是提供了两段代码,以及它们是否相似的人工打标的结果。

数据分为train.txt, valid.txt, test.txt三个集合,它们的格式都是同样的:

idx1 idx2 0/1
  • 1

其中idx1和idx2是两段代码在data.jsonl中的索引值,最后一个是它们是否相似的人工打标的值。
代码都保存在data.jsonl中,格式为:

{"func":"代码","idx":"idx值"}
  • 1

我们以训练集train.txt为例,其前两行是这样的:

13988825	8660836	0
80378	18548122	1
  • 1
  • 2

13988825在data.jsonl中对应的结构是这样的:

{"func": "    private void setNodekeyInJsonResponse(String service) throws Exception {\n        String filename = this.baseDirectory + service + \".json\";\n        Scanner s = new Scanner(new File(filename));\n        PrintWriter fw = new PrintWriter(new File(filename + \".new\"));\n        while (s.hasNextLine()) {\n            fw.println(s.nextLine().replaceAll(\"NODEKEY\", this.key));\n        }\n        s.close();\n        fw.close();\n        (new File(filename + \".new\")).renameTo(new File(filename));\n    }\n", "idx": "13988825"}
  • 1

8660836对应的是:

{"func": "    public void transform(String style, String spec, OutputStream out) throws IOException {\n        URL url = new URL(rootURL, spec);\n        InputStream in = new PatchXMLSymbolsStream(new StripDoctypeStream(url.openStream()));\n        transform(style, in, out);\n        in.close();\n    }\n", "idx": "8660836"}
  • 1

而它们的结果是不相似。大家看到,这个例子就是刚才上面我们写的第三个例子。

搜索跟当前代码段语义最相似的代码段

这个我们使用北大李戈李师团队的POJ-104数据集。

这个数据集需要到https://drive.google.com/uc?id=0B2i-vWnOu7MxVlJwQXN6eVNONUU去下载。

每个代码段用一个index来描述,然后code字段是完整的代码。我们来看个例子:

{
        "label":"1",
        "index":"0",
        "code":"
int f(int a,int x)
{
 int count=1,i;
 for(i=x;i<a;i++)
  if(a%i==0)
   count+=f(a/i,i);
 if(i==a)
  return count;
 else
  return 0;
}

void main()
{
 int n,a;
 scanf(\"%d\",&n);
 for(;n>0;n--)
 {
  scanf(\"%d\",&a);
  if(a==1||a==2)
   printf(\"1\
\");
  else
   printf(\"%d\
\",f(a,2));
 }
}
"
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

然后,这个任务的目的就是求出针对某一段代码最相似的代码段。以取top 2为例:输出的样例如下:

{"index": "0", "answers": ["3", "2"]}
{"index": "1", "answers": ["0", "4"]}
{"index": "2", "answers": ["0", "1"]}
{"index": "4", "answers": ["1", "5"]}
{"index": "3", "answers": ["4", "2"]}
{"index": "5", "answers": ["4", "3"]}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

也就是说,针对于代码index 0, 最相似的代码段是 index 3和2.

index 3是这样的:

void qut(int a,int b);                                       //????
int num=0;                                                    //?????????
int main()
{
 int i,n,g[1000];                                         //?????????
 cin>>n;
 for(i=0;i<n;i++)                                         //??????
  cin>>g[i];
 for(i=0;i<n;i++)
 {
     qut(g[i],1);                                         //????
  cout<<num<<endl;
              num=0;
 }
 return 0;
}

void qut(int a,int b)  
{
 int i;
 if (a>=b)  
 {
  num++;  
  if (b==1)                                      
   b++;
  for (i=b;i<=a;i++) 
  {
   if (a%i==0) 
   {
    qut(a/i,i);                                 //??a%i==0,??
   }
  }
 }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

问题:缺陷检测

缺陷检测的数据集非常简单粗暴,就是一段打标的代码,标识是不是有漏洞。

我们看个有漏洞的例子:

{
        "project":"FFmpeg",
        "commit_id":"aba232cfa9b193604ed98f3fa505378d006b1b3b",
        "target":1,
        "func":"static int r3d_read_rdvo(AVFormatContext *s, Atom *atom)

{

    R3DContext *r3d = s->priv_data;

    AVStream *st = s->streams[0];

    int i;



    r3d->video_offsets_count = (atom->size - 8) / 4;

    r3d->video_offsets = av_malloc(atom->size);

    if (!r3d->video_offsets)

        return AVERROR(ENOMEM);



    for (i = 0; i < r3d->video_offsets_count; i++) {

        r3d->video_offsets[i] = avio_rb32(s->pb);

        if (!r3d->video_offsets[i]) {

            r3d->video_offsets_count = i;

            break;

        }

        av_dlog(s, \"video offset %d: %#x\
\", i, r3d->video_offsets[i]);

    }



    if (st->r_frame_rate.num)

        st->duration = av_rescale_q(r3d->video_offsets_count,

                                    (AVRational){st->r_frame_rate.den,

                                                 st->r_frame_rate.num},

                                    st->time_base);

    av_dlog(s, \"duration %\"PRId64\"\
\", st->duration);



    return 0;

}
",
        "idx":5
    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66

信息就这么多,至于哪行是什么问题,训练集中没有。

当然,数据集里大部分还是没有漏洞的,比如第一条:

{"project": "FFmpeg", "commit_id": "973b1a6b9070e2bf17d17568cbaf4043ce931f51", "target": 0, "func": "static av_cold int vdadec_init(AVCodecContext *avctx)\n\n{\n\n    VDADecoderContext *ctx = avctx->priv_data;\n\n    struct vda_context *vda_ctx = &ctx->vda_ctx;\n\n    OSStatus status;\n\n    int ret;\n\n\n\n    ctx->h264_initialized = 0;\n\n\n\n    /* init pix_fmts of codec */\n\n    if (!ff_h264_vda_decoder.pix_fmts) {\n\n        if (kCFCoreFoundationVersionNumber < kCFCoreFoundationVersionNumber10_7)\n\n            ff_h264_vda_decoder.pix_fmts = vda_pixfmts_prior_10_7;\n\n        else\n\n            ff_h264_vda_decoder.pix_fmts = vda_pixfmts;\n\n    }\n\n\n\n    /* init vda */\n\n    memset(vda_ctx, 0, sizeof(struct vda_context));\n\n    vda_ctx->width = avctx->width;\n\n    vda_ctx->height = avctx->height;\n\n    vda_ctx->format = 'avc1';\n\n    vda_ctx->use_sync_decoding = 1;\n\n    vda_ctx->use_ref_buffer = 1;\n\n    ctx->pix_fmt = avctx->get_format(avctx, avctx->codec->pix_fmts);\n\n    switch (ctx->pix_fmt) {\n\n    case AV_PIX_FMT_UYVY422:\n\n        vda_ctx->cv_pix_fmt_type = '2vuy';\n\n        break;\n\n    case AV_PIX_FMT_YUYV422:\n\n        vda_ctx->cv_pix_fmt_type = 'yuvs';\n\n        break;\n\n    case AV_PIX_FMT_NV12:\n\n        vda_ctx->cv_pix_fmt_type = '420v';\n\n        break;\n\n    case AV_PIX_FMT_YUV420P:\n\n        vda_ctx->cv_pix_fmt_type = 'y420';\n\n        break;\n\n    default:\n\n        av_log(avctx, AV_LOG_ERROR, \"Unsupported pixel format: %d\\n\", avctx->pix_fmt);\n\n        goto failed;\n\n    }\n\n    status = ff_vda_create_decoder(vda_ctx,\n\n                                   avctx->extradata, avctx->extradata_size);\n\n    if (status != kVDADecoderNoErr) {\n\n        av_log(avctx, AV_LOG_ERROR,\n\n                \"Failed to init VDA decoder: %d.\\n\", status);\n\n        goto failed;\n\n    }\n\n    avctx->hwaccel_context = vda_ctx;\n\n\n\n    /* changes callback functions */\n\n    avctx->get_format = get_format;\n\n    avctx->get_buffer2 = get_buffer2;\n\n#if FF_API_GET_BUFFER\n\n    // force the old get_buffer to be empty\n\n    avctx->get_buffer = NULL;\n\n#endif\n\n\n\n    /* init H.264 decoder */\n\n    ret = ff_h264_decoder.init(avctx);\n\n    if (ret < 0) {\n\n        av_log(avctx, AV_LOG_ERROR, \"Failed to open H.264 decoder.\\n\");\n\n        goto failed;\n\n    }\n\n    ctx->h264_initialized = 1;\n\n\n\n    return 0;\n\n\n\nfailed:\n\n    vdadec_close(avctx);\n\n    return -1;\n\n}\n", "idx": 0}
  • 1

推理搞起来也是十分省事了,就是对应每个index给个0或1的结果:

0	0
1	1
2	1
3	0
4	0
  • 1
  • 2
  • 3
  • 4
  • 5

问题:代码自动修复

有了识别代码漏洞的,更进一步就是学习自动修复代码的了。

代码自动修复的题目也很简单,一段是有bug的代码,另一段是修复之后的代码。

我们来看一个例子:

有bug的代码是这样的:

public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( VAR_1 . length ) - 1 ) ] . getTime ( ) ) ; }
  • 1

修复之后是这样子的:

public java.lang.String METHOD_1 ( ) { return new TYPE_1 ( STRING_1 ) . format ( VAR_1 [ ( ( type ) - 1 ) ] . getTime ( ) ) ; }
  • 1

也真难为算法了,人看起来都有点费事。

问题:代码互译

比如实现C#语言和Java语言的互译。我们只要有一系列代码的C#写法和Java写法,就可以进行学习进行互译。

我们来看一对例子。
先看C#代码:

public virtual ListSpeechSynthesisTasksResponse ListSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request){
  var options = new InvokeOptions();
  options.RequestMarshaller = ListSpeechSynthesisTasksRequestMarshaller.Instance;
  options.ResponseUnmarshaller = ListSpeechSynthesisTasksResponseUnmarshaller.Instance;
  return Invoke<ListSpeechSynthesisTasksResponse>(request, options);
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

对应的Java

public ListSpeechSynthesisTasksResult listSpeechSynthesisTasks(ListSpeechSynthesisTasksRequest request) {
  request = beforeClientExecution(request);
  return executeListSpeechSynthesisTasks(request);
}
  • 1
  • 2
  • 3
  • 4

代码互译

问题:给代码写注释

在训练素材中,有代码和注释,这个任务的目的为新代码写注释。评价指标是对于生成的注释的语言准确度。

这个我们使用CodeSearchNet数据集。

这个数据集中的每条记录的格式如下:

  • repo: 仓库名
  • path: 文件名
  • func_name: 函数或方法名
  • original_string: 未经处理的源字符串
  • language: 编程语言
  • code/function: 代码信息
  • code_tokens/function_tokens: 分词之后的代码结果
  • docstring: 注释字符串信息
  • docstring_tokens: docstring分词之后的结果
  • url: 自然语言的唯一标识号
  • idx: 代码段的唯一标识号

我们来看个例子:

{"repo": "ciena-blueplanet/bunsen-core", "path": "src/reducer.js", "func_name": "", "original_string": "function
(state, action) {\n    return _.defaults({\n      isValidating: action.isValidating,\n      lastAction: IS_VALIDA
TING\n    }, state)\n  }", "language": "javascript", "code": "function (state, action) {\n    return _.defaults({
\n      isValidating: action.isValidating,\n      lastAction: IS_VALIDATING\n    }, state)\n  }", "code_tokens":
["function", "(", "state", ",", "action", ")", "{", "return", "_", ".", "defaults", "(", "{", "isValidating", ":"
, "action", ".", "isValidating", ",", "lastAction", ":", "IS_VALIDATING", "}", ",", "state", ")", "}"], "docstrin
g": "Update is validating result\n@param {State} state - state to update\n@param {Action} action - action\n@retur
ns {State} - updated state", "docstring_tokens": ["Update", "is", "validating", "result"], "sha": "993c67e314e2b7
5003a1ff4c2f0cb667715562b2", "url": "https://github.com/ciena-blueplanet/bunsen-core/blob/993c67e314e2b75003a1ff4
c2f0cb667715562b2/src/reducer.js#L394-L399", "partition": "train"}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

对于生成的自然语言,我们采用《ORANGE: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation 》论文的方法进行评分。

问题:为自然语言文本匹配最合适的代码段

我们仍然使用上一节的CodeSearchNet数据集。

这个搜索的结果类似于下面这样:

{"url": "url0", "answers": [10,11,12,13,14]}
{"url": "url1", "answers": [10,12,11,13,14]}
{"url": "url2", "answers": [13,11,12,10,14]}
{"url": "url3", "answers": [10,14,12,13,11]}
{"url": "url4", "answers": [10,11,12,13,14]}
  • 1
  • 2
  • 3
  • 4
  • 5

配上UI,大致实现的效果是这样的:

文本转代码

或者是这样:
webquery

问题:根据自然语言生成代码

这是终极任务,就是根据一段文本描述硬生生地生成一段代码出来。

格式非常简单,就一段代码和一段文本。

我们来看个训练样本的例子:

{"code": "void function ( Binder arg0 ) { EventBus loc0 = new EventBus ( ) ; AmbariEventPublisher loc1 = new AmbariEventPublisher ( ) ; repla
ceEventBus ( AmbariEventPublisher . class , loc1 , loc0 ) ; arg0 . bind ( AmbariEventPublisher . class ) . toInstance ( loc1 ) ; }", "nl": "force the eventb us from ambarievent publisher to be serialand synchronous . concode_field_sep PlaceHolder placeHolder concode_field_sep void registerAlertListeners concode_elem_sep EventBus synchronizeAlertEventPublisher concode_elem_sep void replaceEventBus concode_elem_sep void registerAmbariListeners"}
  • 1
  • 2

这NL部分有点乱啊,没办法,为了增加数据量,没有那么多人手打精确的标。

我们再看一个:

{"code": "byte [ ] function ( Class < ? > arg0 , Configuration arg1 ) { return AuthenticationTokenSerializer . serialize ( org . apache . acc
umulo . core . client . mapreduce . lib . impl . ConfiguratorBase . getAuthenticationToken ( arg0 , arg1 ) ) ; }", "nl": "do n't use this . n
o , really , do n't use this . you already have an authenticationtoken with org.apache.accumulo.core.client.mapreduce.lib.impl.configuratorba
se #getauthenticationtoken class , configuration . you do n't need to construct it yourself . gets the password from the configuration . warn
ing : the password is stored in the configuration and shared with all mapreduce tasks ; it is base64 encoded to provide a charset safe conver
sion to a string , and is not intended to be secure . concode_field_sep PlaceHolder placeHolder concode_field_sep String getPrincipal concode
_elem_sep void setLogLevel concode_elem_sep Level getLogLevel concode_elem_sep Boolean isConnectorInfoSet concode_elem_sep String getTokenCla
ss concode_elem_sep void setZooKeeperInstance concode_elem_sep void setMockInstance concode_elem_sep Instance getInstance concode_elem_sep St
ring enumToConfKey concode_elem_sep void setConnectorInfo"}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

是不是质量也没好到哪儿去?这就是CONCODE数据集的样子。

解法:基于大规模预训练模型的多任务学习

402年前,当努尔哈赤面临明朝多路大军的围困的时候,采取了“凭你几路来,我只一路去”的战术赢得了萨尔浒之战的立国之战。
我们同样学习古人的智慧,任你数据集千变万化,我们的工具就只用一个 - 大规模预训练模型。

下面是预训练模型的简要发展史:

image.png

以开头我们展示的微软的codebert模型为例,我们要处理上面最复杂的代码生成任务,只要一条命令就可以搞定:

python -m torch.distributed.launch --nproc_per_node=$PER_NODE_GPU run.py \
        --data_dir=$DATADIR \
        --langs=$LANG \
        --output_dir=$OUTPUTDIR \
        --pretrain_dir=$PRETRAINDIR \
        --log_file=$LOGFILE \
        --model_type=gpt2 \
        --block_size=512 \
        --do_train \
        --node_index 0 \
        --gpu_per_node $PER_NODE_GPU \
        --learning_rate=5e-5 \
        --weight_decay=0.01 \
        --evaluate_during_training \
        --per_gpu_train_batch_size=6 \
        --per_gpu_eval_batch_size=12 \
        --gradient_accumulation_steps=2 \
        --num_train_epochs=30 \
        --logging_steps=100 \
        --save_steps=5000 \
        --overwrite_output_dir \
        --seed=42
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

如果使用两张2 NVIDIA P100 GPU卡的话,22小时左右就可以训练完。

推理呢,也是一条语句就搞定:

python -u run.py \
        --data_dir=$DATADIR \
        --langs=$LANG \
        --output_dir=$OUTPUTDIR \
        --pretrain_dir=$PRETRAINDIR \
        --log_file=$LOGFILE \
        --model_type=gpt2 \
        --block_size=512 \
        --do_infer \
        --logging_steps=100 \
        --seed=42
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

只用一张P100卡,大约40分钟就可以搞定。

有了上面的基础,我们就可以去打比赛啦。上面介绍的数据集,全都是比赛的赛题:
截屏2021-07-02 下午12.56.58.png

上面提到的数据集,可以在https://github.com/microsoft/CodeXGLUE下载到。

欢迎来到代码智能的世界!

附录:快速上手指南

放翁云:纸上得来终觉浅,绝知此事要躬行。
下面我们就落地下,将代码智能模型的训练和推理跑起来~~~

  • 第一步:安装transformers框架,因为codebert是基于这个框架的:
pip install transformers --user
  • 1
  • 第二步:安装PyTorch或者Tensorflow作为Transformers的后端,以2021年7月5日这个时间点,需要的PyTorch版本至少是1.5.0以上。驱动能搞定的话,索性就安装最新的吧:
pip install torch torchvision torchtext torchaudio --user
  • 1
  • 第三步,下载微软的数据集
git clone https://github.com/microsoft/CodeXGLUE
  • 1
  • 第四步,我们先玩玩BigCloneBench吧

到Code-Code/Clone-detection-BigCloneBench/code目录下,运行:

python run.py     --output_dir=./saved_models     --model_type=roberta     --config_name=microsoft/codebert-base     --model_name_or_path=microsoft/codebert-base     --tokenizer_name=roberta-base     --do_train     --train_data_file=../dataset/train.txt     --eval_data_file=../dataset/valid.txt     --test_data_file=../dataset/test.txt     --epoch 2     --block_size 400     --train_batch_size 16     --eval_batch_size 32     --learning_rate 5e-5     --max_grad_norm 1.0     --evaluate_during_training     --seed 123456 2>&1| tee train.log
  • 1

然后训练就运行起来了:

07/05/2021 16:29:24 - INFO - __main__ -   ***** Running training *****
07/05/2021 16:29:24 - INFO - __main__ -     Num examples = 90102
07/05/2021 16:29:24 - INFO - __main__ -     Num Epochs = 2
07/05/2021 16:29:24 - INFO - __main__ -     Instantaneous batch size per GPU = 8
07/05/2021 16:29:24 - INFO - __main__ -     Total train batch size (w. parallel, distributed & accumulation) = 16
07/05/2021 16:29:24 - INFO - __main__ -     Gradient Accumulation steps = 1
07/05/2021 16:29:24 - INFO - __main__ -     Total optimization steps = 11264
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在两张V100卡大约需要训练40分钟左右。
训练之后是验证,会将目前最好的结果保存到checkpoint中以备推理时使用

07/05/2021 17:10:04 - INFO - __main__ -   ***** Running evaluation  ***** 40950/41541 [00:10<00:00, 2785.61it/s]
07/05/2021 17:10:04 - INFO - __main__ -     Num examples = 41541
07/05/2021 17:10:04 - INFO - __main__ -     Batch size = 32
07/05/2021 17:16:05 - INFO - __main__ -   ***** Eval results  *****
07/05/2021 17:16:05 - INFO - __main__ -     eval_f1 = 0.9531
07/05/2021 17:16:05 - INFO - __main__ -     eval_precision = 0.9579
07/05/2021 17:16:05 - INFO - __main__ -     eval_recall = 0.9484
07/05/2021 17:16:05 - INFO - __main__ -     eval_threshold = 0.97
07/05/2021 17:16:06 - INFO - __main__ -     ********************
07/05/2021 17:16:06 - INFO - __main__ -     Best f1:0.9531
07/05/2021 17:16:06 - INFO - __main__ -     ********************
07/05/2021 17:16:08 - INFO - __main__ -   Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

一次训练两轮,第二轮效果提升到0.97多:

07/05/2021 17:56:43 - INFO - __main__ -   ***** Running evaluation  ***** 40950/41541 [00:12<00:00, 3535.62it/s]
07/05/2021 17:56:43 - INFO - __main__ -     Num examples = 41541
07/05/2021 17:56:43 - INFO - __main__ -     Batch size = 32
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
[W pthreadpool-cpp.cc:90] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)
07/05/2021 18:02:44 - INFO - __main__ -   ***** Eval results  *****
07/05/2021 18:02:44 - INFO - __main__ -     eval_f1 = 0.9701
07/05/2021 18:02:44 - INFO - __main__ -     eval_precision = 0.9772
07/05/2021 18:02:44 - INFO - __main__ -     eval_recall = 0.9633
07/05/2021 18:02:44 - INFO - __main__ -     eval_threshold = 0.97
07/05/2021 18:02:45 - INFO - __main__ -     ********************
07/05/2021 18:02:45 - INFO - __main__ -     Best f1:0.9701
07/05/2021 18:02:45 - INFO - __main__ -     ********************
07/05/2021 18:02:47 - INFO - __main__ -   Saving model checkpoint to ./saved_models/checkpoint-best-f1/model.bin
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

然后我们用训好的模型进行推理吧:

python run.py \
    --output_dir=./saved_models \
    --model_type=roberta \
    --config_name=microsoft/codebert-base \
    --model_name_or_path=microsoft/codebert-base \
    --tokenizer_name=roberta-base \
    --do_eval \
    --do_test \
    --train_data_file=../dataset/train.txt \
    --eval_data_file=../dataset/valid.txt \
    --test_data_file=../dataset/test.txt \
    --epoch 2 \
    --block_size 400 \
    --train_batch_size 16 \
    --eval_batch_size 32 \
    --learning_rate 5e-5 \
    --max_grad_norm 1.0 \
    --evaluate_during_training \
    --seed 123456 2>&1| tee test.log
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

最后我们运行evaluator.py来查看测试结果:

python ../evaluator/evaluator.py -a ../dataset/test.txt -p saved_models/predictions.txt
  • 1

输出如下:

{'Recall': 0.9677421599288263, 'Prediction': 0.9557057904236594, 'F1': 0.9616080550111168}
  • 1

准确率0.956, 召回率0.968,还不错~

跟CodeXGLUE的排行榜比一比:

跟榜上的CodeBert的结果基本一致

GraphCodeBert

要提升性能,我们可以用GraphCodeBert来替换CodeBert.

我们先下载GraphCodeBert的代码:

git clone https://github.com/microsoft/CodeBERT
  • 1

然后转到GraphCodeBERT/clonedetection目录,解压dataset.zip:

unzip dataset.zip
  • 1

然后就可以像训练codebert一样训练graphcodebert了:

mkdir saved_models
python run.py \
    --output_dir=saved_models \
    --config_name=microsoft/graphcodebert-base \
    --model_name_or_path=microsoft/graphcodebert-base \
    --tokenizer_name=microsoft/graphcodebert-base \
    --do_train \
    --train_data_file=dataset/train.txt \
    --eval_data_file=dataset/valid.txt \
    --test_data_file=dataset/test.txt \
    --epoch 1 \
    --code_length 512 \
    --data_flow_length 128 \
    --train_batch_size 16 \
    --eval_batch_size 32 \
    --learning_rate 2e-5 \
    --max_grad_norm 1.0 \
    --evaluate_during_training \
    --seed 123456 2>&1| tee saved_models/train.log
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

上面的参数是按4个V100 GPU来调的,如果只有两块V100,可以将–code_length改成256.
CodeBert 40分钟左右一轮,GraphCodeBert大约需要6个半小时一轮。

然后我们进行推理:

python run.py     --output_dir=saved_models     --config_name=microsoft/graphcodebert-base     --model_name_or_path=microsoft/graphcodebert-base     --tokenizer_name=microsoft/graphcodebert-base     --do_eval     --do_test     --train_data_file=dataset/train.txt     --eval_data_file=dataset/valid.txt     --test_data_file=dataset/test.txt     --epoch 1     --code_length 256     --data_flow_length 128     --train_batch_size 16     --eval_batch_size 32     --learning_rate 2e-5     --max_grad_norm 1.0     --evaluate_during_training     --seed 123456 2>&1| tee saved_models/test.log
  • 1

最后我们解读一下结果吧:

python evaluator/evaluator.py -a dataset/test.txt -p saved_models/predictions.txt 2>&1| tee saved_models/score.log
  • 1

结果如下:

{'Recall': 0.9589415798936043, 'Prediction': 0.962620653900429, 'F1': 0.9607703728051462}
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/703400
推荐阅读
相关标签
  

闽ICP备14008679号