Wednesday 4 November 2015

Using Chain Of Responsibility instead of if/else statement

Currently I am involved in building a web application. One of the features of the app is for the user to be able to upload their photo. But the system currently only allows three image formats to be uploaded JPG, BMP and PNG, with the possibility that in the future other formats to be supported.

Each image file has a header that, among other things, it specifies what format it is. In our case the headers are as follows:

File format Header
JPG [0xff, 0xd8]
BMP [0x42, 0x4D]
PNG [0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]

Our image decoding service is defined as

public interface IImageDecodingService
{
    ImageFormat DecodeImage(byte[] imageBuffer);
}

And the ImageFormat is defined as an enumeration:

public enum ImageFormat
{
    Unknown,
    Bmp,
    Png,
    Jpeg
}

and one possible implementation can be:

public class ImageDecodingService : IImageDecodingService
{
    private readonly byte[] _jpgHeader = { 0xff, 0xd8 };
    private readonly byte[] _pngHeader = { 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A };
    private readonly byte[] _bmpHeader = { 0x42, 0x4D };

    public ImageFormat DecodeImage(byte[] imageBuffer)
    {
        if (ContainsHeader(imageBuffer, _jpgHeader))
            return ImageFormat.Jpeg;

        if (ContainsHeader(imageBuffer, _pngHeader))
            return ImageFormat.Png;

        if (ContainsHeader(imageBuffer, _bmpHeader))
            return ImageFormat.Bmp;

        return ImageFormat.Unknown;
    }

    protected static bool ContainsHeader(byte[] buffer, byte[] header)
    {
        for (int i = 0; i < header.Length; i += 1)
        {
            if (header[i] != buffer[i])
            {
                return false;
            }
        }

        return true;
    }
}

The problem with this approach is that any time we need to support a new image format, we need to go and change the class. This breaks the "Open/Closed Principle".

A better approach is to implement each decoder as a class and then chain them together (using "Chain of Responsibility Pattern").

To achieve this, first we need to implement the decoders. The interface for the decoders looks like:

public interface IImageDecoder
{
    ImageFormat DecodeImage(byte[] buffer);
}

Since the decoders are very similar, we can extract a base class that implements common methods as follows:

public abstract class BaseDecoder : IImageDecoder
{
    private ImageFormat _decodingFormat;

    protected BaseDecoder(ImageFormat decodingFormat)
    {
        _decodingFormat = decodingFormat;
    }

    protected abstract byte[] Header { get; }

    public ImageFormat DecodeImage(byte[] buffer)
    {
        if(ContainsHeader(buffer, Header))
        {
            return _decodingFormat;
        }

        return ImageFormat.Unknown;
    }

    private static bool ContainsHeader(byte[] buffer, byte[] header)
    {
        for (int i = 0; i < header.Length; i += 1)
        {
            if (header[i] != buffer[i])
            {
                return false;
            }
        }

        return true;
    }
}

Now our decoders look like:

public sealed class JpegDecoder : BaseDecoder, IImageDecoder
{
    public JpegDecoder() : base(ImageFormat.Jpeg)
    { }

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

public sealed class BmpDecoder : BaseDecoder, IImageDecoder
{
    public BmpDecoder() : base(ImageFormat.Bmp)
    { }

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

public sealed class PngDecoder : BaseDecoder, IImageDecoder
{
    public PngDecoder() : base(ImageFormat.Png)
    { }

    protected override byte[] Header
    {
        get { return new byte[] { 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A }; }
    }
}

And the last decoder is the decoder that just returns ImageFormat.Unknown. The implementation looks like:

public class UnknownImageDecoder : IImageDecoder
{
    public ImageFormat DecodeImage(byte[] buffer)
    {
        return ImageFormat.Unknown
    }
}

The next step is to refactor our base class so it allows chaining. The refactored class looks like:

public abstract class BaseDecoder : IImageDecoder
{
    private readonly ImageFormat _decodingFormat;
    private IImageDecoder _nextChain;

    protected BaseDecoder(ImageFormat decodingFormat)
    {
        _decodingFormat = decodingFormat;
    }

    protected BaseDecoder(IImageDecoder nextChain, ImageFormat decodingFormat) : this(decodingFormat)
    {
        if (nextChain == null)
        {
            throw new ArgumentNullException("nextChain");
        }

        _nextChain = nextChain;
    }

    protected abstract byte[] Header { get; }

    public ImageFormat DecodeImage(byte[] buffer)
    {
        if (ContainsHeader(buffer, Header))
        {
            return _decodingFormat;
        }

        if (_nextChain != null)
        {
            return _nextChain.DecodeImage(buffer);
        }

        return ImageFormat.Unknown;
    }

    private static bool ContainsHeader(byte[] buffer, byte[] header)
    {
        for (int i = 0; i < header.Length; i += 1)
        {
            if (header[i] != buffer[i])
            {
                return false;
            }
        }

        return true;
    }
}

As you can see now we have two constructors, one that takes the ImageFormat and the other that takes IImageDecoder as a next chain and ImageFormat. The reason for two constructors is that the first constructor (with only one parameter) allows the decoder to be used on its own, whereas the second constrcutor (the one with two parameters) enables to build the chain.

Pay attention to the DecodeImage(...) method. Now if this method does not know how to decode the image, and the next chain is specified, it passes the responsibility to the next chain.

We also need to add the second constructor to our decoders:

public sealed class BmpDecoder : BaseDecoder
{
    public BmpDecoder() 
        : base(ImageFormat.Bmp)
    { }

    public BmpDecoder(IImageDecoder nextChain) 
        : base(nextChain, ImageFormat.Bmp)
    { } 

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

public sealed class PngDecoder : BaseDecoder
{
    public PngDecoder() 
        : base(ImageFormat.Png)
    { }

    public PngDecoder(IImageDecoder nextChain) 
        : base(nextChain, ImageFormat.Png)
    { }            

    protected override byte[] Header
    {
        get { return new byte[] { 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A }; }
    }
}

public sealed class JpegDecoder : BaseDecoder
{
    public JpegDecoder()
        : base(ImageFormat.Jpeg)
    { }

    public JpegDecoder(IImageDecoder nextChain)
        : base(nextChain, ImageFormat.Jpeg)
    { } 

    protected override byte[] Header
    {
        get { return new byte[] { 0xff, 0xd8 }; }
    }
}

To construct the chain we need a factory that constructs it and returns the first one. The interface for the factory looks like:

public interface IImageDecoderFactory
{
    IImageDecoder Create();
}

And the implementation looks like:

public class ImageDecoderFactory : IImageDecoderFactory
{
    public IImageDecoder Create()
    {
        return new BmpDecoder(new JpegDecoder(new PngDecoder(new UnknownImageDecoder())));
    }
}

Now our ImageDecodingService looks like:

public class ImageDecodingService : IImageDecodingService
{
    private readonly IImageDecoderFactory _imageDecoderFactory;

    public ImageDecodingService(IImageDecoderFactory imageDecoderFactory)
    {
        _imageDecoderFactory = imageDecoderFactory;
    }

    public ImageFormat DecodeImage(byte[] imageBuffer)
    {
            var decoder = _imageDecoderFactory.Create();
        return decoder.DecodeImage(imageBuffer);
    }
}

So, if we need to support another format, we would implement ther decoder for it and then add it to the factory. In a real-world application you would register the decoders with a DI Container and then the DI Container would pass the decoders to the factory and the factory would chain them together. In this way you do not need to change any existing code to support another format.

Tuesday 20 October 2015

Using Monads to manage exception handling

Lately I have been playing around with functional programming and F# and I have been trying to apply functional programming to C# as well. Even though C# is not a pure functional programming language, by supporting delegates, lambda method syntax and extension methods, we can produce code written in a functional style without any problems.

But I will not go into the details of explaining functional programing paradigm, but rather I will concentrate on how by using a specific functional structure called monad we can remove a lot of boilerplate code.

We all have seen a code like this:

Example 1:
    public void AddStudentToSchool(Guid schoolId, StudentDto studentDto)
    {
        using (ITransaction transaction = _transactionFactory.Create())
        {
            _securityService.Validate();

            School school = _schoolRepository.GetSchoolById(schoolId);
            if(school == null)
            {
                throw new ResourceNotFoundException("School does not exist");
            }

            try
            {
                Student student = new Student(studentDto);
                school.AddStudent(student);
            }
            catch(BusinessRuleException ex)
            {
                throw new ValidationException(ex.BrokenRules);
            }
            catch(InvalidStudentException ex)
            {
                throw new ValidationException(ex.BrokenRules);
            }
        }
    }


Wouldn’t it be more expressive and easier to read and understand the intention if we could write the following instead:

Example 2:
public void AddStudentToSchool(Guid schoolId, StudentDto studentDto)
{
    using (ITransaction transaction = _transactionFactory.Create())
    {
        _securityService.Validate();

        School school = _schoolRepository.GetSchoolById(schoolId);

        school.IfNullThrowResourceNotFound("School does not exist")
                .IfNotNull(() =>
                    {
                        Student student = new Student(studentDto);
                        school.AddStudent(student);
                    })
                .OnException<BusinessRuleException>(ex => throw new ValidationException(ex.BrokenRules))
                .OnException<InvalidStudentException>(ex => throw new ValidationException(ex.BrokenRules));
    }
}
Well, with the help of extension methods and by borrowing a structure from functional programming called monad we can.

Wikipedia defines the monad as following:

“In functional programming, a monad is a structure that represents computations defined as sequences of steps: a type with a monad structure defines what it means to chain operations, or nest functions of that type together. This allows the programmer to build pipelines that process data in steps, in which each action is decorated with additional processing rules provided by the monad.”

From the following paragraph, we are interested in the sentence “…This allows the programmer to build pipelines that process data in steps …” which is exactly what we want to achieve.

I won’t go into the details of explaining monads, but in essence a monad is a pattern for doing function composition with ‘amplified’ types (as Mike Hadlow puts it very eloquently on his blog series about monads). 

Going back to our code in example 1, if we study the code in essence what we trying to do is:

1.      Get the school from the repository
2.      If school is null throw an exception
3.      If school is not null then
4.      Try to execute more code (by adding the student to the school)
5.      If an exception is thrown, wrap that exception into ValidationException

To convert the above steps in chain of actions we need to make sure that the output of the previous action will be the input of the next action on the chain. So maybe monad (also known as Option in F#) will be adequate.

Before we dive in with writing code, let’s look at what do we need to implement for our type to be monad. Definition in the Wikipedia states the following:

For a type to be regarded a monad, it needs to have a type constructor M and two operations, bind and return:

·        The return operation takes a value from a plain type and puts it into a monadic container using the constructor, creating a monadic value.

·        The bind operation takes as its arguments a monadic value and a function from a plain type to a monadic value, and returns a new monadic value.

Maybe monad is very similar to Nullable<T> type in C#, except our maybe type will be represented by two different types. The Some type which will indicate a presence of a value and None type which will indicate an absence of a value. We will implement our maybe type as an interface and will call it IOptional. So the implementation looks like:

public interface IOptional<T>
{
}


The None type is implemented as :

public class None<T> : IOptional<T>
{
    public override string ToString()
    {
        return "None";
    }
}

And Some type is implemented as:

public class Some<T> : IOptional<T>
{
    public Some(T value)
    {
        Value = value;
    }

    public T Value { get; private set; }

    public override string ToString()
    {
        return Value.ToString();
    }
}

Now, to convert our IOptional<T> into a monad we need to implement the Return and Bind methods. We will be implementing them as extension methods. So the implementation looks like:

public static class OptionalExtensions
{
    public static IOptional<T> ToOptional<T>(this T value) where T : class
    {
        if (value == null)
        {
            return new None<T>();
        }

        return new Some<T>(value);
    }

    public static IOptional<T2> Bind<T1, T2>(this IOptional<T1> a, Func<T1, IOptional<T2>> func)
    {
        var val = a as Some<T1>;

        return val != null
            ? func(val.Value)
            : new None<T2>();
    }
}

Now our IOptional<T> type is monad. The next step is to start building the methods that we call chain them together. All the methods will be implemented as extension methods also.

The first one is throw if null. So the implementation looks like:

public static IOptional<T> IfNullThrowResourceNotFound<T>(this IOptional<T> value, string message)
    where T : class           
{
    Some<T> some = value as Some<T>;
    if (some == null)
    {
        throw new ResourceNotFoundException(message);
    }

    return value;
}


The next implementation is to try and execute an action if the value is not null. So the implementation looks like:

public static IOptional<Action> IfNotNull<T>(this IOptional<T> value, Action action)
    where T : class
{
    return value.Bind(t => new Some<Action>(action));
}



Next we will implement OnException step. The implementation looks like:

public static IOptional<Action> OnException<E>(this IOptional<Action> value, Action<E> action)
    where E : Exception
{
    return value.Bind(t =>
    {
        return new Some<Action>(() =>
        {
            try
            {
                t();
            }
            catch (E ex)
            {
                action(ex);
            }
        });
    });
}


And the last step is the Execute, which executes the code

public static void Execute(this IOptional<Action> value)
{
    value.Bind(t =>
    {
        t();
        return new None<bool>();
    });
}


Once we have finished implementing the extension methods, our code from example 1 looks like:

School school = _schoolRepository.GetSchoolById(schoolId);

school
    .ToOptional()
    .IfNullThrowResourceNotFound("School does not exist")
    .IfNotNull(() =>
        {
            Student student = new Student(studentDto);
            school.AddStudent(student);
        })
    .OnException<BusinessRuleException>(ex => { throw new ValidationException(ex.BrokenRules); })
    .OnException<InvalidStudentException>(ex => { throw new ValidationException(ex.BrokenRules); })
    .Execute();

As you can see, now our code looks a lot neater and easier to read without any boilerplate code. I hope you find this useful.