本文针对场景识别,其中采用的模型是mobilenet_v2,如有需要,可替换成其他tf-slim支持模型; 我的数据集结构如下: 操作均在slim文件夹下 (2)在 (3)最后运行如下命令: 完成后可以看到如下终端输出: (1)将 (2)修改 通过以下命令训练: 训练时的终端输出如下: 通过tensorboard查看训练情况: 同样是在 (1)只用如下这个单步骤,会有问题,对于冻好的.pb文件会缺少第一个输入层,导致使用该.pb文件时报错说shape问题; (2)使用如下两个步骤可以正常使用,第一步先导出前向传播图,第二步再利用 预测代码(自己写的,较为简单): 终端输出: 预测正确right!
0. 前提
1. 准备工作
(1)tf-slim源码
git clone 官方代码仓:https://github.com/tensorflow/models
(2)准备完成自己的场景数据集
2. 将数据集转成tfrecord格式
(1)在datasets文件夹下,将download_and_convert_flowers.py
文件,复制一份重新命名为:download_and_convert_place6.py
,打开该文件进行修改,一共修改4处代码如下:from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import os import random import sys import tensorflow as tf from datasets import dataset_utils # The URL where the Flowers data can be downloaded. _DATA_URL = 'https://download.tensorflow.org/example_images/flower_photos.tgz' # The number of images in the validation set. _NUM_VALIDATION = 350 # Seed for repeatability. _RANDOM_SEED = 0 # The number of shards per dataset split. _NUM_SHARDS = 5 class ImageReader(object): """Helper class that provides TensorFlow image coding utilities.""" def __init__(self): # Initializes function that decodes RGB JPEG data. self._decode_jpeg_data = tf.placeholder(dtype=tf.string) self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) def read_image_dims(self, sess, image_data): image = self.decode_jpeg(sess, image_data) return image.shape[0], image.shape[1] def decode_jpeg(self, sess, image_data): image = sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}) assert len(image.shape) == 3 assert image.shape[2] == 3 return image def _get_filenames_and_classes(dataset_dir): """Returns a list of filenames and inferred class names. Args: dataset_dir: A directory containing a set of subdirectories representing class names. Each subdirectory should contain PNG or JPG encoded images. Returns: A list of image file paths, relative to `dataset_dir` and the list of subdirectories, representing class names. """ #改为自己的数据集 flower_root = os.path.join(dataset_dir, 'place_photos') directories = [] class_names = [] for filename in os.listdir(flower_root): path = os.path.join(flower_root, filename) if os.path.isdir(path): directories.append(path) class_names.append(filename) photo_filenames = [] for directory in directories: for filename in os.listdir(directory): path = os.path.join(directory, filename) photo_filenames.append(path) return photo_filenames, sorted(class_names) def _get_dataset_filename(dataset_dir, split_name, shard_id): #修改为place6 output_filename = 'place6_%s_%05d-of-%05d.tfrecord' % ( split_name, shard_id, _NUM_SHARDS) return os.path.join(dataset_dir, output_filename) def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): """Converts the given filenames to a TFRecord dataset. Args: split_name: The name of the dataset, either 'train' or 'validation'. filenames: A list of absolute paths to png or jpg images. class_names_to_ids: A dictionary from class names (strings) to ids (integers). dataset_dir: The directory where the converted datasets are stored. """ assert split_name in ['train', 'validation'] num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) with tf.Graph().as_default(): image_reader = ImageReader() with tf.Session('') as sess: for shard_id in range(_NUM_SHARDS): output_filename = _get_dataset_filename( dataset_dir, split_name, shard_id) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: start_ndx = shard_id * num_per_shard end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) for i in range(start_ndx, end_ndx): sys.stdout.write('r>> Converting image %d/%d shard %d' % ( i+1, len(filenames), shard_id)) sys.stdout.flush() # Read the filename: image_data = tf.gfile.GFile(filenames[i], 'rb').read() height, width = image_reader.read_image_dims(sess, image_data) class_name = os.path.basename(os.path.dirname(filenames[i])) class_id = class_names_to_ids[class_name] example = dataset_utils.image_to_tfexample( image_data, b'jpg', height, width, class_id) tfrecord_writer.write(example.SerializeToString()) sys.stdout.write('n') sys.stdout.flush() def _clean_up_temporary_files(dataset_dir): """Removes temporary files used to create the dataset. Args: dataset_dir: The directory where the temporary files are stored. """ filename = _DATA_URL.split('/')[-1] filepath = os.path.join(dataset_dir, filename) tf.gfile.Remove(filepath) tmp_dir = os.path.join(dataset_dir, 'flower_photos') tf.gfile.DeleteRecursively(tmp_dir) def _dataset_exists(dataset_dir): for split_name in ['train', 'validation']: for shard_id in range(_NUM_SHARDS): output_filename = _get_dataset_filename( dataset_dir, split_name, shard_id) if not tf.gfile.Exists(output_filename): return False return True def run(dataset_dir): """Runs the download and conversion operation. Args: dataset_dir: The dataset directory where the dataset is stored. """ if not tf.gfile.Exists(dataset_dir): tf.gfile.MakeDirs(dataset_dir) if _dataset_exists(dataset_dir): print('Dataset files already exist. Exiting without re-creating them.') return #无需下载,此行注释 #dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) photo_filenames, class_names = _get_filenames_and_classes(dataset_dir) class_names_to_ids = dict(zip(class_names, range(len(class_names)))) # Divide into train and test: random.seed(_RANDOM_SEED) random.shuffle(photo_filenames) training_filenames = photo_filenames[_NUM_VALIDATION:] validation_filenames = photo_filenames[:_NUM_VALIDATION] # First, convert the training and validation sets. _convert_dataset('train', training_filenames, class_names_to_ids, dataset_dir) _convert_dataset('validation', validation_filenames, class_names_to_ids, dataset_dir) # Finally, write the labels file: labels_to_class_names = dict(zip(range(len(class_names)), class_names)) dataset_utils.write_label_file(labels_to_class_names, dataset_dir) #此行注释,避免删除原来照片 #_clean_up_temporary_files(dataset_dir) print('nFinished converting the Flowers dataset!')
slim
文件夹下打开download_and_convert_data.py
文件,添加如下两处代码:from datasets import download_and_convert_place6
elif FLAGS.dataset_name == 'place6': download_and_convert_place6.run(FLAGS.dataset_dir)
python download_and_convert_data.py --dataset_name=place6 --dataset_dir=dataset的位置
数据集被分为 16848 (训练集)和 350(验证集);3. 模型训练过程
slim/datasets/flowers.py
复制一份,并且命名为place6.py
,将其中的3处内容修改如下:
可以知道,实际训练时的训练集设置为16198张图像,验证集设置为1000张图像;_FILE_PATTERN = 'place6_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 16198, 'validation': 1000} _NUM_CLASSES = 7 _ITEMS_TO_DESCRIPTIONS = { 'image': 'A color image of varying size.', 'label': 'A single integer between 0 and 6', }
slim/datasets/dataset_factory.py
内容,一共修改2处:from datasets import place6
datasets_map = { 'cifar10': cifar10, 'flowers': flowers, 'imagenet': imagenet, 'mnist': mnist, 'visualwakewords': visualwakewords, 'place6':place6, }
python3 train_image_classifier.py --train_dir=/*/research/slim/mobilenet_v2_place6/training --dataset_dir=/*/slim/intel_placedata/place6 --dataset_name=place6 --dataset_split_name=train --model_name=mobilenet_v2 --max_number_of_steps=40000 --learning_rate=0.001 --save_interval_secs=60 --save_summaries_secs=60 --log_every_n_steps=10 --optimizer=adam
~/*/slim/mobilenet_v2_place6$ tensorboard --logdir=training
4. 评价(evaluation)训练结果
slim
目录下,输入以下命令:python3 eval_image_classifier.py --checkpoint_path=/*/slim/mobilenet_v2_place6/training --eval_dir=/*/slim/mobilenet_v2_place6/eval --dataset_name=place6 --dataset_split_name=validation --dataset_dir=/*/slim/intel_placedata/place6 --model_name=mobilenet_v2
5. 最终冻成pb文件(使用tensorflow模块功能导出)
两种方式:
报错信息为: (1,224,224,3)的shape和(32,224,224,3)的shape不一致;
python3 -m tensorflow.python.tools.freeze_graph --input_graph /*/slim/mobilenet_v2_place6/training1/graph.pbtxt --input_checkpoint /*/slim/mobilenet_v2_place6/training1/model.ckpt-50000 --input_binary false --output_graph /*/slim/mobilenet_v2_place6/model/mobilenet_v2place6_frozen.pb --output_node_names MobilenetV2/Predictions/Reshape_1
前向传播图
+cpkt文件的参数导入到前向传播图中
得到最终的.pb模型,这个就可以使用:python export_inference_graph.py --alsologtostderr --dataset_dir=/*/slim/intel_placedata/place6 --dataset_name=place6 --model_name=mobilenet_v2 --image_size=224 --output_file=/*/slim/mobilenet_v2_place6/model/mobilenet_v2_inf.pb
python3 -m tensorflow.python.tools.freeze_graph --input_graph /*/slim/mobilenet_v2_place6/model/mobilenet_v2_inf.pb --input_checkpoint /*/slim/mobilenet_v2_place6/training1/model.ckpt-50000 --input_binary True --output_graph /*/slim/mobilenet_v2_place6/model/mobilenet_v2place61_frozen.pb --output_node_names MobilenetV2/Predictions/Reshape_1
6. 利用pb文件随机预测一张图
import tensorflow as tf import numpy as np import cv2 from datasets import dataset_utils from IPython import display import pylab import PIL image_dir='/*/slim/seg_pred/52.jpg' dataset_dir='/*/slim/intel_placedata/place6' model_dir ='/*/slim/mobilenet_v2_place6/model/mobilenet_v2place61_frozen.pb' display.display(display.Image(image_dir)) img = np.array(PIL.Image.open(image_dir).resize((224, 224))).astype(np.float) / 128 - 1 gd = tf.GraphDef.FromString(open(model_dir, 'rb').read()) inp, predictions = tf.import_graph_def(gd, return_elements = ['input:0','MobilenetV2/Predictions/Reshape_1:0']) with tf.Session(graph=inp.graph): x = predictions.eval(feed_dict={inp: img.reshape(1, 224,224, 3)}) label_map = dataset_utils.read_label_file(dataset_dir) print("Top 1 Prediction: ", x.argmax(),label_map[x.argmax()], x.max())
import tensorflow as tf import numpy as np import cv2 from datasets import dataset_utils from IPython import display import pylab import PIL from PIL import Image from PIL import ImageDraw from PIL import ImageFont import matplotlib.font_manager as fm image_dir='/*/slim/seg_pred/52.jpg' dataset_dir='/*/slim/intel_placedata/place6' model_dir ='/*/slim/mobilenet_v2_place6/model/mobilenet_v2place61_frozen.pb' #opencv class TOD(object): def __init__(self): self.PATH_TO_CKPT = '/*/slim/mobilenet_v2_place6/model/mobilenet_v2place61_frozen.pb' self.NUM_CLASSES = 7 self.detection_graph = self._load_model() self.label_map = dataset_utils.read_label_file(dataset_dir) def _load_model(self): detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') return detection_graph def visualization(self,image,str): image_pil = Image.fromarray(np.uint8(image)).convert('RGB') draw = ImageDraw.Draw(image_pil) font = ImageFont.truetype(fm.findfont(fm.FontProperties(family='DejaVu Sans')), 15) # 设置字体DejaVu Sans draw.text((10, 10), str, 'red', font) # 'fuchsia' np.copyto(image, np.array(image_pil)) return image def detect(self,image,resized): with self.detection_graph.as_default(): with tf.Session(graph=self.detection_graph) as sess: # Expand dimensions since the model expects images to have shape: [1, None, None, 3] image_np_expanded = np.expand_dims(resized, axis=0) inp = self.detection_graph.get_tensor_by_name('input:0') predictions = self.detection_graph.get_tensor_by_name('MobilenetV2/Predictions/Reshape_1:0') x = predictions.eval(feed_dict={inp: image_np_expanded}) font1 = str(self.label_map[x.argmax()]) font2 = str(x.max()) font3 = font1 + ":" + font2 image = self.visualization(image,font3) #print("Top 1 Prediction: ", x.argmax(), self.label_map[x.argmax()], x.max()) cv2.namedWindow("detection", cv2.WINDOW_NORMAL) cv2.imshow("detection", image) cv2.waitKey(0) if __name__ == '__main__': image = cv2.imread(image_dir) # dst=cv2.cvtColor(src,cv2.COLOR_BGR2GRAY) width = 224 height = 224 dim = (width, height) # resize image to [-1,1] Maps pixel values to the range [-1, 1] resized = (cv2.resize(image, dim)).astype(np.float) / 128 - 1 detecotr = TOD() detecotr.detect(image,resized)
本网页所有视频内容由 imoviebox边看边下-网页视频下载, iurlBox网页地址收藏管理器 下载并得到。
ImovieBox网页视频下载器 下载地址: ImovieBox网页视频下载器-最新版本下载
本文章由: imapbox邮箱云存储,邮箱网盘,ImageBox 图片批量下载器,网页图片批量下载专家,网页图片批量下载器,获取到文章图片,imoviebox网页视频批量下载器,下载视频内容,为您提供.
阅读和此文章类似的: 全球云计算