Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 61 additions & 52 deletions cogs/commands/market.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import heapq
import time
from collections import defaultdict

from discord.ext import commands
from discord.ext.commands import Bot, Context, check, clean_content
Expand All @@ -14,11 +15,12 @@


class Order:
def __init__(self, price, order_type, user_id):
def __init__(self, price, order_type, user_id, qty, order_time):
self.user_id = user_id
self.price = price
self.order_type = order_type
self.order_time = time.time()
self.qty = qty
self.order_time = order_time

def __lt__(self, other):
if self.order_type == 'ask':
Expand All @@ -37,7 +39,7 @@ def __gt__(self, other):
return self.price < other.price or (self.price == other.price and self.order_time > other.order_time)

def __str__(self):
return f'{self.order_type} <@{self.price}> <@{self.user_id}>'
return f'{self.order_type} {self.qty}@<@{self.price}> <@{self.user_id}>'

class Market:
def __init__(self, stock_name):
Expand All @@ -50,23 +52,22 @@ def __init__(self, stock_name):
self.last_trade = None
self.open = True

def bid(self, price, user_id):
self.bids.append(Order(price, 'bid', user_id))
heapq.heapify(self.bids)
def bid(self, price, user_id, qty, order_time=None):
order_time = time.time() if order_time is None else order_time
heapq.heappush(self.bids, Order(price, 'bid', user_id, qty, order_time))
return self.match()

def ask(self, price, user_id):
self.asks.append(Order(price, 'ask', user_id))
heapq.heapify(self.asks)
def ask(self, price, user_id, qty, order_time=None):
order_time = time.time() if order_time is None else order_time
heapq.heappush(self.asks, Order(price, 'ask', user_id, qty, order_time))
return self.match()

def match(self):
if len(self.bids) == 0 or len(self.asks) == 0:
return None

if self.bids[0].price >= self.asks[0].price:
matched = []
while len(self.bids) > 0 and len(self.asks) > 0 and self.bids[0].price >= self.asks[0].price:
bid = heapq.heappop(self.bids)
ask = heapq.heappop(self.asks)
qty = min(bid.qty, ask.qty)

if bid.user_id not in self.trade_history:
self.trade_history[bid.user_id] = []
Expand All @@ -76,33 +77,33 @@ def match(self):

earliest_trade = min(bid, ask, key=lambda x: x.order_time)

bid.price = earliest_trade.price
ask.price = earliest_trade.price
bought = Order(earliest_trade.price, 'bid', bid.user_id, qty, bid.order_time)
sold = Order(earliest_trade.price, 'ask', ask.user_id, qty, ask.order_time)

self.trade_history[bid.user_id].append(bid)
self.trade_history[ask.user_id].append(ask)

self.last_trade = f"<@{bid.user_id}> bought from <@{ask.user_id}> at {bid.price}"
self.trade_history[bid.user_id].append(bought)
self.trade_history[ask.user_id].append(sold)

return self.last_trade
return None
self.last_trade = f"<@{bid.user_id}> bought {qty} from <@{ask.user_id}> at {bought.price}"

if ask.qty > qty:
heapq.heappush(self.asks, Order(ask.price, 'ask', ask.user_id, ask.qty - qty, ask.order_time))
elif bid.qty > qty:
heapq.heappush(self.bids, Order(bid.price, 'bid', bid.user_id, bid.qty - qty, bid.order_time))

matched.append(self.last_trade)

return "\n".join(matched) if len(matched) > 0 else None



def close_market(self, valuation):
user_to_profit = {}
for user in self.trade_history:
user_valuation = 0
for trade in self.trade_history[user]:
if trade.order_type == 'bid':
user_valuation -= trade.price
user_valuation += valuation
else:
user_valuation += trade.price
user_valuation -= valuation

user_to_profit[user] = user_valuation

closing = valuation * sum(trade.qty if trade.order_type == 'bid' else -trade.qty for trade in self.trade_history[user])
# Note: accumulating _value_ not position, so signs are reversed
pnl = sum(trade.price * (trade.qty if trade.order_type == 'ask' else -trade.qty) for trade in self.trade_history[user])
user_to_profit[user] = closing + pnl

self.open = False

return user_to_profit
Expand Down Expand Up @@ -135,13 +136,17 @@ def __str__(self):
ret_str = "Market is: "
ret_str += "OPEN\n\n" if self.open else "CLOSED\n\n"

# Count bids and asks for each price level
bid_counts = {}
ask_counts = {}
# Count bids and asks and sum quantity for each price level
bid_counts = defaultdict(lambda: [0,0])
ask_counts = defaultdict(lambda: [0,0])
for bid in self.bids:
bid_counts[bid.price] = bid_counts.get(bid.price, 0) + 1
level = bid_counts[bid.price]
level[0] += 1
level[1] += bid.qty
for ask in self.asks:
ask_counts[ask.price] = ask_counts.get(ask.price, 0) + 1
level = ask_counts[ask.price]
level[0] += 1
level[1] += ask.qty

# Get price levels; highest first
all_prices = sorted(set(bid_counts.keys()).union(set(ask_counts.keys())), reverse=True)
Expand All @@ -152,13 +157,13 @@ def __str__(self):
order_book_lines.append("No outstanding orders\n")
else:
order_book_lines.append("```")
order_book_lines.append(f"{'Bid Volume':<15} | {'Price':<10} | {'Ask Volume'}")
order_book_lines.append(f"{'Bid Orders':<15} | {'Bid Volume':<15} | {'Price':<10} | {'Ask Volume':<15} | {'Ask Orders'}")

for price in all_prices:
bid_vol = bid_counts.get(price, " " * 15)
ask_vol = ask_counts.get(price, " " * 10)
bid_vol = bid_counts.get(price, [" " * 15] * 2)
ask_vol = ask_counts.get(price, [" " * 10] * 2)
formatted_price = f"{price:.2f}"
order_book_lines.append(f"{str(bid_vol):<15} | {str(formatted_price):<10} | {str(ask_vol)}")
order_book_lines.append(f"{str(bid_vol[0]):<15} | {str(bid_vol[1]):<15} | {str(formatted_price):<10} | {str(ask_vol[1]):<15} | {str(ask_vol[0])}")

order_book_lines.append("```")

Expand Down Expand Up @@ -200,9 +205,9 @@ async def view_market(self, ctx: Context, *, market: clean_content):
await ctx.reply(market_str, ephemeral=True)

@commands.hybrid_command(help=LONG_HELP_TEXT, brief=SHORT_HELP_TEXT)
async def bid_market(self, ctx: Context, price: float, *, market: clean_content):
async def bid_market(self, ctx: Context, price: float, qty: int, *, market: clean_content):
"""You would place a bid by using this command
'!bid_market 100 "AAPL"'
'!bid_market 123.4 15 "AAPL"'
"""
if market not in self.live_markets:
await ctx.reply("Market does not exist", ephemeral=True)
Expand All @@ -214,15 +219,19 @@ async def bid_market(self, ctx: Context, price: float, *, market: clean_content)
await ctx.reply("Market is closed", ephemeral=True)
return

did_trade = market_obj.bid(price, ctx.author.id)
did_trade = market_obj.bid(price, ctx.author.id, qty)

await ctx.reply("Bid placed", ephemeral=True)

if did_trade is not None:
await ctx.reply(did_trade, ephemeral=False)

@commands.hybrid_command(help=LONG_HELP_TEXT, brief=SHORT_HELP_TEXT)
async def ask_market(self, ctx: Context, price: float, *, market: clean_content):
async def ask_market(self, ctx: Context, price: float, qty: int, *, market: clean_content):
"""You would place an ask by using this command
'!ask_market 123.4 15 "AAPL"'
"""

if market not in self.live_markets:
await ctx.reply("Market does not exist", ephemeral=True)
return
Expand All @@ -234,7 +243,7 @@ async def ask_market(self, ctx: Context, price: float, *, market: clean_content)
return


did_trade = market_obj.ask(price, ctx.author.id)
did_trade = market_obj.ask(price, ctx.author.id, qty)

await ctx.reply("Ask placed", ephemeral=True)

Expand All @@ -250,14 +259,14 @@ async def positions_market(self, ctx: Context, *, market: clean_content):
market_obj = self.live_markets[market]

user_trades = market_obj.trade_history.get(ctx.author.id, [])
user_asks = [trade.price for trade in user_trades if trade.order_type == 'ask']
user_bids = [trade.price for trade in user_trades if trade.order_type == 'bid']
user_asks = "\n".join(f"{trade.qty}@{trade.price}" for trade in user_trades if trade.order_type == 'ask')
user_bids = "\n".join(f"{trade.qty}@{trade.price}" for trade in user_trades if trade.order_type == 'bid')
net = sum(trade.qty if trade.order_type == 'bid' else -trade.qty for trade in user_trades)

positions = f"Positions for <@{ctx.author.id}> in {market_obj.stock_name}\n"
positions += "Bids\n"
positions += "\n".join([str(bid) for bid in user_bids])
positions += "\n\nAsks\n"
positions += "\n".join([str(ask) for ask in user_asks])
positions += f"Net position: {net}\n"
positions += f"Bids\n{user_bids}"
positions += f"Asks\n{user_asks}"

await ctx.reply(str(positions), ephemeral=True)

Expand Down
86 changes: 86 additions & 0 deletions tests/test_market.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from cogs.commands.market import Market


def test_can_place_orders():
m = Market("TEST")
assert m.ask(102, 1, 1, 1) is None
assert m.bid(101, 2, 3, 2) is None
assert m.ask(102, 3, 4, 3) is None
assert len(m.asks) == 2
assert len(m.bids) == 1

assert str(m) == """Market is: OPEN

📊 **TEST Order Book** 📊
```
Bid Orders | Bid Volume | Price | Ask Volume | Ask Orders
| | 102.00 | 5 | 2
1 | 3 | 101.00 | |
```
Last Trade: None"""

def test_single_match():
m = Market("test")
assert m.ask(101, 1, 1) is None
assert (matched := m.bid(101, 2, 1)) is not None
assert matched == "<@2> bought 1 from <@1> at 101"
assert len(m.asks) == 0
assert len(m.bids) == 0

def test_partial_match():
m = Market("test")
assert m.ask(102, 1, 100, 1) is None
assert (o := m.bid(102, 2, 50, 2)) is not None
assert o == "<@2> bought 50 from <@1> at 102"
assert len(m.bids) == 0
assert len(m.asks) == 1
assert m.asks[0].qty == 50
assert m.asks[0].order_time == 1

def test_multi_match():
m = Market("test")
assert m.ask(102, 1, 1, 1) is None
assert m.ask(102, 2, 1, 2) is None
assert m.ask(102, 3, 1, 4) is None
assert m.bid(102, 4, 2, 5) == """<@4> bought 1 from <@1> at 102
<@4> bought 1 from <@2> at 102"""
assert len(m.bids) == 0
assert len(m.asks) == 1
assert m.asks[0].user_id == 3
assert len(m.trade_history[1]) == 1
assert len(m.trade_history[2]) == 1
assert 3 not in m.trade_history
assert len(m.trade_history[4]) == 2

def test_turning():
m = Market("test")
assert m.ask(102, 1, 1, 1) is None
assert m.bid(102, 2, 100, 2) == """<@2> bought 1 from <@1> at 102"""
assert len(m.asks) == 0
assert len(m.bids) == 1
assert m.bids[0].qty == 99
assert m.bids[0].order_time == 2
Copy link
Contributor

@jfitz02 jfitz02 Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert m.bids[0].order_time == 2
assert m.bids[0].order_time == 2
def test_time():
m = Market("test")
assert m.ask(102, 1, 1) is None
assert m.ask(102, 2, 1) is None
ask_times = [ask.order_time for ask in m.asks]
# Testing that all ask times are distinct
assert all([ask_times.count(ask.order_time) == 1 for ask in m.asks])
assert m.bid(102, 3, 1) == """<@3> bought 1 from <@1> at 102"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this in GitHub and have not tested it, but this is just the gist of testing times are ok

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def test_multi_level_clear():
m = Market("test")
assert m.ask(100, 1, 1, 1) is None
assert m.ask(101, 1, 1, 2) is None
assert m.ask(102, 1, 1, 3) is None
assert m.ask(103, 1, 100, 4) is None
assert m.bid(103, 2, 10, 5) == """<@2> bought 1 from <@1> at 100
<@2> bought 1 from <@1> at 101
<@2> bought 1 from <@1> at 102
<@2> bought 7 from <@1> at 103"""
assert len(m.bids) == 0
assert len(m.asks) == 1
assert m.asks[0].qty == 93
assert m.asks[0].price == 103
assert len(m.trade_history[1]) == 4
assert len(m.trade_history[2]) == 4

def test_times():
m = Market("test")
assert m.ask(100, 1 ,1) is None
assert m.ask(101, 2, 1) is None
assert len(set([o.order_time for o in m.asks])) == len(m.asks)
assert m.bid(100, 2, 1) == "<@2> bought 1 from <@1> at 100"