Skip to content

Commit faa844a

Browse files
authored
Rectifying Simba Implementation (#4)
1 parent d5d6234 commit faa844a

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

scale_rl/agents/simba/simba_network.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def setup(self):
2323
)
2424
self.encoder = nn.Sequential(
2525
[
26-
PreLNResidualBlock(hidden_dim=self.hidden_dim)
27-
for _ in range(self.num_blocks)
26+
*[PreLNResidualBlock(hidden_dim=self.hidden_dim)
27+
for _ in range(self.num_blocks)],
28+
nn.LayerNorm(),
2829
]
2930
)
3031
self.predictor = NormalTanhPolicy(self.action_dim)
@@ -51,8 +52,9 @@ def setup(self):
5152
)
5253
self.encoder = nn.Sequential(
5354
[
54-
PreLNResidualBlock(hidden_dim=self.hidden_dim)
55-
for _ in range(self.num_blocks)
55+
*[PreLNResidualBlock(hidden_dim=self.hidden_dim)
56+
for _ in range(self.num_blocks)],
57+
nn.LayerNorm(),
5658
]
5759
)
5860
self.predictor = LinearCritic()

0 commit comments

Comments
 (0)