1use std::cmp::Ordering;
2
3use primitive_types::{H160, U256};
4use zksync_vm2_interface::{
5 CallframeInterface, Event, Flags, GlobalStateInterface, HeapId, L2ToL1Log, StateInterface,
6 Tracer,
7};
8
9use crate::{
10 callframe::{Callframe, NearCallFrame},
11 decommit::is_kernel,
12 predication::{self, Predicate},
13 VirtualMachine, World,
14};
15
16impl<T: Tracer, W: World<T>> StateInterface for VirtualMachine<T, W> {
17 fn read_register(&self, register: u8) -> (U256, bool) {
18 (
19 self.state.registers[register as usize],
20 self.state.register_pointer_flags & (1 << register) != 0,
21 )
22 }
23
24 fn set_register(&mut self, register: u8, value: U256, is_pointer: bool) {
25 self.state.registers[register as usize] = value;
26
27 self.state.register_pointer_flags &= !(1 << register);
28 self.state.register_pointer_flags |= u16::from(is_pointer) << register;
29 }
30
31 fn number_of_callframes(&self) -> usize {
32 self.state
33 .previous_frames
34 .iter()
35 .map(|frame| frame.near_calls.len() + 1)
36 .sum::<usize>()
37 + self.state.current_frame.near_calls.len()
38 + 1
39 }
40
41 fn current_frame(&mut self) -> impl CallframeInterface + '_ {
42 let near_call = self.state.current_frame.near_calls.len().checked_sub(1);
43 CallframeWrapper {
44 frame: &mut self.state.current_frame,
45 near_call,
46 }
47 }
48
49 fn callframe(&mut self, mut n: usize) -> impl CallframeInterface + '_ {
50 for far_frame in std::iter::once(&mut self.state.current_frame)
51 .chain(self.state.previous_frames.iter_mut().rev())
52 {
53 let near_calls = far_frame.near_calls.len();
54 match n.cmp(&near_calls) {
55 Ordering::Less => {
56 return CallframeWrapper {
57 frame: far_frame,
58 near_call: Some(near_calls - 1 - n),
59 }
60 }
61 Ordering::Equal => {
62 return CallframeWrapper {
63 frame: far_frame,
64 near_call: None,
65 }
66 }
67 Ordering::Greater => n -= near_calls + 1,
68 }
69 }
70 panic!("Callframe index out of bounds")
71 }
72
73 fn read_heap_byte(&self, heap: HeapId, index: u32) -> u8 {
74 self.state.heaps[heap].read_byte(index)
75 }
76
77 fn read_heap_u256(&self, heap: HeapId, index: u32) -> U256 {
78 self.state.heaps[heap].read_u256(index)
79 }
80
81 fn write_heap_u256(&mut self, heap: HeapId, index: u32, value: U256) {
82 self.state.heaps.write_u256(heap, index, value);
83 }
84
85 fn flags(&self) -> Flags {
86 let flags = &self.state.flags;
87 Flags {
88 less_than: Predicate::IfLT.satisfied(flags),
89 greater: Predicate::IfGT.satisfied(flags),
90 equal: Predicate::IfEQ.satisfied(flags),
91 }
92 }
93
94 fn set_flags(&mut self, flags: Flags) {
95 self.state.flags = predication::Flags::new(flags.less_than, flags.equal, flags.greater);
96 }
97
98 fn transaction_number(&self) -> u16 {
99 self.state.transaction_number
100 }
101
102 fn set_transaction_number(&mut self, value: u16) {
103 self.state.transaction_number = value;
104 }
105
106 fn context_u128_register(&self) -> u128 {
107 self.state.context_u128
108 }
109
110 fn set_context_u128_register(&mut self, value: u128) {
111 self.state.context_u128 = value;
112 }
113
114 fn get_storage_state(&self) -> impl Iterator<Item = ((H160, U256), U256)> {
115 self.world_diff
116 .get_storage_state()
117 .iter()
118 .map(|(key, value)| (*key, *value))
119 }
120
121 fn get_transient_storage_state(&self) -> impl Iterator<Item = ((H160, U256), U256)> {
122 self.world_diff
123 .get_transient_storage_state()
124 .iter()
125 .map(|(key, value)| (*key, *value))
126 }
127
128 fn get_transient_storage(&self, address: H160, slot: U256) -> U256 {
129 self.world_diff
130 .get_transient_storage_state()
131 .get(&(address, slot))
132 .copied()
133 .unwrap_or_default()
134 }
135
136 fn write_transient_storage(&mut self, address: H160, slot: U256, value: U256) {
137 self.world_diff
138 .write_transient_storage(address, slot, value);
139 }
140
141 fn events(&self) -> impl Iterator<Item = Event> {
142 self.world_diff.events().iter().copied()
143 }
144
145 fn l2_to_l1_logs(&self) -> impl Iterator<Item = L2ToL1Log> {
146 self.world_diff.l2_to_l1_logs().iter().copied()
147 }
148
149 fn pubdata(&self) -> i32 {
150 self.world_diff.pubdata()
151 }
152
153 fn set_pubdata(&mut self, value: i32) {
154 self.world_diff.pubdata.0 = value;
155 }
156}
157
158struct CallframeWrapper<'a, T, W> {
159 frame: &'a mut Callframe<T, W>,
160 near_call: Option<usize>,
161}
162
163impl<T: Tracer, W: World<T>> CallframeInterface for CallframeWrapper<'_, T, W> {
164 fn address(&self) -> H160 {
165 self.frame.address
166 }
167
168 fn set_address(&mut self, address: H160) {
169 self.frame.address = address;
170 self.frame.is_kernel = is_kernel(address);
171 }
172
173 fn code_address(&self) -> H160 {
174 self.frame.code_address
175 }
176
177 fn set_code_address(&mut self, address: H160) {
178 self.frame.code_address = address;
179 }
180
181 fn caller(&self) -> H160 {
182 self.frame.caller
183 }
184
185 fn set_caller(&mut self, address: H160) {
186 self.frame.caller = address;
187 }
188
189 fn is_static(&self) -> bool {
190 self.frame.is_static
191 }
192
193 fn is_kernel(&self) -> bool {
194 self.frame.is_kernel
195 }
196
197 fn context_u128(&self) -> u128 {
198 self.frame.context_u128
199 }
200
201 fn set_context_u128(&mut self, value: u128) {
202 self.frame.context_u128 = value;
203 }
204
205 fn read_stack(&self, index: u16) -> (U256, bool) {
206 (
207 self.frame.stack.get(index),
208 self.frame.stack.get_pointer_flag(index),
209 )
210 }
211
212 fn write_stack(&mut self, index: u16, value: U256, is_pointer: bool) {
213 self.frame.stack.set(index, value);
214 if is_pointer {
215 self.frame.stack.set_pointer_flag(index);
216 } else {
217 self.frame.stack.clear_pointer_flag(index);
218 }
219 }
220
221 fn heap(&self) -> HeapId {
222 self.frame.heap
223 }
224
225 fn heap_bound(&self) -> u32 {
226 self.frame.heap_size
227 }
228
229 fn set_heap_bound(&mut self, value: u32) {
230 self.frame.heap_size = value;
231 }
232
233 fn aux_heap(&self) -> HeapId {
234 self.frame.aux_heap
235 }
236
237 fn aux_heap_bound(&self) -> u32 {
238 self.frame.aux_heap_size
239 }
240
241 fn set_aux_heap_bound(&mut self, value: u32) {
242 self.frame.aux_heap_size = value;
243 }
244
245 fn read_contract_code(&self, slot: u16) -> U256 {
246 self.frame.program.code_page()[slot as usize]
247 }
248
249 fn is_near_call(&self) -> bool {
252 self.near_call.is_some()
253 }
254
255 fn gas(&self) -> u32 {
256 if let Some(call) = self.near_call_on_top() {
257 call.previous_frame_gas
258 } else {
259 self.frame.gas
260 }
261 }
262
263 fn set_gas(&mut self, new_gas: u32) {
264 if let Some(call) = self.near_call_on_top_mut() {
265 call.previous_frame_gas = new_gas;
266 } else {
267 self.frame.gas = new_gas;
268 }
269 }
270
271 fn stack_pointer(&self) -> u16 {
272 if let Some(call) = self.near_call_on_top() {
273 call.previous_frame_sp
274 } else {
275 self.frame.sp
276 }
277 }
278
279 fn set_stack_pointer(&mut self, value: u16) {
280 if let Some(call) = self.near_call_on_top_mut() {
281 call.previous_frame_sp = value;
282 } else {
283 self.frame.sp = value;
284 }
285 }
286
287 #[allow(clippy::cast_possible_truncation)]
289 fn program_counter(&self) -> Option<u16> {
290 if let Some(call) = self.near_call_on_top() {
291 Some(call.previous_frame_pc)
292 } else {
293 let offset = self.frame.get_raw_pc();
294 if offset > u16::MAX as usize || self.frame.program.instruction(offset as u16).is_none()
295 {
296 None
297 } else {
298 Some(offset as u16)
299 }
300 }
301 }
302
303 fn set_program_counter(&mut self, value: u16) {
304 if let Some(call) = self.near_call_on_top_mut() {
305 call.previous_frame_pc = value;
306 } else {
307 self.frame.set_pc_from_u16(value);
308 }
309 }
310
311 fn exception_handler(&self) -> u16 {
312 if let Some(i) = self.near_call {
313 self.frame.near_calls[i].exception_handler
314 } else {
315 self.frame.exception_handler
316 }
317 }
318
319 fn set_exception_handler(&mut self, value: u16) {
320 if let Some(i) = self.near_call {
321 self.frame.near_calls[i].exception_handler = value;
322 } else {
323 self.frame.exception_handler = value;
324 }
325 }
326}
327
328impl<T, W> CallframeWrapper<'_, T, W> {
329 fn near_call_on_top(&self) -> Option<&NearCallFrame> {
330 let index = self.near_call.map_or(0, |i| i + 1);
331 self.frame.near_calls.get(index)
332 }
333
334 fn near_call_on_top_mut(&mut self) -> Option<&mut NearCallFrame> {
335 let index = self.near_call.map_or(0, |i| i + 1);
336 self.frame.near_calls.get_mut(index)
337 }
338}
339
340pub(crate) struct VmAndWorld<'a, T, W> {
341 pub vm: &'a mut VirtualMachine<T, W>,
342 pub world: &'a mut W,
343}
344
345impl<T: Tracer, W: World<T>> GlobalStateInterface for VmAndWorld<'_, T, W> {
346 fn get_storage(&mut self, address: H160, slot: U256) -> U256 {
347 self.vm
348 .world_diff
349 .just_read_storage(self.world, address, slot)
350 }
351}
352
353impl<T: Tracer, W: World<T>> StateInterface for VmAndWorld<'_, T, W> {
355 fn read_register(&self, register: u8) -> (U256, bool) {
356 self.vm.read_register(register)
357 }
358 fn set_register(&mut self, register: u8, value: U256, is_pointer: bool) {
359 self.vm.set_register(register, value, is_pointer);
360 }
361 fn current_frame(&mut self) -> impl CallframeInterface + '_ {
362 self.vm.current_frame()
363 }
364 fn number_of_callframes(&self) -> usize {
365 self.vm.number_of_callframes()
366 }
367 fn callframe(&mut self, n: usize) -> impl CallframeInterface + '_ {
368 self.vm.callframe(n)
369 }
370 fn read_heap_byte(&self, heap: HeapId, offset: u32) -> u8 {
371 self.vm.read_heap_byte(heap, offset)
372 }
373 fn read_heap_u256(&self, heap: HeapId, offset: u32) -> U256 {
374 self.vm.read_heap_u256(heap, offset)
375 }
376 fn write_heap_u256(&mut self, heap: HeapId, offset: u32, value: U256) {
377 self.vm.write_heap_u256(heap, offset, value);
378 }
379 fn flags(&self) -> Flags {
380 self.vm.flags()
381 }
382 fn set_flags(&mut self, flags: Flags) {
383 self.vm.set_flags(flags);
384 }
385 fn transaction_number(&self) -> u16 {
386 self.vm.transaction_number()
387 }
388 fn set_transaction_number(&mut self, value: u16) {
389 self.vm.set_transaction_number(value);
390 }
391 fn context_u128_register(&self) -> u128 {
392 self.vm.context_u128_register()
393 }
394 fn set_context_u128_register(&mut self, value: u128) {
395 self.vm.set_context_u128_register(value);
396 }
397 fn get_storage_state(&self) -> impl Iterator<Item = ((H160, U256), U256)> {
398 self.vm.get_storage_state()
399 }
400 fn get_transient_storage_state(&self) -> impl Iterator<Item = ((H160, U256), U256)> {
401 self.vm.get_transient_storage_state()
402 }
403 fn get_transient_storage(&self, address: H160, slot: U256) -> U256 {
404 self.vm.get_transient_storage(address, slot)
405 }
406 fn write_transient_storage(&mut self, address: H160, slot: U256, value: U256) {
407 self.vm.write_transient_storage(address, slot, value);
408 }
409 fn events(&self) -> impl Iterator<Item = Event> {
410 self.vm.events()
411 }
412 fn l2_to_l1_logs(&self) -> impl Iterator<Item = L2ToL1Log> {
413 self.vm.l2_to_l1_logs()
414 }
415 fn pubdata(&self) -> i32 {
416 self.vm.pubdata()
417 }
418 fn set_pubdata(&mut self, value: i32) {
419 self.vm.set_pubdata(value);
420 }
421}
422
423#[cfg(all(test, not(feature = "single_instruction_test")))]
424mod test {
425 use primitive_types::H160;
426 use zkevm_opcode_defs::ethereum_types::Address;
427 use zksync_vm2_interface::opcodes;
428
429 use super::*;
430 use crate::{
431 testonly::{initial_decommit, TestWorld},
432 Instruction, Program, VirtualMachine,
433 };
434
435 #[test]
436 fn callframe_picking() {
437 let program = Program::from_raw(vec![Instruction::from_invalid()], vec![]);
438
439 let address = Address::from_low_u64_be(0x_1234_5678_90ab_cdef);
440 let mut world = TestWorld::new(&[(address, program)]);
441 let program = initial_decommit(&mut world, address);
442
443 let mut vm = VirtualMachine::new(
444 address,
445 program.clone(),
446 Address::zero(),
447 &[],
448 1000,
449 crate::Settings {
450 default_aa_code_hash: [0; 32],
451 evm_interpreter_code_hash: [0; 32],
452 hook_address: 0,
453 },
454 );
455
456 vm.state.current_frame.gas = 0;
457 vm.state.current_frame.exception_handler = 0;
458 let mut frame_count = 1;
459
460 let add_far_frame = |vm: &mut VirtualMachine<(), TestWorld<()>>, counter: &mut u16| {
461 vm.push_frame::<opcodes::Normal>(
462 H160::from_low_u64_be(1),
463 program.clone(),
464 (*counter).into(),
465 *counter,
466 false,
467 false,
468 HeapId::from_u32_unchecked(5),
469 vm.world_diff.snapshot(),
470 );
471 assert_eq!(vm.current_frame().gas(), (*counter).into());
472 *counter += 1;
473 };
474
475 let add_near_frame = |vm: &mut VirtualMachine<(), TestWorld<()>>, counter: &mut u16| {
476 let count_u32 = (*counter).into();
477 vm.state.current_frame.gas += count_u32;
478 vm.state
479 .current_frame
480 .push_near_call(count_u32, *counter, vm.world_diff.snapshot());
481 assert_eq!(vm.current_frame().gas(), (*counter).into());
482 *counter += 1;
483 };
484
485 add_far_frame(&mut vm, &mut frame_count);
486 add_near_frame(&mut vm, &mut frame_count);
487 add_far_frame(&mut vm, &mut frame_count);
488 add_far_frame(&mut vm, &mut frame_count);
489 add_near_frame(&mut vm, &mut frame_count);
490 add_near_frame(&mut vm, &mut frame_count);
491
492 for (fwd, rev) in (0..frame_count.into()).zip((0..frame_count).rev()) {
493 assert_eq!(vm.callframe(fwd).exception_handler(), rev);
494 assert_eq!(vm.callframe(fwd).gas(), rev.into());
495 }
496 }
497
498 #[test]
499 fn program_counter_with_pointer_before_program() {
500 let program = Program::from_raw(vec![Instruction::from_invalid()], vec![]);
501
502 let address = Address::from_low_u64_be(0x_1234_5678_90ab_cdef);
503 let mut world = TestWorld::<()>::new(&[(address, program)]);
504 let program = initial_decommit(&mut world, address);
505
506 let mut vm = VirtualMachine::new(
507 address,
508 program,
509 Address::zero(),
510 &[],
511 1000,
512 crate::Settings {
513 default_aa_code_hash: [0; 32],
514 evm_interpreter_code_hash: [0; 32],
515 hook_address: 0,
516 },
517 );
518
519 let first_instruction = vm.state.current_frame.pc;
520 vm.state.current_frame.pc = first_instruction.wrapping_sub(1);
521
522 let result = vm.current_frame().program_counter();
523 assert_eq!(result, None);
524 }
525}