google / trax Goto Github PK
View Code? Open in Web Editor NEWTrax — Deep Learning with Clear Code and Speed
License: Apache License 2.0
Trax — Deep Learning with Clear Code and Speed
License: Apache License 2.0
I think in resnet.py
, the padding option for MaxPool
should be 'SAME'
. The shape of the output of MaxPool
and the Resnet50ConvBlock
right after it becomes B x 55 x 55 x C
instead of B x 56 x 56 x C
. See Keras and PyTorch.
...
OS: <your answer here>
$ pip freeze | grep tensor
# your output here
$ pip freeze | grep jax
# your output here
$ python -V
# your output here
# Steps to reproduce:
...
# Error logs:
...
Very interesting and useful library! Thanks! My question is: How one arranges my_input function in order to run multiple batches using reformer. The text generation Colab only covers one batch. I got some useful information from configs (batch_fn) but still the arrangement of input is not clear. I have a sequence with 4M tokens and a 50000 vocabsize for a language model problem.
In addition to saving the latest checkpoint, save the best checkpoint according to the main evaluation metric.
Hi. I trying to train ReformerLM. I get code for training loop from this tutorial https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb#scrollTo=djTiSLcaNFGa. Training start normally, but Trax doesn't utilise second gpu at all. Model model is loading on the second gpu, but GPU-Util always on 0%, when first gpu utilize - 100%.
I tried to change batch size(now, I set in to 8), but if I change It to 10, training failed with OOM error.
Can you, please provide code for multi-gpu training?
OS: ubuntu 18.04
$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.0.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39
$ python -V
Python 3.7.6
Im attempting to train a Transformer model for machine translation with a shared vocabulary. As expected the input and target sequences are different lengths. I was expecting Trax to detect and pad the sequences accordingly. I didn't see examples or documentation for this exact problem. Any advice would be greatly appreciated.
OS: MacOS 10.14.6 (18G3020)
$ pip freeze | grep tensor
mesh-tensorflow==0.1.12
tensor2tensor==1.15.4
tensorboard==2.1.1
tensorflow==2.1.0
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.62
jaxlib==0.1.42
$ python -V
Python 3.7.6
# Steps to reproduce:
1. Download a parallel text corpus.
2. Create a vocabulary and tokenize the source and target text and save as TFRecords.
3. Run this following code to train a Transformer model:
import os
import tensorflow as tf
import trax
from src.common.params import Paths
def train_model(inputs: trax.supervised.Inputs, model_function, output_dir):
trainer = trax.supervised.Trainer(
model=model_function,
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adafactor,
lr_schedule=trax.lr.MultifactorSchedule,
inputs=inputs,
output_dir=output_dir
)
n_epochs = 10
train_steps = 10
eval_steps = 1
for _ in range(n_epochs):
trainer.train_epoch(train_steps, eval_steps)
def parse_example(serialized_example):
"""Return inputs and targets Tensors from a serialized tf.Example."""
data_fields = {
"inputs": tf.io.VarLenFeature(tf.int64),
"targets": tf.io.VarLenFeature(tf.int64)
}
parsed = tf.io.parse_single_example(serialized_example, data_fields)
inputs = tf.sparse.to_dense(parsed["inputs"])
targets = tf.sparse.to_dense(parsed["targets"])
return inputs, targets
def file_length(filename):
with open(filename) as f:
for i, l in enumerate(f):
pass
return i + 1
def main():
ML_ROOT = os.path.join(Paths.data_root, 'machine_translation')
trax_path = os.path.join(ML_ROOT, 'trax')
tokenizer_path = os.path.join(trax_path, 'subtoken.vocab')
tokenized_records_path = os.path.join(trax_path, 'tokenized')
os.makedirs(trax_path, exist_ok=True)
os.makedirs(tokenized_records_path, exist_ok=True)
tf_record_filenames = [os.path.join(tokenized_records_path, p) for p in
os.listdir(tokenized_records_path)]
dataset = tf.data.TFRecordDataset(tf_record_filenames).map(parse_example)
inputs = trax.supervised.Inputs(
train_stream=lambda _: dataset.as_numpy_iterator(),
eval_stream=lambda _: dataset.as_numpy_iterator()
)
# Peek into the inputs.
data_stream = inputs.train_stream(n_devices=1)
for _ in range(10):
sample_input, sample_target = next(data_stream)
print('-' * 100)
print("Inputs: %s, len: %s" % (str(sample_input), str(len(sample_input))))
print("Targets: %s, len: %s" % (str(sample_target), str(len(sample_target))))
vocab_size = file_length(tokenizer_path)
print('Vocab size:', vocab_size)
def transformer(mode):
return trax.models.Transformer(
vocab_size,
mode=mode
)
train_model(inputs, model_function=transformer, output_dir=os.path.expanduser('~/train_dir/'))
if __name__ == '__main__':
main()
----------------------------------------------------------------------------------------------------
Inputs: [704 656 32 769 2 588 820 936 2 47 4 1], len: 12
Targets: [946 947 950 942 937 462 2 7 5 238 21 377 336 4 1], len: 15
----------------------------------------------------------------------------------------------------
Inputs: [798 128 221 866 2 249 37 471 912 4 1], len: 11
Targets: [946 947 950 944 937 2 338 188 190 383 301 106 4 1], len: 14
----------------------------------------------------------------------------------------------------
Inputs: [ 55 641 34 425 685 53 426 391 356 426 4 1], len: 12
Targets: [946 947 948 953 937 216 10 46 230 104 196 172 364 187 187 21 4 1], len: 18
----------------------------------------------------------------------------------------------------
Inputs: [167 475 45 20 122 139 48 2 56 148 76 56 246 33 53 424 299 209
220 687 2 4 1], len: 23
Targets: [988 240 64 127 719 29 92 346 380 109 206 292 163 378 5 67 4 1], len: 18
----------------------------------------------------------------------------------------------------
Inputs: [ 66 156 84 687 43 246 56 687 2 40 246 81 687 739 258 56 148 30
687 50 156 83 54 246 56 687 2 697 156 30 687 45 4 1], len: 34
Targets: [946 947 950 944 937 2 71 698 980 932 827 827 2 13 322 4 1], len: 17
----------------------------------------------------------------------------------------------------
Inputs: [129 665 256 145 470 4 1], len: 7
Targets: [946 947 950 944 937 2 802 5 12 4 1], len: 11
----------------------------------------------------------------------------------------------------
Inputs: [225 255 861 85 148 45 28 56 148 50 687 95 246 62 148 165 2 431
51 181 246 393 739 86 66 148 432 41 4 1], len: 30
Targets: [946 947 950 944 937 2 123 380 304 77 11 745 16 67 153 709 126 618
462 2 4 1], len: 22
----------------------------------------------------------------------------------------------------
Inputs: [438 118 95 128 903 162 69 282 239 4 1], len: 11
Targets: [946 947 950 944 937 10 18 311 394 404 311 778 119 238 64 27 4 1], len: 18
----------------------------------------------------------------------------------------------------
Inputs: [101 602 195 310 37 19 1], len: 7
Targets: [946 947 942 951 937 261 12 231 261 17 1], len: 11
----------------------------------------------------------------------------------------------------
Inputs: [162 253 481 128 78 141 161 145 4 1], len: 10
Targets: [946 947 950 944 937 10 91 213 357 5 106 6 410 189 613 10 4 1], len: 18
Vocab size: 992
# Error logs:
/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/lib/xla_bridge.py:123: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Traceback (most recent call last):
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 480, in _forward_abstract
input_signature, weight_signature, self.state, rng)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/math/jax.py", line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 2104, in eval_shape
*map(abstractify, args_flat))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 274, in abstract_eval_fun
instantiate=True)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 220, in forward_with_state
return self.forward(inputs, weights), state
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/attention.py", line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 485, in _forward_abstract
trace)
trax.layers.base.LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
layer created in file [...]/trax/models/transformer.py, line 291
layer input shapes: ShapeDtype{shape:(1,), dtype:int32}
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2104, in eval_shape
*map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 301
layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})
File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
layer created in file [...]/trax/models/transformer.py, line 291
layer input shapes: ShapeDtype{shape:(1,), dtype:int32}
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2104, in eval_shape
*map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "src/machine_translation/trax/main.py", line 97, in <module>
main()
File "src/machine_translation/trax/main.py", line 93, in main
output_dir=os.path.expanduser('~/train_dir/'))
File "src/machine_translation/trax/main.py", line 21, in train_model
output_dir=output_dir
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 217, in __init__
self.reset(output_dir)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 297, in reset
opt_state, model_state = self._new_opt_state_and_model_state()
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 170, in <lambda>
model_target_shape, self._inputs.target_dtype, init_rng))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
name=flat_fun.__name__)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/core.py", line 895, in call_bind
outs = primitive.impl(f, *args, **params)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 457, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 220, in memoized_fun
ans = call(fun, *args)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 474, in _xla_callable
fun, pvals, instantiate=False, stage_out_calls=True, bottom=True)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/supervised/trainer_lib.py", line 159, in new_opt_state_and_model_state
weights, state = m.init(input_signature)
File "/Users/nathanielbush/.virtualenvs/trax/lib/python3.7/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/supervised/trainer_lib.py, line 157
layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})
File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 301
layer input shapes: (ShapeDtype{shape:(1,), dtype:int64}, ShapeDtype{shape:(1,), dtype:int64})
File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer ShiftRight (in _forward_abstract):
layer created in file [...]/trax/models/transformer.py, line 291
layer input shapes: ShapeDtype{shape:(1,), dtype:int32}
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2104, in eval_shape
*map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 274, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 358, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 42, in ShiftRight
pad_widths[1] = (n_shifts, 0) # Padding on axis=1
IndexError: list assignment index out of range
I think there might be a bug in sublayer rng setting. In the layers/base.py
file, the current code looks like:
Lines 487 to 495 in 0a2be4a
Shouldn't it be:
sublayer._set_rng_recursive(rng)
instead of
sublayer._rng = rng
I am not sure if this is the desired behavior, for how base Layers should be used, but see below for a minimum example of the current behavior and what the behavior becomes by changing that line.
N/A
# MINIMUM WORKING EXAMPLE
from jax import numpy as np
from jax import random as jax_random
from trax import layers as tl
ser1 = tl.Serial([
tl.Dense(2),
tl.Dense(2)
])
ser2 = tl.Serial([
tl.Dense(2),
tl.Dense(2)
])
double_ser = tl.Serial(ser1,ser2)
rng = jax_random.PRNGKey(0)
rng, subkey = jax_random.split(rng)
weights, state = double_ser.init(np.zeros([1,2]), rng=subkey)
print(weights[0][0])
print(weights[0][1])
print(weights[1][0])
print(weights[1][1])
WITHOUT CHANGE (current behavior, all weights are the same):
(DeviceArray([[-0.35201126, 0.34358203],
[ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))
(DeviceArray([[-0.35201126, 0.34358203],
[ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))
(DeviceArray([[-0.35201126, 0.34358203],
[ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))
(DeviceArray([[-0.35201126, 0.34358203],
[ 0.0111863 , -0.12183081]], dtype=float32), DeviceArray([ 2.1635685e-07, -5.2678536e-07], dtype=float32))
WITH CHANGE (all weights are different):
(DeviceArray([[ 0.38822186, 0.5617548 ],
[-0.3487 , -0.47204715]], dtype=float32), DeviceArray([-3.6210415e-07, 2.5783100e-07], dtype=float32))
(DeviceArray([[ 1.195412 , -0.91699356],
[-0.75880295, 0.7693857 ]], dtype=float32), DeviceArray([3.0248168e-07, 1.8994491e-07], dtype=float32))
(DeviceArray([[-1.1131637 , 1.1483511 ],
[ 0.5354116 , 0.78174126]], dtype=float32), DeviceArray([-8.4182403e-07, -1.8476302e-06], dtype=float32))
(DeviceArray([[-0.39427942, -0.25487363],
[-0.61524516, -0.75742614]], dtype=float32), DeviceArray([-6.4497357e-07, 9.1538845e-07], dtype=float32))
I'm getting a segmentation fault when trying to import trax
.
Has anyone encountered the same problem?
OS: Ubuntu 18.04.3 LTS
Docker image: tensorflow/tensorflow:2.1.0-gpu-py3
$ pip freeze | grep tensor
mesh-tensorflow==0.1.12
tensor2tensor==1.15.4
tensorboard==2.1.1
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-gpu==2.1.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.40
$ pip freeze | grep trax
trax==1.2.3
$ pip freeze | grep matplotlib
matplotlib==3.2.0
# also tried matplotlib==2.2.5 with same results
$ python -V
Python 3.6.9
$ lshw -C display
*-display
description: VGA compatible controller
product: GP104 [GeForce GTX 1080]
vendor: NVIDIA Corporation
physical id: 0
bus info: pci@0000:01:00.0
version: a1
width: 64 bits
clock: 33MHz
capabilities: vga_controller bus_master cap_list rom
configuration: driver=nvidia latency=0
resources: irq:126 memory:ee000000-eeffffff memory:d0000000-dfffffff memory:e0000000-e1ffffff ioport:e000(size=128) memory:ef000000-ef07ffff
# Steps to reproduce:
Python 3.6.9 (default, Nov 7 2019, 10:44:02)
[GCC 8.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> tf.__version__
'2.1.0'
>>> print(tf.config.list_physical_devices('GPU'))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
>>> import trax
Segmentation fault (core dumped)
Python 3.6.9 (default, Nov 7 2019, 10:44:02)
[GCC 8.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import faulthandler
>>> faulthandler.enable()
>>> import trax
Fatal Python error: Segmentation fault
Thread 0x00007f77d6bb5740 (most recent call first):
File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 1007 in addfont
File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 991 in __init__
File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 1334 in _rebuild
File "/tmp/env/lib/python3.6/site-packages/matplotlib/font_manager.py", line 1343 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "/tmp/env/lib/python3.6/site-packages/matplotlib/contour.py", line 16 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "/tmp/env/lib/python3.6/site-packages/matplotlib/colorbar.py", line 31 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "/tmp/env/lib/python3.6/site-packages/matplotlib/pyplot.py", line 32 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
File "/tmp/env/lib/python3.6/site-packages/matplotlib/__init__.py", line 1258 in use
File "/tmp/env/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py", line 358 in wrapper
File "/tmp/env/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py", line 296 in wrapper
File "/tmp/env/lib/python3.6/site-packages/tensor2tensor/data_generators/video_generated.py", line 35 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "<frozen importlib._bootstrap>", line 994 in _gcd_import
File "/usr/lib/python3.6/importlib/__init__.py", line 126 in import_module
File "/tmp/env/lib/python3.6/site-packages/tensor2tensor/data_generators/all_problems.py", line 140 in import_modules
File "/tmp/env/lib/python3.6/site-packages/tensor2tensor/problems_colab.py", line 36 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
File "/tmp/env/lib/python3.6/site-packages/trax/supervised/inputs.py", line 31 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
File "/tmp/env/lib/python3.6/site-packages/trax/supervised/__init__.py", line 18 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "/tmp/env/lib/python3.6/site-packages/trax/rl/simulated_env_problem.py", line 35 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
File "/tmp/env/lib/python3.6/site-packages/trax/rl/__init__.py", line 24 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "/tmp/env/lib/python3.6/site-packages/trax/learning_rate.py", line 294 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap>", line 1023 in _handle_fromlist
File "/tmp/env/lib/python3.6/site-packages/trax/__init__.py", line 19 in <module>
File "<frozen importlib._bootstrap>", line 219 in _call_with_frames_removed
File "<frozen importlib._bootstrap_external>", line 678 in exec_module
File "<frozen importlib._bootstrap>", line 665 in _load_unlocked
File "<frozen importlib._bootstrap>", line 955 in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 971 in _find_and_load
File "<stdin>", line 1 in <module>
Segmentation fault (core dumped)
As in previous examples shown, loss_fn
should be callable like this:
trainer = trax.supervised.Trainer(
model=eval(train_model.selector),
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adam,
lr_schedule=trax.lr.MultifactorSchedule,
inputs=trax.supervised.inputs.Inputs(train_stream),
output_dir=output_dir,
)
However, since the latest upgrade to 1.2.4 this cannot not work anymore.
In the trainer_lib
the loss_fn
gets passed to a Serial
constructor:
trax/trax/supervised/trainer_lib.py
Line 130 in 93f2bd4
Which in turn runs _ensure_flat
in it's constructor
trax/trax/layers/combinators.py
Line 47 in 5b15659
However, all objects in layers
have to be of type base.Laser
:
def _ensure_flat(layers):
"""Ensures that layers is a single flat list of Layer instances."""
if len(layers) == 1 and layers[0] is None:
layers = ()
else:
layers = _deep_flatten(layers)
for obj in layers:
if not isinstance(obj, base.Layer):
raise ValueError(
f'Found nonlayer object ({obj}) in layers: {layers}')
return layers
See
trax/trax/layers/combinators.py
Line 775 in 5b15659
Thus we'll see an exception:
ValueError: Found nonlayer object (<function CrossEntropyLoss at 0x7fc5be59a9e0>) in layers:
In line:
q_start
is checked to be of type int
. This means that if q_start = DeviceArray(0, dtype=int32)
e.g. the code jumps directly into "handling one token at a time and it is checked that q_len == 1
.
Removing the isinstance(q_start, int)
line works fine if I want to handle more than one token at a time.
What is this isinstance
check good for?
Dear people,
What would be the advantage of trax vs tf or Pytorch ?
Best,
T.C
OS: <your answer here>
$ pip freeze | grep tensor
# your output here
$ pip freeze | grep jax
# your output here
$ python -V
# your output here
# Steps to reproduce:
...
# Error logs:
...
As per our (currently my) discussion in the trax gitter there is significant interest in there being Trax codelabs. Here's a live prototype or a screenshot to get the feel:
As I mentioned in the Trax gitter while this might mirror the content of a notebook someone could just use on colab it could potentially provide some added benefit by helping to reduce cognitive load. But also as I mentioned there there can be colabs as well codelabs compiled from the colabs 😉
E.g. a notebook and a markdown that is being used to generate the codelab in the linked prototype
Hosting this on our docs site just for demo purposes. Prototype was a fork of the open-wc codelab package which was generated using https://github.com/googlecodelabs/tools iiuc.
Related but separate is the possibility of there being a trax docs site? 😏 We did ours with Vuepress (again forked from openwc) which is working out really well.
I modified the ende reformer config to train my own reformer model for a low resource language pair (18000 sentence pairs). Note that I am using GPUs and not TPUs. I found that the reformer encoder-decoder trains very slowly (1 second per batch on a V100). Is this normal? I was under the impression that the reformer trains fast. Am I missing something?
...
OS: CentOS
$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.14.0
tensorboard==1.15.0
tensorflow-datasets==1.3.2
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.0
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37
I have installed the GPU versions of this.
$ python -V
Python 3.6.8
# Steps to reproduce:
...
I run: python -m trax.trainer --config_file=$PWD/trax/configs/reformer_wmt_ende.gin
# Error logs:
...
N/A
I was trying to pip install trax on my local computer but I am unable to complete the installation due to some errors.
OS: Windows 10
Version 1909
OS Build 18363.628
$ pip freeze | grep tensor
-
$ pip freeze | grep jax
-
$ python -V
Python 3.7.4
# Steps to reproduce:
...
Open CMD and run pip install trax
# Error logs:
...
Collecting trax
Using cached trax-1.2.2-py2.py3-none-any.whl (311 kB)
Collecting jax
Using cached jax-0.1.58.tar.gz (262 kB)
Requirement already satisfied: numpy in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from tr
ax) (1.18.1)
Requirement already satisfied: scipy in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from tr
ax) (1.4.1)
Collecting gin-config
Using cached gin_config-0.3.0-py3-none-any.whl (44 kB)
Collecting funcsigs
Using cached funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)
Requirement already satisfied: absl-py in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from
trax) (0.9.0)
Collecting tensorflow-datasets
Using cached tensorflow_datasets-2.0.0-py3-none-any.whl (3.1 MB)
Collecting tensor2tensor
Using cached tensor2tensor-1.15.4-py2.py3-none-any.whl (1.4 MB)
Requirement already satisfied: six in c:\users\yuqua\appdata\local\programs\python\python37\lib\site-packages (from trax
) (1.12.0)
Collecting gym
Using cached gym-0.15.6.tar.gz (1.6 MB)
ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none)
ERROR: No matching distribution found for jaxlib (from trax)
Hello,
I train my Reformer model with parameters
n_encoder_layers = 3,
n_encoder_layers = 3,
d_model = 512,
ff_size = 2048,
attention_heads_num = 8,
dropout = 0.1,
max_len=250
on Google Colab TPU. My sequences are padded to the same length, so I feed this shapes to the Reformer:
x.shape = (batch_size, 256)
y.shape = (batch_size, 128)
The batch size the TPU accepts is 2048, not more. Training step lasts for approximately 1.15 seconds. So, it means that I'm able to put 2048 * (256 + 128) = 786432 int64 numbers to all 8 TPU cores. If I choose bigger batch (e.g. 4096), it won't fit into memory and show me this error stack trace:
RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 14.93G of 8.00G hbm. Exceeded hbm capacity by 6.93G.
Total hbm usage >= 14.93G:
reserved 529.00M
program 14.41G
arguments unknown size
Output size unknown.
Program hbm requirement 14.41G:
reserved 4.0K
global 100.0K
HLO temp 14.41G (98.3% utilization: Unpadded (14.16G) Padded (14.40G), 0.0% fragmentation (7.25M))
Largest program allocations in hbm:
1. Size: 5.97G
Shape: f32[512,100,30000]{2,1,0:T(8,128)}
Unpadded size: 5.72G
Extra memory due to padding: 250.62M (1.0x expansion)
XLA label: %copy.1598 = f32[512,100,30000]{2,1,0:T(8,128)} copy(f32[512,100,30000]{0,2,1} %copy.1597)
Allocation type: HLO temp
==========================
2. Size: 5.72G
Shape: f32[512,100,30000]{0,2,1}
Unpadded size: 5.72G
XLA label: %copy.1597 = f32[512,100,30000]{0,2,1} copy(f32[512,100,30000]{2,1,0:T(8,128)} %get-tuple-element.2671)
Allocation type: HLO temp
==========================
3. Size: 1.00G
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None ]"
Shape: f32[512,8,256,256]{2,3,1,0:T(8,128)}
Unpadded size: 1.00G
XLA label: %fusion.15833 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256,256]{2,3,1,0:T(8,128)}) fusion(f32[512,256]{1,0:T(8,128)} %get-tuple-element.2417, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.727, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.233),...
Allocation type: HLO temp
==========================
4. Size: 256.00M
Operator: op_type="dot_general" op_name="dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None ]"
Shape: f32[512,256,512]{2,1,0:T(8,128)}
Unpadded size: 256.00M
XLA label: %fusion.659 = f32[512,256,512]{2,1,0:T(8,128)} fusion(f32[512,256,512]{2,1,0:T(8,128)} %get-tuple-element.2643, f32[512]{0:T(512)} %get-tuple-element.2794, f32[512,512]{1,0:T(8,128)} %reshape.4754, f32[512,256,2048]{2,1,0:T(8,128)} %fusion.276, f32[512,204...
Allocation type: HLO temp
==========================
5. Size: 256.00M
Operator: op_type="reduce_sum" op_name="pmap(mapped_update)/reduce_sum[ axes=(2,)\n input_shape=(512, 256, 512) ]"
Shape: f32[512,256,512]{2,1,0:T(8,128)}
Unpadded size: 256.00M
XLA label: %fusion.15774 = (f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}, f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}) fusion(f32[512,256,512]{2,1,0:T(8,128)} %get-tuple-element.2319, f32[512,256,512]{2,1,0:T(8,128)} %fusion.659, f32[...
Allocation type: HLO temp
==========================
6. Size: 256.00M
Operator: op_type="reduce_sum" op_name="reduce_sum[ axes=(2,)\n input_shape=(512, 256, 512) ]"
Shape: f32[512,256,512]{2,1,0:T(8,128)}
Unpadded size: 256.00M
XLA label: %fusion.15776 = (f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}) fusion(f32[512,256,512]{2,1,0:T(8,128)} %get-tuple-element.2337, f32[512,256,512]{2,1,0:T(8,128)} %fusion.33), kind=kLoop, calls=%fused_computation.15091, metadata={op_type="red...
Allocation type: HLO temp
==========================
7. Size: 256.00M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None ]"
Shape: f32[512,256,512]{2,1,0:T(8,128)}
Unpadded size: 256.00M
XLA label: %fusion.15832 = (f32[512,256]{1,0:T(8,128)}, f32[512,256,512]{2,1,0:T(8,128)}) fusion(f32[512,256,512]{2,1,0:T(8,128)} %fusion.659, f32[512,512]{1,0:T(8,128)} %reshape.4767, f32[512]{0:T(512)} %get-tuple-element.2794, f32[2048,512,1]{1,0,2:T(8,128)} %bitca...
Allocation type: HLO temp
==========================
8. Size: 128.00M
Operator: op_type="gather" op_name="pmap(mapped_update)/gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))\n operand_shape=(512, 256, 512)\n slice_sizes=(1, 256, 512) ]"
Shape: bf16[512,256,512]{2,1,0:T(8,128)(2,1)}
Unpadded size: 128.00M
XLA label: %fusion.5 = bf16[512,256,512]{2,1,0:T(8,128)(2,1)} fusion(bf16[512,256,512]{2,1,0:T(8,128)(2,1)} %fusion.593, s32[512]{0} %get-tuple-element.2747), kind=kCustom, calls=%fused_computation.5, metadata={op_type="gather" op_name="pmap(mapped_update)/gather[ di...
Allocation type: HLO temp
==========================
9. Size: 128.00M
Operator: op_type="gather" op_name="gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0,))\n operand_shape=(512, 256, 512)\n slice_sizes=(1, 256, 512) ]"
Shape: bf16[512,256,512]{2,1,0:T(8,128)(2,1)}
Unpadded size: 128.00M
XLA label: %fusion.10 = bf16[512,256,512]{2,1,0:T(8,128)(2,1)} fusion(bf16[512,256,512]{2,1,0:T(8,128)(2,1)} %fusion.660, s32[512]{0} %get-tuple-element.2748), kind=kCustom, calls=%fused_computation.10, metadata={op_type="gather" op_name="gather[ dimension_numbers=Ga...
Allocation type: HLO temp
==========================
10. Size: 128.00M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None ]"
Shape: bf16[512,8,256,64]{0,3,2,1:T(8,128)(2,1)}
Unpadded size: 128.00M
XLA label: %copy.980 = bf16[512,8,256,64]{0,3,2,1:T(8,128)(2,1)} copy(bf16[512,8,256,64]{2,3,1,0:T(8,128)(2,1)} %fusion.349), metadata={op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n ...
Allocation type: HLO temp
==========================
11. Size: 100.00M
Shape: f32[512,100,512]{2,0,1}
Unpadded size: 100.00M
XLA label: %copy.1046 = f32[512,100,512]{2,0,1} copy(f32[512,100,512]{2,1,0:T(8,128)} %get-tuple-element.2670)
Allocation type: HLO temp
==========================
12. Size: 100.00M
Shape: f32[512,100,512]{2,0,1}
Unpadded size: 100.00M
XLA label: %copy.1044 = f32[512,100,512]{2,0,1} copy(f32[512,100,512]{2,1,0:T(8,128)} %get-tuple-element.2224)
Allocation type: HLO temp
==========================
13. Size: 100.00M
Shape: f32[512,100,512]{2,0,1}
Unpadded size: 100.00M
XLA label: %copy.1052 = f32[512,100,512]{2,0,1} copy(f32[512,100,512]{2,1,0:T(8,128)} %get-tuple-element.2222)
Allocation type: HLO temp
==========================
14. Size: 4.00M
Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
Shape: f32[512,2048]{1,0:T(8,128)}
Unpadded size: 4.00M
XLA label: %reshape.4770 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1786), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
Allocation type: HLO temp
==========================
15. Size: 4.00M
Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
Shape: f32[512,2048]{1,0:T(8,128)}
Unpadded size: 4.00M
XLA label: %reshape.4773 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1785), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
Allocation type: HLO temp
==========================
16. Size: 4.00M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None ]"
Shape: f32[512,8,256]{2,1,0:T(8,128)}
Unpadded size: 4.00M
XLA label: %fusion.15833 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256,256]{2,3,1,0:T(8,128)}) fusion(f32[512,256]{1,0:T(8,128)} %get-tuple-element.2417, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.727, bf16[512,256,8,64]{1,3,2,0:T(8,128)(2,1)} %fusion.233),...
Allocation type: HLO temp
==========================
17. Size: 4.00M
Operator: op_type="reduce_sum" op_name="pmap(mapped_update)/reduce_sum[ axes=(3,)\n input_shape=(512, 8, 256, 256) ]"
Shape: f32[512,8,256]{2,1,0:T(8,128)}
Unpadded size: 4.00M
XLA label: %fusion.366 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256]{2,1,0:T(8,128)}) fusion(f32[512,8,256,256]{2,3,1,0:T(8,128)} %get-tuple-element.2649, f32[512,8,256]{2,1,0:T(8,128)} %get-tuple-element.2648, f32[512,8,256]{2,1,0:T(8,128)} %fusion.1865), kind=k...
Allocation type: HLO temp
==========================
18. Size: 4.00M
Operator: op_type="reduce_sum" op_name="pmap(mapped_update)/reduce_sum[ axes=(3,)\n input_shape=(512, 8, 256, 256) ]"
Shape: f32[512,8,256]{2,1,0:T(8,128)}
Unpadded size: 4.00M
XLA label: %fusion.366 = (f32[512,8,256]{2,1,0:T(8,128)}, f32[512,8,256]{2,1,0:T(8,128)}) fusion(f32[512,8,256,256]{2,3,1,0:T(8,128)} %get-tuple-element.2649, f32[512,8,256]{2,1,0:T(8,128)} %get-tuple-element.2648, f32[512,8,256]{2,1,0:T(8,128)} %fusion.1865), kind=k...
Allocation type: HLO temp
==========================
19. Size: 4.00M
Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
Shape: f32[512,2048]{1,0:T(8,128)}
Unpadded size: 4.00M
XLA label: %reshape.4764 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1789), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
Allocation type: HLO temp
==========================
20. Size: 4.00M
Operator: op_type="mul" op_name="pmap(mapped_update)/mul"
Shape: f32[512,2048]{1,0:T(8,128)}
Unpadded size: 4.00M
XLA label: %reshape.4766 = f32[512,2048]{1,0:T(8,128)} reshape(f32[1048576]{0:T(1024)} %fusion.1788), metadata={op_type="mul" op_name="pmap(mapped_update)/mul"}
Allocation type: HLO temp
==========================
However, I wonder how the authors of this quite popular paper - https://ufal.mff.cuni.cz/pbml/110/art-popel-bojar.pdf were able to fit bigger batches to the GTX 1080 Ti GPU and managed to reach such high throughput for a single GPU (page 9, for batch size 500 they have 43400 steps per hour (21 700 000 examples per hour)), whereas the Reformer model has only (2048 / 1.15) * 60 * 60 = 4 915 199 examples per hour? Am I mistaken or what am I doing wrong?
Thanks.
I'm running command:
python -m trax.trainer --config_file=$PWD/trax/configs/mlp_mnist.gin
It trains the model, prints Finished training.
and then hangs forever. This process is not using CPU but it does not exit either. It never returns to the shell and I have to terminate it using Ctrl-C
Trax: 1.2.2
OS: Ubuntu 18.04
$ pip freeze | grep tensor
mesh-tensorflow==0.1.4
neptune-tensorboard==0.3.8
tensor2tensor==1.14.1
tensorboard==1.15.0
tensorflow==1.15.0
tensorflow-datasets==1.3.0
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.1
tensorflow-probability==0.8.0
$ pip freeze | grep jax
jax==0.1.51
jaxlib==0.1.32
$ python -V
Python 3.7.3
python -m trax.trainer --config_file=$PWD/trax/configs/mlp_mnist.gin
I0217 13:01:30.008547 140681594124096 trainer_lib.py:751] Step 2000: Ran 200 train steps in 4.19 secs
Step 2000: Ran 200 train steps in 4.19 secs
I0217 13:01:30.012206 140681594124096 trainer_lib.py:751] Step 2000: Evaluation
Step 2000: Evaluation
I0217 13:01:30.073523 140681594124096 trainer_lib.py:751] Step 2000: train accuracy | 0.99804688
Step 2000: train accuracy | 0.99804688
I0217 13:01:30.074985 140681594124096 trainer_lib.py:751] Step 2000: train loss | 0.01065689
Step 2000: train loss | 0.01065689
I0217 13:01:30.076281 140681594124096 trainer_lib.py:751] Step 2000: train neg_log_perplexity | 0.01065689
Step 2000: train neg_log_perplexity | 0.01065689
I0217 13:01:30.077118 140681594124096 trainer_lib.py:751] Step 2000: train weights_per_batch_per_core | 256.00000000
Step 2000: train weights_per_batch_per_core | 256.00000000
I0217 13:01:30.423881 140681594124096 trainer_lib.py:751] Step 2000: eval accuracy | 0.96406251
Step 2000: eval accuracy | 0.96406251
I0217 13:01:30.424737 140681594124096 trainer_lib.py:751] Step 2000: eval loss | 0.62180674
Step 2000: eval loss | 0.62180674
I0217 13:01:30.426048 140681594124096 trainer_lib.py:751] Step 2000: eval neg_log_perplexity | 0.62180674
Step 2000: eval neg_log_perplexity | 0.62180674
I0217 13:01:30.426502 140681594124096 trainer_lib.py:751] Step 2000: eval weights_per_batch_per_core | 256.00000000
Step 2000: eval weights_per_batch_per_core | 256.00000000
I0217 13:01:30.427090 140681594124096 trainer_lib.py:751] Step 2000: Finished evaluation
Step 2000: Finished evaluation
I0217 13:01:30.445652 140681594124096 trainer_lib.py:751] Model saved to /home/pawel/trax/MLP_mnist_20200217_1300/model.pkl
I0217 13:01:30.446032 140681594124096 trainer_lib.py:751] Step 2000: Training done
Step 2000: Training done
I0217 13:01:30.446371 140681594124096 trainer_lib.py:751] Finished training.
Finished training.
Im trying to train the basic Reformer and not the RefomerLM on long sequence of text based on the language generation example. Simply by replacing the RefomerLM class with the reformer and remove the mask, but feeding in the entire crime and punishment book, throws the following error:
TypeError: requesting more random bits than a single call provides.
everything works fine if I cut down the input to smaller sequences. The example can be seen in the following notebook:
https://colab.research.google.com/drive/1C9KOHOfuVhoOqzRKx_rRaeOV3jPvuZ_L
Failed to sample from the Reformer model after training on my local machine.
No codes changed.
OS: Ubuntu 16.04 LTS
$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.14.1
tensorboard==1.15.0
tensorboardcolab==0.0.22
tensorflow-datasets==2.0.0
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-privacy==0.2.2
tensorflow-probability==0.7.0
tensorflow-serving-api-gpu==1.13.0
tensorflow-tensorboard==0.4.0
$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37
$ python -V
Python 3.6.9 :: Anaconda, Inc.
# Steps to reproduce:
run Text Generation on the local machine
https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb#scrollTo=favRDt3U4CJY
# Error logs:
...
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/base.py in _forward_internal(self, x, weights, state, rng)
452 outputs, s = self.forward_with_state(
--> 453 x, weights=weights, state=state, rng=rng)
454 else:
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/combinators.py in forward_with_state(self, xs, weights, state, **kwargs)
59 self._validate_forward_inputs(xs)
---> 60 rngs = _pop_rng_and_split(kwargs, self._n_layers)
61 if not self.sublayers: # No-op: leave args unchanged.
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/combinators.py in _pop_rng_and_split(args_dict, n_copies)
688 return (None,) * n_copies
--> 689 return math.random.split(rng, n_copies)
690
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/math/backend.py in split(self, prng, num)
122 def split(self, prng, num=2):
--> 123 return backend()['random_split'](prng, num)
124
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/random.py in split(key, num)
243 """
--> 244 return _split(key, num)
245
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
148 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
150 return tree_unflatten(out_tree(), out)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
604 tracers = map(top_trace.full_raise, args)
--> 605 outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
606 return apply_todos(env_trace_todo(), outs)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
125 fun, aux = partial_eval(f, self, in_pvs)
--> 126 out_flat = call_primitive.bind(fun, *in_consts, **params)
127 out_pvs, jaxpr, env = aux()
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
601 with new_sublevel():
--> 602 outs = primitive.impl(f, *args, **params)
603 else:
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
441 backend = params['backend']
--> 442 compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
443 try:
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
222 else:
--> 223 ans = call(fun, *args)
224 cache[key] = (ans, fun.stores)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
458 with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459 jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
460 assert not env # no subtraces here
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
151
--> 152 ans = self.f(*args, **dict(self.params, **kwargs))
153 del args
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/random.py in _split(key, num)
248 counts = lax.tie_in(key, lax.iota(onp.uint32, num * 2))
--> 249 return lax.reshape(threefry_2x32(key, counts), (num, 2))
250
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
148 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
150 return tree_unflatten(out_tree(), out)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
604 tracers = map(top_trace.full_raise, args)
--> 605 outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
606 return apply_todos(env_trace_todo(), outs)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
125 fun, aux = partial_eval(f, self, in_pvs)
--> 126 out_flat = call_primitive.bind(fun, *in_consts, **params)
127 out_pvs, jaxpr, env = aux()
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
604 tracers = map(top_trace.full_raise, args)
--> 605 outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
606 return apply_todos(env_trace_todo(), outs)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
125 fun, aux = partial_eval(f, self, in_pvs)
--> 126 out_flat = call_primitive.bind(fun, *in_consts, **params)
127 out_pvs, jaxpr, env = aux()
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
601 with new_sublevel():
--> 602 outs = primitive.impl(f, *args, **params)
603 else:
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
441 backend = params['backend']
--> 442 compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
443 try:
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
222 else:
--> 223 ans = call(fun, *args)
224 cache[key] = (ans, fun.stores)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
458 with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459 jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
460 assert not env # no subtraces here
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
151
--> 152 ans = self.f(*args, **dict(self.params, **kwargs))
153 del args
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/random.py in threefry_2x32(keypair, count)
215 """
--> 216 key1, key2 = keypair
217 if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == onp.uint32:
ValueError: not enough values to unpack (expected 2, got 1)
During handling of the above exception, another exception occurred:
LayerError Traceback (most recent call last)
<ipython-input-30-58abfd2a9337> in <module>
1 # Sample from the Reformer language model, given a prefix.
----> 2 samples = sample(length=128, prompt="There was a time when")
3 for ids in samples:
4 print(TOKENIZER.DecodeIds(ids.tolist()))
<ipython-input-29-f9fee3fa3424> in sample(length, prompt)
19 model_weights,
20 cur_state,
---> 21 rngs)
22
23 if prompt is not None and iteration < prompt.shape[1]:
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/api.py in f_jitted(*args, **kwargs)
147 _check_args(args_flat)
148 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 149 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
150 return tree_unflatten(out_tree(), out)
151
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/core.py in call_bind(primitive, f, *args, **params)
600 if top_trace is None:
601 with new_sublevel():
--> 602 outs = primitive.impl(f, *args, **params)
603 else:
604 tracers = map(top_trace.full_raise, args)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
440 device = params['device']
441 backend = params['backend']
--> 442 compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
443 try:
444 return compiled_fun(*args)
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
221 fun.populate_stores(stores)
222 else:
--> 223 ans = call(fun, *args)
224 cache[key] = (ans, fun.stores)
225 return ans
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *arg_specs)
457 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
458 with core.new_master(pe.StagingJaxprTrace, True) as master:
--> 459 jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
460 assert not env # no subtraces here
461 del master, env
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
150 gen = None
151
--> 152 ans = self.f(*args, **dict(self.params, **kwargs))
153 del args
154 while stack:
~/anaconda3/envs/mindlogic/lib/python3.6/site-packages/trax/layers/base.py in _forward_internal(self, x, weights, state, rng)
460 name, trace = self.__class__.__name__, _short_traceback()
461 raise LayerError(name, '_forward_internal',
--> 462 self._caller, signature(x), trace)
463
464 def _forward_abstract(self, input_signature):
LayerError: Exception passing through layer Serial (in _forward_internal):
layer created in file [...]/models/reformer/reformer.py, line 612
layer input shapes: ShapeDtype{shape:(1, 1, 1), dtype:int32}
File [...]/trax/layers/combinators.py, line 60, in forward_with_state
rngs = _pop_rng_and_split(kwargs, self._n_layers)
File [...]/trax/layers/combinators.py, line 689, in _pop_rng_and_split
return math.random.split(rng, n_copies)
File [...]/trax/math/backend.py, line 123, in split
return backend()['random_split'](prng, num)
File [...]/site-packages/jax/random.py, line 244, in split
return _split(key, num)
File [...]/site-packages/jax/api.py, line 149, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
File [...]/site-packages/jax/core.py, line 605, in call_bind
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
File [...]/jax/interpreters/partial_eval.py, line 126, in process_call
out_flat = call_primitive.bind(fun, *in_consts, **params)
File [...]/site-packages/jax/core.py, line 602, in call_bind
outs = primitive.impl(f, *args, **params)
File [...]/jax/interpreters/xla.py, line 442, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
File [...]/site-packages/jax/linear_util.py, line 223, in memoized_fun
ans = call(fun, *args)
File [...]/jax/interpreters/xla.py, line 459, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 152, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/random.py, line 249, in _split
return lax.reshape(threefry_2x32(key, counts), (num, 2))
File [...]/site-packages/jax/api.py, line 149, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
File [...]/site-packages/jax/core.py, line 605, in call_bind
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
File [...]/jax/interpreters/partial_eval.py, line 126, in process_call
out_flat = call_primitive.bind(fun, *in_consts, **params)
File [...]/site-packages/jax/core.py, line 605, in call_bind
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
File [...]/jax/interpreters/partial_eval.py, line 126, in process_call
out_flat = call_primitive.bind(fun, *in_consts, **params)
File [...]/site-packages/jax/core.py, line 602, in call_bind
outs = primitive.impl(f, *args, **params)
File [...]/jax/interpreters/xla.py, line 442, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
File [...]/site-packages/jax/linear_util.py, line 223, in memoized_fun
ans = call(fun, *args)
File [...]/jax/interpreters/xla.py, line 459, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 152, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/random.py, line 216, in threefry_2x32
key1, key2 = keypair
ValueError: not enough values to unpack (expected 2, got 1)
Have you implemented LSH for Enc-Dec attention? I know that the motivation behind full attention was that Enc-Dec is mostly used for MT and full-attention for Enc-Dec should be OK. But i'm using it for larger sequences and I'm hitting OOM issue. Wanted to know if you have implemented LSH for Enc-Dec attention.
When setting the output_dir argument to None when creating a trax.supervised.Trainer this leads to the reset method not being called (Line 216/217 in in trainer_lib: if output_dir is not None: self.reset(output_dir)
), which breaks the Trainer because the reset method sets the train_stream among other things and a model can't function without the train stream. Surely you would still want to have a train stream even if not writing the model to an output dir? Is there a reason for the if condition in line 216?
# Steps to reproduce:
Create a trainer with any model without setting an output_dir.
# Error logs:
...
After using colab for training/loading model into prediction mode, runs out of memory on second prediction run on TPU runtime
https://colab.research.google.com/drive/1v2q5Qp2-68hLG-uTZ3gZZHvkm9Ovbpkc
Reformer model details:
def reformer(mode):
return trax.models.reformer.ReformerLM(
d_model=32,
d_ff=128,
n_layers=8,
vocab_size=1024,
mode=mode)
Sequence Length = 100
batch size = 128
...
OS: Google Colab
$ pip freeze | grep tensor
mesh-tensorflow==0.1.13
tensor2tensor==1.15.4
tensorboard==2.2.0
tensorboard-plugin-wit==1.6.0.post2
tensorboardcolab==0.0.22
tensorflow==2.2.0rc2
tensorflow-addons==0.8.3
tensorflow-datasets==2.1.0
tensorflow-estimator==2.2.0rc0
tensorflow-gan==2.0.0
tensorflow-gcs-config==2.1.8
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-privacy==0.2.2
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39
$ python -V
Python 3.6.9
# Steps to reproduce:
Run all cells upto the "Speed" markdown cell
# Error logs:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
443 else:
--> 444 outputs, s = self._do_custom_gradients(x, weights, state, rng=rng)
445 self._state = s
16 frames
RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available
During handling of the above exception, another exception occurred:
LayerError Traceback (most recent call last)
LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
layer created in file [...]/models/reformer/reformer.py, line 802
layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})
File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
output, state = _do_forward(x, weights)
File [...]/dist-packages/jax/api.py, line 1460, in __call__
num_consts=len(consts))
File [...]/dist-packages/jax/core.py, line 179, in bind
return self.impl(*args, **kwargs)
File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
return core.eval_jaxpr(params['jaxpr'], consts, *args)
File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
File [...]/dist-packages/jax/core.py, line 179, in bind
return self.impl(*args, **kwargs)
File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
return core.eval_jaxpr(params['jaxpr'], consts, *args)
File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
File [...]/dist-packages/jax/core.py, line 179, in bind
return self.impl(*args, **kwargs)
File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
return compiled_fun(*args)
File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
out_buf = compiled.Execute(input_bufs)
RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available
During handling of the above exception, another exception occurred:
LayerError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/trax/layers/base.py in pure_fn(self, x, weights, state, rng)
449 name, trace = self.__class__.__name__, _short_traceback()
450 raise LayerError(name, 'pure_fn',
--> 451 self._caller, signature(x), trace)
452
453 def output_signature(self, input_signature):
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/models/reformer/reformer.py, line 811
layer input shapes: ShapeDtype{shape:(100, 1), dtype:int32}
File [...]/trax/layers/combinators.py, line 77, in forward_with_state
outputs, s = layer.pure_fn(inputs, w, s, rng)
LayerError: Exception passing through layer ReversibleSerial (in pure_fn):
layer created in file [...]/models/reformer/reformer.py, line 802
layer input shapes: (ShapeDtype{shape:(100, 1, 32), dtype:float32}, ShapeDtype{shape:(100, 1, 32), dtype:float32})
File [...]/trax/layers/base.py, line 562, in _do_custom_gradients
output, state = _do_forward(x, weights)
File [...]/dist-packages/jax/api.py, line 1460, in __call__
num_consts=len(consts))
File [...]/dist-packages/jax/core.py, line 179, in bind
return self.impl(*args, **kwargs)
File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
return core.eval_jaxpr(params['jaxpr'], consts, *args)
File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
File [...]/dist-packages/jax/core.py, line 179, in bind
return self.impl(*args, **kwargs)
File [...]/dist-packages/jax/api.py, line 1511, in fun_impl
return core.eval_jaxpr(params['jaxpr'], consts, *args)
File [...]/dist-packages/jax/core.py, line 249, in eval_jaxpr
ans = eqn.primitive.bind(*(subfuns + in_vals), **params)
File [...]/dist-packages/jax/core.py, line 179, in bind
return self.impl(*args, **kwargs)
File [...]/jax/interpreters/xla.py, line 159, in apply_primitive
return compiled_fun(*args)
File [...]/jax/interpreters/xla.py, line 246, in _execute_compiled_primitive
out_buf = compiled.Execute(input_bufs)
RuntimeError: Failed precondition: Dependency failed: Could not allocate 419430400 bytes in memory 0x0x0_HBM0; 370884608 bytes allocatable, 376881152 bytes available
Could mode.pkl
transfer to tensorflow model?
...
OS: <your answer here>
$ pip freeze | grep tensor
# your output here
$ pip freeze | grep jax
# your output here
$ python -V
# your output here
# Steps to reproduce:
...
# Error logs:
...
Hi all, I'm building dataset using tensorflow and trax on Ubuntu docker. But I encountered Segmentation fault error.
When I run the code without trax, there is no error. Please help me.
FROM tensorflow/tensorflow:latest-gpu-py3
RUN apt-get -y update
RUN apt-get -y upgrade
RUN apt-get install -y less wget git
# for error of matplotlib + trax
RUN apt-get install -y python3-cairocffi python3-gi gir1.2-gtk-3.0
RUN pip install -U pip
RUN pip install -U six
RUN pip install -U matplotlib==3.1.3
RUN pip install --upgrade https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.43-cp36-none-linux_x86_64.whl
RUN pip install --upgrade jax
WORKDIR /tmp/docker_works
RUN git clone https://github.com/google/trax.git
WORKDIR /tmp/docker_works/trax
RUN sed -i '1s/^/import tensorflow\n/' ./trax/models/research/bert.py
RUN sed -i -e "s/from tensorflow.train import load_checkpoint//g" ./trax/models/research/bert.py
RUN sed -i -e "s/load_checkpoint/tensorflow.train.load_checkpoint/g" ./trax/models/research/bert.py
RUN python setup.py install
WORKDIR /tmp/docker_works
import matplotlib as mlp
mlp.use('Agg')
import trax
import faulthandler
faulthandler.enable()
import pickle
import random
import numpy as np
import tensorflow as tf
if __name__ == "__main__":
with tf.io.TFRecordWriter('./data/tmp.tfrecord') as writer:
for i in range(10):
example = tf.train.Example(features=tf.train.Features(
feature = {'input_ids':tf.train.Feature(int64_list=tf.train.Int64List(value=range(10))),
'labels':tf.train.Feature(int64_list=tf.train.Int64List(value=range(10)))
}
))
writer.write(example.SerializeToString())
dataset = tf.data.TFRecordDataset('./data/tmp.tfrecord')
print(dataset)
$ python script/sample_with_trax.py
Fatal Python error: Segmentation fault
Current thread 0x00007eff32052740 (most recent call first):
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/context.py", line 1081 in _initialize_physical_devices
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/context.py", line 815 in config
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/context.py", line 496 in ensure_initialized
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 95 in convert_to_eager_tensor
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 266 in _constant_impl
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 258 in constant
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/constant_op.py", line 317 in _constant_tensor_conversion_function
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1302 in convert_to_tensor
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/data/ops/readers.py", line 55 in _create_or_validate_filenames_dataset
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/data/ops/readers.py", line 316 in __init__
File "script/sample_with_trax.py", line 20 in <module>
Segmentation fault (core dumped)
import trax
)$ python script/sample_with_trax.py
<TFRecordDatasetV2 shapes: (), types: tf.string>
I generated (size=320) bpe vocab and model files and compared with Crimes & Punishment files and everything went OK
Number of tokens: 750515
and (device count, tokens per device) = (8, 1048576)
until training:
RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 18.74G of 8.00G hbm. Exceeded hbm capacity by 10.74G.
Wow, a 50% tokens increase jumps 512M to 18G?
Or did I miss something else?
Hallo,if possible,it is appreciate for your to upload a gin to train t2t_translate_ende_wmt32k with LSHSelfAttention. i have tried for many time ,but still return error.
...
OS: <your answer here>
$ pip freeze | grep tensor
# your output here
$ pip freeze | grep jax
# your output here
$ python -V
# your output here
# Steps to reproduce:
...
# Error logs:
...
In models/bert.py the line
from tensorflow.train import load_checkpoint
crashes with
Traceback (most recent call last):
File "math_trax.py", line 19, in
import trax
File "/root/.local/lib/python3.6/site-packages/trax/init.py", line 19, in
from trax import lr_schedules as lr
File "/root/.local/lib/python3.6/site-packages/trax/lr_schedules.py", line 37, in
from trax import models as trax_models
File "/root/.local/lib/python3.6/site-packages/trax/models/init.py", line 32, in
from trax.models.research import bert
File "/root/.local/lib/python3.6/site-packages/trax/models/research/bert.py", line 20, in
from tensorflow.train import load_checkpoint
ModuleNotFoundError: No module named 'tensorflow.train'
If we change that (in bert.py) to
#from tensorflow.train import load_checkpoint
from tensorflow_core._api.v2.train import load_checkpoint
the import works.
My environment or setup? Sorry if so. Tried to exclude that but to no avail.
...
OS: ubuntu 18.04 in a docker container
mesh-tensorflow==0.1.11
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0
# your output here
$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39
(I know I need to upgrade but I deem unrelated)
$ python3 -V
Python 3.6.9
# Steps to reproduce:
import trax
Traceback (most recent call last):
File "math_trax.py", line 19, in
import trax
File "/root/.local/lib/python3.6/site-packages/trax/init.py", line 19, in
from trax import lr_schedules as lr
File "/root/.local/lib/python3.6/site-packages/trax/lr_schedules.py", line 37, in
from trax import models as trax_models
File "/root/.local/lib/python3.6/site-packages/trax/models/init.py", line 32, in
from trax.models.research import bert
File "/root/.local/lib/python3.6/site-packages/trax/models/research/bert.py", line 20, in
from tensorflow.train import load_checkpoint
ModuleNotFoundError: No module named 'tensorflow.train'
I have managed to adapt the colab code for learning document representations and the training and generation phase works smoothly. I adapted the sample() method to return the final state after processing the document. However this final state seems to be a complex list consisting of a variety of information. What I want is the hidden representation of the topmost layer which I am assuming represents the whole document. Is there any way to obtain said hidden representation? I am providing my part of the code that is relevant. Any suggestions will be appreciated.
...
OS: Ubuntu 16.04 (Irrelevant)
$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.14.0
tensorboard==1.15.0
tensorflow-datasets==1.3.2
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.0
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37
$ python -V
Python 3.6.8
# Steps to reproduce:
The relevant part of the code I wrote:
# Prepare a jitted copy of the model.
jit_model_infer = trax.layers.base._accelerate(
model_infer._forward_internal, trax.math.device_count())
# Set up the initial state for sampling.
infer_state = model_infer.new_weights_and_state(
trax.supervised.trainer_lib.ShapeDtype((1,1), dtype=np.int32))[1]
infer_state = trainer._for_n_devices(infer_state)
def docvector(length=0, prompt=None):
"""Sample from the ReformerLM model"""
model_weights = trainer._opt_state[0][0]
length = len(prompt.split(" "))
# Token id 0 is the equivalent of a "start" token
cur_inputs = np.zeros((trax.math.device_count(), 1, 1), dtype=np.int32)
cur_state = infer_state
rngs = trax.math.random.split(trax.math.random.get_prng(0), trax.math.device_count())
all_samples = []
prompt = np.asarray(
[TOKENIZER.EncodeAsIds(prompt)] * trax.math.device_count()) <--------- Prompt is the input document as a string.
logits, cur_state = jit_model_infer(
cur_inputs,
model_weights,
cur_state,
rngs)
for iteration in range(length):
cur_samples = onp.array(prompt[:, iteration], dtype=int)
cur_inputs = np.array(cur_samples[:,None,None])
logits, cur_state = jit_model_infer(
cur_inputs,
model_weights,
cur_state,
rngs)
return cur_state <-------------------- This is a list of lists/dictionaries/tensors/tuples. How do I get the final hidden state?
# Error logs:
N/A
It would be great if pretrained Reformer models become available (e.g., trained on BooksCorpus and English Wikipedia).
Coming from tensor2tensor I was wondering whether the Reformer model would be also a candidate for speech recognition? Looking at the examples there is none for ASR.
Would it be possible to train an ASR model on the Reformer or would code changes be necessary? If so, can we estimate how much would have to be changed on the model implementation?
Thank you for any insight into this!
Since Trax is a successor of tensor2tensor (according to the release notes of tensor2tensor v1.15.0), it would be helpful if you could provide examples for more advanced machine learning tasks. An outstanding feature of tensor2tensor are the numerous (and useful) examples which Trax is currently lacking. Such examples would especially be helpful for machine learning tasks with complex input transformations like speech recognition or translation with subword encodings.
I'm editing the trax codebase and for debugging purposes I have to print values eagerly. But because of the jit, everything is a jit abstract expression(google/jax#196). How do you guys debug the code without being able to print anything? I assume if i disable the jit, everything would be executed eagerly. Am i right? If yes, is there anyway to disable jit?
Training a Transformer converges.
Then beam_search fails though. When n_devices == 1 some reshapes crash in decode().
OS:
ubuntu 18.04
CUDA 10.1
1 GPU environment
$ pip freeze | grep tensor
mesh-tensorflow==0.1.11
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.1.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39
$ python -V
Python 3.6.9
Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 442, in pure_fn
x, weights=weights, state=state, rng=rng)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 220, in forward_with_state
return self.forward(inputs, weights), state
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File "/root/.local/lib/python3.6/site-packages/trax/layers/attention.py", line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 959, in _reshape_method
return _reshape(a, newshape, order=order)
File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 938, in _reshape
return lax.reshape(a, computed_newshape, None)
File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 640, in reshape
old_sizes=onp.shape(operand))
File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 480, in _forward_abstract
input_signature, weight_signature, self.state, rng)
File "/root/.local/lib/python3.6/site-packages/trax/math/jax.py", line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 273, in abstract_eval_fun
instantiate=True)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 451, in pure_fn
self._caller, signature(x), trace)
trax.layers.base.LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)
File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)
File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))
File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 485, in _forward_abstract
trace)
trax.layers.base.LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file [...]/trax/layers/combinators.py, line 468
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/combinators.py, line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)
LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)
File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)
File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))
File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init
weights, state = self.new_weights_and_state(input_signature)
File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/layers/combinators.py, line 470
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}
File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file [...]/trax/layers/combinators.py, line 468
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/combinators.py, line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)
LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)
File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)
File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))
File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "math_trax.py", line 565, in
seqs, scores = beam_decoder.decode(inputs=batch, batch_size=iBatch_size)#, )targets_prefix=prefix_for_bs,
File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 602, in decode
dummy=np.zeros(n_devices))
File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 146, in f_jitted
name=flat_fun.name)
File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 642, in call_bind
outs = primitive.impl(f, *args, **params)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 448, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 220, in memoized_fun
ans = call(fun, *args)
File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 465, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 535, in _unreplicated_beam_search
self._get_initial_state(inputs, targets_prefix, batch_size),
File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 490, in _get_initial_state
_, initial_state = self.model(mode='predict').init(signature)
File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init
input_signature, trace)
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 301
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(2, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state
weights_or_empty, state = sublayer.init(inputs)
LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/layers/combinators.py, line 470
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}
File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer Parallel (in _forward_abstract):
layer created in file [...]/trax/layers/combinators.py, line 468
layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})
File [...]/trax/math/jax.py, line 175, in shape_fun
jax_shapes = jax.eval_shape(f, *args, **kwargs)
File [...]/site-packages/jax/api.py, line 2042, in eval_shape
out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun
instantiate=True)
File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/trax/layers/base.py, line 477, in call_on_input
return self.forward_with_state(x, weights=weights, state=state, rng=rng)
File [...]/trax/layers/combinators.py, line 238, in forward_with_state
sub_outputs, sub_state = layer.pure_fn(x, w, s, r)
LayerError: Exception passing through layer PaddingMask (in pure_fn):
layer created in file [...]/trax/models/transformer.py, line 286
layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}
File [...]/trax/layers/base.py, line 220, in forward_with_state
return self.forward(inputs, weights), state
File [...]/trax/layers/base.py, line 580, in _forward
raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access
File [...]/trax/layers/attention.py, line 51, in PaddingMask
return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))
File [...]/jax/numpy/lax_numpy.py, line 921, in reshape
return a.reshape(newshape, order=order) # forward to method for ndarrays
File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method
return _reshape(a, newshape, order=order)
File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape
return lax.reshape(a, computed_newshape, None)
File [...]/jax/lax/lax.py, line 640, in reshape
old_sizes=onp.shape(operand))
File [...]/site-packages/jax/core.py, line 182, in bind
out_tracer = top_trace.process_primitive(self, tracers, kwargs)
File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive
out_aval = primitive.abstract_eval(*avals, **params)
File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval
return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).
REMARK:
1 is n_devices, 2 is batch size, 30 is max_len
# Steps to reproduce:
I tried to force your machine_translation.ipynb in colab to use the GPU but didnt succeed. But maybe for you it's the fastest to check what happens if only 1 GPU as the colab in itsef runs smoothly (on a TPU).
# Error logs:
...
Hello, I was wondering, how large can the batch size be considering TPU training? Now I'm training vanilla Transformer model in Colab and I can barely fit TPU memory. My batch size is 128, sequences are padded with padded_batch
function, max_len is 512. It seems to me that I'm missing something, because it's a bit suspicious that TPU cannot handle batches of higher magnitude (like 2048).
The thing that I tried to establish is to run TPU profiler, but I could not do it since the model doesn't output anything to keep track of.
That's why, my question is, what are the best practices of training Trax transformer on TPUs?
RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 12.64G of 8.00G hbm. Exceeded hbm capacity by 4.64G.
Total hbm usage >= 12.64G:
reserved 529.00M
program 12.13G
arguments unknown size
Output size unknown.
Program hbm requirement 12.13G:
reserved 4.0K
global 196.0K
HLO temp 12.13G (58.5% utilization: Unpadded (7.09G) Padded (12.12G), 0.1% fragmentation (10.34M))
Largest program allocations in hbm:
1. Size: 937.50M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None ]"
Shape: f32[64,128,30000]{1,2,0:T(8,128)}
Unpadded size: 937.50M
XLA label: %fusion.1546 = (f32[64,128]{1,0:T(8,128)}, f32[64,128]{1,0:T(8,128)}, f32[64,128,30000]{1,2,0:T(8,128)}) fusion(f32[64,128]{1,0:T(8,128)} %fusion.9002.remat3, f32[64,128]{1,0:T(8,128)} %fusion.28213.remat, f32[30000]{0:T(1024)} %get-tuple-element.4759, f32...
Allocation type: HLO temp
==========================
2. Size: 512.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
Unpadded size: 128.00M
Extra memory due to padding: 384.00M (4.0x expansion)
XLA label: %copy.249 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4769), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
3. Size: 512.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
Unpadded size: 128.00M
Extra memory due to padding: 384.00M (4.0x expansion)
XLA label: %copy.248 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4765), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
4. Size: 512.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
Unpadded size: 128.00M
Extra memory due to padding: 384.00M (4.0x expansion)
XLA label: %copy.247 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4761), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
5. Size: 512.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
Unpadded size: 128.00M
Extra memory due to padding: 384.00M (4.0x expansion)
XLA label: %copy.246 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4757), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
6. Size: 512.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
Unpadded size: 128.00M
Extra memory due to padding: 384.00M (4.0x expansion)
XLA label: %copy.245 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4753), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
7. Size: 512.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)}
Unpadded size: 128.00M
Extra memory due to padding: 384.00M (4.0x expansion)
XLA label: %copy.244 = pred[64,8,512,512]{2,3,1,0:T(8,128)E(32)} copy(pred[64,8,512,512]{3,2,1,0:T(8,128)E(32)} %reshape.4747), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
8. Size: 512.00M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None ]"
Shape: f32[64,8,512,512]{2,3,1,0:T(8,128)}
Unpadded size: 512.00M
XLA label: %fusion.2186 = (f32[64,8,512]{2,1,0:T(8,128)}, f32[64,8,512]{2,1,0:T(8,128)}, f32[64,8,512,512]{2,3,1,0:T(8,128)}) fusion(f32[64,8,512]{2,1,0:T(8,128)} %fusion.2753, pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4382, f32[64,8,512]{2,1,0:T(8,128)} %fu...
Allocation type: HLO temp
==========================
9. Size: 512.00M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None ]"
Shape: f32[64,8,512,512]{2,3,1,0:T(8,128)}
Unpadded size: 512.00M
XLA label: %convolution-base-dilated.117.remat5 = f32[64,8,512,512]{2,3,1,0:T(8,128)} convolution(bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.312, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.314), window={size=64x8 stride=63x7 lhs_dilate=64x8}, dim_labels...
Allocation type: HLO temp
==========================
10. Size: 256.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
Unpadded size: 64.00M
Extra memory due to padding: 192.00M (4.0x expansion)
XLA label: %reshape.4751 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3020), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
11. Size: 256.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
Unpadded size: 64.00M
Extra memory due to padding: 192.00M (4.0x expansion)
XLA label: %reshape.4755 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3021), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
12. Size: 256.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
Unpadded size: 64.00M
Extra memory due to padding: 192.00M (4.0x expansion)
XLA label: %reshape.4759 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3022), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
13. Size: 256.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
Unpadded size: 64.00M
Extra memory due to padding: 192.00M (4.0x expansion)
XLA label: %reshape.4763 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3023), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
14. Size: 256.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
Unpadded size: 64.00M
Extra memory due to padding: 192.00M (4.0x expansion)
XLA label: %reshape.4767 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3024), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
15. Size: 256.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,512,2048]{2,1,0:T(8,128)E(32)}
Unpadded size: 64.00M
Extra memory due to padding: 192.00M (4.0x expansion)
XLA label: %reshape.4771 = pred[64,512,2048]{2,1,0:T(8,128)E(32)} reshape(pred[67108864]{0:T(1024)E(32)} %fusion.3025), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
16. Size: 128.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
Unpadded size: 32.00M
Extra memory due to padding: 96.00M (4.0x expansion)
XLA label: %reshape.4740 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2426), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
17. Size: 128.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
Unpadded size: 32.00M
Extra memory due to padding: 96.00M (4.0x expansion)
XLA label: %reshape.4745 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2431), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
18. Size: 128.00M
Operator: op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"
Shape: pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)}
Unpadded size: 32.00M
Extra memory due to padding: 96.00M (4.0x expansion)
XLA label: %reshape.4791 = pred[64,8,128,512]{3,2,1,0:T(8,128)E(32)} reshape(pred[33554432]{0:T(1024)E(32)} %fusion.2427), metadata={op_type="lt" op_name="pmap(mapped_update)/jit(_bernoulli)/lt"}
Allocation type: HLO temp
==========================
19. Size: 128.00M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None ]"
Shape: f32[64,8,128,512]{3,2,1,0:T(8,128)}
Unpadded size: 128.00M
XLA label: %fusion.4304 = (f32[64,8,128]{2,1,0:T(8,128)}, f32[64,8,128,512]{3,2,1,0:T(8,128)}) fusion(pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4384, bf16[64,128,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.85, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.389)...
Allocation type: HLO temp
==========================
20. Size: 128.00M
Operator: op_type="dot_general" op_name="pmap(mapped_update)/dot_general[ dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None ]"
Shape: f32[64,8,128,512]{3,2,1,0:T(8,128)}
Unpadded size: 128.00M
XLA label: %fusion.4305 = (f32[64,8,128]{2,1,0:T(8,128)}, f32[64,8,128,512]{3,2,1,0:T(8,128)}) fusion(pred[64,512]{1,0:T(8,128)E(32)} %get-tuple-element.4384, bf16[64,128,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.464, bf16[64,512,8,64]{1,3,2,0:T(8,128)(2,1)} %bitcast.466...
Allocation type: HLO temp
==========================
Fails on import trax
Singularity image bootstrapped from "docker ubuntu:latest"
OS: Ubuntu 18.04 LTS
$ pip freeze | grep tensor
mesh-tensorflow==0.1.7
tensor2tensor==1.15.2
tensorboard==2.0.2
tensorflow==2.0.0
tensorflow-datasets==1.3.2
tensorflow-estimator==2.0.1
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.1
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.55
jaxlib==0.1.37
$ python -V
Python 3.6.9
python3 -c 'import trax'
Traceback (most recent call last):
File "", line 1, in
File "/venv/lib/python3.6/site-packages/trax/init.py", line 21, in
from trax import learning_rate as lr
File "/venv/lib/python3.6/site-packages/trax/learning_rate.py", line 294, in
from trax.rl import online_tune
File "/venv/lib/python3.6/site-packages/trax/rl/init.py", line 24, in
from trax.rl import simulated_env_problem
File "/venv/lib/python3.6/site-packages/trax/rl/simulated_env_problem.py", line 29, in
from trax import trainer_lib
File "/venv/lib/python3.6/site-packages/trax/trainer_lib.py", line 41, in
from trax import jaxboard
File "/venv/lib/python3.6/site-packages/trax/jaxboard.py", line 38, in
from tensorflow import HistogramProto
ImportError: cannot import name 'HistogramProto'
Try to run the reformer model with the configuration reformer_enwik8.gin. Get an error: Can't find ptxas binary in ${CUDA_DIR}/bin.
...
OS: Ubuntu 18.04.3 LTS
$ pip freeze | grep tensor
mesh-tensorflow==0.1.7
tensor2tensor==1.15.4
tensorboard==1.15.0
tensorflow-datasets==1.3.2
tensorflow-estimator==1.15.1
tensorflow-gan==2.0.0
tensorflow-gpu==1.15.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.2
tensorflow-probability==0.7.0
tensorrt==6.0.1.4
$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37
$ python -V
python 3.6.8
$ nvcc --version
cuda10.0 (/usr/local/cuda --> /usr/local/cuda-10.0, but /usr/local/cuda-10.1 exists)
GPU: 2080TI * 4
# Steps to reproduce:
Just run the trainer.py in trax/trax using the configuration reformer_enwiki8.gin.
# Error logs:
[[[!!!! I remove some normal info about dataset]]]
I0119 09:32:55.178084 140128464549696 problem.py:651] Reading data files from /root/tensorflow_datasets/t2t_enwik8_l65k/enwik8_l65k-dev*
INFO:tensorflow:partition: 0 num_data_files: 1
I0119 09:32:55.179685 140128464549696 problem.py:677] partition: 0 num_data_files: 1
I0119 09:32:56.124050 140128464549696 inputs.py:443] Heuristically setting bucketing to False based on shapes of target tensors.
I0119 09:32:56.131589 140128464549696 inputs.py:443] Heuristically setting bucketing to False based on shapes of target tensors.
I0119 09:32:56.136316 140128464549696 inputs.py:443] Heuristically setting bucketing to False based on shapes of target tensors.
I0119 09:33:05.191175 140128464549696 trainer_lib.py:754] Model loaded from ../checkpoints/model.pkl at step 0
Model loaded from ../checkpoints/model.pkl at step 0
I0119 09:33:05.192780 140128464549696 trainer_lib.py:754] Step 0: Starting training using 1 devices
Step 0: Starting training using 1 devices
I0119 09:33:05.194077 140128464549696 trainer_lib.py:754] Step 0: Total number of trainable weights: 215865602
Step 0: Total number of trainable weights: 215865602
2020-01-19 09:33:09.105234: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.105464: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.105489: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.105517: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:09.105532: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:09.105554: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:09.105567: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.193084: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.193291: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.193319: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.193338: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:09.193354: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:09.193384: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:09.193418: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.345517: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.345708: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.345732: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.345749: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:09.345762: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:09.345776: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:09.345790: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.440697: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.440881: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.440903: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.440918: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:09.440930: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:09.440941: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:09.440954: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.545554: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.545752: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.545774: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.545791: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:09.545804: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:09.545815: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:09.545827: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:09.730990: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:09.731233: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:09.731260: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:09.731279: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:09.731293: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:09.731305: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:09.731319: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:10.081432: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:10.081621: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:10.081644: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:10.081659: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:10.081671: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:10.081708: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:10.081721: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:13.557328: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:13.557530: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:13.557552: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:13.557567: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:13.557578: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:13.557589: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:13.557601: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:13.633426: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:13.633613: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:13.633636: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:13.633651: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:13.633663: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:13.633700: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:13.633713: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:13.709584: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:13.709778: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:13.709801: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:13.709815: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:13.709826: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:13.709839: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:13.709876: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:14.256316: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:14.256517: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:73] Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to the GPU driver for PTX -> sass compilation. This is OK so long as you don't see a warning below about an out-of-date driver version.
2020-01-19 09:33:14.256540: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] Searched for CUDA in the following directories:
2020-01-19 09:33:14.256556: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] ./cuda_sdk_lib
2020-01-19 09:33:14.256568: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] /usr/local/cuda
2020-01-19 09:33:14.256579: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:77] .
2020-01-19 09:33:14.256591: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:79] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
2020-01-19 09:33:31.094227: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:31.094430: W external/org_tensorflow/tensorflow/stream_executor/gpu/redzone_allocator.cc:312] Internal: Failed to launch ptxas
Relying on driver to perform ptx compilation. This message will be only logged once.
2020-01-19 09:33:31.177827: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
2020-01-19 09:33:31.255405: E external/org_tensorflow/tensorflow/core/platform/default/subprocess.cc:208] Start cannot fork() child process: Cannot allocate memory
Traceback (most recent call last):
File "/home/xxx/pycharm_proj/trax/trax/trainer.py", line 195, in <module>
app.run(main)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "/home/xxx/pycharm_proj/trax/trax/trainer.py", line 189, in main
trainer_lib.train(output_dir=output_dir)
File "/usr/local/lib/python3.6/dist-packages/gin/config.py", line 1078, in gin_wrapper
utils.augment_exception_message_and_reraise(e, err_str)
File "/usr/local/lib/python3.6/dist-packages/gin/utils.py", line 49, in augment_exception_message_and_reraise
six.raise_from(proxy.with_traceback(exception.__traceback__), None)
File "<string>", line 3, in raise_from
File "/usr/local/lib/python3.6/dist-packages/gin/config.py", line 1055, in gin_wrapper
return fn(*new_args, **new_kwargs)
File "/home/xxx/pycharm_proj/trax/trax/supervised/trainer_lib.py", line 641, in train
trainer.train_epoch(epoch_steps, eval_steps)
File "/home/xxx/pycharm_proj/trax/trax/supervised/trainer_lib.py", line 305, in train_epoch
self.train_step(batch)
File "/home/xxx/pycharm_proj/trax/trax/supervised/trainer_lib.py", line 337, in train_step
self._step, opt_state, batch, self._model_state, self._rngs)
File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 149, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 602, in call_bind
outs = primitive.impl(f, *args, **params)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 442, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, *map(arg_spec, args))
File "/usr/local/lib/python3.6/dist-packages/jax/linear_util.py", line 223, in memoized_fun
ans = call(fun, *args)
File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 499, in _xla_callable
compiled = built.Compile(compile_options=options, backend=xb.get_backend(backend))
File "/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py", line 609, in Compile
return backend.compile(self.computation, compile_options)
File "/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py", line 161, in compile
compile_options.device_assignment)
RuntimeError: Internal: Failed to launch ptxas
Is there a way to incorporate custom input embedding while retaining the abstraction of the library for Transformers?
Right now training and evaluation are interleaved. This means that increasing the eval steps or frequency slows down training.
It would be great to have the option to spawn a separate process for evaluation.
hello, do you have any tutorial on how to extract the reversible self attention layer as a tf.layer? Is it possible? Could it be possible to just take the self attention layer and integrate it to Bert? It would be amazing! Or any tutorial on bow to integrate jax with tf also will be amazing. Thanks!
Hi, first of all thank everyone for this great project.
I just want to point out that currently indentation by 4 spaces and 2 spaces are used freely, even in the same file (trainer.py
).
I know code style is a sensitive topic so I just want to ask if it is possible to standardize this (4 spaces), or there are some good reasons not to that I am not aware of?
Thanks.
Hello,
please help us understand where we are heading with t2t being discontinued. We have a lot of interacting scripts with the t2t-eco system. Please help us understand what the perspective is for trax as a replacement. Will there be ,migration help? How does it integrate with tpus on GCP. Couldn't find any announcements what t2t users should expect.
Thx
Phillip
Got inputs layer issue testing Resnet50 with cifar10 as toy example:
ValueError: number of inputs (2) to Serial.forward less than n_in (3)
.
https://gist.github.com/rodrigobaron/e0874af5e8e32b18411fa4bb30e49174
...
Jax==0.1.52
Trax==1.2.2
Tensorflow==1.15.0
OS: Linux (Google Colab)
# Steps to reproduce:
Import https://gist.github.com/rodrigobaron/e0874af5e8e32b18411fa4bb30e49174 on Google Colab and run with GPU runtime.
# Error logs:
LayerError: Exception passing through layer Serial (in _forward_internal):
layer created in file [...]/trax/supervised/trainer_lib.py, line 674
layer input shapes: (ShapeDtype{shape:(32, 32, 32, 3), dtype:float32}, ShapeDtype{shape:(32, 10), dtype:float32})
File [...]/trax/layers/combinators.py, line 59, in forward_with_state
self._validate_forward_inputs(xs)
File [...]/trax/layers/combinators.py, line 137, in _validate_forward_inputs
' ({})'.format(len(xs), self.n_in))
ValueError: number of inputs (2) to Serial.forward less than n_in (3)
It is extremely difficult to debug nested Serial
layer's stack, especially when I'm using layers like Branch
and SerialWithSideOutputs
(the layers that are built from other basic combinators), because the error stack shows them as just Serial
layer.
I've made some small changes to base.Layer
and combinators.py
so that base.Layer
supports overriding layer names (self.__class__.__name__
) with user supplied names, and it helped debugging large models a lot.
# base.py
class Layer(object):
def __init__(self, n_in=1, n_out=1, name=None): # Added name
self._name = name or self.__class__.__name__
...
# Replace self.__class__.__name__ in LayerError calls with self._name
# combinators.py
def Branch(*layers, name='Branch'):
return Serial(..., name=name)
If this seems OK, I'd be glad to make a PR for this.
Hello. I've been playing around with both T2T and Trax libraries for a while. Since Trax has several bugs during inference, I've decided to switch to T2T. However, it seems to me that Transformer in Tensor2Tensor is not the same as in Trax.
In Tensor2Tensor I create my Transformer model this way:
hparams_my = {
'batch_size': 128,
'batch_shuffle_size': 128,
'use_fixed_batch_size': True,
'num_hidden_layers': 1,
'max_input_seq_length': 252,
'max_target_seq_length': 252,
'max_length': 252,
'symbol_modality_num_shards': 1,
'filter_size': 2048,
'dropout': 0.1
}
In Trax:
Transformer(input_vocab_size=127,
output_vocab_size=127,
d_model=512,
d_ff=2048,
n_encoder_layers=1,
n_decoder_layers=1,
n_heads=8,
dropout=0.1,
max_len=2048,
mode='train',
ff_activation=tl.Relu):
After I run training with T2T, I get this message:
(btw, 2 times)
INFO:tensorflow:Trainable Variables Total size: 7433728
INFO:tensorflow:Trainable Variables Total size: 7433728
Whereas in Trax I after I call trainer.print_n_weights()
I get
Step 0: Total number of trainable weights: 7614591
I would like to notice, that when I train my Transformer model in Trax, I reach convergence almost immediately (considering the nature of the task - just simple sequence copying with little changes), while with T2T I reach some loss values like 3-4 and no convergence at all.
Could anybody tell me what do I have to do? It seems like a common problem with T2T Transformer convergence, but I want to emphasise that in Trax it is another Transformer...
Apologies in advance for the question that might seem off topic.
Given the early breakout of JAX, it seems there is no convergence yet on an high level library for deep and reinforcement learning.
Do you personally have any plan to merge with flax?
The question comes from the will to contribute efficiently to a these libraries, where efficiency the probability to be superseeded.
I'd like to see a proper comparison against transformer (GPT-2) on text-generation with the same number of parameters. I'd like to see how it compares against when trained on sequences with the same length, and when Reformer uses a bigger context window.
Thanks a lot for your unique contribution, but substantial empirical and qualitative evidence still lacks.
...
OS: Ubuntu 18.04/NVDIA DGX Station (Desktop)
$ pip freeze | grep tensor
bert-tensorflow==1.0.1
mesh-tensorflow==0.1.9
tensor2tensor==1.15.4
tensorboard==2.0.2
tensorboardX==1.9
tensorflow-datasets==1.3.2
tensorflow-estimator==2.0.1
tensorflow-gan==2.0.0
tensorflow-gpu==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.15.1
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.57
jaxlib==0.1.37
$ python -V
Python 3.7.5
# Steps to reproduce:
Text Generation code of on own machine
https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb#scrollTo=PdAwmpS220ub
def my_inputs(n_devices):
while True:
inputs = []
mask = []
pad_amounts = onp.random.choice(PAD_AMOUNT, n_devices)
for i in range(n_devices):
inputs.append(onp.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
mode='constant'))
mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32),
(pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
mode='constant'))
inputs = onp.stack(inputs)
mask = onp.stack(mask)
yield (inputs, inputs, mask)
print("(device count, tokens per device) = ",
next(my_inputs(trax.math.device_count()))[0].shape)
...
(device count, tokens per device) = (1, 524288)
/home/sn/anaconda3/envs/py37/lib/python3.7/site-packages/jax/lib/xla_bridge.py:119: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
flake8 testing of https://github.com/google/trax on Python 3.8.0
$ flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
./trax/tf_numpy/jax_tests/lax_numpy_test.py:588:12: F821 undefined name 'FLAGS'
if not FLAGS.jax_enable_x64 and any(
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:740:17: F821 undefined name 'dtypes'
tol_spec = {dtypes.bfloat16: 3e-1, onp.float16: 0.15}
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1574:13: F821 undefined name 'dtypes'
dtype = dtypes.canonicalize_dtype(dtype)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1618:13: F821 undefined name 'api'
csame = api.jit(same)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1651:11: F821 undefined name 'api'
fun = api.jit(fun)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1663:11: F821 undefined name 'api'
fun = api.jit(fun)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1671:6: F821 undefined name 'api'
@api.jit
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1691:42: F821 undefined name 'api'
self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1692:42: F821 undefined name 'api'
self.assertRaises(TypeError, lambda: api.jit(f)(x, y))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1697:6: F821 undefined name 'api'
@api.jit
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1705:6: F821 undefined name 'api'
@api.jit
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1723:12: F821 undefined name 'api'
cfoo = api.jit(foo)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:1820:9: F821 undefined name 'lax'
x = lax.add(lnp.eye(3, dtype=lnp.float_), 0.)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2006:23: F821 undefined name 'dtypes'
dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2191:14: F821 undefined name 'api'
result = api.grad(test_fail)(x)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2220:39: F821 undefined name 'api'
self.assertAllClose(onp.int64(7), api.jit(lambda x: x)(onp.longlong(7)),
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2249:50: F821 undefined name 'lax'
self.assertTrue(type(lnp.arange(77)) == type(lax.iota(onp.int32, 77)))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2255:26: F821 undefined name 'lax'
type(lax.iota(onp.int32, 77)))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2274:9: F821 undefined name 'api'
f = api.grad(lambda x: lnp.sum(lnp.tanh(x)))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2284:11: F821 undefined name 'jax'
y = jax.ops.index_add(onp.ones(10,), [2, 4, 5], u)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2287:14: F821 undefined name 'lax'
return lax.tie_in(y, 7.)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2289:40: F821 undefined name 'api'
self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2295:9: F821 undefined name 'api'
f = api.grad(lambda x: lnp.sum(1 / (1 + lnp.exp(-x))))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2311:23: F821 undefined name 'dtypes'
dtype = onp.dtype(dtypes.canonicalize_dtype(dtype)).type
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2333:14: F821 undefined name 'api'
@partial(api.jit, static_argnums=(1,))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2520:13: F821 undefined name 'FLAGS'
not FLAGS.jax_enable_x64):
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2604:19: F821 undefined name 'FLAGS'
prev_flag = FLAGS.jax_numpy_rank_promotion
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2605:7: F821 undefined name 'FLAGS'
FLAGS.jax_numpy_rank_promotion = "allow"
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2608:7: F821 undefined name 'FLAGS'
FLAGS.jax_numpy_rank_promotion = prev_flag
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2611:19: F821 undefined name 'FLAGS'
prev_flag = FLAGS.jax_numpy_rank_promotion
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2612:7: F821 undefined name 'FLAGS'
FLAGS.jax_numpy_rank_promotion = "raise"
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2615:7: F821 undefined name 'FLAGS'
FLAGS.jax_numpy_rank_promotion = prev_flag
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2618:19: F821 undefined name 'FLAGS'
prev_flag = FLAGS.jax_numpy_rank_promotion
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2619:7: F821 undefined name 'FLAGS'
FLAGS.jax_numpy_rank_promotion = "warn"
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2633:7: F821 undefined name 'FLAGS'
FLAGS.jax_numpy_rank_promotion = prev_flag
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2638:6: F821 undefined name 'api'
@api.jit
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2643:6: F821 undefined name 'api'
@api.jit
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2656:15: F821 undefined name 'jax'
y = y + jax.grad(lambda z: lnp.sum(lnp.maximum(z, 0.)))(x)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2659:19: F821 undefined name 'lax'
f = lambda y: lax.fori_loop(0, 5, body, (y, y))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2660:15: F821 undefined name 'linear_util'
wrapped = linear_util.wrap_init(f)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2661:10: F821 undefined name 'partial_eval'
pv = partial_eval.PartialVal(
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2662:8: F821 undefined name 'jax'
(jax.ShapedArray((3, 4), onp.float32), jax.core.unit))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2662:46: F821 undefined name 'jax'
(jax.ShapedArray((3, 4), onp.float32), jax.core.unit))
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2663:20: F821 undefined name 'partial_eval'
_, _, consts = partial_eval.trace_to_jaxpr(wrapped, [pv])
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2709:15: F821 undefined name 'lax'
HIGHEST = lax.Precision.HIGHEST
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2786:20: F821 undefined name 'dtypes'
return lnp.finfo(dtypes.canonicalize_dtype(dtype)).bits
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2803:5: F821 undefined name 'check_grads'
check_grads(op, args, order, ["fwd", "rev"], tol, tol)
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2813:5: F821 undefined name 'check_grads'
check_grads(op, (special_value,), order, ["fwd", "rev"],
^
./trax/tf_numpy/jax_tests/lax_numpy_test.py:2825:5: F821 undefined name 'check_grads'
check_grads(f, (1.,), order=1)
^
49 F821 undefined name 'FLAGS'
49
https://flake8.pycqa.org/en/latest/user/error-codes.html
On the flake8 test selection, this PR does not focus on "style violations" (the majority of flake8 error codes that psf/black can autocorrect). Instead these tests are focus on runtime safety and correctness:
OS: <your answer here>
$ pip freeze | grep tensor
# your output here
$ pip freeze | grep jax
# your output here
$ python -V
# your output here
# Steps to reproduce:
...
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# Error logs:
...
Hello, I'm facing a problem while trying to work with Trax Trainer class. I have loaded my dataset from TFRecords file and created a Dataset instance using Dataset API. Then, I try to feed my dataset to the Trax trainer, but got this error. Could you please tell me how to accomplish this? I haven't found anything explaining how to use Dataset API with Trax library. Thanks!
OS: Google Colab notebook
Pass dataset iterator to tras.Inputs class
TypeError: Argument '[[ 2 16 9 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
...
[ 2 70 21 ... 0 0 0]
[ 2 16 9 ... 0 0 0]
[ 2 47 14 ... 0 0 0]]' of type <class 'tensorflow.python.framework.ops.EagerTensor'> is not a valid JAX type
I trying to train ReformerLM model from this tutorial
and cont't feed reinformer to gpu
OS: ubuntu 18.04
$ pip freeze | grep tensor
mesh-tensorflow==0.1.9
tensor2tensor==1.15.4
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-datasets==2.0.0
tensorflow-estimator==2.1.0
tensorflow-gan==2.0.0
tensorflow-hub==0.7.0
tensorflow-metadata==0.21.1
tensorflow-probability==0.7.0
$ pip freeze | grep jax
jax==0.1.59
jaxlib==0.1.39
$ python -V
python -V
### For bugs: reproduction and error logs
pip install -q -U trax
pip install -q tensorflow
import trax
from trax.models import ReformerLM
import os
import numpy as np
import tensorflow as tf
import jax
def copy_task(batch_size, vocab_size, length):
"""This task is to copy a random string w, so the input is 0w0w."""
while True:
assert length % 2 == 0
w_length = (length // 2) - 1
w = np.random.randint(low=1, high=vocab_size-1,
size=(batch_size, w_length))
zero = np.zeros([batch_size, 1], np.int32)
loss_weights = np.concatenate([np.zeros((batch_size, w_length)),
np.ones((batch_size, w_length+2))], axis=1)
x = np.concatenate([zero, w, zero, w], axis=1)
yield (x, x, loss_weights) # Here inputs and targets are the same.
copy_inputs = trax.supervised.Inputs(lambda _: copy_task(16, 32, 10))
data_stream = copy_inputs.train_stream(1)
inputs, targets, mask = next(data_stream)
print("Inputs[0]: %s" % str(inputs[0]))
print("Targets[0]: %s" % str(targets[0]))
print("Mask[0]: %s" % str(mask[0]))
def tiny_transformer_lm(mode):
return trax.models.TransformerLM( # You can try trax_models.ReformerLM too.
d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl # Remove old model.
trainer = trax.supervised.Trainer(
model=tiny_transformer_lm,
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adafactor, # Change optimizer params here.
lr_schedule=trax.lr.MultifactorSchedule, # Change lr schedule here.
inputs=copy_inputs,
output_dir=output_dir,
has_weights=True) # Because we have loss mask, this API may change.
n_epochs = 3
train_steps = 500
eval_steps = 2
for _ in range(n_epochs):
trainer.train_epoch(train_steps, eval_steps)
/opt/anaconda/envs/trax_3_7/lib/python3.7/site-packages/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
if i try to get jax.devices() directly:
[CpuDevice(id=0)]
but, tensorflow haven't problem with gpu detection
tf.config.list_physical_devices()
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
PhysicalDevice(name='/physical_device:XLA_GPU:1', device_type='XLA_GPU')]
How to use Trax to implement NER model in NLP using reformer.
Trax is a library for deep learning that focuses on sequence models and reinforcement learning. It combines performance with code clarity and maintained documentation and tests.
...
Sorry to bother, I'll be brief. I don't think the "maintained documentation" part of the statement is true (yet?). I like the work and I respect every project that goes deep down on neural network implementation, but I feel there is a critical lack of documentation for this project.
I was giving a look at Flax read the docs and, although the projects have different motives, I believe there should be something alike for Trax.
Again, sorry to bother. Wish all the luck and success for the project.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.