Source code for hyperts.framework.nas.layers._layers

import gc
import numpy as np
from hyperts.framework.dl import layers

from hypernets.core.search_space import ModuleSpace
from hypernets.core.ops import Choice, ModuleChoice, InputChoice


[docs]def compile_layer(search_space, layer_class, name, **kwargs): if kwargs.get('name') is None: kwargs['name'] = name cache = search_space.__dict__.get('weights_cache') if cache is not None: layer = cache.retrieve(kwargs['name']) if layer is None: layer = layer_class(**kwargs) cache.put(kwargs['name'], layer) else: layer = layer_class(**kwargs) return layer
[docs]class HyperLayer(ModuleSpace): def __init__(self, keras_layer_class, space=None, name=None, **hyperparams): self.keras_layer_class = keras_layer_class ModuleSpace.__init__(self, space, name, **hyperparams) def _compile(self): self.keras_layer = compile_layer(self.space, self.keras_layer_class, self.name, **self.param_values) def _forward(self, inputs): return self.keras_layer(inputs)
[docs]class CalibrateSize(ModuleSpace): def __init__(self, node, name_prefix, space=None, name=None, **hyperparams): self.node = node self.name_prefix = name_prefix self.reduce0 = None self.reduce1 = None ModuleSpace.__init__(self, space, name, **hyperparams) def _compile(self): self.compile_layer = compile_layer
[docs] def factorized_reduce(self, name_posfix, period, filters, strides=1): return self.compile_layer( search_space=self.space, layer_class=layers.FactorizedReduce, period=period, filters=filters, strides=strides, name=f'{self.name_prefix}_factorized_reduce_{name_posfix}')
[docs] def get_timestemp(self, x): return x.get_shape().as_list()[1]
[docs] def get_channels(self, x): return x.get_shape().as_list()[-1]
def _forward(self, inputs): if isinstance(inputs, list): t = [self.get_timestemp(inp) for inp in inputs] c = [self.get_channels(inp) for inp in inputs] min_t_value = int(np.min(t)) min_c_value = int(np.min(c)) x = inputs[self.node] if t[self.node] != min_t_value and self.reduce0 is None: self.reduce0 = self.factorized_reduce(f'timestemp{self.node}', min_t_value, min_c_value//2) if c[self.node] != min_c_value and self.reduce1 is None: self.reduce1 = self.factorized_reduce(f'variables{self.node}', min_t_value, min_c_value//2) if t[self.node] != min_t_value: x = self.reduce0(x) if c[self.node] != min_c_value: x = self.reduce1(x) return x else: return inputs
[docs]class SafeMerge(ModuleSpace): def __init__(self, name_prefix, ops='add', space=None, name=None, **hyperparams): self.ops = ops.lower() self.name_prefix = name_prefix ModuleSpace.__init__(self, space, name, **hyperparams) def _compile(self): pass def _on_params_ready(self): pass def _forward(self, inputs): if isinstance(inputs, list): pv = self.param_values if pv.get('name') is None: pv['name'] = self.name if self.ops == 'add': return layers.Add(name=pv['name'])(inputs) elif self.ops == 'concat': return layers.Concatenate(**pv)(inputs) else: raise ValueError(f'Not supported operation:{self.ops}') else: return inputs
[docs]def stem_ops(input, units=64): rnn = HyperLayer(layers.GRU, units=units, return_sequences=True, name='stem_gru') if input is None: input = rnn else: rnn(input) ln = HyperLayer(layers.LayerNormalization, name='stem_layernorm')(rnn) return ln, input
[docs]def cell_ops(inputs, name_prefix, block_no, node_no, cell_no, filters_or_units, kernel_size=(1, 3, 5)): name_prefix = f'{name_prefix}_block{block_no}_node{node_no}_cell{cell_no}' inpc = InputChoice(inputs, num_chosen_most=1, name=f'{name_prefix}_inputchoice')(inputs) if isinstance(filters_or_units, (tuple, list)): vaive_cnn_filters = Choice(list(filters_or_units)) depsep_cnn_filters = Choice(list(filters_or_units)) gru_filters = Choice(list(filters_or_units)) lstm_filters = Choice(list(filters_or_units)) else: vaive_cnn_filters = depsep_cnn_filters = gru_filters = lstm_filters = filters_or_units vaive_cnn = HyperLayer(layers.Conv1D, filters=vaive_cnn_filters, padding='same', activation='relu', kernel_size=Choice(list(kernel_size)), name=f'{name_prefix}_conv1d') depsep_cnn = HyperLayer(layers.SeparableConv1D, filters=depsep_cnn_filters, padding='same', activation='relu', kernel_size=Choice(list(kernel_size)), name=f'{name_prefix}_separableconv1d') gru = HyperLayer(layers.GRU, units=gru_filters, return_sequences=True, name=f'{name_prefix}_gru') lstm = HyperLayer(layers.LSTM, units=lstm_filters, return_sequences=True, name=f'{name_prefix}_lstm') identity = HyperLayer(layers.Identity, name=f'{name_prefix}_identity') op_choice = ModuleChoice([vaive_cnn, depsep_cnn, gru, lstm, identity], name=f'{name_prefix}_modulechoice')(inpc) return op_choice
[docs]def node_ops(inputs, name_prefix, block_no, node_no, filters_or_units=(16, 32, 64), kernel_size=(1, 3, 5)): cell0 = cell_ops(inputs, name_prefix, block_no, node_no, 0, filters_or_units, kernel_size) cell1 = cell_ops(inputs, name_prefix, block_no, node_no, 1, filters_or_units, kernel_size) out0 = CalibrateSize(node=0, name_prefix=f'{name_prefix}_block{block_no}_node{node_no}_reduce0')([cell0, cell1]) out1 = CalibrateSize(node=1, name_prefix=f'{name_prefix}_block{block_no}_node{node_no}_reduce1')([cell0, cell1]) out = merge_ops(inputs=[out0, out1], name_prefix=f'{name_prefix}_block{block_no}_node{node_no}') return out
[docs]def merge_ops(inputs, name_prefix, ops='add'): if ops == 'add': merge = HyperLayer(layers.Add, name=f'{name_prefix}_add')(inputs) elif ops == 'concat': merge = HyperLayer(layers.Concatenate, name=f'{name_prefix}_concat')(inputs) else: raise ValueError(f'Not supported operation:{ops}') return merge
[docs]class LayerWeightsCache(object): def __init__(self): self.reset() super(LayerWeightsCache, self).__init__()
[docs] def reset(self): self.cache = dict() self.hit_counter = 0 self.miss_counter = 0
[docs] def clear(self): del self.cache gc.collect() self.reset()
[docs] def hit(self): self.hit_counter += 1
[docs] def miss(self): self.miss_counter += 1
[docs] def put(self, key, layer): assert self.cache.get(key) is None, f'Duplicate keys are not allowed. key:{key}' self.cache[key] = layer
[docs] def retrieve(self, key): item = self.cache.get(key) if item is None: self.miss() else: self.hit() return item