Skip to content

Commit

Permalink
fix(interfaces): add custom json schema to node interface so it can s…
Browse files Browse the repository at this point in the history
…erialize torch objects
  • Loading branch information
LilithWittmann committed Jul 6, 2024
1 parent 32fac94 commit 65d2125
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions causy/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import multiprocessing
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Union, TypeVar, Generic, Any
from typing_extensions import Annotated
import logging

from pydantic import BaseModel, computed_field, Field
from pydantic import BaseModel, computed_field, Field, PlainValidator, WithJsonSchema
import torch

from causy.graph_utils import (
Expand Down Expand Up @@ -47,7 +48,16 @@ class NodeInterface(BaseModel, ABC):

name: str
id: str
values: Optional[torch.DoubleTensor] = Field(exclude=True, default=None)
values: Annotated[
Optional[torch.DoubleTensor],
WithJsonSchema(
{
"type": "array",
"items": {"type": "number"},
"description": "Node values",
}
),
] = Field(exclude=True, default=None)

class Config:
arbitrary_types_allowed = True
Expand Down

0 comments on commit 65d2125

Please sign in to comment.