Android 端影象多風格遷移

pqpo發表於2019-04-22

影象風格遷移是利用機器學習演算法實現的影象風格轉換, 本篇文章會從風格遷移網路發展歷史出發一步步瞭解風格遷移網路演算法,然後帶領大家搭建單模型多風格的訓練網路,最終給出如何將訓練出的模型移植到 Android 端執行的工程化實踐。

pqpo部落格 >> 原文連結

何為影象風格遷移? 使用機器學習訓練特定圖片的風格,然後將對應的風格應用到任意圖片。效果如下所示:

Android 端影象多風格遷移

風格遷移在移動端的最佳實踐:Prisma

Android 端影象多風格遷移

風格遷移網路發展史

A Neural Algorithm of Artistic Style》:第一代風格遷移網路風格化的過程是一個訓練過程,輸入風格圖片與內容影象經過訓練生成風格遷移圖片。經過訓練降低內容損失與風格損失,從而得到一張即保證內容又擁有特定風格的圖片。缺點顯而易見,速度慢!但是確奠定了之後風格遷移的基礎。

《Perceptual Losses for Real-Time Style Transfer and Super-Resolution》:之後是稱為快速風格遷移的網路,在上一代的基礎上增加了轉換網路,通過訓練轉換網路的引數,可以達到一個模型風格化任意圖片的目的,由於這是一次正向傳播的過程,速度相比上一代著實提高了許多,同樣使用訓練好的 VGG 網路進行特徵提取,經過訓練降低內容損失與風格損失。但是這個網路只能產生一種型別風格的圖片,如果要支援多風格需要訓練多個模型。

《A Learned Representation For Artistic Style》:然後在上一代的基礎上發展出了支援多風格的快速遷移網路,網路結構與上一代基本一致,最大的不同是使用 Conditional Instance Normalization 層來代替了原來的 Batch Normalization,前者可以理解為是多個 Batch Normalization 層的合併,根據輸入的風格選取特定的 Batch Normalization 引數。該模型的優點是單個模型支援多種風格,速度快,模型小;但是隻能支援預訓練好的若干模型。

《Meta Networks for Neural Style Transfer》:最後還有一種支援任意風格任意影象的風格遷移網路,這種網路更進一步,引入了 MetaNet,其中轉換網路 Transform Net 的一部分引數是 MetaNet 生成的,一部分引數是訓練產生的。最終能輸出任意風格與內容影象的風格化圖片。缺點是模型較大,網路較複雜,不太適合於移動端的風格遷移。

經過對比最終選擇更適合移動端的第三種風格遷移網路。

原理

單風格遷移與多風格遷移的模型結構是大體一致的,如下所示:

Android 端影象多風格遷移

總共包括了轉換網路與損失網路,該圖中的損失網路使用了 VGG-16 網路,當然也可以使用 VGG-19 等其他圖片分類網路。訓練階段訓練的是轉換網路的引數,其中 VGG-16 網路是訓練好的影象分類模型用於提取特徵並計算損失。

下面詳細講解一下該圖對應的網路模型,其中特徵層與損失網路選型有關,此處以 VGG-16 為例:

  1. 內容圖片經過轉換網路輸出遷移圖片 Y
  2. 遷移圖片經過損失網路提取特徵層 (relu3_3)
  3. 原內容圖片經過損失網路提取特徵層 (relu3_3)
  4. 使用步驟2、3中的特徵層計算內容損失 content_loss
  5. 風格圖片經過損失網路提取特徵層(relu1_2、relu2_2、relu3_3、relu4_3)
  6. 使用第二步中的模型(遷移圖片經過損失網路)提取特徵層(relu1_2、relu2_2、relu3_3、relu4_3)
  7. 使用步驟5、6中的特徵層計算風格損失 style_loss
  8. 訓練降低內容損失與風格損失:content_loss + style_loss

實現

首先定義轉化網路,模型層級與對應引數如下:

Android 端影象多風格遷移

實現轉化網路:

def net(x, style_control=None, reuse=False, alpha=1.0):
    with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
        x = conv_layer(x, int(alpha * 32), 9, 1, style_control=style_control, name='conv1')
        x = conv_layer(x, int(alpha * 64), 3, 2, style_control=style_control, name='conv2')
        x = conv_layer(x, int(alpha * 128), 3, 2, style_control=style_control, name='conv3')
        x = residual_block(x, int(alpha * 128), 3, style_control=style_control, name='res1')
        x = residual_block(x, int(alpha * 128), 3, style_control=style_control, name='res2')
        x = residual_block(x, int(alpha * 128), 3, style_control=style_control, name='res3')
        x = residual_block(x, int(alpha * 128), 3, style_control=style_control, name='res4')
        x = residual_block(x, int(alpha * 128), 3, style_control=style_control, name='res5')
        x = conv_tranpose_layer(x, int(alpha * 64), 3, 2, style_control=style_control, name='up_conv1')
        x = pooling(x)
        x = conv_tranpose_layer(x, int(alpha * 32), 3, 2, style_control=style_control, name='up_conv2')
        x = pooling(x)
        x = conv_layer(x, 3, 9, 1, relu=False, style_control=style_control, name='output')
        preds = tf.nn.sigmoid(x) * 255.
    return preds
複製程式碼

最後的損失函式使用 sigmoid ,它的取值範圍是0-1,所以需要乘以 255 轉化為顏色值。

每一層的具體實現如下:

def conv_layer(net, num_filters, filter_size, strides, style_control=None, relu=True, name='conv'):
    with tf.variable_scope(name):
        b,w,h,c = net.get_shape().as_list()
        weights_shape = [filter_size, filter_size, c, num_filters]
        weights_init = tf.get_variable(name, shape=weights_shape, initializer=tf.truncated_normal_initializer(stddev=.01))
        strides_shape = [1, strides, strides, 1]
        p = int((filter_size - 1) / 2)
        if strides == 1:
            net = tf.pad(net, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
            net = tf.nn.conv2d(net, weights_init, strides_shape, padding="VALID")
        else:
            net = tf.nn.conv2d(net, weights_init, strides_shape, padding="SAME")
        net = conditional_instance_norm(net, style_control=style_control)
        if relu:
            net = tf.nn.relu(net)

    return net

def conv_tranpose_layer(net, num_filters, filter_size, strides, style_control=None, name='conv_t'):
    with tf.variable_scope(name):
        b, w, h, c = net.get_shape().as_list()
        weights_shape = [filter_size, filter_size, num_filters, c]
        weights_init = tf.get_variable(name, shape=weights_shape, initializer=tf.truncated_normal_initializer(stddev=.01))
        batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()]
        new_rows, new_cols = int(rows * strides), int(cols * strides)
        new_shape = [batch_size, new_rows, new_cols, num_filters]
        tf_shape = tf.stack(new_shape)
        strides_shape = [1,strides,strides,1]

        p = (filter_size - 1) / 2
        if strides == 1:
            net = tf.pad(net, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
            net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding="VALID")
        else:
            net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding="SAME")
        net = conditional_instance_norm(net, style_control=style_control)

    return tf.nn.relu(net)

def residual_block(net, num_filters=128, filter_size=3, style_control=None, name='res'):
    with tf.variable_scope(name+'_a'):
        tmp = conv_layer(net, num_filters, filter_size, 1, style_control=style_control)
    with tf.variable_scope(name+'_b'):
        output = net + conv_layer(tmp, num_filters, filter_size, 1, style_control=style_control, relu=False)
    return output
複製程式碼

層級最後都使用了歸一化函式 conditional_instance_norm,這正是多風格遷移網路與單風格遷移網路的不同之處,單風格遷移網路使用的歸一化實現如下:

def instance_norm(net, train=True, name='in'):
    with tf.variable_scope(name):
        batch, rows, cols, channels = [i.value for i in net.get_shape()]
        var_shape = [channels]
        mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
        shift = tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.))
        scale = tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.))
        epsilon = 1e-3
        normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
    return scale * normalized + shift
複製程式碼

使用 tf.nn.moments 計算輸入層的平均數與方差,然後將輸入層減去平均數除以方差,最後開根號實現了輸入層的歸一化。最後針對歸一化的結果乘以 scale 加上 shift: scale * normalized + shift,其中 scale 與 shift 是需要經過訓練的引數(有些文中使用 gamma、beta 作為此處引數)。

多風格遷移的區別就在於,有多少風格就存在多少對 scale, shift 引數(對應圖中的 gamma、beta ),然後根據風格圖位置選取對應引數,所以內建訓練風格越多模型會越大,如下圖所示

Android 端影象多風格遷移

conditional_instance_norm 的實現方式如下:

def conditional_instance_norm(net, style_control=None, name='cond_in'):
    with tf.variable_scope(name):
        batch, rows, cols, channels = [i.value for i in net.get_shape()]
        mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
        var_shape = [channels]
        shift = []
        scale = []
        for i in range(style_control.shape[0]):
            with tf.variable_scope('{0}'.format(i) + '_style'):
                shift.append(tf.get_variable('shift', shape=var_shape, initializer=tf.constant_initializer(0.)))
                scale.append(tf.get_variable('scale', shape=var_shape, initializer=tf.constant_initializer(1.)))
        shift = tf.convert_to_tensor(shift)
        scale = tf.convert_to_tensor(scale)
        epsilon = 1e-3
        normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
        idx = tf.where(tf.not_equal(style_control, tf.constant(0, dtype=tf.float32)))
        style_select = tf.gather(style_control, idx)
        scale_select = tf.gather_nd(scale, idx)
        shift_select = tf.gather_nd(shift, idx)
        style_scale = tf.reduce_sum(scale_select * style_select, axis=0)
        style_shift = tf.reduce_sum(shift_select * style_select, axis=0)
        style_sum = tf.reduce_sum(style_control)
        style_scale = style_scale / style_sum
        style_shift = style_shift / style_sum
        output = style_scale * normalized + style_shift
    return output
複製程式碼

其中輸入的 style_control 是 one-hot 格式的資料代表具體哪個風格。例如總共有 5 種風格,訓練第一種的時候 style_control 為 [1, 0, 0, 0, 0]。

然後定義損失網路,這裡選取 VGG-19 網路,其中池化層使用平均池化而非最大池化,定義網路並且載入訓練好的模型:

def net(input_image, data):
    layers = (
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
        'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
        'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
        'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
        'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
        'relu5_3', 'conv5_4', 'relu5_4'
    )
    weights = data['layers'][0]
    net = {}
    current = input_image
    net['input'] = input_image
    for i, name in enumerate(layers):
        kind = name[:4]
        if kind == 'conv':
            kernels, bias = weights[i][0][0][0][0]
            kernels = np.transpose(kernels, (1, 0, 2, 3))
            bias = bias.reshape(-1)
            current = _conv_layer(current, kernels, bias)
        elif kind == 'relu':
            current = tf.nn.relu(current)
        elif kind == 'pool':
            current = _pool_layer(current)
        net[name] = current
    return [net['relu1_1'], net['relu2_1'], net['relu3_1'], net['relu4_1'], net['relu5_1'], net['relu4_2']]
def _conv_layer(input, weights, bias):
    conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), padding='SAME')
    return tf.nn.bias_add(conv, bias)
def _pool_layer(input):
#    return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
    return tf.nn.avg_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
            padding='SAME')
複製程式碼

值得注意的是,由於損失網路不參與訓練,引數需要固定。此處的入參 data 是訓練好的模型,需要提前下載並載入:

weights = scipy.io.loadmat('net/vgg19.mat')
複製程式碼

兩個網路已經搭建完成,接著定義內容輸入、風格輸入、遷移輸出:

 # content_input
content_input = tf.placeholder(tf.float32, shape=batch_shape, name='content_input')
# style_input
 style_img = get_img(style_target[style_index])
style_input = tf.constant((style_img[np.newaxis, ...]), dtype=tf.float32)
# output
style_control = [1 if i == style_index else 0 for i in range(style_num)]
transform_output = transform_net.net(content_input, alpha=alpha,
style_control=tf.constant(style_control, dtype=tf.float32))
複製程式碼

然後將原圖,風格圖,遷移圖均減去顏色均值再輸入到損失網路(vgg-19)中:

vgg_mean = tf.constant(np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3)), dtype='float32')
content_feats = vgg.net(content_input - vgg_mean, weights)
style_feats = vgg.net(style_input - vgg_mean, weights)
transform_feats = vgg.net(transform_output - vgg_mean, weights)
複製程式碼

通過損失網路的特徵輸出計算內容損失:

c_loss = content_weight * euclidean_loss(transform_feats[-1], content_feats[-1])
def euclidean_loss(input_, target_):
    b,w,h,c = input_.get_shape().as_list()
    return 2 * tf.nn.l2_loss(input_- target_) / b/w/h/c
複製程式碼

其中輸入引數分別取得是內容圖與遷移圖經過 vgg-19 網路的特徵層 relu4_2 , content_weight 為內容損失佔比係數1.5,可自行調節。

通過損失網路的特徵輸出計算風格損失:

s_loss = style_weight * sum([style_loss(transform_feats[i], style_feats[i]) for i in range(5)])
def style_loss(input_, style_):
    b,h,w,c = input_.get_shape().as_list()
    input_gram = gram_matrix(input_)
    style_gram = gram_matrix(style_)
    return 2 * tf.nn.l2_loss(input_gram - style_gram)/b/c/c
def gram_matrix(net):
    b,h,w,c = net.get_shape().as_list()
    feats = tf.reshape(net, (b, h*w, c))
    feats_T = tf.transpose(feats, perm=[0,2,1])
    grams = tf.matmul(feats_T, feats) / h/w/c
    return grams
複製程式碼

計算風格損失的輸入是風格圖與遷移圖經過 vgg 網路的特徵層:relu1_1, relu2_1, relu3_1, relu4_1,計算各層的 gram 矩陣,所有特徵層的損失相加,最終得出風格損失。

為了使影象更加平滑還加入了全變差正則,最後所有損失相加得到最終的損失函式:

def total_variation(preds):
     # total variation denoising
     b,w,h,c = preds.get_shape().as_list()
     y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:w-1,:,:])
     x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:h-1,:])
     tv_loss = 2*(x_tv + y_tv)/b/w/h/c
     return tv_loss
tv_loss=total_variation_weight*total_variation(transform_output)
loss = c_loss + s_loss + tv_loss
複製程式碼

然後定義優化器,降低總損失:

t_vars = tf.trainable_variables()
    var_list = [var for var in t_vars if '{0}'.format(style_index) + '_style' in var.name]
    print(var_list)
    if style_index == 0:
        train_opt = tf.train.AdamOptimizer(learning_rate, momentum).minimize(loss)
    else:
        train_opt = tf.train.AdamOptimizer(learning_rate, momentum).minimize(loss, var_list=var_list)
複製程式碼

這裡做了個特殊判斷,如果是首次訓練優化所有引數,之後固定卷積核引數只優化 conditional_instance_norm 層的引數。

最後分 batch 訓練即可得到最終的模型:

with tf.Session() as session:
    writer_train = tf.summary.FileWriter(tensorboard_dir, session=session)
    writer_train.add_graph(session.graph)
    session.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=t_vars)
    checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
    if checkpoint_file:
       print('restore checkpoint: {}'.format(checkpoint_file))
       saver.restore(session, checkpoint_file)
    for epoch in range(epochs):
        num_examples = len(content_targets)
        iterations = 0
        while iterations * batch_size < num_examples:
            start_time = time.time()
            curr = iterations * batch_size
            step = curr + batch_size
            content_batch = np.zeros(batch_shape, dtype=np.float32)
            for j, img_p in enumerate(content_targets[curr:step]):
                content_batch[j] = get_img(img_p, batch_shape[1:]).astype(np.float32)
                iterations += 1
                assert content_batch.shape[0] == batch_size
                feed_dict = {
                    content_input: content_batch
                }
                global_step += 1
                session.run(train_opt, feed_dict=feed_dict)
                if iterations % 10 == 0:
                    summary = session.run(summary_merge, feed_dict=feed_dict)
                    writer_train.add_summary(summary, global_step=global_step)
                    writer_train.flush()
                end_time = time.time()
                delta_time = end_time - start_time
                print("%s, batch time: %s" % (global_step, delta_time))
                if iterations > 0 and iterations % 100 == 0:
                    save_model(saver, session)
    save_model(saver, session)
print('train style end: {}'.format(style_index))
複製程式碼

訓練過程擁有很大的計算量,推薦使用 TensorFlow-gpu 版本進行訓練。我個人是在 Google Cloud Platform 上申請的機器訓練的,首次使用贈送 $300。

為了方便 Android 端移植,重新實現正向傳播,並儲存為 pb 格式:

style_index = 0
style_target = glob.glob(style_dir)
style_num = len(style_target)
style_control_array = [1 if i == style_index else 0 for i in range(style_num)]
print('style control: {}'.format(style_control_array))
img_data = get_img(transfer_image_file, (transfer_image_size, transfer_image_size, 3))
im_input_4d = img_data[np.newaxis, ...]
im_b, im_h, im_w, im_c = np.shape(im_input_4d)
img = tf.placeholder(tf.float32, [1, transfer_image_size, transfer_image_size, 3], name='input')
style_control = tf.placeholder(tf.float32, [style_num], name='style_num')
with tf.Session() as sess:
    preds = transform_net.net(img, style_control=style_control, alpha=alpha)
    print([node.name for node in sess.graph.as_graph_def().node])
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
    if checkpoint_file:
        print('restore checkpoint: {}'.format(checkpoint_file))
        saver.restore(sess, checkpoint_file)
    output = tf.gather(preds, 0, name="out_img")
    out = sess.run(output, feed_dict={img: im_input_4d, style_control: style_control_array})
scm.imsave(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'transfer.jpg'), out)
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['out_img'])
    with tf.gfile.GFile('./model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())
複製程式碼

先載入了前面訓練好的 ckpt 模型,最終將模型固化為 pb 格式。其中輸入為影象畫素rgb通道陣列(input),風格圖片型別位置(style_num);輸出為遷移影象的rgb通道陣列(out_img)。

模型移植

通過上述步驟產生的模型大小約為 8M 左右,可以通過模型量化減小模型。何為模型量化?一般情況下編寫的模型都是以 float32 作為儲存單位,在儘量不影響模型準確率的情況下可以使用更簡單的數值型別進行計算,既減小了模型大小又加快了計算速度。一般使用 8 位量化,理論上可以將模型減少 4 倍。參考:Reducing Core ML 2 Model Size by 4X Using Quantization in iOS 12

目前 TensorFlow 新版本的量化工具貌似只能量化成 tflite 格式,但是在量化的過程中失敗了,提示有不支援的 op,所以只能退而求其次使用 TensorFlow mobile 而不是 TensorFlow Lite,然後我是使用老版本 TensorFlow 的 tools 目錄下的工具量化的。

python /tensorflow/tools/quantization/quantize_graph.py --output_node_names=out_img --output=XXX --mode=eightbit --input=XXX
複製程式碼

經過量化模型縮小到 2 M,支援 16 種風格的轉化。

將模型檔案拷貝到 Android 工程的 assets 目錄下,並且整合 TensorFlow mobile:

implementation 'org.tensorflow:tensorflow-android:1.13.1'
複製程式碼

Android 程式碼實現:

public Bitmap stylizeImage(Bitmap bitmap, int model) {
    Log.w("launched", "stylized in tensor module");
    TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);
    bitmap = Bitmap.createScaledBitmap(bitmap, desiredSize, desiredSize, false);
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());

    long time = System.currentTimeMillis();
        getBitmapPixels(bitmap, floatValues);
        Log.w(TAG, "getBitmapPixels time:" + 
    (System.currentTimeMillis() - time));

    for (int i = 0; i < NUM_STYLES; ++i) {
        styleVals[i] = 0f;
    }
    styleVals[model] = 1f;

    time = System.currentTimeMillis();
    // Copy the input data into TensorFlow.
    Log.w("tensor", "Width: " + bitmap.getWidth() + ", Height: " + bitmap.getHeight());
    inferenceInterface.feed(INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3);
    inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES);

    inferenceInterface.run(new String[]{OUTPUT_NODE}, false);
    inferenceInterface.fetch(OUTPUT_NODE, floatValues);

    Log.w(TAG, "run model time:" + (System.currentTimeMillis() - time));
    time = System.currentTimeMillis();

    mergePixels(floatValues, intValues);

    bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    Log.w(TAG, "return bitmap time:" + + (System.currentTimeMillis() - time));
    return bitmap;
}
複製程式碼

先通過模型路徑初始化 TensorFlowInferenceInterface,獲取 Bitmap 的畫素值,通過 getBitmapPixels 方法轉化成 rgb 三通道的陣列,然後根據風格圖片的位置初始化 one-hot 形式的 style_num 輸入。最後得出風格化的輸出,由於輸出是 rgb 三通道陣列,然後通過 mergePixels 方法轉化成 int 陣列。出於效能考慮,兩個陣列遍歷處理(getBitmapPixels、mergePixels)統一由 native 實現,速度提升一個數量級,實現如下:

extern "C" JNIEXPORT void
JNICALL
Java_me_pqpo_awesomeimage_ImageStyle_mergePixels(JNIEnv *env, jobject obj, jfloatArray pix_, jintArray mergedPix_) {
    jfloat *pix = env->GetFloatArrayElements(pix_, NULL);
    jint *mergedPix = env->GetIntArrayElements(mergedPix_, NULL);
    int len = env->GetArrayLength(mergedPix_);
    for (int i = 0; i < len; ++i) {
        mergedPix[i] =
            0xFF000000
            | (((int) (pix[i * 3])) << 16)
            | (((int) (pix[i * 3 + 1])) << 8)
            | ((int) (pix[i * 3 + 2]));
     }
    return;
}

extern "C" JNIEXPORT void
JNICALL
Java_me_pqpo_awesomeimage_ImageStyle_getBitmapPixels(JNIEnv *env, jobject obj, jobject srcBitmap, jfloatArray pix_) {
    jfloat *pix = env->GetFloatArrayElements(pix_, NULL);

    void *srcPixels = 0;
    AndroidBitmapInfo srcBitmapInfo;
    try {
        AndroidBitmap_getInfo(env, srcBitmap, &srcBitmapInfo);
        AndroidBitmap_lockPixels(env, srcBitmap, &srcPixels);

        uint32_t srcHeight = srcBitmapInfo.height;
        uint32_t srcWidth = srcBitmapInfo.width;

        for (int i = 0; i < srcHeight * srcWidth; ++i) {
            int val = static_cast<int*>(srcPixels)[i];
            pix[i * 3] = static_cast<jfloat>(((val) & 0xFF));
            pix[i * 3 + 1] = static_cast<jfloat>(((val >> 8) & 0xFF));
            pix[i * 3 + 2] = static_cast<jfloat>(((val >> 16) & 0xFF));
        }

        AndroidBitmap_unlockPixels(env, srcBitmap);
        return;
    } catch (...) {
        AndroidBitmap_unlockPixels(env, srcBitmap);
        jclass je = env->FindClass("java/lang/Exception");
        env -> ThrowNew(je, "unknown");
        return;
    }
    return;
}
複製程式碼

使用方法:

ImageStyle imageStyle = new ImageStyle(MainActivity.this);
Bitmap bitmap = BitmapFactory.decodeResource(MainActivity.this.getResources(), R.mipmap.tubingen);
Bitmap styleBitmap = imageStyle.stylizeImage(bitmap, 0);
複製程式碼

以上程式碼在 Android 端風格化一張 1024*1024 的圖片大概需要 18s,下面通過裁剪網路的方式進一步縮小模型體積,提升轉化速度。最終風格化一張 1024*1024 的圖片時間壓縮到了 5s。

網路裁剪

主要思路是減小卷積核寬度,刪減殘差層,下面是裁剪過的轉化網路,參考:github.com/fritzlabs/f…

def net_small(x, style_control=None, reuse=False, alpha=1.0):
    with tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
        x = conv_layer(x, int(alpha * 32), 9, 1, style_control=style_control, name='conv1')
        x = conv_layer(x, int(alpha * 32), 3, 2, style_control=style_control, name='conv2')
        x = conv_layer(x, int(alpha * 32), 3, 2, style_control=style_control, name='conv3')
        x = residual_block(x, int(alpha * 32), 3, style_control=style_control, name='res1')
        x = residual_block(x, int(alpha * 32), 3, style_control=style_control, name='res2')
        x = residual_block(x, int(alpha * 32), 3, style_control=style_control, name='res3')
        x = conv_tranpose_layer(x, int(alpha * 32), 3, 2, style_control=style_control, name='up_conv1')
        x = pooling(x)
        x = conv_tranpose_layer(x, int(alpha * 32), 3, 2, style_control=style_control, name='up_conv2')
        x = pooling(x)
        x = conv_layer(x, 3, 9, 1, relu=False, style_control=style_control, name='output')
        preds = tf.nn.sigmoid(x) * 255.
    return preds
複製程式碼

將每一層卷積層的卷積核寬度減小到了 32 層,刪除了兩層殘差塊,並且引入了超引數 alpha 進一步減小卷積核寬度,原作者給出 alpha 值可以減小到 0.3,量化模型大小縮小到了 17kb,實驗下來效果損失比較大,大家可以多實驗,找出自己能接受的轉化效果對應的引數。

參考:

相關文章