
Abstract Base Class for Estimating Synthesizer Parameters using TensorFlow.

Wraps TensorFlow library calls and provides functionality for training and making synthesizer parameter predictions. Inheriting classes must implement the build_model() method, which is automatically called upon construction and sets up the neural network model.


Here is an example of extending TFEstimatorBase to create a simple multi-layer perceptron network:

import tensorflow as tf
from tensorflow.keras import layers
import spiegelib.estimators import TFEstimatorBase

class MLP(TFEstimatorBase):

    def __init__(self, input_shape, num_outputs, **kwargs):

        # Call TFEstimatorBase constructor
        super().__init__(input_shape, num_outputs, **kwargs)

    def build_model(self):

        # Model must be defined in the model attribute
        self.model = tf.keras.Sequential()
        self.model.add(layers.Dense(50, input_shape=self.input_shape, activation='relu'))
        self.model.add(layers.Dense(40, activation='relu'))
        self.model.add(layers.Dense(30, activation='relu'))


For a detailed example of running a synthesizer sound matching experiment using deep learning models like this, see the FM Sound Match Example.

class spiegelib.estimator.TFEstimatorBase(input_shape=None, num_outputs=None, weights_path='', callbacks=[])

Bases: spiegelib.estimator.estimator_base.EstimatorBase

  • input_shape (tuple, optional) – Shape of matrix that will be passed to model input

  • num_outputes (int, optional) – Number of outputs the model has. If estimating synthesizer parameters this will typically be the number of parameters.

  • weights_path (string, optional) – If given, model weights will be loaded from this file

  • callbacks (list, optional) – A list of callbacks to be passed into model fit method


Attribute for the model, see TensorFlow docs



abstract build_model()

Abstract method that should contain the model definition when implemented

add_training_data(input, output, batch_size=64, shuffle_size=None)

Create a from training data, and shuffles / batches data for training. Stores results in the train_data attribute.

  • input (np.ndarray) – training data tensor (ex, audio features)

  • output (np.ndarray) – ground truth data (ex, parameter values)

  • batch_size – (int, optional): If provided, will batch data into batches of this size. None or 0 will be no batching. Defaults to 64.

  • shuffle_size (int, optional) – If provided will shuffle data with a buffer of this size. None or 0 corresponds to no shuffling. Defaults to None.

add_testing_data(input, output, batch_size=64)

Create a and optionally batches for validation. Stores results in the test_data attribute.

  • input (np.ndarray) – validation data (ex, audio features)

  • output (np.ndarray) – ground truth data (ex, parameter values)

  • batch_size – (int, optional): If provided, will batch data into batches of this size. None or 0 will be no batching. Defaults to 64.

fit(epochs=1, callbacks=[], **kwargs)

Train model on for a fixed number of epochs on training data and validation data if it has been added to this estimator

  • epochs (int, optional) – Number of epocs to train model on.

  • callbacks (list, optional) – List of callback functions for training.

  • kwargs – Keyword args passed to model fit method. See Tensflow Docs.


Run prediction on input

Arg (np.ndarray): matrix of input data to run predictions on. Can a single

instance or a batch

load_weights(filepath, **kwargs)

Load model weights from H5 or TensorFlow file

  • filepath (string) – location of saved model weights

  • kwargs – optional keyword arguments passed to tf load_weights methods, see TensorFlow Docs.

save_weights(filepath, **kwargs)

Save model weights to a HDF5 or TensorFlow file.

  • filepath (str) – filepath to save model weights. Using a file suffix of ‘.h5’ or ‘.keras’ will save in HDF5 format. Otherwise will save as TensorFlow.

  • kwargs – optional keyword arguments passed to tf save_weights method, see Tensflow Docs.

save_model(filepath, **kwargs)

Save entire model

  • filepath (str) – path to SavedModel or H5 file to save the model.

  • kwargs – optional keyword arguments pass to tf save method, see TensorFlow Docs.

static load(filepath, **kwargs)

Load entire model and return an istantiated TFEstimatorBase class with the saved model loaded into it.

  • filepath (str) – path to SavedModel or H5 file of saved model.

  • kwargs – Keyword arguments to pass into load_model function. See TensorFlow Doc.

static rms_error(y_true, y_pred)

Static method for calculating root mean squared error between predictions and targets

  • y_true (Tensor) – Ground truth labels

  • y_pred (Tensor) – Predictions