Skip to content

Proposing a new example: Mnist #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
CFLAGS = -Wall -Wshadow -O3 -g -march=native
CFLAGS = -Wall -Wshadow -O3 -I. -g -march=native
LDLIBS = -lm

all: check example1 example2 example3 example4
all: check example1 example2 example3 example4 mnist

sigmoid: CFLAGS += -Dgenann_act=genann_act_sigmoid_cached
sigmoid: all
Expand All @@ -25,10 +25,11 @@ example3: example3.o genann.o

example4: example4.o genann.o

mnist: mnist.o mnist_db.o genann.o

clean:
$(RM) *.o
$(RM) test example1 example2 example3 example4 *.exe
$(RM) test example1 example2 example3 example4 mnist *.exe
$(RM) persist.txt

.PHONY: sigmoid threshold linear clean
93 changes: 93 additions & 0 deletions mnist.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>

#include "genann.h"
#include "mnist_db.h"

#define CLASS_COUNT 10

int main(int argc, char* argv[])
{
size_t i;
int j;
double output[CLASS_COUNT];
MnistDataset training, tests;

if(argc != 5) {
printf("./mnist [NUMBER OF HIDDEN LAYERS] [NEURON PER HIDDEN LAYERS] [TRAINING ITERATION COUNT] [OUTPUT FILE]");
return 1;
}

if(mnist_init(&training,
"mnist_data/train-images-idx3-ubyte",
"mnist_data/train-labels-idx1-ubyte",
0, 0
))
return 1;

if(mnist_load_batch(&training) != training.batch_size) {
mnist_free(&training);
return 1;
}

if(mnist_init(&tests,
"mnist_data/t10k-images-idx3-ubyte",
"mnist_data/t10k-labels-idx1-ubyte",
0, 0
)) {
mnist_free(&training);
return 1;
}

if(mnist_load_batch(&tests) != tests.batch_size) {
mnist_free(&tests);
mnist_free(&training);
return 1;
}

assert(training.width == tests.width);
assert(training.height == tests.height);
assert(training.width != 0);
assert(training.height != 0);

genann *ann = genann_init(training.width * training.height,
atoi(argv[1]),
atoi(argv[2]),
CLASS_COUNT
);
assert(ann != NULL);

memset(output, 0, CLASS_COUNT * sizeof(double));

for(j = 0; j < atoi(argv[3]); j ++) {
for (i = 0; i < training.batch_size; ++i) {
printf("[Training number %d]: %zd%%\r",
j+1,
(100 * (i+1)) / training.batch_size
);

output[training.batch_entries[i].class] = 1;
genann_train(ann, training.batch_entries[i].pixels, output, 0.25);
output[training.batch_entries[i].class] = 0;
}
printf("\n");
}

FILE *output_file = fopen(argv[4], "w");
if(output_file) {
genann_write(ann, output_file);
fclose(output_file);
} else
perror("fopen");

genann_free(ann);

mnist_free(&training);
mnist_free(&tests);

return 0;
}

176 changes: 176 additions & 0 deletions mnist_db.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "mnist_db.h"
#include "utils.h"

int mnist_init(MnistDataset *output,
const char *images_file,
const char *labels_file,
int transpose,
size_t batch_size
)
{
size_t i;
double *buf;

if(!output)
return -1;

memset(output, 0, sizeof(MnistDataset));
output->transpose = transpose;

#ifndef _MSC_VER
output->fimage = fopen(images_file, "r");

if(!output->fimage) {
perror("fopen");
return -1;
}
#else
if(fopen_s(&output->fimage, images_file, "rb"))
return 1;
#endif

#ifndef _MSC_VER
output->flabel = fopen(labels_file, "r");

if(!output->flabel) {
perror("fopen");
return -1;
}
#else
if(fopen_s(&output->flabel, labels_file, "rb"))
return 1;
#endif

fseek(output->fimage, 4, SEEK_SET);

if(!fread(&output->entries_count, 4, 1, output->fimage)) {
perror("fread1");
return -1;
}

if(!fread(&output->width, 4, 1, output->fimage)) {
perror("fread2");
return -1;
}

if(!fread(&output->height, 4, 1, output->fimage)) {
perror("fread3");
return -1;
}

#ifdef LITTLE_ENDIAN
output->entries_count = CHANGE_ENDIANNESS(output->entries_count);
output->width = CHANGE_ENDIANNESS(output->width);
output->height = CHANGE_ENDIANNESS(output->height);
#endif /* LITTLE_ENDIAN */

if(batch_size != 0)
output->batch_size = batch_size;
else
output->batch_size = output->entries_count;

printf("Batch size: %zd; Width: %d; Height: %d\n",
output->batch_size, output->width, output->height);

output->batch_entries = malloc(sizeof(MnistEntry) * output->batch_size);
if(!output->batch_entries) {
perror("malloc");
return -1;
}

buf = malloc(sizeof(double) * output->width * output->height * output->batch_size);
if(!buf) {
perror("malloc");
return -1;
}

for(i = 0; i < output->batch_size; i ++) {
output->batch_entries[i].class = 0;
output->batch_entries[i].pixels = buf + i * (output->width * output->height);
}

return 0;
}

size_t mnist_load_batch(MnistDataset *dt)
{
size_t i, j;
size_t x, y;
double tmp;
MnistEntry *entry;
const size_t MNIST_ENTRY_SIZE = dt->width * dt->height;
unsigned char buf[MNIST_ENTRY_SIZE + 1];

if(dt->entries_read >= dt->entries_count)
dt->entries_read = 0;

for(i = 0; i < dt->batch_size; i ++, dt->entries_read ++) {
entry = &dt->batch_entries[i];

if(dt->entries_read >= dt->entries_count)
break;

fseek(dt->fimage, MNIST_ENTRY_SIZE * dt->entries_read + 16, SEEK_SET);
fseek(dt->flabel, dt->entries_read + 8, SEEK_SET);

/* Read the label */
if(!fread(buf, 1, 1, dt->flabel)) {
perror("fread");
break;
}
entry->class = (int) buf[0];

/* Read the image */
if(!fread(buf, MNIST_ENTRY_SIZE, 1, dt->fimage)) {
perror("fread");
break;
}

for(j = 0; j < MNIST_ENTRY_SIZE; j ++)
entry->pixels[j] = ((double) buf[j]) / 255.;
}

if(dt->transpose) {
for(i = 0; i < dt->batch_size; i ++) {
entry = &dt->batch_entries[i];
for(x = 0; x < dt->width; x ++) {
for(y = x+1; y < dt->height; y ++) {
/* Swap entry->pixels[x + dt->width * y] and entry->pixels[y + dt->width * x] */
tmp = entry->pixels[x + dt->width * y];
entry->pixels[x + dt->width * y] = entry->pixels[y + dt->width * x];
entry->pixels[y + dt->width * x] = tmp;
}
}
}
}

return i;
}

void mnist_free(MnistDataset *dt)
{
if(!dt)
return;

/*
Cette ligne de code fonctionne, car
elle repose sur le fait que les pixels
des différentes images soient sur un
même buffer contigue, et que l'addresse
du début de ce dit buffer correspond
à l'addresse du début de la première
MnistEntry, d'où ce free en particulier.
*/
fclose(dt->fimage);
fclose(dt->flabel);

free(dt->batch_entries[0].pixels);
free(dt->batch_entries);

memset(dt, 0, sizeof(MnistDataset));
}
48 changes: 48 additions & 0 deletions mnist_db.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef _MNIST_DB_H_
#define _MNIST_DB_H_

#include <stddef.h>
#include <stdio.h>

#define CHANGE_ENDIANNESS(a) \
( \
(((a) >> 0) & 0xFF) << 24 | \
(((a) >> 8) & 0xFF) << 16 | \
(((a) >> 16) & 0xFF) << 8 | \
(((a) >> 24) & 0xFF) << 0 \
)

typedef struct MnistEntry MnistEntry;
struct MnistEntry {
int class;
double *pixels;
};

typedef struct MnistDataset MnistDataset;
struct MnistDataset {
unsigned int width;
unsigned int height;
int transpose;

size_t entries_count;
size_t entries_read;

FILE *fimage;
FILE *flabel;

size_t batch_size;
MnistEntry *batch_entries;
};

#define CLASS_COUNT 10

/* Read a dataset from .
Returns -1 in case of a failure, and 0 otherwise. */
int mnist_init(MnistDataset *output, const char *images_file, const char *labels_file, int transpose, size_t batch_size);

size_t mnist_load_batch(MnistDataset *dt);

/* Libère la la base de donnée de la mémoire. */
void mnist_free(MnistDataset *dt);

#endif /* _MNIST_DB_H_ */