Skip to content

Commit 6d65b1a

Browse files
authored
Merge pull request #1050 from eclipse/ag_tablesaw_csv
Add table saw example
2 parents bc1bac6 + e06611e commit 6d65b1a

File tree

2 files changed

+122
-9
lines changed

2 files changed

+122
-9
lines changed

data-pipeline-examples/pom.xml

+32-9
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ information regarding copyright ownership.
3434
<properties>
3535
<dl4j-master.version>1.0.0-M1.1</dl4j-master.version>
3636
<!-- Change the nd4j.backend property to nd4j-cuda-X-platform to use CUDA GPUs -->
37-
<!-- <nd4j.backend>nd4j-cuda-10.2-platform</nd4j.backend> -->
37+
<!-- <nd4j.backend>nd4j-cuda-10.2-platform</nd4j.backend> -->
3838
<nd4j.backend>nd4j-native</nd4j.backend>
3939
<java.version>1.8</java.version>
4040
<maven-compiler-plugin.version>3.6.1</maven-compiler-plugin.version>
@@ -43,7 +43,10 @@ information regarding copyright ownership.
4343
<maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
4444
<logback.version>1.1.7</logback.version>
4545
<scala.binary.version>2.11</scala.binary.version>
46-
<spark.version>2.4.3</spark.version>
46+
<spark.version>2.4.8</spark.version>
47+
<tablesaw.version>0.38.5</tablesaw.version>
48+
<!-- Note: tablesaw needs a newer version of guava -->
49+
<guava.version>30.0-jre</guava.version>
4750
</properties>
4851

4952
<repositories>
@@ -62,6 +65,16 @@ information regarding copyright ownership.
6265
</repositories>
6366

6467

68+
<dependencyManagement>
69+
<dependencies>
70+
<dependency>
71+
<groupId>com.google.guava</groupId>
72+
<artifactId>guava</artifactId>
73+
<version>${guava.version}</version>
74+
</dependency>
75+
</dependencies>
76+
</dependencyManagement>
77+
6578
<dependencies>
6679
<dependency>
6780
<groupId>org.nd4j</groupId>
@@ -81,7 +94,7 @@ information regarding copyright ownership.
8194
<dependency>
8295
<groupId>org.datavec</groupId>
8396
<artifactId>datavec-spark_${scala.binary.version}</artifactId>
84-
<version>1.0.0-SNAPSHOT</version>
97+
<version>${dl4j-master.version}</version>
8598
</dependency>
8699
<dependency>
87100
<groupId>org.datavec</groupId>
@@ -102,12 +115,12 @@ information regarding copyright ownership.
102115
<groupId>org.deeplearning4j</groupId>
103116
<artifactId>deeplearning4j-ui</artifactId>
104117
<version>${dl4j-master.version}</version>
105-
<exclusions>
106-
<exclusion>
107-
<groupId>net.jpountz.lz4</groupId>
108-
<artifactId>lz4</artifactId>
109-
</exclusion>
110-
</exclusions>
118+
<exclusions>
119+
<exclusion>
120+
<groupId>net.jpountz.lz4</groupId>
121+
<artifactId>lz4</artifactId>
122+
</exclusion>
123+
</exclusions>
111124
</dependency>
112125
<dependency>
113126
<groupId>org.apache.spark</groupId>
@@ -139,6 +152,16 @@ information regarding copyright ownership.
139152
<artifactId>httpclient</artifactId>
140153
<version>4.3.5</version>
141154
</dependency>
155+
<dependency>
156+
<groupId>tech.tablesaw</groupId>
157+
<artifactId>tablesaw-core</artifactId>
158+
<version>${tablesaw.version}</version>
159+
</dependency>
160+
<dependency>
161+
<groupId>com.google.guava</groupId>
162+
<artifactId>guava</artifactId>
163+
<version>${guava.version}</version>
164+
</dependency>
142165
</dependencies>
143166

144167
<!-- Maven Enforcer: Ensures user has an up to date version of Maven before building -->
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*******************************************************************************
2+
*
3+
*
4+
*
5+
* This program and the accompanying materials are made available under the
6+
* terms of the Apache License, Version 2.0 which is available at
7+
* https://www.apache.org/licenses/LICENSE-2.0.
8+
* See the NOTICE file distributed with this work for additional
9+
* information regarding copyright ownership.
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* License for the specific language governing permissions and limitations
15+
* under the License.
16+
*
17+
* SPDX-License-Identifier: Apache-2.0
18+
******************************************************************************/
19+
package org.deeplearning4j.datapipelineexamples.tablesaw;
20+
21+
import com.google.common.primitives.Doubles;
22+
import com.google.common.primitives.Ints;
23+
import org.deeplearning4j.datapipelineexamples.utils.DownloaderUtility;
24+
import org.nd4j.linalg.api.ndarray.INDArray;
25+
import org.nd4j.linalg.dataset.DataSet;
26+
import org.nd4j.linalg.factory.Nd4j;
27+
import org.nd4j.linalg.util.FeatureUtil;
28+
import tech.tablesaw.api.CategoricalColumn;
29+
import tech.tablesaw.api.DoubleColumn;
30+
import tech.tablesaw.api.Table;
31+
import tech.tablesaw.io.csv.CsvReadOptions;
32+
33+
import java.io.File;
34+
import java.util.Arrays;
35+
import java.util.stream.Collectors;
36+
37+
/**
38+
* This example uses the table saw library to prepare csv data for conversion to a neural network.
39+
* If you would like more information on tablesaw, please look at the table saw quickstart:
40+
* https://jtablesaw.github.io/tablesaw/gettingstarted
41+
*
42+
* This example leverages tablesaw to load a csv and convert it to a dataset object.
43+
*
44+
* @author Adam Gibson
45+
*/
46+
public class TablesawCSVExample {
47+
48+
public static void main(String...args) throws Exception {
49+
//download the data
50+
String directory = DownloaderUtility.IRISDATA.Download();
51+
//note our downloaded csv has no headers, so we want auto generated column names
52+
CsvReadOptions csvReadOptions = CsvReadOptions
53+
.builder(new File(directory, "iris.txt")).header(false).build();
54+
Table table = Table.read().csv(csvReadOptions);
55+
System.out.println(table.columnNames());
56+
//Convert the data without the label column to get just the raw input data out.
57+
Table justLabel = Table.create(table.column(4));
58+
Table withoutLabel = table.removeColumns(table.column(4));
59+
//convert the data to a double array filtering the column without
60+
double[][] data = Arrays.stream(withoutLabel.columnArray())
61+
.map(column -> (DoubleColumn) column)
62+
.map(input -> input.asList())
63+
.map(input -> Doubles.toArray(input))
64+
.collect(Collectors.toList())
65+
.toArray(new double[table.columnNames().size()][]);
66+
67+
//create the data from the array and print the data
68+
INDArray arr = Nd4j.create(data);
69+
System.out.println(arr.toStringFull());
70+
71+
//print the categories
72+
CategoricalColumn<?> objects = justLabel.categoricalColumn(0);
73+
System.out.println("List " + objects.asList());
74+
System.out.println(objects.countByCategory());
75+
76+
77+
//create an ndarray of the outcomes converted to categorical 0 1 labels
78+
int[] outcomes = Ints.toArray(justLabel.longColumn(0).asList());
79+
INDArray labels = FeatureUtil.toOutcomeMatrix(outcomes, 3);
80+
81+
82+
//create a dataset object containing the input and the labels
83+
DataSet dataSet = new DataSet(arr,labels);
84+
85+
86+
87+
}
88+
89+
90+
}

0 commit comments

Comments
 (0)