Comments (30)
Hi, you can add any new loss function to tflearn/objectives.py file (https://github.com/tflearn/tflearn/blob/master/tflearn/objectives.py). To use it, you just need to call its name in regression
layer.
from tflearn.
TFLearn can now accept custom function for loss, it just have to match the following pattern:
def custom_objective(y_pred, y_true)
...
return 'Tensor'
from tflearn.
@aymericdamien @r3db @selcouthlyBlue
Hi, In order to integrate the CTC Loss, I add the code in the tflearn/objectives.py file as follows:
def ctc_loss(y_pred,y_true):
with tf.name_scope(None):
return tf.nn.ctc_loss(y_pred,y_true,320)
where 320 is max length of input sequency.
However, there raise TypeError ("Expected labels(first argument) to be a SparseTensor").
Why the y_pred is not a SparseTensor? How to resolve this problem?
from tflearn.
hi I want to know where to add the code ?
def custom_objective(y_pred, y_true)
Is it in tflearn/objectives.py to add the code? thaks!
from tflearn.
You can add the code anywhere, for example:
def my_objective(y_pred, y_true):
with tf.name_scope(None):
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true))
And the call it like this:
regression(network, optimizer='adam', learning_rate=0.01, loss=my_objective, name='target')
from tflearn.
@Holded care to post a gist of your code that'll reproduce that?
from tflearn.
The code is as follows, thank you @selcouthlyBlue
https://github.com/Holded/TFLearn_CTC-Loss
from tflearn.
Where did you get the speech_data module?
from tflearn.
It's the python file, I've update it as the link above @selcouthlyBlue
from tflearn.
Apparently, your labels are not sparse. You must first convert them to sparse labels. Here's how I do it:
def convert_to_sparse(labels, dtype=np.int32):
indices = []
values = []
for n, seq in enumerate(labels):
indices.extend(zip([n] * len(seq), range(len(seq))))
values.extend(seq)
indices = np.asarray(indices, dtype=dtype)
values = np.asarray(values, dtype=dtype)
shape = np.asarray([len(labels), np.asarray(indices).max(0)[1] + 1], dtype=dtype)
return indices, values, shape
Then try feeding that to ctc_loss
from tflearn.
I've added the convert_to_sparse module in the speech_data and convert the labels in the demo.py;
However, the same error is raised, the modified file is as the link above, where did i go wrong?
Thank you @selcouthlyBlue
from tflearn.
I can't really run your code since I don't have the data.
Try putting the result of convert_to_sparse in tf.SparseTensorValue()
from tflearn.
The result of convert_to_sparse have been put in tf.SparseTensorValue() also.
However, the same error is raised.
I've upload four data files for you to test the code. Thank you very much! @selcouthlyBlue
from tflearn.
Hi, I've modified the tags variable Y to SparseTensor and has been tested successfully.
However, the same error is raised, the modified file is as follows:
https://github.com/Holded/TFLearn_CTC-Loss
Where did i go wrong? Thank you for your help! @selcouthlyBlue
from tflearn.
I would like to apologize for the late reply. I was on a small vacation during the holidays.
I've reproduced the error on my own with my own project. I'm encountering a different error using your code. Now onto figuring out how to fix it. I'm not sure about this but I think tflearn's DNN assumes that both X and Y are both dense.
from tflearn.
One other option is to use tflearn's trainer to use custom loss functions, and placeholders (there's a sparse_placeholder for sparse labels).
from tflearn.
I created a workaround that avoids the SparseTensor errors. I used TFLearn's Trainer class to be able to use my own model and sparse_placeholders. However, I'm encountering a recursion error when I try to start the training.
To reproduce it, simply clone this repository and run main/train_using_tflearn_trainer.py. @Holded
from tflearn.
Sorry for the late reply. I have a final exam three days later and until then can I verify your solution.
Thanks for your help and kind reply !
from tflearn.
You're welcome! I wish ya good luck in your exam :)
from tflearn.
I've addressed the recursion error. Apparently, the label_error_rate
metric is causing the error, so I removed that one. Now, I'm encountering this error:
File "...Anaconda3\lib\threading.py", line 914, in _bootstrap_inner
self.run()
File "...\Anaconda3\lib\threading.py", line 862, in run
self._target(*self._args, **self._kwargs)
File "...Anaconda3\lib\site-packages\tflearn\data_flow.py", line 187, in fill_feed_dict_queue
data = self.retrieve_data(batch_ids)
File "...Anaconda3\lib\site-packages\tflearn\data_flow.py", line 222, in retrieve_data
utils.slice_array(self.feed_dict[key], batch_ids)
File "...Anaconda3\lib\site-packages\tflearn\utils.py", line 187, in slice_array
return X[start]
TypeError: only integer scalar arrays can be converted to a scalar index
from tflearn.
I managed to address the error by replacing the sparse_placeholder
with this
label_indices = tf.placeholder(tf.int64)
label_shape = tf.placeholder(tf.int64)
label_values = tf.placeholder(tf.int64)
Y = tf.SparseTensor(indices=label_indices, dense_shape=label_shape, values=label_values)
Y = tf.cast(Y, dtype=tf.int32)
So the feed dict now looks like this:
trainer.fit(feed_dicts=
{X: x_train,
label_indices: y_train[0],
label_values: y_train[1],
label_shape: y_train[2],
seq_lens: dataset_utils.get_seq_lens(x_train)},
...)
But now I'm encountering this error:
Traceback (most recent call last):
File ".../Optimized_OCR/main/train_using_tflearn_trainer.py", line 58, in <module>
tf.app.run(main=main)
File "...\Anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File ".../Optimized_OCR/main/train_using_tflearn_trainer.py", line 54, in main
n_epoch=1)
File "...\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 338, in fit
show_metric)
File "...\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 817, in _train
feed_batch)
File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
run_metadata_ptr)
File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
feed_dict_tensor, options, run_metadata)
File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
options, run_metadata)
File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[0] does not have value
from tflearn.
Sorry for the late response.
I've reproduced the error on my computer with your project. But I'm encountering a different error using your code.
Number of images read: 0/2 Done reading images WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2790>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2790>: The input_size parameter is deprecated. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d20d0>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d20d0>: The input_size parameter is deprecated. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2290>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2290>: The input_size parameter is deprecated. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794fd550>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794fd550>: The input_size parameter is deprecated. Traceback (most recent call last): File "train_using_tflearn_trainer.py", line 59, in <module> tf.app.run(main=main) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 44, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "train_using_tflearn_trainer.py", line 46, in main net = dnn(X) File "train_using_tflearn_trainer.py", line 33, in dnn bidirectional_grid_rnn = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, x, dtype=tf.float32) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 350, in bidirectional_dynamic_rnn time_major=time_major, scope=fw_scope) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 545, in dynamic_rnn dtype=dtype) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 712, in _dynamic_rnn_loop swap_memory=swap_memory) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2615, in while_loop result = context.BuildLoop(cond, body, loop_vars, shape_invariants) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2448, in BuildLoop pred, body, original_loop_vars, loop_vars, shape_invariants) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2398, in _BuildLoop body_result = body(*packed_vars_for_body) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 697, in _time_step (output, new_state) = call_cell() File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 683, in <lambda> call_cell = lambda: cell(input_t, state) File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py", line 183, in __call__ if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len( TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'
from tflearn.
I guess I should mention that I'm using Tensorflow 1.4
from tflearn.
Try replacing the GridLSTMCells with BasicLSTMCells from tf contrib
from tflearn.
The tensorflow cannot be updated because I have no permission to modify the server.
I'm encountering this error:
Number of images read: 0/2 Done reading images Traceback (most recent call last): File "train_using_tflearn_trainer.py", line 61, in <module> tf.app.run(main=main) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 44, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "train_using_tflearn_trainer.py", line 48, in main net = dnn(X) File "train_using_tflearn_trainer.py", line 35, in dnn bidirectional_grid_rnn = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, x, dtype=tf.float32) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 363, in bidirectional_dynamic_rnn seq_dim=time_dim, batch_dim=batch_dim) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 2331, in reverse_sequence name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 2863, in reverse_sequence batch_dim=batch_dim, name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 509, in apply_op (input_name, err)) ValueError: Tried to convert 'seq_lengths' to a tensor and failed. Error: None values not supported.
from tflearn.
I take it you encountered that error after replacing the GridLSTMCells? Can you print the tensorflow version you're using?
As for that error, try to remove the seq_lens placeholder and set the ctc_loss sequence_length parameter to 320. If that doesn't work, try to use the trainer for your own project and tell me how it goes.
from tflearn.
I think one other way is to have the ctc_loss
accept the dense labels and do the conversion internally. I made a feature request for such an op. So far, I've successfully converted a dense tensor into a sparse one and the issue now lies in ctc_loss
.
from tflearn.
@selcouthlyBlue Thank you for your help! Using my own placeholders seems more complicated because the need to change the code. I agree with you to have the ctc_loss do the conversion internally.
from tflearn.
You're welcome! I'm encountering the same problem myself for my project. Maybe you can help me with the feature request :))
from tflearn.
This should work:
def ctc_loss(y_pred, y_true):
with tf.name_scope("CTCLoss"):
indices = tf.where(tf.not_equal(y_true, tf.constant(0, dtype=y_true.dtype)))
values = tf.gather_nd(y_true, indices)
shape = tf.shape(y_true, out_type=tf.int64)
sparse_y_true = tf.SparseTensor(
indices,
values,
shape
)
return tf.nn.ctc_loss(inputs=y_pred, labels=sparse_y_true, sequence_length=320)
from tflearn.
Related Issues (20)
- No swish activation function? HOT 1
- HELLO? ANYBODY HOME? HOT 1
- Syntax warnings due to comparison of literals using is
- Error in TF2.0... Any ideas? HOT 3
- #001 'unicodeescape' code can't decode bytes in position 2-3: truncated \UXXXXXXXX escape HOT 1
- Update required HOT 1
- Not working with tensorflow 2.3.1 HOT 5
- Import error
- ValueError: Cannot feed value of shape (61,) for Tensor 'InputData/X:0', which has shape '(?, 61)' HOT 1
- Examples
- Xception Example model
- def variance_scaling in initializations.py tries to call deprecated class
- LSTM stateful?
- why i got same number with Alexnet ? HOT 1
- Why tflearn disables executing eagerly ? HOT 2
- OSS License compatibility question
- tflearn import issue HOT 5
- Pillow 10 breaks tflearn import HOT 2
- Requesting Review for pull request #1176
- Use github.com/apssouza22/chatflow as a conversational layer. It would enable actual API requests to be carried out from natural language inputs.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tflearn.