虽然之前有用过CMSIS-NN框架进行一些归类问题的学习,但是其实也没系统地分析过主流一些框架的具体做法,越来越多的人找到我希望能写点什幺,终于拖到现在打算写一下.
截至发这篇博客,我检出的版本是:8855f56500ff8efa449662a95fe69f24bb78c0a6
主要例子的起始文件是hello_world_test.cc,具体找到tensorflow源码工程中以下代码:
一开始就引入了大量头文件:
//一个允许解释器加载我们的模型所需要使用的操作的类 #include "tensorflow/lite/micro/all_ops_resolver.h" //我们转换后得到的模型,Flat的,二进制存在数组里的. #include "tensorflow/lite/micro/examples/hello_world/model.h" //一个日志用的调试类. #include "tensorflow/lite/micro/micro_error_reporter.h" //TensorFlow Lite for Microcontrollers解释器,他会运行我们的模型. #include "tensorflow/lite/micro/micro_interpreter.h" //测试框架 #include "tensorflow/lite/micro/testing/micro_test.h" //定义数据结构schema用,用于理解model数据. #include "tensorflow/lite/schema/schema_generated.h"
代码下一部分由测试框架代替,即上面引入的头文件中,其中由这幺包裹着.
TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(LoadModelAndPerformInference) { ...... } TF_LITE_MICRO_TESTS_END
其中TF_LITE_MICRO_TEST传入的参数LoadModelAndPerformInference是测试的名称,他会喝测试结果一起输出,以便查看测试是通过还是失败,先不看具体代码,运行试试.
make -f tensorflow/lite/micro/tools/make/Makefile test_hello_world_test
最终可以看到,测试成功.
tensorflow/lite/micro/tools/make/downloads/flatbuffers already exists, skipping the download. tensorflow/lite/micro/tools/make/downloads/pigweed already exists, skipping the download. g++ -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DLINUX -DTF_LITE_USE_CTIME -I. -Itensorflow/lite/micro/tools/make/downloads/gemmlowp -Itensorflow/lite/micro/tools/make/downloads/flatbuffers/include -Itensorflow/lite/micro/tools/make/downloads/ruy -Itensorflow/lite/micro/tools/make/downloads/kissfft -c tensorflow/lite/micro/examples/hello_world/hello_world_test.cc -o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/obj/tensorflow/lite/micro/examples/hello_world/hello_world_test.o g++ -std=c++11 -fno-rtti -fno-exceptions -fno-threadsafe-statics -fno-unwind-tables -ffunction-sections -fdata-sections -fmessage-length=0 -DTF_LITE_STATIC_MEMORY -DTF_LITE_DISABLE_X86_NEON -O3 -Werror -Wsign-compare -Wdouble-promotion -Wshadow -Wunused-variable -Wmissing-field-initializers -Wunused-function -Wswitch -Wvla -Wall -Wextra -Wstrict-aliasing -Wno-unused-parameter -DLINUX -DTF_LITE_USE_CTIME -I. -Itensorflow/lite/micro/tools/make/downloads/gemmlowp -Itensorflow/lite/micro/tools/make/downloads/flatbuffers/include -Itensorflow/lite/micro/tools/make/downloads/ruy -Itensorflow/lite/micro/tools/make/downloads/kissfft -o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/bin/hello_world_test tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/obj/tensorflow/lite/micro/examples/hello_world/hello_world_test.o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/obj/tensorflow/lite/micro/examples/hello_world/model.o tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/lib/libtensorflow-microlite.a -Wl,--fatal-warnings -Wl,--gc-sections -lm tensorflow/lite/micro/tools/make/gen/linux_x86_64_default/bin/hello_world_test '~~~ALL TESTS PASSED~~~' linux Testing LoadModelAndPerformInference 1/1 tests passed ~~~ALL TESTS PASSED~~~
代码一开始先设置了一个logger,这个logger用起来和printf差不多,紧接着的下面就有一个例子,我们可以试试修改这个逻辑,让他直接打印或者自己模仿写一个来试试.
// Set up logging tflite::MicroErrorReporter micro_error_reporter; if (model->version() != TFLITE_SCHEMA_VERSION) { TF_LITE_REPORT_ERROR(µ_error_reporter, "Model provided is schema version %d not equal " "to supported version %d.\n", model->version(), TFLITE_SCHEMA_VERSION); }
他大致会输出这些内容(实际上由于版本相等,他不会输出!):
Model provided is schema version X not equal to supported version X.
在这个打印之前有一个GetModel的操作,就是从我们的数组里面读取Model,然后生成一个TF Lite的Model的对象,这个Model是从示例代码里面的create_sine_model.ipynb创建的.
接下来代码创建定义各种东西,首先是创建一个主要的操作类,然后tensor_arena是TF Lite的工作内存,在单片机中应该用mallloc之类管理,他应该多大这个很难确定,一般就是先设定一个较大的数,然后逐步缩小,确定一个稳定且较为节约的数值,最后把这些东西连接起来成为一个interpreter,即字面意思:解释器.
// This pulls in all the operation implementations we need tflite::AllOpsResolver resolver; constexpr int kTensorArenaSize = 2000; uint8_t tensor_arena[kTensorArenaSize]; // Build an interpreter to run the model with tflite::MicroInterpreter interpreter(model, resolver, tensor_arena, kTensorArenaSize, µ_error_reporter); // Allocate memory from the tensor_arena for the model's tensors TF_LITE_MICRO_EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
接下来要把一些输入的内容放进去,先申请一个输入空间,然后确定申请到的不是NULL(就像内存分配一样,能分配到空间才行啊),一旦申请成功后,模型就已经成功加载(因为之前声明解释器时候已经确定了模型),而后面几个EQ是断言这个模型的大小规格,也就是说,实际上整个申请只有interpreter.input(0)这幺一句.
// Obtain a pointer to the model's input tensor TfLiteTensor* input = interpreter.input(0); // Make sure the input has the properties we expect TF_LITE_MICRO_EXPECT_NE(nullptr, input); // The property "dims" tells us the tensor's shape. It has one element for // each dimension. Our input is a 2D tensor containing 1 element, so "dims" // should have size 2. TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size); // The value of each element gives the length of the corresponding tensor. // We should expect two single element tensors (one is contained within the // other). TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]); // The input is an 8 bit integer value TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input->type);
这幺就可以开始进行推断了,这里分别三段是获取输入的量化参数,将输入的浮点数量化为整数(为了优化速度),将量化的输入放在模型的输入张量中,然后运行模型.
// Get the input quantization parameters float input_scale = input->params.scale; int input_zero_point = input->params.zero_point; // Quantize the input from floating-point to integer int8_t x_quantized = x / input_scale + input_zero_point; // Place the quantized input in the model's input tensor input->data.int8[0] = x_quantized; // Run the model and check that it succeeds TfLiteStatus invoke_status = interpreter.Invoke(); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
然后读取其输出,然后断言其数据正确性,再把输出还原成float类型.
// Obtain a pointer to the output tensor and make sure it has the // properties we expect. It should be the same as the input tensor. TfLiteTensor* output = interpreter.output(0); TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size); TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[1]); TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type); // Get the output quantization parameters float output_scale = output->params.scale; int output_zero_point = output->params.zero_point; // Obtain the quantized output from model's output tensor int8_t y_pred_quantized = output->data.int8[0]; // Dequantize the output from integer to floating-point float y_pred = (y_pred_quantized - output_zero_point) * output_scale;
然后测试误差在不在范围内,后续几个测试都是这个意思.
float epsilon = 0.05f; TF_LITE_MICRO_EXPECT_NEAR(y_true, y_pred, epsilon); // Run inference on several more values and confirm the expected outputs x = 1.f; y_true = sin(x); input->data.int8[0] = x / input_scale + input_zero_point; interpreter.Invoke(); y_pred = (output->data.int8[0] - output_zero_point) * output_scale; TF_LITE_MICRO_EXPECT_NEAR(y_true, y_pred, epsilon);
这里Invoke有很多个用途,模型输入用了他(输入就会产生输出,所以输出部分看不到~),用来推断数据也用了他.
如果修改输入模型的参数,或者让误差变得更严格(超过模型本身能力),就会出现错误.
除了这个代码,还有很多其他文件夹里面包含了不同微控制器用的代码.
比如看到ESP的代码里只有这些.
从头文件能看出来,其他内容从main_functions.cc开始,这个文件一开始有很多熟悉的include,不用展开应该都能猜到具体意思了.
#include "tensorflow/lite/micro/examples/hello_world/main_functions.h" #include "tensorflow/lite/micro/all_ops_resolver.h" #include "tensorflow/lite/micro/examples/hello_world/constants.h" #include "tensorflow/lite/micro/examples/hello_world/model.h" #include "tensorflow/lite/micro/examples/hello_world/output_handler.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/system_setup.h" #include "tensorflow/lite/schema/schema_generated.h"
然后定义了全局使用的一些变量,这些变量名应该都很熟悉,在我们的测试例子里就是同样的名字,唯一的新鲜事物是inference_count,他指示这个程序进行了多少次推断.
// Globals, used for compatibility with Arduino-style sketches. namespace { tflite::ErrorReporter* error_reporter = nullptr; const tflite::Model* model = nullptr; tflite::MicroInterpreter* interpreter = nullptr; TfLiteTensor* input = nullptr; TfLiteTensor* output = nullptr; int inference_count = 0; constexpr int kTensorArenaSize = 2000; uint8_t tensor_arena[kTensorArenaSize]; } // namespace
玩过Arduino应该都不陌生,一开始有个setup函数,然后有个loop函数.setup函数只执行一次,loop函数一直执行.
setup还是熟悉的代码,有点疑惑的是他在setup过程就访问了输出?不是的,其实只是给output这个ptr分配一下内存~
// The name of this function is important for Arduino compatibility. void setup() { tflite::InitializeTarget(); // Set up logging. Google style is to avoid globals or statics because of // lifetime uncertainty, but since this has a trivial destructor it's okay. // NOLINTNEXTLINE(runtime-global-variables) static tflite::MicroErrorReporter micro_error_reporter; error_reporter = µ_error_reporter; // Map the model into a usable data structure. This doesn't involve any // copying or parsing, it's a very lightweight operation. model = tflite::GetModel(g_model); if (model->version() != TFLITE_SCHEMA_VERSION) { TF_LITE_REPORT_ERROR(error_reporter, "Model provided is schema version %d not equal " "to supported version %d.", model->version(), TFLITE_SCHEMA_VERSION); return; } // This pulls in all the operation implementations we need. // NOLINTNEXTLINE(runtime-global-variables) static tflite::AllOpsResolver resolver; // Build an interpreter to run the model with. static tflite::MicroInterpreter static_interpreter( model, resolver, tensor_arena, kTensorArenaSize, error_reporter); interpreter = &static_interpreter; // Allocate memory from the tensor_arena for the model's tensors. TfLiteStatus allocate_status = interpreter->AllocateTensors(); if (allocate_status != kTfLiteOk) { TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed"); return; } // Obtain pointers to the model's input and output tensors. input = interpreter->input(0); output = interpreter->output(0); // Keep track of how many inferences we have performed. inference_count = 0; }
loop代码里也相对的熟悉,其中kXrange,kInferencesPerCycle是定义在常量里的,kXrange是2pi,kInferencesPerCycle是限制inference_count的每个周期的最大推理次数,一旦到达整个周期,将会回到0继续推理(因为这是个周期正弦),要记住,这里依然是x输入得到y,代码和测试样例里差不多,除了HandleOutput这个额外的.
// The name of this function is important for Arduino compatibility. void loop() { // Calculate an x value to feed into the model. We compare the current // inference_count to the number of inferences per cycle to determine // our position within the range of possible x values the model was // trained on, and use this to calculate a value. float position = static_cast<float>(inference_count) / static_cast<float>(kInferencesPerCycle); float x = position * kXrange; // Quantize the input from floating-point to integer int8_t x_quantized = x / input->params.scale + input->params.zero_point; // Place the quantized input in the model's input tensor input->data.int8[0] = x_quantized; // Run inference, and report any error TfLiteStatus invoke_status = interpreter->Invoke(); if (invoke_status != kTfLiteOk) { TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed on x: %f\n", static_cast<double>(x)); return; } // Obtain the quantized output from model's output tensor int8_t y_quantized = output->data.int8[0]; // Dequantize the output from integer to floating-point float y = (y_quantized - output->params.zero_point) * output->params.scale; // Output the results. A custom HandleOutput function can be implemented // for each supported hardware target. HandleOutput(error_reporter, x, y); // Increment the inference_counter, and reset it if we have reached // the total number per cycle inference_count += 1; if (inference_count >= kInferencesPerCycle) inference_count = 0; }
HandleOutput即关联到硬件上的输出,看看Arduino例子的Handle是这幺实现的.
// Animates a dot across the screen to represent the current x and y values void HandleOutput(tflite::ErrorReporter* error_reporter, float x_value, float y_value) { // Do this only once if (!initialized) { // Set the LED pin to output pinMode(led, OUTPUT); initialized = true; } // Calculate the brightness of the LED such that y=-1 is fully off // and y=1 is fully on. The LED's brightness can range from 0-255. int brightness = (int)(127.5f * (y_value + 1)); // Set the brightness of the LED. If the specified pin does not support PWM, // this will result in the LED being on when y > 127, off otherwise. analogWrite(led, brightness); // Log the current brightness value for display in the Arduino plotter TF_LITE_REPORT_ERROR(error_reporter, "%d\n", brightness); }
OK,这样有一个基于机器学习(逼格高)的呼吸灯,现在还没说怎幺让他在实际硬件中跑,所以,我们可以这幺测试.
make -f tensorflow/lite/micro/tools/make/Makefile hello_world
执行其生成的文件即可,但是输出的数据是以2的幂形式输出的,具体自己手动换算一下就得了.
Be First to Comment