GithubHelp home page GithubHelp logo

praxis's Introduction


What is Praxis? Praxis is the layer library for Pax. While Praxis is optimized for ML at scale, Praxis has a goal to be usable by other JAX-based ML projects.

Some examples of layers to be folded into Praxis are in the praxis/layers/ directory.

Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.

praxis's People


edloper avatar laurentes avatar rybakov avatar jianlijianli avatar ukoxyz avatar zhangqiaorjc avatar protoget avatar phoenix-meadowlark avatar bignamehyp avatar shivaniag avatar dhr avatar m-orsini avatar jysohn23 avatar jihwanlee-alphago avatar kaixih avatar aaroey avatar royaurko avatar ashors1 avatar hawkinsp avatar a9isha avatar saeta avatar tink-expo avatar dryman avatar rohan-anil avatar yashk2810 avatar vlad17 avatar changlan avatar zhangyujing avatar ppwwyyxx avatar cdh4696 avatar


Danil avatar Roarke McNaught avatar Seonuk Kim avatar KanxueLi avatar Bryan avatar Benjamin Anderson avatar  avatar Zigao Wang avatar Lnyan avatar David Lakin avatar  avatar Sankalp  avatar Wayne Zhang avatar Yujun Jung avatar Frédéric Larochelle avatar Ran Ran avatar Junchen(Kevin) Zhao avatar Hao Tang avatar Benjamin Minixhofer avatar  avatar Lulzx avatar L avatar bion howard avatar Gurumurthi V Ramanan avatar Maciej Kaczkowski avatar Jorge Hernández avatar Cuong Nguyen avatar GanymedeNil avatar Chenhui Zhang avatar  avatar Samim AB avatar Inoichan avatar Devansh Varshney (देवांश वार्ष्णेय) avatar Mircea Mironenco avatar flyingcat  avatar Michael Goldfarb avatar Colin Carroll avatar Buddh Prakash avatar Nan avatar Samuel Tovey avatar Taufiq Dwi Purnomo avatar Frank Facundo avatar Paweł Budzianowski avatar Alex Hallam avatar leisu avatar Shijie avatar Cameron Tew avatar Dylan Cutler avatar  avatar  avatar Haoning XU avatar  avatar Hyunwoo Lee avatar Dmitry Balabka avatar Unchun Yang avatar Pu-Chin Chen avatar Kunal Deo avatar Nhat Le avatar Alp Dener avatar Arunkumar Venkataramanan avatar Abhinav Gupta avatar Cyprien avatar Ahmed Elnaggar avatar Arjun Vikram avatar atksh avatar Chen-Ting Chuang avatar zhuoy avatar Akash Singh avatar  avatar Leandro Lacerda Campos avatar Zijun Zhou avatar A.J avatar  avatar Pate Motter avatar Davide Spallaccini avatar  avatar Cienanos avatar  avatar Reshinth Adithyan avatar Peng Wang avatar Rizmy Abdulla avatar Pengfei Xing avatar  avatar Nikolaus Schlemm avatar Jiayu Chang avatar Menegazzi avatar Lingjie Kong avatar Joey Nassar avatar Pablo Duque avatar Luerwei avatar Makdoud avatar Felarof avatar Guodong Zhang avatar Peter Whidden avatar Sandalots avatar tawsif avatar Mohammad Reza Taesiri avatar Romain Beaumont avatar jiangplus avatar Michael Tu avatar


 avatar James Cloos avatar Jongwook Choi avatar  avatar  avatar Ruixiang Zhang avatar  avatar Lekko avatar

praxis's Issues

Cross-layer attention weight sharing fails in different scopes

Hi. I try to share attention weight across layers following the testcase in

  def testSharedTemplateLayer(self):
    sub_params = pax_fiddle.Config(
        linears.FeedForward, input_dims=8, output_dims=8
    # Only share the linear projection, not the entire FeedForward layer.
    sub_params.linear_tpl.shared_weight_layer_id = 'shared_weight'
    test_layer_p = pax_fiddle.Config(
    x_in = jnp.ones([2, 8])
    with base_layer.JaxContext.new_context():
      prng_key = jax.random.PRNGKey(1234)
      layer = base_layer.instantiate(test_layer_p)
      init_vars = layer.init(prng_key, x_in)

But it failed to share weight because of using different scopes when set or lookup cache.

  def lookup_shared_layer(
      self, root_scope: flax_core.Scope, shared_layer_id: str
  ) -> _SharedLayerCacheEntry | None:'lookup_shared_layer called with id: %s in the scope of %s',
                 shared_layer_id, root_scope)
    return self._root_scope_to_shared_layers_map[root_scope][shared_layer_id]

  def set_shared_layer(self, root_scope: flax_core.Scope, shared_layer_id: str,
                       wrapper: _WrapperLayer, layer_hparams):'set_shared_layer called with id: %s in the scope of %s',
                 shared_layer_id, root_scope)
    existing = self.lookup_shared_layer(root_scope, shared_layer_id)
    assert existing is None
        shared_layer_id] = _SharedLayerCacheEntry(
            layer=wrapper.cld, hparams=layer_hparams.clone(), wrapper=wrapper)

Specifically, I implement a 24-layer Llama with StackedTransformer(not using StackedTransformerRepeated) and set shared_weight_layer_id interleaved with the interval of 6, below the line in setup function of StackedTransformer. The main code differences are bolded in the following block. Meanwhile I set remat=True, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING in StackedTransformer.

class StackedTransformer(base_layer.BaseLayer):
  use_cross_attention: bool = False
  mask_self_attention: bool = False
  num_layers: int = 0
  model_dims: int = 0
  hidden_dims: int = 0
  num_heads: int = 0
  dim_per_head: int | None = None
  dropout_prob: float = 0.0
  atten_dropout_prob: float | None = None
  residual_dropout_prob: float | None = None
  relu_dropout_prob: float | None = None
  residual_droppath_prob: float = 0.0
  input_dropout_prob: float = 0.0
  gating_func: str = 'top2'
  unadjusted_expert_capacity_factor: float = 2.0
  transformer_layer_params_tpl: LayerTpl | Sequence[LayerTpl] = template_field(
  packed_input: bool = False
  fold_padding_with_segment_mask: bool = False
  moe_layer_tpl: LayerTpl | None = template_field(TransformerFeedForwardMoe)
  num_experts: int = 0
  num_groups: int = 1
  min_group_size: int | None = None
  moe_layers: Sequence[int] | None = ()
  ngrammer_tpls: Sequence[LayerTpl] | None = template_field(None)
  remat: bool = False
  share_interval: int = 6
  checkpoint_policy: AutodiffCheckpointType = (

  def _clone_layer_params(self, layer_tpl: LayerTpl) -> LayerTpl:
    """Useful to let subclasses switch the class (e.g. Streaming version)."""
    return layer_tpl.clone()

  def setup(self) -> None:
    assert self.num_layers > 0
    assert self.model_dims > 0
    assert self.hidden_dims > 0
    assert self.num_heads > 0
    assert 0.0 <= self.dropout_prob < 1.0
    assert 0.0 <= self.input_dropout_prob < 1.0
    def _layer_params(i):
      """Construct i-th layer params."""
      if isinstance(self.transformer_layer_params_tpl, Sequence):
        factor = self.num_layers // len(self.transformer_layer_params_tpl)
        ii = i // factor
        p_i = self._clone_layer_params(self.transformer_layer_params_tpl[ii])
        p_i = self._clone_layer_params(self.transformer_layer_params_tpl) = f'layer_{i}'
      ii = i % self.share_interval  # ii is in the range [0,5] when share_interval = 6
      p_i.tr_atten_tpl.shared_weight_layer_id = f'shared_attn_{ii}'
      p_i.use_cross_attention = self.use_cross_attention
      p_i.num_heads = self.num_heads
      p_i.dim_per_head = self.dim_per_head
      p_i.input_dims = self.model_dims
      p_i.packed_input = self.packed_input
      p_i.atten_dropout_prob = self.atten_dropout_prob or self.dropout_prob
      p_i.residual_dropout_prob = (
          self.residual_dropout_prob or self.dropout_prob
      p_i.relu_dropout_prob = self.relu_dropout_prob or self.dropout_prob
      p_i.hidden_dims = self.hidden_dims
      if self.residual_droppath_prob > 0.0:
        p_i.residual_droppath_prob = (
            self.residual_droppath_prob * i / max(1, self.num_layers)
      if self.moe_layers and i in self.moe_layers:
        assert self.num_experts > 0
        assert self.moe_layer_tpl is not None
        moe_p = self.moe_layer_tpl.clone()
        moe_p.num_experts = self.num_experts
        moe_p.num_groups = self.num_groups
        moe_p.min_group_size = self.min_group_size
        moe_p.gating_func = self.gating_func
        if moe_p.hidden_dims:
          # MoE hidden_dims could be different from FFN hidden_dims
          p_i.hidden_dims = moe_p.hidden_dims
        p_i.tr_fflayer_tpl = moe_p
      if self.ngrammer_tpls is not None:
        if self.ngrammer_tpls[i] is not None:
          p_i.ngrammer_tpl = self.ngrammer_tpls[i]
      return p_i

    if isinstance(self.transformer_layer_params_tpl, (list, tuple)):
      if self.num_layers % len(self.transformer_layer_params_tpl):
        raise ValueError(
            'num_layers should be divisible by transformer_layer_params_tpl'

    layer_params = [_layer_params(i) for i in range(self.num_layers)]
    self.create_children('x_layers', layer_params)

    if self.input_dropout_prob > 0.0:
              stochastics.Dropout, keep_prob=1.0 - self.input_dropout_prob

Could you explain why the scopes are different when sharing attention weight across layers? Is it related to layer-wise checkpointing?
I would be grateful for a demonstration of how to share attention weights, or any other advice you might offer.

Praxis layers don't support user-specified collection names

We noticed that the praxis won't allocate the variables into the user-specified collection and instead the variables will stay in the default params. For example, in the following script, we have a custom layer Foo where we would like the variable input_scale to be created into fp8_params collection. However the output will show XXX vars {'params': {'input_scale': None, 'w': None}}, meaning the input_scale is put into the params collection. In contrast, what we expect is XXX vars {'fp8_params': {'input_scale': None}, 'params': {'w': None}}.

The motivation and the use case is that we need to maintain a set of variables for the fp8 support. And the updating of such variables needs a special process: (1) we use the custom_vjp mechanism to define how the grads of these variables are computed and the grads are basically the new variables (2) during the apply grads, we use these grads to replace the variables. To facilitate this, we would like to declare a new collection to keep such variables.

I have created a branch to fix the above issue here. But it is very specific to our use case. So, I am wondering any idea or suggestion about how to improve this?

cc. @pjannaty @nluehr @reedwm

from typing import Optional

from jax import lax
from jax import numpy as jnp
from jax import random
import jax._src.test_util as jtu

from praxis import base_layer
from praxis import pax_fiddle
from praxis import pytypes

instantiate = base_layer.instantiate
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
template_field = base_layer.template_field
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
JTensor = pytypes.JTensor

class Dot(base_layer.BaseLayer):
  """Wrapper around used in standard Pax layers."""

  def __call__(self, lhs: JTensor, rhs: JTensor) -> JTensor:
    return, rhs)

class Foo(base_layer.BaseLayer):
  input_dims: int = 0
  output_dims: int = 0
  weight_init: Optional[WeightInit] = None
  dot_tpl: LayerTpl = template_field(Dot)

  def setup(self) -> None:
    wp = self.weight_split_dims_mapping
            shape=[self.input_dims, self.output_dims],

    scale_args = {
        'shape': [1],
        'init': WeightInit.Constant(1.0),
        'dtype': jnp.float32,
        'mesh_shape': self.mesh_shape,
        'tensor_split_dims_mapping': None,
        'collections': ['fp8_params'],
    self.create_variable('input_scale', WeightHParams(**scale_args))

    self.create_child('dot', self.dot_tpl.clone())

  def __call__(self, inputs: JTensor) -> JTensor:
    """Apply projection to inputs.

      inputs: The inputs JTensor.  Shaped [..., input_dims].

      Projected inputs.
    ap = self.activation_split_dims_mapping

    original_shape = inputs.shape
    assert len(original_shape) >= 2

    inputs = jnp.asarray(inputs, self.dtype)
    kernel = jnp.asarray(self.theta.w, self.dtype)

    # Reshape the inputs to 2D matrix.
    inp_mat = jnp.reshape(inputs,
                          (-1, self.input_dims))

    inp_mat = inp_mat * self.theta.input_scale

    # Actual dense layer math.
    out =, kernel)

    # Reshape back the outputs.
    out = jnp.reshape(out, (*original_shape[0:-1], self.output_dims))

    return out

in_size, out_size = 16, 32

prng_key = random.PRNGKey(seed=123)
prng_key, init_key, random_key = random.split(prng_key, 3)
inputs = random.uniform(random_key, (48, in_size)).astype(jnp.bfloat16)

foo_kwargs = {'input_dims': in_size, 'output_dims': out_size,
              'dtype': jnp.bfloat16}
foo: Foo = instantiate(
    pax_fiddle.Config(Foo, name='foo', **foo_kwargs)

variables = foo.init(init_key, inputs)
var_tree = jtu.tree_map(lambda x: None, variables)
print("XXX vars", var_tree)

Any publicly available document?

Hi Praxis team, I have been using Jax and Flax for quite sometime before find out Praxis. Flax was great, however, not well suited for scaling. I also checked T5X and Flaxformer but it seems like they are not very developer friendly as the main functionality is defining transformer layers. Me and my collegues would love to move to Praxis in our future work. I wonder if there are any available document that we can use?

Incorrect conversion from tf dtype to jax dtype

In class DatasetInputSpecsProvider when converting tf specs to jax


as_numpy_dtype is considered as a method when it is actually an attribute of tf.dtypes.Dtype (

The code works for most dtypes but fails to do for tf.string as the returned entity is a pointer to the object np datatype and not the object datatype itself.

Support custom FP8 dtype in Pipelined Transformer

We have submitted two PRs to introduce a new custom data type for FP8 params, also known as OWG params, in this PR and this PR. The purpose of this custom data type is primarily for custom gradient accumulation using the max operation.

After the merger of the aforementioned PRs, we still require one additional change, likely to the LayerwiseShardablePipelined, to perform the type conversion outside the scan_fn. This is necessary because the custom data type needs to be recognized before being broadcast into the iterations within the scan_fn to ensure that autograd correctly applies the custom gradient accumulation.

I have prepared a self-contained Python code for this potential change, which you can find here.

Essentially, you can disregard the lines before line 199 as if they have already been merged. Line 243 represents the proposed dtype conversion to be added to LayerwiseShardablePipelined, where we convert all OWG params into the custom data type.

However, there is an issue regarding how to obtain the mask of the OWG params. As per my understanding, OWG params physically reside in the PARAMS category, and we have weight hparams to determine if they are OWG or not. However, such weight hparams seem inaccessible inside the LayerwiseShardablePipelined. In the provided code, I compute the owg_mask outside in line 263 and pass it as an input to the model.apply in line 263. Nevertheless, I feel this is not an ideal design since it modifies the model call signature and is specific only to the FP8 scenario.

Ideally, I believe that if we can compute the owg_mask inside the layer (similar to line 226) by accessing the weight hparams, that would be preferable. I've observed a similar example with bf16_accum_in_fp32 here, although it doesn't require any weight hparams.

To sum up, what is the best practice to obtain the owg_mask inside the LayerwiseShardablePipelined where the weight hparams are not available?

(Note, to run the gist code, you need the latest jax build like 0.4.24.devxxxxx)

cc. @zhangqiaorjc

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.