Search in sources :

Example 1 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class Chap03 method main.

public static void main(String[] args) throws Exception {
    System.out.print("Flink version: ");
    System.out.println(AlinkGlobalConfiguration.getFlinkVersion());
    System.out.print("Default Plugin Dir: ");
    System.out.println(new File(AlinkGlobalConfiguration.getPluginDir()).getCanonicalPath());
    AlinkGlobalConfiguration.setPluginDir(ALINK_PLUGIN_DIR);
    System.out.print("Current Plugin Dir: ");
    System.out.println(new File(AlinkGlobalConfiguration.getPluginDir()).getCanonicalPath());
    PluginDownloader downloader = AlinkGlobalConfiguration.getPluginDownloader();
    List<String> pluginNames = downloader.listAvailablePlugins();
    for (String pluginName : pluginNames) {
        List<String> versions = downloader.listAvailablePluginVersions(pluginName);
        System.out.println(pluginName + " => " + ArrayUtils.toString(versions));
    }
    downloader.downloadPlugin("mysql", "5.1.27");
    downloader.downloadAll();
    BatchOperator.setParallelism(1);
    c_1_1();
    c_1_2_1();
    c_1_2_2();
    c_1_3_1();
    c_1_3_2();
    c_2_1_1();
    c_2_1_2();
    c_2_2();
    c_2_3_1();
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) File(java.io.File)

Example 2 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BertTextEmbeddingMapperTest method test.

@Category(DLTest.class)
@Test
public void test() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = BertResources.getRegisterKey(ModelName.BASE_CHINESE, ResourceType.SAVED_MODEL);
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    Params params = new Params();
    params.set(BertTextEmbeddingParams.BERT_MODEL_NAME, "base-chinese");
    params.set(BertTextEmbeddingParams.SELECTED_COL, "text_a");
    params.set(BertTextEmbeddingParams.OUTPUT_COL, "embed");
    params.set(BertTextEmbeddingParams.RESERVED_COLS, new String[] { "text_b" });
    params.set(BertTextEmbeddingParams.LAYER, -2);
    params.set(BertTextEmbeddingParams.DO_LOWER_CASE, true);
    params.set(BertTextEmbeddingParams.MAX_SEQ_LENGTH, null);
    TableSchema dataSchema = TableSchema.builder().field("text_a", Types.STRING).field("text_b", Types.STRING).build();
    BertTextEmbeddingMapper mapper = new BertTextEmbeddingMapper(dataSchema, params);
    System.out.println(mapper.getOutputSchema());
    mapper.open();
    // Test English
    {
        Row result = mapper.map(Row.of("An english sentence.", "中文句子"));
        // this result is already compared with HuggingFace result with `np.allclose(a, b, atol=1e-6)`
        double[] expected;
        {
            expected = new double[] { 0.089073084, 0.20056348, 0.3516009, -0.29446188, 0.28716442, -0.8390936, 0.4182136, -0.18312955, 0.45383734, 0.65060294, 0.051138446, 0.22724059, 0.27704486, -0.65300393, 0.98041856, 0.09643227, 0.35706848, 0.6780398, 0.123599894, -0.18899295, -0.26279125, 0.44108105, 0.2103959, -0.38176265, 0.3546249, -0.56568205, -0.06408266, -0.27861258, 0.06763687, 0.1666355, 0.014923916, 0.007977423, -0.021746507, -0.2657324, 0.22584824, -0.7585038, 0.04854399, -0.009563353, 0.50774336, -0.3309519, 0.45124555, 0.33328262, -0.3069663, 0.14376236, -0.83023846, -0.36924392, 0.28073525, 0.007041442, 0.47512805, 0.02919788, -0.17385913, 8.108538, 0.5834048, 0.4401444, -0.23541903, 0.24149127, 0.48492688, -0.05407483, -0.1428802, -0.906088, -0.66015375, -0.07356628, -0.02404631, 0.5503326, -0.0069908244, -0.3908803, 0.6102739, -0.088781185, 0.19956489, 0.033619206, 0.18250264, -0.03669653, 0.5917836, -0.008031259, 0.2687996, 0.21792586, -0.612289, -0.14856814, 0.40163162, -0.06571546, -0.061800003, -0.077387005, -0.20518686, 0.37954232, -0.27198482, -0.047261536, 0.26703018, 0.23476101, -0.035176348, 0.38439688, 0.1560608, 0.08167325, -0.5239277, 0.21489042, -0.07825369, 0.14580987, -0.056497473, 0.68141, -0.020929975, -0.060883746, 0.14015076, 0.38673988, -0.017782483, -0.33661842, -0.17964366, 2.0730247, 0.25984696, -0.06799029, -0.020712927, -0.4332832, -0.406993, 0.48468983, -0.7822486, -0.2704517, -0.009232765, 0.31674927, -0.44500858, 0.29990223, 0.19996054, -0.7969638, 0.4079611, -0.10532191, 0.4765936, 0.78918314, 0.47114816, 0.29029685, -0.058260716, -0.20847711, 0.044183735, 0.3809099, -0.11276442, -0.111197315, 0.30226305, -0.60424334, -0.11220238, 0.034364242, 0.3626102, -0.23099208, -0.67327315, 0.5894861, -0.12862045, 0.16448429, -0.6982243, -0.31022722, 0.44725528, 0.29328576, 0.63399965, 0.8528066, -0.333605, -0.20188892, -0.13903502, -0.2719082, -0.005485626, -0.13167712, 0.12400101, 0.1406203, 0.5087689, 0.0786869, -0.16825227, 0.016926207, 0.27797514, 0.5736567, 0.41123697, -0.29326963, -0.04540349, -0.4771427, 0.33664274, -0.61130226, 0.010852352, 0.5919538, -0.40366083, -0.6581887, -0.1263123, 0.06560425, -0.27138802, 0.055331457, -0.5859766, -0.4126649, -0.05864937, 0.18987034, -1.323727, -0.17399509, 0.16118336, 0.37771478, -0.14546172, -0.56139463, 0.3935274, -0.09063497, -0.06971432, 0.06910816, 0.40521935, -0.12895046, 0.03743412, -0.21655117, 0.37266493, 0.3083255, 0.1077936, 0.31159502, -0.20470755, 0.29801783, -0.92609096, -0.41326338, -0.4643464, 0.12998559, 0.052346323, -0.010444339, 1.2203202, 0.39612085, -0.11651037, -0.35085708, 0.50413233, -0.05850105, -0.19116706, 0.84110314, -0.07596079, -0.48562497, 0.074693196, 0.39990926, 0.6280287, -0.273503, 0.14537963, 0.23487864, 0.1028335, 0.5833937, -0.36302227, 0.39023992, -0.05183344, 0.4928117, -0.6357919, 0.28944543, 0.15658323, 0.4429412, -0.23182562, 0.5952031, 0.32105538, -0.15936747, -0.22350433, -0.22395854, 0.79166734, 0.1855918, -0.029996136, 0.24539514, -0.21128504, -0.1549417, 0.56421924, 0.19268781, -0.25786325, -0.17128454, -0.058976505, 0.30126545, -0.06038547, 0.36465937, -0.23425841, -0.3668297, -0.43959534, 0.38348344, 0.018327482, 0.4123795, -0.04099132, 0.012244595, 0.49430224, -0.32182, -0.6019442, -0.015638014, 0.51026505, 0.1068835, 0.03532704, -1.1697477, 0.73523736, 0.6450712, -0.94440675, -0.2517156, 0.05954993, 0.3661471, -0.33924416, -0.64593345, -0.16584522, 0.184489, 0.36938062, 0.014731329, -0.7576283, -0.20003356, 0.43362293, 0.31597397, 0.29983208, 0.051825784, -0.070177235, -0.43420777, 0.19157119, -0.72922254, -0.57415557, -0.6044294, -0.16493118, -0.11360306, 3.107155, -0.41438738, 0.07558526, -0.27004382, 0.3984214, 0.40921322, -0.8065814, 0.42056122, -0.6024486, 0.6556587, 0.7679883, -0.24327908, -0.037827346, 0.08392061, 0.19684246, 0.023937171, -0.2301765, -0.017345892, 0.86922956, -0.4563738, -0.07474198, -0.61980855, 0.4433976, -0.41383296, 0.1086806, -0.15101764, -0.15205815, -0.08962966, 0.053946137, 0.6667063, 0.53586006, -0.09328938, 0.66806424, -1.0493509, 0.6348454, -0.31382576, -0.025229406, 0.14783347, 0.10988949, 0.67091316, 0.08754399, -0.047084138, 0.19628495, -0.1960672, 0.008753071, -0.4863474, -0.42654783, 0.07900074, -0.5045268, 0.28226477, -0.68395984, -0.03362143, -0.014122, -0.12746842, 0.20408744, -0.23023072, 0.25175712, 0.22371025, -0.58425987, -0.39206004, -0.7467269, -0.2676628, -0.6527664, 0.40715516, 0.2686616, 0.22032455, 0.6499737, 0.025618043, -0.013638815, -0.06677554, 0.5073563, 0.48881236, -0.10977107, 0.27756214, 0.43145332, -0.391, 0.2766744, 0.25780347, 0.19996944, 0.47653675, -0.21553159, 0.42380095, 0.042722337, -0.24715047, 0.6930257, -0.19161859, -0.44895643, 0.07181698, -0.47244114, 0.23078291, 0.5319174, 0.09428276, 0.88944554, -0.7046152, -0.33722076, -0.25829387, -0.6431879, 0.36444882, 0.35961518, 0.6304723, 0.26947102, 0.1360612, 0.43042153, -0.3797312, 0.33820257, -0.49372354, 0.48026037, 0.17663732, 0.31045946, -0.09340149, -0.2514405, 0.13995169, 0.319726, -0.093683735, 0.8692651, -0.39895487, -0.18782426, 0.44306785, 0.08950553, 0.17533414, 0.46815917, -0.19433542, 0.4034581, 0.017003069, 0.380734, -0.14529721, -0.3129647, 0.23608868, 0.2613641, -0.5531271, 0.5769198, -0.23282692, -0.34113243, 0.19667614, -0.40514293, 0.45187485, -0.36073324, -0.56202817, -0.29550225, 0.6506499, -0.7111385, 0.0374949, -0.6610919, -0.1571065, -0.24706063, -0.37649712, 0.042870417, -0.54221517, -0.41191533, 0.08726144, 0.4776909, 0.17190549, -0.06772761, 0.25872043, -0.27678707, 1.0068991, 1.0665413, 0.38176787, -0.5344325, 0.08187985, -0.5407639, -0.014430046, -0.1571247, -0.24232443, 0.32073474, -0.36651748, -0.5973171, -0.38527447, 0.24909297, -0.03317968, 0.53609484, 0.16727896, -0.12671076, -0.60251546, -0.16898218, -0.43730286, -0.07012422, 0.055398405, 0.0055612167, 0.04978605, -0.35225272, 0.79142636, 0.4752555, -4.0751882E-4, -0.49435905, -0.11460393, 0.21916138, -0.08201371, 0.33711022, -0.2675837, -0.20002404, 0.020791233, 0.14494887, -0.37943113, -0.114889406, 0.25731784, -0.47965208, 0.33978042, 0.29868034, 0.48542416, -0.23872109, -0.5013446, 0.37800694, -0.183577, 0.34225625, -0.17961891, -0.47093594, -0.036617823, 0.007107933, -1.1118665, -0.80339146, 0.18684658, 0.31432015, 0.19632326, -0.3347409, 0.62178373, 0.0131239975, -0.4607244, 0.12699743, -0.20341435, 0.5638455, 0.0043124957, 0.017100042, 0.01631973, -0.09410723, 0.41816282, -0.25815982, 0.42425865, 0.5635872, -0.15940957, 0.1961887, 0.18360351, 0.23237108, -0.07556974, -0.2736857, 0.23201534, -0.16999784, -0.26305136, 0.03616975, 0.1490356, -0.31567413, -0.43173206, -0.14945477, -0.486137, -0.24025288, 1.024534, 0.56033677, -0.589806, -0.030385733, 0.39098585, -0.15309268, 0.5764666, -0.42667878, 0.039409995, 0.25750253, -0.068650514, -0.0400423, 0.0791655, 0.41124946, 0.030511046, -0.56712395, 0.049367446, 0.33574924, -0.18492404, -0.42992696, 0.06730447, -0.867941, -0.14779219, -0.06794728, -0.14499989, -0.66920936, -0.07477878, -0.1496128, -0.12472385, 0.32638785, 0.16772252, 0.18236694, 0.2806634, -0.43957776, 0.063719206, -0.047652632, -0.023638368, 0.08651493, -0.24735723, 0.14668958, -0.38661274, -0.18371972, 0.13630907, -0.047801457, -0.062571496, -0.06973621, 0.08127009, 0.39519763, -0.06686574, 0.36243415, 0.42805952, -0.26940322, -0.39546525, 0.703899, 0.76100504, -0.11092762, -0.16995934, 0.3018224, 0.02022123, 0.0039080214, 0.05129872, 0.022262864, 0.10892585, -0.6039272, -0.12820105, 0.36593267, 0.2438417, 0.11985212, -0.23406501, -0.45799938, 0.14650963, 0.2928344, -0.18667708, 0.5859505, -0.32387814, -0.31242442, -0.26415035, -0.2947386, 0.43614882, -0.20847061, -0.0026533732, -0.82337356, -0.5180675, 0.47124967, -0.54627514, 0.8065849, 0.066406906, 0.21717988, -0.07597586, 0.2015869, -0.45065638, -0.23080602, -0.3306711, 0.53038013, 0.17238441, -0.05583993, -0.34874824, -0.030980049, 0.3712661, -0.28410155, -0.47641006, 0.2571871, 0.34637967, -0.008641526, -0.22793008, -0.17121935, -0.075603165, 0.63212615, -0.2495372, 0.54692096, -0.27013966, 0.65434086, -0.26032814, 4.6927598E-4, -0.2872511, 0.6838579, -0.010238962, 1.172118, -0.21562524, 0.21527323, -0.25726637, -0.059714526, 0.2502196, 0.61774695, -0.2925442, 0.31493646, 0.3539606, -0.21500365, 0.29216832, 0.17060214, 1.1769214, -0.48165116, -0.007944281, 0.31471652, -0.28504306, -0.106352694, 0.3174839, 0.29149395, 0.45813614, 0.09905476, -0.3644543, 0.8294207, -0.28475535, 0.4997395, 0.26549652, 0.3360206, 0.3276545, -0.4240916, -0.26481426, 0.5335953, -0.13540079, -0.42259574, 0.4970337, 0.90958863, -0.34506807, -0.71407956, -12.573184, -0.015176255, 0.014076395, 0.08243999, 0.22496581, 0.7837829, 0.24365745, -0.07991051, -0.32745788, 0.29997948, -0.2765685, 0.09622552, 0.7834622, -0.09733227, 0.513706, 0.32021743, 0.3571033, 0.08759855, -0.8603887, 0.38747987, 0.15559518, -0.47555524, -0.06449654, 0.30913618, -0.3591024, 0.27649412, -0.33506414, 0.6148581, 0.26756653, 0.22131035, 0.455839, -0.2536694, 0.2805701, 0.1510673, 0.06973805, 0.27653408, -0.7637936, 0.56921303, -0.21870066, 0.0030567697, -0.45575276, -0.35519856, 0.6791311, -0.16410932, 0.9783003, -0.7080118, 0.33462662, -0.19515315, 0.008544347, -0.057110723, -0.4179183, -0.057121307, -0.045918223, -0.00953709, 0.4432715, -0.3628051, 0.3457998, 1.2684398, 0.25004348, 0.06620354, -0.64985204, -0.7835815, -0.14532885, -0.10538084, 0.27382025, 0.037735656, -0.46953905, 0.013022469, -0.34265116, -0.62361187, -0.40610304, 0.2656766, 0.22653538 };
        }
        double[] actuals = Arrays.stream(((String) result.getField(1)).split(" ")).mapToDouble(Double::parseDouble).toArray();
        Assert.assertArrayEquals(expected, actuals, 1e-2);
    }
    // Test Chinese
    {
        Row result = mapper.map(Row.of("这是一个中文句子", "English sentence"));
        // this result is already compared with HuggingFace result with `np.allclose(a, b, atol=1e-6)`
        double[] expected;
        {
            expected = new double[] { -0.67530066, 0.61594594, -0.056693964, -0.5887352, 0.789959, -1.0262289, 0.17976294, -0.13194971, -0.13077053, 0.72640383, -0.37538335, 0.2723307, 0.23495108, -0.17314787, 0.5977942, -0.031309847, -0.059140697, 0.06078338, 0.6097555, -0.102226615, 0.3175769, -0.13020651, 0.4198398, -0.053274635, 0.29587185, -0.44391736, 0.5419951, -0.70908546, 0.3402457, 0.42862296, -0.091947116, 0.08983829, -0.1262072, -0.46288097, 0.6739168, 0.6550396, -0.44469208, 0.12572655, 0.5656046, 0.014711852, 0.13476004, -0.39801106, 0.24407381, -0.078456834, -0.69440216, -0.9241711, 0.36084417, 0.12529421, 0.2964751, -0.13504946, 0.5846213, 9.069189, 0.6926915, -0.006680295, -0.053333905, 0.20349851, 0.03051044, 0.082488894, 0.20274848, -0.26723158, -0.27928227, -0.32986775, 0.18134634, 0.76559615, -0.08205241, -0.3319447, 0.35764945, 0.44840956, -0.20678492, 0.09583793, 0.07100839, -0.07977803, -0.26061502, 0.40102804, 0.4134229, -0.026218204, -0.18589236, -0.25868884, 0.11865289, 0.07374452, -0.22874112, 0.15793541, -0.30735266, 0.29275736, -0.5740867, -0.4128435, -0.5746707, 0.5775771, 0.24116059, 0.053272687, -0.9665233, 0.019698314, -0.5611463, -0.06428814, -0.07146941, 0.6884945, -0.081167884, 0.13943768, -0.2565966, 0.12531683, -0.8389492, -0.18512477, -0.6811464, 1.0172554, -0.3174257, 1.1349304, 0.037906446, 0.03227273, 0.13568033, 0.19226004, -0.29172787, 0.5762233, -0.3751546, -0.34556678, 0.05894437, -0.23264313, -0.19717728, 0.34370446, 0.10671129, -0.09483788, 0.09413337, 0.0120273605, 0.3687146, 0.016466876, 0.26079795, 0.12096234, 0.06536596, -0.2846961, -0.4254726, -0.086970545, -0.3523813, 0.27028432, 0.33481583, -0.4809917, -0.13440402, 0.10722831, -0.13411504, 0.25703508, -0.7368167, 0.5896459, 0.34396634, -0.37947753, -0.49355638, -0.2119609, 0.4325806, 0.6902696, 0.31320798, 0.56302667, 0.09417565, 0.10611598, -0.13109058, -0.3317111, 0.15148342, 0.49761084, 0.04269329, -0.21372125, 0.07363319, 0.008428436, -0.024836902, -0.42088622, -0.08528121, 0.52482665, 0.009306508, 0.084471785, -0.16503625, -0.017703086, -0.040039614, -0.39058322, -0.44110668, 0.03559628, 0.15294297, 0.05129311, -0.1225929, 0.0453198, 0.15047674, 0.38232255, -0.19527261, 0.07444011, 0.23059799, -0.108181946, -0.53018343, -0.06947766, 0.06221337, -0.21450219, -0.26703936, 0.1363355, -0.2144355, 0.1846843, -0.3489314, -0.44297603, 0.16135284, -0.16144319, 0.12989485, -0.4466489, 0.3299668, -0.33318454, 0.005696993, 0.106931396, 0.03200325, 0.71382195, -0.5061703, -0.51759875, 0.19212422, 0.04083005, -0.069712356, -0.27714688, 0.62951106, 0.4787875, 0.07156138, -0.040474653, 0.5792815, 0.094450206, 0.38409925, 0.044219777, 0.44535616, 0.61113906, 0.25654292, -0.23699701, 0.4121236, 0.11646683, 0.1060531, 0.063176244, -0.15068226, -0.36154372, 0.100659974, -0.06296361, 0.47572947, 0.22414996, -0.26523653, -0.30621967, 0.21636419, 0.30424714, -0.3487851, 0.046763465, 0.07107177, 0.18565592, -0.12067094, 0.0013190955, 0.14982918, -0.1791432, 0.2638161, 0.08431108, 0.13931157, -0.30190447, -0.8309533, -0.20725285, -0.15094705, -0.06601327, 0.4640572, 0.14468326, 0.023498995, 0.18870713, -0.34320956, -0.661373, -0.30789602, 0.6710894, 0.23725967, 0.29611212, -0.1190055, 0.1204317, -0.06637057, 0.49064258, -0.71295047, -0.050598703, 0.67767906, 0.19402614, -0.32084003, -0.5663817, 0.45377412, 0.43638042, -0.3744999, -0.13144939, 0.14107968, -0.038628887, 0.4568751, -0.23326007, 0.36349088, -0.51840705, -0.12847212, 0.1377168, -0.19077566, -0.05172722, 0.052513633, -0.41866693, 0.3356778, -0.9116799, 0.022393376, 0.42054072, -0.25010934, -0.26275656, -0.24738365, -0.36184305, -0.008754494, -0.5345826, 1.9404162, -0.15192655, -0.04973651, -0.383245, -0.047098756, 0.36554, -0.92108303, 0.5524133, -0.49321863, 0.40408948, -0.104516745, 0.1490599, 0.2319855, 0.19936046, 0.43532246, 0.2525194, 0.20589466, 0.17501022, 0.09644188, -0.6353579, 0.12638886, -0.26842985, 0.16125616, 0.0662491, 0.36832082, -0.18316424, 0.23531684, -0.064759985, 0.06317103, 0.13060418, 0.11981614, -0.46983102, -0.021113481, -0.6405569, 0.48385215, 0.22422825, -0.47484398, -0.43302104, -0.39680895, 0.2969588, 0.71188235, 0.35006645, 0.1453332, 0.38688833, 0.087308414, -0.16717857, 0.071094304, -0.357324, 0.092632696, -0.5404178, -0.15791667, 0.11731309, 0.6964182, -0.5139237, 0.0370673, -0.32018265, -0.3304256, 0.18350527, -0.2327405, -0.2806859, -0.46063206, 0.20929328, -0.9721338, 0.22960168, -0.4957172, 0.1716004, 0.9903108, -0.05619159, 0.3180037, 0.37226647, 0.15360224, -0.09633264, 0.07887778, -0.055395328, 0.43275484, -0.57034624, 0.15261361, -0.21941976, 0.2996632, -0.14595373, 0.46338066, -0.33412176, -0.07854297, -0.2309417, -0.47490707, -0.17874378, -0.71656114, 0.5268188, 0.2117231, -0.14560121, 0.36597407, 0.18142216, 0.25824717, -0.07349595, 0.20754527, -0.101178385, -0.31051034, 0.8227749, -0.47965798, 0.647296, 0.08015231, 0.4955694, 0.13666879, -0.5208422, 0.40139586, -0.97117084, -0.011994401, 0.5233678, -0.056755953, -0.51033676, -0.17177448, -1.5431503, 0.5889413, 0.6228572, 0.63479817, -0.84462154, 0.18693976, 0.3436292, 0.54034567, -3.331292E-4, 0.22284168, -0.13979425, 0.89563626, 0.24064341, 0.37815812, -0.3982351, 0.29792884, 0.46231395, 3.9599428E-4, 0.06392538, -0.039604213, -0.5172013, -0.9095762, -0.055725507, 0.44805524, 0.076555304, 0.18504235, 0.20971671, -0.15783177, 0.4367001, 0.27264124, 0.6719367, -0.13801299, 0.32202727, -0.7289231, 0.49216944, 0.16491476, -0.54049265, -0.349262, 0.15989335, 0.07725942, -0.059543356, 0.20832932, -0.2788309, 0.46147618, 0.377395, 0.6870447, -0.36773545, -0.33286142, -0.00564318, -0.34101346, 0.31051284, 0.48842332, -0.4044629, 0.042303976, -0.17452177, 0.252395, 0.13060617, 0.18446071, 0.3588017, 0.07421272, 0.61265516, 0.21274248, -0.11963665, 0.41415054, -0.54229313, -0.12186932, -0.28478304, 0.10421303, 0.16568889, 0.008618522, -0.027819963, 0.06666305, 0.3182098, -0.9028614, -0.6005104, 0.4919864, -0.1461413, 0.7291196, -0.3996712, 0.5876187, 0.004493188, 0.025345238, 0.3169623, 0.6749967, 0.39498597, -0.5359254, -0.095082656, 0.48348033, -0.098074004, -0.2801097, 0.06363764, 0.65557885, 0.40710562, 0.12931323, 0.2650215, 0.3964837, -0.25598568, 0.49602374, -1.0051873, -0.34377953, -0.09127153, -0.09675396, 0.1854866, -0.40283024, -0.38420478, -0.047086354, 0.0022745468, -0.3908877, -0.41917852, 0.23628148, -0.045927297, 0.579911, -0.0176501, 0.78772616, 0.68177736, -0.30374396, 0.6266679, 0.5441422, 0.02987818, 0.061886575, 0.77770215, -0.07932908, -0.33103776, -0.28426328, -0.08212303, 0.042760104, -0.22247449, -0.16209176, 0.2573553, 0.082359865, 0.39242637, -0.29822388, -0.1584899, -0.3133868, 0.85202396, 0.06537342, -0.6039086, -0.015240644, 0.026361527, 0.19257236, 0.44367018, -0.43892404, -0.2818918, -0.1333324, 0.30089086, -0.06731736, -0.3561126, 0.30563122, 0.19338936, 0.23238602, -0.5528297, 0.3175812, 0.3169552, 0.16785741, -0.13066062, -0.5738654, 0.0655802, -0.41268015, 0.8220674, -0.8956028, 0.26038638, -0.30397677, 0.36395076, -0.030978065, -0.19425192, -0.42041972, 0.28871268, 0.27627385, -0.15289834, 0.2192859, -0.35178152, 0.050021175, 0.4931296, 0.37743875, -0.85979533, 0.41984963, 0.2929396, -0.20290282, 0.31495363, -0.4516361, 0.05143035, 0.47896478, 0.32632774, 0.42514107, 0.3382379, -0.19836453, 0.06697344, 0.24137944, 0.7640314, -0.3762555, 0.013103037, 0.48038802, 0.45538932, -0.33745858, -0.17517568, -0.3041898, -0.2591084, 0.32268375, 0.07433881, 0.62003034, -0.10828473, -0.1459326, 0.10252112, -0.02565641, 0.33866906, 0.054388985, -0.29714227, 0.3316252, 0.36335745, -0.16429752, -0.37883237, 0.27983367, -0.20368642, 0.027274836, -0.19377503, -0.42317852, -0.32567635, 0.32197624, 0.11617374, 0.49373147, 0.24056932, 0.34718528, -0.19460575, 0.22410084, -0.20902018, -0.548216, -0.25581166, -0.0356154, -0.07507834, 0.24492535, 0.2095812, 0.34545565, 0.5623524, 0.55008215, 0.07757254, -0.368302, 0.1641082, 0.6412437, 0.5775129, 0.0068530664, -0.21834902, 1.0380987, -0.56366926, 0.28795698, -0.6723282, 0.3331007, 0.32735845, 0.0013436913, 0.14073966, 0.6418032, -0.07251156, 0.46461236, -0.3980294, -0.26899946, 0.28307268, 0.04513643, 0.46508798, 0.98492813, 0.0033691209, 0.6478122, 0.45498514, 0.03501081, -0.6224758, -0.60526854, 1.4630777, -0.21021096, -0.1361977, 0.35714173, -0.22630489, -0.088193096, 0.038374603, 0.6149385, 0.14970909, 0.24360953, -0.40581468, 0.53937286, 0.01577225, -0.44539908, -0.012677344, 0.08264979, 0.0011214706, 0.45130453, -0.12674686, -0.4381374, -0.26769787, 0.09133074, 0.5127477, 0.40458146, -0.41203475, -0.78492004, -13.184955, -0.12402881, 0.28563613, -0.451989, 0.061025247, -0.29014575, 0.26550668, 0.22826888, -0.008380551, -0.6564349, -0.596356, 0.13780147, 0.398271, 0.02930097, 0.14554422, -0.114068694, 0.49568233, 0.46178737, -0.76289266, 0.2996972, 0.13880418, 0.43003055, -0.25730652, 0.495184, -0.4609631, -0.4219708, -0.12821707, 0.21092743, -0.36644837, 0.7974993, -0.08656667, -0.5314165, -0.07028645, 0.16069822, -0.18608248, 0.61463374, -1.1505986, -0.027690325, -0.13950048, 0.1695784, 0.5994827, -0.17978896, 0.022879254, -0.035342228, 0.7042892, -0.53858364, 0.1631006, -0.24884132, -0.09085574, -0.123774245, 0.026201617, 0.031942204, -0.41603425, -0.14774984, 0.9545462, -0.2746995, -0.118844785, 0.52170795, -0.11141356, -0.0073091295, -0.10394416, -0.73148096, 0.19561234, -0.0794713, 0.32161495, 0.046833698, -0.25482565, -0.30924103, -0.47611946, 0.08269883, 0.3996641, 0.70499754, -0.46837986 };
        }
        double[] actuals = Arrays.stream(((String) result.getField(1)).split(" ")).mapToDouble(Double::parseDouble).toArray();
        Assert.assertArrayEquals(expected, actuals, 1e-2);
    }
    mapper.close();
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) TableSchema(org.apache.flink.table.api.TableSchema) BertTextEmbeddingParams(com.alibaba.alink.params.tensorflow.bert.BertTextEmbeddingParams) Params(org.apache.flink.ml.api.misc.param.Params) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 3 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BaseTFSavedModelPredictMapperTest method testString.

@Category(DLTest.class)
@Test
public void testString() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    String url = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/mnist_dense.csv";
    String schema = "label bigint, image string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setFieldDelimiter(";");
    List<Row> rows = data.collect();
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath" + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes bigint, probabilities string");
    BaseTFSavedModelPredictMapper baseTFSavedModelPredictMapper = new BaseTFSavedModelPredictMapper(data.getSchema(), params);
    baseTFSavedModelPredictMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", Types.STRING).field("classes", Types.LONG).field("probabilities", Types.STRING).build(), baseTFSavedModelPredictMapper.getOutputSchema());
    for (Row row : rows) {
        Row output = baseTFSavedModelPredictMapper.map(row);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
    }
    baseTFSavedModelPredictMapper.close();
}
Also used : PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) Params(org.apache.flink.ml.api.misc.param.Params) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 4 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BaseTFSavedModelPredictRowFlatMapperTest method testTensor.

@Category(DLTest.class)
@Test
public void testTensor() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    List<Row> rows = new ArrayList<>();
    for (int i = 0; i < 1000; i += 1) {
        Row row = Row.of(0, new FloatTensor(new Shape(28, 28)));
        rows.add(row);
    }
    BatchOperator<?> data = new MemSourceBatchOp(rows, "label LONG, image FLOAT_TENSOR");
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath: " + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes LONG, probabilities FLOAT_TENSOR");
    BaseTFSavedModelPredictRowFlatMapper baseTFSavedModelPredictRowFlatMapper = new BaseTFSavedModelPredictRowFlatMapper(data.getSchema(), params);
    baseTFSavedModelPredictRowFlatMapper.open();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", TensorTypes.FLOAT_TENSOR).field("classes", Types.LONG).field("probabilities", TensorTypes.FLOAT_TENSOR).build(), baseTFSavedModelPredictRowFlatMapper.getOutputSchema());
    List<Row> outputs = new ArrayList<>();
    ListCollector<Row> collector = new ListCollector<>(outputs);
    for (Row row : rows) {
        baseTFSavedModelPredictRowFlatMapper.flatMap(row, collector);
    }
    baseTFSavedModelPredictRowFlatMapper.close();
    for (int i = 0; i < rows.size(); i += 1) {
        Row row = rows.get(i);
        Row output = outputs.get(i);
        Assert.assertEquals(row.getField(0), output.getField(0));
        Assert.assertEquals(row.getField(1), output.getField(1));
        Assert.assertArrayEquals(((FloatTensor) output.getField(3)).shape(), new long[] { 10 });
    }
}
Also used : Shape(com.alibaba.alink.common.linalg.tensor.Shape) ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) MemSourceBatchOp(com.alibaba.alink.operator.batch.source.MemSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) ListCollector(org.apache.flink.api.common.functions.util.ListCollector) FloatTensor(com.alibaba.alink.common.linalg.tensor.FloatTensor) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Example 5 with PluginDownloader

use of com.alibaba.alink.common.io.plugin.PluginDownloader in project Alink by alibaba.

the class BaseTFSavedModelPredictRowFlatMapperTest method testString.

@Category(DLTest.class)
@Test
public void testString() throws Exception {
    AlinkGlobalConfiguration.setPrintProcessInfo(true);
    PluginDownloader pluginDownloader = AlinkGlobalConfiguration.getPluginDownloader();
    RegisterKey registerKey = TFPredictorClassLoaderFactory.getRegisterKey();
    pluginDownloader.downloadPlugin(registerKey.getName(), registerKey.getVersion());
    MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(2);
    String url = "https://alink-test-data.oss-cn-hangzhou.aliyuncs.com/mnist_dense.csv";
    String schema = "label bigint, image string";
    BatchOperator<?> data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema).setFieldDelimiter(";");
    List<Row> rows = data.collect();
    String modelPath = "http://alink-dataset.oss-cn-zhangjiakou.aliyuncs.com/tf/1551968314.zip";
    String workDir = PythonFileUtils.createTempDir("temp_").toString();
    String fn = FileDownloadUtils.downloadHttpOrOssFile(modelPath, workDir);
    String localModelPath = workDir + File.separator + fn;
    System.out.println("localModelPath: " + localModelPath);
    if (localModelPath.endsWith(".zip")) {
        File target = new File(localModelPath).getParentFile();
        ZipFileUtil.unZip(new File(localModelPath), target);
        localModelPath = localModelPath.substring(0, localModelPath.length() - ".zip".length());
        Preconditions.checkArgument(new File(localModelPath).exists(), "problematic zip file.");
    }
    Params params = new Params();
    params.set(HasModelPath.MODEL_PATH, localModelPath);
    params.set(HasSelectedCols.SELECTED_COLS, new String[] { "image" });
    params.set(HasOutputSchemaStr.OUTPUT_SCHEMA_STR, "classes bigint, probabilities string");
    BaseTFSavedModelPredictRowFlatMapper baseTFSavedModelPredictFlatMapper = new BaseTFSavedModelPredictRowFlatMapper(data.getSchema(), params);
    baseTFSavedModelPredictFlatMapper.open();
    List<Row> list = new ArrayList<>();
    ListCollector<Row> collector = new ListCollector<>(list);
    for (Row row : rows) {
        baseTFSavedModelPredictFlatMapper.flatMap(row, collector);
    }
    baseTFSavedModelPredictFlatMapper.close();
    Assert.assertEquals(TableSchema.builder().field("label", Types.LONG).field("image", Types.STRING).field("classes", Types.LONG).field("probabilities", Types.STRING).build(), baseTFSavedModelPredictFlatMapper.getOutputSchema());
    for (int i = 0; i < rows.size(); i += 1) {
        Assert.assertEquals(rows.get(i).getField(0), list.get(i).getField(0));
        Assert.assertEquals(rows.get(i).getField(1), list.get(i).getField(1));
    }
    Assert.assertEquals(rows.size(), list.size());
}
Also used : ArrayList(java.util.ArrayList) Params(org.apache.flink.ml.api.misc.param.Params) CsvSourceBatchOp(com.alibaba.alink.operator.batch.source.CsvSourceBatchOp) PluginDownloader(com.alibaba.alink.common.io.plugin.PluginDownloader) ListCollector(org.apache.flink.api.common.functions.util.ListCollector) Row(org.apache.flink.types.Row) RegisterKey(com.alibaba.alink.common.io.plugin.RegisterKey) File(java.io.File) Category(org.junit.experimental.categories.Category) Test(org.junit.Test) DLTest(com.alibaba.alink.testutil.categories.DLTest)

Aggregations

PluginDownloader (com.alibaba.alink.common.io.plugin.PluginDownloader)21 RegisterKey (com.alibaba.alink.common.io.plugin.RegisterKey)19 Test (org.junit.Test)19 DLTest (com.alibaba.alink.testutil.categories.DLTest)15 Params (org.apache.flink.ml.api.misc.param.Params)12 Row (org.apache.flink.types.Row)12 Category (org.junit.experimental.categories.Category)12 File (java.io.File)10 CsvSourceBatchOp (com.alibaba.alink.operator.batch.source.CsvSourceBatchOp)8 ArrayList (java.util.ArrayList)8 FloatTensor (com.alibaba.alink.common.linalg.tensor.FloatTensor)4 Shape (com.alibaba.alink.common.linalg.tensor.Shape)4 ShuffleBatchOp (com.alibaba.alink.operator.batch.dataproc.ShuffleBatchOp)4 MemSourceBatchOp (com.alibaba.alink.operator.batch.source.MemSourceBatchOp)4 AkSourceBatchOp (com.alibaba.alink.operator.batch.source.AkSourceBatchOp)3 TypeConvertStreamOp (com.alibaba.alink.operator.stream.dataproc.TypeConvertStreamOp)3 RandomTableSourceStreamOp (com.alibaba.alink.operator.stream.source.RandomTableSourceStreamOp)3 InputStream (java.io.InputStream)3 HashMap (java.util.HashMap)3 Random (java.util.Random)3