Skip to content

Commit

Permalink
fixup! discojs/model: expose batch generator
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Jun 26, 2024
1 parent a98a0a2 commit 01313bd
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
4 changes: 2 additions & 2 deletions discojs/src/models/gpt/gpt.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ describe('gpt-tfjs', function() {
}).repeat().batch(64)

const model = new GPT(config)
for await (const epoch of model.train(tokenDataset, undefined, 5))
for await (const _ of epoch);
for (let i = 0; i < 5; i++)
for await (const _ of model.train(tokenDataset, undefined));
const generation = await model.generate("Lorem ipsum dolor", tokenizer, 1)
expect(generation).equal(data) // Assert that the model completes 'Lorem ipsum dolor' with 'sit'
})
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ As you can see in `training.ts` a client is represented by a `Disco` object:

```js
const disco = new Disco(task, { url, scheme: "federated" });
await disco.fit(dataset); // Start training on the dataset
for await (const round of disco.fit(dataset))
for await (const epoch of round)
for await (const batch of epoch);
await disco.close();
```

Expand Down
8 changes: 6 additions & 2 deletions docs/examples/training.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ import { startServer } from 'server'
async function runUser (url: URL, task: Task, dataset: data.DataSplit): Promise<void> {
// Create Disco object associated with the server url, the training scheme
const disco = new Disco(task, { url, scheme: 'federated' })
for await (const _ of disco.fit(dataset)); // Start training on the dataset

// Stop training and disconnect from the remote server
// Run training on the dataset
for await (const round of disco.fit(dataset))
for await (const epoch of round)
for await (const _ of epoch);

// Disconnect from the remote server
await disco.close()
}

Expand Down
4 changes: 3 additions & 1 deletion docs/examples/wikitext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ async function main(): Promise<void> {
const aggregator = new aggregators.MeanAggregator()
const client = new clients.federated.FederatedClient(url, task, aggregator)
const disco = new Disco(task, { scheme: 'federated', client, aggregator })
for await (const _ of disco.fit(dataset));
for await (const round of disco.fit(dataset))
for await (const epoch of round)
for await (const _ of epoch);

// Get the model and complete the prompt
if (aggregator.model === undefined) {
Expand Down
18 changes: 14 additions & 4 deletions server/tests/e2e/federated.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import { assert, expect } from 'chai'
import type { RoundLogs, WeightsContainer } from '@epfml/discojs'
import {
Disco, client as clients, data,
aggregator as aggregators, defaultTasks
aggregator as aggregators, defaultTasks,
async_iterator
} from '@epfml/discojs'
import { NodeImageLoader, NodeTabularLoader, NodeTextLoader } from '@epfml/discojs-node'

Expand Down Expand Up @@ -68,7 +69,10 @@ describe("end-to-end federated", function () {

let logs = List<RoundLogs>()
for await (const round of disco.fit(data)) {
logs = logs.push(round)
const [roundGen, roundLogs] = async_iterator.split(round)
for await (const epoch of roundGen)
for await (const _ of epoch);
logs = logs.push(await roundLogs)
}
await disco.close()

Expand Down Expand Up @@ -99,7 +103,10 @@ describe("end-to-end federated", function () {

let logs = List<RoundLogs>()
for await (const round of disco.fit(dataSplit)) {
logs = logs.push(round)
const [roundGen, roundLogs] = async_iterator.split(round)
for await (const epoch of roundGen)
for await (const _ of epoch);
logs = logs.push(await roundLogs)
}
await disco.close()

Expand Down Expand Up @@ -132,7 +139,10 @@ describe("end-to-end federated", function () {

let logs = List<RoundLogs>()
for await (const round of disco.fit(data)) {
logs = logs.push(round)
const [roundGen, roundLogs] = async_iterator.split(round)
for await (const epoch of roundGen)
for await (const _ of epoch);
logs = logs.push(await roundLogs)
}
await disco.close()

Expand Down

0 comments on commit 01313bd

Please sign in to comment.