Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions aleph_message/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from pydantic import BaseModel, Extra, Field, validator
from typing_extensions import TypeAlias

from ..utils import dump_content
from .abstract import BaseContent
from .base import Chain, HashType, MessageType
from .execution.base import MachineType, Payment, PaymentType # noqa
from .execution.instance import InstanceContent
from .execution.program import ProgramContent
from .execution.base import PaymentType, MachineType, Payment # noqa
from .item_hash import ItemHash, ItemType


Expand Down Expand Up @@ -126,7 +127,7 @@ class ForgetContent(BaseContent):
"""Content of a FORGET message"""

hashes: List[ItemHash]
aggregates: List[ItemHash] = Field(default_factory=list)
aggregates: Optional[List[ItemHash]] = None
reason: Optional[str] = None

def __hash__(self):
Expand Down Expand Up @@ -200,8 +201,45 @@ def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
)
return v

@validator("content")
def check_content(cls, v, values):
"""Check that the content matches the serialized item_content"""
item_type = values["item_type"]
if item_type != ItemType.inline:
return v

try:
item_content = json.loads(values["item_content"])
except JSONDecodeError:
raise ValueError("Field 'item_content' does not appear to be valid JSON")
json_dump = json.loads(v.json())
for key, value in json_dump.items():
if value != item_content[key]:
cls._raise_value_error(item_content, key, value)
return v

@staticmethod
def _raise_value_error(item_content, key, value):
"""Raise a ValueError with a message that explains the content/item_content mismatch"""
if isinstance(value, list):
for item in value:
if item not in item_content[key]:
raise ValueError(
f"Field 'content.{key}' does not match 'item_content.{key}': {item} != {item_content[key]}"
)
if isinstance(value, dict):
for item in value.items():
if item not in item_content[key].items():
raise ValueError(
f"Field 'content.{key}' does not match 'item_content.{key}': {value} != {item_content[key]}"
)
raise ValueError(
f"Field 'content.{key}' does not match 'item_content.{key}': {value} != {item_content[key]} or type mismatch ({type(value)} != {type(item_content[key])})"
)

@validator("item_hash")
def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
"""Check that the 'item_hash' matches the 'item_content's SHA256 hash"""
item_type = values["item_type"]
if item_type == ItemType.inline:
item_content: str = values["item_content"]
Expand All @@ -225,13 +263,15 @@ def check_item_hash(cls, v: ItemHash, values) -> ItemHash:

@validator("confirmed")
def check_confirmed(cls, v, values):
"""Check that 'confirmed' is not True without 'confirmations'"""
confirmations = values["confirmations"]
if v is True and not bool(confirmations):
raise ValueError("Message cannot be 'confirmed' without 'confirmations'")
return v

@validator("time")
def convert_float_to_datetime(cls, v, values):
"""Converts a Unix timestamp to a datetime object"""
if isinstance(v, float):
v = datetime.datetime.fromtimestamp(v)
assert isinstance(v, datetime.datetime)
Expand Down Expand Up @@ -277,20 +317,6 @@ class ProgramMessage(BaseMessage):
type: Literal[MessageType.program]
content: ProgramContent

@validator("content")
def check_content(cls, v, values):
item_type = values["item_type"]
if item_type == ItemType.inline:
item_content = json.loads(values["item_content"])
if v.dict(exclude_none=True) != item_content:
# Print differences
vdict = v.dict(exclude_none=True)
for key, value in item_content.items():
if vdict[key] != value:
print(f"{key}: {vdict[key]} != {value}")
raise ValueError("Content and item_content differ")
return v


class InstanceMessage(BaseMessage):
type: Literal[MessageType.instance]
Expand Down Expand Up @@ -337,12 +363,12 @@ def parse_message(message_dict: Dict) -> AlephMessage:


def add_item_content_and_hash(message_dict: Dict, inplace: bool = False):
# TODO: I really don't like this function. There is no validation of the
# message_dict, if it is indeed a real message, and can lead to unexpected results.
if not inplace:
message_dict = copy(message_dict)

message_dict["item_content"] = json.dumps(
message_dict["content"], separators=(",", ":")
)
message_dict["item_content"] = dump_content(message_dict["content"])
message_dict["item_hash"] = sha256(
message_dict["item_content"].encode()
).hexdigest()
Expand Down
5 changes: 5 additions & 0 deletions aleph_message/models/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel, Extra

from aleph_message.utils import dump_content


def hashable(obj):
"""Convert `obj` into a hashable object."""
Expand All @@ -26,3 +28,6 @@ class BaseContent(BaseModel):

class Config:
extra = Extra.forbid

def json(self, *args, **kwargs):
return dump_content(self)
8 changes: 2 additions & 6 deletions aleph_message/models/execution/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@

from pydantic import Field

from .environment import (
FunctionEnvironment,
HostRequirements,
MachineResources,
)
from ..abstract import BaseContent, HashableModel
from .base import Payment
from .environment import FunctionEnvironment, HostRequirements, MachineResources
from .volume import MachineVolume
from ..abstract import BaseContent, HashableModel


class BaseExecutableContent(HashableModel, BaseContent, ABC):
Expand Down
2 changes: 1 addition & 1 deletion aleph_message/models/execution/program.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal, Optional, List
from typing import List, Literal, Optional

from pydantic import Field

Expand Down
4 changes: 2 additions & 2 deletions aleph_message/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,12 @@ def test_create_new_message():
"chain": "ETH",
"sender": "0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9",
"type": "POST",
"time": "1625652287.017",
"time": 1625652287.017,
"item_type": "inline",
"content": {
"address": "0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9",
"type": "test-message",
"time": "1625652287.017",
"time": 1625652287.017,
"content": {
"hello": "world",
},
Expand Down
49 changes: 48 additions & 1 deletion aleph_message/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,53 @@
from aleph_message.utils import Gigabytes, gigabyte_to_mebibyte
from datetime import date, datetime, time

import pytest
from pydantic import BaseModel

from aleph_message.utils import (
Gigabytes,
dump_content,
extended_json_encoder,
gigabyte_to_mebibyte,
)


def test_gigabyte_to_mebibyte():
assert gigabyte_to_mebibyte(Gigabytes(1)) == 954
assert gigabyte_to_mebibyte(Gigabytes(100)) == 95368


def test_extended_json_encoder():
now = datetime.now()
today = date.today()
now_time = time(hour=1, minute=2, second=3, microsecond=4)
assert extended_json_encoder(now) == now.timestamp()
assert extended_json_encoder(today) == today.toordinal()
assert extended_json_encoder(now_time) == 3723.000004


def test_dump_content():
class TestModel(BaseModel):
address: str
time: float

assert (
dump_content({"address": "0x1", "time": 1.0}) == '{"address":"0x1","time":1.0}'
)
assert (
dump_content(TestModel(address="0x1", time=1.0))
== '{"address":"0x1","time":1.0}'
)


@pytest.mark.parametrize(
"content",
[
1,
"test",
None,
True,
],
)
def test_dump_content_invalid(content):
with pytest.raises(TypeError):
dump_content(content)
37 changes: 36 additions & 1 deletion aleph_message/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

import json
import math
from typing import NewType
from datetime import date, datetime, time
from typing import Any, Dict, NewType, Union

from pydantic import BaseModel
from pydantic.json import pydantic_encoder

Megabytes = NewType("Megabytes", int)
Mebibytes = NewType("Mebibytes", int)
Expand All @@ -15,3 +20,33 @@ def gigabyte_to_mebibyte(n: Gigabytes) -> Mebibytes:
mebibyte = 2**20
gigabyte = 10**9
return Mebibytes(math.ceil(n * gigabyte / mebibyte))


def extended_json_encoder(obj: Any) -> Any:
"""
Extended JSON encoder for dumping objects that contain pydantic models and datetime objects.
"""
if isinstance(obj, datetime):
return obj.timestamp()
elif isinstance(obj, date):
return obj.toordinal()
elif isinstance(obj, time):
return obj.hour * 3600 + obj.minute * 60 + obj.second + obj.microsecond / 1e6
else:
return pydantic_encoder(obj)


def dump_content(obj: Union[Dict, BaseModel]) -> str:
"""Dump message content as JSON string."""
if isinstance(obj, dict):
# without None values
obj = {k: v for k, v in obj.items() if v is not None}
return json.dumps(obj, separators=(",", ":"), default=extended_json_encoder)
elif isinstance(obj, BaseModel):
return json.dumps(
obj.dict(exclude_none=True),
separators=(",", ":"),
default=extended_json_encoder,
)
else:
raise TypeError(f"Invalid type: `{type(obj)}`")