代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ Date : 2022/12/21 13:19
# @ Author : paperClubimport os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # 去掉警告import tensorflow
if int(tensorflow.__version__[0]) == 2:import tensorflow.compat.v1 as tftf.disable_v2_behavior()
else:import tensorflow as tfdef load_model(pb_model_file):graph = tf.Graph()with graph.as_default():graph_def = tf.GraphDef()graph_def.ParseFromString(open(pb_model_file, 'rb').read())tensors = tf.import_graph_def(graph_def, name="")session = tf.Session(graph=graph)with session.as_default():with graph.as_default():init = tf.global_variables_initializer()session.run(init)session.graph.get_operations()return sessionsession = None
if session is None:pb_model_file = "./tf2_model.pb"session = load_model(pb_model_file)layer_input = 'input_1:0' # 更新实际情况填写
layer_output = 'output' # 更新实际情况填写
img_input = '' # 更新实际情况填写feed_input = session.graph.get_tensor_by_name(layer_input)
feches = session.graph.get_tensor_by_name(layer_output)res = session.run(feches, feed_dict={feed_input: img_input})