Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.

Commit

Permalink
refactor(experimental): refactor ScalarEnums in codecs (#2100)
Browse files Browse the repository at this point in the history
This PR refactors the definition of a "Scalar Enum" in the context of codecs which unlocks stricter types and adds support for hybrid scalar enums (both numeric and lexical).

- The `ScalarEnum` type now simply defines the "lookup object" created by the `enum` constructor.
- The `ScalarEnumFrom<T>` and `ScalarEnumTo<T>` types describe the input/output to encode/decode respectively, given a `enum` constructor.
- The Scalar Enum `Encoder`, `Decoder` and `Codec` functions now use these types to create stronger type constraints. This means, TypeScript will now fail if we provide the wrong key or value for the enum.
- The `getScalarEnumStats` helper function was also updated to provide non-duplicated keys and values for the `enum` constructor. Whilst it was not an issue for full-numeric or full-lexical enums, the previous implementation made it impossible to encode/decode hybrid enums. Tests were added for that scenario.

As an added bonus, this PR should make the closed PR #2091 obsolete.
  • Loading branch information
lorisleiva authored Feb 7, 2024
1 parent 1a0940d commit 606040b
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 75 deletions.
48 changes: 40 additions & 8 deletions packages/codecs-data-structures/src/__tests__/scalar-enum-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,22 @@ describe('getScalarEnumCodec', () => {
LEFT = 'Left',
RIGHT = 'Right',
}
enum Hybrid {
NUMERIC,
LEXICAL = 'Lexical',
}

it('encodes scalar enums', () => {
// Bad.
expect(scalarEnum(Feedback).encode(Feedback.BAD)).toStrictEqual(b('00'));
expect(scalarEnum(Feedback).encode('BAD')).toStrictEqual(b('00'));
expect(scalarEnum(Feedback).encode('0')).toStrictEqual(b('00'));
expect(scalarEnum(Feedback).encode(0)).toStrictEqual(b('00'));
expect(scalarEnum(Feedback).read(b('00'), 0)).toStrictEqual([Feedback.BAD, 1]);
expect(scalarEnum(Feedback).read(b('ffff00'), 2)).toStrictEqual([Feedback.BAD, 3]);

// Good.
expect(scalarEnum(Feedback).encode(Feedback.GOOD)).toStrictEqual(b('01'));
expect(scalarEnum(Feedback).encode('GOOD')).toStrictEqual(b('01'));
expect(scalarEnum(Feedback).encode('1')).toStrictEqual(b('01'));
expect(scalarEnum(Feedback).encode(1)).toStrictEqual(b('01'));
expect(scalarEnum(Feedback).read(b('01'), 0)).toStrictEqual([Feedback.GOOD, 1]);
expect(scalarEnum(Feedback).read(b('ffff01'), 2)).toStrictEqual([Feedback.GOOD, 3]);
Expand All @@ -43,9 +45,10 @@ describe('getScalarEnumCodec', () => {
expect(u64Feedback.read(b('0100000000000000'), 0)).toStrictEqual([Feedback.GOOD, 8]);

// Invalid examples.
// @ts-expect-error Invalid scalar enum variant.
expect(() => scalarEnum(Feedback).encode('Missing')).toThrow(
'Invalid scalar enum variant. ' +
'Expected one of [0, 1, BAD, GOOD] or a number between 0 and 1, ' +
'Expected one of [BAD, GOOD] or a number between 0 and 1, ' +
'got "Missing".',
);
expect(() => scalarEnum(Feedback).read(new Uint8Array([2]), 0)).toThrow(
Expand All @@ -56,34 +59,35 @@ describe('getScalarEnumCodec', () => {
it('encodes lexical scalar enums', () => {
// Up.
expect(scalarEnum(Direction).encode(Direction.UP)).toStrictEqual(b('00'));
expect(scalarEnum(Direction).encode('UP')).toStrictEqual(b('00'));
expect(scalarEnum(Direction).encode('Up' as Direction)).toStrictEqual(b('00'));
expect(scalarEnum(Direction).encode('UP' as Direction)).toStrictEqual(b('00'));
expect(scalarEnum(Direction).read(b('00'), 0)).toStrictEqual([Direction.UP, 1]);
expect(scalarEnum(Direction).read(b('ffff00'), 2)).toStrictEqual([Direction.UP, 3]);

// Down.
expect(scalarEnum(Direction).encode(Direction.DOWN)).toStrictEqual(b('01'));
expect(scalarEnum(Direction).encode('DOWN')).toStrictEqual(b('01'));
expect(scalarEnum(Direction).encode('Down' as Direction)).toStrictEqual(b('01'));
expect(scalarEnum(Direction).encode('DOWN' as Direction)).toStrictEqual(b('01'));
expect(scalarEnum(Direction).read(b('01'), 0)).toStrictEqual([Direction.DOWN, 1]);
expect(scalarEnum(Direction).read(b('ffff01'), 2)).toStrictEqual([Direction.DOWN, 3]);

// Left.
expect(scalarEnum(Direction).encode(Direction.LEFT)).toStrictEqual(b('02'));
expect(scalarEnum(Direction).encode('LEFT')).toStrictEqual(b('02'));
expect(scalarEnum(Direction).encode('Left' as Direction)).toStrictEqual(b('02'));
expect(scalarEnum(Direction).encode('LEFT' as Direction)).toStrictEqual(b('02'));
expect(scalarEnum(Direction).read(b('02'), 0)).toStrictEqual([Direction.LEFT, 1]);
expect(scalarEnum(Direction).read(b('ffff02'), 2)).toStrictEqual([Direction.LEFT, 3]);

// Right.
expect(scalarEnum(Direction).encode(Direction.RIGHT)).toStrictEqual(b('03'));
expect(scalarEnum(Direction).encode('RIGHT')).toStrictEqual(b('03'));
expect(scalarEnum(Direction).encode('Right' as Direction)).toStrictEqual(b('03'));
expect(scalarEnum(Direction).encode('RIGHT' as Direction)).toStrictEqual(b('03'));
expect(scalarEnum(Direction).read(b('03'), 0)).toStrictEqual([Direction.RIGHT, 1]);
expect(scalarEnum(Direction).read(b('ffff03'), 2)).toStrictEqual([Direction.RIGHT, 3]);

// Invalid examples.
expect(() => scalarEnum(Direction).encode('Diagonal' as unknown as Direction)).toThrow(
// @ts-expect-error Invalid scalar enum variant.
expect(() => scalarEnum(Direction).encode('Diagonal')).toThrow(
'Invalid scalar enum variant. ' +
'Expected one of [UP, DOWN, LEFT, RIGHT, Up, Down, Left, Right] ' +
'or a number between 0 and 3, got "Diagonal".',
Expand All @@ -93,10 +97,38 @@ describe('getScalarEnumCodec', () => {
);
});

it('encodes hybrid scalar enums', () => {
// Numeric.
expect(scalarEnum(Hybrid).encode(Hybrid.NUMERIC)).toStrictEqual(b('00'));
expect(scalarEnum(Hybrid).encode('NUMERIC')).toStrictEqual(b('00'));
expect(scalarEnum(Hybrid).encode(0)).toStrictEqual(b('00'));
expect(scalarEnum(Hybrid).read(b('00'), 0)).toStrictEqual([Hybrid.NUMERIC, 1]);
expect(scalarEnum(Hybrid).read(b('ffff00'), 2)).toStrictEqual([Hybrid.NUMERIC, 3]);

// Lexical.
expect(scalarEnum(Hybrid).encode(Hybrid.LEXICAL)).toStrictEqual(b('01'));
expect(scalarEnum(Hybrid).encode('LEXICAL')).toStrictEqual(b('01'));
expect(scalarEnum(Hybrid).encode('Lexical' as Hybrid)).toStrictEqual(b('01'));
expect(scalarEnum(Hybrid).read(b('01'), 0)).toStrictEqual([Hybrid.LEXICAL, 1]);
expect(scalarEnum(Hybrid).read(b('ffff01'), 2)).toStrictEqual([Hybrid.LEXICAL, 3]);

// Invalid examples.
// @ts-expect-error Invalid scalar enum variant.
expect(() => scalarEnum(Hybrid).encode('Missing')).toThrow(
'Invalid scalar enum variant. ' +
'Expected one of [NUMERIC, LEXICAL, Lexical] ' +
'or a number between 0 and 1, got "Missing".',
);
expect(() => scalarEnum(Hybrid).read(new Uint8Array([2]), 0)).toThrow(
'Enum discriminator out of range. Expected a number between 0 and 1, got 2.',
);
});

it('has the right sizes', () => {
expect(scalarEnum(Empty).fixedSize).toBe(1);
expect(scalarEnum(Feedback).fixedSize).toBe(1);
expect(scalarEnum(Direction).fixedSize).toBe(1);
expect(scalarEnum(Hybrid).fixedSize).toBe(1);
expect(scalarEnum(Feedback, { size: u32() }).fixedSize).toBe(4);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,24 @@ enum Feedback {
BAD,
GOOD,
}
type FeedbackInput = Feedback | keyof typeof Feedback;

enum Direction {
UP = 'Up',
DOWN = 'Down',
LEFT = 'Left',
RIGHT = 'Right',
}
type DirectionInput = Direction | keyof typeof Direction;

{
// [getScalarEnumEncoder]: It knows if the encoder is fixed size or variable size.
getScalarEnumEncoder(Feedback) satisfies FixedSizeEncoder<Feedback, 1>;
getScalarEnumEncoder(Direction) satisfies FixedSizeEncoder<Direction, 1>;
getScalarEnumEncoder(Feedback, { size: getU32Encoder() }) satisfies FixedSizeEncoder<Feedback, 4>;
getScalarEnumEncoder(Feedback, { size: {} as VariableSizeEncoder<number> }) satisfies VariableSizeEncoder<Feedback>;
getScalarEnumEncoder(Feedback) satisfies FixedSizeEncoder<FeedbackInput, 1>;
getScalarEnumEncoder(Direction) satisfies FixedSizeEncoder<DirectionInput, 1>;
getScalarEnumEncoder(Feedback, { size: getU32Encoder() }) satisfies FixedSizeEncoder<FeedbackInput, 4>;
getScalarEnumEncoder(Feedback, {
size: {} as VariableSizeEncoder<number>,
}) satisfies VariableSizeEncoder<FeedbackInput>;
}

{
Expand All @@ -39,8 +44,11 @@ enum Direction {

{
// [getScalarEnumCodec]: It knows if the codec is fixed size or variable size.
getScalarEnumCodec(Feedback) satisfies FixedSizeCodec<Feedback, Feedback, 1>;
getScalarEnumCodec(Direction) satisfies FixedSizeCodec<Direction, Direction, 1>;
getScalarEnumCodec(Feedback, { size: getU32Codec() }) satisfies FixedSizeCodec<Feedback, Feedback, 4>;
getScalarEnumCodec(Feedback, { size: {} as VariableSizeCodec<number> }) satisfies VariableSizeCodec<Feedback>;
getScalarEnumCodec(Feedback) satisfies FixedSizeCodec<FeedbackInput, Feedback, 1>;
getScalarEnumCodec(Direction) satisfies FixedSizeCodec<DirectionInput, Direction, 1>;
getScalarEnumCodec(Feedback, { size: getU32Codec() }) satisfies FixedSizeCodec<FeedbackInput, Feedback, 4>;
getScalarEnumCodec(Feedback, { size: {} as VariableSizeCodec<number> }) satisfies VariableSizeCodec<
FeedbackInput,
Feedback
>;
}
142 changes: 84 additions & 58 deletions packages/codecs-data-structures/src/scalar-enum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,36 @@ import {
} from '@solana/codecs-numbers';

/**
* Defines a scalar enum as a type from its constructor.
* Defines the "lookup object" of a scalar enum.
*
* @example
* ```ts
* enum Direction { Left, Right };
* type DirectionType = ScalarEnum<Direction>;
* ```
*/
export type ScalarEnum<T> = ({ [key: number | string]: string | number | T } | number | T) & NonNullable<unknown>;
export type ScalarEnum = { [key: string]: string | number };

/**
* Returns the allowed input for a scalar enum.
*
* @example
* ```ts
* enum Direction { Left, Right };
* type DirectionInput = ScalarEnumFrom<Direction>; // "Left" | "Right" | 0 | 1
* ```
*/
export type ScalarEnumFrom<TEnum extends ScalarEnum> = keyof TEnum | TEnum[keyof TEnum];

/**
* Returns all the available variants of a scalar enum.
*
* @example
* ```ts
* enum Direction { Left, Right };
* type DirectionOutput = ScalarEnumFrom<Direction>; // 0 | 1
* ```
*/
export type ScalarEnumTo<TEnum extends ScalarEnum> = TEnum[keyof TEnum];

/** Defines the config for scalar enum codecs. */
export type ScalarEnumCodecConfig<TDiscriminator extends NumberCodec | NumberEncoder | NumberDecoder> = {
Expand All @@ -49,37 +70,37 @@ export type ScalarEnumCodecConfig<TDiscriminator extends NumberCodec | NumberEnc
* @param constructor - The constructor of the scalar enum.
* @param config - A set of config for the encoder.
*/
export function getScalarEnumEncoder<TFrom, TFromConstructor extends ScalarEnum<TFrom>>(
constructor: TFromConstructor,
): FixedSizeEncoder<TFrom, 1>;
export function getScalarEnumEncoder<TFrom, TFromConstructor extends ScalarEnum<TFrom>, TSize extends number>(
constructor: TFromConstructor,
export function getScalarEnumEncoder<TEnum extends ScalarEnum>(
constructor: TEnum,
): FixedSizeEncoder<ScalarEnumFrom<TEnum>, 1>;
export function getScalarEnumEncoder<TEnum extends ScalarEnum, TSize extends number>(
constructor: TEnum,
config: ScalarEnumCodecConfig<NumberEncoder> & { size: FixedSizeNumberEncoder<TSize> },
): FixedSizeEncoder<TFrom, TSize>;
export function getScalarEnumEncoder<TFrom, TFromConstructor extends ScalarEnum<TFrom>>(
constructor: TFromConstructor,
): FixedSizeEncoder<ScalarEnumFrom<TEnum>, TSize>;
export function getScalarEnumEncoder<TEnum extends ScalarEnum>(
constructor: TEnum,
config?: ScalarEnumCodecConfig<NumberEncoder>,
): VariableSizeEncoder<TFrom>;
export function getScalarEnumEncoder<TFrom, TFromConstructor extends ScalarEnum<TFrom>>(
constructor: TFromConstructor,
): VariableSizeEncoder<ScalarEnumFrom<TEnum>>;
export function getScalarEnumEncoder<TEnum extends ScalarEnum>(
constructor: TEnum,
config: ScalarEnumCodecConfig<NumberEncoder> = {},
): Encoder<TFrom> {
): Encoder<ScalarEnumFrom<TEnum>> {
const prefix = config.size ?? getU8Encoder();
const { minRange, maxRange, stringValues, enumKeys, enumValues } = getScalarEnumStats(constructor);
return mapEncoder(prefix, (value: TFrom): number => {
const { minRange, maxRange, allStringInputs, enumKeys, enumValues } = getScalarEnumStats(constructor);
return mapEncoder(prefix, (value: ScalarEnumFrom<TEnum>): number => {
const isInvalidNumber = typeof value === 'number' && (value < minRange || value > maxRange);
const isInvalidString = typeof value === 'string' && !stringValues.includes(value);
const isInvalidString = typeof value === 'string' && !allStringInputs.includes(value);
if (isInvalidNumber || isInvalidString) {
// TODO: Coded error.
throw new Error(
`Invalid scalar enum variant. ` +
`Expected one of [${stringValues.join(', ')}] ` +
`Expected one of [${allStringInputs.join(', ')}] ` +
`or a number between ${minRange} and ${maxRange}, ` +
`got "${value}".`,
);
}
if (typeof value === 'number') return value;
const valueIndex = enumValues.indexOf(value);
const valueIndex = enumValues.indexOf(value as string);
if (valueIndex >= 0) return valueIndex;
return enumKeys.indexOf(value as string);
});
Expand All @@ -91,24 +112,24 @@ export function getScalarEnumEncoder<TFrom, TFromConstructor extends ScalarEnum<
* @param constructor - The constructor of the scalar enum.
* @param config - A set of config for the decoder.
*/
export function getScalarEnumDecoder<TTo, TToConstructor extends ScalarEnum<TTo>>(
constructor: TToConstructor,
): FixedSizeDecoder<TTo, 1>;
export function getScalarEnumDecoder<TTo, TToConstructor extends ScalarEnum<TTo>, TSize extends number>(
constructor: TToConstructor,
export function getScalarEnumDecoder<TEnum extends ScalarEnum>(
constructor: TEnum,
): FixedSizeDecoder<ScalarEnumTo<TEnum>, 1>;
export function getScalarEnumDecoder<TEnum extends ScalarEnum, TSize extends number>(
constructor: TEnum,
config: ScalarEnumCodecConfig<NumberDecoder> & { size: FixedSizeNumberDecoder<TSize> },
): FixedSizeDecoder<TTo, TSize>;
export function getScalarEnumDecoder<TTo, TToConstructor extends ScalarEnum<TTo>>(
constructor: TToConstructor,
): FixedSizeDecoder<ScalarEnumTo<TEnum>, TSize>;
export function getScalarEnumDecoder<TEnum extends ScalarEnum>(
constructor: TEnum,
config?: ScalarEnumCodecConfig<NumberDecoder>,
): VariableSizeDecoder<TTo>;
export function getScalarEnumDecoder<TTo, TToConstructor extends ScalarEnum<TTo>>(
constructor: TToConstructor,
): VariableSizeDecoder<ScalarEnumTo<TEnum>>;
export function getScalarEnumDecoder<TEnum extends ScalarEnum>(
constructor: TEnum,
config: ScalarEnumCodecConfig<NumberDecoder> = {},
): Decoder<TTo> {
): Decoder<ScalarEnumTo<TEnum>> {
const prefix = config.size ?? getU8Decoder();
const { minRange, maxRange, isNumericEnum, enumValues } = getScalarEnumStats(constructor);
return mapDecoder(prefix, (value: number | bigint): TTo => {
const { minRange, maxRange, enumKeys } = getScalarEnumStats(constructor);
return mapDecoder(prefix, (value: number | bigint): ScalarEnumTo<TEnum> => {
const valueAsNumber = Number(value);
if (valueAsNumber < minRange || valueAsNumber > maxRange) {
// TODO: Coded error.
Expand All @@ -117,7 +138,7 @@ export function getScalarEnumDecoder<TTo, TToConstructor extends ScalarEnum<TTo>
`Expected a number between ${minRange} and ${maxRange}, got ${valueAsNumber}.`,
);
}
return (isNumericEnum ? valueAsNumber : enumValues[valueAsNumber]) as TTo;
return constructor[enumKeys[valueAsNumber]] as ScalarEnumTo<TEnum>;
});
}

Expand All @@ -127,45 +148,50 @@ export function getScalarEnumDecoder<TTo, TToConstructor extends ScalarEnum<TTo>
* @param constructor - The constructor of the scalar enum.
* @param config - A set of config for the codec.
*/
export function getScalarEnumCodec<TFrom, TFromConstructor extends ScalarEnum<TFrom>>(
constructor: TFromConstructor,
): FixedSizeCodec<TFrom, TFrom, 1>;
export function getScalarEnumCodec<TFrom, TFromConstructor extends ScalarEnum<TFrom>, TSize extends number>(
constructor: TFromConstructor,
export function getScalarEnumCodec<TEnum extends ScalarEnum>(
constructor: TEnum,
): FixedSizeCodec<ScalarEnumFrom<TEnum>, ScalarEnumTo<TEnum>, 1>;
export function getScalarEnumCodec<TEnum extends ScalarEnum, TSize extends number>(
constructor: TEnum,
config: ScalarEnumCodecConfig<NumberCodec> & { size: FixedSizeNumberCodec<TSize> },
): FixedSizeCodec<TFrom, TFrom, TSize>;
export function getScalarEnumCodec<TFrom, TFromConstructor extends ScalarEnum<TFrom>>(
constructor: TFromConstructor,
): FixedSizeCodec<ScalarEnumFrom<TEnum>, ScalarEnumTo<TEnum>, TSize>;
export function getScalarEnumCodec<TEnum extends ScalarEnum>(
constructor: TEnum,
config?: ScalarEnumCodecConfig<NumberCodec>,
): VariableSizeCodec<TFrom>;
export function getScalarEnumCodec<TFrom, TFromConstructor extends ScalarEnum<TFrom>>(
constructor: TFromConstructor,
): VariableSizeCodec<ScalarEnumFrom<TEnum>, ScalarEnumTo<TEnum>>;
export function getScalarEnumCodec<TEnum extends ScalarEnum>(
constructor: TEnum,
config: ScalarEnumCodecConfig<NumberCodec> = {},
): Codec<TFrom> {
): Codec<ScalarEnumFrom<TEnum>, ScalarEnumTo<TEnum>> {
return combineCodec(getScalarEnumEncoder(constructor, config), getScalarEnumDecoder(constructor, config));
}

function getScalarEnumStats<TFrom>(constructor: ScalarEnum<TFrom>): {
function getScalarEnumStats<TEnum extends ScalarEnum>(
constructor: TEnum,
): {
allStringInputs: string[];
enumKeys: string[];
enumValues: TFrom[];
isNumericEnum: boolean;
enumValues: (string | number)[];
minRange: number;
maxRange: number;
stringValues: string[];
} {
const enumKeys = Object.keys(constructor);
const enumValues = Object.values(constructor);
const isNumericEnum = enumValues.some(v => typeof v === 'number');
const numericValues = Object.values(constructor).filter(v => typeof v === 'number') as number[];
const deduplicatedConstructor = Object.fromEntries(
Object.entries(constructor).slice(numericValues.length),
) as Record<string, string | number>;
const enumKeys = Object.keys(deduplicatedConstructor);
const enumValues = Object.values(deduplicatedConstructor);
const minRange = 0;
const maxRange = isNumericEnum ? enumValues.length / 2 - 1 : enumValues.length - 1;
const stringValues: string[] = isNumericEnum ? [...enumKeys] : [...new Set([...enumKeys, ...enumValues])];
const maxRange = enumValues.length - 1;
const allStringInputs: string[] = [
...new Set([...enumKeys, ...enumValues.filter((v): v is string => typeof v === 'string')]),
];

return {
allStringInputs,
enumKeys,
enumValues,
isNumericEnum,
maxRange,
minRange,
stringValues,
};
}
2 changes: 1 addition & 1 deletion packages/codecs-data-structures/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
"extends": "tsconfig/base.json",
"include": ["src"],
"compilerOptions": {
"lib": ["DOM", "ES2017", "ES2022.Error"]
"lib": ["DOM", "ES2019", "ES2022.Error"]
}
}

0 comments on commit 606040b

Please sign in to comment.