Skip to content

Redis

package main

import (
    "context"
    "fmt"
    "testing"
    "time"

    "github.com/go-redis/redis/v8"
    "github.com/google/uuid"
    "github.com/testcontainers/testcontainers-go"
    "github.com/testcontainers/testcontainers-go/wait"
)

type redisContainer struct {
    testcontainers.Container
    URI string
}

func setupRedis(ctx context.Context) (*redisContainer, error) {
    req := testcontainers.ContainerRequest{
        Image:        "redis:6",
        ExposedPorts: []string{"6379/tcp"},
        WaitingFor:   wait.ForLog("* Ready to accept connections"),
    }
    container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
        ContainerRequest: req,
        Started:          true,
    })
    if err != nil {
        return nil, err
    }

    mappedPort, err := container.MappedPort(ctx, "6379")
    if err != nil {
        return nil, err
    }

    hostIP, err := container.Host(ctx)
    if err != nil {
        return nil, err
    }

    uri := fmt.Sprintf("redis://%s:%s", hostIP, mappedPort.Port())

    return &redisContainer{Container: container, URI: uri}, nil
}

func flushRedis(ctx context.Context, client redis.Client) error {
    return client.FlushAll(ctx).Err()
}

func TestIntegrationSetGet(t *testing.T) {
    if testing.Short() {
        t.Skip("Skipping integration test")
    }

    ctx := context.Background()

    redisContainer, err := setupRedis(ctx)
    if err != nil {
        t.Fatal(err)
    }
    defer redisContainer.Terminate(ctx)

    // You will likely want to wrap your Redis package of choice in an
    // interface to aid in unit testing and limit lock-in throughtout your
    // codebase but that's out of scope for this example
    options, err := redis.ParseURL(redisContainer.URI)
    if err != nil {
        t.Fatal(err)
    }
    client := redis.NewClient(options)
    defer flushRedis(ctx, *client)

    // Set data
    key := fmt.Sprintf("{user.%s}.favoritefood", uuid.NewString())
    value := "Cabbage Biscuits"
    ttl, _ := time.ParseDuration("2h")
    err = client.Set(ctx, key, value, ttl).Err()
    if err != nil {
        t.Fatal(err)
    }

    // Get data
    savedValue, err := client.Get(ctx, key).Result()
    if err != nil {
        t.Fatal(err)
    }
    if savedValue != value {
        t.Fatalf("Expected value %s. Got %s.", savedValue, value)
    }
}