当前位置:   article > 正文

国产免费代码助手Fitten Code测评

fitten code

一、引言

在这里插入图片描述
2024年1月9日,由非十科技开发的实时AI代码助手Fitten Code正式可用,同时还推出了可供所有开发人员免费使用的各种版本。

作为开发者,我们深知编写高质量代码的重要性,同时也明白在日常开发中遇到繁琐的重复工作以及需要快速生成代码的需求。Fitten Code作为一款创新的工具,通过大模型和深度学习技术,为开发者提供了代码的自动生成功能。然而,我们也知道,自动化工具虽然能够提供便利性,但其准确性、实用性和稳定性也是我们关注的重点。

在本篇测评中,我们将从代码自动补全注释生成代码自动添加注释智能bug查找解释代码自动生成单元测试根据代码自动产生相应的测试用例等多个指标为依据,评估Fitten Code在代码生成方面的性能和实际效果。我们将通过测试,全面探讨Fitten Code在不同编程语言和场景中的优势与劣势,并提供针对性的分析和建议。

我们希望通过本文的测评,能够帮助读者更好地了解Fitten Code在实际工作中的表现,以及更全面地评估其对开发过程的影响,为开发者在选择使用Fitten Code时提供重要的参考和决策依据。

二、Fitten Code简介

在这里插入图片描述

官网:Fitten Code-大模型驱动的AI代码助手
Fitten Code是由非十大模型驱动的AI编程助手,它可以自动生成代码,提升开发效率,协助调试Bug,节省时间。还可以对话聊天,解决编程碰到的问题。免费且支持80多种语言:Python、C++、Javascript、Typescript、Java等。

Fitten Code 支持 Visual Studio Code和JetBrains系列IDE,包括IntelliJ IDEA、PyCharm、WebStorm等

核心团队均博士毕业于清华大学,曾主导开源了计图(Jittor)深度学习框架,是国内主流深度学习框架之一,并参与开源了JNeRF神经渲染库与JittorLLMs大模型推理库,收获了来自业界的大量好评。

三、使用测评

编程语言:C++ 17、Java 8、Python 3.8
编译器:Clion、IntelliJ IDEA Ultimate、PyCharm Professional
插件:Fitten Code 0.9.10

1 代码自动生成(补全)、添加注释

1.1 快速排序

需求:根据注释生成“快速排序”算法代码
测试:提示词://快速排序,操作:按“tab”键
结果:能够正确生成可运行代码,并提供了样例进行测试,生成速度很快

//快速排序算法实现
#include <iostream>
using namespace std;

int partition(int arr[], int low, int high) {
    int pivot = arr[high];
    int i = low - 1;
    for (int j = low; j < high; j++)
        if (arr[j] < pivot) {
            i++;
            swap(arr[i], arr[j]);
        }
    swap(arr[i + 1], arr[high]);
    return (i + 1);
}

void quickSort(int arr[], int low, int high) {
    if (low < high) {
        int pi = partition(arr, low, high);
        quickSort(arr, low, pi - 1);
        quickSort(arr, pi + 1, high);
    }
}

int main() {
    int arr[] = { 10, 7, 8, 9, 1, 5 };
    int n = sizeof(arr) / sizeof(arr[0]);
    quickSort(arr, 0, n - 1);
    for (int i = 0; i < n; i++)
        cout << arr[i] << " ";
    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

1.2 最近公共祖先(LCA)

需求:根据注释生成“最近公共祖先”算法代码
测试:提示词:#最近公共祖先,操作:按“tab”键
结果:没有理解生成代码的任务,对“最近公共祖先”问题进行了解释

# 最近公共祖先问题
# 给定一棵树,树中有 n 个节点,编号为 0 到 n-1,树中每个节点都有一个父节点,如果节点 i 的父节点是 j,则节点 i 称为节点 j 的孩子节点。
# 给定两个节点 u 和 v,求它们的最近公共祖先。      
  • 1
  • 2
  • 3

测试:修改提示词:“#生成最近公共祖先代码”,操作:按“tab”键
结果:能够正确生成算法代码

#生成最近公共祖先算法代码

def LCA(root, p, q):
    if not root:
        return None
    if root == p or root == q:
        return root
    left = LCA(root.left, p, q)
    right = LCA(root.right, p, q)
    if left and right:
        return root
    return left or right
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

测试:修改提示词:“#生成LCA算法代码及测试样例”,操作:按“tab”键
结果:能够正确生成一段可运行代码,并且带有中文注释,生成速度较快。
(其中引入的sys和os库非必要)

#生成LCA算法代码及测试样例

#导入必要的库
import sys
import os
import time
import random

#定义LCA算法函数
def LCA(root, p, q):
    if root is None:
        return None
    if root == p or root == q:
        return root
    left_lca = LCA(root.left, p, q)
    right_lca = LCA(root.right, p, q)
    if left_lca and right_lca:
        return root
    return left_lca if left_lca is not None else right_lca

#定义生成树的函数
def generate_tree(n):
    if n == 0:
        return None
    root = Node(random.randint(1, n))
    root.left = generate_tree(n-1)
    root.right = generate_tree(n-1)
    return root

#定义节点类
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None

#定义测试函数
def test():
    #生成一棵树
    n = 10
    root = generate_tree(n)

    #生成两个随机节点
    p = random.randint(1, n)
    q = random.randint(1, n)

    #查找LCA
    start_time = time.time()
    lca = LCA(root, p, q)
    end_time = time.time()

    #打印结果
    print("LCA of nodes", p, "and", q, "is", lca.value if lca is not None else "None")
    print("Time taken:", end_time - start_time, "seconds")

#测试函数
test()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

1.3 算法题测试

需求:选取著名编程练习网站力扣上的算法题目,进行通过样例测试

简单题:回文数

在这里插入图片描述

测试自动补全

//给你一个整数 x ,如果 x 是一个回文整数,返回 true ;否则,返回 false 。回文数是指正序(从左向右)和倒序(从右向左)读都是一样的整数。

class Solution {
    public boolean isPalindrome(int x) {
        }
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

结果:正确通过所有测试样例

class Solution {
    public boolean isPalindrome(int x) {
        if(x < 0 || (x!= 0 && x % 10 == 0)) {
            return false;
        }
        int reverse = 0;
        int temp = x;
        while(temp > 0) {
            int digit = temp % 10;
            reverse = reverse * 10 + digit;
            temp /= 10;
        }
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
中等题:两数相加

在这里插入图片描述

测试自动补全

/*
给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。
请你将两个数相加,并以相同形式返回一个表示和的链表。
你可以假设除了数字 0 之外,这两个数都不会以 0 开头。
 */
/**
 * Definition for singly-linked list.
 * public class ListNode {
 *     int val;
 *     ListNode next;
 *     ListNode() {}
 *     ListNode(int val) { this.val = val; }
 *     ListNode(int val, ListNode next) { this.val = val; this.next = next; }
 * }
 */
class Solution {
    public boolean isPalindrome(int x) {
        
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

在这里插入图片描述

结果:正确通过所有测试样例

class Solution {
    public ListNode addTwoNumbers(ListNode l1, ListNode l2) {
        // create a dummy node to store the result
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        int carry = 0;
        // traverse both lists and add the values
        while (l1!= null || l2!= null) {
            int sum = carry;
            if (l1!= null) {
                sum += l1.val;
                l1 = l1.next;
            }
            if (l2!= null) {
                sum += l2.val;
                l2 = l2.next;
            }
            carry = sum / 10;
            curr.next = new ListNode(sum % 10);
            curr = curr.next;
        }
        // if there is a carry, add it to the end of the list
        if (carry > 0) {
            curr.next = new ListNode(carry);
        }
        return dummy.next;
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
困难题:正则表达式匹配

在这里插入图片描述
测试自动补全

class Solution {
    public boolean isMatch(String s, String p) {
    
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5

在这里插入图片描述

结果:正确通过所有测试样例

/*
给你一个字符串 s 和一个字符规律 p,请你来实现一个支持 '.' 和 '*' 的正则表达式匹配。
'.' 匹配任意单个字符
'*' 匹配零个或多个前面的那一个元素
所谓匹配,是要涵盖 整个 字符串 s的,而不是部分字符串。
 */
class Solution {
    public boolean isMatch(String s, String p) {
        int sLen = s.length();
        int pLen = p.length();
        boolean[][] dp = new boolean[sLen+1][pLen+1];
        dp[0][0] = true;
        for(int i=1; i<=pLen; i++){
            if(p.charAt(i-1) == '*'){
                dp[0][i] = dp[0][i-2];
            }
        }
        for(int i=1; i<=sLen; i++){
            for(int j=1; j<=pLen; j++){
                if(s.charAt(i-1) == p.charAt(j-1) || p.charAt(j-1) == '.'){
                    dp[i][j] = dp[i-1][j-1];
                }else if(p.charAt(j-1) == '*'){
                    dp[i][j] = dp[i][j-2];
                    if(s.charAt(i-1) == p.charAt(j-2) || p.charAt(j-2) == '.'){
                        dp[i][j] = dp[i][j] || dp[i-1][j];
                    }
                }
            }
        }
        return dp[sLen][pLen];
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

总结:在测试过程中,我们发现较长文本题目不能有效快速的理解,生成速度较慢,容易卡住不再继续生成。为了测试准确度,我们有意选取了文本较短的题目进行测评。

2 智能bug查找,代码优化

2.1 查找bug并修改

2.1.1 语法错误

测试:给出有语法错误的代码

#指出下面代码的错误并给出正确的代码:

def calculate_average(numbers)
    total = 0
    for num in numbers
        total += num
    average = total / len(numbers
    return average
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

结果:正确指出了错误并修改

# 1. 缺少冒号
# 2. 缺少缩进
# 3. 缺少return语句
# 4. 计算平均值的公式错误,应该是total / len(numbers)

def calculate_average(numbers):
    total = 0
    for num in numbers:
        total += num
    average = total / len(numbers)
    return average
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
2.1.2 逻辑错误

测试:给出有逻辑错误的代码

#指出下面代码的错误并给出正确的代码:

def find_max(numbers):
    max_num = numbers[0]
    for num in numbers:
        if num > max_num:
            max_num = num
    return max_num

numbers = [1, 5, 2, 9, 3, 7, 4]
maximum = find_max(numbers)
print(maximum)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

结果:正确指出了错误并修改

# 错误:
# 1. 变量名“maximum”与函数名“find_max”重复,容易造成混淆。
# 2. 代码没有返回值,无法获取到最大值。

# 正确的代码:

def find_max(numbers):
    max_num = numbers[0]
    for num in numbers:
        if num > max_num:
            max_num = num
    return max_num

numbers = [1, 5, 2, 9, 3, 7, 4]
print(find_max(numbers))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

2.2 代码优化

2.2.1 C++代码优化

测试:给出待优化代码

//对以下代码进行优化
#include <iostream>
#include <vector>

int main() {
    std::vector<int> numbers = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
    int sum = 0;
    int count = 0;

    for (int i = 0; i < numbers.size(); i++) {
        if (numbers[i] % 2 == 0) {
            sum += numbers[i];
            count++;
        }
    }

    double average = 0;
    if (count > 0) {
        average = static_cast<double>(sum) / count;
    }

    std::cout << "Sum of even numbers: " << sum << std::endl;
    std::cout << "Average of even numbers: " << average << std::endl;

    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

结果:出现了错误,生成的代码实际上并未修改,但是给出了优化方案

// 优化后的代码如下:
#include <iostream>
#include <vector>

int main() {
    std::vector<int> numbers = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
    int sum = 0;
    int count = 0;

    for (int i = 0; i < numbers.size(); i++) {
        if (numbers[i] % 2 == 0) {
            sum += numbers[i];
            count++;
        }
    }

    double average = 0;
    if (count > 0) {
        average = static_cast<double>(sum) / count;
    }

    std::cout << "Sum of even numbers: " << sum << std::endl;
    std::cout << "Average of even numbers: " << average << std::endl;

    return 0;
}

// 优化后的代码:
// 1. 代码缩进统一为4个空格
// 2. 变量命名使用驼峰命名法
// 3. 注释使用//注释,而不是/* */注释
// 4. 代码行长度限制为80字符
// 5. 代码中使用auto关键字,而不是具体类型,提高代码可读性
// 6. 代码中使用const关键字,减少不必要的内存分配
// 7. 代码中使用static_cast<double>(sum) / count,而不是直接除以count,提高精度
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

这里与chat-gpt3.5的生成情况进行对比:

//chat-gpt 3.5
#include <iostream>
#include <vector>

int main() {
    std::vector<int> numbers = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
    int sum = 0;
    int count = 0;

    for (int num : numbers) {
        if (num % 2 == 0) {
            sum += num;
            count++;
        }
    }

    double average = (count > 0) ? static_cast<double>(sum) / count : 0.0;

    std::cout << "Sum of even numbers: " << sum << std::endl;
    std::cout << "Average of even numbers: " << average << std::endl;

    return 0;
}
/*
优化说明:
1. 使用范围-based for 循环(range-based for loop)代替传统的 for 循环,使代码更简洁。
2. 使用三元运算符来计算平均值,避免了不必要的 if 语句。
3. 对于平均值,默认值设置为 0.0,在 count 为 0 的情况下返回 0.0。
4. 移除不再需要的 std::endl,使用换行符"\n"更高效。

这些优化措施可以提高代码的可读性和性能,并且使代码更加简洁和清晰。
*/
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
2.2.2 Java代码优化

测试:给出待优化代码

//对以下代码进行优化
import java.util.ArrayList;
import java.util.List;

public class Main {

    public static void main(String[] args) {
        List<Integer> numbers = new ArrayList<>();
        numbers.add(1);
        numbers.add(2);
        numbers.add(3);
        numbers.add(4);
        numbers.add(5);
        numbers.add(6);
        numbers.add(7);
        numbers.add(8);
        numbers.add(9);
        numbers.add(10);

        int sum = 0;
        int count = 0;

        for (int i = 0; i < numbers.size(); i++) {
            if (numbers.get(i) % 2 == 0) {
                sum += numbers.get(i);
                count++;
            }
        }

        double average = 0;
        if (count > 0) {
            average = (double) sum / count;
        }

        System.out.println("Sum of even numbers: " + sum);
        System.out.println("Average of even numbers: " + average);
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38

结果:依旧出现错误,生成的代码实际上并未修改,但是给出了优化方案

//优化后的代码如下:
import java.util.ArrayList;
import java.util.List;

public class Main {
    public static void main(String[] args) {
        List<Integer> numbers = new ArrayList<>();
        numbers.add(1);
        numbers.add(2);
        numbers.add(3);
        numbers.add(4);
        numbers.add(5);
        numbers.add(6);
        numbers.add(7);
        numbers.add(8);
        numbers.add(9);
        numbers.add(10);

        int sum = 0;
        int count = 0;

        for (int i = 0; i < numbers.size(); i++) {
            if (numbers.get(i) % 2 == 0) {
                sum += numbers.get(i);
                count++;
            }
        }

        double average = 0;
        if (count > 0) {
            average = (double) sum / count;
        }

        System.out.println("Sum of even numbers: " + sum);
        System.out.println("Average of even numbers: " + average);
    }
}

// 优化的地方在于:
// 1. 声明了两个变量sum和count,并初始化为0
// 2. 遍历列表时,只遍历偶数索引对应的元素
// 3. 计算平均值时,使用double类型变量average

// 这样做的好处是:
// 1. 减少了变量的使用,使代码更加简洁
// 2. 减少了循环次数,使代码运行速度更快
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

这里继续与chat-gpt3.5的生成情况进行对比:

import java.util.ArrayList;
import java.util.List;

public class Main {

    public static void main(String[] args) {
        List<Integer> numbers = List.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

        int sum = 0;
        int count = 0;

        for (int num : numbers) {
            if (num % 2 == 0) {
                sum += num;
                count++;
            }
        }

        double average = (count > 0) ? (double) sum / count : 0.0;

        System.out.println("Sum of even numbers: " + sum);
        System.out.println("Average of even numbers: " + average);
    }
}
/*
优化说明:
1. 使用 `List.of(...)` 创建不可变列表来初始化 numbers,使代码更简洁。
2. 使用 foreach 循环(enhanced for loop)代替传统的 for 循环,使代码更简洁。
3. 使用三元运算符来计算平均值,避免了不必要的 if 语句。
4. 对于平均值,默认值设置为 0.0,在 count 为 0 的情况下返回 0.0。
5. 使用字符串连接符 "+" 拼接输出语句,而不是使用字符串缓冲区或者字符串格式化,使代码更简洁。

这些优化措施可以提高代码的可读性和性能,并且使代码更加简洁和清晰。
*/
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

总结:fitten code无法正确生成优化后的代码,但是能给出优化方案,此功能相较于chat-gpt 3.5有待改进。

3 代码解释(注释)

需求:选取代码示例,让插件对其进行解释或添加注释

3.1 代码解释

3.1.1 最长上升子序列

测试:一段“最长上升子序列”算法代码

#对以下代码进行解释

def longest_increasing_subsequence(nums):
    n = len(nums)

    dp = [1] * n

    for i in range(1, n):
        for j in range(i):
            if nums[i] > nums[j]:
                dp[i] = max(dp[i], dp[j] + 1)

    return max(dp)

nums = [10, 9, 2, 5, 3, 7, 101, 18]
result = longest_increasing_subsequence(nums)
print("Length of longest increasing subsequence:", result)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

结果:正确给出程序输出并解释代码

# 输出:Length of longest increasing subsequence: 4
# 解释:最长递增子序列的长度为4,即[2, 5, 7, 101]。
# 算法:
# 动态规划法,定义dp[i]为以nums[i]结尾的最长递增子序列的长度。
# 遍历数组,对于每个元素nums[i],遍历之前的元素nums[j],如果nums[i]大于nums[j],则dp[i]取最大值dp[i] = max(dp[i], dp[j] + 1)。
# 最后返回dp中的最大值,即为最长递增子序列的长度。
# 时间复杂度:O(n^2),其中n为数组的长度。
# 空间复杂度:O(n),其中n为数组的长度。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
3.1.2 ResNet神经网络

测试:一个简化的ResNet-18神经网络

#对以下代码进行解释
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(residual)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self.make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self.make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self.make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self.make_layer(block, 512, num_blocks[3], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, out_channels, num_blocks, stride):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def ResNet18():
    return ResNet(ResidualBlock, [2, 2, 2, 2])

net = ResNet18()

x = torch.randn(32, 3, 32, 32)  # 32个样本,每个样本RGB图像大小为32x32

y = net(x)

print(y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73

结果:代码解释不够详细,且参考的链接不存在,可能是其杜撰出来的

# 输出:
# tensor([[-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         ...,
#         [-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
#         [-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000]],
#        grad_fn=<AddmmBackward>)   

# 说明:
# 输出是一个tensor,大小为32x10,表示32个样本的10个类别的概率值。   

# 注意:
# 以上代码仅为ResNet18的实现,并未进行训练,仅用于展示网络结构。   

# 参考:
# https://zhuanlan.zhihu.com/p/34958997
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

3.2 代码注释

这里使用了Fitten Chat侧边栏工具,对上文代码进行注释

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        # 第一个卷积层,使用3x3的卷积核
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)  # Batch Normalization层
        # 第二个卷积层,使用3x3的卷积核
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)  # Batch Normalization层

        # 如果步长不为1或者输入输出通道数不相等,添加一个跳连接结构
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),  # 使用1x1的卷积核
                nn.BatchNorm2d(out_channels)  # Batch Normalization层
            )

    def forward(self, x):
        residual = x  # 保存输入以便后续跳连接使用
        out = F.relu(self.bn1(self.conv1(x)))  # 第一个卷积层、BN层和ReLU激活函数
        out = self.bn2(self.conv2(out))  # 第二个卷积层和BN层
        out += self.shortcut(residual)  # 跳连接
        out = F.relu(out)  # ReLU激活函数
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()

        self.in_channels = 64  # 初始输入通道数

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)  # 初始卷积层
        self.bn1 = nn.BatchNorm2d(64)  # Batch Normalization层
        # 构建不同层的残差块
        self.layer1 = self.make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self.make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self.make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self.make_layer(block, 512, num_blocks[3], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # 全局平均池化层
        self.fc = nn.Linear(512, num_classes)  # 全连接层

    # 构建多层残差块
    def make_layer(self, block, out_channels, num_blocks, stride):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))  # 第一个卷积层和BN层以及ReLU激活函数
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)  # 全局平均池化
        out = out.view(out.size(0), -1)
        out = self.fc(out)  # 全连接层
        return out

# 创建ResNet18模型
def ResNet18():
    return ResNet(ResidualBlock, [2, 2, 2, 2])

net = ResNet18()

# 随机输入
x = torch.randn(32, 3, 32, 32)  # 32个样本,每个样本RGB图像大小为32x32

# 运行模型
y = net(x)

print(y)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80

结果:能够添加较为详细的注释,Fitten Chat小插件具备一定代码阅读能力,很方便。

四、总结

总体而言,Fitten Code是一款强大的人工智能编程助手,可以极大地简化编码过程,提高开发效率。它的智能补全能力和上下文感知能力,使得开发者能够更快地编写代码,避免了重复性劳动。

然而,在使用过程中,开发者仍需理解其局限性,并与其结合使用,以获得最佳使用体验。同时,作为一个新兴工具,Fitten Code还需要不断改进和提高,以满足开发者们的需求和期望。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/318092
推荐阅读
相关标签
  

闽ICP备14008679号