GithubHelp home page GithubHelp logo

sokrypton / colabdesign Goto Github PK

View Code? Open in Web Editor NEW
493.0 493.0 111.0 253.76 MB

Making Protein Design accessible to all via Google Colab!

Jupyter Notebook 43.68% Python 55.91% CSS 0.04% JavaScript 0.37%

colabdesign's People

Contributors

andrewfavor95 avatar hunarbatra avatar keaunamani avatar sokrypton avatar stephen-rettie avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

colabdesign's Issues

GPU Not Utilized

I have installed the code on a cloud GPU instance at Lambda labs as well as a local HPC but I am having the same issue where it seems like the GPU is not recognized by the software and not utilized. I have tried restarting the kernal but the same issue arises. My installation was via jupyter notebook as below. For the time I am testing with just a short cystine dense peptide.

pip install git+https://github.com/sokrypton/ColabDesign.git
mkdir params
!curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar | tar x -C params
TF_CPP_MIN_LOG_LEVEL=0
import numpy as np
from IPython.display import HTML
from colabdesign import mk_afdesign_model, clear_mem
model = mk_afdesign_model(protocol="fixbb")
model.prep_inputs(pdb_filename="6cdx.pdb", chain="A")
model.design_3stage()

At first I get this warning, which prompted me to add the TF_CPP_MIN_LOG_LEVEL=0 command.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Now the output log is below:

/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/prep.py:248: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  batch = jax.tree_map(lambda x:x[has_ca], batch)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/prep.py:273: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  o["template_features"] = jax.tree_map(lambda x:x[None],o["template_features"])
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/prep.py:376: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(lambda x:x[None], inputs)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/shared/utils.py:36: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(lambda y:y, x)
/home/ubuntu/.local/lib/python3.8/site-packages/haiku/_src/data_structures.py:144: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
  leaves, treedef = jax.tree_flatten(tree)
/home/ubuntu/.local/lib/python3.8/site-packages/haiku/_src/data_structures.py:145: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
  return jax.tree_unflatten(treedef, leaves)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/modules.py:318: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  ret = impl(ensembled_batch=jax.tree_map(lambda x:x[0], batch),
/home/ubuntu/.local/lib/python3.8/site-packages/haiku/_src/stateful.py:32: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(lambda x: x, bundle)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/prng.py:49: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(SafeKey, tuple(new_keys))
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/mapping.py:50: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
  values_tree_def = jax.tree_flatten(values)[1]
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/mapping.py:54: FutureWarning: jax.tree_unflatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_unflatten instead.
  return jax.tree_unflatten(values_tree_def, flat_axes)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/mapping.py:128: FutureWarning: jax.tree_flatten is deprecated, and will be removed in a future release. Use jax.tree_util.tree_flatten instead.
  flat_sizes = jax.tree_flatten(in_sizes)[0]
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/mapping.py:146: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/mapping.py:147: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
/home/ubuntu/.local/lib/python3.8/site-packages/haiku/_src/stateful.py:314: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  params_after = jax.tree_map(
/home/ubuntu/.local/lib/python3.8/site-packages/haiku/_src/stateful.py:322: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  state_after = jax.tree_map(
/home/ubuntu/.local/lib/python3.8/site-packages/haiku/_src/stateful.py:575: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
  length = jax.tree_leaves(xs)[0].shape[0]
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/quat_affine.py:304: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  rotation = jax.tree_map(expand_fn, rotation)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/quat_affine.py:305: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  translation = jax.tree_map(expand_fn, translation)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/quat_affine.py:330: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  rotation = jax.tree_map(expand_fn, rotation)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/quat_affine.py:331: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  translation = jax.tree_map(expand_fn, translation)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:508: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  chi2_frame_to_frame = jax.tree_map(lambda x: x[:, 5], all_frames)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:509: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  chi3_frame_to_frame = jax.tree_map(lambda x: x[:, 6], all_frames)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:510: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  chi4_frame_to_frame = jax.tree_map(lambda x: x[:, 7], all_frames)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:512: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  chi1_frame_to_backb = jax.tree_map(lambda x: x[:, 4], all_frames)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:525: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  all_frames_to_backb = jax.tree_map(
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:535: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  jax.tree_map(lambda x: x[:, None], backb_to_global),
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:564: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  map_atoms_to_global = jax.tree_map(
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:585: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  pred_positions = jax.tree_map(lambda x: x * mask, pred_positions)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/loss.py:389: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  sub_batch = jax.tree_map(lambda x: x, batch)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:1067: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  jax.tree_map(lambda r: r[:, None], r3.invert_rigids(pred_frames)),
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:1068: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  jax.tree_map(lambda x: x[None, :], pred_positions))
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:1073: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  jax.tree_map(lambda r: r[:, None], r3.invert_rigids(target_frames)),
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/alphafold/model/all_atom.py:1074: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  jax.tree_map(lambda x: x[None, :], target_positions))
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/design.py:194: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  out["grad"] = jax.tree_map(lambda *x: jnp.stack(x).mean(0), *grad)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/design.py:125: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  outs = jax.tree_map(lambda *x: jnp.stack(x), *outs)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/design.py:129: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  self.grad = jax.tree_map(lambda x: x.mean(0), outs["grad"])
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/design.py:132: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  self.aux = jax.tree_map(lambda x:x[0], outs["aux"])
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/design.py:136: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  self.aux["losses"] = jax.tree_map(lambda x: x.mean(0), outs["aux"]["losses"])
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/shared/utils.py:36: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  return jax.tree_map(lambda y:y, x)
/home/ubuntu/.local/lib/python3.8/site-packages/colabdesign/af/design.py:223: FutureWarning: jax.tree_map is deprecated, and will be removed in a future release. Use jax.tree_util.tree_map instead.
  self.grad = jax.tree_map(lambda x:x*lr, self.grad)

I do not see anything immediately that would suggest why the GPU is not seen. I do know it is not a misreported error as when I run watch nvidia-smi I can see that no processes are running on the GPU and it is at idle.

Is this typical behavior? Or is it supposed to utilize GPU?

Thanks! Super excited about the program!

small fix required for newer numpy versions

In newer numpy versions dtype=np.int is no longer supported.

Found here for example, to fix this either use python int or np.int32 / np.int64.

restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int32)
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int32)
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)

losses went into inf/nan if optimizer is not SGD

I tested other optimizers like adam/rmsprop/momentum from 'jax.example_libraries.optimizers'. As design initiated, the whole structure predicted by AlphaFold2 became a โ€œblack holeโ€ after 10-20 steps, then loss value changed into inf/nan after 60-80 steps, yet the inputs seems normal. I tried several times per optimizer and got the same result.

partial hallucination with multiple chains

I would like to design a protein complex with partial hallucination. My input pdb has two chains, A and B. Chain A is the target, so the whole sequence is fixed. In chain B, I would like to do a partial hallucination on a loop (between B100-B116). My script is something like this:

af_model = mk_afdesign_model(protocol="partial",
use_templates=False, # set True to constrain positions using template input
data_dir="/home/n721/AFD_params")

old_pos = "A1-A286,B1-B100,B116-B130"

af_model.prep_inputs(pdb_filename="ranked_0_286_HC.pdb", chain="A,B",
pos=old_pos, # define positions to contrain
fix_seq=True, # set True to constrain the sequence
use_sidechains=False) # set True to restrain the sidechains

af_model.rewire(order = [0,1,2], # set order of segments
loops = [0,15],
offset=0)

However, I got an output pdb with just one long chain that connects my chain A and chain B. I guess rewire will connect all fragments into one chain. How do I only connect the two fragments in chain B with partial hallucination while keeping chain A intact?

feature request: only search within a list of potential peptides

Hi,

Would it be possible for AfDesign to design peptides that bind a certain structure/domain but to do so for a given length and within a list (fasta file) of potential peptides?

AfDesign would not attempt a peptide that is not in the list, instead giving the best peptide within the list.

Thanks

Problem in designability_test.py

I tried to design binder using diffusion.ipynb colab. But I experienced error like this.

{'pdb':'outputs/Actinbindinger_i5jed_0.pdb','loc':'outputs/Actinbindinger_i5jed','contigs':'A5-39/A51-375:150-150','copies':1,'num_seqs':8,'initial_guess':True,'use_multimer':False,'num_recycles':3,'rm_aa':'C','num_designs':1}
protocol=binder
running proteinMPNN...
running AlphaFold...
Traceback (most recent call last):
File "/content/colabdesign/rf/designability_test.py", line 198, in
main(sys.argv[1:])
File "/content/colabdesign/rf/designability_test.py", line 180, in main
sub_seq_divided = "".join(np.insert(list(sub_seq),np.cumsum(af_model._lengths[:-1]),"/"))
File "<array_function internals>", line 180, in insert
File "/usr/local/lib/python3.9/dist-packages/numpy/lib/function_base.py", line 5280, in insert
raise IndexError(f"index {obj} is out of bounds for axis {axis} "
IndexError: index [360] is out of bounds for axis 0 with size 150
CPU times: user 348 ms, sys: 86.4 ms, total: 434 ms
Wall time: 59.6 s

My parameters are as followed:
Contig : A:100
PDB:1J6Z
HOTSPOTS : A143,A148,A341,A345,A349

num_seqs:8
initial_guess:True
num_recycles:3

Similar error did not occurred a few days ago..

Obtaining designed sequences fails

On colab, I try to run the design example with PDB 1TEN and everything works (animation, PDB structure vis). However, when I then try to model.get_seqs() I get:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-30-302e20bcfe65>](https://localhost:8080/#) in <module>()
----> 1 model.get_seqs()

2 frames
[/content/design.py](https://localhost:8080/#) in <listcomp>(.0)
    556     outs = self._outs if self._best_outs is None else self._best_outs
    557     x = outs["seq"].argmax(-1)
--> 558     return ["".join([order_restype[a] for a in s]) for s in x]
    559 
    560   def get_loss(self, x="loss"):

TypeError: unhashable type: 'DeviceArray'

What is the best way to obtained the sequence designs?

Binder Hallucination Memory Leak

I have written a little loop to generate as many binders as I want for further refinement/testing but am running into what seems to be a memory leak that causes the program to crash. Consistently, following ~332 loops the software exhausts all VRAM resources and exits. This happens on machines with anywhere from 16 to 48GB VRAM. I have used clear_mem() and model.restart() in the loop to try and free the resources but to no avail. Further, I have tried using %reset -f to ablate namespace. Has anyone experienced anything similar? Any advice on how to avoid this buildup?
Thanks!

Add seed

Hi! Is it possible to add (or use, if implemented already) a seed to reproduce the designs?

How to design more binders

Thank your team for bringing such perfect work.
One binder may not have good experimental results. How to design binders on a large scale

Homooligomer design in ProteinMPNN

Hi developers,

First, thanks for this awesome tool! I'm just starting to use it!

In first tests, I tried to design a sequence from a structure using a GTPase (PDB ID 6N12, chains A and B). However, I'm getting this error message below in cell "Run ProteinMPNN to design new sequences for given backbone". I tried different PDB files with same issue.

Any ideas about what could be going on and hot to get around it?

Thanks in advance,

Alessandro


AssertionError Traceback (most recent call last)

in

/content/colabdesign/mpnn/model.py in prep_inputs(self, pdb_filename, chain, homooligomer, ignore_missing, fix_pos, inverse, rm_aa, verbose, **kwargs)
80
81 if homooligomer:
---> 82 assert min(self._lengths) == max(self._lengths)
83 self._tied_lengths = True
84 self._len = self._lengths[0]

AssertionError:

The sequence in design.fasta is different from the sequence in best.pdb

I am trying to use RFdiffusion in ColabDesign for protein design, and I found that the sequence in design.fasta is different from the sequence in best.pdb, is this normal? Below is one of my results

design.fasta: TYKVVALFTG-----/----
best.pdb: TYKVVALFTGC-----/----

  • represents other amino acids
    You can see that the sequence of best.pdb has one more C than the sequence of design.fasta

Patch fix_partial_contigs when residue numbering in PDB has a gap

In rf/utils.py, around line 78, should the 3 lines be added, in case the residue numbering in the original PDB file has a gap? Thanks!

          if L > 0:
            new_contig.append(f"{L}-{L}")
            unseen = []
          ### in case residue numbering jumps
          elif len(seen)>0 and seen[-1][1]!=i-1:
              new_contig.append(f"{seen[0][0]}{seen[0][1]}-{seen[-1][1]}")
              seen = []
          ###
          seen.append([c,i])

out of memory for partial hallucination

Run the partial hallucination using the following code and got out of memory issue. the complex is composed of two chains ['A', 'C'], A chain has 1505 residues. Need to repaired side chain structure using partial hallucination. the residues which need to repair structure is list in pos. totally 19 parts.

if __name__ == '__main__':
    pdb_files = ['NavPas_LqhaIT.pdb']   
    data_dir = '/ai/cheng/data/zhongke/cryo-EM-mod'

    for pdb_file in pdb_files:    
        pdb_path = os.path.join(data_dir, pdb_file)
        save_path = os.path.join(data_dir, pdb_file.replace('.pdb', '_repaired.pdb'))
        
        clear_mem()
        af_model = mk_afdesign_model(protocol="partial",
                             use_templates=False, # set True to constrain positions using template input
                             data_dir='/ai/cheng/gitlab/software_git/ColabDesign/params')   

        if pdb_file == 'NavPas_LqhaIT.pdb':
            af_model.prep_inputs(pdb_filename=pdb_path, 
                                 chain="A", 
                                 pos="295,297,300,302,304,320,321,323,324,1196,1197,1200,1201,1203,1204,1250,1251,1254,1258-1263",
                                 fix_seq=True, # set True to constrain the sequence
                                 )
            af_model.rewire(loops=[6]*18)
            # initialize with wildtype seq, fill in the rest with soft_gumbel distribution
            af_model.restart(mode=["soft","gumbel","wildtype"])
            af_model.design_3stage(100, 100, 10)
            af_model.save_pdb(save_path)

when run the code, report following error. the memory consumption is unexpectedly large and could you kindly help me to check why?

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/ai/cheng/gitlab/software_git/ColabDesign/partial_hallucination.py", line 35, in <module>
    af_model.design_3stage(100, 100, 10)
  File "/ai/cheng/gitlab/software_git/ColabDesign/colabdesign/af/design.py", line 356, in design_3stage
    self.design_logits(soft_iters, e_soft=1,
  File "/ai/cheng/gitlab/software_git/ColabDesign/colabdesign/af/design.py", line 334, in design_logits
    self.design(iters, **kwargs)
  File "/ai/cheng/gitlab/software_git/ColabDesign/colabdesign/af/design.py", line 328, in design
    self.step(lr_scale=lr_scale, num_recycles=num_recycles,
  File "/ai/cheng/gitlab/software_git/ColabDesign/colabdesign/af/design.py", line 195, in step
    self.run(num_recycles=num_recycles, num_models=num_models, sample_models=sample_models,
  File "/ai/cheng/gitlab/software_git/ColabDesign/colabdesign/af/design.py", line 96, in run
    auxs.append(self._recycle(p, num_recycles=num_recycles, backprop=backprop))
  File "/ai/cheng/gitlab/software_git/ColabDesign/colabdesign/af/design.py", line 180, in _recycle
    aux = self._single(model_params, backprop)
  File "/ai/cheng/gitlab/software_git/ColabDesign/colabdesign/af/design.py", line 140, in _single
    (loss, aux), grad = self._model["grad_fn"](*flags)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 219121747088 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    1.14GiB
              constant allocation:    86.3KiB
        maybe_live_out allocation:  829.86MiB
     preallocated temp allocation:  204.07GiB
                 total allocation:  206.03GiB

Out of memory error

Hi there! I found your colab coincidentally and wanted to try hallucinating some binders using it for the protein "8DYS" on PDB.

However, when using the recommended settings and initializing with "WEQLARDRSRFARR" (a known natural binder that we're trying to modulate), I get the following error:

Stage 1: running (logits โ†’ soft)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-4-09dc0c470929>](https://localhost:8080/#) in <module>
     36 if optimizer == "pssm_semigreedy":
---> 37   model.design_pssm_semigreedy(120, 32, **flags)
     38   pssm = softmax(model._tmp["seq_logits"],1)

23 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 13771167248 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  444.12MiB
              constant allocation:    54.2KiB
        maybe_live_out allocation:   89.58MiB
     preallocated temp allocation:   12.83GiB
                 total allocation:   13.35GiB

Do you have any suggestions on how to get around the issues?

Thank you very much!

atom index in af.prep.prep_pdb

In the prep_pdb.py file, line 434

    im = ignore_missing[n] if isinstance(ignore_missing,list) else ignore_missing
    if im:
      r = batch["all_atom_mask"][:,0] == 1  #<===== THIS LINE
      batch = jax.tree_map(lambda x:x[r], batch)
      residue_index = batch["residue_index"] + last

Should not one use batch["all_atom_mask"][1], if we want to check if Ca is present? The 0th atom is for N according to residue_constants.atom_types.

Thanks!

Using RFdiffusion to design a binder, with some parts of the binder fixed

Hello,

I would like to address the following task using the RFdiffusion notebook:

I have a complex of two proteins (target-binder) in a single PDB; chain B is the target which I want to preserve, and chain A I want to preserve some parts but want to redesign others. I am also providing the hotspots from the target.

My specification looks as follows:

Contigs: B123-370/0 A1-50/2/A53-54/9/A64-65/1/A67-122 (I also tried replacing trailing โ€œ/0โ€ with โ€œ:โ€)
Hotspots: B181,B221,B229

In the RFdiffusion cell, the execution falls to โ€œmode: fixedโ€ (As there is at least one fixed residue in the contig).

In the designability test, the execution falls to โ€œprotocol: partialโ€.
In the designability test, to get into the โ€œprotocol: binderโ€, I see I would need to provide chains, where each of the chains is either entirely fixed or entirely free.

My main concern is mainly about what model is used by RFdiffusion in my case, e.g. are the hotspots even utilized? Are the weights used fine-tuned for the binder task?

Also, if I want to tackle a task like this, is there some other recommended setting for the pipeline? Ideally also utilizing the designability test for the binder task (to produce iPAE) - but this I can bypass. My main concern is the RFdiffusion part.

Thank you very much in advance to anyone willing to discuss this!
Petr Kouba

MPNN designed sequence not separated correctly with multiple targets and multiple binders

With a toy example, my contig string is 'A 20 E 10' with a modified pdb structure. The output sequence is:
TTCCPSIRSNFNVCRLPGTPEAICATYTGCIIIPGATCPGDYAN/RRRNNDKPVDMLYPMVAMEMTLGSEFEVME

In the above output, the two target sequences (chain A and E) are combined, followed by "/", then followed by two binder sequences (20 + 10). There is no separation "/" between the two targets and the two binders.

Some initial tracing of the code points to line 27 in af.prep._prep_binder

self._lengths = [self._target_len, self._binder_len]

self._lengths only contain two elements. This is later passed on to the mpnn code, leading to mpnn think there are only two chains in sequence output. I can certainly read the sequences from the final PDB file, just wonder if this is something that needs to be fixed.

Thanks!

git clone error

During installation, I get an error. Does anyone have a solution to this problem?
error

OOM on AF DB proteins?

Creating a binder for this protein: https://alphafold.ebi.ac.uk/entry/Q8W3K0 and I'm getting this error both on T4's and A100's:

This error makes sense, but I'm confused as to how a protein that requires 100GB could've been folded by Alphafold in the first place? Shouldn't any protein that Alphafold can intake also be used by Colabdesign? Or does Colabdesign take more memory?

Stage 1: running (logits โ†’ soft)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-3-09dc0c470929>](https://localhost:8080/#) in <module>
     36 if optimizer == "pssm_semigreedy":
---> 37   model.design_pssm_semigreedy(120, 32, **flags)
     38   pssm = softmax(model._tmp["seq_logits"],1)

25 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 164515000848 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation: 1006.63MiB
              constant allocation:    97.4KiB
        maybe_live_out allocation:  660.41MiB
     preallocated temp allocation:  153.22GiB
                 total allocation:  154.84GiB
Peak buffers:
	Buffer 1:
		Size: 22.78GiB
		XLA Label: copy
		Shape: f32[288,4,4,1152,1152]
		==========================

	Buffer 2:
		Size: 22.78GiB
		XLA Label: copy
		Shape: f32[288,4,4,1152,1152]
		==========================

	Buffer 3:
		Size: 22.78GiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_attention_starting_node/broadcast_in_dim[shape=(288, 4, 4, 1152, 1152) broadcast_dimensions=()]" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/stateful.py" source_line=640
		XLA Label: broadcast
		Shape: f32[288,4,4,1152,1152]
		==========================

	Buffer 4:
		Size: 648.00MiB
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 5:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/gating_linear/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 6:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/output_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 7:
		Size: 648.00MiB
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 8:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/mul" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/layer_norm.py" source_line=205
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 9:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/mul" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/layer_norm.py" source_line=205
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 10:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/right_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 11:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/right_gate/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 12:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/left_projection/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 13:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/evoformer/__layer_stack_no_state_1/while/body/remat/evoformer_iteration/triangle_multiplication_outgoing/left_gate/...a, ah->...h/jit(_einsum)/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/usr/local/lib/python3.9/dist-packages/colabdesign/af/alphafold/model/common_modules.py" source_line=118
		XLA Label: custom-call
		Shape: f32[1327104,128]
		==========================

	Buffer 14:
		Size: 648.00MiB
		XLA Label: fusion
		Shape: f32[128,1152,1152]
		==========================

	Buffer 15:
		Size: 648.00MiB
		Operator: op_name="jit(_model)/jit(main)/transpose(jvp(jit(apply)))/jit(apply_fn)/alphafold/alphafold_iteration/structure_module/broadcast_in_dim[shape=(1152, 1152, 128) broadcast_dimensions=()]" source_file="/usr/local/lib/python3.9/dist-packages/haiku/_src/stateful.py" source_line=640
		XLA Label: broadcast
		Shape: f32[1152,1152,128]
		==========================

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

fix_pos in af/fixbb

Dear developers,

I'm trying to use ColabDesign AF to design a sequence using the fixbb protocol fixing sites with a list of positions as follows:

model = mk_afdesign_model(protocol="fixbb")
model.prep_inputs(pdb_filename="mystruc.pdb", chain="A",fix_pos="10,14,22,31,46,48,56,58,60,77,79,81,131,133,141,143,145,153,155,157,164,166,168")
model.restart()
model.design_3stage()
seq=SeqRecord(Seq(model.get_seqs()[0]),id=f"design{i}",description="")
model.save_pdb(f"fixed/org/design{i}.{model.protocol}.pdb")

I've assumed the positions need to 0-based. Is that correct?

I must be doing something wrong as the resulting sequence differs from the input pdb protein sequence in the sites given to fix_pos?

Any hint would be greatly appreciated.

Best,
Daniel

Sidechain rmsd loss throws out error

Firstly, thank you for this codebase; it is greatly accelerating our research.

I am trying to add a sidechain loss to certain residues in partial hallucination, but I am getting an error deep within the code.

When I prep_inputs with use_sidechains = True, I get the following error:
TypeError: take requires ndarray or scalar arguments, got <class 'colabdesign.af.alphafold.model.r3.Rigids'> at position 0.

I do not have the technical expertise to dive into the code and fix this issue, please help!

Screen Shot 2022-10-13 at 16 04 38

WARNING: 'model_1_ptm' not found

Hello,
I'm trying to setup a local colabdesing instance, but I encounter this error "WARNING: 'model_1_ptm' not found" while trying to run
model = mk_afdesign_model(protocol="fixbb")
Where can I find them and where to install ?
Thank you in advance,
DomML

Translate colab for RFdiffusion to a jupyter notebook?

Hi, the colab for RFdiffsuion is great, but I have many disconnections problems with the free google colab runtime.

WOuld it be possible to translate the notebook to jupyter notebook in order to run it locally?

I tried to run the colab locally but there are many errors due to google.colab packages etc...

Binder+MSA: Cannot concatenate arrays with different numbers of dimensions

Hi,

Thank you for this great setup! I came across a bug in the binder prediction with MSA (I understand that the latter was only tested for fixbb). I get this following error (also on colab, not just locally):

Traceback (most recent call last):
  File "colabdesign/tools/af_design_motifs.py", line 799, in <module>
    design_model.design(50, weights={"plddt":0.1,"pae":0.1,"ent":ent})
  File "colabdesign/tools/af_design_motifs.py", line 634, in design
    self._state, outs, loss = step(self._k, self._state, subkey, opt)
  File "colabdesign/tools/af_design_motifs.py", line 589, in step
    (loss, outs), grad = self._grad(self._get_params(state), self._params[n], self._inputs, key, opt)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/api.py", line 433, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 1681, in bind
    return call_bind(self, fun, *args, **params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 1693, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 594, in process_call
    return primitive.impl(f, *tracers, **params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
    *arg_specs).compile().unsafe_call
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/dispatch.py", line 198, in lower_xla_callable
    fun, abstract_args, pe.debug_info_final(fun, "jit"))
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1680, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1657, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/api.py", line 1073, in value_and_grad_f
    f_partial, *dyn_args, has_aux=True, reduce_axes=reduce_axes)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/api.py", line 2528, in _vjp
    flat_fun, primals_flat, has_aux=True, reduce_axes=reduce_axes)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 118, in vjp
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 522, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "colabdesign/tools/af_design_motifs.py", line 314, in mod
    seq_hard = jnp.concatenate([seq_target[None], seq_hard], 1)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in concatenate
    for i in range(0, len(arrays), k)]
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in <listcomp>
    for i in range(0, len(arrays), k)]
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 557, in concatenate
    return concatenate_p.bind(*operands, dimension=dimension)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 272, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 275, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 289, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 440, in linear_jvp
    val_out = primitive.bind(*primals, **params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 272, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/core.py", line 275, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1404, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1408, in default_process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/lax/utils.py", line 66, in standard_abstract_eval
    return core.ShapedArray(shape_rule(*avals, **kwargs),
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/lax/lax.py", line 2762, in _concatenate_shape_rule
    raise TypeError(msg.format(", ".join(str(o.shape) for o in operands)))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 85, 20), (2, 13, 20).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "colabdesign/tools/af_design_motifs.py", line 799, in <module>
    design_model.design(50, weights={"plddt":0.1,"pae":0.1,"ent":ent})
  File "colabdesign/tools/af_design_motifs.py", line 634, in design
    self._state, outs, loss = step(self._k, self._state, subkey, opt)
  File "colabdesign/tools/af_design_motifs.py", line 589, in step
    (loss, outs), grad = self._grad(self._get_params(state), self._params[n], self._inputs, key, opt)
  File "colabdesign/tools/af_design_motifs.py", line 314, in mod
    seq_hard = jnp.concatenate([seq_target[None], seq_hard], 1)
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in concatenate
    for i in range(0, len(arrays), k)]
  File "colabdesign/colabdesign/local/lib/python3.7/site-packages/jax/_src/numpy/lax_numpy.py", line 3430, in <listcomp>
    for i in range(0, len(arrays), k)]
TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 85, 20), (2, 13, 20).

The target is 85, the binder is 13 residue long. How should the 2 arrays be concatenated? I could not figure it out.

Thank you for your amazing work!

TypeError: tree_map() missing 1 required positional argument: 'tree'

Hello!

I am running the following script for binder hallucination:

from colabdesign import mk_af_model

model = mk_af_model(protocol="binder")
model.prep_inputs(pdb_filename=target, binder_len=10)
model.design_3stage(100, 100, 10)

and i am getting the following error:

/home/yianni/.cache/pypoetry/virtualenvs/colabdesign-ho_SLdon-py3.8/lib/python3.8/site-packages/dm_haiku-0.0.7-py3.8.egg/haiku/_src/data_structures.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
  PyTreeDef = type(jax.tree_structure(None))
2022-09-01 23:34:49.083349: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-09-01 23:34:49.675704: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-09-01 23:34:49.675780: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-09-01 23:34:49.675793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2022-09-01 23:34:50.497173: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2022-09-01 23:34:50.497203: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1934] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
ERROR: alphafold params not found
Traceback (most recent call last):
  File "test.py", line 7, in <module>
    model.design_3stage(100, 100, 10)
  File "/home/yianni/ColabDesign/colabdesign/af/design.py", line 393, in design_3stage
    self.design_logits(soft_iters, e_soft=1, **kwargs)
  File "/home/yianni/ColabDesign/colabdesign/af/design.py", line 366, in design_logits
    self.design(iters, **kwargs)
  File "/home/yianni/ColabDesign/colabdesign/af/design.py", line 361, in design
    self.step(lr_scale=lr_scale, backprop=backprop, repredict=repredict,
  File "/home/yianni/ColabDesign/colabdesign/af/design.py", line 193, in step
    self.run(backprop=backprop, callback=callback)
  File "/home/yianni/ColabDesign/colabdesign/af/design.py", line 105, in run
    aux = jax.tree_map(lambda *x: jnp.stack(x), *aux)
TypeError: tree_map() missing 1 required positional argument: 'tree'

Please let me know what I should do!

Thank you for your time.

Fix sequence at certain positions in Binder protocol

Hello.
Thanks for this amazing work.
I am trying to redesign a peptide binder while keeping the sequence at some positions.
For example, I would like to keep the prolines at positions 2,9 and 16 as present in the input peptide.
I tried the following.

pep_model = mk_afdesign(protocol="binder")
pep_model.prep_inputs(pdb_filename="./data/Complex.pdb", chain="A", binder_chain="B", hotspot = "37,38,67,68,69", fix_pos="2,9,16", fix_seq=True) 

However, the designed peptides don't retain prolines at these positions.
Am I doing something wrong here?
I would be really grateful for any suggestions.
Thanks.
Amin.

pLDDT increases and then decreases during the soft iterations

Hi, thanks for open sourcing this work! I am exploring the binder design (using the peptide_binder_design.ipynb) with pdb id 7BW1 with binder lengths between 6-8. What I have observed is that when setting the soft iteration as 300, the pLDDT increases to above 0.8 after 100-ish iterations, and then it decreases to below 0.4 at the end of the iterations. With such low pLDDT, the subsequent hard iteration will end with pLDDT around 0.4. I tried to understand why this happened and how to improve it. I have tried different number of iterations for soft and hard, and also number of tries. But nothing seems to help. Do you have any suggestions? Thanks!

Minimize sequence length while preserving some amount of structural stability.

Hey @sokrypton! Thanks for developing this package, its been incredibly useful to work with and learn from.

I'm currently working on a project where we aim to minimize the sequence length by introducing deletions in specific domains of the protein, while preserving structural stability/binding. We have a large amount of enrichment data from large scale deletion screens, so I was originally thinking about testing out trying to train a regression model and use that to rank deletion variants (or use ESM as a zero-shot variant effect predictor), but I came across a lot of your conditional hallucination work! I was wondering if you had any pointers regarding ways to penalize sequence length while preserving specific domains/structure overall in a conditional hallucination/inpainting task?

difference between plddt in the log and in the saved PDB!

I noticed that there is a discrepancy between the plddt values in the log (i.e. af_model.aux["log"]["plddt"]) and the one obtained by averaging the b-factor of the saved PDB file (where the PDB string is obtained by af_model.save_pdb(get_best=False)). The plddt in the log is usually larger.

This is the case even when num_models=1 and num_recycles=0. Is this due to slightly different normalisation? or am I missing something?

ps,
af_model = mk_af_model(use_multimer=True, use_temaplete=True, best_metric="dgram_cce")

binder redesign

Here is a working version of the "binder redesign" idea as posited on Twitter:

http://github.com/avilella/utils/blob/master/af/design.ipynb

I've tried it a few times and the results make me think that there is scope for improvement. What would be the best way to give more chances of this working? Of the three number below, which one should I attempt to increase/decrease first?

model.design_3stage(100,50,10)

Thx in advance

WARNING:absl:No GPU/TPU found, falling back to CPU.

When I ran the fixbb design using Afdesign on my computer, I got the following message:

/home/n721/pymol/lib/python3.7/site-packages/haiku/_src/data_structures.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
PyTreeDef = type(jax.tree_structure(None))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

The job was running as I can see the messages like this:
1 models [3] recycles 0 hard 0 soft 0.00 temp 1 seqid 0.04 loss 4.96 plddt 0.41 pae 0.69 dgram_cce 4.95 ptm 0.13 rmsd 30.81
2 models [1] recycles 0 hard 0 soft 0.01 temp 1 seqid 0.07 loss 6.83 plddt 0.06 pae 0.22 dgram_cce 6.83 ptm 0.54 rmsd 38.16

However, I wonder why the GPU was not found. I am sure the computer has a RTX3090, and the tensorflow can see it:
(tf) n721@n721-System-Product-Name:~/params$ python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Any help would be appreciated!

File Traceback Error AlphaFold Design

Hello there,
I have been using the AF Design software for the better part of a week now, and just this evening when trying to run both a partial and full design, I get an error:

BadZipFile Traceback (most recent call last)

in ()
1 clear_mem()
----> 2 af_model = mk_afdesign_model(protocol="fixbb")
3 af_model.prep_inputs(pdb_filename=get_pdb("1TEN"), chain="A")
4
5 print("length", af_model._len)

6 frames

/content/colabdesign/af/model.py in init(self, protocol, num_seq, num_models, sample_models, recycle_mode, num_recycles, use_templates, best_metric, crop_len, crop_mode, subbatch_size, debug, use_alphafold, use_openfold, loss_callback, data_dir)
99 self._model_params, self._model_names = [],[]
100 for model_name in model_names:
--> 101 params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir)
102 if params is not None:
103 if not use_templates:

/content/colabdesign/af/alphafold/model/data.py in get_model_haiku_params(model_name, data_dir)
38 if os.path.isfile(path):
39 with open(path, 'rb') as f:
---> 40 params = np.load(io.BytesIO(f.read()), allow_pickle=False)
41 return utils.flat_params_to_haiku(params)
42 else:

/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
431 stack.pop_all()
432 ret = NpzFile(fid, own_fid=own_fid, allow_pickle=allow_pickle,
--> 433 pickle_kwargs=pickle_kwargs)
434 return ret
435 elif magic == format.MAGIC_PREFIX:

/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in init(self, fid, own_fid, allow_pickle, pickle_kwargs)
187 # Import is postponed to here since zipfile depends on gzip, an
188 # optional component of the so-called standard library.
--> 189 _zip = zipfile_factory(fid)
190 self._files = _zip.namelist()
191 self.files = []

/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in zipfile_factory(file, *args, **kwargs)
110 import zipfile
111 kwargs['allowZip64'] = True
--> 112 return zipfile.ZipFile(file, *args, **kwargs)
113
114

/usr/lib/python3.7/zipfile.py in init(self, file, mode, compression, allowZip64, compresslevel)
1256 try:
1257 if mode == 'r':
-> 1258 self._RealGetContents()
1259 elif mode in ('w', 'x'):
1260 # set the modified flag so central directory gets written

/usr/lib/python3.7/zipfile.py in _RealGetContents(self)
1323 raise BadZipFile("File is not a zip file")
1324 if not endrec:
-> 1325 raise BadZipFile("File is not a zip file")
1326 if self.debug > 1:
1327 print(endrec)

BadZipFile: File is not a zip file

Did something on the backend recently change such as a file rearrangement?

Can we see generated sequences other than final one?

Hi Thank you for your great work!
I wonder if there are any way to see sequences that are generated during each run!
I am running AFdesign in colab at this moment (link below):
https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/af/examples/afdesign_hotspot_test.ipynb
and I found out final output is not always the best one..
I would like to know if there are any other way that i can get sequence & pdbs for each models
is it possible to get sequence & pdbs for each models in below image..?
model.get_seqs() only gives final sequence ;(...
image

please help me!
Thank you!!

why colabdesign alphafold limits on insertion code at chain

Dear sokrypton,
In colabdesign/af/alphafold/common/protein.py, there are limit on insertion code at chain as the code snippet below.
Could you teach why alphafold design has this limit of PDB data? thanks!!

  for res in chain:
    if res.id[2] != ' ':
      raise ValueError(
          f'PDB contains an insertion code at chain {chain.id} and residue '
          f'index {res.id[1]}. These are not supported.')

how to interpret results of 'use_templates=True' versus 'False' in partial_hallucination_rewire?

I'm trying to use the partial_hallucination_rewire notebook to replace part of a protein with a linker-like sequence that otherwise preserves the overall predicted structure (code below). To preserve the exact sequence outside of the hallucination region, I think I have to have fix_seq=True in model.prep_inputs(), but I'm unsure whether use_templates in mk_afdesign_model() should be True or False?

The goal with these partly hallucinated proteins is to disrupt a function specific to the hallucinated region while preserving functions localized to other parts as well as overall stability/solubility. To assess which partly hallucinated variants are likely to have these properties I'm looking at how well the structure outside the hallucinated region is preserved and the pLDDT score both outside and within the hallucinated region.

These metrics look better when use_templates=True, but does this actually mean that the predicted structure is more reliable or is use_templates=True forcing the output to conform to the input by construction?

Maybe relevant is that the protein is so large (~1000AA) that my input PDB is the part of a larger alphafold model containing only the relevant domain (otherwise memory errors), and I'm unsure whether AlphaFold would predict this domain to fold on its own as it does in the larger structure (would it help to model this?).

Here's the code:

use_templates_setting = False
fix_seq_setting = True

model = mk_afdesign_model(protocol="partial",
                          use_templates=use_templates_setting) # set True to constrain positions using template input

#define positions we want to constrain (input PDB numbering)
input_len =142
swap_region = [842, 875]  #region of the structure to be replaced by loop

for loop_len in [4,5,6,7,8,9,10]: #not sure how long the inserted loop should be
                          
  new_len = input_len - (swap_region[1]-swap_region[0]+1) + loop_len
  old_pos = "771-" + str(swap_region[0]-1) + "," + str(swap_region[1]+1) + "-912"

  outputfile = '_'.join(["loop" + str(loop_len),
                        'template' + str(int(use_templates_setting)),
                        'seq' + str(int(fix_seq_setting))]) + '.pdb'

  print(new_len, old_pos)
  print(outputfile)

  model.prep_inputs("myprotein.pdb", chain="A",
                    pos=old_pos,               # define positions to contrain
                    length=new_len,             # define if the desired length is different from input PDB
                    fix_seq=fix_seq_setting)   # set True to constrain the sequence

  model.rewire(loops=[loop_len])

  print(model.opt["pos"])

  model.restart()

  #balance weights [dgram_cce = restraint weight], [con = hallucination weight]
  model.set_weights(dgram_cce=1, con=0) #no idea what these numbers should be
  model.design_3stage(200,100,10) #no idea what these numbers should be

  model.save_pdb(outputfile)

Do the soft iterations in pssm_semigreedy obey the bias matrix

Hello.
This is somewhat related to #107
I am fixing the positions of some of the residues using the bias matrix as follows.

bias = np.zeros((af_model._binder_len,20))
fixpos = [1,11,15]
nonfixpos = [0,2,3,4,5,6,7,8,9,10,12,13,14,16]
aa_order = residue_constants.restype_order
bias[fixpos,residue_constants.restype_order["E"]] = 1e9
bias[nonfixpos,residue_constants.restype_order["E"]] = -100

This ensures that I have E at the fixed positions in the designed sequences.
However, now if I try to add a term to the loss function to restrict the distances or angles between these residues using

def dist_loss(inputs, outputs):
  positions = outputs["structure_module"]["final_atom_positions"]
  D1 = positions[215,residue_constants.atom_order["OE1"]]
  D2 = positions[225,residue_constants.atom_order["OE1"]]
  squared_dist1 = jnp.sum((D1-D2)**2,axis=0)
  dist1 = jnp.sqrt(squared_dist1)
  dist_desired = 12
  dist = jax.nn.elu(dist1-dist_desired)
  return {"dist":dist}

I get an error

ValueError: cannot convert float NaN to integer

However, if I use CA, I don't get this error.
Interestingly, if I use a very small number of soft iterations (e.g. 2), the script runs but optimization is not optimal as expected.
It seems to me that this can happen if during the soft iterations, especially when the number of soft iterations is not too small, the residue at the fixed position is not actually "E".
Is this true?
If so, if there a way to ensure that the positions are fixed even during soft iterations?

I would be really grateful for any suggestions.
Best,
Amin.

Early stopping

When using model.design_3stage(), I observe a decrease in RMSE and at some point the error increases again with each additional iteration. I suspect the last iteration is used as design result, but is there a way to use the "best" iteration? Thanks for your help.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.