# 微调预训练模型——文本分类

文本分类任务有9个级别：

```python
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc",
              "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
```

* [CoLA](https://nyu-mll.github.io/CoLA/) (Corpus of Linguistic Acceptability) 鉴别一个句子是否语法正确.
* [MNLI](https://arxiv.org/abs/1704.05426) (Multi-Genre Natural Language Inference) 给定一个假设，判断另一个句子与该假设的关系：entails, contradicts 或者 unrelated。
* [MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398) (Microsoft Research Paraphrase Corpus) 判断两个句子是否互为paraphrases.
* [QNLI](https://rajpurkar.github.io/SQuAD-explorer/) (Question-answering Natural Language Inference) 判断第2句是否包含第1句问题的答案。
* [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs) (Quora Question Pairs2) 判断两个问句是否语义相同。
* [RTE](https://aclweb.org/aclwiki/Recognizing_Textual_Entailment) (Recognizing Textual Entailment)判断一个句子是否与假设成entail关系。
* [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) 判断一个句子的情感正负向.
* [STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) (Semantic Textual Similarity Benchmark) 判断两个句子的相似性（分数为1-5分）。
* [WNLI](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not.&#x20;

本次的notebook以CoLA任务为例：

```python
# 任务为CoLA任务
task = "cola"
# BERT模型
model_checkpoint = "distilbert-base-uncased"
# 根据GPU调整batch_size大小，避免显存溢出
batch_size = 16
```

## 加载数据

```python
from datasets import load_dataset, load_metric
actual_task = "mnli" if task == "mnli-mm" else task
dataset = load_dataset("glue", actual_task) 
metric = load_metric('glue', actual_task)
```

这个`datasets`对象本身是一种[`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict)数据结构. 对于训练集、验证集和测试集，只需要使用对应的key（train，validation，test）即可得到相应的数据。

下面是结构：

```
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1063
    })
})
```

metric评估分数的api

`metric.compute(predictions=fake_preds, references=fake_labels)`

计算分数，每一个文本分类任务所对应的metic有所不同，具体如下:

* for CoLA: [Matthews Correlation Coefficient](https://en.wikipedia.org/wiki/Matthews_correlation_coefficient)
* for MNLI (matched or mismatched): Accuracy
* for MRPC: Accuracy and [F1 score](https://en.wikipedia.org/wiki/F1_score)
* for QNLI: Accuracy
* for QQP: Accuracy and [F1 score](https://en.wikipedia.org/wiki/F1_score)
* for RTE: Accuracy
* for SST-2: Accuracy
* for STS-B: [Pearson Correlation Coefficient](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient) and \[Spearman's\_Rank\_Correlation\_Coefficient]\(<https://en.wikipedia.org/wiki/Spearman's_rank_correlation_coefficient>)
* for WNLI: Accuracy

所以metric要和任务对齐。

## 数据预处理

在将数据喂入模型之前，我们需要对数据进行预处理。预处理的工具叫`Tokenizer`。`Tokenizer`首先对输入进行tokenize，然后将tokens转化为预模型中需要对应的token ID，再转化为模型需要的输入格式。

为了达到数据预处理的目的，我们使用`AutoTokenizer.from_pretrained`方法实例化我们的tokenizer，这样可以确保：

* 我们得到一个与预训练模型一一对应的tokenizer。
* 使用指定的模型checkpoint对应的tokenizer的时候，我们也下载了模型需要的词表库vocabulary，准确来说是tokens vocabulary。

这个被下载的tokens vocabulary会被缓存起来，从而再次使用的时候不会重新下载。

```python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
```

注意：`use_fast=True`要求tokenizer必须是

```python
transformers.PreTrainedTokenizerFast
```

类型，因为我们在预处理的时候需要用到fast tokenizer的一些特殊特性（比如多线程快速tokenizer）。如果对应的模型没有fast tokenizer，去掉这个选项即可。

几乎所有模型对应的tokenizer都有对应的fast tokenizer。我们可以在[模型tokenizer对应表](https://huggingface.co/transformers/index.html#bigtable)里查看所有预训练模型对应的tokenizer所拥有的特点。

tokenizer既可以对单个文本进行预处理，也可以对一对文本进行预处理，tokenizer预处理后得到的数据满足预训练模型输入格式

```python
tokenizer("Hello, this one sentence!", "And this sentence goes with it.")
```

取决于我们选择的预训练模型，我们将会看到tokenizer有不同的返回，tokenizer和预训练模型是一一对应的，更多信息可以在[这里](https://huggingface.co/transformers/preprocessing.html)进行学习。

预处理代码:

```python
def preprocess_function(examples):

  if sentence2_key is None:

     return tokenizer(examples[sentence1_key], truncation=True)

  return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
```

预处理函数可以处理单个样本，也可以对多个样本进行处理。如果输入是多个样本，那么返回的是一个list：

```python
preprocess_function(dataset['train'][:5])
```

## 微调预训练模型

既然数据已经准备好了，现在我们需要下载并加载我们的预训练模型，然后微调预训练模型。既然我们是做seq2seq任务，那么我们需要一个能解决这个任务的模型类。我们使用`AutoModelForSequenceClassification` 这个类。和tokenizer相似，`from_pretrained`方法同样可以帮助我们下载并加载模型，同时也会对模型进行缓存，就不会重复下载模型啦。

需要注意的是：STS-B是一个回归问题，MNLI是一个3分类问题：

```python
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
```

由于我们微调的任务是文本分类任务，而我们加载的是预训练的语言模型，所以会提示我们加载模型的时候扔掉了一些不匹配的神经网络参数.

### 训练参数

为了能够得到一个`Trainer`训练工具，我们还需要3个要素，其中最重要的是训练的设定/参数 [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments)。这个训练设定包含了能够定义训练过程的所有属性。

```python
metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"

args = TrainingArguments(
    "test-glue",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)
```

上面evaluation\_strategy = "epoch"参数告诉训练代码：我们每个epcoh会做一次验证评估。

由于不同的任务需要不同的评测指标，我们定一个函数来根据任务名字得到评价方法:

```python
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if task != "stsb":
        predictions = np.argmax(predictions, axis=1)
    else:
        predictions = predictions[:, 0]
    return metric.compute(predictions=predictions, references=labels)
```

传给Trainer

```python
validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
```

训练与评估

```python
trainer.train()
```

```
trainer.evaluate()
```

> ```
> The following columns in the evaluation set  don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: idx, sentence.
> ***** Running Evaluation *****
>   Num examples = 1043
>   Batch size = 16
> ```
>
> \[66/66 00:02]
>
> ```
> {'epoch': 5.0,
>  'eval_loss': 0.7388790845870972,
>  'eval_matthews_correlation': 0.5308757570358055,
>  'eval_runtime': 2.0787,
>  'eval_samples_per_second': 501.765,
>  'eval_steps_per_second': 31.751}
> ```

## 超参数搜索

Trainer支持超参数搜索，使用[optuna](https://optuna.org/) or [Ray Tune](https://docs.ray.io/en/latest/tune/)代码库。需要安装：

! pip install optuna

! pip install ray\[tune]

`Trainer`在搜索时将会返回多个训练好的模型，所以需要传入一个定义好的模型从而让`Trainer`可以不断重新初始化该传入的模型：

```python
def model_init():

  return AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
```

传给Trainer

```python
trainer = Trainer(

  model_init=model_init,

  args=args,

  train_dataset=encoded_dataset["train"],

  eval_dataset=encoded_dataset[validation_key],

  tokenizer=tokenizer,

  compute_metrics=compute_metrics

)
```

调用方法`hyperparameter_search`。这个过程可能很久，我们可以先用部分数据集进行超参搜索，再进行全量训练。 比如使用1/10的数据进行搜索：

```python
best_run = trainer.hyperparameter_search(n_trials=10, direction="maximize")
```

```
best_run
```

> ```
> BestRun(run_id='2', objective=0.5255827322697563, hyperparameters={'learning_rate': 4.5335767669016705e-05, 'num_train_epochs': 5, 'seed': 34, 'per_device_train_batch_size': 16})
> ```

给`trainer`设置参数

```python
for n, v in best_run.hyperparameters.items():

  setattr(trainer.args, n, v)
```

```python
trainer.train()
```


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://jon-xia.gitbook.io/workspace/ji-qi-xue-xi-23/datawhale/transformer/t6.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
