Source code for bob.learn.tensorflow.utils.keras

import logging

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.python.util import nest

logger = logging.getLogger(__name__)

SINGLE_LAYER_OUTPUT_ERROR_MSG = (
    "All layers in a Sequential model should have "
    "a single output tensor. For multi-output "
    "layers, use the functional API."
)


[docs]def keras_channels_index(): return -3 if K.image_data_format() == "channels_first" else -1
[docs]def keras_model_weights_as_initializers_for_variables(model): """Changes the initialization operations of variables in the model to take the current value as the initial values. This is useful when you want to restore a pre-trained Keras model inside the model_fn of an estimator. Parameters ---------- model : object A Keras model. """ sess = tf.compat.v1.keras.backend.get_session() n = len(model.variables) logger.debug("Initializing %d variables with their current weights", n) for variable in model.variables: value = variable.eval(sess) initial_value = tf.constant(value=value, dtype=value.dtype.name) variable._initializer_op = variable.assign(initial_value) variable._initial_value = initial_value
def _create_var_map(variables, normalizer=None): if normalizer is None: def normalizer(name): return name.split(":")[0] assignment_map = {normalizer(v.name): v for v in variables} assert len(assignment_map) return assignment_map
[docs]def restore_model_variables_from_checkpoint( model, checkpoint, session=None, normalizer=None ): if session is None: session = tf.compat.v1.keras.backend.get_session() var_list = _create_var_map(model.variables, normalizer=normalizer) saver = tf.compat.v1.train.Saver(var_list=var_list) ckpt_state = tf.train.get_checkpoint_state(checkpoint) logger.info("Loading checkpoint %s", ckpt_state.model_checkpoint_path) saver.restore(session, ckpt_state.model_checkpoint_path)
[docs]def initialize_model_from_checkpoint(model, checkpoint, normalizer=None): assignment_map = _create_var_map(model.variables, normalizer=normalizer) tf.compat.v1.train.init_from_checkpoint(checkpoint, assignment_map=assignment_map)
[docs]def model_summary(model, do_print=False): from tensorflow.keras.backend import count_params if model.__class__.__name__ == "Sequential": sequential_like = True elif not model._is_graph_network: # We treat subclassed models as a simple sequence of layers, for logging # purposes. sequential_like = True else: sequential_like = True nodes_by_depth = model._nodes_by_depth.values() nodes = [] for v in nodes_by_depth: if (len(v) > 1) or ( len(v) == 1 and len(nest.flatten(v[0].inbound_layers)) > 1 ): # if the model has multiple nodes # or if the nodes have multiple inbound_layers # the model is no longer sequential sequential_like = False break nodes += v if sequential_like: # search for shared layers for layer in model.layers: flag = False for node in layer._inbound_nodes: if node in nodes: if flag: sequential_like = False break else: flag = True if not sequential_like: break if sequential_like: # header names for the different log elements to_display = ["Layer (type)", "Details", "Output Shape", "Number of Parameters"] else: # header names for the different log elements to_display = [ "Layer (type)", "Details", "Output Shape", "Number of Parameters", "Connected to", ] relevant_nodes = [] for v in model._nodes_by_depth.values(): relevant_nodes += v rows = [to_display] def print_row(fields): for i, v in enumerate(fields): if isinstance(v, int): fields[i] = f"{v:,}" rows.append(fields) def layer_details(layer): cls_name = layer.__class__.__name__ details = [] if "Conv" in cls_name and "ConvBlock" not in cls_name: details += [f"filters={layer.filters}"] details += [f"kernel_size={layer.kernel_size}"] if "Pool" in cls_name and "Global" not in cls_name: details += [f"pool_size={layer.pool_size}"] if ( "Conv" in cls_name and "ConvBlock" not in cls_name or "Pool" in cls_name and "Global" not in cls_name ): details += [f"strides={layer.strides}"] if ( "ZeroPad" in cls_name or cls_name in ("Conv1D", "Conv2D", "Conv3D") or "Pool" in cls_name and "Global" not in cls_name ): details += [f"padding={layer.padding}"] if "Cropping" in cls_name: details += [f"cropping={layer.cropping}"] if cls_name == "Dense": details += [f"units={layer.units}"] if cls_name in ("Conv1D", "Conv2D", "Conv3D") or cls_name == "Dense": act = layer.activation.__name__ if act != "linear": details += [f"activation={act}"] if cls_name == "Dropout": details += [f"drop_rate={layer.rate}"] if cls_name == "Concatenate": details += [f"axis={layer.axis}"] if cls_name == "Activation": act = layer.get_config()["activation"] details += [f"activation={act}"] if "InceptionModule" in cls_name: details += [f"b1_c1={layer.filter_1x1}"] details += [f"b2_c1={layer.filter_3x3_reduce}"] details += [f"b2_c2={layer.filter_3x3}"] details += [f"b3_c1={layer.filter_5x5_reduce}"] details += [f"b3_c2={layer.filter_5x5}"] details += [f"b4_c1={layer.pool_proj}"] if cls_name == "LRN": details += [f"depth_radius={layer.depth_radius}"] details += [f"alpha={layer.alpha}"] details += [f"beta={layer.beta}"] if cls_name == "ConvBlock": details += [f"filters={layer.num_filters}"] details += [f"bottleneck={layer.bottleneck}"] details += [f"dropout_rate={layer.dropout_rate}"] if cls_name == "DenseBlock": details += [f"layers={layer.num_layers}"] details += [f"growth_rate={layer.growth_rate}"] details += [f"bottleneck={layer.bottleneck}"] details += [f"dropout_rate={layer.dropout_rate}"] if cls_name == "TransitionBlock": details += [f"filters={layer.num_filters}"] if cls_name == "InceptionA": details += [f"pool_filters={layer.pool_filters}"] if cls_name == "InceptionResnetBlock": details += [f"block_type={layer.block_type}"] details += [f"scale={layer.scale}"] details += [f"n={layer.n}"] if cls_name == "ReductionA": details += [f"k={layer.k}"] details += [f"kl={layer.kl}"] details += [f"km={layer.km}"] details += [f"n={layer.n}"] if cls_name == "ReductionB": details += [f"k={layer.k}"] details += [f"kl={layer.kl}"] details += [f"km={layer.km}"] details += [f"n={layer.n}"] details += [f"no={layer.no}"] details += [f"p={layer.p}"] details += [f"pq={layer.pq}"] if cls_name == "ScaledResidual": details += [f"scale={layer.scale}"] return ", ".join(details) def print_layer_summary(layer): """Prints a summary for a single layer. Arguments: layer: target layer. """ try: output_shape = layer.output_shape except AttributeError: output_shape = "multiple" except RuntimeError: # output_shape unknown in Eager mode. output_shape = "?" name = layer.name cls_name = layer.__class__.__name__ fields = [ name + " (" + cls_name + ")", layer_details(layer), output_shape, layer.count_params(), ] print_row(fields) def print_layer_summary_with_connections(layer): """Prints a summary for a single layer (including topological connections). Arguments: layer: target layer. """ try: output_shape = layer.output_shape except AttributeError: output_shape = "multiple" connections = [] for node in layer._inbound_nodes: if relevant_nodes and node not in relevant_nodes: # node is not part of the current network continue for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound(): connections.append( "{}[{}][{}]".format(inbound_layer.name, node_index, tensor_index) ) name = layer.name cls_name = layer.__class__.__name__ if not connections: first_connection = "" else: first_connection = connections[0] fields = [ name + " (" + cls_name + ")", layer_details(layer), output_shape, layer.count_params(), first_connection, ] print_row(fields) if len(connections) > 1: for i in range(1, len(connections)): fields = ["", "", "", "", connections[i]] print_row(fields) layers = model.layers for i in range(len(layers)): if sequential_like: print_layer_summary(layers[i]) else: print_layer_summary_with_connections(layers[i]) model._check_trainable_weights_consistency() if hasattr(model, "_collected_trainable_weights"): trainable_count = count_params(model._collected_trainable_weights) else: trainable_count = count_params(model.trainable_weights) non_trainable_count = count_params(model.non_trainable_weights) print_row([]) print_row( [ "Model", f"Parameters: total={trainable_count + non_trainable_count:,}, trainable={trainable_count:,}", ] ) if do_print: from tabulate import tabulate print() print(tabulate(rows, headers="firstrow")) print() return rows