随着TensorFlow 2.0 alpha的发布,TensorFlow.js更新到首个正式版本1.0,TensorFlow的官网也增加了TensorFlow.js的文档,这说明TensorFlow.js不再是一个试验品。作为一名浏览器内核研发工程师,对TensorFlow.js自然充满了兴趣。
Javascript语言这些年来四处攻城掠地,服务端有Node.js,移动前端开发更是大热,就连桌面应用也有JS的身影,比如最近火热的Visual Studio Code,现在又渗透到人工智能领域。不得不感概,当年匆忙设计出来,饱受批评的一门脚本语言,竟然生命力这幺顽强。
闲话少说,下面就来看看在浏览器中训练模型是怎样的一种体验。
我之前写过一系列的《一步步提高手写数字的识别率(1) (2) (3)》,手写数字识别是一个非常好的入门项目,所以在这里我就以手写数字识别为例,说明在浏览器中如何训练模型。这里就不从最简单的线性回归模型开始,而是直接选用卷积神经网络。
和python代码中训练模型的步骤一样,使用TensorFlow.js在浏览器中训练模型的步骤主要有4步:
加载数据。
定义模型结构。
训练模型并监控其训练时的表现。
评估训练的模型。
加载数据
有过机器学习知识的朋友,应该对MNIST数据集不陌生,这是一套28×28大小手写数字的灰度图像,包含55000个训练样本,10000个测试样本,另外还有5000个交叉验证数据样本。tensorflow python提供了一个封装类,可以直接加载MNIST数据集,在TensorFlow.js中需要自己写代码加载:
const IMAGE_SIZE = 784; const NUM_CLASSES = 10; const NUM_DATASET_ELEMENTS = 65000; const TRAIN_TEST_RATIO = 5 / 6; const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS); const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS; const MNIST_IMAGES_SPRITE_PATH = 'mnist_images.png'; const MNIST_LABELS_PATH = 'mnist_labels_uint8'; /** * A class that fetches the sprited MNIST dataset and returns shuffled batches. * * NOTE: This will get much easier. For now, we do data fetching and * manipulation manually. */export class MnistData { constructor() { this.shuffledTrainIndex = 0; this.shuffledTestIndex = 0; } async load() { // Make a request for the MNIST sprited image. const img = new Image(); const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d'); const imgRequest = new Promise((resolve, reject) => { img.crossOrigin = ''; img.onload = () => { img.width = img.naturalWidth; img.height = img.naturalHeight; const datasetBytesBuffer = new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4); const chunkSize = 5000; canvas.width = img.width; canvas.height = chunkSize; for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) { const datasetBytesView = new Float32Array( datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4, IMAGE_SIZE * chunkSize); ctx.drawImage( img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width, chunkSize); const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); for (let j = 0; j < imageData.data.length / 4; j++) { // All channels hold an equal value since the image is grayscale, so // just read the red channel. datasetBytesView[j] = imageData.data[j * 4] / 255; } } this.datasetImages = new Float32Array(datasetBytesBuffer); resolve(); }; img.src = MNIST_IMAGES_SPRITE_PATH; }); const labelsRequest = fetch(MNIST_LABELS_PATH); const [imgResponse, labelsResponse] = await Promise.all([imgRequest, labelsRequest]); this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer()); // Create shuffled indices into the train/test set for when we select a // random dataset element for training / validation. this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS); this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS); // Slice the the images and labels into train and test sets. this.trainImages = this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS); this.trainLabels = this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS); this.testLabels = this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS); } nextTrainBatch(batchSize) { return this.nextBatch( batchSize, [this.trainImages, this.trainLabels], () => { this.shuffledTrainIndex = (this.shuffledTrainIndex + 1) % this.trainIndices.length; return this.trainIndices[this.shuffledTrainIndex]; }); } nextTestBatch(batchSize) { return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => { this.shuffledTestIndex = (this.shuffledTestIndex + 1) % this.testIndices.length; return this.testIndices[this.shuffledTestIndex]; }); } nextBatch(batchSize, data, index) { const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE); const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES); for (let i = 0; i < batchSize; i++) { const idx = index(); const image = data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE); batchImagesArray.set(image, i * IMAGE_SIZE); const label = data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES); batchLabelsArray.set(label, i * NUM_CLASSES); } const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]); const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]); return {xs, labels}; } }
代码中,加载一个 mnist_images.png 图片,该图片是所有MNIST数据集的图像拼接而来(文件很大,大约10M),另外加载一个 mnist_labels_uint8 文本文件,包含所有的MNIST数据集对应的标签。
需要注意的是,这只是一种加载MNIST数据集的方法,你也可以使用一个手写数字一张图片的MNIST数据集,分次加载多个图片文件。
上述代码实现了一个MnistData类,它有两个公共方法:
nextTrainBatch(batchSize):从训练集中返回一组随机图像及其标签。
nextTestBatch(batchSize):从测试集中返回一批图像及其标签
为了检验上述代码是否工作正常,可以写一段代码显示加载的数据:
async function showExamples(data) { // Create a container in the visor const surface = tfvis.visor().surface({ name: 'Input Data Examples', tab: 'Input Data'}); // Get the examples const examples = data.nextTestBatch(20); const numExamples = examples.xs.shape[0]; // Create a canvas element to render each example for (let i = 0; i < numExamples; i++) { const imageTensor = tf.tidy(() => { // Reshape the image to 28x28 px return examples.xs .slice([i, 0], [1, examples.xs.shape[1]]) .reshape([28, 28, 1]); }); const canvas = document.createElement('canvas'); canvas.width = 28; canvas.height = 28; canvas.style = 'margin: 4px;'; await tf.browser.toPixels(imageTensor, canvas); surface.drawArea.appendChild(canvas); imageTensor.dispose(); } } async function run() { const data = new MnistData(); await data.load(); await showExamples(data); } document.addEventListener('DOMContentLoaded', run);
定义模型结构
关于卷积神经网络,可以参阅《 一步步提高手写数字的识别率(3) 》这篇文章,这里定义的卷积网络结构为:
CONV -> MAXPOOlING -> CONV -> MAXPOOLING -> FC -> SOFTMAX
每个卷积层使用RELU激活函数,代码如下:
function getModel() { const model = tf.sequential(); const IMAGE_WIDTH = 28; const IMAGE_HEIGHT = 28; const IMAGE_CHANNELS = 1; // In the first layer of out convolutional neural network we have // to specify the input shape. Then we specify some paramaters for // the convolution operation that takes place in this layer. model.add(tf.layers.conv2d({ inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS], kernelSize: 5, filters: 8, strides: 1, activation: 'relu', kernelInitializer: 'varianceScaling' })); // The MaxPooling layer acts as a sort of downsampling using max values // in a region instead of averaging. model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); // Repeat another conv2d + maxPooling stack. // Note that we have more filters in the convolution. model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, strides: 1, activation: 'relu', kernelInitializer: 'varianceScaling' })); model.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]})); // Now we flatten the output from the 2D filters into a 1D vector to prepare // it for input into our last layer. This is common practice when feeding // higher dimensional data to a final classification output layer. model.add(tf.layers.flatten()); // Our last layer is a dense layer which has 10 output units, one for each // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9). const NUM_OUTPUT_CLASSES = 10; model.add(tf.layers.dense({ units: NUM_OUTPUT_CLASSES, kernelInitializer: 'varianceScaling', activation: 'softmax' })); // Choose an optimizer, loss function and accuracy metric, // then compile and return the model const optimizer = tf.train.adam(); model.compile({ optimizer: optimizer, loss: 'categoricalCrossentropy', metrics: ['accuracy'], }); return model; }
如果有过tensorflow python代码编写经验,上面的代码应该很容易理解。
训练模型并监控其训练时的表现
在浏览器中训练,也可以批量输入图像数据,可以指定batch size,epoch轮次。
async function train(model, data) { const metrics = ['loss', 'val_loss', 'acc', 'val_acc']; const container = { name: 'Model Training', styles: { height: '1000px' } }; const fitCallbacks = tfvis.show.fitCallbacks(container, metrics); const BATCH_SIZE = 512; const TRAIN_DATA_SIZE = 5500; const TEST_DATA_SIZE = 1000; const [trainXs, trainYs] = tf.tidy(() => { const d = data.nextTrainBatch(TRAIN_DATA_SIZE); return [ d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]), d.labels ]; }); const [testXs, testYs] = tf.tidy(() => { const d = data.nextTestBatch(TEST_DATA_SIZE); return [ d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]), d.labels ]; }); return model.fit(trainXs, trainYs, { batchSize: BATCH_SIZE, validationData: [testXs, testYs], epochs: 10, shuffle: true, callbacks: fitCallbacks }); }
和python代码相比,fit多了一个callbacks参数。需要注意的是,训练过程比较长,我们不能阻塞浏览器主线程,代码中大多时候需要异步方法。而callbacks可以通知主线程更新,这里借用了tfvis库,可以可视化训练过程(类似于tensorboard),但这里是在网页上显示。
评估训练的模型
评估时喂入测试集,代码也和python版本类似:
function doPrediction(model, data, testDataSize = 500) { const IMAGE_WIDTH = 28; const IMAGE_HEIGHT = 28; const testData = data.nextTestBatch(testDataSize); const testxs = testData.xs.reshape([testDataSize, IMAGE_WIDTH, IMAGE_HEIGHT, 1]); const labels = testData.labels.argMax([-1]); const preds = model.predict(testxs).argMax([-1]); testxs.dispose(); return [preds, labels]; }
如果我们希望更直观的显示每个类别的精确度以及错误的分类,可以借助tfvis库:
async function showAccuracy(model, data) { const [preds, labels] = doPrediction(model, data); const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds); const container = {name: 'Accuracy', tab: 'Evaluation'}; tfvis.show.perClassAccuracy(container, classAccuracy, classNames); labels.dispose(); } async function showConfusion(model, data) { const [preds, labels] = doPrediction(model, data); const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds); const container = {name: 'Confusion Matrix', tab: 'Evaluation'}; tfvis.render.confusionMatrix( container, {values: confusionMatrix}, classNames); labels.dispose(); }
评估结果如下图所示:
关于TensorFlow.js
TensowFlow.js借助于WebGL,可以加速训练过程。如果浏览器不支持WebGL,也不会出错,只不过会走CPU的路径,当然速度也会慢很多。
虽然通过WebGL,也利用上了GPU,但对于大规模深度学习模型,在浏览器中训练也不现实,这个时候我们也可以在server上训练好模型,转换为TensorFlow.js可用的模型格式,在浏览器中加载模型,并进行推断,关于这个话题,请关注后续的文章。
以上示例有完整的代码,点击阅读原文,跳转到我在github上建的示例代码。 另外,你也可以在浏览器中直接访问:http://ilego.club/ai/index.html ,直接体验浏览器中的机器学习。
参考文献:
tensorflow官网
TensorFlow.js — Handwritten digit recognition with CNNs
Be First to Comment