Implementing Ruby's Enumerator in Ruby (using Fiber)
I want to understand how Ruby's Enumerator works and decided to implement it in Ruby for fun. Well, technically I had to understand it for a different but related problem - but I maintain that it was fun.
Tests first... #
To make sure we're implementing the correct Enumerator
behaviour, let's go to to Ruby Spec and copy the specs for Enumerator#new
which tests the most basic Enumerator
behaviour we want to implement. I love it when the tests are written for you 😌
I've copied the spec and swapped out the class to our own SimpleEnumerator
:
describe "SimpleEnumerator" do
context "when passed a block" do
it "defines iteration with block, yielder argument and calling << method" do
enum = SimpleEnumerator.new do |yielder|
a = 1
loop do
yielder << a
a = a + 1
end
end
enum.take(3).should == [1, 2, 3]
end
end
In case you're new to Ruby enumerators work:
We have a SimpleEnumerator
, which takes a block argument. The block argument contains a loop that calls the method #<<
with yielder
as a receiver, then increment the number. Even though it seems to be an endless loop without a stop condition, the method #take
seems to be able to control the amount of times the loop block runs, and then break out of it.
To make the code example easier to follow for this blog post, I'm going to remove the loop:
describe "SimpleEnumerator" do
context "when passed a block" do
it "defines iteration with block, yielder argument and calling << method" do
enum = SimpleEnumerator.new do |yielder|
puts "reached hi"
yielder << "hi"
puts "reached hello"
yielder << "hello"
puts "reached bye"
yielder << "bye"
end
enum.take(1).should == ["hi"]
enum.take(2).should == ["hi", "hello"]
enum.take(3).should == ["hi", "hello", "bye"]
end
end
Here's the full implementation of SimpleEnumerator. Note that it's not the full implementation of Enumerator and written to only satisfy the basic requirements shown in the spec above.
Let's walk through it!
SimpleEnumerator#new #
class SimpleEnumerator
def initialize(&block)
raise ArgumentError unless block_given?
@block = block
end
...
end
When we create a new instance of SimpleEnumerator
, we ensure that there is a block given because the block argument is required. Then, we save the block as an instance variable because we'll need to use it later.
SimpleEnumerator#take #
The next part is the #take
method. In Ruby, this method is implemented as part of the Enumerable
library, whose module that's included in Enumerator
. To keep things simple (as the class name suggests), we implement #take
in SimpleEnumerator
.
def take(num)
# 1) create a fiber
@fiber = Fiber.new do
@block.call(EnumYielder.new)
end
# 2) collect values
ary = []
num.times do
ary << self.next
end
ary
end
1) create a fiber #
When #take
is called, it creates a fiber. The fiber block does not run yet - but when it does, it'll call the block argument provided when the enumerator was initialized with an object called EnumYielder
. (We'll talk about what EnumYielder
does later).
2) collect values #
After that, we create an empty array to store the items we want to "take" and then a loop that iterates num
times. In each iteration, we store self.next
, a method that returns whatever value we get from the SimpleEnumerator
block.
Finally, we return the final array containing the objects we've collected.
SimpleEnumerator#next #
As mentioned, self.next
is supposed to return the object from the SimpleEnumerator
block. We can get the value that by running the fiber we've created, using Fiber#resume
.
def next
@fiber.resume
end
Calling Fiber#resume
will run the block (see below) we've passed in SimpleEnumerator
.
enum = SimpleEnumerator.new do |yielder|
puts "reached hi"
yielder << "hi"
puts "reached hello"
yielder << "hello"
puts "reached bye"
yielder << "bye"
end
Note that the yielder block argument is an instance of EnumYielder
, a class we've created ourselves. So at this point, we are passing in the value "hi"
into EnumYielder#<<(value)
.
EnumYielder#<<(value) #
EnumYielder
implements #<<
that simply calls Fiber.yield
with the value it received.
class EnumYielder
def <<(value)
Fiber.yield(value)
end
end
This yields the control back to the caller, in this case, self.next
, with the value it got. In other words, our "hi"
string is now passed back to self.next
, and to the loop in #take
, to be appended into the array we previously created.
If take(1)
is called once, then the times-loop ends here, returning ["hi"]
.
If take is called more than once (like in our spec above), we'll continue execution on the block where we previously left off (or yielded) to get the next value, "hello"
, and so on.
Summary #
Now we better understand how Enumerator and Fiber works by implementing a simple version of an enumerator in Ruby. Using fibers, SimpleEnumerator can easily resume and yield control of the given block. In fact, Ruby's Enumerator is also implemented using fibers, namely for the methods #peek
and #next
.