使用 MobileNet 和 TensorFlow Lite 对图像进行分类

使用 MobileNet 和 TensorFlow Lite 对图像进行分类

MobileNet 是专为移动和嵌入式设备设计的卷积神经网络 (CNN)。MobileNet 通常用于图像分类。本教程演示了如何使用 MobileNet 和 TensorFlow Lite 对图像进行分类。

准备环境

下载预训练的 MobileNet 模型:

curl -Lo mobilenet_v1_1.0_224.tflite https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_1.0_224/1/default/1?lite-format=tflite

一个模型在 ImageNet 数据集上训练了 1001 个类。“背景”的 0 类,后面是 1000 个实际的 ImageNet 类。图像输入尺寸 224×224。

下载 ImageNet 类标签:

curl -o labels.txt https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt

下载图片进行测试:

curl -o test.bmp https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp

推理

类标签从文件写入数组。TensorFlow Lite 模型加载到内存中。在预处理过程中,调整图像大小,并将像素值归一化到 [-1, 1] 范围内。为了对图像进行分类,调用了一个模型。模型返回类别概率列表。最大概率映射到相关类别。

from tflite_runtime.interpreter import Interpreter
import numpy as np
import cv2


def readLabels(labelsFile):
    with open(labelsFile, 'r') as file:
        labels = [line.strip() for line in file.readlines()]

    return labels


def preprocessImage(img):
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.expand_dims(img, axis=0)

    return np.float32((1 / 127.5) * img - 1)


def main():
    LABELS_FILE = 'labels.txt'
    IMG_FILE = 'test.bmp'
    MODEL_FILE = 'mobilenet_v1_1.0_224.tflite'
    
    labels = readLabels(LABELS_FILE)

    img = cv2.imread(IMG_FILE)
    img = preprocessImage(img)

    interpreter = Interpreter(MODEL_FILE)
    interpreter.allocate_tensors()

    inputDetails = interpreter.get_input_details()
    interpreter.set_tensor(inputDetails[0]['index'], img)
    interpreter.invoke()

    outputDetails = interpreter.get_output_details()
    outputData = interpreter.get_tensor(outputDetails[0]['index'])
    outputData = np.squeeze(outputData)

    idx = np.argmax(outputData)
    label = labels[idx]
    output = outputData[idx]

    print('%.6f: %s (%s)' % (output, label, idx))


main()
#include <iostream>
#include <fstream>
#include <algorithm>
#include <opencv2/opencv.hpp>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>

using namespace cv;
using namespace tflite;

void readLabels(const char *labelsFile, std::vector<std::string> &labels)
{
    std::string line;
    std::ifstream fin(labelsFile);
    while (getline(fin, line)) {
        labels.push_back(line);
    }
}

void preprocessImage(Mat &img)
{
    resize(img, img, Size(224, 224), 0, 0, INTER_AREA);
    cvtColor(img, img, COLOR_BGRA2RGB);
    img.convertTo(img, CV_32FC3, 1 / 127.5f, -1);
}

int main()
{
    const int NUM_CLASSES = 1001;
    const char *LABELS_FILE = "labels.txt";
    const char *IMG_FILE = "test.bmp";
    const char *MODEL_FILE = "mobilenet_v1_1.0_224.tflite";

    std::vector<std::string> labels;
    readLabels(LABELS_FILE, labels);

    Mat img = imread(IMG_FILE);
    preprocessImage(img);

    std::unique_ptr<FlatBufferModel> model = FlatBufferModel::BuildFromFile(MODEL_FILE);
    ops::builtin::BuiltinOpResolver resolver;
    std::unique_ptr<Interpreter> interpreter;
    InterpreterBuilder(*model, resolver)(&interpreter);
    interpreter->AllocateTensors();

    auto *inputTensor = interpreter->typed_input_tensor<float>(0);
    memcpy(inputTensor, img.data, img.total() * img.elemSize());
    interpreter->Invoke();

    auto *outputData = interpreter->typed_output_tensor<float>(0);

    float *output = std::max_element(outputData, outputData + NUM_CLASSES);
    long idx = output - outputData;
    std::string label = labels[idx];

    std::cout << *output << ": " << label << " (" << idx << ")" << std::endl;

    return 0;
}

C++ 代码是使用 CMake 构建的。

CMake is not required.
cmake_minimum_required(VERSION 3.22)
project(app)

set(CMAKE_CXX_STANDARD 14)

add_executable(app main.cpp)

target_link_libraries(app tensorflow-lite opencv_core opencv_imgcodecs opencv_imgproc)

在我们的例子中,返回以下输出:

0.912099: military uniform (653)
0.912099: military uniform (653)

图像被正确分类的概率约为 91%。

本文来自作者投稿,版权归原作者所有。如需转载,请注明出处:https://www.nxrte.com/jishu/17447.html

(0)

相关推荐

发表回复

登录后才能评论