Skip to content

Commit

Permalink
Fix multiprocessing issues in utilblocks
Browse files Browse the repository at this point in the history
Address the following issue with using num_proc>1 with Dataset.map():

```
File "/usr/lib64/python3.11/pickle.py", line 578, in save
    rv = reduce(self.proto)
         ^^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'SSLContext' object
```

The entire block object is being serialized to sent to the
multiprocessing worker. And now that this includes PipelineContext,
which includes the OpenAI client object, which includes SSLContext,
we hit a known issue: uqfoundation/dill#308

Signed-off-by: Mark McLoughlin <[email protected]>
  • Loading branch information
markmc committed Jul 11, 2024
1 parent 7cfbaa9 commit 9d92548
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions src/instructlab/sdg/utilblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ def __init__(
self.column_name = column_name
self.num_procs = batch_kwargs.get("num_procs", 8)

def _generate(self, sample) -> dict:
sample = {**sample, **self.configs[sample[self.column_name]]}
return sample
# Using a static method to avoid serializing self when using multiprocessing
@staticmethod
def _map_populate(samples, configs, column_name, num_proc=1):
def populate(sample):
return {**sample, **configs[sample[column_name]]}

return samples.map(populate, num_proc)

def generate(self, samples) -> Dataset:
samples = samples.map(self._generate, num_proc=self.num_procs)
return samples
return self._map_populate_samples(
samples, self.configs, self.column_name, self.num_procs
)


class SelectorBlock(Block):
Expand All @@ -44,13 +49,23 @@ def __init__(self, ctx, choice_map, choice_col, output_col, **batch_kwargs) -> N
self.output_col = output_col
self.num_procs = batch_kwargs.get("num_procs", 8)

def _generate(self, sample) -> dict:
sample[self.output_col] = sample[self.choice_map[sample[self.choice_col]]]
return sample
# Using a static method to avoid serializing self when using multiprocessing
@staticmethod
def _map_select_choice(samples, choice_map, choice_col, output_col, num_proc=1):
def select_choice(sample) -> dict:
sample[output_col] = sample[choice_map[sample[choice_col]]]
return sample

return samples.map(select_choice, num_proc)

def generate(self, samples: Dataset) -> Dataset:
samples = samples.map(self._generate, num_proc=self.num_procs)
return samples
return self._map_select_choice(
samples,
self.choice_map,
self.choice_col,
self.output_col,
self.num_procs,
)


class CombineColumnsBlock(Block):
Expand All @@ -63,12 +78,16 @@ def __init__(
self.separator = separator
self.num_procs = batch_kwargs.get("num_procs", 8)

def _generate(self, sample) -> dict:
sample[self.output_col] = self.separator.join(
[sample[col] for col in self.columns]
)
return sample
# Using a static method to avoid serializing self when using multiprocessing
@staticmethod
def _map_combine(samples, columns, output_col, separator, num_proc=1):
def combine(sample):
sample[output_col] = separator.join([sample[col] for col in columns])
return sample

return samples.map(combine, num_proc=num_proc)

def generate(self, samples: Dataset) -> Dataset:
samples = samples.map(self._generate, num_proc=self.num_procs)
return samples
return self._map_combine(
samples, self.columns, self.output_col, self.separator, self.num_procs
)

0 comments on commit 9d92548

Please sign in to comment.