diff --git a/src/refiners/foundationals/swin/mvanet/converter.py b/src/refiners/foundationals/swin/mvanet/converter.py index daacd69..26ac970 100644 --- a/src/refiners/foundationals/swin/mvanet/converter.py +++ b/src/refiners/foundationals/swin/mvanet/converter.py @@ -11,7 +11,7 @@ def convert_weights(official_state_dict: dict[str, Tensor]) -> dict[str, Tensor] r"multifieldcrossatt.attention.5", r"dec_blk\d+\.linear[12]", r"dec_blk[1234]\.attention\.[4567]", - # We don't need the sideout weights + # We don't need the sideout weights for inference r"sideout\d+", ] state_dict = {k: v for k, v in official_state_dict.items() if not any(re.match(rm, k) for rm in rm_list)} diff --git a/src/refiners/foundationals/swin/mvanet/mclm.py b/src/refiners/foundationals/swin/mvanet/mclm.py index 041308a..b860d06 100644 --- a/src/refiners/foundationals/swin/mvanet/mclm.py +++ b/src/refiners/foundationals/swin/mvanet/mclm.py @@ -132,6 +132,8 @@ class MCLM(fl.Chain): positional_embedding = PositionEmbeddingSine(num_pos_feats=emb_dim // 2, device=device) # LayerNorms in MCLM share their weights. + # We use the `proxy` trick below so they can be present only + # once in the tree but called in two different places. ln1 = fl.LayerNorm(emb_dim, device=device) ln2 = fl.LayerNorm(emb_dim, device=device)