GithubHelp home page GithubHelp logo

Comments (30)

aymericdamien avatar aymericdamien commented on May 17, 2024 2

Hi, you can add any new loss function to tflearn/objectives.py file (https://github.com/tflearn/tflearn/blob/master/tflearn/objectives.py). To use it, you just need to call its name in regression layer.

from tflearn.

aymericdamien avatar aymericdamien commented on May 17, 2024 1

TFLearn can now accept custom function for loss, it just have to match the following pattern:

def custom_objective(y_pred, y_true)
    ...
    return 'Tensor'

from tflearn.

Holded avatar Holded commented on May 17, 2024 1

@aymericdamien @r3db @selcouthlyBlue
Hi, In order to integrate the CTC Loss, I add the code in the tflearn/objectives.py file as follows:
def ctc_loss(y_pred,y_true):
with tf.name_scope(None):
return tf.nn.ctc_loss(y_pred,y_true,320)
where 320 is max length of input sequency.

However, there raise TypeError ("Expected labels(first argument) to be a SparseTensor").
Why the y_pred is not a SparseTensor? How to resolve this problem?

from tflearn.

liuhuiwisdom avatar liuhuiwisdom commented on May 17, 2024

hi I want to know where to add the code ?
def custom_objective(y_pred, y_true)
Is it in tflearn/objectives.py to add the code? thaks!

from tflearn.

r3db avatar r3db commented on May 17, 2024

@liuhuiwisdom

You can add the code anywhere, for example:

def my_objective(y_pred, y_true):
    with tf.name_scope(None):
        return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true))

And the call it like this:

regression(network, optimizer='adam', learning_rate=0.01, loss=my_objective, name='target')

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

@Holded care to post a gist of your code that'll reproduce that?

from tflearn.

Holded avatar Holded commented on May 17, 2024

The code is as follows, thank you @selcouthlyBlue
https://github.com/Holded/TFLearn_CTC-Loss

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

Where did you get the speech_data module?

from tflearn.

Holded avatar Holded commented on May 17, 2024

It's the python file, I've update it as the link above @selcouthlyBlue

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

Apparently, your labels are not sparse. You must first convert them to sparse labels. Here's how I do it:

def convert_to_sparse(labels, dtype=np.int32):
     indices = []
     values = []
 
     for n, seq in enumerate(labels):
         indices.extend(zip([n] * len(seq), range(len(seq))))
         values.extend(seq)
 
     indices = np.asarray(indices, dtype=dtype)
     values = np.asarray(values, dtype=dtype)
     shape = np.asarray([len(labels), np.asarray(indices).max(0)[1] + 1], dtype=dtype)
 
     return indices, values, shape

Then try feeding that to ctc_loss

from tflearn.

Holded avatar Holded commented on May 17, 2024

I've added the convert_to_sparse module in the speech_data and convert the labels in the demo.py;
However, the same error is raised, the modified file is as the link above, where did i go wrong?
Thank you @selcouthlyBlue

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I can't really run your code since I don't have the data.

Try putting the result of convert_to_sparse in tf.SparseTensorValue()

from tflearn.

Holded avatar Holded commented on May 17, 2024

The result of convert_to_sparse have been put in tf.SparseTensorValue() also.
However, the same error is raised.
I've upload four data files for you to test the code. Thank you very much! @selcouthlyBlue

from tflearn.

Holded avatar Holded commented on May 17, 2024

Hi, I've modified the tags variable Y to SparseTensor and has been tested successfully.
However, the same error is raised, the modified file is as follows:
https://github.com/Holded/TFLearn_CTC-Loss
Where did i go wrong? Thank you for your help! @selcouthlyBlue

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I would like to apologize for the late reply. I was on a small vacation during the holidays.

I've reproduced the error on my own with my own project. I'm encountering a different error using your code. Now onto figuring out how to fix it. I'm not sure about this but I think tflearn's DNN assumes that both X and Y are both dense.

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

One other option is to use tflearn's trainer to use custom loss functions, and placeholders (there's a sparse_placeholder for sparse labels).

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I created a workaround that avoids the SparseTensor errors. I used TFLearn's Trainer class to be able to use my own model and sparse_placeholders. However, I'm encountering a recursion error when I try to start the training.

To reproduce it, simply clone this repository and run main/train_using_tflearn_trainer.py. @Holded

from tflearn.

Holded avatar Holded commented on May 17, 2024

Sorry for the late reply. I have a final exam three days later and until then can I verify your solution.
Thanks for your help and kind reply !

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

You're welcome! I wish ya good luck in your exam :)

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I've addressed the recursion error. Apparently, the label_error_rate metric is causing the error, so I removed that one. Now, I'm encountering this error:

File "...Anaconda3\lib\threading.py", line 914, in _bootstrap_inner
    self.run()
  File "...\Anaconda3\lib\threading.py", line 862, in run
    self._target(*self._args, **self._kwargs)
  File "...Anaconda3\lib\site-packages\tflearn\data_flow.py", line 187, in fill_feed_dict_queue
    data = self.retrieve_data(batch_ids)
  File "...Anaconda3\lib\site-packages\tflearn\data_flow.py", line 222, in retrieve_data
    utils.slice_array(self.feed_dict[key], batch_ids)
  File "...Anaconda3\lib\site-packages\tflearn\utils.py", line 187, in slice_array
    return X[start]
TypeError: only integer scalar arrays can be converted to a scalar index

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I managed to address the error by replacing the sparse_placeholder with this

    label_indices = tf.placeholder(tf.int64)
    label_shape = tf.placeholder(tf.int64)
    label_values = tf.placeholder(tf.int64)
    Y = tf.SparseTensor(indices=label_indices, dense_shape=label_shape, values=label_values)
    Y = tf.cast(Y, dtype=tf.int32)

So the feed dict now looks like this:

    trainer.fit(feed_dicts=
                {X: x_train, 
                 label_indices: y_train[0], 
                 label_values: y_train[1], 
                 label_shape: y_train[2], 
                 seq_lens: dataset_utils.get_seq_lens(x_train)},
                        ...)

But now I'm encountering this error:

    Traceback (most recent call last):
      File ".../Optimized_OCR/main/train_using_tflearn_trainer.py", line 58, in <module>
        tf.app.run(main=main)
      File "...\Anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 48, in run
        _sys.exit(main(_sys.argv[:1] + flags_passthrough))
      File ".../Optimized_OCR/main/train_using_tflearn_trainer.py", line 54, in main
        n_epoch=1)
      File "...\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 338, in fit
        show_metric)
      File "...\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 817, in _train
        feed_batch)
      File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 889, in run
        run_metadata_ptr)
      File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1120, in _run
        feed_dict_tensor, options, run_metadata)
      File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1317, in _do_run
        options, run_metadata)
      File "...\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1336, in _do_call
        raise type(e)(node_def, op, message)
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[0] does not have value

from tflearn.

Holded avatar Holded commented on May 17, 2024

Sorry for the late response.
I've reproduced the error on my computer with your project. But I'm encountering a different error using your code.

Number of images read: 0/2 Done reading images WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2790>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2790>: The input_size parameter is deprecated. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d20d0>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d20d0>: The input_size parameter is deprecated. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2290>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794d2290>: The input_size parameter is deprecated. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794fd550>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. WARNING:tensorflow:<tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.LSTMCell object at 0x7f8f794fd550>: The input_size parameter is deprecated. Traceback (most recent call last): File "train_using_tflearn_trainer.py", line 59, in <module> tf.app.run(main=main) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 44, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "train_using_tflearn_trainer.py", line 46, in main net = dnn(X) File "train_using_tflearn_trainer.py", line 33, in dnn bidirectional_grid_rnn = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, x, dtype=tf.float32) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 350, in bidirectional_dynamic_rnn time_major=time_major, scope=fw_scope) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 545, in dynamic_rnn dtype=dtype) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 712, in _dynamic_rnn_loop swap_memory=swap_memory) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2615, in while_loop result = context.BuildLoop(cond, body, loop_vars, shape_invariants) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2448, in BuildLoop pred, body, original_loop_vars, loop_vars, shape_invariants) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2398, in _BuildLoop body_result = body(*packed_vars_for_body) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 697, in _time_step (output, new_state) = call_cell() File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 683, in <lambda> call_cell = lambda: cell(input_t, state) File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py", line 183, in __call__ if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len( TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I guess I should mention that I'm using Tensorflow 1.4

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

Try replacing the GridLSTMCells with BasicLSTMCells from tf contrib

from tflearn.

Holded avatar Holded commented on May 17, 2024

The tensorflow cannot be updated because I have no permission to modify the server.
I'm encountering this error:

Number of images read: 0/2 Done reading images Traceback (most recent call last): File "train_using_tflearn_trainer.py", line 61, in <module> tf.app.run(main=main) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 44, in run _sys.exit(main(_sys.argv[:1] + flags_passthrough)) File "train_using_tflearn_trainer.py", line 48, in main net = dnn(X) File "train_using_tflearn_trainer.py", line 35, in dnn bidirectional_grid_rnn = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, x, dtype=tf.float32) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 363, in bidirectional_dynamic_rnn seq_dim=time_dim, batch_dim=batch_dim) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 2331, in reverse_sequence name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 2863, in reverse_sequence batch_dim=batch_dim, name=name) File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 509, in apply_op (input_name, err)) ValueError: Tried to convert 'seq_lengths' to a tensor and failed. Error: None values not supported.

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I take it you encountered that error after replacing the GridLSTMCells? Can you print the tensorflow version you're using?

As for that error, try to remove the seq_lens placeholder and set the ctc_loss sequence_length parameter to 320. If that doesn't work, try to use the trainer for your own project and tell me how it goes.

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

I think one other way is to have the ctc_loss accept the dense labels and do the conversion internally. I made a feature request for such an op. So far, I've successfully converted a dense tensor into a sparse one and the issue now lies in ctc_loss.

from tflearn.

Holded avatar Holded commented on May 17, 2024

@selcouthlyBlue Thank you for your help! Using my own placeholders seems more complicated because the need to change the code. I agree with you to have the ctc_loss do the conversion internally.

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

You're welcome! I'm encountering the same problem myself for my project. Maybe you can help me with the feature request :))

from tflearn.

selcouthlyBlue avatar selcouthlyBlue commented on May 17, 2024

This should work:

def ctc_loss(y_pred, y_true):
    with tf.name_scope("CTCLoss"):
        indices = tf.where(tf.not_equal(y_true, tf.constant(0, dtype=y_true.dtype)))
        values = tf.gather_nd(y_true, indices)
        shape = tf.shape(y_true, out_type=tf.int64)
        sparse_y_true = tf.SparseTensor(
            indices,
            values,
            shape
        )
        return tf.nn.ctc_loss(inputs=y_pred, labels=sparse_y_true, sequence_length=320)

@Holded

from tflearn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.