From 0cb7fc427dc68a43c472ac7468782f0b6b200ab7 Mon Sep 17 00:00:00 2001
From: Patrice Tisserand
Date: Mon, 15 Jul 2024 16:06:17 +0200
Subject: [PATCH] erc2981: add test cases for ERC2981Component
---
src/tests/mocks.cairo | 1 +
src/tests/mocks/erc2981_mocks.cairo | 45 +++++++++++
src/tests/token.cairo | 1 +
src/tests/token/erc2981.cairo | 1 +
src/tests/token/erc2981/test_erc2981.cairo | 90 ++++++++++++++++++++++
src/token/common/erc2981/erc2981.cairo | 25 +++++-
6 files changed, 161 insertions(+), 2 deletions(-)
create mode 100644 src/tests/mocks/erc2981_mocks.cairo
create mode 100644 src/tests/token/erc2981.cairo
create mode 100644 src/tests/token/erc2981/test_erc2981.cairo
diff --git a/src/tests/mocks.cairo b/src/tests/mocks.cairo
index 6e6e38156..9797f0252 100644
--- a/src/tests/mocks.cairo
+++ b/src/tests/mocks.cairo
@@ -4,6 +4,7 @@ pub(crate) mod erc1155_mocks;
pub(crate) mod erc1155_receiver_mocks;
pub(crate) mod erc20_mocks;
pub(crate) mod erc20_votes_mocks;
+pub(crate) mod erc2981_mocks;
pub(crate) mod erc721_mocks;
pub(crate) mod erc721_receiver_mocks;
pub(crate) mod eth_account_mocks;
diff --git a/src/tests/mocks/erc2981_mocks.cairo b/src/tests/mocks/erc2981_mocks.cairo
new file mode 100644
index 000000000..1f74f65ef
--- /dev/null
+++ b/src/tests/mocks/erc2981_mocks.cairo
@@ -0,0 +1,45 @@
+#[starknet::contract]
+pub(crate) mod ERC2981Mock {
+ use openzeppelin::introspection::src5::SRC5Component;
+ use openzeppelin::token::common::erc2981::ERC2981Component;
+ use starknet::ContractAddress;
+
+ component!(path: ERC2981Component, storage: erc2981, event: ERC2981Event);
+ component!(path: SRC5Component, storage: src5, event: SRC5Event);
+
+ #[storage]
+ struct Storage {
+ #[substorage(v0)]
+ erc2981: ERC2981Component::Storage,
+ #[substorage(v0)]
+ src5: SRC5Component::Storage
+ }
+
+ #[event]
+ #[derive(Drop, starknet::Event)]
+ enum Event {
+ #[flat]
+ ERC2981Event: ERC2981Component::Event,
+ #[flat]
+ SRC5Event: SRC5Component::Event
+ }
+
+
+ #[abi(embed_v0)]
+ impl ERC2981Impl = ERC2981Component::ERC2981Impl;
+ impl ERC2981InternalImpl = ERC2981Component::InternalImpl;
+
+ // SRC5
+ #[abi(embed_v0)]
+ impl SRC5Impl = SRC5Component::SRC5Impl;
+
+ #[constructor]
+ fn constructor(
+ ref self: ContractState,
+ owner: ContractAddress,
+ default_receiver: ContractAddress,
+ default_royalty_fraction: u256
+ ) {
+ self.erc2981.initializer(default_receiver, default_royalty_fraction);
+ }
+}
diff --git a/src/tests/token.cairo b/src/tests/token.cairo
index 04f631ea8..d366c7814 100644
--- a/src/tests/token.cairo
+++ b/src/tests/token.cairo
@@ -1,3 +1,4 @@
pub(crate) mod erc1155;
pub(crate) mod erc20;
+pub(crate) mod erc2981;
pub(crate) mod erc721;
diff --git a/src/tests/token/erc2981.cairo b/src/tests/token/erc2981.cairo
new file mode 100644
index 000000000..806e601d1
--- /dev/null
+++ b/src/tests/token/erc2981.cairo
@@ -0,0 +1 @@
+mod test_erc2981;
diff --git a/src/tests/token/erc2981/test_erc2981.cairo b/src/tests/token/erc2981/test_erc2981.cairo
new file mode 100644
index 000000000..e0f374a94
--- /dev/null
+++ b/src/tests/token/erc2981/test_erc2981.cairo
@@ -0,0 +1,90 @@
+use openzeppelin::introspection::interface::{ISRC5Dispatcher, ISRC5DispatcherTrait};
+
+use openzeppelin::tests::mocks::erc2981_mocks::ERC2981Mock;
+use openzeppelin::token::common::erc2981::ERC2981Component::{ERC2981Impl, InternalImpl};
+use openzeppelin::token::common::erc2981::ERC2981Component;
+use openzeppelin::token::common::erc2981::interface::IERC2981_ID;
+use openzeppelin::token::common::erc2981::{IERC2981Dispatcher, IERC2981DispatcherTrait};
+
+use starknet::{ContractAddress, contract_address_const};
+
+
+type ComponentState = ERC2981Component::ComponentState;
+
+fn COMPONENT_STATE() -> ComponentState {
+ ERC2981Component::component_state_for_testing()
+}
+
+fn OWNER() -> ContractAddress {
+ contract_address_const::<'OWNER'>()
+}
+
+fn DEFAULT_RECEIVER() -> ContractAddress {
+ contract_address_const::<'DEFAULT_RECEIVER'>()
+}
+
+fn RECEIVER() -> ContractAddress {
+ contract_address_const::<'RECEIVER'>()
+}
+
+// 0.5% (default denominator is 10000)
+fn DEFAULT_FEE_NUMERATOR() -> u256 {
+ 50
+}
+
+// 5% (default denominator is 10000)
+fn FEE_NUMERATOR() -> u256 {
+ 500
+}
+
+fn setup() -> ComponentState {
+ let mut state = COMPONENT_STATE();
+ state.initializer(DEFAULT_RECEIVER(), DEFAULT_FEE_NUMERATOR());
+ state
+}
+
+
+#[test]
+fn test_default_royalty() {
+ let mut state = setup();
+ let token_id = 12;
+ let sale_price = 1_000_000;
+ let (receiver, amount) = state.royalty_info(token_id, sale_price);
+ assert_eq!(receiver, DEFAULT_RECEIVER(), "Default receiver incorrect");
+ assert_eq!(amount, 5000, "Default fees incorrect");
+
+ state._set_default_royalty(RECEIVER(), FEE_NUMERATOR());
+
+ let (receiver, amount) = state.royalty_info(token_id, sale_price);
+ assert_eq!(receiver, RECEIVER(), "Default receiver incorrect");
+ assert_eq!(amount, 50000, "Default fees incorrect");
+}
+
+
+#[test]
+fn test_token_royalty_token() {
+ let mut state = setup();
+ let token_id = 12;
+ let another_token_id = 13;
+ let sale_price = 1_000_000;
+ let (receiver, amount) = state.royalty_info(token_id, sale_price);
+ assert_eq!(receiver, DEFAULT_RECEIVER(), "Default receiver incorrect");
+ assert_eq!(amount, 5000, "Wrong royalty amount");
+ let (receiver, amount) = state.royalty_info(another_token_id, sale_price);
+ assert_eq!(receiver, DEFAULT_RECEIVER(), "Default receiver incorrect");
+ assert_eq!(amount, 5000, "Wrong royalty amount");
+
+ state._set_token_royalty(token_id, RECEIVER(), FEE_NUMERATOR());
+ let (receiver, amount) = state.royalty_info(another_token_id, sale_price);
+ assert_eq!(receiver, DEFAULT_RECEIVER(), "Default receiver incorrect");
+ assert_eq!(amount, 5000, "Wrong royalty amount");
+ let (receiver, amount) = state.royalty_info(token_id, sale_price);
+ assert_eq!(receiver, RECEIVER(), "Default receiver incorrect");
+ assert_eq!(amount, 50000, "Wrong royalty amount");
+
+ state._reset_token_royalty(token_id);
+ let (receiver, amount) = state.royalty_info(token_id, sale_price);
+ assert_eq!(receiver, DEFAULT_RECEIVER(), "Default receiver incorrect");
+ assert_eq!(amount, 5000, "Wrong royalty amount");
+}
+
diff --git a/src/token/common/erc2981/erc2981.cairo b/src/token/common/erc2981/erc2981.cairo
index f6e3d831d..798e3d5af 100644
--- a/src/token/common/erc2981/erc2981.cairo
+++ b/src/token/common/erc2981/erc2981.cairo
@@ -7,8 +7,6 @@
pub mod ERC2981Component {
use core::num::traits::Zero;
- use openzeppelin::access::ownable::OwnableComponent::InternalTrait as OwnableInternalTrait;
- use openzeppelin::access::ownable::OwnableComponent;
use openzeppelin::introspection::src5::SRC5Component::InternalTrait as SRC5InternalTrait;
use openzeppelin::introspection::src5::SRC5Component::SRC5Impl;
use openzeppelin::introspection::src5::SRC5Component;
@@ -90,6 +88,13 @@ pub mod ERC2981Component {
10000
}
+ /// Returns the royalty information that all ids in this contract will default to.
+ fn _default_royalty(
+ self: @ComponentState
+ ) -> (ContractAddress, u256, u256) {
+ let royalty_info: RoyaltyInfo = self.default_royalty_info.read();
+ (royalty_info.receiver, royalty_info.royalty_fraction, self._fee_denominator())
+ }
/// Sets the royalty information that all ids in this contract will default to.
///
@@ -132,6 +137,22 @@ pub mod ERC2981Component {
.write(token_id, RoyaltyInfo { receiver, royalty_fraction: fee_numerator },)
}
+ /// Returns the royalty information that all ids in this contract will default to.
+ fn _token_royalty(
+ self: @ComponentState, token_id: u256
+ ) -> (ContractAddress, u256, u256) {
+ let royalty_info: RoyaltyInfo = self.token_royalty_info.read(token_id);
+ let mut receiver = royalty_info.receiver;
+ let mut royalty_fraction = royalty_info.royalty_fraction;
+
+ if receiver.is_zero() {
+ let default_royalty_info: RoyaltyInfo = self.default_royalty_info.read();
+ receiver = default_royalty_info.receiver;
+ royalty_fraction = default_royalty_info.royalty_fraction;
+ };
+ (receiver, royalty_fraction, self._fee_denominator())
+ }
+
/// Resets royalty information for the token id back to the global default.
fn _reset_token_royalty(ref self: ComponentState, token_id: u256) {