代码和上一讲只是相差了保存权重和增加了预测功能 2.callback权重填写函数 3.训练时调用 4 预测部分 #预测函数 对于前面几张的预测效果
1.权重保存的路径和读取方法checkpoint_save_path = "./checkpoint/mnist.ckpt" if os.path.exists(checkpoint_save_path+'.index'): print('---------------------load the model------------------') model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_best_only=True, save_weights_only=True)
history = model.fit(x_train,y_train,batch_size=32,epochs=10,validation_data=(x_test,y_test),validation_freq=2,callbacks=[cp_callback])
①网络复现checkpoint_save_path = "./checkpoint/mnist.ckpt" model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128,activation='relu'), tf.keras.layers.Dense(10,activation='softmax') ]) model.load_weights(checkpoint_save_path)
def preidct(img_path): img = Image.open(img_path) img =img.resize((28,28),Image.ANTIALIAS) img_arr = np.array(img.convert('L')) import matplotlib.pyplot as plt # plt.imshow(img,cmap='gray') # plt.show() for i in range(28): for j in range(28): if img_arr[i][j] < 200: img_arr[i][j] = 255 else: img_arr[i][j] = 0 # plt.imshow(img_arr,cmap='gray') # plt.show() img_arr = img_arr/255.0 x_predict = img_arr[tf.newaxis,...] result = model.predict(x_predict) pred = tf.argmax(result,axis=1) print('n') class_names=['T-shirt/top','Trouser','Pullover','Dress','Coat', 'Sandal','Shirt','Sneaker','Bag','Ankle boot'] pred = pred.numpy() print(class_names[pred[0]])
本网页所有视频内容由 imoviebox边看边下-网页视频下载, iurlBox网页地址收藏管理器 下载并得到。
ImovieBox网页视频下载器 下载地址: ImovieBox网页视频下载器-最新版本下载
本文章由: imapbox邮箱云存储,邮箱网盘,ImageBox 图片批量下载器,网页图片批量下载专家,网页图片批量下载器,获取到文章图片,imoviebox网页视频批量下载器,下载视频内容,为您提供.
阅读和此文章类似的: 全球云计算