-
Notifications
You must be signed in to change notification settings - Fork 338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] LLMEnv and DataLoadingPrimer #2818
base: gh/vmoens/96/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2818
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Unrelated FailuresAs of commit e8cd9f9 with merge base 8c9dc05 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.6155s | 0.5268s | 1.8983 Ops/s | 1.9871 Ops/s | |
test_transformed | 1.1174s | 1.0255s | 0.9751 Ops/s | 0.9746 Ops/s | |
test_serial | 1.5981s | 1.5148s | 0.6602 Ops/s | 0.6583 Ops/s | |
test_parallel | 1.4020s | 1.3063s | 0.7655 Ops/s | 0.7705 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2764ms | 30.1940μs | 33.1192 KOps/s | 32.9913 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 49.3920μs | 17.7845μs | 56.2287 KOps/s | 56.3239 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 45.7350μs | 17.1028μs | 58.4698 KOps/s | 58.2347 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 43.2110μs | 10.1216μs | 98.7984 KOps/s | 99.7412 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 76.9940μs | 31.9714μs | 31.2780 KOps/s | 31.0566 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 60.9640μs | 19.5093μs | 51.2576 KOps/s | 51.5504 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 47.9400μs | 18.9945μs | 52.6468 KOps/s | 52.4973 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 50.1140μs | 11.8135μs | 84.6491 KOps/s | 83.9012 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 98.5840μs | 33.8281μs | 29.5612 KOps/s | 29.1797 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 55.1240μs | 21.5644μs | 46.3726 KOps/s | 46.5297 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 52.0480μs | 18.9393μs | 52.8003 KOps/s | 52.0068 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 40.5660μs | 11.8784μs | 84.1863 KOps/s | 84.1305 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 79.4390μs | 35.8074μs | 27.9272 KOps/s | 28.0244 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 62.1770μs | 23.1232μs | 43.2466 KOps/s | 43.4051 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 47.4080μs | 20.6968μs | 48.3165 KOps/s | 48.2773 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 45.0940μs | 13.6396μs | 73.3161 KOps/s | 72.7004 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 85.2100μs | 34.4315μs | 29.0432 KOps/s | 29.2564 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 57.5180μs | 21.4513μs | 46.6171 KOps/s | 46.1809 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 52.2980μs | 21.4419μs | 46.6376 KOps/s | 45.5530 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 48.1300μs | 13.1904μs | 75.8127 KOps/s | 74.7390 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 88.7860μs | 35.7141μs | 28.0001 KOps/s | 28.0619 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 0.1074ms | 23.0578μs | 43.3694 KOps/s | 42.8549 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.6571ms | 23.7978μs | 42.0207 KOps/s | 42.4988 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 48.1600μs | 14.9444μs | 66.9149 KOps/s | 67.3371 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 89.5070μs | 37.3055μs | 26.8057 KOps/s | 26.5798 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 57.8980μs | 25.1210μs | 39.8073 KOps/s | 39.8357 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 51.9070μs | 23.2537μs | 43.0039 KOps/s | 42.4362 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 47.1180μs | 15.0165μs | 66.5935 KOps/s | 66.8625 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 79.3680μs | 38.9024μs | 25.7054 KOps/s | 25.5776 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 56.5060μs | 26.5254μs | 37.6997 KOps/s | 37.2743 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 61.3850μs | 24.8674μs | 40.2133 KOps/s | 38.5420 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 0.1807ms | 16.6823μs | 59.9437 KOps/s | 61.3209 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.8088ms | 10.1149ms | 98.8638 Ops/s | 97.9999 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 35.8081ms | 24.8311ms | 40.2721 Ops/s | 38.0607 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2551ms | 0.1806ms | 5.5369 KOps/s | 5.6600 KOps/s | |
test_values[td1_return_estimate-False-False] | 28.5485ms | 25.2982ms | 39.5285 Ops/s | 41.4192 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 26.8510ms | 24.3805ms | 41.0164 Ops/s | 36.5255 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 37.5417ms | 36.0940ms | 27.7054 Ops/s | 28.2665 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 32.4431ms | 24.8415ms | 40.2552 Ops/s | 38.0030 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 11.8544ms | 8.7532ms | 114.2440 Ops/s | 117.1630 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.1835ms | 1.8789ms | 532.2389 Ops/s | 539.5031 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5476ms | 0.3700ms | 2.7028 KOps/s | 2.6882 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 40.3582ms | 38.5860ms | 25.9161 Ops/s | 23.9830 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.5561ms | 3.4504ms | 289.8214 Ops/s | 289.7500 Ops/s | |
test_dqn_speed[False-None] | 6.2402ms | 1.4214ms | 703.5539 Ops/s | 710.3082 Ops/s | |
test_dqn_speed[False-backward] | 2.7064ms | 1.9316ms | 517.7068 Ops/s | 520.2567 Ops/s | |
test_dqn_speed[True-None] | 0.7265ms | 0.4767ms | 2.0977 KOps/s | 2.0785 KOps/s | |
test_dqn_speed[True-backward] | 1.3077ms | 1.0561ms | 946.9146 Ops/s | 1.0517 KOps/s | |
test_dqn_speed[reduce-overhead-None] | 0.7268ms | 0.4764ms | 2.0990 KOps/s | 2.0815 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 0.9463ms | 0.8995ms | 1.1117 KOps/s | 1.0983 KOps/s | |
test_ddpg_speed[False-None] | 3.5507ms | 2.9056ms | 344.1590 Ops/s | 341.0496 Ops/s | |
test_ddpg_speed[False-backward] | 4.2111ms | 4.0393ms | 247.5668 Ops/s | 246.1683 Ops/s | |
test_ddpg_speed[True-None] | 1.4189ms | 1.2100ms | 826.4646 Ops/s | 811.8646 Ops/s | |
test_ddpg_speed[True-backward] | 2.1677ms | 2.1106ms | 473.8021 Ops/s | 447.4851 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.6045ms | 1.2197ms | 819.8911 Ops/s | 788.8532 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.1622ms | 2.1035ms | 475.3982 Ops/s | 470.6613 Ops/s | |
test_sac_speed[False-None] | 9.2255ms | 8.0267ms | 124.5834 Ops/s | 121.1360 Ops/s | |
test_sac_speed[False-backward] | 12.3078ms | 10.8202ms | 92.4193 Ops/s | 91.6627 Ops/s | |
test_sac_speed[True-None] | 2.7515ms | 2.0805ms | 480.6442 Ops/s | 476.6430 Ops/s | |
test_sac_speed[True-backward] | 3.8312ms | 3.7658ms | 265.5445 Ops/s | 245.6479 Ops/s | |
test_sac_speed[reduce-overhead-None] | 2.5534ms | 2.0788ms | 481.0387 Ops/s | 479.1013 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 3.8632ms | 3.7676ms | 265.4239 Ops/s | 257.1135 Ops/s | |
test_redq_speed[False-None] | 15.5535ms | 13.0293ms | 76.7503 Ops/s | 66.8755 Ops/s | |
test_redq_speed[False-backward] | 24.9113ms | 22.6343ms | 44.1807 Ops/s | 40.7050 Ops/s | |
test_redq_speed[True-None] | 6.3211ms | 4.9000ms | 204.0828 Ops/s | 182.6440 Ops/s | |
test_redq_speed[True-backward] | 13.8845ms | 12.3041ms | 81.2737 Ops/s | 75.4412 Ops/s | |
test_redq_speed[reduce-overhead-None] | 5.8866ms | 4.8120ms | 207.8119 Ops/s | 153.8882 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 14.3489ms | 12.6739ms | 78.9025 Ops/s | 73.2833 Ops/s | |
test_redq_deprec_speed[False-None] | 14.5112ms | 12.9824ms | 77.0271 Ops/s | 65.2692 Ops/s | |
test_redq_deprec_speed[False-backward] | 21.6702ms | 18.7735ms | 53.2665 Ops/s | 45.4356 Ops/s | |
test_redq_deprec_speed[True-None] | 4.7965ms | 3.9319ms | 254.3310 Ops/s | 211.1805 Ops/s | |
test_redq_deprec_speed[True-backward] | 10.0132ms | 9.4526ms | 105.7912 Ops/s | 105.6832 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 4.3521ms | 3.9350ms | 254.1318 Ops/s | 215.8349 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 10.0058ms | 9.1627ms | 109.1378 Ops/s | 100.6448 Ops/s | |
test_td3_speed[False-None] | 9.3276ms | 8.2643ms | 121.0022 Ops/s | 115.4125 Ops/s | |
test_td3_speed[False-backward] | 13.4055ms | 10.8847ms | 91.8717 Ops/s | 84.1313 Ops/s | |
test_td3_speed[True-None] | 1.9979ms | 1.8259ms | 547.6703 Ops/s | 527.8033 Ops/s | |
test_td3_speed[True-backward] | 3.9899ms | 3.6838ms | 271.4598 Ops/s | 252.8562 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.1166ms | 1.8387ms | 543.8681 Ops/s | 524.1915 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 3.9358ms | 3.5705ms | 280.0756 Ops/s | 227.7347 Ops/s | |
test_cql_speed[False-None] | 40.3355ms | 37.2983ms | 26.8109 Ops/s | 25.1637 Ops/s | |
test_cql_speed[False-backward] | 50.4941ms | 47.9077ms | 20.8735 Ops/s | 19.7521 Ops/s | |
test_cql_speed[True-None] | 18.5777ms | 16.6529ms | 60.0495 Ops/s | 57.8468 Ops/s | |
test_cql_speed[True-backward] | 26.0155ms | 23.8697ms | 41.8942 Ops/s | 41.8900 Ops/s | |
test_cql_speed[reduce-overhead-None] | 18.3829ms | 16.9950ms | 58.8408 Ops/s | 59.1469 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 25.0438ms | 24.0386ms | 41.5997 Ops/s | 40.9670 Ops/s | |
test_a2c_speed[False-None] | 8.5938ms | 7.5188ms | 133.0004 Ops/s | 127.3067 Ops/s | |
test_a2c_speed[False-backward] | 17.2572ms | 15.3570ms | 65.1171 Ops/s | 62.6629 Ops/s | |
test_a2c_speed[True-None] | 5.6487ms | 3.9491ms | 253.2241 Ops/s | 252.9739 Ops/s | |
test_a2c_speed[True-backward] | 11.6051ms | 10.9691ms | 91.1648 Ops/s | 85.7791 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.8856ms | 4.0043ms | 249.7320 Ops/s | 260.4836 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.2161ms | 10.8788ms | 91.9223 Ops/s | 87.3513 Ops/s | |
test_ppo_speed[False-None] | 9.0110ms | 8.1407ms | 122.8394 Ops/s | 121.0560 Ops/s | |
test_ppo_speed[False-backward] | 17.0710ms | 16.1300ms | 61.9963 Ops/s | 61.6946 Ops/s | |
test_ppo_speed[True-None] | 5.3503ms | 4.9275ms | 202.9416 Ops/s | 225.7494 Ops/s | |
test_ppo_speed[True-backward] | 11.9698ms | 10.9946ms | 90.9538 Ops/s | 92.3701 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 5.0519ms | 4.6544ms | 214.8513 Ops/s | 225.1127 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 11.8601ms | 10.6556ms | 93.8473 Ops/s | 87.0575 Ops/s | |
test_reinforce_speed[False-None] | 8.3425ms | 7.0315ms | 142.2166 Ops/s | 137.0977 Ops/s | |
test_reinforce_speed[False-backward] | 11.0278ms | 10.5488ms | 94.7974 Ops/s | 93.1342 Ops/s | |
test_reinforce_speed[True-None] | 3.8887ms | 3.3325ms | 300.0754 Ops/s | 296.9898 Ops/s | |
test_reinforce_speed[True-backward] | 15.1774ms | 10.2476ms | 97.5834 Ops/s | 101.6584 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 4.5190ms | 3.4583ms | 289.1557 Ops/s | 306.6650 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 11.0054ms | 9.8199ms | 101.8344 Ops/s | 96.8382 Ops/s | |
test_iql_speed[False-None] | 36.2013ms | 33.9301ms | 29.4724 Ops/s | 29.3483 Ops/s | |
test_iql_speed[False-backward] | 49.6386ms | 47.0694ms | 21.2452 Ops/s | 21.1083 Ops/s | |
test_iql_speed[True-None] | 12.6085ms | 11.7112ms | 85.3886 Ops/s | 83.4086 Ops/s | |
test_iql_speed[True-backward] | 24.0450ms | 22.7081ms | 44.0371 Ops/s | 43.0617 Ops/s | |
test_iql_speed[reduce-overhead-None] | 13.0597ms | 11.4631ms | 87.2365 Ops/s | 83.5773 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 24.0840ms | 22.5550ms | 44.3362 Ops/s | 42.6116 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.7211ms | 4.8988ms | 204.1325 Ops/s | 202.4986 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8350ms | 0.5222ms | 1.9148 KOps/s | 1.9348 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7811ms | 0.4941ms | 2.0238 KOps/s | 2.0153 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.5276ms | 4.6634ms | 214.4372 Ops/s | 206.1709 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.8423ms | 0.5133ms | 1.9483 KOps/s | 1.9473 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7691ms | 0.4942ms | 2.0234 KOps/s | 2.0406 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.8913ms | 1.6536ms | 604.7310 Ops/s | 602.8372 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.2761ms | 1.5790ms | 633.2999 Ops/s | 635.0601 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.4203ms | 4.7773ms | 209.3222 Ops/s | 198.1200 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 3.0166ms | 0.6590ms | 1.5175 KOps/s | 1.4800 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.1887ms | 0.6464ms | 1.5471 KOps/s | 1.5789 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.3074ms | 4.6162ms | 216.6291 Ops/s | 208.1750 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.6589ms | 0.5222ms | 1.9151 KOps/s | 1.9315 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7136ms | 0.5028ms | 1.9890 KOps/s | 1.9903 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.0873ms | 4.6840ms | 213.4917 Ops/s | 210.7017 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.8530ms | 0.5209ms | 1.9197 KOps/s | 1.9821 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7734ms | 0.4916ms | 2.0340 KOps/s | 2.0453 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.8432ms | 4.7374ms | 211.0857 Ops/s | 204.7809 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.3290ms | 0.6646ms | 1.5046 KOps/s | 1.4871 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9056ms | 0.6414ms | 1.5591 KOps/s | 1.5785 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.3797ms | 4.4393ms | 225.2586 Ops/s | 242.8473 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 6.4336ms | 2.3145ms | 432.0567 Ops/s | 416.3116 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 5.6930ms | 1.4250ms | 701.7411 Ops/s | 683.3626 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.5636s | 15.5206ms | 64.4303 Ops/s | 241.8562 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 6.5083ms | 2.2917ms | 436.3639 Ops/s | 426.8602 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 1.7891ms | 1.2690ms | 788.0275 Ops/s | 722.4473 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 5.9199ms | 4.4755ms | 223.4363 Ops/s | 31.4893 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.3036ms | 2.4856ms | 402.3120 Ops/s | 398.2762 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 5.0796ms | 1.6335ms | 612.1937 Ops/s | 628.1660 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 12.4410ms | 12.1353ms | 82.4045 Ops/s | 77.0834 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.9168ms | 14.8793ms | 67.2076 Ops/s | 69.2355 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 22.1493ms | 20.9713ms | 47.6843 Ops/s | 46.3229 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 16.2478ms | 14.9010ms | 67.1098 Ops/s | 67.7133 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 22.0773ms | 20.8640ms | 47.9295 Ops/s | 46.8663 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.1365ms | 16.0454ms | 62.3230 Ops/s | 61.8740 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.9355s | 0.8368s | 1.1950 Ops/s | 1.2070 Ops/s | |
test_transformed | 1.5154s | 1.4252s | 0.7016 Ops/s | 0.6893 Ops/s | |
test_serial | 2.4334s | 2.3321s | 0.4288 Ops/s | 0.4186 Ops/s | |
test_parallel | 1.9667s | 1.8709s | 0.5345 Ops/s | 0.5266 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1570ms | 40.6503μs | 24.6000 KOps/s | 24.8684 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 51.0510μs | 24.2041μs | 41.3154 KOps/s | 41.8788 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 52.5510μs | 22.6578μs | 44.1349 KOps/s | 44.6626 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 38.4910μs | 13.3015μs | 75.1793 KOps/s | 75.9512 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 74.4110μs | 43.4679μs | 23.0055 KOps/s | 23.4617 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 56.8110μs | 26.2119μs | 38.1506 KOps/s | 39.2849 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 61.6210μs | 24.8897μs | 40.1773 KOps/s | 40.6875 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 41.8210μs | 15.6729μs | 63.8044 KOps/s | 65.1244 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 77.2520μs | 45.8903μs | 21.7911 KOps/s | 22.1557 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 60.4610μs | 28.6764μs | 34.8719 KOps/s | 36.6367 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 87.9720μs | 25.3553μs | 39.4396 KOps/s | 40.3050 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 0.1597ms | 15.8812μs | 62.9677 KOps/s | 64.7340 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 76.9020μs | 48.3974μs | 20.6623 KOps/s | 21.0679 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 68.2810μs | 31.1057μs | 32.1484 KOps/s | 32.6003 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 75.8520μs | 27.4234μs | 36.4652 KOps/s | 36.8020 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 54.8110μs | 18.0388μs | 55.4360 KOps/s | 56.3885 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 79.7910μs | 45.8599μs | 21.8055 KOps/s | 22.1263 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 64.0510μs | 29.1767μs | 34.2739 KOps/s | 34.6093 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 62.0810μs | 29.7227μs | 33.6443 KOps/s | 34.8063 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 46.0810μs | 17.5414μs | 57.0081 KOps/s | 56.8584 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 92.6220μs | 48.5713μs | 20.5883 KOps/s | 20.9695 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 0.1033ms | 31.5128μs | 31.7332 KOps/s | 32.4452 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.2943ms | 31.8939μs | 31.3539 KOps/s | 31.5858 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 64.2310μs | 20.1510μs | 49.6253 KOps/s | 50.6639 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.1785ms | 50.4430μs | 19.8244 KOps/s | 19.9151 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 65.9510μs | 33.7204μs | 29.6557 KOps/s | 30.0724 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 96.5020μs | 31.6394μs | 31.6062 KOps/s | 32.5661 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 58.6710μs | 19.9401μs | 50.1502 KOps/s | 50.5536 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 90.4020μs | 53.3619μs | 18.7400 KOps/s | 19.0612 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 69.0810μs | 35.9116μs | 27.8462 KOps/s | 28.1837 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 65.0920μs | 33.3722μs | 29.9651 KOps/s | 30.7013 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 94.0110μs | 22.0656μs | 45.3195 KOps/s | 45.4405 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 25.7662ms | 25.1947ms | 39.6909 Ops/s | 40.2234 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 99.6743ms | 2.8978ms | 345.0899 Ops/s | 341.3655 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1052ms | 78.5875μs | 12.7247 KOps/s | 12.4859 KOps/s | |
test_values[td1_return_estimate-False-False] | 56.0966ms | 55.6889ms | 17.9569 Ops/s | 18.1158 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3107ms | 1.0891ms | 918.1600 Ops/s | 919.7956 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 88.8139ms | 88.2694ms | 11.3290 Ops/s | 11.5096 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.4193ms | 1.0880ms | 919.1473 Ops/s | 922.5808 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 25.1212ms | 24.9530ms | 40.0754 Ops/s | 41.0182 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0428ms | 0.7572ms | 1.3207 KOps/s | 1.3179 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8167ms | 0.6700ms | 1.4926 KOps/s | 1.5023 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6245ms | 1.4922ms | 670.1549 Ops/s | 673.6116 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7825ms | 0.7171ms | 1.3945 KOps/s | 1.4639 KOps/s | |
test_dqn_speed[False-None] | 1.6561ms | 1.5007ms | 666.3529 Ops/s | 666.8581 Ops/s | |
test_dqn_speed[False-backward] | 2.1538ms | 2.1036ms | 475.3716 Ops/s | 473.4709 Ops/s | |
test_dqn_speed[True-None] | 0.9602ms | 0.5497ms | 1.8190 KOps/s | 1.8034 KOps/s | |
test_dqn_speed[True-backward] | 1.2013ms | 1.1243ms | 889.4648 Ops/s | 806.3904 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.9886ms | 0.5732ms | 1.7446 KOps/s | 1.7538 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0045ms | 0.9610ms | 1.0405 KOps/s | 922.7899 Ops/s | |
test_ddpg_speed[False-None] | 3.2144ms | 2.8185ms | 354.7978 Ops/s | 355.5134 Ops/s | |
test_ddpg_speed[False-backward] | 4.5796ms | 4.1065ms | 243.5175 Ops/s | 239.2478 Ops/s | |
test_ddpg_speed[True-None] | 1.4737ms | 1.3342ms | 749.5104 Ops/s | 747.8654 Ops/s | |
test_ddpg_speed[True-backward] | 2.6164ms | 2.5693ms | 389.2080 Ops/s | 384.6444 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4177ms | 1.3396ms | 746.4947 Ops/s | 739.0424 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.0825ms | 2.0335ms | 491.7592 Ops/s | 489.4150 Ops/s | |
test_sac_speed[False-None] | 8.3827ms | 8.0095ms | 124.8511 Ops/s | 123.5965 Ops/s | |
test_sac_speed[False-backward] | 11.6625ms | 11.1531ms | 89.6613 Ops/s | 89.2023 Ops/s | |
test_sac_speed[True-None] | 1.9375ms | 1.8285ms | 546.9041 Ops/s | 543.8527 Ops/s | |
test_sac_speed[True-backward] | 3.8101ms | 3.7417ms | 267.2597 Ops/s | 265.8085 Ops/s | |
test_sac_speed[reduce-overhead-None] | 21.5241ms | 11.9946ms | 83.3707 Ops/s | 83.1205 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.8199ms | 1.7704ms | 564.8339 Ops/s | 554.3911 Ops/s | |
test_redq_speed[False-None] | 8.0688ms | 7.5969ms | 131.6318 Ops/s | 130.1279 Ops/s | |
test_redq_speed[False-backward] | 12.3529ms | 11.8098ms | 84.6755 Ops/s | 84.0201 Ops/s | |
test_redq_speed[True-None] | 2.6879ms | 2.3375ms | 427.8006 Ops/s | 423.7014 Ops/s | |
test_redq_speed[True-backward] | 4.6577ms | 4.2310ms | 236.3520 Ops/s | 245.1400 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.5622ms | 2.3484ms | 425.8227 Ops/s | 419.4649 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.6208ms | 4.2439ms | 235.6340 Ops/s | 233.5670 Ops/s | |
test_redq_deprec_speed[False-None] | 9.4012ms | 9.0511ms | 110.4840 Ops/s | 109.8376 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.6302ms | 12.2901ms | 81.3662 Ops/s | 81.0709 Ops/s | |
test_redq_deprec_speed[True-None] | 2.7297ms | 2.6298ms | 380.2499 Ops/s | 375.7173 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.9358ms | 4.4876ms | 222.8342 Ops/s | 218.4316 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.8536ms | 2.6275ms | 380.5839 Ops/s | 377.0174 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.7346ms | 4.4911ms | 222.6614 Ops/s | 220.1437 Ops/s | |
test_td3_speed[False-None] | 8.2499ms | 7.9525ms | 125.7466 Ops/s | 126.3694 Ops/s | |
test_td3_speed[False-backward] | 12.6160ms | 10.6366ms | 94.0154 Ops/s | 95.6580 Ops/s | |
test_td3_speed[True-None] | 1.6398ms | 1.6142ms | 619.5073 Ops/s | 604.3157 Ops/s | |
test_td3_speed[True-backward] | 3.9111ms | 3.3575ms | 297.8441 Ops/s | 297.7732 Ops/s | |
test_td3_speed[reduce-overhead-None] | 51.5542ms | 26.5733ms | 37.6318 Ops/s | 38.7573 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.5192ms | 1.4669ms | 681.7046 Ops/s | 747.1874 Ops/s | |
test_cql_speed[False-None] | 17.0126ms | 16.7007ms | 59.8776 Ops/s | 59.5573 Ops/s | |
test_cql_speed[False-backward] | 23.1367ms | 22.3170ms | 44.8089 Ops/s | 45.6744 Ops/s | |
test_cql_speed[True-None] | 3.3830ms | 3.2618ms | 306.5799 Ops/s | 303.1949 Ops/s | |
test_cql_speed[True-backward] | 5.9325ms | 5.5347ms | 180.6773 Ops/s | 179.8800 Ops/s | |
test_cql_speed[reduce-overhead-None] | 21.2515ms | 13.2401ms | 75.5282 Ops/s | 75.3880 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 2.1566ms | 1.9873ms | 503.1990 Ops/s | 497.5840 Ops/s | |
test_a2c_speed[False-None] | 3.3029ms | 3.1554ms | 316.9202 Ops/s | 307.7316 Ops/s | |
test_a2c_speed[False-backward] | 6.9327ms | 6.2950ms | 158.8565 Ops/s | 162.5240 Ops/s | |
test_a2c_speed[True-None] | 1.4895ms | 1.3502ms | 740.6528 Ops/s | 739.9032 Ops/s | |
test_a2c_speed[True-backward] | 3.1411ms | 3.0770ms | 324.9884 Ops/s | 326.3143 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 16.0358ms | 9.1145ms | 109.7150 Ops/s | 111.0662 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.7425ms | 1.6103ms | 621.0196 Ops/s | 613.2876 Ops/s | |
test_ppo_speed[False-None] | 3.8404ms | 3.6461ms | 274.2636 Ops/s | 266.0117 Ops/s | |
test_ppo_speed[False-backward] | 7.4438ms | 6.9853ms | 143.1584 Ops/s | 141.1138 Ops/s | |
test_ppo_speed[True-None] | 1.6138ms | 1.4176ms | 705.4288 Ops/s | 694.8804 Ops/s | |
test_ppo_speed[True-backward] | 3.2919ms | 3.2465ms | 308.0213 Ops/s | 317.2224 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 1.0678ms | 0.9691ms | 1.0318 KOps/s | 1.0206 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.7133ms | 1.5619ms | 640.2301 Ops/s | 680.7583 Ops/s | |
test_reinforce_speed[False-None] | 2.3945ms | 2.2423ms | 445.9770 Ops/s | 443.1683 Ops/s | |
test_reinforce_speed[False-backward] | 3.4351ms | 3.3532ms | 298.2255 Ops/s | 304.6989 Ops/s | |
test_reinforce_speed[True-None] | 1.4156ms | 1.2889ms | 775.8497 Ops/s | 763.5794 Ops/s | |
test_reinforce_speed[True-backward] | 3.1376ms | 3.0802ms | 324.6507 Ops/s | 325.3150 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 18.3877ms | 10.1621ms | 98.4049 Ops/s | 98.4444 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.7414ms | 1.6421ms | 608.9616 Ops/s | 598.8892 Ops/s | |
test_iql_speed[False-None] | 9.6543ms | 9.1525ms | 109.2593 Ops/s | 107.1089 Ops/s | |
test_iql_speed[False-backward] | 13.6068ms | 13.0779ms | 76.4650 Ops/s | 75.4491 Ops/s | |
test_iql_speed[True-None] | 2.2923ms | 2.2071ms | 453.0749 Ops/s | 437.5687 Ops/s | |
test_iql_speed[True-backward] | 5.1302ms | 4.8695ms | 205.3608 Ops/s | 203.5956 Ops/s | |
test_iql_speed[reduce-overhead-None] | 0.5067s | 13.1104ms | 76.2754 Ops/s | 89.7102 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 2.1183ms | 2.0582ms | 485.8636 Ops/s | 510.3422 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.7678ms | 6.3344ms | 157.8672 Ops/s | 156.1614 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5712ms | 0.3204ms | 3.1214 KOps/s | 3.3427 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.5360ms | 0.3008ms | 3.3243 KOps/s | 2.9185 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3935ms | 6.0087ms | 166.4264 Ops/s | 164.5069 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6791ms | 0.2659ms | 3.7602 KOps/s | 3.7006 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.6468ms | 0.3515ms | 2.8450 KOps/s | 3.5039 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5548ms | 1.3181ms | 758.6534 Ops/s | 764.1108 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4258ms | 1.2138ms | 823.8532 Ops/s | 743.1378 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.3987ms | 6.1923ms | 161.4922 Ops/s | 159.8897 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.3211ms | 0.4991ms | 2.0034 KOps/s | 2.0380 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8188ms | 0.4621ms | 2.1640 KOps/s | 2.1858 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 9.8946ms | 6.1520ms | 162.5483 Ops/s | 163.1232 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.3726ms | 0.3580ms | 2.7930 KOps/s | 3.1119 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6490ms | 0.3232ms | 3.0942 KOps/s | 3.7307 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.3065ms | 5.9790ms | 167.2524 Ops/s | 164.8636 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.0091ms | 0.3796ms | 2.6346 KOps/s | 3.4374 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4705ms | 0.2445ms | 4.0897 KOps/s | 3.2242 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.3784ms | 6.2095ms | 161.0433 Ops/s | 159.7686 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9989ms | 0.4427ms | 2.2589 KOps/s | 2.3357 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7653ms | 0.4664ms | 2.1442 KOps/s | 2.4376 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.1069ms | 5.5094ms | 181.5073 Ops/s | 179.2200 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 9.3073ms | 2.0791ms | 480.9781 Ops/s | 434.1147 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 7.0213ms | 1.2547ms | 797.0245 Ops/s | 800.1932 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.5027s | 15.5367ms | 64.3638 Ops/s | 180.9516 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 10.1118ms | 2.0933ms | 477.7158 Ops/s | 432.5211 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 3.5071ms | 1.1718ms | 853.3765 Ops/s | 823.0818 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 10.4174ms | 5.7821ms | 172.9462 Ops/s | 30.4822 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 6.9370ms | 2.1968ms | 455.2156 Ops/s | 441.4876 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 9.4043ms | 1.4249ms | 701.8119 Ops/s | 717.9699 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.7409ms | 13.4630ms | 74.2777 Ops/s | 69.9610 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 18.6681ms | 16.9916ms | 58.8525 Ops/s | 56.3520 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 18.3269ms | 18.0832ms | 55.2998 Ops/s | 52.8635 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 18.7739ms | 17.2384ms | 58.0101 Ops/s | 57.3477 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 18.5117ms | 18.1377ms | 55.1337 Ops/s | 53.4640 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 19.7766ms | 18.5390ms | 53.9404 Ops/s | 51.9188 Ops/s |
) -> LLMEnv: | ||
"""Creates an LLMEnv instance from a dataloader. | ||
|
||
This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which loads data from the provided dataloader. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a n00b perspective, this would help me when reading this
This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which loads data from the provided dataloader. | |
This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which populates ``data_keys`` (by default ``observation_key``) with data from the provided dataloader when the environment is reset. |
unbounded vocabulary. Defaults to ``None``. | ||
primers (Composite | None, optional): The primers to use for each key in the dataloader. | ||
Defaults to ``None``. | ||
data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. | |
data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data. |
integers representing a sequence of tokens. | ||
The action is also a string or a tensor of integers, which is concatenated to the previous observation to form the | ||
new observation. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader` |
class LLMEnv(EnvBase): | ||
"""A text generation environment. | ||
|
||
This environment is designed to work with language models, where the observation is a string or a tensor of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioning that
- by default this is meant to track history for a prompt and users can append transforms to tailor this to their use case (e.g. CoT etc.)
- users must append a transform to set the "done" condition, which would trigger the loading of the next prompt
might be helpful (is my understanding here correct?)
data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None. | ||
data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None. | ||
example_data (Any, optional): Example data to use for initializing the primer. Defaults to None. | ||
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to None. | |
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``. |
|
||
|
||
class DataLoadingPrimer(TensorDictPrimer): | ||
"""A primer that loads data from a dataloader and stacks it into a tensordict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: As a n00b, the term stack
is not as familiar to me (though I know it is used extensively within tensordict). Feel free to ignore this suggestion though
"""A primer that loads data from a dataloader and stacks it into a tensordict. | |
"""A primer that loads data from a dataloader and converts it into a tensordict using ``stack_method``. |
Stack from ghstack (oldest at bottom):