improve comments

This commit is contained in:
Pierre Chapuis 2024-08-26 13:42:02 +02:00
parent 10dfa73a09
commit 2fc93cb334
2 changed files with 3 additions and 1 deletions

View file

@ -11,7 +11,7 @@ def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor]
r"multifieldcrossatt.attention.5",
r"dec_blk\d+\.linear[12]",
r"dec_blk[1234]\.attention\.[4567]",
# We don't need the sideout weights
# We don't need the sideout weights for inference
r"sideout\d+",
]
state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)}

View file

@ -132,6 +132,8 @@ class MCLM(fl.Chain):
positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device)
# LayerNorms in MCLM share their weights.
# We use the `proxy` trick below so they can be present only
# once in the tree but called in two different places.
ln1 = fl.LayerNorm(emb_dim, device=device)
ln2 = fl.LayerNorm(emb_dim, device=device)