基于Tensorflow Hub进行迁移学习完成人脸BMI指数预测

Tensorflow Hub

TF Hub是一个通过复用Tensorflow models来完成迁移学习的模型库,目前有自然语言、图像和视频三大类,具体可参考下面链接(部分页面需要翻墙):

https://www.tensorflow.org/hub

模型结构

首先对人物图片进行人脸识别,然后利用tfhub中inception v3模型提取feature vector,最后使用SVR模型完成基于人脸的BMI指数预测。

参考论文链接:https://arxiv.org/abs/1703.03156

人脸识别

这里介绍golang版本解决方案,python的资源丰富,例如face_recognition等。go-face提供了纯go版本的人脸识别功能,不需要安装opencv等复杂的环境依赖,相关的依赖也可以通过apt-get方式快速安装,值得注意的是其需要人脸识别的模型文件shape_predictor和dlib_face_recognition,具体介绍可以参考其github主页。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//主要代码
const dataDir = "testdata"
func main() {
// Init the recognizer.
rec, err := face.NewRecognizer(dataDir)
if err != nil {
log.Fatalf("Can't init face recognizer: %v", err)
}
// Free the resources when you're finished.
defer rec.Close()
// Test
testImage := filepath.Join(dataDir, "face.jpg")
// Recognize faces on that image.
faces, err := rec.RecognizeFile(testImagePristin)
if err != nil {
log.Fatalf("Can't recognize: %v", err)
}
}

tfhub

这里使用google发布的inception_v3模型,由于网络原因,如果在代码中无法下载可以选择手动下载并指定路径,下载时url为:

https://storage.googleapis.com/tfhub-modules/google/imagenet/inception_v3/feature_vector/1.tar.gz

模型下载完成并指定路径后可直接在hub中使用并获得输入图片的feature vector。

1
2
3
4
height, width = hub.get_expected_image_size(module_spec)
resized_input_tensor = tf.placeholder(tf.float32, [None, height, width, 3], name="input_tensor")
m = hub.Module(module_spec)
bottleneck_tensor = m(resized_input_tensor)

SVR模型

对于一般的回归问题,给定训练样本,模型希望学习得到一个f(x)与y尽可能的接近,只有f(x)和y完全相同时,损失才为零,而支持向量回归可以容忍f(x)与y之前最多有ε的偏差,当且仅当f(x)与y的差别绝对值大于ε时,才计算损失。此时相当于以f(x)为中心,构建一个宽度为2ε的间隔带,若训练样本落入此间隔带,则认为是被预测正确的。如下图所示:
图片
参考链接:

https://blog.csdn.net/zb123455445/article/details/78354489

Tensorflow中实现SVR模型首先设置和初始化W, b和ε,通过W*x+b获得final_tensor,最后计算loss,公式为:

1
2
with tf.name_scope('loss'):
loss = tf.reduce_mean(tf.maximum(0., tf.subtract(tf.abs(tf.subtract(final_tensor, ground_truth_input)), epsilon)))

模型代码

训练代码参考了tensorflow提供的鲜花分类的retrain.py代码,主要对loss函数,数据处理和模型导出做了修改。

参考连接:https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py