当前位置:   article > 正文

最新JAVA的NLP工具DJL_java nlp

java nlp

零、其他:NLP工具包
LingPipe

是alias公司开发的一款自然语言处理软件包。

主题分类(Top Classification)
命名实体识别(Named Entity Recognition)
词性标注(Part-of Speech Tagging)
句题检测(Sentence Detection)
查询拼写检查(Query Spell Checking)
兴趣短语检测(Interseting Phrase Detection)
聚类(Clustering)
字符语言建模(Character Language Modeling)
医学文献下载/解析/索引(MEDLINE Download, Parsing and Indexing)
数据库文本挖掘(Database Text Mining)
中文分词(Chinese Word Segmentation)
情感分析(Sentiment Analysis)
语言辨别(Language Identification)

HanLP

HanLP是一系列模型与算法组成的NLP工具包,目标是普及自然语言处理在生产环境中的应用。

FudanNLP

FNLP主要是为中文自然语言处理而开发的工具包,也包含为实现这些任务的机器学习算法和数据集。 本工具包及其包含数据集使用LGPL3.0许可证。

    信息检索: 文本分类 新闻聚类
    中文处理: 中文分词 词性标注 实体名识别 关键词抽取 依存句法分析 时间短语识别
    结构化学习: 在线学习 层次分类 聚类

Apache OpenNLP

OpenNLP支持最常见的NLP任务:

例如标记化,句子分段,词性标记,命名实体提取,分块,解析,语言检测和共指解析
1
Stanford CoreNLP

斯坦福大学开发的自然语言处理工具套件,包括词性标注、命名实体识别、共指消解等 NLP 任务:


一、简介
开源库以Java构建和部署深度学习、编写一次即可在任何地方运行。使用DJL开发模型并在您选择的引擎上运行。直观的API使用本机Java概念并抽象化了深度学习所涉及的复杂性。引入您自己的模型,或使用我们库中的最新模型在几分钟内进行部署。

二、开源地址:
https://github.com/awslabs/djl

三、例子或者用法

1、Single-shot object detection example
2、Train your first model
3、Image classification example
4、Transfer learning example
5、Train SSD model example
6、Multi-threaded inference example
7、Bert question and answer example
8、Instance segmentation example
9、Pose estimation example
10、Action recognition example
11、Multi-label dataset training example

四、官网地址:

https://djl.ai/

五、代码如下:
1、依赖:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.2.5.RELEASE</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>com.xxxx</groupId>
    <artifactId>bigdata</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>bigdata</name>
    <description>Demo project for Spring Boot</description>

    <properties>
        <java.version>1.8</java.version>
        <spring-cloud.version>Hoxton.SR1</spring-cloud.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.cloud</groupId>
            <artifactId>spring-cloud-starter-aws</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>examples</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.3.0</version>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <optional>true</optional>
        </dependency>

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>0.3.0</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
            <exclusions>
                <exclusion>
                    <groupId>org.junit.vintage</groupId>
                    <artifactId>junit-vintage-engine</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
    </dependencies>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.springframework.cloud</groupId>
                <artifactId>spring-cloud-dependencies</artifactId>
                <version>${spring-cloud.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>



2、问答案例:

package com.xxxx.bigdata.nlputils;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.mxnet.zoo.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * An example of inference using BertQA.
 *
 * <p>See:
 *
 * <ul>
 *   <li>the <a href="https://github.com/awslabs/djl/blob/master/jupyter/BERTQA.ipynb">jupyter
 *       demo</a> with more information about BERT.
 *   <li>the <a
 *       href="https://github.com/awslabs/djl/blob/master/examples/docs/BERT_question_and_answer.md">docs</a>
 *       for information about running this example.
 * </ul>
 */
public final class BertQaInference {

    private static final Logger logger = LoggerFactory.getLogger(BertQaInference.class);

    private BertQaInference() {}

    public static void main(String[] args) throws IOException, TranslateException, ModelException {
        String answer = BertQaInference.predict();
        logger.info("Answer: {}", answer);
    }

    public static String predict() throws IOException, TranslateException, ModelException {
        String question = "When did BBC Japan start broadcasting?";
        String paragraph =
                "BBC Japan was a general entertainment Channel.\n"
                        + "Which operated between December 2004 and April 2006.\n"
                        + "It ceased operations after its Japanese distributor folded.";

        QAInput input = new QAInput(question, paragraph, 384);
        logger.info("Paragraph: {}", input.getParagraph());
        logger.info("Question: {}", input.getQuestion());

        Criteria<QAInput, String> criteria =
                Criteria.builder()
                        .optApplication(Application.NLP.QUESTION_ANSWER)
                        .setTypes(QAInput.class, String.class)
                        .optFilter("backbone", "bert")
                        .optFilter("dataset", "book_corpus_wiki_en_uncased")
                        .optProgress(new ProgressBar())
                        .build();

        try (ZooModel<QAInput, String> model = ModelZoo.loadModel(criteria)) {
            try (Predictor<QAInput, String> predictor = model.newPredictor()) {
                return predictor.predict(input);
            }
        }
    }
}


3、训练模型

package com.xxxx.bigdata.nlputils;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.CaptchaDataset;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.Dataset.Usage;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.SimpleCompositeLoss;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.nio.file.Paths;

/**
 * An example of training a CAPTCHA solving model.
 *
 * <p>See this <a
 * href="https://github.com/awslabs/djl/blob/master/examples/docs/train_captcha.md">doc</a> for
 * information about this example.
 */
public final class TrainCaptcha {

    private TrainCaptcha() {}

    public static void main(String[] args) throws Exception{
        TrainCaptcha.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args)
            throws Exception {
        Arguments arguments = Arguments.parseArgs(args);

        try (Model model = Model.newInstance()) {
            model.setBlock(getBlock());

            // get training and validation dataset
            RandomAccessDataset trainingSet = getDataset(Usage.TRAIN, arguments);
            RandomAccessDataset validateSet = getDataset(Usage.VALIDATION, arguments);

            // setup training configuration
            DefaultTrainingConfig config = setupTrainingConfig(arguments);

            ExampleTrainingResult result;
            try (Trainer trainer = model.newTrainer(config)) {
                trainer.setMetrics(new Metrics());

                Shape inputShape =
                        new Shape(1, 1, CaptchaDataset.IMAGE_HEIGHT, CaptchaDataset.IMAGE_WIDTH);

                // initialize trainer with proper input shape
                trainer.initialize(inputShape);

                TrainingUtils.fit(
                        trainer,
                        arguments.getEpoch(),
                        trainingSet,
                        validateSet,
                        arguments.getOutputDir(),
                        "captcha");

                result = new ExampleTrainingResult(trainer);
            }
            model.save(Paths.get(arguments.getOutputDir()), "captcha");
            return result;
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        SimpleCompositeLoss loss = new SimpleCompositeLoss();
        for (int i = 0; i < CaptchaDataset.CAPTCHA_LENGTH; i++) {
            loss.addLoss(new SoftmaxCrossEntropyLoss("loss_digit_" + i), i);
        }

        DefaultTrainingConfig config =
                new DefaultTrainingConfig(loss)
                        .optDevices(Device.getDevices(arguments.getMaxGpus()))
                        .addTrainingListeners(
                                TrainingListener.Defaults.logging(arguments.getModelDir(),arguments.getBatchSize(),arguments.getEpoch(),arguments.getMaxGpus(),arguments.getOutputDir()));

        for (int i = 0; i < CaptchaDataset.CAPTCHA_LENGTH; i++) {
            config.addEvaluator(new Accuracy("acc_digit_" + i, i));
        }

        return config;
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments)
            throws IOException {
        CaptchaDataset dataset =
                CaptchaDataset.builder()
                        .optUsage(usage)
                        .setSampling(arguments.getBatchSize(), true)
                        .optMaxIteration(arguments.getMaxIterations())
                        .build();
        dataset.prepare(new ProgressBar());
        return dataset;
    }

    private static Block getBlock() {
        Block resnet =
                ResNetV1.builder()
                        .setNumLayers(50)
                        .setImageShape(
                                new Shape(
                                        1, CaptchaDataset.IMAGE_HEIGHT, CaptchaDataset.IMAGE_WIDTH))
                        .setOutSize(CaptchaDataset.CAPTCHA_OPTIONS * CaptchaDataset.CAPTCHA_LENGTH)
                        .build();

        return new SequentialBlock()
                .add(resnet)
                .add(
                        resnetOutputList -> {
                            NDArray resnetOutput = resnetOutputList.singletonOrThrow();
                            NDList splitOutput =
                                    resnetOutput
                                            .reshape(
                                                    -1,
                                                    CaptchaDataset.CAPTCHA_LENGTH,
                                                    CaptchaDataset.CAPTCHA_OPTIONS)
                                            .split(CaptchaDataset.CAPTCHA_LENGTH, 1);

                            NDList output = new NDList(CaptchaDataset.CAPTCHA_LENGTH);
                            for (NDArray outputDigit : splitOutput) {
                                output.add(outputDigit.squeeze(1));
                            }
                            return output;
                        });
    }
}

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/374193?site
推荐阅读
相关标签
  

闽ICP备14008679号