TFEstimatorBase

Abstract Base Class for Estimating Synthesizer Parameters using TensorFlow.

Wraps TensorFlow library calls and provides functionality for training and making synthesizer parameter predictions. Inheriting classes must implement the build_model() method, which is automatically called upon construction and sets up the neural network model.

Examples

Here is an example of extending TFEstimatorBase to create a simple multi-layer perceptron network:

import tensorflow as tf
from tensorflow.keras import layers
import spiegelib.estimators import TFEstimatorBase

class MLP(TFEstimatorBase):

    def __init__(self, input_shape, num_outputs, **kwargs):

        # Call TFEstimatorBase constructor
        super().__init__(input_shape, num_outputs, **kwargs)

    def build_model(self):

        # Model must be defined in the model attribute
        self.model = tf.keras.Sequential()
        self.model.add(layers.Dense(50, input_shape=self.input_shape, activation='relu'))
        self.model.add(layers.Dense(40, activation='relu'))
        self.model.add(layers.Dense(30, activation='relu'))
        self.model.add(layers.Dense(self.num_outputs))

        self.model.compile(
            optimizer=tf.optimizers.Adam(),
            loss=TFEstimatorBase.rms_error,
            metrics=['accuracy']
        )

For a detailed example of running a synthesizer sound matching experiment using deep learning models like this, see the FM Sound Match Example.

class spiegelib.estimator.TFEstimatorBase(input_shape=None, num_outputs=None, weights_path='', callbacks=[])

Bases: spiegelib.estimator.estimator_base.EstimatorBase

Parameters
  • input_shape (tuple, optional) – Shape of matrix that will be passed to model input

  • num_outputes (int, optional) – Number of outputs the model has. If estimating synthesizer parameters this will typically be the number of parameters.

  • weights_path (string, optional) – If given, model weights will be loaded from this file

  • callbacks (list, optional) – A list of callbacks to be passed into model fit method

model

Attribute for the model, see TensorFlow docs

Type

tf.keras.Model

abstract build_model()

Abstract method that should contain the model definition when implemented

add_training_data(input, output, batch_size=64, shuffle_size=None)

Create a tf.data.Dataset from training data, and shuffles / batches data for training. Stores results in the train_data attribute.

Parameters
  • input (np.ndarray) – training data tensor (ex, audio features)

  • output (np.ndarray) – ground truth data (ex, parameter values)

  • batch_size – (int, optional): If provided, will batch data into batches of this size. None or 0 will be no batching. Defaults to 64.

  • shuffle_size (int, optional) – If provided will shuffle data with a buffer of this size. None or 0 corresponds to no shuffling. Defaults to None.

add_testing_data(input, output, batch_size=64)

Create a tf.data.Dataset and optionally batches for validation. Stores results in the test_data attribute.

Parameters
  • input (np.ndarray) – validation data (ex, audio features)

  • output (np.ndarray) – ground truth data (ex, parameter values)

  • batch_size – (int, optional): If provided, will batch data into batches of this size. None or 0 will be no batching. Defaults to 64.

fit(epochs=1, callbacks=[], **kwargs)

Train model on for a fixed number of epochs on training data and validation data if it has been added to this estimator

Parameters
  • epochs (int, optional) – Number of epocs to train model on.

  • callbacks (list, optional) – List of callback functions for training.

  • kwargs – Keyword args passed to model fit method. See Tensflow Docs.

predict(input)

Run prediction on input

Arg (np.ndarray): matrix of input data to run predictions on. Can a single

instance or a batch

load_weights(filepath, **kwargs)

Load model weights from H5 or TensorFlow file

Parameters
  • filepath (string) – location of saved model weights

  • kwargs – optional keyword arguments passed to tf load_weights methods, see TensorFlow Docs.

save_weights(filepath, **kwargs)

Save model weights to a HDF5 or TensorFlow file.

Parameters
  • filepath (str) – filepath to save model weights. Using a file suffix of ‘.h5’ or ‘.keras’ will save in HDF5 format. Otherwise will save as TensorFlow.

  • kwargs – optional keyword arguments passed to tf save_weights method, see Tensflow Docs.

save_model(filepath, **kwargs)

Save entire model

Parameters
  • filepath (str) – path to SavedModel or H5 file to save the model.

  • kwargs – optional keyword arguments pass to tf save method, see TensorFlow Docs.

static load(filepath, **kwargs)

Load entire model and return an istantiated TFEstimatorBase class with the saved model loaded into it.

Parameters
  • filepath (str) – path to SavedModel or H5 file of saved model.

  • kwargs – Keyword arguments to pass into load_model function. See TensorFlow Doc.

static rms_error(y_true, y_pred)

Static method for calculating root mean squared error between predictions and targets

Parameters
  • y_true (Tensor) – Ground truth labels

  • y_pred (Tensor) – Predictions