Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zerobubble] Merge Main. #6107

Merged
merged 190 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
f5a52e1
fp8 operators for compressed communication
BurkeHulk Jul 1, 2024
6991819
Merge branch 'hpcaitech:main' into feature/fp8_comm
BurkeHulk Jul 4, 2024
e17f835
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2024
dbfa7d3
fix typo
GuangyaoZhang Jul 10, 2024
1e19594
fix scaling algorithm in FP8 casting
BurkeHulk Jul 12, 2024
e881901
support fp8 communication in pipeline parallelism
BurkeHulk Jul 12, 2024
6601874
add fp8_communication flag in the script
BurkeHulk Jul 12, 2024
1f1b856
Merge remote-tracking branch 'origin/feature/fp8_comm' into feature/f…
BurkeHulk Jul 12, 2024
51f916b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
9470701
Merge pull request #5885 from BurkeHulk/feature/fp8_comm
BurkeHulk Jul 16, 2024
457a0de
shardformer fp8
GuangyaoZhang Jul 8, 2024
5a310b9
fix rebase
GuangyaoZhang Jul 17, 2024
6a20f07
remove all to all
GuangyaoZhang Jul 17, 2024
d0bdb51
Merge pull request #5899 from BurkeHulk/SP_fp8
GuangyaoZhang Jul 18, 2024
5b969fd
fix shardformer fp8 communication training degradation
GuangyaoZhang Jul 18, 2024
62661cd
Merge pull request #5921 from BurkeHulk/fp8_fix
GuangyaoZhang Jul 18, 2024
5fd0592
[fp8] support all-gather flat tensor (#5932)
ver217 Jul 24, 2024
ae486ce
[fp8] add fp8 comm for low level zero
ver217 Aug 2, 2024
91e596d
[test] add zero fp8 test case
ver217 Aug 2, 2024
c297e21
Merge pull request #5961 from ver217/feature/zeor-fp8
BurkeHulk Aug 2, 2024
53cb960
[Feature] llama shardformer fp8 support (#5938)
GuangyaoZhang Aug 5, 2024
0c10afd
[FP8] rebase main (#5963)
flybird11111 Aug 6, 2024
afb26de
[fp8]support all2all fp8 (#5953)
flybird11111 Aug 6, 2024
76ea164
[fp8] add fp8 linear (#5967)
ver217 Aug 7, 2024
ccabcf6
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
ver217 Aug 7, 2024
7739629
fix (#5976)
flybird11111 Aug 7, 2024
b480eec
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
BurkeHulk Aug 8, 2024
4b9bec8
[test ci]Feature/fp8 comm (#5981)
flybird11111 Aug 8, 2024
8241c0c
[fp8] support gemini plugin (#5978)
ver217 Aug 9, 2024
e4aadee
[fp8] use torch compile (torch >= 2.3.0) (#5979)
botbw Aug 9, 2024
f1a3a32
[fp8]Moe support fp8 communication (#5977)
flybird11111 Aug 9, 2024
b2483c8
[fp8] support hybrid parallel plugin (#5982)
wangbluo Aug 12, 2024
0978080
[fp8] refactor fp8 linear with compile (#5993)
ver217 Aug 13, 2024
597b206
[fp8] support asynchronous FP8 communication (#5997)
flybird11111 Aug 14, 2024
88fa096
[fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)
botbw Aug 15, 2024
1a2e90d
[fp8] linear perf enhancement
botbw Aug 15, 2024
20722a8
[fp8]update reduce-scatter test (#6002)
flybird11111 Aug 15, 2024
3f09a61
[fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)
wangbluo Aug 16, 2024
0a51319
[fp8] zero support fp8 linear. (#6006)
flybird11111 Aug 16, 2024
4cf79fa
merge
wangbluo Aug 17, 2024
81272e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
02636c5
fix the merge
wangbluo Aug 19, 2024
52289e4
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
1a5847e
fix the merge
wangbluo Aug 19, 2024
3353042
fix the merge
wangbluo Aug 19, 2024
64aad96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
4c82bfc
fix the merge
wangbluo Aug 19, 2024
0d8e82a
Merge branch 'fp8_merge' of https://github.com/wangbluo/ColossalAI in…
wangbluo Aug 19, 2024
12b4401
fix
wangbluo Aug 19, 2024
2eb3683
fix
wangbluo Aug 19, 2024
88b3f06
fix the merge
wangbluo Aug 19, 2024
1f703e0
fix
wangbluo Aug 19, 2024
5382311
fix
wangbluo Aug 20, 2024
f7acfa1
fix
wangbluo Aug 20, 2024
2ee6235
fix
wangbluo Aug 20, 2024
2e4cbe3
fix
wangbluo Aug 20, 2024
2d362ac
fix merge
wangbluo Aug 20, 2024
eb5ba40
fix the merge
wangbluo Aug 21, 2024
193030f
fix
wangbluo Aug 21, 2024
6aface9
fix
wangbluo Aug 21, 2024
698c8b9
fix
wangbluo Aug 21, 2024
8b8e282
fix
wangbluo Aug 21, 2024
eea37da
[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
wangbluo Aug 22, 2024
d77e66a
Merge pull request #6023 from wangbluo/fp8_merge
wangbluo Aug 22, 2024
971b16a
fix
wangbluo Aug 22, 2024
a292554
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
afe845f
Merge pull request #6024 from wangbluo/fix_merge
wangbluo Aug 22, 2024
caab4a3
Merge branch 'main' into feature/fp8_comm
ver217 Aug 22, 2024
0bc9a87
Update train_dpo.py
flybird11111 Aug 23, 2024
3b0df30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2024
9e76764
Update low_level_zero_plugin.py
flybird11111 Aug 23, 2024
0bf46c5
Merge pull request #6029 from hpcaitech/flybird11111-patch-1
wangbluo Aug 23, 2024
dae3999
fix
wangbluo Aug 26, 2024
80d24ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2024
4a6f31e
Merge pull request #6033 from wangbluo/fix
wangbluo Aug 26, 2024
17904cb
Merge pull request #6012 from hpcaitech/feature/fp8_comm
ver217 Aug 27, 2024
d383449
[CI] Remove triton version for compatibility bug; update req torch >=…
Edenzzzz Aug 27, 2024
cc1b0ef
[plugin] hotfix zero plugin (#6036)
ver217 Aug 28, 2024
4a68efb
[Colossal-LLaMA] Refactor latest APIs (#6030)
TongLi3701 Aug 28, 2024
0d3a85d
add fused norm (#6038)
TongLi3701 Aug 28, 2024
e96a076
[FP8] unsqueeze scale to make it compatible with torch.compile (#6040)
GuangyaoZhang Aug 29, 2024
e9032fb
[colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model;…
flymin Sep 2, 2024
c650a90
[Hotfix] Remove deprecated install (#6042)
TongLi3701 Sep 3, 2024
c3b5caf
[fp8] optimize all-gather (#6043)
ver217 Sep 3, 2024
26e5539
[fp8] fix linear hook (#6046)
ver217 Sep 3, 2024
5ce6dd7
[fp8] disable all_to_all_fp8 in intranode (#6045)
BurkeHulk Sep 9, 2024
b3db105
[release] update version (#6041)
ver217 Sep 10, 2024
8fd25d6
[Feature] Split cross-entropy computation in SP (#5959)
Edenzzzz Sep 10, 2024
c54c4fc
[hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)
botbw Sep 10, 2024
13946c4
[fp8] hotfix backward hook (#6053)
ver217 Sep 11, 2024
a35a078
[doc] update sp doc (#6055)
flybird11111 Sep 11, 2024
fdd84b9
fix the sp
wangbluo Sep 13, 2024
216d54e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2024
0a01e2a
fix the attn
wangbluo Sep 13, 2024
683179c
fix
wangbluo Sep 13, 2024
6eb8832
fix
wangbluo Sep 13, 2024
f393867
fix
wangbluo Sep 13, 2024
dc03217
fix
wangbluo Sep 13, 2024
696fced
[fp8] fix missing fp8_comm flag in mixtral (#6057)
botbw Sep 13, 2024
0b14a55
fix
wangbluo Sep 13, 2024
0ad3129
fix
wangbluo Sep 13, 2024
b582319
fix
wangbluo Sep 13, 2024
f20b066
[fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 …
GuangyaoZhang Sep 14, 2024
bdb125f
[doc] FP8 training and communication document (#6050)
GuangyaoZhang Sep 14, 2024
827ef3e
fix
wangbluo Sep 14, 2024
37e3523
Merge pull request #6061 from wangbluo/sp_fix
wangbluo Sep 14, 2024
10e4f7d
fix
wangbluo Sep 16, 2024
63314ce
Merge pull request #6064 from wangbluo/fix_attn
wangbluo Sep 18, 2024
4fa6b95
[moe] add parallel strategy for shared_expert && fix test for deepsee…
botbw Sep 18, 2024
f9546ba
[ColossalEval] support for vllm (#6056)
Camille7777 Sep 18, 2024
dabc2e7
[release] update version (#6062)
ver217 Sep 19, 2024
cbaa104
release FP8 news (#6068)
binmakeswell Sep 25, 2024
cfd9eda
fix the ring attn
wangbluo Sep 25, 2024
65c8297
fix the attn
wangbluo Sep 25, 2024
6fb1322
fix
wangbluo Sep 25, 2024
91ed32c
fix
wangbluo Sep 25, 2024
6705dad
fix
wangbluo Sep 25, 2024
f4daf04
add funding news (#6072)
binmakeswell Sep 26, 2024
3fab921
fix
wangbluo Sep 26, 2024
3532f77
fix
wangbluo Oct 9, 2024
3f5bec8
[feat] support zbv in mixtral benchmark;
duanjunwen Oct 9, 2024
b635dd0
fix
wangbluo Oct 9, 2024
9ee80fc
[fix] MixtralForCausalLMPolicy get_held_layer support zbv;
duanjunwen Oct 10, 2024
72b507a
[feat] update MixtralPipelineForwards --> mixtral_model_forward; supp…
duanjunwen Oct 10, 2024
646b3c5
[shardformer] fix linear 1d row and support uneven splits for fused q…
ver217 Oct 10, 2024
e234dfa
[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forwa…
duanjunwen Oct 10, 2024
f98384a
fix
wangbluo Oct 10, 2024
5ecc27e
fix
wangbluo Oct 10, 2024
6b2c506
Update README.md (#6087)
supercooledith Oct 10, 2024
efe3042
fix
wangbluo Oct 10, 2024
dc2cdaf
[shardformer] optimize seq parallelism (#6086)
ver217 Oct 11, 2024
0002ae5
fix
wangbluo Oct 11, 2024
1507a75
fix
wangbluo Oct 11, 2024
0ca16d5
[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral …
duanjunwen Oct 11, 2024
4e0e99b
fix the test
wangbluo Oct 11, 2024
703bb5c
fix the test
wangbluo Oct 11, 2024
4c8e85e
[Coati] Train DPO using PP (#6054)
TongLi3701 Oct 11, 2024
e1e86f9
fix
wangbluo Oct 14, 2024
d891e50
fix
wangbluo Oct 14, 2024
cfade4c
[feat] Linear1D_COL/ROW support zbv WeightGradStore;
duanjunwen Oct 14, 2024
a11b4b5
[feat] support use_zbv in llama, mixtral modeling; only replace Linea…
duanjunwen Oct 14, 2024
abd4551
[fix] fix test case; moe error in second iter
duanjunwen Oct 14, 2024
160e9a4
[feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;
duanjunwen Oct 14, 2024
23199e3
fix
wangbluo Oct 14, 2024
3201377
fix
wangbluo Oct 14, 2024
fe9208f
fix
wangbluo Oct 14, 2024
8ff7d0c
fix
wangbluo Oct 14, 2024
3dc08c8
fix
wangbluo Oct 15, 2024
6be9862
fix
wangbluo Oct 15, 2024
fd92789
fix
wangbluo Oct 15, 2024
bc7eead
fix
wangbluo Oct 15, 2024
9912cc8
[fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Ro…
duanjunwen Oct 15, 2024
52dcc73
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Oct 15, 2024
83cf2f8
fix
wangbluo Oct 15, 2024
dcd41d0
Merge pull request #6071 from wangbluo/ring_attention
wangbluo Oct 15, 2024
90939b7
[fix] debug zbv llama test;
duanjunwen Oct 15, 2024
62c13e7
[Ring Attention] Improve comments (#6085)
Edenzzzz Oct 16, 2024
e76308c
[fix] rm use_zbv flag in Shardconfig; rm debug info;
duanjunwen Oct 16, 2024
705b18e
[fix] add & fix llama test
duanjunwen Oct 16, 2024
cd61353
[pipeline] hotfix backward for multiple outputs (#6090)
ver217 Oct 16, 2024
2bcd0b6
[ckpt] add safetensors util
botbw Oct 14, 2024
3b1d7d1
[chore] refactor
botbw Oct 14, 2024
5ddad48
[fp8] add fallback and make compile option configurable (#6092)
ver217 Oct 18, 2024
58d8b8a
[misc] fit torch api upgradation and remove legecy import (#6093)
ver217 Oct 18, 2024
19baab5
[release] update version (#6094)
ver217 Oct 21, 2024
b10339d
fix lora ckpt save format (ColoTensor to Tensor)
BurkeHulk Oct 21, 2024
6d6cafa
pre-commit fix
BurkeHulk Oct 21, 2024
dee63cc
Merge pull request #6096 from BurkeHulk/hotfix/lora_ckpt
BurkeHulk Oct 21, 2024
80a8ca9
[extension] hotfix compile check (#6099)
ver217 Oct 24, 2024
4294ae8
[doc] sora solution news (#6100)
binmakeswell Oct 24, 2024
2eca112
[feat] support meta cache, meta_grad_send, meta_tensor_send; fix runt…
duanjunwen Oct 24, 2024
89a9a60
[MCTS] Add self-refined MCTS (#6098)
TongLi3701 Oct 24, 2024
d0ec221
[fix\ fix fail case test_shard_llama
duanjunwen Oct 25, 2024
cc0dfdd
[fix] fix test_shard_llama
duanjunwen Oct 25, 2024
03fa79a
[fix] fix llama modeling policy;
duanjunwen Oct 25, 2024
6377aa0
[fix] fix test_shard_llama ci;
duanjunwen Oct 28, 2024
5aee426
[fix] fix test zerobubble
duanjunwen Oct 28, 2024
fafe049
[fix] fix handle name; rm useless comments;
duanjunwen Oct 29, 2024
fa3ccda
[fix] fix send recv signature;
duanjunwen Oct 29, 2024
982e4ee
[fix] fix comment in llama & benchmark
duanjunwen Oct 29, 2024
d2e05a9
[feat] support no tensor parallel Linear in shardformer; Add test for…
duanjunwen Oct 30, 2024
5f09243
[fix] fix linear (no tp) ops func name;
duanjunwen Oct 31, 2024
c2e8f61
[checkpointio] fix hybrid plugin model save (#6106)
ver217 Oct 31, 2024
2f583c1
[pre-commit.ci] pre-commit autoupdate (#6078)
pre-commit-ci[bot] Oct 31, 2024
1d328ff
Merge branch 'main' into dev/zero_bubble
duanjunwen Nov 1, 2024
c82c75a
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Nov 1, 2024
3b5c314
[fix] fix fp8 args in HybridParallel
duanjunwen Nov 1, 2024
5b5fbcf
[fix] fix hybridparall use_fp8 config
duanjunwen Nov 1, 2024
0218e67
[fix] fix use_fp8 flag
duanjunwen Nov 1, 2024
8e40087
[fix] fix model zoo init
duanjunwen Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ repos:
args: ["--profile", "black"] # avoid conflict with black

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v19.1.2
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: check-merge-conflict
Expand Down
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,36 @@

</div>

## GPU Cloud HPC-AI.COM Coming!!

For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price.
Plus, when you refer a friend, you’ll receive 20% cashback or compute credits equal to 100% of their top-up!

Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance.
Don’t miss this incredible opportunity to accelerate your AI projects!

Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10!

Special Bonuses:

* Top up $1,000 and receive 300 credits
* Top up $500 and receive 100 credits

<div align="center">
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/HPCAICOM241010.jpg" width="700" />
</a>
</div>


## Latest News
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)

## Table of Contents
<ul>
Expand Down
21 changes: 18 additions & 3 deletions applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [O1 Journey](#o1-journey)
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
- [The Plan](#the-plan)
- [Real-time progress](#real-time-progress)
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
- [Quick Preview](#quick-preview)
- [Authors](#authors)
Expand Down Expand Up @@ -272,7 +272,7 @@ Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pd
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.

### Inference Quantization and Serving - After Training
## Inference Quantization and Serving - After Training

We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.

Expand All @@ -281,6 +281,21 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).

## O1 Journey
### Inference with Self-refined MCTS
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
To run inference with MCTS, simply use the following script.
```python
from coati.reasoner.guided_search.mcts import MCTS
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG

problem = "How Many R in 'Strawberry'"

search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
answer = search_tree.simulate()
print(answer)
```

## Coati7B examples

### Generation
Expand Down
5 changes: 3 additions & 2 deletions applications/ColossalChat/coati/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,11 @@ def forward(
else:
# If no reference model is provided
ref_logratios = 0.0

pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits)

loss = losses.mean()
# Calculate rewards for logging
if logprob_ref_chosen is not None:
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
Expand All @@ -167,7 +168,7 @@ def forward(
else:
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()

return losses, chosen_rewards, rejected_rewards
return loss, chosen_rewards, rejected_rewards


class LogSigLoss(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions applications/ColossalChat/coati/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
torch.Tensor: The log probabilities corresponding to the labels.
"""
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)


def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
Expand Down
26 changes: 26 additions & 0 deletions applications/ColossalChat/coati/reasoner/guided_search/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

API_KEY = "Dummy API Key"


def get_client(base_url: str | None = None) -> openai.Client:
return openai.Client(api_key=API_KEY, base_url=base_url)


def chat_completion(
messages: list[ChatCompletionMessageParam],
model: str,
base_url: str | None = None,
temperature: float = 0.8,
**kwargs,
) -> ChatCompletion:
client = get_client(base_url)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
**kwargs,
)
return response
Loading
Loading