TensorFlow是一个开源的机器学习框架,上手简单,且在大量商业场景使用,可靠性高。
机器学习常用的语言主要是python、c,官方也提供了Java版本sdk,示例也很丰富。对于Rust语言,虽然官方文档没有写,其实官方也提供了rust版本的绑定,项目名称就是rust。
在Cargo.toml中配置版本,这里的0.18.0对应TensorFlow 2.8.0
tensorflow = "0.18.0"
这个依赖中的-sys模块会根据实际情况进行处理,如果本地没有Tensorflow的库文件,就会自动下载。如果没有预编译好的库文件,或者用户指定要从源码编译,那就会在本地编译。
我这里使用的是frozen graph,在rust代码中新建一个Graph进行加载即可。
let mut graph = Graph::new(); let model_file = MODEL_DIR .get_file("mobilenet/mobilenet_v2_1.4_224_frozen.pb") .unwrap(); let label_file = MODEL_DIR.get_file("mobilenet/label.txt").unwrap(); graph .import_graph_def(model_file.contents(), &ImportGraphDefOptions::new()) .unwrap();
为了最后打包方便,这里使用了include_dir包,会把模型和最终产物打包到一起。
这里的模型是一个mobilenet的图片分类模型,输入为224像素。图片的处理使用image包。
let img = image::open(photo.get_store_path())?; let resized = image::imageops::thumbnail(&img, 224, 224); let mut flattened: Vec<f32> = Vec::new(); for rgb in resized.pixels() { flattened.push(rgb[0] as f32 / 255.); flattened.push(rgb[1] as f32 / 255.); flattened.push(rgb[2] as f32 / 255.); } { let input = Tensor::new(&[1, 224, 224, 3]) .with_values(&flattened) .unwrap(); let session = Session::new(&SessionOptions::new(), &graph)?; let mut args = SessionRunArgs::new(); args.add_feed( &graph.operation_by_name_required("input").unwrap(), 0, &input, ); let prediction = args.request_fetch( &graph.operation_by_name_required("MobilenetV2/Predictions/Softmax")?, 0, ); session.run(&mut args).unwrap(); let prediction_res: Tensor<f32> = args.fetch(prediction)?; let mut i = 0; let mut json_vec: Vec<f32> = Vec::new(); while i < prediction_res.len() { json_vec.push(prediction_res[i]); i += 1; } let label = labels.get(imax(&json_vec).unwrap()).unwrap(); return Ok(label.to_string()); }
整体使用还是很简单的,只是部分文档缺失,需要研究下。如果有其他语言的Tensorflow使用经验会好很多。
目前Tensorflow Hub做的很好了,模型可以直接从上面下载,Savemodel也可以直接加载,也可以转为frozen后使用。
本文版权归作者所有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。
转载自夜明的孤行灯
本文链接地址:https://www.huangyunkun.com/2022/08/15/use-tensorflow-with-rust/
Be First to Comment