图像分类实战——阿猫阿狗的识别
概述
做一个阿猫阿狗图片分类的功能,数据集和模型均来自Huggingface
数据加载
数据集共包含1w个数据,分为训练集8000,测试集2000。原始数据集在Hugginface上,需要联网下载,可以下载离线数据集然后从本地加载(模型和数据在文末的地址)。
数据示例

1
| image_data = load_dataset('Bingsu/Cat_and_Dog')
|
数据处理
1 2 3 4 5 6 7 8 9
| image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
def image_detail(examples): images =[Image.open(io.BytesIO(image['bytes'])) for image in examples['image']] inputs = image_processor(images, return_tensors="pt") inputs['labels']=examples['labels'] return inputs
image_data_pt=image_data.map(image_detail,batched=True,remove_columns=image_data["train"].column_names)
|
加载模型
1 2 3
| id2label={0:'cat', 1:'dog'} label2id={'cat':0, 'dog':1} model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50",num_labels=2,id2label=id2label,label2id=label2id,ignore_mismatched_sizes=True)
|
评估函数
如果你不想在训练过程中对模型进行评估,可以不写。只需要在最后做一次评估就行。
1 2 3 4 5 6 7 8 9 10 11
| import evaluate
acc_metric = evaluate.load("accuracy") f1_metirc = evaluate.load("f1",average='micro') def eval_metric(eval_predict): predictions, labels = eval_predict predictions = predictions.argmax(axis=-1) acc = acc_metric.compute(predictions=predictions, references=labels) f1 = f1_metirc.compute(predictions=predictions, references=labels,average='micro') acc.update(f1) return acc
|
训练参数
1 2 3 4 5 6 7 8 9 10 11
| train_args =TrainingArguments(output_dir="./checkpoints", per_device_train_batch_size=512, per_device_eval_batch_size=1, num_train_epochs=1, logging_steps=10, evaluation_strategy="epoch", save_strategy="epoch", learning_rate=2e-5, metric_for_best_model="f1", load_best_model_at_end=True, )
|
训练器
1 2 3 4 5
| trainer = Trainer(model=model, args=train_args, train_dataset=image_data_pt["train"], eval_dataset=image_data_pt["test"], compute_metrics=eval_metric)
|
模型训练
配置好所有参数后,使用训练器开启训练
我在训练过程中的评估结果
| Epoch |
Training Loss |
Validation Loss |
Accuracy |
F1 |
| 1 |
No log |
0.676835 |
0.662000 |
0.662000 |
模型推理
1 2 3 4 5 6 7 8 9 10 11 12 13
| from transformers import pipeline pipe = pipeline("image-classification", model=model,image_processor=image_processor,device=0)
train_image = image_data['test']['image'] show_examples=train_image[:15] for image in show_examples: _image = Image.open(io.BytesIO(image['bytes'])) _image=_image.resize((100,100)) res=pipe(_image) result = sorted(res,key=lambda x:x['score'],reverse=True)[0]['label'] _image.show(title='result') print(result) print('------------')
|
我只取了测试集中的前15条进行推理预测,预测结果如下:

可以看得出来,第一个小猫就被识别错了,哈哈,效果一般般啦。原因有很多,跟数据集的大小,训练的次数,学习率等多种因素有关,感兴趣的小伙伴可以自行尝试下。