TFEstimatorBase¶
Abstract Base Class for Estimating Synthesizer Parameters using TensorFlow.
Wraps TensorFlow library calls and provides functionality for training and making
synthesizer parameter predictions. Inheriting classes must implement the build_model()
method, which is automatically called upon construction and sets up the neural network
model.
Examples
Here is an example of extending TFEstimatorBase
to create a simple multi-layer
perceptron network:
import tensorflow as tf
from tensorflow.keras import layers
import spiegelib.estimators import TFEstimatorBase
class MLP(TFEstimatorBase):
def __init__(self, input_shape, num_outputs, **kwargs):
# Call TFEstimatorBase constructor
super().__init__(input_shape, num_outputs, **kwargs)
def build_model(self):
# Model must be defined in the model attribute
self.model = tf.keras.Sequential()
self.model.add(layers.Dense(50, input_shape=self.input_shape, activation='relu'))
self.model.add(layers.Dense(40, activation='relu'))
self.model.add(layers.Dense(30, activation='relu'))
self.model.add(layers.Dense(self.num_outputs))
self.model.compile(
optimizer=tf.optimizers.Adam(),
loss=TFEstimatorBase.rms_error,
metrics=['accuracy']
)
For a detailed example of running a synthesizer sound matching experiment using deep learning models like this, see the FM Sound Match Example.
-
class
spiegelib.estimator.
TFEstimatorBase
(input_shape=None, num_outputs=None, weights_path='', callbacks=[])¶ Bases:
spiegelib.estimator.estimator_base.EstimatorBase
- Parameters
input_shape (tuple, optional) – Shape of matrix that will be passed to model input
num_outputes (int, optional) – Number of outputs the model has. If estimating synthesizer parameters this will typically be the number of parameters.
weights_path (string, optional) – If given, model weights will be loaded from this file
callbacks (list, optional) – A list of callbacks to be passed into model fit method
-
model
¶ Attribute for the model, see TensorFlow docs
- Type
tf.keras.Model
-
abstract
build_model
()¶ Abstract method that should contain the model definition when implemented
-
add_training_data
(input, output, batch_size=64, shuffle_size=None)¶ Create a tf.data.Dataset from training data, and shuffles / batches data for training. Stores results in the
train_data
attribute.- Parameters
input (np.ndarray) – training data tensor (ex, audio features)
output (np.ndarray) – ground truth data (ex, parameter values)
batch_size – (int, optional): If provided, will batch data into batches of this size. None or 0 will be no batching. Defaults to 64.
shuffle_size (int, optional) – If provided will shuffle data with a buffer of this size. None or 0 corresponds to no shuffling. Defaults to None.
-
add_testing_data
(input, output, batch_size=64)¶ Create a tf.data.Dataset and optionally batches for validation. Stores results in the
test_data
attribute.- Parameters
input (np.ndarray) – validation data (ex, audio features)
output (np.ndarray) – ground truth data (ex, parameter values)
batch_size – (int, optional): If provided, will batch data into batches of this size. None or 0 will be no batching. Defaults to 64.
-
fit
(epochs=1, callbacks=[], **kwargs)¶ Train model on for a fixed number of epochs on training data and validation data if it has been added to this estimator
- Parameters
epochs (int, optional) – Number of epocs to train model on.
callbacks (list, optional) – List of callback functions for training.
kwargs – Keyword args passed to model fit method. See Tensflow Docs.
-
predict
(input)¶ Run prediction on input
- Arg (np.ndarray): matrix of input data to run predictions on. Can a single
instance or a batch
-
load_weights
(filepath, **kwargs)¶ Load model weights from H5 or TensorFlow file
- Parameters
filepath (string) – location of saved model weights
kwargs – optional keyword arguments passed to tf load_weights methods, see TensorFlow Docs.
-
save_weights
(filepath, **kwargs)¶ Save model weights to a HDF5 or TensorFlow file.
- Parameters
filepath (str) – filepath to save model weights. Using a file suffix of ‘.h5’ or ‘.keras’ will save in HDF5 format. Otherwise will save as TensorFlow.
kwargs – optional keyword arguments passed to tf save_weights method, see Tensflow Docs.
-
save_model
(filepath, **kwargs)¶ Save entire model
- Parameters
filepath (str) – path to SavedModel or H5 file to save the model.
kwargs – optional keyword arguments pass to tf save method, see TensorFlow Docs.
-
static
load
(filepath, **kwargs)¶ Load entire model and return an istantiated TFEstimatorBase class with the saved model loaded into it.
- Parameters
filepath (str) – path to SavedModel or H5 file of saved model.
kwargs – Keyword arguments to pass into load_model function. See TensorFlow Doc.
-
static
rms_error
(y_true, y_pred)¶ Static method for calculating root mean squared error between predictions and targets
- Parameters
y_true (Tensor) – Ground truth labels
y_pred (Tensor) – Predictions