@hardmaru I've tried following your super helpful Sketch_RNN_TF_To_JS_Tutorial notebook, but ran into errors using a model I've just trained.
With the version of sketch-rnn-js in this repo I get this error:
numjs.js:5040 Uncaught Error: all the input arrays must have same number of dimensions
at new ValueError (numjs.js:5040)
at Object.concatenate (numjs.js:6156)
at new LSTMCell (sketch_rnn.js:383)
at load_model (sketch_rnn.js:584)
at new SketchRNN (sketch_rnn.js:1089)
at sketch_rnn.js:345
at XMLHttpRequest.xobj.onreadystatechange (sketch_rnn.js:275)
using your latest magenta-js version of sketch-rnn:
magentasketch.js:23262 Uncaught (in promise) Error: Constructing tensor of shape (NaN) should match the length of values (1024)
at Object.assert (magentasketch.js:23262)
at new Tensor (magentasketch.js:22128)
at Function.Tensor.make (magentasketch.js:22143)
at tensor (magentasketch.js:20430)
at Object.tensor2d (magentasketch.js:20469)
at SketchRNN.instantiateFromJSON (magentasketch.js:42407)
at SketchRNN.<anonymous> (magentasketch.js:42430)
at step (magentasketch.js:42371)
at Object.next (magentasketch.js:42352)
at fulfilled (magentasketch.js:42343)
More info on training this model:
- it uses giraffe.npz from quickdraw_dataset/sketchrnn
- I've used the following command to train:
sketch_rnn_train --data_dir=datasets\quickdraw_dataset --hparams="data_set=[giraffe.npz],num_steps=200000,conditional=0,dec_rnn_size=1024"
- the training environment is Windows 10 with Python 3.6.5 using
magenta-gpu
installed via pip
: magenta version 0.3.12
, tensorflow version
1.10.0`
- the Sketch RNN TF to JS python script has been minimally tweaked: it simply has a few print statements to check progress/errors and skips drawing the svg:
# import the required libraries
import numpy as np
import time
import random
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange
# libraries required for visualisation:
from IPython.display import SVG, display
import svgwrite # conda install -c omnia svgwrite=1.1.6
import PIL
from PIL import Image
import matplotlib.pyplot as plt
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *
print('import complete')
# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
# little function that displays vector images and saves them to .svg
def draw_strokes(data, factor=0.2, svg_filename = 'sample.svg'):
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
min_x, max_x, min_y, max_y = get_bounds(data, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
lift_pen = 1
abs_x = 25 - min_x
abs_y = 25 - min_y
p = "M%s,%s " % (abs_x, abs_y)
command = "m"
for i in xrange(len(data)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(data[i,0])/factor
y = float(data[i,1])/factor
lift_pen = data[i, 2]
p += command+str(x)+","+str(y)+" "
the_color = "black"
stroke_width = 1
dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
dwg.save()
display(SVG(dwg.tostring()))
# generate a 2D grid of many vector drawings
def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
def get_start_and_end(x):
x = np.array(x)
x = x[:, 0:2]
x_start = x[0]
x_end = x.sum(axis=0)
x = x.cumsum(axis=0)
x_max = x.max(axis=0)
x_min = x.min(axis=0)
center_loc = (x_max+x_min)*0.5
return x_start-center_loc, x_end
x_pos = 0.0
y_pos = 0.0
result = [[x_pos, y_pos, 1]]
for sample in s_list:
s = sample[0]
grid_loc = sample[1]
grid_y = grid_loc[0]*grid_space+grid_space*0.5
grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
start_loc, delta_pos = get_start_and_end(s)
loc_x = start_loc[0]
loc_y = start_loc[1]
new_x_pos = grid_x+loc_x
new_y_pos = grid_y+loc_y
result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])
result += s.tolist()
result[-1][2] = 1
x_pos = new_x_pos+delta_pos[0]
y_pos = new_y_pos+delta_pos[1]
return np.array(result)
# TODO: make these args
data_dir = './datasets/quickdraw_dataset'
model_dir = './checkpoints'
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env(data_dir, model_dir)
print('loaded env')
[hps_model, eval_hps_model, sample_hps_model] = load_model(model_dir)
print('loaded model',model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
print('preparing interactive session')
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print('interactive session ready')
def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2):
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
strokes = to_normal_strokes(sample_strokes)
if draw_mode:
draw_strokes(strokes, factor)
return strokes
print('loading checkpoints')
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)
print('loaded checkpoints')
# randomly unconditionally generate 10 examples
N = 10
reconstructions = []
for i in range(N):
reconstructions.append([decode(temperature=0.5, draw_mode=False), [0, i]])
# stroke_grid = make_grid_svg(reconstructions)
# draw_strokes(stroke_grid)
def get_model_params():
# get trainable params.
model_names = []
model_params = []
model_shapes = []
with sess.as_default():
t_vars = tf.trainable_variables()
for var in t_vars:
param_name = var.name
p = sess.run(var)
model_names.append(param_name)
params = p
model_params.append(params)
model_shapes.append(p.shape)
return model_params, model_shapes, model_names
def quantize_params(params, max_weight=10.0, factor=32767):
result = []
max_weight = np.abs(max_weight)
for p in params:
r = np.array(p)
r /= max_weight
r[r>1.0] = 1.0
r[r<-1.0] = -1.0
result.append(np.round(r*factor).flatten().astype(np.int).tolist())
return result
model_params, model_shapes, model_names = get_model_params()
print('got model params')
print('model_names',model_names)
# scale factor converts "model-coordinates" to "pixel coordinates" for your JS canvas demo later on.
# the larger it is, the larger your drawings (in pixel space) will be.
# I recommend setting this to 100.0 and iterating the value in the json file later on when you build the JS part.
scale_factor = 200.0
metainfo = {"mode":2,"version":6,"max_seq_len":train_set.max_seq_length,"name":"custom","scale_factor":scale_factor}
model_params_quantized = quantize_params(model_params)
print('quantized params')
model_blob = [metainfo, model_shapes, model_params_quantized]
# TODO: add filename arg
with open("giraffe.gen.full.json", 'w') as outfile:
json.dump(model_blob, outfile, separators=(',', ':'))
print('complete!')
There are a few details:
- The tensorflow graph looks different. The Python notebook graph looks like this:
['vector_rnn/RNN/output_w:0',
'vector_rnn/RNN/output_b:0',
'vector_rnn/RNN/LSTMCell/W_xh:0',
'vector_rnn/RNN/LSTMCell/W_hh:0',
'vector_rnn/RNN/LSTMCell/bias:0']
however the one I got looks like this:
['vector_rnn/ENC_RNN/fw/LSTMCell/W_xh:0',
'vector_rnn/ENC_RNN/fw/LSTMCell/W_hh:0',
'vector_rnn/ENC_RNN/fw/LSTMCell/bias:0',
'vector_rnn/ENC_RNN/bw/LSTMCell/W_xh:0',
'vector_rnn/ENC_RNN/bw/LSTMCell/W_hh:0',
'vector_rnn/ENC_RNN/bw/LSTMCell/bias:0',
'vector_rnn/ENC_RNN_mu/super_linear_w:0',
'vector_rnn/ENC_RNN_mu/super_linear_b:0',
'vector_rnn/ENC_RNN_sigma/super_linear_w:0',
'vector_rnn/ENC_RNN_sigma/super_linear_b:0',
'vector_rnn/linear/super_linear_w:0',
'vector_rnn/linear/super_linear_b:0',
'vector_rnn/RNN/output_w:0',
'vector_rnn/RNN/output_b:0',
'vector_rnn/RNN/LSTMCell/W_xh:0',
'vector_rnn/RNN/LSTMCell/W_hh:0',
'vector_rnn/RNN/LSTMCell/bias:0']
- The array after the model meta info in the json model looks different compared to existing pre-trained sketch-rnn models. For example hand.gen.json has this structure:
[[1024,123],[123],[5,4096],[1024,4096],[4096]]
while the recently generated giraffe.gen.json has this structure:
[[5,1024],[256,1024],[1024],[5,1024],[256,1024],[1024],[512,128],[128],[512,128],[128],[128,1024],[1024],[512,123],[123],[133,2048],[512,2048],[2048]]
- I have uploaded the model here
Somehow I'm generating the same network ? Maybe I just need the RNN and LTSM cell layers with their weights and biases, and not the whole encoder ? If so, how do I do that ?
Otherwise, what am I missing / doing wrong ?
How can I train a sketch-rnn model using a dataset from quickdraw ?
Thank you,
George